""" 策略抽象基类 所有策略必须继承此类并实现必要方法 """ from abc import ABC, abstractmethod from typing import Dict, Optional, Any import pandas as pd from framework_v2.config.schemas import StrategyConfig class StrategyBase(ABC): """ 策略抽象基类(V2 增强版) 定义策略的标准生命周期: 1. 初始化配置(使用 Pydantic Schema) 2. 获取数据(支持多数据源) 3. 计算因子(使用框架因子库) 4. 生成信号(策略特定逻辑) 5. 执行回测(框架通用执行器) 子类必须实现: - get_codes(): 获取标的列表 - compute_factors(): 计算因子 - generate_signals(): 生成信号 - manage_positions(): 仓位管理 """ INTERFACE_VERSION = 2 # V2 版本 def __init__(self, config: StrategyConfig): """ 初始化策略 Args: config: 策略配置(Pydantic Schema) """ self.config = config self.name = config.metadata.strategy # 组件将在子类中初始化 self._data_fetcher = None self._executor = None @abstractmethod def get_codes(self) -> list: """ 获取标的列表(策略必须实现) Returns: 标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F'] """ pass @abstractmethod def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]: """ 计算因子(策略必须实现) Args: data: 数据字典 {code: DataFrame} Returns: 因子字典 {code: Series} """ pass @abstractmethod def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame: """ 生成信号(策略必须实现) Args: factors: 因子字典 {code: Series} Returns: 信号 DataFrame(包含 'signal' 列,1=买入,0=空仓) """ pass @abstractmethod def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame: """ 仓位管理(策略必须实现) Args: signals: 信号 DataFrame Returns: 仓位 DataFrame(包含 'weight' 列,权重和为 1) """ pass def get_data(self) -> Dict[str, pd.DataFrame]: """ 获取数据(框架实现,策略可覆盖) Returns: 数据字典 {code: DataFrame} """ if self._data_fetcher is None: self._data_fetcher = self._create_data_fetcher() codes = self.get_codes() # 批量获取数据(fetch_indices 返回 {code: DataFrame}) try: data = self._data_fetcher.fetch_indices( codes=codes, start=self.config.backtest.start_date, end=self.config.backtest.end_date ) return data except Exception as e: print(f" 错误: 数据获取失败 - {e}") return {} def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame: """ 计算因子(可选覆盖) Args: data: 数据字典 Returns: 因子 DataFrame(日期 × 标的) """ if self._factor is None: self._factor = self.init_factors() # 默认实现:遍历标的计算因子 factor_values = {} for code in data.get('valid_codes', []): if code in data.get('index_data', {}): factor_values[code] = self._factor.compute(data['index_data'][code]) return pd.DataFrame(factor_values) def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: """ 生成信号 Args: factor_df: 因子 DataFrame Returns: 信号 DataFrame(包含 'signal' 列) """ if self._signal_generator is None: self._signal_generator = self.init_signal_generator() return self._signal_generator.generate(factor_df) def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]: """ 运行完整回测流程(框架标准流程) Args: data: 可选,如不提供则自动获取 Returns: 回测结果字典,包含: - equity_curve: 净值曲线 - trades: 交易记录 - metrics: 绩效指标 """ # 1. 获取数据 if data is None: print("[1/5] 获取数据...") data = self.get_data() print(f" 获取 {len(data)} 个标的") # 2. 计算因子 print("[2/5] 计算因子...") factors = self.compute_factors(data) print(f" 计算 {len(factors)} 个因子") # 3. 生成信号 print("[3/5] 生成信号...") signals = self.generate_signals(factors) print(f" 生成 {signals.shape[0]} 个信号") # 4. 仓位管理 print("[4/5] 仓位管理...") positions = self.manage_positions(signals) print(f" 平均持仓: {positions['weight'].sum().mean():.2%}") # 5. 执行回测 print("[5/5] 执行回测...") result = self._execute_backtest(positions, data) print(f" 回测完成") return result def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]: """ 执行回测(子类可覆盖) Args: signals: 信号 DataFrame data: 数据字典 Returns: 回测结果 """ raise NotImplementedError("Subclasses must implement _execute_backtest()") def _create_data_fetcher(self): """ 创建数据获取器(框架实现) Returns: DataFetcher 实例 """ from framework_v2.shared.data import FlaskAPIFetcher # 使用配置中的第一个启用的数据源 for source_config in self.config.data.sources: if source_config.enabled and source_config.type.value == 'flask_api': return FlaskAPIFetcher( base_url=source_config.url, timeout=source_config.timeout ) raise ValueError("未找到可用的数据源配置")