Files
etf/framework_v2/core/strategy.py
aszerW de988b919b feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略
StrategyBase ABC:
- 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute
- 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices)
- 提供 run() 方法执行完整回测流程

SimpleRotationStrategy:
- 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions
- 支持动量因子计算(MomentumFactor)
- 实现全局选股和等权仓位管理
- 修复 int64 → float 转换问题

框架定位:
- 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略
- 策略只需实现 4 个抽象方法即可接入框架
2026-05-24 14:25:47 +08:00

229 lines
6.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略抽象基类
所有策略必须继承此类并实现必要方法
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework_v2.config.schemas import StrategyConfig
class StrategyBase(ABC):
"""
策略抽象基类V2 增强版)
定义策略的标准生命周期:
1. 初始化配置(使用 Pydantic Schema
2. 获取数据(支持多数据源)
3. 计算因子(使用框架因子库)
4. 生成信号(策略特定逻辑)
5. 执行回测(框架通用执行器)
子类必须实现:
- get_codes(): 获取标的列表
- compute_factors(): 计算因子
- generate_signals(): 生成信号
- manage_positions(): 仓位管理
"""
INTERFACE_VERSION = 2 # V2 版本
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置Pydantic Schema
"""
self.config = config
self.name = config.metadata.strategy
# 组件将在子类中初始化
self._data_fetcher = None
self._executor = None
@abstractmethod
def get_codes(self) -> list:
"""
获取标的列表(策略必须实现)
Returns:
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
"""
pass
@abstractmethod
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
计算因子(策略必须实现)
Args:
data: 数据字典 {code: DataFrame}
Returns:
因子字典 {code: Series}
"""
pass
@abstractmethod
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
生成信号(策略必须实现)
Args:
factors: 因子字典 {code: Series}
Returns:
信号 DataFrame包含 'signal'1=买入0=空仓)
"""
pass
@abstractmethod
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(策略必须实现)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列,权重和为 1
"""
pass
def get_data(self) -> Dict[str, pd.DataFrame]:
"""
获取数据(框架实现,策略可覆盖)
Returns:
数据字典 {code: DataFrame}
"""
if self._data_fetcher is None:
self._data_fetcher = self._create_data_fetcher()
codes = self.get_codes()
# 批量获取数据fetch_indices 返回 {code: DataFrame}
try:
data = self._data_fetcher.fetch_indices(
codes=codes,
start=self.config.backtest.start_date,
end=self.config.backtest.end_date
)
return data
except Exception as e:
print(f" 错误: 数据获取失败 - {e}")
return {}
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
计算因子(可选覆盖)
Args:
data: 数据字典
Returns:
因子 DataFrame日期 × 标的)
"""
if self._factor is None:
self._factor = self.init_factors()
# 默认实现:遍历标的计算因子
factor_values = {}
for code in data.get('valid_codes', []):
if code in data.get('index_data', {}):
factor_values[code] = self._factor.compute(data['index_data'][code])
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""
生成信号
Args:
factor_df: 因子 DataFrame
Returns:
信号 DataFrame包含 'signal' 列)
"""
if self._signal_generator is None:
self._signal_generator = self.init_signal_generator()
return self._signal_generator.generate(factor_df)
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
"""
运行完整回测流程(框架标准流程)
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典,包含:
- equity_curve: 净值曲线
- trades: 交易记录
- metrics: 绩效指标
"""
# 1. 获取数据
if data is None:
print("[1/5] 获取数据...")
data = self.get_data()
print(f" 获取 {len(data)} 个标的")
# 2. 计算因子
print("[2/5] 计算因子...")
factors = self.compute_factors(data)
print(f" 计算 {len(factors)} 个因子")
# 3. 生成信号
print("[3/5] 生成信号...")
signals = self.generate_signals(factors)
print(f" 生成 {signals.shape[0]} 个信号")
# 4. 仓位管理
print("[4/5] 仓位管理...")
positions = self.manage_positions(signals)
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
# 5. 执行回测
print("[5/5] 执行回测...")
result = self._execute_backtest(positions, data)
print(f" 回测完成")
return result
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行回测(子类可覆盖)
Args:
signals: 信号 DataFrame
data: 数据字典
Returns:
回测结果
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def _create_data_fetcher(self):
"""
创建数据获取器(框架实现)
Returns:
DataFetcher 实例
"""
from framework_v2.shared.data import FlaskAPIFetcher
# 使用配置中的第一个启用的数据源
for source_config in self.config.data.sources:
if source_config.enabled and source_config.type.value == 'flask_api':
return FlaskAPIFetcher(
base_url=source_config.url,
timeout=source_config.timeout
)
raise ValueError("未找到可用的数据源配置")