Files
etf/framework_v2/core/strategy.py
aszerW e6657bd2cc feat(framework_v2): 对齐 V1 配置,实现指数信号→ETF收益回测
配置对齐:
- config_simple.yaml 严格对齐 V1 config.yaml
  * 11 个标的覆盖 7 个策略分组
  * 回测区间: 2020-01-01 ~ 至今
  * 选股数量: Top-3,强制分散化
  * V3 动态阈值(短债动量参考)
  * 溢价控制启用(HK/US 10%阈值)

策略实现:
- SimpleRotationStrategy 支持 signal_source/trade_source 分离
  * get_codes() 同时获取信号和交易标的
  * compute_factors() 只使用 signal_source 计算因子
  * _execute_backtest() 使用 trade_source 计算收益
  * 支持跨市场场景(指数信号 → ETF收益)

回测验证:
- 成功运行端到端回测
- 获取 21 个标的(11 signal + 10 trade)
- 平均仓位 84.42%
- ⚠️ 已知问题: Flask API 只返回缓存数据(2026年),需修复

修复项:
- StrategyBase.run() 兼容信号矩阵(移除 'weight' 列假设)
2026-05-24 14:58:41 +08:00

234 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略抽象基类
所有策略必须继承此类并实现必要方法
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework_v2.config.schemas import StrategyConfig
class StrategyBase(ABC):
"""
策略抽象基类V2 增强版)
定义策略的标准生命周期:
1. 初始化配置(使用 Pydantic Schema
2. 获取数据(支持多数据源)
3. 计算因子(使用框架因子库)
4. 生成信号(策略特定逻辑)
5. 执行回测(框架通用执行器)
子类必须实现:
- get_codes(): 获取标的列表
- compute_factors(): 计算因子
- generate_signals(): 生成信号
- manage_positions(): 仓位管理
"""
INTERFACE_VERSION = 2 # V2 版本
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置Pydantic Schema
"""
self.config = config
self.name = config.metadata.strategy
# 组件将在子类中初始化
self._data_fetcher = None
self._executor = None
@abstractmethod
def get_codes(self) -> list:
"""
获取标的列表(策略必须实现)
Returns:
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
"""
pass
@abstractmethod
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
计算因子(策略必须实现)
Args:
data: 数据字典 {code: DataFrame}
Returns:
因子字典 {code: Series}
"""
pass
@abstractmethod
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
生成信号(策略必须实现)
Args:
factors: 因子字典 {code: Series}
Returns:
信号 DataFrame包含 'signal'1=买入0=空仓)
"""
pass
@abstractmethod
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(策略必须实现)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列,权重和为 1
"""
pass
def get_data(self) -> Dict[str, pd.DataFrame]:
"""
获取数据(框架实现,策略可覆盖)
Returns:
数据字典 {code: DataFrame}
"""
if self._data_fetcher is None:
self._data_fetcher = self._create_data_fetcher()
codes = self.get_codes()
# 批量获取数据fetch_indices 返回 {code: DataFrame}
try:
data = self._data_fetcher.fetch_indices(
codes=codes,
start=self.config.backtest.start_date,
end=self.config.backtest.end_date
)
return data
except Exception as e:
print(f" 错误: 数据获取失败 - {e}")
return {}
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
计算因子(可选覆盖)
Args:
data: 数据字典
Returns:
因子 DataFrame日期 × 标的)
"""
if self._factor is None:
self._factor = self.init_factors()
# 默认实现:遍历标的计算因子
factor_values = {}
for code in data.get('valid_codes', []):
if code in data.get('index_data', {}):
factor_values[code] = self._factor.compute(data['index_data'][code])
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""
生成信号
Args:
factor_df: 因子 DataFrame
Returns:
信号 DataFrame包含 'signal' 列)
"""
if self._signal_generator is None:
self._signal_generator = self.init_signal_generator()
return self._signal_generator.generate(factor_df)
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
"""
运行完整回测流程(框架标准流程)
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典,包含:
- equity_curve: 净值曲线
- trades: 交易记录
- metrics: 绩效指标
"""
# 1. 获取数据
if data is None:
print("[1/5] 获取数据...")
data = self.get_data()
print(f" 获取 {len(data)} 个标的")
# 2. 计算因子
print("[2/5] 计算因子...")
factors = self.compute_factors(data)
print(f" 计算 {len(factors)} 个因子")
# 3. 生成信号
print("[3/5] 生成信号...")
signals = self.generate_signals(factors)
print(f" 生成 {signals.shape[0]} 个信号")
# 4. 仓位管理
print("[4/5] 仓位管理...")
positions = self.manage_positions(signals)
# positions 可能是信号矩阵或权重矩阵,计算平均仓位
if hasattr(positions, 'sum'):
avg_position = positions.sum(axis=1).mean() if hasattr(positions.sum(axis=1), 'mean') else 0
print(f" 平均仓位: {avg_position:.2%}")
else:
print(f" 仓位管理完成")
# 5. 执行回测
print("[5/5] 执行回测...")
result = self._execute_backtest(positions, data)
print(f" 回测完成")
return result
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行回测(子类可覆盖)
Args:
signals: 信号 DataFrame
data: 数据字典
Returns:
回测结果
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def _create_data_fetcher(self):
"""
创建数据获取器(框架实现)
Returns:
DataFetcher 实例
"""
from framework_v2.shared.data import FlaskAPIFetcher
# 使用配置中的第一个启用的数据源
for source_config in self.config.data.sources:
if source_config.enabled and source_config.type.value == 'flask_api':
return FlaskAPIFetcher(
base_url=source_config.url,
timeout=source_config.timeout
)
raise ValueError("未找到可用的数据源配置")