- 修复 end_date=None 导致 Flask API 返回错误时间范围的 bug * strategy.py: 自动使用今天日期作为 end_date * 验证:回测区间从 77 天恢复到 1539 天 - ETF 收益计算从原始价格改为后复权价格 * flask_api_fetcher.py: adj='raw' → adj='hfq' * 自动处理 ETF 份额拆分事件,确保收益率准确 - V2 简单版添加 A 股交易日过滤 * simple.py: 获取 SSE 交易日历,过滤非交易日 * 验证:1999 天 → 1539 天(与 V1 一致) - 配置严格对齐 V1 config.yaml * config_simple.yaml: start_date 从 2020-01-01 改为 2020-01-10 * group 字段值严格映射 V1 的 market 字段 关键验证: - V2 简单版回测:1539 天,981.95% 收益(未计入交易成本) - V2 正式版回测:1539 天,135.63% 收益(已计入交易成本) - V1 旧版框架:1539 天,103.29% 收益(基准)
241 lines
7.1 KiB
Python
241 lines
7.1 KiB
Python
"""
|
||
策略抽象基类
|
||
|
||
所有策略必须继承此类并实现必要方法
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, Optional, Any
|
||
import pandas as pd
|
||
|
||
from framework_v2.config.schemas import StrategyConfig
|
||
|
||
|
||
class StrategyBase(ABC):
|
||
"""
|
||
策略抽象基类(V2 增强版)
|
||
|
||
定义策略的标准生命周期:
|
||
1. 初始化配置(使用 Pydantic Schema)
|
||
2. 获取数据(支持多数据源)
|
||
3. 计算因子(使用框架因子库)
|
||
4. 生成信号(策略特定逻辑)
|
||
5. 执行回测(框架通用执行器)
|
||
|
||
子类必须实现:
|
||
- get_codes(): 获取标的列表
|
||
- compute_factors(): 计算因子
|
||
- generate_signals(): 生成信号
|
||
- manage_positions(): 仓位管理
|
||
"""
|
||
|
||
INTERFACE_VERSION = 2 # V2 版本
|
||
|
||
def __init__(self, config: StrategyConfig):
|
||
"""
|
||
初始化策略
|
||
|
||
Args:
|
||
config: 策略配置(Pydantic Schema)
|
||
"""
|
||
self.config = config
|
||
self.name = config.metadata.strategy
|
||
|
||
# 组件将在子类中初始化
|
||
self._data_fetcher = None
|
||
self._executor = None
|
||
|
||
@abstractmethod
|
||
def get_codes(self) -> list:
|
||
"""
|
||
获取标的列表(策略必须实现)
|
||
|
||
Returns:
|
||
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||
"""
|
||
计算因子(策略必须实现)
|
||
|
||
Args:
|
||
data: 数据字典 {code: DataFrame}
|
||
|
||
Returns:
|
||
因子字典 {code: Series}
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||
"""
|
||
生成信号(策略必须实现)
|
||
|
||
Args:
|
||
factors: 因子字典 {code: Series}
|
||
|
||
Returns:
|
||
信号 DataFrame(包含 'signal' 列,1=买入,0=空仓)
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
仓位管理(策略必须实现)
|
||
|
||
Args:
|
||
signals: 信号 DataFrame
|
||
|
||
Returns:
|
||
仓位 DataFrame(包含 'weight' 列,权重和为 1)
|
||
"""
|
||
pass
|
||
|
||
def get_data(self) -> Dict[str, pd.DataFrame]:
|
||
"""
|
||
获取数据(框架实现,策略可覆盖)
|
||
|
||
Returns:
|
||
数据字典 {code: DataFrame}
|
||
"""
|
||
if self._data_fetcher is None:
|
||
self._data_fetcher = self._create_data_fetcher()
|
||
|
||
codes = self.get_codes()
|
||
|
||
# 处理 end_date 为 None 的情况(使用今天)
|
||
from datetime import date
|
||
start = self.config.backtest.start_date
|
||
end = self.config.backtest.end_date
|
||
if end is None:
|
||
end = date.today().strftime('%Y-%m-%d')
|
||
|
||
# 批量获取数据(fetch_indices 返回 {code: DataFrame})
|
||
try:
|
||
data = self._data_fetcher.fetch_indices(
|
||
codes=codes,
|
||
start=start,
|
||
end=end
|
||
)
|
||
return data
|
||
except Exception as e:
|
||
print(f" 错误: 数据获取失败 - {e}")
|
||
return {}
|
||
|
||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||
"""
|
||
计算因子(可选覆盖)
|
||
|
||
Args:
|
||
data: 数据字典
|
||
|
||
Returns:
|
||
因子 DataFrame(日期 × 标的)
|
||
"""
|
||
if self._factor is None:
|
||
self._factor = self.init_factors()
|
||
|
||
# 默认实现:遍历标的计算因子
|
||
factor_values = {}
|
||
for code in data.get('valid_codes', []):
|
||
if code in data.get('index_data', {}):
|
||
factor_values[code] = self._factor.compute(data['index_data'][code])
|
||
|
||
return pd.DataFrame(factor_values)
|
||
|
||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
生成信号
|
||
|
||
Args:
|
||
factor_df: 因子 DataFrame
|
||
|
||
Returns:
|
||
信号 DataFrame(包含 'signal' 列)
|
||
"""
|
||
if self._signal_generator is None:
|
||
self._signal_generator = self.init_signal_generator()
|
||
|
||
return self._signal_generator.generate(factor_df)
|
||
|
||
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
|
||
"""
|
||
运行完整回测流程(框架标准流程)
|
||
|
||
Args:
|
||
data: 可选,如不提供则自动获取
|
||
|
||
Returns:
|
||
回测结果字典,包含:
|
||
- equity_curve: 净值曲线
|
||
- trades: 交易记录
|
||
- metrics: 绩效指标
|
||
"""
|
||
# 1. 获取数据
|
||
if data is None:
|
||
print("[1/5] 获取数据...")
|
||
data = self.get_data()
|
||
print(f" 获取 {len(data)} 个标的")
|
||
|
||
# 2. 计算因子
|
||
print("[2/5] 计算因子...")
|
||
factors = self.compute_factors(data)
|
||
print(f" 计算 {len(factors)} 个因子")
|
||
|
||
# 3. 生成信号
|
||
print("[3/5] 生成信号...")
|
||
signals = self.generate_signals(factors)
|
||
print(f" 生成 {signals.shape[0]} 个信号")
|
||
|
||
# 4. 仓位管理
|
||
print("[4/5] 仓位管理...")
|
||
positions = self.manage_positions(signals)
|
||
# positions 可能是信号矩阵或权重矩阵,计算平均仓位
|
||
if hasattr(positions, 'sum'):
|
||
avg_position = positions.sum(axis=1).mean() if hasattr(positions.sum(axis=1), 'mean') else 0
|
||
print(f" 平均仓位: {avg_position:.2%}")
|
||
else:
|
||
print(f" 仓位管理完成")
|
||
|
||
# 5. 执行回测
|
||
print("[5/5] 执行回测...")
|
||
result = self._execute_backtest(positions, data)
|
||
print(f" 回测完成")
|
||
|
||
return result
|
||
|
||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
执行回测(子类可覆盖)
|
||
|
||
Args:
|
||
signals: 信号 DataFrame
|
||
data: 数据字典
|
||
|
||
Returns:
|
||
回测结果
|
||
"""
|
||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||
|
||
def _create_data_fetcher(self):
|
||
"""
|
||
创建数据获取器(框架实现)
|
||
|
||
Returns:
|
||
DataFetcher 实例
|
||
"""
|
||
from framework_v2.shared.data import FlaskAPIFetcher
|
||
|
||
# 使用配置中的第一个启用的数据源
|
||
for source_config in self.config.data.sources:
|
||
if source_config.enabled and source_config.type.value == 'flask_api':
|
||
return FlaskAPIFetcher(
|
||
base_url=source_config.url,
|
||
timeout=source_config.timeout
|
||
)
|
||
|
||
raise ValueError("未找到可用的数据源配置")
|