Files
etf/strategies/rotation/strategy.py
aszerW 19131c41dd fix: 数据源路由修复与因子计算改进
1. 修复期货路由逻辑:NYMEX期货(.NYM)走YFinance而非Tushare
2. 添加SSH隧道路径修复(原引擎)
3. 因子计算只使用close列(处理部分指数只有收盘价的情况)
4. 添加数据不足和缺失率剔除日志

收益对比:
- 原引擎(剔除国债): 累计1804%, 调仓459次
- 新框架: 累计772%, 调仓1276次

差异原因待查:
- 国债剔除逻辑不同
- 调仓频率差异
2026-05-12 00:47:43 +08:00

305 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
轮动策略完整实现
整合数据获取、因子计算、信号生成、回测执行
"""
import pandas as pd
import yaml
from datetime import datetime
from pathlib import Path
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.execution import BacktestExecutor
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
# 导入定制组件
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
class RotationStrategy(StrategyBase):
"""
ETF轮动策略完整实现
基于动量因子 + Top N选股 + 分散化
使用方式:
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
result = strategy.run_backtest()
"""
name = "rotation"
select_num = 3
stoploss = -0.05
n_days = 25
rebalance_days = 1
rebalance_threshold = 0.0
trade_cost = 0.001
def __init__(self, config: dict = None):
"""初始化策略"""
# 应用配置
if config:
self._apply_config(config)
self.config = config
else:
self.config = {}
# 初始化因子
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
self._factor = FactorRegistry.get(
'momentum',
n_days=self.n_days,
crash_filter=True
)
# 构建分组映射(分散化选股)
self._group_mapping = self._build_group_mapping()
# 初始化信号生成器
self._selector = TopNSelector(
select_num=self.select_num,
group_mapping=self._group_mapping,
min_score=0.0,
rebalance_days=self.rebalance_days,
rebalance_threshold=self.rebalance_threshold
)
@classmethod
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
"""从YAML配置创建策略实例"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
return cls(config)
def _apply_config(self, config: dict) -> None:
"""应用配置参数"""
self.select_num = config.get('select_num', self.select_num)
self.n_days = config.get('n_days', self.n_days)
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
self.trade_cost = config.get('trade_cost', self.trade_cost)
self.start_date = config.get('start_date', '2019-01-01')
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
def _build_group_mapping(self) -> dict:
"""构建分组映射(分散化选股)"""
group_mapping = {}
code_list_config = self.config.get('code_list', {})
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self) -> dict:
"""获取数据(使用新数据源模块)"""
code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {})
benchmark_code = benchmark_config.get('code', '000300.SH')
if not code_list_config:
raise ValueError("配置中未找到 code_list")
# 使用新数据源模块
from datasource import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 调用 fetch_all
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
data_source.fetch_all(
code_config=code_list_config,
benchmark_code=benchmark_code,
start_date=self.start_date,
end_date=self.end_date
)
return {
'index_data': index_ohlcv_data, # 原始OHLCV数据
'index_close': index_data, # 对齐后的收盘价(宽格式)
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def compute_factors(self, data: dict) -> pd.DataFrame:
"""计算因子值"""
index_data = data['index_data']
valid_codes = data['valid_codes']
factor_values = {}
final_valid_codes = []
for code in valid_codes:
df = index_data[code]
# 只使用 close 列计算因子(匹配原引擎逻辑:部分指数只有收盘价)
if 'close' in df.columns:
close_series = df['close'].dropna()
else:
close_series = df.dropna()
# 原引擎剔除逻辑close 数据需要至少 n_days + 1 条
if len(close_series) < self.n_days + 1:
print(f" ⚠ 剔除 {code}: 数据不足 ({len(close_series)} < {self.n_days + 1})")
continue
# 只传入 close 列给因子计算器
close_df = pd.DataFrame({'close': close_series})
values = self._factor.compute(close_df)
factor_values[code] = values
final_valid_codes.append(code)
factor_df = pd.DataFrame(factor_values)
# 过滤缺失率过高的标的
total_rows = len(factor_df)
for code in final_valid_codes:
if code in factor_df.columns:
null_pct = factor_df[code].isnull().sum() / total_rows
if null_pct > 0.5:
print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高")
factor_df = factor_df.drop(columns=[code])
# 更新有效代码列表
data['valid_codes'] = [c for c in final_valid_codes if c in factor_df.columns]
return factor_df
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""生成信号"""
return self._selector.generate(factor_df)
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
"""
完整回测流程
Args:
data: 可选,如不提供则自动获取
save_path: 报告保存路径
Returns:
回测结果字典
"""
print("\n" + "=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 1. 获取数据
if data is None:
data = self.get_data()
valid_codes = data['valid_codes']
index_data = data['index_data']
print(f"\n候选标的: {len(valid_codes)}")
print(f"回测区间: {self.start_date} ~ {self.end_date}")
# 2. 计算因子
print("\n计算因子...")
factor_df = self.compute_factors(data)
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)}")
# 3. 生成信号
print("\n生成信号...")
signals = self.generate_signals(factor_df)
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)}")
# 4. 执行回测
print("\n执行回测...")
# 获取ETF数据和代码映射
etf_data = data.get('etf_data')
etf_code_map = data.get('etf_code_map', {}) # {指数代码: ETF代码}
# 计算日收益率使用ETF价格数据匹配原引擎逻辑
if etf_data is not None and not etf_data.empty:
# 使用ETF价格计算收益列名保持指数代码格式
returns_data = {}
for idx_code in valid_codes:
etf_code = etf_code_map.get(idx_code, idx_code)
if etf_code in etf_data.columns:
returns_data[f'日收益率_{idx_code}'] = etf_data[etf_code].pct_change()
returns_df = pd.DataFrame(returns_data)
else:
# 回退到指数收盘价数据
index_close = data.get('index_close')
if index_close is not None and not index_close.empty:
returns_df = index_close.pct_change()
returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns]
else:
returns_data = {}
for code in valid_codes:
if code in index_data:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
if valid_codes:
first_code = valid_codes[0]
returns_df.index = index_data[first_code].index
# 确保信号和收益率数据日期对齐
common_dates = signals.index.intersection(returns_df.index)
signals = signals.loc[common_dates]
returns_df = returns_df.loc[common_dates]
print(f" 对齐后日期: {len(common_dates)}")
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=self.trade_cost,
select_num=self.select_num
)
portfolio = executor.execute(signals, returns_df)
# 5. 输出结果
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print("\n回测结果:")
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
# 保存报告
if save_path:
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
signals.to_csv(f"{save_path}_signals.csv")
print(f" 报告保存: {save_path}_*.csv")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
# 保留抽象方法实现
def init_factors(self) -> FactorCombiner:
return FactorCombiner([self._factor])
def init_signal_generator(self) -> SignalGenerator:
return self._selector