Files
etf/strategies/rotation/engine.py
aszerW 9ecc796d36 fix(engine): 修复净值计算中NaN值导致的缺失问题
问题分析:
1. HSTECH.HK数据从2020-08-26才开始,早期数据缺失
2. 原净值计算使用iloc[0]作为基准,若首日价格是NaN则整列净值变成NaN
3. 原日收益率计算中,某品种日收益率NaN会导致mean()返回NaN
4. 导致策略净值62个缺失值(3.53%)

修复内容:
1. strategies/rotation/engine.py - 各品种净值计算
   - 使用第一个有效价格作为基准(而非iloc[0])
   - 若品种无有效数据则净值列全部为NaN

2. strategies/rotation/engine.py - 策略日收益率计算
   - select_num=1时:若日收益率是NaN则返回0.0
   - select_num>1时:忽略NaN值计算mean()
   - 若所有品种日收益率都缺失则返回0.0

修复效果:
- 策略净值缺失:从62个(3.53%)降至0个(0%)
- HSTECH.HK净值缺失:从100%降至21.49%(377/1754)
- 其他品种净值:全部无缺失

说明:
- HSTECH.HK净值377个缺失是正常的(数据从2020-08-26才开始)
- 早期缺失日期(2019年)非HSTECH持仓期,缺失由其他原因导致
2026-05-08 22:52:36 +08:00

348 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ETF轮动策略引擎
整合信号生成和回测逻辑
使用 YFinance 数据源(支持 SSH 隧道)
"""
import pandas as pd
import numpy as np
from typing import Optional
from strategies.base import BacktestStrategy
from core.datasource.hybrid_source import HybridDataSource
from core.factors.momentum import compute_factors, calculate_daily_return
class RotationStrategy(BacktestStrategy):
"""ETF轮动策略"""
def __init__(self, config: dict):
super().__init__("ETF轮动策略", config)
# 初始化混合数据源
ssh_config = config.get("ssh_tunnel", {})
self.data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=config.get("use_cache", True)
)
print(f"使用混合数据源: Tushare(中国A股) + YFinance(港股/美股/加密货币)")
print(f"SSH隧道: {ssh_config.get('enabled', False)}")
self.data = None
self.signals = None
self.backtest_result = None
def fetch_data(self) -> pd.DataFrame:
"""获取数据(支持指数-ETF双轨数据"""
from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
# 获取代码配置(包含 name, etf, market
code_config = self.config.get("code_list", {})
# 使用上下文管理器管理 SSH 隧道
with self.data_source:
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = self.data_source.fetch_all(
code_config,
benchmark_code,
self.config["start_date"],
self.config["end_date"],
)
# 存储数据和配置
self.index_data = index_data # 指数数据(用于因子计算)
self.etf_data = etf_data # ETF价格数据用于收益计算
self.etf_nav_data = etf_nav_data # ETF净值数据用于溢价率计算
self.benchmark_data = benchmark_data
self.valid_codes = valid_codes
self.code_config = code_config # 代码配置(用于判断市场类型)
# 计算因子传入两套数据指数数据用于因子ETF数据用于收益
factor_data, valid_codes = compute_factors(
index_data,
valid_codes,
n=self.config["n_days"],
factor_type=self.config["factor_type"],
etf_data=etf_data, # 传入ETF数据用于收益计算
code_config=code_config, # 传入配置以判断加密货币
index_ohlcv_data=index_ohlcv_data,
auto_day=self.config.get("auto_day", False),
min_days=self.config.get("min_days", 20),
max_days=self.config.get("max_days", 60),
)
self.data = factor_data
self.valid_codes = valid_codes
return factor_data
def generate_signals(self) -> pd.DataFrame:
"""生成轮动信号"""
if self.data is None:
self.fetch_data()
result = self.data.copy()
score_cols = [f"得分_{code}" for code in self.valid_codes]
select_num = self.config["select_num"]
rebalance_days = self.config["rebalance_days"]
rebalance_threshold = self.config["rebalance_threshold"]
# Step 1: 每日目标组合
if not score_cols:
raise ValueError("没有有效的指数代码,无法生成信号")
diversified = self.config.get("diversified", False)
if not diversified:
if select_num == 1:
def top_1_filter(row):
scores = pd.to_numeric(row[score_cols], errors="coerce").dropna()
if scores.empty: return ""
best_code = scores.idxmax()
if scores[best_code] <= 0: return "" # 强制过滤负分
return best_code.replace("得分_", "")
daily_target = result.apply(top_1_filter, axis=1)
else:
def top_n_codes(row):
scores = pd.to_numeric(row[score_cols], errors="coerce").dropna()
scores = scores[scores > 0] # 强制只保留正分标的
if scores.empty: return ""
top = scores.nlargest(min(select_num, len(scores))).index.tolist()
return ",".join([c.replace("得分_", "") for c in top])
daily_target = result.apply(top_n_codes, axis=1)
else:
# 强制分散化:每个大类只选 Top 1
def top_n_diversified(row):
scores = pd.to_numeric(row[score_cols], errors="coerce").dropna()
scores = scores[scores > 0] # 强制只保留正分标的
if scores.empty: return ""
# 建立 category -> (code, score) 的映射
cat_best = {}
for col_name, score in scores.items():
code = col_name.replace("得分_", "")
cat = self.code_config.get(code, {}).get("market", "未知")
if cat not in cat_best or score > cat_best[cat][1]:
cat_best[cat] = (code, score)
# 对各大类的冠军进行排序
sorted_cats = sorted(cat_best.values(), key=lambda x: x[1], reverse=True)
top = [code for code, score in sorted_cats[:select_num]]
return ",".join(top)
daily_target = result.apply(top_n_diversified, axis=1)
# Step 2: 逐日生成信号(调仓周期控制)
held_signals = []
current_held = None
last_rebalance_idx = 0
for i in range(len(result)):
target = daily_target.iloc[i]
if current_held is None:
# 跳过空信号,直到找到第一个有效信号
if not target:
held_signals.append(None) # 添加None占位保持长度一致
continue
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
continue
days_since = i - last_rebalance_idx
if days_since >= rebalance_days:
# 目标信号为空时不调仓
if target: # 只在目标有效时才检查是否调仓
should = self._check_rebalance(
result.iloc[i], current_held, target,
select_num, rebalance_threshold
)
if should:
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
result["信号_raw"] = held_signals
result["信号"] = result["信号_raw"].shift(1)
result = result.drop(columns=["信号_raw"])
# 删除信号为 NaN 或空字符串的行
result = result.dropna(subset=["信号"])
result = result[result["信号"] != ""]
self.signals = result
self._print_signal_stats(result, select_num, rebalance_days, rebalance_threshold)
return result
def _check_rebalance(self, row, current_held, target, select_num, threshold):
"""检查是否应该调仓"""
if select_num == 1:
if target == current_held:
return False
new_score = float(row[f"得分_{target}"])
old_score = float(row[f"得分_{current_held}"])
if old_score > 0:
return (new_score / old_score - 1) >= threshold
return new_score > 0
else:
new_codes = [c for c in target.split(",") if c] # 过滤空字符串
old_codes = [c for c in current_held.split(",") if c] # 过滤空字符串
if not new_codes or not old_codes:
return True # 有空持仓,需要调仓
if set(new_codes) == set(old_codes):
return False
new_total = sum(float(row.get(f"得分_{c}", 0)) for c in new_codes)
old_total = sum(float(row.get(f"得分_{c}", 0)) for c in old_codes)
if old_total > 0:
return (new_total / old_total - 1) >= threshold
return new_total > 0
def _print_signal_stats(self, result, select_num, rebalance_days, rebalance_threshold):
"""打印信号统计"""
total_days = len(result)
if select_num == 1:
rebalance_count = (result["信号"] != result["信号"].shift(1)).sum() - 1
else:
prev = None
rebalance_count = 0
for s in result["信号"]:
if prev is not None and s != prev:
if set(s.split(",")) != set(prev.split(",")):
rebalance_count += 1
prev = s
rebalance_count = max(rebalance_count, 0)
avg_hold = total_days / max(rebalance_count, 1)
years = total_days / 252
annual_rebalances = rebalance_count / max(years, 0.1)
print(f"\n信号生成完成:")
print(f" 调仓周期: {rebalance_days} 天 | 阈值: {rebalance_threshold:.1%}")
print(f" 交易天数: {total_days}")
print(f" 调仓次数: {rebalance_count} | 平均持仓: {avg_hold:.1f} 天 | 年均调仓: {annual_rebalances:.1f}")
if select_num == 1:
signal_counts = result["信号"].value_counts()
print(f" 品种持仓分布 (前10):")
for code, count in signal_counts.head(10).items():
pct = count / total_days * 100
print(f" {code}: {count}天 ({pct:.1f}%)")
def run_backtest(self) -> pd.DataFrame:
"""执行回测"""
if self.signals is None:
self.generate_signals()
result = self.signals.copy()
select_num = self.config["select_num"]
trade_cost = self.config["trade_cost"]
# 计算策略日收益率 - 处理NaN值
if select_num == 1:
def calc_return(row):
signal = row['信号']
if not signal or pd.isna(signal):
return 0.0
ret = row.get(f"日收益率_{signal}", 0.0)
# 如果日收益率是NaN返回0.0
return ret if pd.notna(ret) else 0.0
result["轮动策略日收益率"] = result.apply(calc_return, axis=1)
else:
def calc_multi_return(row):
codes = [c for c in row["信号"].split(",") if c] # 过滤空字符串
if not codes:
return 0.0
# 获取各品种日收益率忽略NaN值
returns = []
for c in codes:
ret = row.get(f"日收益率_{c}", None)
if ret is not None and pd.notna(ret):
returns.append(ret)
# 如果所有品种日收益率都缺失返回0.0
return np.mean(returns) if returns else 0.0
result["轮动策略日收益率"] = result.apply(calc_multi_return, axis=1)
# 扣除交易成本
if trade_cost > 0:
prev_signal = result["信号"].shift(1)
if select_num == 1:
changed = (result["信号"] != prev_signal) & prev_signal.notna()
result.loc[changed, "轮动策略日收益率"] -= trade_cost
else:
turnover_list = []
for curr, prev in zip(result["信号"], prev_signal):
if pd.isna(prev) or curr == prev:
turnover_list.append(0.0)
else:
old = set(prev.split(","))
new = set(curr.split(","))
swapped = len(old - new)
turnover_list.append(swapped / len(old))
result["换手率"] = turnover_list
result["轮动策略日收益率"] -= result["换手率"] * trade_cost
# 计算净值
result["轮动策略净值"] = (1 + result["轮动策略日收益率"]).cumprod()
# 各ETF单独净值 - 使用第一个有效价格作为基准
for code in self.valid_codes:
# 获取第一个有效价格非NaN
valid_prices = result[code][result[code].notna()]
if len(valid_prices) > 0:
first_valid_price = valid_prices.iloc[0]
result[f"净值_{code}"] = result[code] / first_valid_price
else:
# 如果没有有效数据净值列全部为NaN
result[f"净值_{code}"] = np.nan
# 基准净值
# benchmark_data 是 DataFrame需要提取 close 列
if isinstance(self.benchmark_data, pd.DataFrame):
if 'close' in self.benchmark_data.columns:
bench_close = self.benchmark_data['close']
else:
# 宽格式数据
bench_close = self.benchmark_data.iloc[:, 0]
else:
bench_close = self.benchmark_data
bench_ret = bench_close.pct_change().dropna()
common_dates = result.index.intersection(bench_ret.index)
bench_ret = bench_ret.loc[common_dates]
result["基准日收益率"] = bench_ret.reindex(result.index, fill_value=0)
result["基准净值"] = (1 + result["基准日收益率"]).cumprod()
self.backtest_result = result
# 打印摘要
total_days = len(result)
strategy_total_return = result["轮动策略净值"].iloc[-1] - 1
benchmark_total_return = result["基准净值"].iloc[-1] - 1
print(f"\n回测完成:")
print(f" 回测区间: {result.index.min().date()} ~ {result.index.max().date()}")
print(f" 交易天数: {total_days}")
print(f" 策略累计收益: {strategy_total_return:.2%}")
print(f" 基准累计收益: {benchmark_total_return:.2%}")
return result
def run(self) -> dict:
"""运行完整流程"""
self.fetch_data()
self.generate_signals()
self.run_backtest()
return self.backtest_result
def get_signals(self) -> pd.DataFrame:
"""获取当前信号"""
if self.signals is None:
self.generate_signals()
return self.signals