核心逻辑: 1. config.yaml新增bond_threshold配置块 2. selectors.py新增动态阈值逻辑: - _get_dynamic_threshold(): 阈值=短债动量×ratio - _grouped_selection(): BOND不参与竞争,空余仓位填充短债 3. strategy.py传入bond_threshold_config 回测验证: - 最终净值: 292.56 - 累计收益: 29155.96% - 持仓3只: 92.3%(满仓率提升) - 短债填充: 27.7%时间启用(空余仓位) 信号特征: - 短债可重复出现表示仓位占比 - 例如 "NDX,931862.CSI,931862.CSI" → NDX 33%, 短债 67%
487 lines
19 KiB
Python
487 lines
19 KiB
Python
"""
|
||
轮动策略完整实现
|
||
|
||
整合数据获取、因子计算、信号生成、回测执行
|
||
"""
|
||
|
||
import pandas as pd
|
||
import yaml
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
# 加载环境变量
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
from framework.factors import FactorRegistry, FactorCombiner
|
||
from framework.signals import SignalGenerator
|
||
from framework.execution import BacktestExecutor
|
||
from framework.risk import CallbackHook, Position
|
||
from framework.strategy import StrategyBase
|
||
|
||
# 导入定制组件
|
||
from strategies.shared.factors.momentum import MomentumFactor
|
||
from strategies.shared.signals.selectors import TopNSelector
|
||
|
||
|
||
class RotationStrategy(StrategyBase):
|
||
"""
|
||
ETF轮动策略(完整实现)
|
||
|
||
基于动量因子 + Top N选股 + 分散化
|
||
|
||
使用方式:
|
||
from strategies.rotation.strategy import RotationStrategy
|
||
strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml')
|
||
result = strategy.run_backtest()
|
||
"""
|
||
|
||
name = "rotation"
|
||
select_num = 3
|
||
stoploss = -0.05
|
||
n_days = 25
|
||
rebalance_days = 1
|
||
rebalance_threshold = 0.0
|
||
trade_cost = 0.001
|
||
|
||
def __init__(self, config: dict = None):
|
||
"""初始化策略"""
|
||
# 应用配置
|
||
if config:
|
||
self._apply_config(config)
|
||
self.config = config
|
||
else:
|
||
self.config = {}
|
||
|
||
# 初始化因子
|
||
FactorRegistry.clear()
|
||
FactorRegistry.register(MomentumFactor)
|
||
self._factor = FactorRegistry.get(
|
||
'momentum',
|
||
n_days=self.n_days,
|
||
crash_filter=True
|
||
)
|
||
|
||
# 构建分组映射(分散化选股)
|
||
self._group_mapping = self._build_group_mapping()
|
||
|
||
# 初始化信号生成器
|
||
self._selector = TopNSelector(
|
||
select_num=self.select_num,
|
||
group_mapping=self._group_mapping,
|
||
min_score=self.min_score, # 从配置读取,支持动态调整阈值
|
||
rebalance_days=self.rebalance_days,
|
||
rebalance_threshold=self.rebalance_threshold,
|
||
bond_threshold_config=self.config.get('bond_threshold', {}) # V3动态阈值配置
|
||
)
|
||
|
||
@classmethod
|
||
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
|
||
"""从YAML配置创建策略实例"""
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
# 设置结束日期
|
||
if not config.get('end_date'):
|
||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
return cls(config)
|
||
|
||
def _apply_config(self, config: dict) -> None:
|
||
"""应用配置参数"""
|
||
self.select_num = config.get('select_num', self.select_num)
|
||
self.n_days = config.get('n_days', self.n_days)
|
||
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
|
||
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
|
||
self.trade_cost = config.get('trade_cost', self.trade_cost)
|
||
self.min_score = config.get('min_score', 0.0) # 动量最低阈值,默认过滤负动量
|
||
self.start_date = config.get('start_date', '2019-01-01')
|
||
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||
|
||
def _build_group_mapping(self) -> dict:
|
||
"""构建分组映射(分散化选股)"""
|
||
group_mapping = {}
|
||
code_list_config = self.config.get('code_list', {})
|
||
for code, cfg in code_list_config.items():
|
||
if isinstance(cfg, dict):
|
||
group_mapping[code] = cfg.get('market', 'default')
|
||
return group_mapping
|
||
|
||
def get_data(self, use_flask_api: bool = True) -> dict:
|
||
"""
|
||
获取数据
|
||
|
||
Args:
|
||
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True)
|
||
False 则使用本地 HybridDataSource
|
||
"""
|
||
code_list_config = self.config.get('code_list', {})
|
||
benchmark_config = self.config.get('benchmark', {})
|
||
benchmark_code = benchmark_config.get('code', '000300.SH')
|
||
|
||
if not code_list_config:
|
||
raise ValueError("配置中未找到 code_list")
|
||
|
||
# 获取 Flask API 地址
|
||
flask_api_config = self.config.get('flask_api', {})
|
||
flask_api_url = flask_api_config.get('url') if flask_api_config.get('enabled') else None
|
||
|
||
if use_flask_api:
|
||
# 使用 Flask API 服务获取数据(远程调用)
|
||
return self._get_data_from_flask_api(
|
||
code_list_config,
|
||
benchmark_code,
|
||
flask_api_url
|
||
)
|
||
else:
|
||
# 使用本地 HybridDataSource(需要本地 SSH 隧道)
|
||
return self._get_data_from_local(
|
||
code_list_config,
|
||
benchmark_code
|
||
)
|
||
|
||
def _get_data_from_flask_api(
|
||
self,
|
||
code_list_config: dict,
|
||
benchmark_code: str,
|
||
flask_api_url: str = None
|
||
) -> dict:
|
||
"""通过 Flask API 服务获取数据"""
|
||
from datasource.flask_api_source import FlaskAPIDataSource
|
||
|
||
# 初始化 Flask API 数据源
|
||
api_source = FlaskAPIDataSource(base_url=flask_api_url)
|
||
|
||
# 检查服务状态
|
||
health = api_source.get_health()
|
||
if health.get('status') != 'healthy':
|
||
print(f"⚠ Flask API 服务状态: {health}")
|
||
else:
|
||
print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})")
|
||
|
||
# 打印回测时间区间说明
|
||
print(f"\n回测配置区间: {self.start_date} ~ {self.end_date}")
|
||
print("注: 各标的实际数据范围可能因上市时间/数据源限制而不同")
|
||
|
||
# 获取指数代码列表
|
||
index_codes = list(code_list_config.keys())
|
||
|
||
# 获取 ETF 代码映射
|
||
etf_code_map = {}
|
||
etf_codes = []
|
||
for index_code, cfg in code_list_config.items():
|
||
if isinstance(cfg, dict) and cfg.get('etf'):
|
||
etf_code_map[index_code] = cfg['etf']
|
||
etf_codes.append(cfg['etf'])
|
||
|
||
# 获取指数 OHLCV 数据
|
||
print(f"\n获取指数数据 ({len(index_codes)} 只)...")
|
||
index_ohlcv_data = api_source.fetch_batch(
|
||
index_codes,
|
||
self.start_date,
|
||
self.end_date
|
||
)
|
||
|
||
# 过滤有效代码
|
||
valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0]
|
||
print(f"有效指数: {len(valid_codes)} 只")
|
||
|
||
# 获取 ETF 价格数据(同时获取净值和溢价率)
|
||
print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...")
|
||
etf_ohlcv_data = api_source.fetch_batch(
|
||
etf_codes,
|
||
self.start_date,
|
||
self.end_date
|
||
)
|
||
|
||
# 转换为宽格式 DataFrame,并提取净值/溢价率数据
|
||
etf_data = None
|
||
etf_nav_data = {}
|
||
etf_premium_data = {}
|
||
|
||
if etf_ohlcv_data:
|
||
etf_close_dict = {}
|
||
for etf_code, df in etf_ohlcv_data.items():
|
||
if df is not None and 'close' in df.columns:
|
||
etf_close_dict[etf_code] = df['close']
|
||
|
||
# 从 DataFrame.attrs 中提取净值和溢价率数据
|
||
# Flask API 已自动附加这些数据
|
||
if 'nav' in df.attrs:
|
||
etf_nav_data[etf_code] = df.attrs['nav']
|
||
if 'premium_series' in df.attrs:
|
||
etf_premium_data[etf_code] = {
|
||
'series': df.attrs['premium_series'],
|
||
'latest': df.attrs.get('latest_premium'),
|
||
'date': df.attrs.get('premium_date'),
|
||
'stats': df.attrs.get('premium_stats'),
|
||
}
|
||
if etf_close_dict:
|
||
etf_data = pd.DataFrame(etf_close_dict)
|
||
|
||
print(f"有效净值: {len(etf_nav_data)} 只")
|
||
print(f"有效溢价率: {len(etf_premium_data)} 只")
|
||
|
||
# 获取基准数据
|
||
print(f"\n获取基准数据 ({benchmark_code})...")
|
||
benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date)
|
||
benchmark_data = None
|
||
if benchmark_ohlcv is not None:
|
||
benchmark_data = benchmark_ohlcv['close']
|
||
|
||
# 构建指数收盘价宽格式 DataFrame(用于因子计算)
|
||
index_close_dict = {}
|
||
for code in valid_codes:
|
||
df = index_ohlcv_data.get(code)
|
||
if df is not None and 'close' in df.columns:
|
||
index_close_dict[code] = df['close']
|
||
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
|
||
|
||
return {
|
||
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
|
||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||
'etf_data': etf_data, # ETF 收盘价(宽格式)
|
||
'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame}
|
||
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
|
||
'benchmark_data': benchmark_data, # 基准收盘价 Series
|
||
'valid_codes': valid_codes, # 有效指数代码列表
|
||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||
}
|
||
|
||
def _get_data_from_local(
|
||
self,
|
||
code_list_config: dict,
|
||
benchmark_code: str
|
||
) -> dict:
|
||
"""使用本地 HybridDataSource 获取数据"""
|
||
from datasource import HybridDataSource
|
||
|
||
ssh_config = self.config.get('ssh_tunnel', {})
|
||
|
||
data_source = HybridDataSource(
|
||
ssh_config=ssh_config,
|
||
use_cache=self.config.get('use_cache', True)
|
||
)
|
||
|
||
# 调用 fetch_all
|
||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
|
||
data_source.fetch_all(
|
||
code_config=code_list_config,
|
||
benchmark_code=benchmark_code,
|
||
start_date=self.start_date,
|
||
end_date=self.end_date
|
||
)
|
||
|
||
return {
|
||
'index_data': index_ohlcv_data, # 原始OHLCV数据
|
||
'index_close': index_data, # 对齐后的收盘价(宽格式)
|
||
'etf_data': etf_data,
|
||
'etf_nav_data': etf_nav_data,
|
||
'benchmark_data': benchmark_data,
|
||
'valid_codes': valid_codes,
|
||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||
}
|
||
|
||
def compute_factors(self, data: dict) -> pd.DataFrame:
|
||
"""计算因子值(匹配原引擎:先计算因子再对齐到A股交易日历)
|
||
|
||
注意:不剔除数据不足的标的,保留所有标的以暴露策略问题
|
||
"""
|
||
index_data = data['index_data']
|
||
valid_codes = data['valid_codes']
|
||
|
||
# 获取A股交易日历作为基准(使用已有的对齐后数据索引)
|
||
index_close = data.get('index_close')
|
||
if index_close is not None:
|
||
a_share_dates = index_close.index
|
||
else:
|
||
for code in valid_codes:
|
||
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
|
||
a_share_dates = index_data[code].index
|
||
break
|
||
else:
|
||
a_share_dates = index_data[valid_codes[0]].index
|
||
|
||
factor_values = {}
|
||
final_valid_codes = []
|
||
|
||
for code in valid_codes:
|
||
df = index_data[code].copy()
|
||
|
||
# 检查是否有有效的OHLCV数据(列存在且不全为None)
|
||
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
|
||
required_cols = ['open', 'high', 'low', 'close']
|
||
|
||
# 检查列是否存在
|
||
cols_exist = all(col in df.columns for col in required_cols)
|
||
|
||
# 检查数据是否有效(不全为None/NaN)
|
||
if cols_exist:
|
||
cols_have_data = all(df[col].notna().any() for col in required_cols)
|
||
else:
|
||
cols_have_data = False
|
||
|
||
if cols_exist and cols_have_data:
|
||
# 有完整有效的OHLCV数据,整行dropna()后提取close
|
||
df_clean = df[ohlcv_cols].dropna()
|
||
close_series = df_clean['close'] if len(df_clean) > 0 else pd.Series(dtype=float)
|
||
elif 'close' in df.columns and df['close'].notna().any():
|
||
# 只有close列有效数据(如债券指数)
|
||
close_series = df['close'].dropna()
|
||
else:
|
||
# 无有效数据
|
||
close_series = pd.Series(dtype=float)
|
||
|
||
# 检查数据长度并警告,但不剔除
|
||
if len(close_series) < self.n_days + 1:
|
||
print(f" ⚠ {code}: 数据不足 ({len(close_series)} < {self.n_days + 1}),保留但因子值可能为NaN")
|
||
|
||
# 原引擎逻辑:先在原始交易日历上计算因子
|
||
# rolling窗口使用的是原始交易日数据,不包含ffill填充的重复值
|
||
if len(close_series) > 0:
|
||
close_df = pd.DataFrame({'close': close_series})
|
||
factor_series = self._factor.compute(close_df)
|
||
|
||
# 然后对齐因子序列到A股交易日历(匹配原引擎逻辑)
|
||
factor_aligned = factor_series.reindex(a_share_dates, method='ffill')
|
||
else:
|
||
# 没有数据,创建空的因子序列
|
||
factor_aligned = pd.Series(index=a_share_dates, dtype=float)
|
||
|
||
factor_values[code] = factor_aligned
|
||
final_valid_codes.append(code)
|
||
|
||
factor_df = pd.DataFrame(factor_values)
|
||
|
||
# 检查缺失率并警告,但不剔除(保留所有标的以暴露策略问题)
|
||
total_rows = len(factor_df)
|
||
for code in final_valid_codes:
|
||
if code in factor_df.columns:
|
||
null_pct = factor_df[code].isnull().sum() / total_rows
|
||
if null_pct > 0.5:
|
||
print(f" ⚠ {code}: 缺失率 {null_pct:.1%} 较高,保留但信号生成时可能跳过")
|
||
|
||
# 不更新有效代码列表,保留所有原始代码
|
||
data['valid_codes'] = final_valid_codes
|
||
|
||
return factor_df
|
||
|
||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""生成信号"""
|
||
return self._selector.generate(factor_df)
|
||
|
||
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
|
||
"""
|
||
完整回测流程
|
||
|
||
Args:
|
||
data: 可选,如不提供则自动获取
|
||
save_path: 报告保存路径
|
||
|
||
Returns:
|
||
回测结果字典
|
||
"""
|
||
print("\n" + "=" * 60)
|
||
print(" ETF轮动策略 回测系统")
|
||
print("=" * 60)
|
||
|
||
# 1. 获取数据
|
||
if data is None:
|
||
data = self.get_data()
|
||
|
||
valid_codes = data['valid_codes']
|
||
index_data = data['index_data']
|
||
|
||
print(f"\n候选标的: {len(valid_codes)} 只")
|
||
print(f"回测区间: {self.start_date} ~ {self.end_date}")
|
||
|
||
# 2. 计算因子
|
||
print("\n计算因子...")
|
||
factor_df = self.compute_factors(data)
|
||
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)} 只")
|
||
|
||
# 3. 生成信号
|
||
print("\n生成信号...")
|
||
signals = self.generate_signals(factor_df)
|
||
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)} 天")
|
||
|
||
# 4. 执行回测
|
||
print("\n执行回测...")
|
||
|
||
# 获取A股交易日历(从因子数据索引)
|
||
a_share_dates = signals.index
|
||
|
||
# 计算日收益率:先在原始交易日历计算,再对齐到A股日历
|
||
# 关键:与因子计算逻辑一致,避免交易日不对齐导致收益率NaN
|
||
returns_data = {}
|
||
for code in valid_codes:
|
||
if code in index_data:
|
||
df = index_data[code]
|
||
# 提取原始收盘价序列
|
||
if 'close' in df.columns:
|
||
close_series = df['close'].dropna()
|
||
# 先在原始交易日历计算收益率
|
||
returns_series = close_series.pct_change(fill_method=None)
|
||
# 然后对齐到A股交易日历(用ffill填充非共同交易日)
|
||
returns_aligned = returns_series.reindex(a_share_dates, method='ffill')
|
||
returns_data[f'日收益率_{code}'] = returns_aligned
|
||
|
||
returns_df = pd.DataFrame(returns_data)
|
||
|
||
# 确保信号和收益率数据日期对齐
|
||
common_dates = signals.index.intersection(returns_df.index)
|
||
signals = signals.loc[common_dates]
|
||
returns_df = returns_df.loc[common_dates]
|
||
|
||
print(f" 对齐后日期: {len(common_dates)} 天")
|
||
|
||
executor = BacktestExecutor(
|
||
initial_capital=100000,
|
||
trade_cost=self.trade_cost,
|
||
select_num=self.select_num
|
||
)
|
||
|
||
portfolio = executor.execute(signals, returns_df)
|
||
|
||
# 5. 输出结果
|
||
if hasattr(portfolio, 'backtest_result'):
|
||
result = portfolio.backtest_result
|
||
final_nav = result['策略净值'].iloc[-1]
|
||
total_return = (final_nav - 1) * 100
|
||
|
||
print("\n回测结果:")
|
||
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
|
||
|
||
# 获取调仓事件
|
||
rebalance_events = getattr(portfolio, 'rebalance_events', pd.DataFrame())
|
||
if not rebalance_events.empty:
|
||
print(f" 调仓次数: {len(rebalance_events)} 次")
|
||
|
||
# 保存报告
|
||
if save_path:
|
||
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
|
||
signals.to_csv(f"{save_path}_signals.csv")
|
||
|
||
# 保存调仓事件记录
|
||
if not rebalance_events.empty:
|
||
rebalance_events.to_csv(f"{save_path}_rebalances.csv")
|
||
print(f" 报告保存: {save_path}_*.csv (含调仓记录)")
|
||
else:
|
||
print(f" 报告保存: {save_path}_*.csv")
|
||
|
||
return {
|
||
'signals': signals,
|
||
'result': result,
|
||
'portfolio': portfolio,
|
||
'total_return': total_return,
|
||
'rebalance_events': rebalance_events
|
||
}
|
||
|
||
return {'signals': signals, 'result': None}
|
||
|
||
# 保留抽象方法实现
|
||
def init_factors(self) -> FactorCombiner:
|
||
return FactorCombiner([self._factor])
|
||
|
||
def init_signal_generator(self) -> SignalGenerator:
|
||
return self._selector |