Files
etf/framework_v2/core/strategy.py
aszerW 908b28473f feat(framework_v2): 创建框架V2骨架 - 三层架构+因子验证通过
## 架构设计
- 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试)
- 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher
- 零侵入:与现有框架并行开发,不修改生产代码

## 已完成
✓ 核心接口层(5个ABC类)
✓ 通用因子层(MomentumFactor完全复制现有逻辑)
✓ 对比验证测试(新旧因子输出差异=0,测试通过)

## 验证结果
- 最大差异: 0.000000e+00
- 平均差异: 0.000000e+00
- 容差: < 1e-10

## 下一步
- 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController)
- 阶段4: 执行层迁移(BacktestRunner)
- 阶段5: 数据层迁移(DataFetcher实现)
- 阶段6: 完整策略对比验证

## 设计原则
- 按需抽象,不预先设计
- 职责分离,避免框架膨胀
- 测试驱动,每个组件必须有对比测试
- 渐进式迁移,验证通过再替换
2026-05-24 09:12:29 +08:00

152 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略抽象基类
所有策略必须继承此类并实现必要方法
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
class StrategyBase(ABC):
"""
策略抽象基类
定义策略的标准生命周期:
1. 初始化配置
2. 获取数据
3. 计算因子
4. 生成信号
5. 执行回测
子类必须实现:
- init_factors(): 初始化因子
- init_signal_generator(): 初始化信号生成器
"""
INTERFACE_VERSION = 2 # V2 版本
name: str = "base"
timeframe: str = "1d"
def __init__(self, config: Optional[Dict] = None):
"""
初始化策略
Args:
config: 策略配置字典
"""
self.config = config or {}
self._factor = None
self._signal_generator = None
@abstractmethod
def init_factors(self) -> Any:
"""
初始化因子组件
Returns:
因子实例(继承 FactorBase
"""
pass
@abstractmethod
def init_signal_generator(self) -> Any:
"""
初始化信号生成器
Returns:
信号生成器实例(继承 SignalGenerator
"""
pass
def get_data(self) -> Dict[str, Any]:
"""
获取数据(可选覆盖)
Returns:
数据字典,包含:
- index_data: 指数数据
- etf_data: ETF数据
- benchmark_data: 基准数据
- valid_codes: 有效标的列表
- trading_calendar: 交易日历
"""
raise NotImplementedError("Subclasses must implement get_data()")
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
计算因子(可选覆盖)
Args:
data: 数据字典
Returns:
因子 DataFrame日期 × 标的)
"""
if self._factor is None:
self._factor = self.init_factors()
# 默认实现:遍历标的计算因子
factor_values = {}
for code in data.get('valid_codes', []):
if code in data.get('index_data', {}):
factor_values[code] = self._factor.compute(data['index_data'][code])
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""
生成信号
Args:
factor_df: 因子 DataFrame
Returns:
信号 DataFrame包含 'signal' 列)
"""
if self._signal_generator is None:
self._signal_generator = self.init_signal_generator()
return self._signal_generator.generate(factor_df)
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
"""
运行完整回测流程
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典
"""
# 1. 获取数据
if data is None:
data = self.get_data()
# 2. 计算因子
factor_df = self.compute_factors(data)
# 3. 生成信号
signals = self.generate_signals(factor_df)
# 4. 执行回测(子类实现)
return self._execute_backtest(signals, data)
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行回测(子类可覆盖)
Args:
signals: 信号 DataFrame
data: 数据字典
Returns:
回测结果
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"