Files
etf/framework_v2/strategies/rotation/simple.py
aszerW de988b919b feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略
StrategyBase ABC:
- 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute
- 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices)
- 提供 run() 方法执行完整回测流程

SimpleRotationStrategy:
- 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions
- 支持动量因子计算(MomentumFactor)
- 实现全局选股和等权仓位管理
- 修复 int64 → float 转换问题

框架定位:
- 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略
- 策略只需实现 4 个抽象方法即可接入框架
2026-05-24 14:25:47 +08:00

242 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
简单轮动策略
基于动量因子的 ETF 轮动策略
- 计算各标的动量得分
- 选择 Top-N 标的
- 等权分配仓位
"""
import pandas as pd
import numpy as np
from typing import Dict
from framework_v2.core.strategy import StrategyBase
from framework_v2.config.schemas import StrategyConfig
from framework_v2.shared.factors import MomentumFactor
class SimpleRotationStrategy(StrategyBase):
"""
简单轮动策略
策略逻辑:
1. 计算各标的动量得分(加权线性回归)
2. 选择得分最高的 Top-N 标的
3. 等权分配仓位
示例:
from framework_v2.config import load_config
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
config = load_config('rotation_simple.yaml')
strategy = SimpleRotationStrategy(config)
result = strategy.run()
"""
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置
"""
super().__init__(config)
# 初始化动量因子
self.momentum = MomentumFactor(
n_days=config.factor.n_days,
weighted=(config.factor.type.value == 'weighted_momentum')
)
# 策略参数
self.select_num = config.rotation.select_num if config.rotation else 3
self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0
def get_codes(self) -> list:
"""
获取标的列表
从配置的资产池中获取所有标的
"""
codes = []
# 股票资产
if self.config.asset_pools.equity:
codes.extend(self.config.asset_pools.equity.keys())
# 商品资产
if self.config.asset_pools.commodity:
codes.extend(self.config.asset_pools.commodity.keys())
# 固定收益资产
if self.config.asset_pools.fixed_income:
codes.extend(self.config.asset_pools.fixed_income.keys())
return codes
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
计算动量因子
Args:
data: 数据字典 {code: DataFrame}
Returns:
因子字典 {code: Series}
"""
factors = {}
for code, df in data.items():
try:
# 计算动量得分
factor_values = self.momentum.compute(df)
factors[code] = factor_values
except Exception as e:
print(f" 警告: {code} 因子计算失败 - {e}")
continue
return factors
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
生成轮动信号
逻辑:
1. 每个交易日选择动量得分最高的 Top-N 标的
2. 过滤得分低于阈值的标的
Args:
factors: 因子字典 {code: Series}
Returns:
信号 DataFrameindex=日期, columns=标的, values=1或0
"""
if not factors:
return pd.DataFrame()
# 对齐所有因子的日期
factor_df = pd.DataFrame(factors)
# 生成信号
signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0)
for date in factor_df.index:
# 获取当日因子值
scores = factor_df.loc[date].dropna()
if scores.empty:
continue
# 过滤低分标的
if self.min_score > 0:
scores = scores[scores >= self.min_score]
# 选择 Top-N
if len(scores) > self.select_num:
top_codes = scores.nlargest(self.select_num).index
else:
top_codes = scores.index
# 标记信号
signals.loc[date, top_codes] = 1
return signals.astype(int)
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(等权分配)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列)
"""
positions = signals.astype(float).copy()
# 计算每个日期的权重
for date in positions.index:
signal_row = positions.loc[date]
n_selected = signal_row.sum()
if n_selected > 0:
# 等权分配
positions.loc[date] = signal_row / n_selected
else:
# 空仓
positions.loc[date] = 0
return positions
def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]:
"""
执行回测
Args:
positions: 仓位 DataFrame
data: 数据字典 {code: DataFrame}
Returns:
回测结果字典
"""
# 提取收盘价
close_prices = {}
for code, df in data.items():
if 'close' in df.columns:
close_prices[code] = df['close']
close_df = pd.DataFrame(close_prices)
# 计算收益率
returns = close_df.pct_change()
# 计算策略收益(仓位加权)
# 注意T+1 执行,今天的信号明天生效
positions_delayed = positions.shift(1).fillna(0)
strategy_returns = (positions_delayed * returns).sum(axis=1)
# 计算净值曲线
equity_curve = (1 + strategy_returns).cumprod()
# 检查是否有数据
if len(equity_curve) == 0:
return {
'equity_curve': equity_curve,
'strategy_returns': strategy_returns,
'positions': positions,
'metrics': {
'total_return': 0,
'annual_return': 0,
'max_drawdown': 0,
'sharpe_ratio': 0,
'n_days': 0,
}
}
# 计算绩效指标
total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1
n_days = len(strategy_returns)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
# 最大回撤
cumulative_max = equity_curve.cummax()
drawdown = (equity_curve - cumulative_max) / cumulative_max
max_drawdown = drawdown.min()
# 夏普比率
sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0
return {
'equity_curve': equity_curve,
'strategy_returns': strategy_returns,
'positions': positions,
'metrics': {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'n_days': n_days,
}
}