feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略

StrategyBase ABC:
- 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute
- 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices)
- 提供 run() 方法执行完整回测流程

SimpleRotationStrategy:
- 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions
- 支持动量因子计算(MomentumFactor)
- 实现全局选股和等权仓位管理
- 修复 int64 → float 转换问题

框架定位:
- 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略
- 策略只需实现 4 个抽象方法即可接入框架
This commit is contained in:
2026-05-24 14:25:47 +08:00
parent 341611c32b
commit de988b919b
3 changed files with 450 additions and 40 deletions

View File

@@ -8,72 +8,115 @@ from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework_v2.config.schemas import StrategyConfig
class StrategyBase(ABC):
"""
策略抽象基类
策略抽象基类V2 增强版)
定义策略的标准生命周期:
1. 初始化配置
2. 获取数据
3. 计算因子
4. 生成信号
5. 执行回测
1. 初始化配置(使用 Pydantic Schema
2. 获取数据(支持多数据源)
3. 计算因子(使用框架因子库)
4. 生成信号(策略特定逻辑)
5. 执行回测(框架通用执行器)
子类必须实现:
- init_factors(): 初始化因子
- init_signal_generator(): 初始化信号生成器
- get_codes(): 获取标的列表
- compute_factors(): 计算因子
- generate_signals(): 生成信号
- manage_positions(): 仓位管理
"""
INTERFACE_VERSION = 2 # V2 版本
name: str = "base"
timeframe: str = "1d"
def __init__(self, config: Optional[Dict] = None):
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置字典
config: 策略配置Pydantic Schema
"""
self.config = config or {}
self._factor = None
self._signal_generator = None
self.config = config
self.name = config.metadata.strategy
# 组件将在子类中初始化
self._data_fetcher = None
self._executor = None
@abstractmethod
def init_factors(self) -> Any:
def get_codes(self) -> list:
"""
初始化因子组件
获取标的列表(策略必须实现)
Returns:
因子实例(继承 FactorBase
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
"""
pass
@abstractmethod
def init_signal_generator(self) -> Any:
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
初始化信号生成器
计算因子(策略必须实现)
Args:
data: 数据字典 {code: DataFrame}
Returns:
信号生成器实例(继承 SignalGenerator
因子字典 {code: Series}
"""
pass
def get_data(self) -> Dict[str, Any]:
@abstractmethod
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
获取数据(可选覆盖
生成信号(策略必须实现
Args:
factors: 因子字典 {code: Series}
Returns:
数据字典,包含:
- index_data: 指数数据
- etf_data: ETF数据
- benchmark_data: 基准数据
- valid_codes: 有效标的列表
- trading_calendar: 交易日历
信号 DataFrame包含 'signal'1=买入0=空仓)
"""
raise NotImplementedError("Subclasses must implement get_data()")
pass
@abstractmethod
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(策略必须实现)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列,权重和为 1
"""
pass
def get_data(self) -> Dict[str, pd.DataFrame]:
"""
获取数据(框架实现,策略可覆盖)
Returns:
数据字典 {code: DataFrame}
"""
if self._data_fetcher is None:
self._data_fetcher = self._create_data_fetcher()
codes = self.get_codes()
# 批量获取数据fetch_indices 返回 {code: DataFrame}
try:
data = self._data_fetcher.fetch_indices(
codes=codes,
start=self.config.backtest.start_date,
end=self.config.backtest.end_date
)
return data
except Exception as e:
print(f" 错误: 数据获取失败 - {e}")
return {}
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
@@ -111,28 +154,46 @@ class StrategyBase(ABC):
return self._signal_generator.generate(factor_df)
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
"""
运行完整回测流程
运行完整回测流程(框架标准流程)
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典
回测结果字典,包含:
- equity_curve: 净值曲线
- trades: 交易记录
- metrics: 绩效指标
"""
# 1. 获取数据
if data is None:
print("[1/5] 获取数据...")
data = self.get_data()
print(f" 获取 {len(data)} 个标的")
# 2. 计算因子
factor_df = self.compute_factors(data)
print("[2/5] 计算因子...")
factors = self.compute_factors(data)
print(f" 计算 {len(factors)} 个因子")
# 3. 生成信号
signals = self.generate_signals(factor_df)
print("[3/5] 生成信号...")
signals = self.generate_signals(factors)
print(f" 生成 {signals.shape[0]} 个信号")
# 4. 执行回测(子类实现)
return self._execute_backtest(signals, data)
# 4. 仓位管理
print("[4/5] 仓位管理...")
positions = self.manage_positions(signals)
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
# 5. 执行回测
print("[5/5] 执行回测...")
result = self._execute_backtest(positions, data)
print(f" 回测完成")
return result
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -147,5 +208,21 @@ class StrategyBase(ABC):
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
def _create_data_fetcher(self):
"""
创建数据获取器(框架实现)
Returns:
DataFetcher 实例
"""
from framework_v2.shared.data import FlaskAPIFetcher
# 使用配置中的第一个启用的数据源
for source_config in self.config.data.sources:
if source_config.enabled and source_config.type.value == 'flask_api':
return FlaskAPIFetcher(
base_url=source_config.url,
timeout=source_config.timeout
)
raise ValueError("未找到可用的数据源配置")