feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略
StrategyBase ABC: - 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute - 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices) - 提供 run() 方法执行完整回测流程 SimpleRotationStrategy: - 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions - 支持动量因子计算(MomentumFactor) - 实现全局选股和等权仓位管理 - 修复 int64 → float 转换问题 框架定位: - 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略 - 策略只需实现 4 个抽象方法即可接入框架
This commit is contained in:
@@ -8,72 +8,115 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
策略抽象基类(V2 增强版)
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置
|
||||
2. 获取数据
|
||||
3. 计算因子
|
||||
4. 生成信号
|
||||
5. 执行回测
|
||||
1. 初始化配置(使用 Pydantic Schema)
|
||||
2. 获取数据(支持多数据源)
|
||||
3. 计算因子(使用框架因子库)
|
||||
4. 生成信号(策略特定逻辑)
|
||||
5. 执行回测(框架通用执行器)
|
||||
|
||||
子类必须实现:
|
||||
- init_factors(): 初始化因子
|
||||
- init_signal_generator(): 初始化信号生成器
|
||||
- get_codes(): 获取标的列表
|
||||
- compute_factors(): 计算因子
|
||||
- generate_signals(): 生成信号
|
||||
- manage_positions(): 仓位管理
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置字典
|
||||
config: 策略配置(Pydantic Schema)
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._factor = None
|
||||
self._signal_generator = None
|
||||
self.config = config
|
||||
self.name = config.metadata.strategy
|
||||
|
||||
# 组件将在子类中初始化
|
||||
self._data_fetcher = None
|
||||
self._executor = None
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> Any:
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
初始化因子组件
|
||||
获取标的列表(策略必须实现)
|
||||
|
||||
Returns:
|
||||
因子实例(继承 FactorBase)
|
||||
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> Any:
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
初始化信号生成器
|
||||
计算因子(策略必须实现)
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
信号生成器实例(继承 SignalGenerator)
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, Any]:
|
||||
@abstractmethod
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
获取数据(可选覆盖)
|
||||
生成信号(策略必须实现)
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
数据字典,包含:
|
||||
- index_data: 指数数据
|
||||
- etf_data: ETF数据
|
||||
- benchmark_data: 基准数据
|
||||
- valid_codes: 有效标的列表
|
||||
- trading_calendar: 交易日历
|
||||
信号 DataFrame(包含 'signal' 列,1=买入,0=空仓)
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_data()")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(策略必须实现)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列,权重和为 1)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取数据(框架实现,策略可覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典 {code: DataFrame}
|
||||
"""
|
||||
if self._data_fetcher is None:
|
||||
self._data_fetcher = self._create_data_fetcher()
|
||||
|
||||
codes = self.get_codes()
|
||||
|
||||
# 批量获取数据(fetch_indices 返回 {code: DataFrame})
|
||||
try:
|
||||
data = self._data_fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start=self.config.backtest.start_date,
|
||||
end=self.config.backtest.end_date
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f" 错误: 数据获取失败 - {e}")
|
||||
return {}
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -111,28 +154,46 @@ class StrategyBase(ABC):
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程
|
||||
运行完整回测流程(框架标准流程)
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
回测结果字典,包含:
|
||||
- equity_curve: 净值曲线
|
||||
- trades: 交易记录
|
||||
- metrics: 绩效指标
|
||||
"""
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
print("[1/5] 获取数据...")
|
||||
data = self.get_data()
|
||||
print(f" 获取 {len(data)} 个标的")
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
print("[2/5] 计算因子...")
|
||||
factors = self.compute_factors(data)
|
||||
print(f" 计算 {len(factors)} 个因子")
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
print("[3/5] 生成信号...")
|
||||
signals = self.generate_signals(factors)
|
||||
print(f" 生成 {signals.shape[0]} 个信号")
|
||||
|
||||
# 4. 执行回测(子类实现)
|
||||
return self._execute_backtest(signals, data)
|
||||
# 4. 仓位管理
|
||||
print("[4/5] 仓位管理...")
|
||||
positions = self.manage_positions(signals)
|
||||
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
|
||||
|
||||
# 5. 执行回测
|
||||
print("[5/5] 执行回测...")
|
||||
result = self._execute_backtest(positions, data)
|
||||
print(f" 回测完成")
|
||||
|
||||
return result
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -147,5 +208,21 @@ class StrategyBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
def _create_data_fetcher(self):
|
||||
"""
|
||||
创建数据获取器(框架实现)
|
||||
|
||||
Returns:
|
||||
DataFetcher 实例
|
||||
"""
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
# 使用配置中的第一个启用的数据源
|
||||
for source_config in self.config.data.sources:
|
||||
if source_config.enabled and source_config.type.value == 'flask_api':
|
||||
return FlaskAPIFetcher(
|
||||
base_url=source_config.url,
|
||||
timeout=source_config.timeout
|
||||
)
|
||||
|
||||
raise ValueError("未找到可用的数据源配置")
|
||||
|
||||
Reference in New Issue
Block a user