Files
etf/framework_v2/core/strategy.py
aszerW 94b9ef165b feat(v2): 增强框架核心功能与ETF复权修复
- 修复 end_date=None 导致 Flask API 返回错误时间范围的 bug
  * strategy.py: 自动使用今天日期作为 end_date
  * 验证:回测区间从 77 天恢复到 1539 天

- ETF 收益计算从原始价格改为后复权价格
  * flask_api_fetcher.py: adj='raw' → adj='hfq'
  * 自动处理 ETF 份额拆分事件,确保收益率准确

- V2 简单版添加 A 股交易日过滤
  * simple.py: 获取 SSE 交易日历,过滤非交易日
  * 验证:1999 天 → 1539 天(与 V1 一致)

- 配置严格对齐 V1 config.yaml
  * config_simple.yaml: start_date 从 2020-01-01 改为 2020-01-10
  * group 字段值严格映射 V1 的 market 字段

关键验证:
- V2 简单版回测:1539 天,981.95% 收益(未计入交易成本)
- V2 正式版回测:1539 天,135.63% 收益(已计入交易成本)
- V1 旧版框架:1539 天,103.29% 收益(基准)
2026-05-24 22:53:45 +08:00

241 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略抽象基类
所有策略必须继承此类并实现必要方法
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework_v2.config.schemas import StrategyConfig
class StrategyBase(ABC):
"""
策略抽象基类V2 增强版)
定义策略的标准生命周期:
1. 初始化配置(使用 Pydantic Schema
2. 获取数据(支持多数据源)
3. 计算因子(使用框架因子库)
4. 生成信号(策略特定逻辑)
5. 执行回测(框架通用执行器)
子类必须实现:
- get_codes(): 获取标的列表
- compute_factors(): 计算因子
- generate_signals(): 生成信号
- manage_positions(): 仓位管理
"""
INTERFACE_VERSION = 2 # V2 版本
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置Pydantic Schema
"""
self.config = config
self.name = config.metadata.strategy
# 组件将在子类中初始化
self._data_fetcher = None
self._executor = None
@abstractmethod
def get_codes(self) -> list:
"""
获取标的列表(策略必须实现)
Returns:
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
"""
pass
@abstractmethod
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
计算因子(策略必须实现)
Args:
data: 数据字典 {code: DataFrame}
Returns:
因子字典 {code: Series}
"""
pass
@abstractmethod
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
生成信号(策略必须实现)
Args:
factors: 因子字典 {code: Series}
Returns:
信号 DataFrame包含 'signal'1=买入0=空仓)
"""
pass
@abstractmethod
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(策略必须实现)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列,权重和为 1
"""
pass
def get_data(self) -> Dict[str, pd.DataFrame]:
"""
获取数据(框架实现,策略可覆盖)
Returns:
数据字典 {code: DataFrame}
"""
if self._data_fetcher is None:
self._data_fetcher = self._create_data_fetcher()
codes = self.get_codes()
# 处理 end_date 为 None 的情况(使用今天)
from datetime import date
start = self.config.backtest.start_date
end = self.config.backtest.end_date
if end is None:
end = date.today().strftime('%Y-%m-%d')
# 批量获取数据fetch_indices 返回 {code: DataFrame}
try:
data = self._data_fetcher.fetch_indices(
codes=codes,
start=start,
end=end
)
return data
except Exception as e:
print(f" 错误: 数据获取失败 - {e}")
return {}
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
计算因子(可选覆盖)
Args:
data: 数据字典
Returns:
因子 DataFrame日期 × 标的)
"""
if self._factor is None:
self._factor = self.init_factors()
# 默认实现:遍历标的计算因子
factor_values = {}
for code in data.get('valid_codes', []):
if code in data.get('index_data', {}):
factor_values[code] = self._factor.compute(data['index_data'][code])
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""
生成信号
Args:
factor_df: 因子 DataFrame
Returns:
信号 DataFrame包含 'signal' 列)
"""
if self._signal_generator is None:
self._signal_generator = self.init_signal_generator()
return self._signal_generator.generate(factor_df)
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
"""
运行完整回测流程(框架标准流程)
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典,包含:
- equity_curve: 净值曲线
- trades: 交易记录
- metrics: 绩效指标
"""
# 1. 获取数据
if data is None:
print("[1/5] 获取数据...")
data = self.get_data()
print(f" 获取 {len(data)} 个标的")
# 2. 计算因子
print("[2/5] 计算因子...")
factors = self.compute_factors(data)
print(f" 计算 {len(factors)} 个因子")
# 3. 生成信号
print("[3/5] 生成信号...")
signals = self.generate_signals(factors)
print(f" 生成 {signals.shape[0]} 个信号")
# 4. 仓位管理
print("[4/5] 仓位管理...")
positions = self.manage_positions(signals)
# positions 可能是信号矩阵或权重矩阵,计算平均仓位
if hasattr(positions, 'sum'):
avg_position = positions.sum(axis=1).mean() if hasattr(positions.sum(axis=1), 'mean') else 0
print(f" 平均仓位: {avg_position:.2%}")
else:
print(f" 仓位管理完成")
# 5. 执行回测
print("[5/5] 执行回测...")
result = self._execute_backtest(positions, data)
print(f" 回测完成")
return result
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行回测(子类可覆盖)
Args:
signals: 信号 DataFrame
data: 数据字典
Returns:
回测结果
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def _create_data_fetcher(self):
"""
创建数据获取器(框架实现)
Returns:
DataFetcher 实例
"""
from framework_v2.shared.data import FlaskAPIFetcher
# 使用配置中的第一个启用的数据源
for source_config in self.config.data.sources:
if source_config.enabled and source_config.type.value == 'flask_api':
return FlaskAPIFetcher(
base_url=source_config.url,
timeout=source_config.timeout
)
raise ValueError("未找到可用的数据源配置")