Files
etf/strategies/rotation/strategy.py
aszerW 982fbe250b fix: 修复跨市场收益率计算Bug
Bug机制:
- 先pct_change再ffill对齐,导致海外标的休市日复制前一天非零收益率
- 例如:美股圣诞节放假,但ffill填入前一天的+1.2%,重复计算收益

修复方案:
- 先ffill价格对齐到A股日历(休市日价格不变)
- 再pct_change计算收益率(休市日自然为0%)

影响:
- 修复前净值: 394.80(高估)
- 修复后净值: 16.23(真实)
- 该Bug导致海外标的在A股独有的交易日被重复计算收益

验证:
- 语法检查通过
- 回测运行正常
2026-05-20 22:34:12 +08:00

488 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
轮动策略完整实现
整合数据获取、因子计算、信号生成、回测执行
"""
import pandas as pd
import yaml
from datetime import datetime
from pathlib import Path
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.execution import BacktestExecutor
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
# 导入定制组件
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
class RotationStrategy(StrategyBase):
"""
ETF轮动策略完整实现
基于动量因子 + Top N选股 + 分散化
使用方式:
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml')
result = strategy.run_backtest()
"""
name = "rotation"
select_num = 3
stoploss = -0.05
n_days = 25
rebalance_days = 1
rebalance_threshold = 0.0
trade_cost = 0.001
def __init__(self, config: dict = None):
"""初始化策略"""
# 应用配置
if config:
self._apply_config(config)
self.config = config
else:
self.config = {}
# 初始化因子
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
self._factor = FactorRegistry.get(
'momentum',
n_days=self.n_days,
crash_filter=True
)
# 构建分组映射(分散化选股)
self._group_mapping = self._build_group_mapping()
# 初始化信号生成器
self._selector = TopNSelector(
select_num=self.select_num,
group_mapping=self._group_mapping,
min_score=self.min_score, # 从配置读取,支持动态调整阈值
rebalance_days=self.rebalance_days,
rebalance_threshold=self.rebalance_threshold,
bond_threshold_config=self.config.get('bond_threshold', {}) # V3动态阈值配置
)
@classmethod
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
"""从YAML配置创建策略实例"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
return cls(config)
def _apply_config(self, config: dict) -> None:
"""应用配置参数"""
self.select_num = config.get('select_num', self.select_num)
self.n_days = config.get('n_days', self.n_days)
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
self.trade_cost = config.get('trade_cost', self.trade_cost)
self.min_score = config.get('min_score', 0.0) # 动量最低阈值,默认过滤负动量
self.start_date = config.get('start_date', '2019-01-01')
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
def _build_group_mapping(self) -> dict:
"""构建分组映射(分散化选股)"""
group_mapping = {}
code_list_config = self.config.get('code_list', {})
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self, use_flask_api: bool = True) -> dict:
"""
获取数据
Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource
"""
code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {})
benchmark_code = benchmark_config.get('code', '000300.SH')
if not code_list_config:
raise ValueError("配置中未找到 code_list")
# 获取 Flask API 地址
flask_api_config = self.config.get('flask_api', {})
flask_api_url = flask_api_config.get('url') if flask_api_config.get('enabled') else None
if use_flask_api:
# 使用 Flask API 服务获取数据(远程调用)
return self._get_data_from_flask_api(
code_list_config,
benchmark_code,
flask_api_url
)
else:
# 使用本地 HybridDataSource需要本地 SSH 隧道)
return self._get_data_from_local(
code_list_config,
benchmark_code
)
def _get_data_from_flask_api(
self,
code_list_config: dict,
benchmark_code: str,
flask_api_url: str = None
) -> dict:
"""通过 Flask API 服务获取数据"""
from datasource.flask_api_source import FlaskAPIDataSource
# 初始化 Flask API 数据源
api_source = FlaskAPIDataSource(base_url=flask_api_url)
# 检查服务状态
health = api_source.get_health()
if health.get('status') != 'healthy':
print(f"⚠ Flask API 服务状态: {health}")
else:
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
# 打印回测时间区间说明
print(f"\n回测配置区间: {self.start_date} ~ {self.end_date}")
print("注: 各标的实际数据范围可能因上市时间/数据源限制而不同")
# 获取指数代码列表
index_codes = list(code_list_config.keys())
# 获取 ETF 代码映射
etf_code_map = {}
etf_codes = []
for index_code, cfg in code_list_config.items():
if isinstance(cfg, dict) and cfg.get('etf'):
etf_code_map[index_code] = cfg['etf']
etf_codes.append(cfg['etf'])
# 获取指数 OHLCV 数据
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
index_ohlcv_data = api_source.fetch_batch(
index_codes,
self.start_date,
self.end_date
)
# 过滤有效代码
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
print(f"有效指数: {len(valid_codes)}")
# 获取 ETF 价格数据(同时获取净值和溢价率)
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
etf_ohlcv_data = api_source.fetch_batch(
etf_codes,
self.start_date,
self.end_date
)
# 转换为宽格式 DataFrame并提取净值/溢价率数据
etf_data = None
etf_nav_data = {}
etf_premium_data = {}
if etf_ohlcv_data:
etf_close_dict = {}
for etf_code, df in etf_ohlcv_data.items():
if df is not None and 'close' in df.columns:
etf_close_dict[etf_code] = df['close']
# 从 DataFrame.attrs 中提取净值和溢价率数据
# Flask API 已自动附加这些数据
if 'nav' in df.attrs:
etf_nav_data[etf_code] = df.attrs['nav']
if 'premium_series' in df.attrs:
etf_premium_data[etf_code] = {
'series': df.attrs['premium_series'],
'latest': df.attrs.get('latest_premium'),
'date': df.attrs.get('premium_date'),
'stats': df.attrs.get('premium_stats'),
}
if etf_close_dict:
etf_data = pd.DataFrame(etf_close_dict)
print(f"有效净值: {len(etf_nav_data)}")
print(f"有效溢价率: {len(etf_premium_data)}")
# 获取基准数据
print(f"\n获取基准数据 ({benchmark_code})...")
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
benchmark_data = None
if benchmark_ohlcv is not None:
benchmark_data = benchmark_ohlcv['close']
# 构建指数收盘价宽格式 DataFrame用于因子计算
index_close_dict = {}
for code in valid_codes:
df = index_ohlcv_data.get(code)
if df is not None and 'close' in df.columns:
index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data, # ETF 收盘价(宽格式)
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def _get_data_from_local(
self,
code_list_config: dict,
benchmark_code: str
) -> dict:
"""使用本地 HybridDataSource 获取数据"""
from datasource import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 调用 fetch_all
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
data_source.fetch_all(
code_config=code_list_config,
benchmark_code=benchmark_code,
start_date=self.start_date,
end_date=self.end_date
)
return {
'index_data': index_ohlcv_data, # 原始OHLCV数据
'index_close': index_data, # 对齐后的收盘价(宽格式)
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
}
def compute_factors(self, data: dict) -> pd.DataFrame:
"""计算因子值匹配原引擎先计算因子再对齐到A股交易日历
注意:不剔除数据不足的标的,保留所有标的以暴露策略问题
"""
index_data = data['index_data']
valid_codes = data['valid_codes']
# 获取A股交易日历作为基准使用已有的对齐后数据索引
index_close = data.get('index_close')
if index_close is not None:
a_share_dates = index_close.index
else:
for code in valid_codes:
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
a_share_dates = index_data[code].index
break
else:
a_share_dates = index_data[valid_codes[0]].index
factor_values = {}
final_valid_codes = []
for code in valid_codes:
df = index_data[code].copy()
# 检查是否有有效的OHLCV数据列存在且不全为None
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
required_cols = ['open', 'high', 'low', 'close']
# 检查列是否存在
cols_exist = all(col in df.columns for col in required_cols)
# 检查数据是否有效不全为None/NaN
if cols_exist:
cols_have_data = all(df[col].notna().any() for col in required_cols)
else:
cols_have_data = False
if cols_exist and cols_have_data:
# 有完整有效的OHLCV数据整行dropna()后提取close
df_clean = df[ohlcv_cols].dropna()
close_series = df_clean['close'] if len(df_clean) > 0 else pd.Series(dtype=float)
elif 'close' in df.columns and df['close'].notna().any():
# 只有close列有效数据如债券指数
close_series = df['close'].dropna()
else:
# 无有效数据
close_series = pd.Series(dtype=float)
# 检查数据长度并警告,但不剔除
if len(close_series) < self.n_days + 1:
print(f"{code}: 数据不足 ({len(close_series)} < {self.n_days + 1})保留但因子值可能为NaN")
# 原引擎逻辑:先在原始交易日历上计算因子
# rolling窗口使用的是原始交易日数据不包含ffill填充的重复值
if len(close_series) > 0:
close_df = pd.DataFrame({'close': close_series})
factor_series = self._factor.compute(close_df)
# 然后对齐因子序列到A股交易日历匹配原引擎逻辑
factor_aligned = factor_series.reindex(a_share_dates, method='ffill')
else:
# 没有数据,创建空的因子序列
factor_aligned = pd.Series(index=a_share_dates, dtype=float)
factor_values[code] = factor_aligned
final_valid_codes.append(code)
factor_df = pd.DataFrame(factor_values)
# 检查缺失率并警告,但不剔除(保留所有标的以暴露策略问题)
total_rows = len(factor_df)
for code in final_valid_codes:
if code in factor_df.columns:
null_pct = factor_df[code].isnull().sum() / total_rows
if null_pct > 0.5:
print(f"{code}: 缺失率 {null_pct:.1%} 较高,保留但信号生成时可能跳过")
# 不更新有效代码列表,保留所有原始代码
data['valid_codes'] = final_valid_codes
return factor_df
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""生成信号"""
return self._selector.generate(factor_df)
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
"""
完整回测流程
Args:
data: 可选,如不提供则自动获取
save_path: 报告保存路径
Returns:
回测结果字典
"""
print("\n" + "=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 1. 获取数据
if data is None:
data = self.get_data()
valid_codes = data['valid_codes']
index_data = data['index_data']
print(f"\n候选标的: {len(valid_codes)}")
print(f"回测区间: {self.start_date} ~ {self.end_date}")
# 2. 计算因子
print("\n计算因子...")
factor_df = self.compute_factors(data)
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)}")
# 3. 生成信号
print("\n生成信号...")
signals = self.generate_signals(factor_df)
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)}")
# 4. 执行回测
print("\n执行回测...")
# 获取A股交易日历从因子数据索引
a_share_dates = signals.index
# 计算日收益率先在原始交易日历计算再对齐到A股日历
# 关键与因子计算逻辑一致避免交易日不对齐导致收益率NaN
returns_data = {}
for code in valid_codes:
if code in index_data:
df = index_data[code]
# 提取原始收盘价序列
if 'close' in df.columns:
close_series = df['close'].dropna()
# 修复先ffill价格对齐到A股日历再计算收益率
# 原因若先pct_change再ffill休市日会复制前一天的非零收益率
# 正确做法:休市日价格不变 → 收益率应为0%
close_aligned = close_series.reindex(a_share_dates, method='ffill')
returns_aligned = close_aligned.pct_change(fill_method=None)
returns_data[f'日收益率_{code}'] = returns_aligned
returns_df = pd.DataFrame(returns_data)
# 确保信号和收益率数据日期对齐
common_dates = signals.index.intersection(returns_df.index)
signals = signals.loc[common_dates]
returns_df = returns_df.loc[common_dates]
print(f" 对齐后日期: {len(common_dates)}")
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=self.trade_cost,
select_num=self.select_num
)
portfolio = executor.execute(signals, returns_df)
# 5. 输出结果
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print("\n回测结果:")
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
# 获取调仓事件
rebalance_events = getattr(portfolio, 'rebalance_events', pd.DataFrame())
if not rebalance_events.empty:
print(f" 调仓次数: {len(rebalance_events)}")
# 保存报告
if save_path:
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
signals.to_csv(f"{save_path}_signals.csv")
# 保存调仓事件记录
if not rebalance_events.empty:
rebalance_events.to_csv(f"{save_path}_rebalances.csv")
print(f" 报告保存: {save_path}_*.csv (含调仓记录)")
else:
print(f" 报告保存: {save_path}_*.csv")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return,
'rebalance_events': rebalance_events
}
return {'signals': signals, 'result': None}
# 保留抽象方法实现
def init_factors(self) -> FactorCombiner:
return FactorCombiner([self._factor])
def init_signal_generator(self) -> SignalGenerator:
return self._selector