Files
etf/strategies/rotation/strategy.py
aszerW 893a75a27f refactor: 将回测逻辑整合到策略类,简化执行入口
重构 RotationStrategy:
- 添加 from_yaml() 从配置创建实例
- 添加 get_data() 获取数据
- 添加 compute_factors() 计算因子
- 添加 generate_signals() 生成信号
- 添加 run_backtest() 完整回测流程

简化 run_rotation.py:
- 从 264 行简化为 9 行
- 只做策略调用入口

执行方式:
  python run_rotation.py --config config/strategies/rotation.yaml
  python run_rotation.py --save-path results/my_rotation

代码方式:
  strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
  result = strategy.run_backtest()
2026-05-11 23:50:40 +08:00

246 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
轮动策略完整实现
整合数据获取、因子计算、信号生成、回测执行
"""
import pandas as pd
import yaml
from datetime import datetime
from pathlib import Path
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.execution import BacktestExecutor
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
# 导入定制组件
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
class RotationStrategy(StrategyBase):
"""
ETF轮动策略完整实现
基于动量因子 + Top N选股 + 分散化
使用方式:
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
result = strategy.run_backtest()
"""
name = "rotation"
select_num = 3
stoploss = -0.05
n_days = 25
rebalance_days = 1
rebalance_threshold = 0.0
trade_cost = 0.001
def __init__(self, config: dict = None):
"""初始化策略"""
# 应用配置
if config:
self._apply_config(config)
self.config = config
else:
self.config = {}
# 初始化因子
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
self._factor = FactorRegistry.get(
'momentum',
n_days=self.n_days,
crash_filter=True
)
# 构建分组映射(分散化选股)
self._group_mapping = self._build_group_mapping()
# 初始化信号生成器
self._selector = TopNSelector(
select_num=self.select_num,
group_mapping=self._group_mapping,
min_score=0.0,
rebalance_days=self.rebalance_days,
rebalance_threshold=self.rebalance_threshold
)
@classmethod
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
"""从YAML配置创建策略实例"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
return cls(config)
def _apply_config(self, config: dict) -> None:
"""应用配置参数"""
self.select_num = config.get('select_num', self.select_num)
self.n_days = config.get('n_days', self.n_days)
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
self.trade_cost = config.get('trade_cost', self.trade_cost)
self.start_date = config.get('start_date', '2019-01-01')
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
def _build_group_mapping(self) -> dict:
"""构建分组映射(分散化选股)"""
group_mapping = {}
code_list_config = self.config.get('code_list', {})
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self) -> dict:
"""获取数据(复用归档的数据源)"""
code_list_config = self.config.get('code_list', {})
code_list = list(code_list_config.keys())
if not code_list:
raise ValueError("配置中未找到 code_list")
# 使用归档的HybridDataSource
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})
if ssh_config.get('enabled'):
ssh_config = {
'host': ssh_config.get('host'),
'port': ssh_config.get('port', 22),
'username': ssh_config.get('username', 'root'),
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
'local_port': ssh_config.get('local_port', 1080)
}
else:
ssh_config = None
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 获取数据
index_data = {}
all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date)
for code, df in all_data.items():
if df is not None and not df.empty:
index_data[code] = df
return {
'index_data': index_data,
'valid_codes': list(index_data.keys())
}
def compute_factors(self, data: dict) -> pd.DataFrame:
"""计算因子值"""
index_data = data['index_data']
valid_codes = data['valid_codes']
factor_values = {}
for code in valid_codes:
df = index_data[code]
if len(df) >= self.n_days:
values = self._factor.compute(df)
factor_values[code] = values
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""生成信号"""
return self._selector.generate(factor_df)
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
"""
完整回测流程
Args:
data: 可选,如不提供则自动获取
save_path: 报告保存路径
Returns:
回测结果字典
"""
print("\n" + "=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 1. 获取数据
if data is None:
data = self.get_data()
valid_codes = data['valid_codes']
index_data = data['index_data']
print(f"\n候选标的: {len(valid_codes)}")
print(f"回测区间: {self.start_date} ~ {self.end_date}")
# 2. 计算因子
print("\n计算因子...")
factor_df = self.compute_factors(data)
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)}")
# 3. 生成信号
print("\n生成信号...")
signals = self.generate_signals(factor_df)
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)}")
# 4. 执行回测
print("\n执行回测...")
returns_data = {}
first_code = valid_codes[0]
for code in valid_codes:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
returns_df.index = index_data[first_code].index
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=self.trade_cost,
select_num=self.select_num
)
portfolio = executor.execute(signals, returns_df)
# 5. 输出结果
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print("\n回测结果:")
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
# 保存报告
if save_path:
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
signals.to_csv(f"{save_path}_signals.csv")
print(f" 报告保存: {save_path}_*.csv")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
# 保留抽象方法实现
def init_factors(self) -> FactorCombiner:
return FactorCombiner([self._factor])
def init_signal_generator(self) -> SignalGenerator:
return self._selector