StrategyBase ABC: - 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute - 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices) - 提供 run() 方法执行完整回测流程 SimpleRotationStrategy: - 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions - 支持动量因子计算(MomentumFactor) - 实现全局选股和等权仓位管理 - 修复 int64 → float 转换问题 框架定位: - 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略 - 策略只需实现 4 个抽象方法即可接入框架
242 lines
7.2 KiB
Python
242 lines
7.2 KiB
Python
"""
|
||
简单轮动策略
|
||
|
||
基于动量因子的 ETF 轮动策略
|
||
- 计算各标的动量得分
|
||
- 选择 Top-N 标的
|
||
- 等权分配仓位
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Dict
|
||
|
||
from framework_v2.core.strategy import StrategyBase
|
||
from framework_v2.config.schemas import StrategyConfig
|
||
from framework_v2.shared.factors import MomentumFactor
|
||
|
||
|
||
class SimpleRotationStrategy(StrategyBase):
|
||
"""
|
||
简单轮动策略
|
||
|
||
策略逻辑:
|
||
1. 计算各标的动量得分(加权线性回归)
|
||
2. 选择得分最高的 Top-N 标的
|
||
3. 等权分配仓位
|
||
|
||
示例:
|
||
from framework_v2.config import load_config
|
||
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
|
||
|
||
config = load_config('rotation_simple.yaml')
|
||
strategy = SimpleRotationStrategy(config)
|
||
result = strategy.run()
|
||
"""
|
||
|
||
def __init__(self, config: StrategyConfig):
|
||
"""
|
||
初始化策略
|
||
|
||
Args:
|
||
config: 策略配置
|
||
"""
|
||
super().__init__(config)
|
||
|
||
# 初始化动量因子
|
||
self.momentum = MomentumFactor(
|
||
n_days=config.factor.n_days,
|
||
weighted=(config.factor.type.value == 'weighted_momentum')
|
||
)
|
||
|
||
# 策略参数
|
||
self.select_num = config.rotation.select_num if config.rotation else 3
|
||
self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0
|
||
|
||
def get_codes(self) -> list:
|
||
"""
|
||
获取标的列表
|
||
|
||
从配置的资产池中获取所有标的
|
||
"""
|
||
codes = []
|
||
|
||
# 股票资产
|
||
if self.config.asset_pools.equity:
|
||
codes.extend(self.config.asset_pools.equity.keys())
|
||
|
||
# 商品资产
|
||
if self.config.asset_pools.commodity:
|
||
codes.extend(self.config.asset_pools.commodity.keys())
|
||
|
||
# 固定收益资产
|
||
if self.config.asset_pools.fixed_income:
|
||
codes.extend(self.config.asset_pools.fixed_income.keys())
|
||
|
||
return codes
|
||
|
||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||
"""
|
||
计算动量因子
|
||
|
||
Args:
|
||
data: 数据字典 {code: DataFrame}
|
||
|
||
Returns:
|
||
因子字典 {code: Series}
|
||
"""
|
||
factors = {}
|
||
|
||
for code, df in data.items():
|
||
try:
|
||
# 计算动量得分
|
||
factor_values = self.momentum.compute(df)
|
||
factors[code] = factor_values
|
||
except Exception as e:
|
||
print(f" 警告: {code} 因子计算失败 - {e}")
|
||
continue
|
||
|
||
return factors
|
||
|
||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||
"""
|
||
生成轮动信号
|
||
|
||
逻辑:
|
||
1. 每个交易日选择动量得分最高的 Top-N 标的
|
||
2. 过滤得分低于阈值的标的
|
||
|
||
Args:
|
||
factors: 因子字典 {code: Series}
|
||
|
||
Returns:
|
||
信号 DataFrame(index=日期, columns=标的, values=1或0)
|
||
"""
|
||
if not factors:
|
||
return pd.DataFrame()
|
||
|
||
# 对齐所有因子的日期
|
||
factor_df = pd.DataFrame(factors)
|
||
|
||
# 生成信号
|
||
signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0)
|
||
|
||
for date in factor_df.index:
|
||
# 获取当日因子值
|
||
scores = factor_df.loc[date].dropna()
|
||
|
||
if scores.empty:
|
||
continue
|
||
|
||
# 过滤低分标的
|
||
if self.min_score > 0:
|
||
scores = scores[scores >= self.min_score]
|
||
|
||
# 选择 Top-N
|
||
if len(scores) > self.select_num:
|
||
top_codes = scores.nlargest(self.select_num).index
|
||
else:
|
||
top_codes = scores.index
|
||
|
||
# 标记信号
|
||
signals.loc[date, top_codes] = 1
|
||
|
||
return signals.astype(int)
|
||
|
||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
仓位管理(等权分配)
|
||
|
||
Args:
|
||
signals: 信号 DataFrame
|
||
|
||
Returns:
|
||
仓位 DataFrame(包含 'weight' 列)
|
||
"""
|
||
positions = signals.astype(float).copy()
|
||
|
||
# 计算每个日期的权重
|
||
for date in positions.index:
|
||
signal_row = positions.loc[date]
|
||
n_selected = signal_row.sum()
|
||
|
||
if n_selected > 0:
|
||
# 等权分配
|
||
positions.loc[date] = signal_row / n_selected
|
||
else:
|
||
# 空仓
|
||
positions.loc[date] = 0
|
||
|
||
return positions
|
||
|
||
def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]:
|
||
"""
|
||
执行回测
|
||
|
||
Args:
|
||
positions: 仓位 DataFrame
|
||
data: 数据字典 {code: DataFrame}
|
||
|
||
Returns:
|
||
回测结果字典
|
||
"""
|
||
# 提取收盘价
|
||
close_prices = {}
|
||
for code, df in data.items():
|
||
if 'close' in df.columns:
|
||
close_prices[code] = df['close']
|
||
|
||
close_df = pd.DataFrame(close_prices)
|
||
|
||
# 计算收益率
|
||
returns = close_df.pct_change()
|
||
|
||
# 计算策略收益(仓位加权)
|
||
# 注意:T+1 执行,今天的信号明天生效
|
||
positions_delayed = positions.shift(1).fillna(0)
|
||
strategy_returns = (positions_delayed * returns).sum(axis=1)
|
||
|
||
# 计算净值曲线
|
||
equity_curve = (1 + strategy_returns).cumprod()
|
||
|
||
# 检查是否有数据
|
||
if len(equity_curve) == 0:
|
||
return {
|
||
'equity_curve': equity_curve,
|
||
'strategy_returns': strategy_returns,
|
||
'positions': positions,
|
||
'metrics': {
|
||
'total_return': 0,
|
||
'annual_return': 0,
|
||
'max_drawdown': 0,
|
||
'sharpe_ratio': 0,
|
||
'n_days': 0,
|
||
}
|
||
}
|
||
|
||
# 计算绩效指标
|
||
total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1
|
||
n_days = len(strategy_returns)
|
||
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||
|
||
# 最大回撤
|
||
cumulative_max = equity_curve.cummax()
|
||
drawdown = (equity_curve - cumulative_max) / cumulative_max
|
||
max_drawdown = drawdown.min()
|
||
|
||
# 夏普比率
|
||
sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0
|
||
|
||
return {
|
||
'equity_curve': equity_curve,
|
||
'strategy_returns': strategy_returns,
|
||
'positions': positions,
|
||
'metrics': {
|
||
'total_return': total_return,
|
||
'annual_return': annual_return,
|
||
'max_drawdown': max_drawdown,
|
||
'sharpe_ratio': sharpe,
|
||
'n_days': n_days,
|
||
}
|
||
}
|