Files
etf/strategies/rotation/engine.py
aszerW 70bb69fd98 fix(core): 修复计算与数据对齐等多处逻辑问题
- 修正CAGR计算,去除NaN并检查起始值有效性以避免异常结果
- 优化混合数据源的数据对齐逻辑,使用配置结束日期与A股最新数据日期的较早者
- 计算因子时对齐A股交易日历,重新基于对齐价格计算日收益率,改进因子对齐准确度
- 轮动策略中跳过空信号,避免空信号影响持仓和调仓逻辑
- 调整信号处理,过滤空字符串和NaN,保证轮动信号数据有效性
- 多品种轮动持仓中加入空信号判断,避免无效信号导致错误
- 调整调仓明细和品种汇总保存逻辑,增加空文件创建以保证输出路径文件稳定生成
- 完善多处打印信息和注释,增强代码可读性与调试便利性
2026-03-26 22:21:38 +08:00

306 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ETF轮动策略引擎
整合信号生成和回测逻辑
使用 YFinance 数据源(支持 SSH 隧道)
"""
import pandas as pd
import numpy as np
from typing import Optional
from strategies.base import BacktestStrategy
from core.datasource.hybrid_source import HybridDataSource
from core.factors.momentum import compute_factors, calculate_daily_return
class RotationStrategy(BacktestStrategy):
"""ETF轮动策略"""
def __init__(self, config: dict):
super().__init__("ETF轮动策略", config)
# 初始化混合数据源
ssh_config = config.get("ssh_tunnel", {})
self.data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=config.get("use_cache", True)
)
print(f"使用混合数据源: Tushare(中国A股) + YFinance(港股/美股/加密货币)")
print(f"SSH隧道: {ssh_config.get('enabled', False)}")
self.data = None
self.signals = None
self.backtest_result = None
def fetch_data(self) -> pd.DataFrame:
"""获取数据(支持指数-ETF双轨数据"""
from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
# 获取代码配置(包含 name, etf, market
code_config = self.config.get("code_list", {})
# 使用上下文管理器管理 SSH 隧道
with self.data_source:
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes = self.data_source.fetch_all(
code_config,
benchmark_code,
self.config["start_date"],
self.config["end_date"],
)
# 存储数据和配置
self.index_data = index_data # 指数数据(用于因子计算)
self.etf_data = etf_data # ETF价格数据用于收益计算
self.etf_nav_data = etf_nav_data # ETF净值数据用于溢价率计算
self.benchmark_data = benchmark_data
self.valid_codes = valid_codes
self.code_config = code_config # 代码配置(用于判断市场类型)
# 计算因子传入两套数据指数数据用于因子ETF数据用于收益
factor_data, valid_codes = compute_factors(
index_data,
valid_codes,
n=self.config["n_days"],
factor_type=self.config["factor_type"],
etf_data=etf_data, # 传入ETF数据用于收益计算
code_config=code_config, # 传入配置以判断加密货币
)
self.data = factor_data
self.valid_codes = valid_codes
return factor_data
def generate_signals(self) -> pd.DataFrame:
"""生成轮动信号"""
if self.data is None:
self.fetch_data()
result = self.data.copy()
score_cols = [f"得分_{code}" for code in self.valid_codes]
select_num = self.config["select_num"]
rebalance_days = self.config["rebalance_days"]
rebalance_threshold = self.config["rebalance_threshold"]
# Step 1: 每日目标组合
if not score_cols:
raise ValueError("没有有效的指数代码,无法生成信号")
if select_num == 1:
daily_target = (
result[score_cols]
.idxmax(axis=1)
.str.replace("得分_", "", regex=False)
)
else:
def top_n_codes(row):
scores = pd.to_numeric(row[score_cols], errors="coerce")
# 过滤掉 NaN 值
scores = scores.dropna()
if len(scores) == 0:
return ""
top = scores.nlargest(min(select_num, len(scores))).index.tolist()
return ",".join([c.replace("得分_", "") for c in top])
daily_target = result.apply(top_n_codes, axis=1)
# Step 2: 逐日生成信号(调仓周期控制)
held_signals = []
current_held = None
last_rebalance_idx = 0
for i in range(len(result)):
target = daily_target.iloc[i]
if current_held is None:
# 跳过空信号,直到找到第一个有效信号
if not target:
held_signals.append(None) # 添加None占位保持长度一致
continue
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
continue
days_since = i - last_rebalance_idx
if days_since >= rebalance_days:
# 目标信号为空时不调仓
if target: # 只在目标有效时才检查是否调仓
should = self._check_rebalance(
result.iloc[i], current_held, target,
select_num, rebalance_threshold
)
if should:
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
result["信号_raw"] = held_signals
result["信号"] = result["信号_raw"].shift(1)
result = result.drop(columns=["信号_raw"])
# 删除信号为 NaN 或空字符串的行
result = result.dropna(subset=["信号"])
result = result[result["信号"] != ""]
self.signals = result
self._print_signal_stats(result, select_num, rebalance_days, rebalance_threshold)
return result
def _check_rebalance(self, row, current_held, target, select_num, threshold):
"""检查是否应该调仓"""
if select_num == 1:
if target == current_held:
return False
new_score = float(row[f"得分_{target}"])
old_score = float(row[f"得分_{current_held}"])
if old_score > 0:
return (new_score / old_score - 1) >= threshold
return new_score > 0
else:
new_codes = [c for c in target.split(",") if c] # 过滤空字符串
old_codes = [c for c in current_held.split(",") if c] # 过滤空字符串
if not new_codes or not old_codes:
return True # 有空持仓,需要调仓
if set(new_codes) == set(old_codes):
return False
new_total = sum(float(row.get(f"得分_{c}", 0)) for c in new_codes)
old_total = sum(float(row.get(f"得分_{c}", 0)) for c in old_codes)
if old_total > 0:
return (new_total / old_total - 1) >= threshold
return new_total > 0
def _print_signal_stats(self, result, select_num, rebalance_days, rebalance_threshold):
"""打印信号统计"""
total_days = len(result)
if select_num == 1:
rebalance_count = (result["信号"] != result["信号"].shift(1)).sum() - 1
else:
prev = None
rebalance_count = 0
for s in result["信号"]:
if prev is not None and s != prev:
if set(s.split(",")) != set(prev.split(",")):
rebalance_count += 1
prev = s
rebalance_count = max(rebalance_count, 0)
avg_hold = total_days / max(rebalance_count, 1)
years = total_days / 252
annual_rebalances = rebalance_count / max(years, 0.1)
print(f"\n信号生成完成:")
print(f" 调仓周期: {rebalance_days} 天 | 阈值: {rebalance_threshold:.1%}")
print(f" 交易天数: {total_days}")
print(f" 调仓次数: {rebalance_count} | 平均持仓: {avg_hold:.1f} 天 | 年均调仓: {annual_rebalances:.1f}")
if select_num == 1:
signal_counts = result["信号"].value_counts()
print(f" 品种持仓分布 (前10):")
for code, count in signal_counts.head(10).items():
pct = count / total_days * 100
print(f" {code}: {count}天 ({pct:.1f}%)")
def run_backtest(self) -> pd.DataFrame:
"""执行回测"""
if self.signals is None:
self.generate_signals()
result = self.signals.copy()
select_num = self.config["select_num"]
trade_cost = self.config["trade_cost"]
# 计算策略日收益率
if select_num == 1:
def calc_return(row):
signal = row['信号']
if not signal or pd.isna(signal):
return 0.0
return row.get(f"日收益率_{signal}", 0.0)
result["轮动策略日收益率"] = result.apply(calc_return, axis=1)
else:
def calc_multi_return(row):
codes = [c for c in row["信号"].split(",") if c] # 过滤空字符串
if not codes:
return 0.0
returns = [row.get(f"日收益率_{c}", 0.0) for c in codes]
return np.mean(returns)
result["轮动策略日收益率"] = result.apply(calc_multi_return, axis=1)
# 扣除交易成本
if trade_cost > 0:
prev_signal = result["信号"].shift(1)
if select_num == 1:
changed = (result["信号"] != prev_signal) & prev_signal.notna()
result.loc[changed, "轮动策略日收益率"] -= trade_cost
else:
turnover_list = []
for curr, prev in zip(result["信号"], prev_signal):
if pd.isna(prev) or curr == prev:
turnover_list.append(0.0)
else:
old = set(prev.split(","))
new = set(curr.split(","))
swapped = len(old - new)
turnover_list.append(swapped / len(old))
result["换手率"] = turnover_list
result["轮动策略日收益率"] -= result["换手率"] * trade_cost
# 计算净值
result["轮动策略净值"] = (1 + result["轮动策略日收益率"]).cumprod()
# 各ETF单独净值
for code in self.valid_codes:
first_price = result[code].iloc[0]
result[f"净值_{code}"] = result[code] / first_price
# 基准净值
# benchmark_data 是 DataFrame需要提取 close 列
if isinstance(self.benchmark_data, pd.DataFrame):
if 'close' in self.benchmark_data.columns:
bench_close = self.benchmark_data['close']
else:
# 宽格式数据
bench_close = self.benchmark_data.iloc[:, 0]
else:
bench_close = self.benchmark_data
bench_ret = bench_close.pct_change().dropna()
common_dates = result.index.intersection(bench_ret.index)
bench_ret = bench_ret.loc[common_dates]
result["基准日收益率"] = bench_ret.reindex(result.index, fill_value=0)
result["基准净值"] = (1 + result["基准日收益率"]).cumprod()
self.backtest_result = result
# 打印摘要
total_days = len(result)
strategy_total_return = result["轮动策略净值"].iloc[-1] - 1
benchmark_total_return = result["基准净值"].iloc[-1] - 1
print(f"\n回测完成:")
print(f" 回测区间: {result.index.min().date()} ~ {result.index.max().date()}")
print(f" 交易天数: {total_days}")
print(f" 策略累计收益: {strategy_total_return:.2%}")
print(f" 基准累计收益: {benchmark_total_return:.2%}")
return result
def run(self) -> dict:
"""运行完整流程"""
self.fetch_data()
self.generate_signals()
self.run_backtest()
return self.backtest_result
def get_signals(self) -> pd.DataFrame:
"""获取当前信号"""
if self.signals is None:
self.generate_signals()
return self.signals