Files
etf/strategies/rotation/engine.py
aszerW 9ea84f0e57 feat(rotation): 支持混合数据源并优化因子计算和策略逻辑
- 删除旧的Tushare Token环境变量函数,简化配置
- 在配置文件中新增全市场指数及SSH隧道配置支持YFinance数据访问
- 更新compute_factors函数,支持长格式混合数据源,兼容旧宽格式数据
- 修改RotationStrategy使用HybridDataSource,支持Tushare与YFinance数据源混合
- 添加SSH隧道支持,实现安全访问非主市场数据
- 优化因子计算逻辑,提升缺失值处理和因子合并的鲁棒性
- 修正基准净值计算,兼容长宽格式基准数据处理
- 增强信号生成逻辑,处理因子得分中的NaN情况防止异常
2026-03-19 20:38:13 +08:00

282 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ETF轮动策略引擎
整合信号生成和回测逻辑
使用 YFinance 数据源(支持 SSH 隧道)
"""
import pandas as pd
import numpy as np
from typing import Optional
from strategies.base import BacktestStrategy
from core.data.hybrid_source import HybridDataSource
from core.factors.momentum import compute_factors, calculate_daily_return
class RotationStrategy(BacktestStrategy):
"""ETF轮动策略"""
def __init__(self, config: dict):
super().__init__("ETF轮动策略", config)
# 初始化混合数据源
ssh_config = config.get("ssh_tunnel", {})
self.data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=config.get("use_cache", True)
)
print(f"使用混合数据源: Tushare(中国A股) + YFinance(港股/美股/加密货币)")
print(f"SSH隧道: {ssh_config.get('enabled', False)}")
self.data = None
self.signals = None
self.backtest_result = None
def fetch_data(self) -> pd.DataFrame:
"""获取数据"""
from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
# 使用上下文管理器管理 SSH 隧道(如果是 YFinance 数据源)
with self.data_source:
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all(
self.config["code_list"],
benchmark_code,
self.config["start_date"],
self.config["end_date"],
)
self.etf_data = etf_data
self.benchmark_data = benchmark_data
self.valid_codes = valid_codes
# 计算因子
factor_data, valid_codes = compute_factors(
etf_data,
valid_codes,
n=self.config["n_days"],
factor_type=self.config["factor_type"],
)
self.data = factor_data
self.valid_codes = valid_codes
return factor_data
def generate_signals(self) -> pd.DataFrame:
"""生成轮动信号"""
if self.data is None:
self.fetch_data()
result = self.data.copy()
score_cols = [f"得分_{code}" for code in self.valid_codes]
select_num = self.config["select_num"]
rebalance_days = self.config["rebalance_days"]
rebalance_threshold = self.config["rebalance_threshold"]
# Step 1: 每日目标组合
if not score_cols:
raise ValueError("没有有效的指数代码,无法生成信号")
if select_num == 1:
daily_target = (
result[score_cols]
.idxmax(axis=1)
.str.replace("得分_", "", regex=False)
)
else:
def top_n_codes(row):
scores = pd.to_numeric(row[score_cols], errors="coerce")
# 过滤掉 NaN 值
scores = scores.dropna()
if len(scores) == 0:
return ""
top = scores.nlargest(min(select_num, len(scores))).index.tolist()
return ",".join([c.replace("得分_", "") for c in top])
daily_target = result.apply(top_n_codes, axis=1)
# Step 2: 逐日生成信号(调仓周期控制)
held_signals = []
current_held = None
last_rebalance_idx = 0
for i in range(len(result)):
target = daily_target.iloc[i]
if current_held is None:
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
continue
days_since = i - last_rebalance_idx
if days_since >= rebalance_days:
should = self._check_rebalance(
result.iloc[i], current_held, target,
select_num, rebalance_threshold
)
if should:
current_held = target
last_rebalance_idx = i
held_signals.append(current_held)
result["信号_raw"] = held_signals
result["信号"] = result["信号_raw"].shift(1)
result = result.drop(columns=["信号_raw"])
result = result.dropna(subset=["信号"])
self.signals = result
self._print_signal_stats(result, select_num, rebalance_days, rebalance_threshold)
return result
def _check_rebalance(self, row, current_held, target, select_num, threshold):
"""检查是否应该调仓"""
if select_num == 1:
if target == current_held:
return False
new_score = float(row[f"得分_{target}"])
old_score = float(row[f"得分_{current_held}"])
if old_score > 0:
return (new_score / old_score - 1) >= threshold
return new_score > 0
else:
new_codes = target.split(",")
old_codes = current_held.split(",")
if set(new_codes) == set(old_codes):
return False
new_total = sum(float(row[f"得分_{c}"]) for c in new_codes)
old_total = sum(float(row[f"得分_{c}"]) for c in old_codes)
if old_total > 0:
return (new_total / old_total - 1) >= threshold
return new_total > 0
def _print_signal_stats(self, result, select_num, rebalance_days, rebalance_threshold):
"""打印信号统计"""
total_days = len(result)
if select_num == 1:
rebalance_count = (result["信号"] != result["信号"].shift(1)).sum() - 1
else:
prev = None
rebalance_count = 0
for s in result["信号"]:
if prev is not None and s != prev:
if set(s.split(",")) != set(prev.split(",")):
rebalance_count += 1
prev = s
rebalance_count = max(rebalance_count, 0)
avg_hold = total_days / max(rebalance_count, 1)
years = total_days / 252
annual_rebalances = rebalance_count / max(years, 0.1)
print(f"\n信号生成完成:")
print(f" 调仓周期: {rebalance_days} 天 | 阈值: {rebalance_threshold:.1%}")
print(f" 交易天数: {total_days}")
print(f" 调仓次数: {rebalance_count} | 平均持仓: {avg_hold:.1f} 天 | 年均调仓: {annual_rebalances:.1f}")
if select_num == 1:
signal_counts = result["信号"].value_counts()
print(f" 品种持仓分布 (前10):")
for code, count in signal_counts.head(10).items():
pct = count / total_days * 100
print(f" {code}: {count}天 ({pct:.1f}%)")
def run_backtest(self) -> pd.DataFrame:
"""执行回测"""
if self.signals is None:
self.generate_signals()
result = self.signals.copy()
select_num = self.config["select_num"]
trade_cost = self.config["trade_cost"]
# 计算策略日收益率
if select_num == 1:
def calc_return(row):
return row[f"日收益率_{row['信号']}"]
result["轮动策略日收益率"] = result.apply(calc_return, axis=1)
else:
def calc_multi_return(row):
codes = row["信号"].split(",")
returns = [row[f"日收益率_{c}"] for c in codes]
return np.mean(returns)
result["轮动策略日收益率"] = result.apply(calc_multi_return, axis=1)
# 扣除交易成本
if trade_cost > 0:
prev_signal = result["信号"].shift(1)
if select_num == 1:
changed = (result["信号"] != prev_signal) & prev_signal.notna()
result.loc[changed, "轮动策略日收益率"] -= trade_cost
else:
turnover_list = []
for curr, prev in zip(result["信号"], prev_signal):
if pd.isna(prev) or curr == prev:
turnover_list.append(0.0)
else:
old = set(prev.split(","))
new = set(curr.split(","))
swapped = len(old - new)
turnover_list.append(swapped / len(old))
result["换手率"] = turnover_list
result["轮动策略日收益率"] -= result["换手率"] * trade_cost
# 计算净值
result["轮动策略净值"] = (1 + result["轮动策略日收益率"]).cumprod()
# 各ETF单独净值
for code in self.valid_codes:
first_price = result[code].iloc[0]
result[f"净值_{code}"] = result[code] / first_price
# 基准净值
# benchmark_data 是 DataFrame需要提取 close 列
if isinstance(self.benchmark_data, pd.DataFrame):
if 'close' in self.benchmark_data.columns:
bench_close = self.benchmark_data['close']
else:
# 宽格式数据
bench_close = self.benchmark_data.iloc[:, 0]
else:
bench_close = self.benchmark_data
bench_ret = bench_close.pct_change().dropna()
common_dates = result.index.intersection(bench_ret.index)
bench_ret = bench_ret.loc[common_dates]
result["基准日收益率"] = bench_ret.reindex(result.index, fill_value=0)
result["基准净值"] = (1 + result["基准日收益率"]).cumprod()
self.backtest_result = result
# 打印摘要
total_days = len(result)
strategy_total_return = result["轮动策略净值"].iloc[-1] - 1
benchmark_total_return = result["基准净值"].iloc[-1] - 1
print(f"\n回测完成:")
print(f" 回测区间: {result.index.min().date()} ~ {result.index.max().date()}")
print(f" 交易天数: {total_days}")
print(f" 策略累计收益: {strategy_total_return:.2%}")
print(f" 基准累计收益: {benchmark_total_return:.2%}")
return result
def run(self) -> dict:
"""运行完整流程"""
self.fetch_data()
self.generate_signals()
self.run_backtest()
return self.backtest_result
def get_signals(self) -> pd.DataFrame:
"""获取当前信号"""
if self.signals is None:
self.generate_signals()
return self.signals