## 架构设计 - 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试) - 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher - 零侵入:与现有框架并行开发,不修改生产代码 ## 已完成 ✓ 核心接口层(5个ABC类) ✓ 通用因子层(MomentumFactor完全复制现有逻辑) ✓ 对比验证测试(新旧因子输出差异=0,测试通过) ## 验证结果 - 最大差异: 0.000000e+00 - 平均差异: 0.000000e+00 - 容差: < 1e-10 ## 下一步 - 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController) - 阶段4: 执行层迁移(BacktestRunner) - 阶段5: 数据层迁移(DataFetcher实现) - 阶段6: 完整策略对比验证 ## 设计原则 - 按需抽象,不预先设计 - 职责分离,避免框架膨胀 - 测试驱动,每个组件必须有对比测试 - 渐进式迁移,验证通过再替换
152 lines
3.9 KiB
Python
152 lines
3.9 KiB
Python
"""
|
||
策略抽象基类
|
||
|
||
所有策略必须继承此类并实现必要方法
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, Optional, Any
|
||
import pandas as pd
|
||
|
||
|
||
class StrategyBase(ABC):
|
||
"""
|
||
策略抽象基类
|
||
|
||
定义策略的标准生命周期:
|
||
1. 初始化配置
|
||
2. 获取数据
|
||
3. 计算因子
|
||
4. 生成信号
|
||
5. 执行回测
|
||
|
||
子类必须实现:
|
||
- init_factors(): 初始化因子
|
||
- init_signal_generator(): 初始化信号生成器
|
||
"""
|
||
|
||
INTERFACE_VERSION = 2 # V2 版本
|
||
|
||
name: str = "base"
|
||
timeframe: str = "1d"
|
||
|
||
def __init__(self, config: Optional[Dict] = None):
|
||
"""
|
||
初始化策略
|
||
|
||
Args:
|
||
config: 策略配置字典
|
||
"""
|
||
self.config = config or {}
|
||
self._factor = None
|
||
self._signal_generator = None
|
||
|
||
@abstractmethod
|
||
def init_factors(self) -> Any:
|
||
"""
|
||
初始化因子组件
|
||
|
||
Returns:
|
||
因子实例(继承 FactorBase)
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def init_signal_generator(self) -> Any:
|
||
"""
|
||
初始化信号生成器
|
||
|
||
Returns:
|
||
信号生成器实例(继承 SignalGenerator)
|
||
"""
|
||
pass
|
||
|
||
def get_data(self) -> Dict[str, Any]:
|
||
"""
|
||
获取数据(可选覆盖)
|
||
|
||
Returns:
|
||
数据字典,包含:
|
||
- index_data: 指数数据
|
||
- etf_data: ETF数据
|
||
- benchmark_data: 基准数据
|
||
- valid_codes: 有效标的列表
|
||
- trading_calendar: 交易日历
|
||
"""
|
||
raise NotImplementedError("Subclasses must implement get_data()")
|
||
|
||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||
"""
|
||
计算因子(可选覆盖)
|
||
|
||
Args:
|
||
data: 数据字典
|
||
|
||
Returns:
|
||
因子 DataFrame(日期 × 标的)
|
||
"""
|
||
if self._factor is None:
|
||
self._factor = self.init_factors()
|
||
|
||
# 默认实现:遍历标的计算因子
|
||
factor_values = {}
|
||
for code in data.get('valid_codes', []):
|
||
if code in data.get('index_data', {}):
|
||
factor_values[code] = self._factor.compute(data['index_data'][code])
|
||
|
||
return pd.DataFrame(factor_values)
|
||
|
||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
生成信号
|
||
|
||
Args:
|
||
factor_df: 因子 DataFrame
|
||
|
||
Returns:
|
||
信号 DataFrame(包含 'signal' 列)
|
||
"""
|
||
if self._signal_generator is None:
|
||
self._signal_generator = self.init_signal_generator()
|
||
|
||
return self._signal_generator.generate(factor_df)
|
||
|
||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||
"""
|
||
运行完整回测流程
|
||
|
||
Args:
|
||
data: 可选,如不提供则自动获取
|
||
|
||
Returns:
|
||
回测结果字典
|
||
"""
|
||
# 1. 获取数据
|
||
if data is None:
|
||
data = self.get_data()
|
||
|
||
# 2. 计算因子
|
||
factor_df = self.compute_factors(data)
|
||
|
||
# 3. 生成信号
|
||
signals = self.generate_signals(factor_df)
|
||
|
||
# 4. 执行回测(子类实现)
|
||
return self._execute_backtest(signals, data)
|
||
|
||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
执行回测(子类可覆盖)
|
||
|
||
Args:
|
||
signals: 信号 DataFrame
|
||
data: 数据字典
|
||
|
||
Returns:
|
||
回测结果
|
||
"""
|
||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||
|
||
def __repr__(self) -> str:
|
||
return f"{self.__class__.__name__}(name={self.name})"
|