""" 轮动策略完整实现 整合数据获取、因子计算、信号生成、回测执行 """ import pandas as pd import yaml from datetime import datetime from pathlib import Path # 加载环境变量 from dotenv import load_dotenv load_dotenv() from framework.factors import FactorRegistry, FactorCombiner from framework.signals import SignalGenerator from framework.execution import BacktestExecutor from framework.risk import CallbackHook, Position from framework.strategy import StrategyBase # 导入定制组件 from strategies.shared.factors.momentum import MomentumFactor from strategies.shared.signals.selectors import TopNSelector class RotationStrategy(StrategyBase): """ ETF轮动策略(完整实现) 基于动量因子 + Top N选股 + 分散化 使用方式: from strategies.rotation.strategy import RotationStrategy strategy = RotationStrategy.from_yaml('strategies/rotation/config.yaml') result = strategy.run_backtest() """ name = "rotation" select_num = 3 stoploss = -0.05 n_days = 25 rebalance_days = 1 rebalance_threshold = 0.0 trade_cost = 0.001 def __init__(self, config: dict = None): """初始化策略""" # 应用配置 if config: self._apply_config(config) self.config = config else: self.config = {} # 初始化因子 FactorRegistry.clear() FactorRegistry.register(MomentumFactor) self._factor = FactorRegistry.get( 'momentum', n_days=self.n_days, crash_filter=True ) # 构建分组映射(分散化选股) self._group_mapping = self._build_group_mapping() # 初始化信号生成器 self._selector = TopNSelector( select_num=self.select_num, group_mapping=self._group_mapping, min_score=self.min_score, # 从配置读取,支持动态调整阈值 rebalance_days=self.rebalance_days, rebalance_threshold=self.rebalance_threshold, bond_threshold_config=self.config.get('bond_threshold', {}) # V3动态阈值配置 ) @classmethod def from_yaml(cls, config_path: str) -> 'RotationStrategy': """从YAML配置创建策略实例""" with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # 设置结束日期 if not config.get('end_date'): config['end_date'] = datetime.now().strftime('%Y-%m-%d') return cls(config) def _apply_config(self, config: dict) -> None: """应用配置参数""" self.select_num = config.get('select_num', self.select_num) self.n_days = config.get('n_days', self.n_days) self.rebalance_days = config.get('rebalance_days', self.rebalance_days) self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold) self.trade_cost = config.get('trade_cost', self.trade_cost) self.min_score = config.get('min_score', 0.0) # 动量最低阈值,默认过滤负动量 self.start_date = config.get('start_date', '2019-01-01') self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d')) def _build_group_mapping(self) -> dict: """构建分组映射(分散化选股)""" group_mapping = {} code_list_config = self.config.get('code_list', {}) for code, cfg in code_list_config.items(): if isinstance(cfg, dict): group_mapping[code] = cfg.get('market', 'default') return group_mapping def get_data(self, use_flask_api: bool = True) -> dict: """ 获取数据 Args: use_flask_api: 是否使用 Flask API 服务获取数据(默认 True) False 则使用本地 HybridDataSource """ code_list_config = self.config.get('code_list', {}) benchmark_config = self.config.get('benchmark', {}) benchmark_code = benchmark_config.get('code', '000300.SH') if not code_list_config: raise ValueError("配置中未找到 code_list") # 获取 Flask API 地址 flask_api_config = self.config.get('flask_api', {}) flask_api_url = flask_api_config.get('url') if flask_api_config.get('enabled') else None if use_flask_api: # 使用 Flask API 服务获取数据(远程调用) return self._get_data_from_flask_api( code_list_config, benchmark_code, flask_api_url ) else: # 使用本地 HybridDataSource(需要本地 SSH 隧道) return self._get_data_from_local( code_list_config, benchmark_code ) def _get_data_from_flask_api( self, code_list_config: dict, benchmark_code: str, flask_api_url: str = None ) -> dict: """通过 Flask API 服务获取数据""" from datasource.flask_api_source import FlaskAPIDataSource # 初始化 Flask API 数据源 api_source = FlaskAPIDataSource(base_url=flask_api_url) # 检查服务状态 health = api_source.get_health() if health.get('status') != 'healthy': print(f"⚠ Flask API 服务状态: {health}") else: print(f"✓ Flask API 服务正常 (SSH: {health.get('ssh_configured', False)})") # 打印回测时间区间说明 print(f"\n回测配置区间: {self.start_date} ~ {self.end_date}") print("注: 各标的实际数据范围可能因上市时间/数据源限制而不同") # 获取指数代码列表 index_codes = list(code_list_config.keys()) # 获取 ETF 代码映射 etf_code_map = {} etf_codes = [] for index_code, cfg in code_list_config.items(): if isinstance(cfg, dict) and cfg.get('etf'): etf_code_map[index_code] = cfg['etf'] etf_codes.append(cfg['etf']) # 获取指数 OHLCV 数据 print(f"\n获取指数数据 ({len(index_codes)} 只)...") index_ohlcv_data = api_source.fetch_batch( index_codes, self.start_date, self.end_date ) # 过滤有效代码 valid_codes = [code for code, df in index_ohlcv_data.items() if df is not None and len(df) > 0] print(f"有效指数: {len(valid_codes)} 只") # 获取 ETF 价格数据(同时获取净值和溢价率) print(f"\n获取 ETF 数据 ({len(etf_codes)} 只)...") etf_ohlcv_data = api_source.fetch_batch( etf_codes, self.start_date, self.end_date ) # 转换为宽格式 DataFrame,并提取净值/溢价率数据 etf_data = None etf_nav_data = {} etf_premium_data = {} if etf_ohlcv_data: etf_close_dict = {} for etf_code, df in etf_ohlcv_data.items(): if df is not None and 'close' in df.columns: etf_close_dict[etf_code] = df['close'] # 从 DataFrame.attrs 中提取净值和溢价率数据 # Flask API 已自动附加这些数据 if 'nav' in df.attrs: etf_nav_data[etf_code] = df.attrs['nav'] if 'premium_series' in df.attrs: etf_premium_data[etf_code] = { 'series': df.attrs['premium_series'], 'latest': df.attrs.get('latest_premium'), 'date': df.attrs.get('premium_date'), 'stats': df.attrs.get('premium_stats'), } if etf_close_dict: etf_data = pd.DataFrame(etf_close_dict) print(f"有效净值: {len(etf_nav_data)} 只") print(f"有效溢价率: {len(etf_premium_data)} 只") # 获取基准数据 print(f"\n获取基准数据 ({benchmark_code})...") benchmark_ohlcv = api_source.fetch(benchmark_code, self.start_date, self.end_date) benchmark_data = None if benchmark_ohlcv is not None: benchmark_data = benchmark_ohlcv['close'] # 构建指数收盘价宽格式 DataFrame(用于因子计算) index_close_dict = {} for code in valid_codes: df = index_ohlcv_data.get(code) if df is not None and 'close' in df.columns: index_close_dict[code] = df['close'] index_close = pd.DataFrame(index_close_dict) if index_close_dict else None return { 'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame} 'index_close': index_close, # 对齐后的收盘价(宽格式) 'etf_data': etf_data, # ETF 收盘价(宽格式) 'etf_nav_data': etf_nav_data, # ETF 净值数据 {code: DataFrame} 'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict} 'benchmark_data': benchmark_data, # 基准收盘价 Series 'valid_codes': valid_codes, # 有效指数代码列表 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 } def _get_data_from_local( self, code_list_config: dict, benchmark_code: str ) -> dict: """使用本地 HybridDataSource 获取数据""" from datasource import HybridDataSource ssh_config = self.config.get('ssh_tunnel', {}) data_source = HybridDataSource( ssh_config=ssh_config, use_cache=self.config.get('use_cache', True) ) # 调用 fetch_all index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \ data_source.fetch_all( code_config=code_list_config, benchmark_code=benchmark_code, start_date=self.start_date, end_date=self.end_date ) return { 'index_data': index_ohlcv_data, # 原始OHLCV数据 'index_close': index_data, # 对齐后的收盘价(宽格式) 'etf_data': etf_data, 'etf_nav_data': etf_nav_data, 'benchmark_data': benchmark_data, 'valid_codes': valid_codes, 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 } def compute_factors(self, data: dict) -> pd.DataFrame: """计算因子值(匹配原引擎:先计算因子再对齐到A股交易日历) 注意:不剔除数据不足的标的,保留所有标的以暴露策略问题 """ index_data = data['index_data'] valid_codes = data['valid_codes'] # 获取A股交易日历作为基准(使用已有的对齐后数据索引) index_close = data.get('index_close') if index_close is not None: a_share_dates = index_close.index else: for code in valid_codes: if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'): a_share_dates = index_data[code].index break else: a_share_dates = index_data[valid_codes[0]].index factor_values = {} final_valid_codes = [] for code in valid_codes: df = index_data[code].copy() # 检查是否有有效的OHLCV数据(列存在且不全为None) ohlcv_cols = ['open', 'high', 'low', 'close', 'volume'] required_cols = ['open', 'high', 'low', 'close'] # 检查列是否存在 cols_exist = all(col in df.columns for col in required_cols) # 检查数据是否有效(不全为None/NaN) if cols_exist: cols_have_data = all(df[col].notna().any() for col in required_cols) else: cols_have_data = False if cols_exist and cols_have_data: # 有完整有效的OHLCV数据,整行dropna()后提取close df_clean = df[ohlcv_cols].dropna() close_series = df_clean['close'] if len(df_clean) > 0 else pd.Series(dtype=float) elif 'close' in df.columns and df['close'].notna().any(): # 只有close列有效数据(如债券指数) close_series = df['close'].dropna() else: # 无有效数据 close_series = pd.Series(dtype=float) # 检查数据长度并警告,但不剔除 if len(close_series) < self.n_days + 1: print(f" ⚠ {code}: 数据不足 ({len(close_series)} < {self.n_days + 1}),保留但因子值可能为NaN") # 原引擎逻辑:先在原始交易日历上计算因子 # rolling窗口使用的是原始交易日数据,不包含ffill填充的重复值 if len(close_series) > 0: close_df = pd.DataFrame({'close': close_series}) factor_series = self._factor.compute(close_df) # 然后对齐因子序列到A股交易日历(匹配原引擎逻辑) factor_aligned = factor_series.reindex(a_share_dates, method='ffill') else: # 没有数据,创建空的因子序列 factor_aligned = pd.Series(index=a_share_dates, dtype=float) factor_values[code] = factor_aligned final_valid_codes.append(code) factor_df = pd.DataFrame(factor_values) # 检查缺失率并警告,但不剔除(保留所有标的以暴露策略问题) total_rows = len(factor_df) for code in final_valid_codes: if code in factor_df.columns: null_pct = factor_df[code].isnull().sum() / total_rows if null_pct > 0.5: print(f" ⚠ {code}: 缺失率 {null_pct:.1%} 较高,保留但信号生成时可能跳过") # 不更新有效代码列表,保留所有原始代码 data['valid_codes'] = final_valid_codes return factor_df def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: """生成信号""" return self._selector.generate(factor_df) def run_backtest(self, data: dict = None, save_path: str = None) -> dict: """ 完整回测流程 Args: data: 可选,如不提供则自动获取 save_path: 报告保存路径 Returns: 回测结果字典 """ print("\n" + "=" * 60) print(" ETF轮动策略 回测系统") print("=" * 60) # 1. 获取数据 if data is None: data = self.get_data() valid_codes = data['valid_codes'] index_data = data['index_data'] print(f"\n候选标的: {len(valid_codes)} 只") print(f"回测区间: {self.start_date} ~ {self.end_date}") # 2. 计算因子 print("\n计算因子...") factor_df = self.compute_factors(data) print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)} 只") # 3. 生成信号 print("\n生成信号...") signals = self.generate_signals(factor_df) print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)} 天") # 4. 执行回测 print("\n执行回测...") # 获取A股交易日历(从因子数据索引) a_share_dates = signals.index # 计算日收益率:先在原始交易日历计算,再对齐到A股日历 # 关键:与因子计算逻辑一致,避免交易日不对齐导致收益率NaN returns_data = {} for code in valid_codes: if code in index_data: df = index_data[code] # 提取原始收盘价序列 if 'close' in df.columns: close_series = df['close'].dropna() # 先在原始交易日历计算收益率 returns_series = close_series.pct_change(fill_method=None) # 然后对齐到A股交易日历(用ffill填充非共同交易日) returns_aligned = returns_series.reindex(a_share_dates, method='ffill') returns_data[f'日收益率_{code}'] = returns_aligned returns_df = pd.DataFrame(returns_data) # 确保信号和收益率数据日期对齐 common_dates = signals.index.intersection(returns_df.index) signals = signals.loc[common_dates] returns_df = returns_df.loc[common_dates] print(f" 对齐后日期: {len(common_dates)} 天") executor = BacktestExecutor( initial_capital=100000, trade_cost=self.trade_cost, select_num=self.select_num ) portfolio = executor.execute(signals, returns_df) # 5. 输出结果 if hasattr(portfolio, 'backtest_result'): result = portfolio.backtest_result final_nav = result['策略净值'].iloc[-1] total_return = (final_nav - 1) * 100 print("\n回测结果:") print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%") # 获取调仓事件 rebalance_events = getattr(portfolio, 'rebalance_events', pd.DataFrame()) if not rebalance_events.empty: print(f" 调仓次数: {len(rebalance_events)} 次") # 保存报告 if save_path: result[['策略净值']].to_csv(f"{save_path}_nav.csv") signals.to_csv(f"{save_path}_signals.csv") # 保存调仓事件记录 if not rebalance_events.empty: rebalance_events.to_csv(f"{save_path}_rebalances.csv") print(f" 报告保存: {save_path}_*.csv (含调仓记录)") else: print(f" 报告保存: {save_path}_*.csv") return { 'signals': signals, 'result': result, 'portfolio': portfolio, 'total_return': total_return, 'rebalance_events': rebalance_events } return {'signals': signals, 'result': None} # 保留抽象方法实现 def init_factors(self) -> FactorCombiner: return FactorCombiner([self._factor]) def init_signal_generator(self) -> SignalGenerator: return self._selector