From ec749314bc1fc4fb74f38454cc9ee39559f37bc1 Mon Sep 17 00:00:00 2001 From: aszerW Date: Wed, 25 Mar 2026 22:01:44 +0800 Subject: [PATCH] =?UTF-8?q?feat(data-source):=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=8C=87=E6=95=B0-ETF=E5=8F=8C=E8=BD=A8=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E5=8F=8A=E5=9B=A0=E5=AD=90=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增使用Tushare获取A股ETF价格及净值数据的私有方法 - fetch_all方法支持接收完整代码配置,区分指数与ETF及市场类别 - 指数数据和ETF数据分别下载,ETF净值数据用于溢价率计算 - 采用A股交易日为主交易日历,非A股数据前向填充对齐 - 调整因子计算,支持指数价格计算因子,ETF价格计算收益率 - run_rotation脚本和RotationStrategy引擎适配指数-ETF配置格式 - 代码结构优化,增强多市场及加密货币处理能力 --- core/datasource/hybrid_source.py | 366 +++++++++++++++++++++++-------- core/factors/momentum.py | 130 ++++------- scripts/run_rotation.py | 30 ++- strategies/rotation/engine.py | 23 +- 4 files changed, 366 insertions(+), 183 deletions(-) diff --git a/core/datasource/hybrid_source.py b/core/datasource/hybrid_source.py index db24c37..28120a8 100644 --- a/core/datasource/hybrid_source.py +++ b/core/datasource/hybrid_source.py @@ -209,6 +209,116 @@ class HybridDataSource: if value is not None: os.environ[key] = value + def _fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """使用 Tushare 获取A股ETF数据(fund_daily接口)""" + import os + + # 临时清除代理环境变量 + original_proxy = {} + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + original_proxy[key] = os.environ.pop(key, None) + + try: + import tushare as ts + + pro = ts.pro_api(self._get_tushare_token()) + + # 转换代码格式 (510300.SH -> 510300.SH) + ts_code = code.replace(".SS", ".SH") + + # 获取ETF日线数据 + df = pro.fund_daily( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "trade_date": "date", + "open": "open", + "high": "high", + "low": "low", + "close": "close", + "vol": "volume", + }) + + # 转换日期格式 + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + + # 添加代码列 + df["code"] = code + + return df + + except Exception as e: + print(f"Tushare 下载ETF {code} 失败: {e}") + return None + + finally: + # 恢复代理环境变量 + for key, value in original_proxy.items(): + if value is not None: + os.environ[key] = value + + def _fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """使用 Tushare 获取ETF净值数据(fund_nav接口)""" + import os + + # 临时清除代理环境变量 + original_proxy = {} + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + original_proxy[key] = os.environ.pop(key, None) + + try: + import tushare as ts + + pro = ts.pro_api(self._get_tushare_token()) + + # 转换代码格式 + ts_code = code.replace(".SS", ".SH") + + # 获取ETF净值数据 + df = pro.fund_nav( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "nav_date": "date", + "unit_nav": "nav", + }) + + # 转换日期格式 + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + + # 添加代码列 + df["code"] = code + + return df + + except Exception as e: + print(f"Tushare 下载ETF净值 {code} 失败: {e}") + return None + + finally: + # 恢复代理环境变量 + for key, value in original_proxy.items(): + if value is not None: + os.environ[key] = value + def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """使用 YFinance 获取数据""" import time @@ -303,39 +413,50 @@ class HybridDataSource: def fetch_all( self, - code_list, # list[代码] 或 dict{代码: 名称} + code_config: dict, # {代码: {name, etf, market}} benchmark_code: str, start_date: str, end_date: str, - ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]: + ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], list]: """ - 批量获取数据 - 注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同, - 这里返回的是长格式数据,由调用方分别处理各市场的数据 - + 批量获取数据(支持指数-ETF映射) + + Args: + code_config: 配置字典,格式为 {index_code: {name, etf, market}} + benchmark_code: 基准指数代码 + start_date: 开始日期 + end_date: 结束日期 + Returns: - (etf_data, benchmark_data, valid_codes) - etf_data: DataFrame with columns [code, close, source], index=date + (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes) + - index_data: 指数数据(用于因子计算) + - etf_data: ETF价格数据(用于收益计算) + - etf_nav_data: ETF净值数据(用于溢价率计算) + - benchmark_data: 基准数据 + - valid_codes: 有效代码列表 """ - all_data = [] + index_data_list = [] + etf_data_list = [] valid_codes = [] - - # 兼容列表和字典格式 - if isinstance(code_list, dict): - codes = list(code_list.keys()) - code_name_map = code_list - else: - codes = code_list - code_name_map = {c: c for c in codes} - - print(f"开始下载 {len(codes)} 只标的的数据...") - china_codes = [c for c in codes if self._is_china_index(c)] - global_codes = [c for c in codes if not self._is_china_index(c)] + + # 提取指数代码和ETF代码 + index_codes = list(code_config.keys()) + etf_codes = {} + for idx_code, cfg in code_config.items(): + if cfg.get('etf'): + etf_codes[idx_code] = cfg['etf'] + + print(f"开始下载 {len(index_codes)} 只标的的数据...") + print(f" 指数代码: {len(index_codes)} 只") + print(f" ETF映射: {len(etf_codes)} 只") + + china_codes = [c for c in index_codes if self._is_china_index(c)] + global_codes = [c for c in index_codes if not self._is_china_index(c)] print(f" 中国A股指数: {len(china_codes)} 只") print(f" 港股/美股/加密货币: {len(global_codes)} 只") # 检查是否需要启动 socks2http 代理(用于加密货币) - crypto_codes = [c for c in codes if self._is_crypto(c)] + crypto_codes = [c for c in index_codes if self._is_crypto(c)] http_proxy = None socks2http_proc = None @@ -358,8 +479,9 @@ class HybridDataSource: except Exception as e: print(f" ✗ 启动代理失败: {e}") - # 分别下载数据 - for code in codes: + # 下载指数数据 + print("\n [1/2] 下载指数数据(用于因子计算)...") + for code in index_codes: if self._is_china_index(code): source = "Tushare" elif self._is_crypto(code): @@ -367,8 +489,8 @@ class HybridDataSource: else: source = "YFinance" - name = code_name_map.get(code, code) - print(f" 下载 {code} ({name}) - {source}...", end=" ") + name = code_config[code].get('name', code) + print(f" 下载 {code} ({name}) - {source}...", end=" ") # 加密货币使用 HTTP 代理 proxy = http_proxy if self._is_crypto(code) else None @@ -378,95 +500,163 @@ class HybridDataSource: # 标准化数据格式 data = data.copy() data['source'] = source + data['code'] = code # 确保code列正确 # 确保索引是日期格式且无时区,只保留日期部分(去掉时间) data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() - all_data.append(data[['code', 'close', 'source']]) + index_data_list.append(data[['code', 'close', 'source']]) valid_codes.append(code) print(f"✓ {len(data)} 条") else: print("✗ 无数据") + # 下载ETF数据(价格+净值,用于溢价率计算) + etf_nav_data_list = [] # ETF净值数据 + + if etf_codes: + print("\n [2/2] 下载ETF数据(价格+净值,用于溢价率计算)...") + + for idx_code, etf_code in etf_codes.items(): + name = code_config[idx_code].get('name', idx_code) + market = code_config[idx_code].get('market', 'A') + + # 加密货币跳过ETF下载 + if market == 'CRYPTO': + continue + + print(f" 下载 ETF {etf_code} (对应指数 {idx_code})...", end=" ") + + # 获取ETF价格数据 + price_data = self._fetch_etf(etf_code, start_date, end_date) + # 获取ETF净值数据 + nav_data = self._fetch_etf_nav(etf_code, start_date, end_date) + + if price_data is not None and len(price_data) > 0: + # 使用指数代码作为列名,保持与指数数据一致 + price_data = price_data.copy() + price_data['source'] = 'Tushare-ETF' + price_data['code'] = idx_code + price_data.index = pd.to_datetime(price_data.index, utc=True).tz_localize(None).normalize() + etf_data_list.append(price_data[['code', 'close', 'source']]) + + # 处理净值数据 + if nav_data is not None and len(nav_data) > 0: + nav_data = nav_data.copy() + nav_data['code'] = idx_code + nav_data.index = pd.to_datetime(nav_data.index, utc=True).tz_localize(None).normalize() + etf_nav_data_list.append(nav_data[['code', 'nav']]) + print(f"✓ 价格{len(price_data)}条 净值{len(nav_data)}条") + else: + print(f"✓ 价格{len(price_data)}条 (无净值数据)") + else: + print(f"✗ 无数据") + # 关闭 socks2http 代理 if socks2http_proc: socks2http_proc.terminate() socks2http_proc.wait() print(f"\n socks2http 代理已关闭") - if not all_data: - return None, None, [] - - # 检查数据源类型 - sources = set(d['source'].iloc[0] for d in all_data) - - if len(sources) == 1: - # 单一数据源:转换为宽格式(向后兼容) - all_df = pd.concat(all_data, ignore_index=False) - all_df = all_df.reset_index() - all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None) - etf_data = all_df.pivot_table( - index='date', - columns='code', - values='close', - aggfunc='first' - ) - print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):") - print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}") - print(f" 交易日数: {len(etf_data)}") - print(f" 有效标的: {len(etf_data.columns)} 只") - else: - # 多数据源:以主市场(Tushare/A股)为基准,其他市场数据前向填充 - print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):") - - # 合并所有数据(索引已经是标准化后的日期) - all_df = pd.concat(all_data, ignore_index=False) - all_df = all_df.reset_index() - # 重命名索引列为 date - if 'index' in all_df.columns: - all_df = all_df.rename(columns={'index': 'date'}) - # 确保 date 列是日期格式(不含时间) - all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize() + if not index_data_list: + return None, None, None, None, [] + # 处理指数数据 + print(f"\n整理指数数据(用于因子计算)...") + index_df = pd.concat(index_data_list, ignore_index=False) + index_df = index_df.reset_index() + if 'index' in index_df.columns: + index_df = index_df.rename(columns={'index': 'date'}) + index_df['date'] = pd.to_datetime(index_df['date']).dt.normalize() + + # 透视为宽格式 + index_data = index_df.pivot_table( + index='date', + columns='code', + values='close', + aggfunc='first' + ) + + # 以A股交易日为基准,对齐所有数据 + tushare_codes = [c for c in valid_codes if self._is_china_index(c)] + if tushare_codes: + primary_dates = index_data[tushare_codes[0]].dropna().index + print(f" 主市场交易日: {len(primary_dates)} 天") + + # 重新索引到主市场交易日 + index_data = index_data.reindex(primary_dates) + + # 对非A股指数进行前向填充 + non_a_codes = [c for c in valid_codes if not self._is_china_index(c)] + for code in non_a_codes: + if code in index_data.columns: + index_data[code] = index_data[code].ffill().bfill() + + print(f" 非A股标的: {len(non_a_codes)} 只 (已前向填充)") + + print(f" 时间范围: {index_data.index[0]} ~ {index_data.index[-1]}") + print(f" 交易日数: {len(index_data)}") + + # 处理ETF数据 + if etf_data_list: + print(f"\n整理ETF数据(用于收益计算)...") + etf_df = pd.concat(etf_data_list, ignore_index=False) + etf_df = etf_df.reset_index() + if 'index' in etf_df.columns: + etf_df = etf_df.rename(columns={'index': 'date'}) + etf_df['date'] = pd.to_datetime(etf_df['date']).dt.normalize() + # 透视为宽格式 - etf_data = all_df.pivot_table( + etf_data = etf_df.pivot_table( index='date', columns='code', values='close', aggfunc='first' ) - - # 获取主市场(Tushare)的交易日历 - tushare_codes = [c for c in valid_codes if self._is_china_index(c)] + + # 对齐到主市场交易日 if tushare_codes: - # 使用第一个A股代码的日期作为主市场交易日 - primary_dates = etf_data[tushare_codes[0]].dropna().index - print(f" 主市场交易日: {len(primary_dates)} 天") - - # 重新索引到主市场交易日,使用前向填充 etf_data = etf_data.reindex(primary_dates) - - # 对每个非主市场代码进行前向填充 - yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)] - for code in yfinance_codes: - if code in etf_data.columns: - # 前向填充:用最近的有效价格填充休市日的数据 - etf_data[code] = etf_data[code].ffill() - # 对于开头的NaN,用后向填充 - etf_data[code] = etf_data[code].bfill() - - print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)") - - print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}") - print(f" 交易日数: {len(etf_data)}") - print(f" 有效标的: {len(etf_data.columns)} 只") - + + print(f" ETF价格数据: {len(etf_data.columns)} 只") + else: + # 如果没有ETF数据,使用指数数据代替 + etf_data = index_data.copy() + print(f"\n无ETF映射,使用指数数据代替") + + # 处理ETF净值数据 + etf_nav_data = None + if etf_nav_data_list: + print(f"\n整理ETF净值数据(用于溢价率计算)...") + nav_df = pd.concat(etf_nav_data_list, ignore_index=False) + nav_df = nav_df.reset_index() + if 'index' in nav_df.columns: + nav_df = nav_df.rename(columns={'index': 'date'}) + nav_df['date'] = pd.to_datetime(nav_df['date']).dt.normalize() + + # 透视为宽格式 + etf_nav_data = nav_df.pivot_table( + index='date', + columns='code', + values='nav', + aggfunc='first' + ) + + # 对齐到主市场交易日,并前向填充缺失值(净值数据通常T+1更新) + if tushare_codes: + etf_nav_data = etf_nav_data.reindex(primary_dates) + etf_nav_data = etf_nav_data.ffill() # 前向填充缺失的净值数据 + + print(f" ETF净值数据: {len(etf_nav_data.columns)} 只") + # 获取基准数据 benchmark_data = self.fetch_single(benchmark_code, start_date, end_date) if benchmark_data is not None: - # 标准化日期索引(无时区,只保留日期部分) benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() - print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") + # 对齐到主市场交易日 + if tushare_codes: + benchmark_data = benchmark_data.reindex(primary_dates) + print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") - return etf_data, benchmark_data, valid_codes + return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes def __enter__(self): """上下文管理器入口""" diff --git a/core/factors/momentum.py b/core/factors/momentum.py index 315cc80..018a37b 100644 --- a/core/factors/momentum.py +++ b/core/factors/momentum.py @@ -80,110 +80,74 @@ def calculate_daily_return(price_series: pd.Series) -> pd.Series: def compute_factors( - etf_data: pd.DataFrame, + index_data: pd.DataFrame, code_list: list, n: int = 25, factor_type: str = "slope_r2", + etf_data: pd.DataFrame = None, + code_config: dict = None, ) -> tuple[pd.DataFrame, list]: """ - 计算所有指数的因子和日收益率 - 支持长格式数据(混合数据源:Tushare + YFinance) - + 计算所有指数的因子和日收益率(支持指数-ETF双轨数据) + Args: - etf_data: DataFrame, 长格式数据,包含 [code, close, source] 列 + index_data: 指数价格数据(宽格式,用于因子计算) code_list: 指数代码列表 n: 动量/趋势窗口 factor_type: 'momentum' 或 'slope_r2' + etf_data: ETF价格数据(宽格式,用于收益计算) + code_config: 代码配置字典 {code: {name, etf, market}},用于判断是否为加密货币 Returns: tuple: (result_df, valid_codes) + - result_df: 包含因子得分和日收益率的DataFrame + - valid_codes: 有效代码列表 """ - # 检查数据格式 - if 'code' in etf_data.columns: - # 长格式数据 - 按 code 分别计算因子(旧逻辑,保留兼容) - all_factors = [] - valid_codes = [] - - for code in code_list: - code_data = etf_data[etf_data['code'] == code].copy() - if len(code_data) == 0: - print(f" ⚠ 跳过 {code}: 不在数据中") - continue - - # 检查缺失值 - null_pct = code_data['close'].isnull().sum() / len(code_data) - if null_pct > 0.2: - print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") - continue - - # 按日期排序 - code_data = code_data.sort_index() - - # 计算日收益率和因子 - code_data[f"日收益率_{code}"] = calculate_daily_return(code_data['close']) - - if factor_type == "momentum": - code_data[f"得分_{code}"] = calculate_momentum(code_data['close'], n) - elif factor_type == "slope_r2": - code_data[f"得分_{code}"] = calculate_slope_r2(code_data['close'], n) - else: - raise ValueError(f"不支持的因子类型: {factor_type}") - - # 保留需要的列 - code_data = code_data[[f"日收益率_{code}", f"得分_{code}"]] - all_factors.append(code_data) + code_config = code_config or {} + + # 如果没有提供ETF数据,创建一个空的DataFrame + if etf_data is None: + etf_data = pd.DataFrame() + + result = index_data.copy() + + # 过滤掉缺失值过多的指数 + total_rows = len(result) + valid_codes = [] + for code in code_list: + if code not in result.columns: + print(f" ⚠ 跳过 {code}: 不在数据中") + continue + null_pct = result[code].isnull().sum() / total_rows + if null_pct > 0.2: + print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") + result = result.drop(columns=[code]) + else: valid_codes.append(code) - if not all_factors: - raise ValueError("没有有效的指数数据") + # 对有效指数计算因子和收益率 + for code in valid_codes: + # 因子基于指数价格计算 + if factor_type == "momentum": + result[f"得分_{code}"] = calculate_momentum(result[code], n) + elif factor_type == "slope_r2": + result[f"得分_{code}"] = calculate_slope_r2(result[code], n) + else: + raise ValueError(f"不支持的因子类型: {factor_type}") + + # 日收益率基于指数价格计算(回测使用指数价格) + result[f"日收益率_{code}"] = calculate_daily_return(result[code]) - # 合并所有因子的数据(按日期内连接 - 只保留所有指数都有数据的日期) - result = all_factors[0] - for df in all_factors[1:]: - result = result.join(df, how='inner') - - # 删除所有得分都是 NaN 的行(即窗口期内的数据) - score_cols = [f"得分_{code}" for code in valid_codes] - # 只删除完全无法比较的行(所有得分都是NaN) - result = result.dropna(subset=score_cols, how='all') - - else: - # 宽格式数据(向后兼容) - result = etf_data.copy() - - # 过滤掉缺失值过多的指数 - total_rows = len(result) - valid_codes = [] - for code in code_list: - if code not in result.columns: - print(f" ⚠ 跳过 {code}: 不在数据中") - continue - null_pct = result[code].isnull().sum() / total_rows - if null_pct > 0.2: - print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") - result = result.drop(columns=[code]) - else: - valid_codes.append(code) - - # 对有效指数计算因子 - for code in valid_codes: - result[f"日收益率_{code}"] = calculate_daily_return(result[code]) - - if factor_type == "momentum": - result[f"得分_{code}"] = calculate_momentum(result[code], n) - elif factor_type == "slope_r2": - result[f"得分_{code}"] = calculate_slope_r2(result[code], n) - else: - raise ValueError(f"不支持的因子类型: {factor_type}") - - # 按得分列做 dropna - score_cols = [f"得分_{code}" for code in valid_codes] - result = result.dropna(subset=score_cols) + # 按得分列做 dropna + score_cols = [f"得分_{code}" for code in valid_codes] + result = result.dropna(subset=score_cols) print("\n因子计算完成:") print(f" 因子类型: {factor_type}") print(f" 窗口天数: {n}") print(f" 有效指数: {len(valid_codes)}/{len(code_list)}") print(f" 有效数据: {len(result)} 行") + if etf_data is not index_data: + print(f" 使用ETF数据计算收益: ✓") return result, valid_codes diff --git a/scripts/run_rotation.py b/scripts/run_rotation.py index 2d1d41e..d14c760 100755 --- a/scripts/run_rotation.py +++ b/scripts/run_rotation.py @@ -59,22 +59,38 @@ def main(): from datetime import datetime config['end_date'] = datetime.now().strftime('%Y-%m-%d') - # 从配置中读取 code_list 和 code_name_map - # code_list 现在是一个字典 {代码: 名称} + # 从配置中读取 code_list(新的配置格式:{代码: {name, etf, market}}) code_list_config = config.get('code_list', {}) + + # 提取代码列表和名称映射 if isinstance(code_list_config, dict): code_list = list(code_list_config.keys()) - code_name_map = code_list_config + # 构建 code_name_map: {代码: 名称} + code_name_map = {} + for code, cfg in code_list_config.items(): + if isinstance(cfg, dict): + code_name_map[code] = cfg.get('name', code) + else: + # 兼容旧格式 + code_name_map[code] = cfg else: # 兼容旧格式(列表) code_list = code_list_config code_name_map = DEFAULT_CODE_NAME_MAP + code_list_config = {} benchmark_config = config.get('benchmark', {}) benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME) print(f"\n配置文件: {args.config}") print(f"候选标的: {len(code_list)} 只") + + # 统计ETF映射情况 + etf_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('etf')) + crypto_count = sum(1 for cfg in code_list_config.values() if isinstance(cfg, dict) and cfg.get('market') == 'CRYPTO') + print(f" - ETF映射: {etf_count} 只") + print(f" - 直接交易: {crypto_count} 只(加密货币)") + print(f"回测区间: {config['start_date']} ~ {config['end_date']}") print(f"因子类型: {config['factor_type']}") print(f"窗口天数: {config['n_days']}") @@ -82,8 +98,8 @@ def main(): print(f"调仓周期: {config['rebalance_days']} 天") print(f"交易成本: {config['trade_cost']:.2%}") - # 更新 config 中的 code_list 为列表格式 - config['code_list'] = code_list + # 保持 config 中的 code_list 为完整配置格式(用于引擎内部解析) + # 不需要修改 config['code_list'],引擎会直接使用原始配置 # 创建策略实例 strategy = RotationStrategy(config) @@ -119,6 +135,10 @@ def main(): benchmark_name=benchmark_name, save_path=args.save_path, select_num=config["select_num"], + code_config=code_list_config, # 传入完整配置以显示ETF映射 + index_data=strategy.index_data, # 传入指数数据 + etf_price_data=strategy.etf_data, # 传入ETF价格数据 + etf_nav_data_raw=strategy.etf_nav_data, # 传入ETF净值数据 ) elapsed = time.time() - start_time diff --git a/strategies/rotation/engine.py b/strategies/rotation/engine.py index 87820a0..3b02fb3 100644 --- a/strategies/rotation/engine.py +++ b/strategies/rotation/engine.py @@ -34,31 +34,40 @@ class RotationStrategy(BacktestStrategy): self.backtest_result = None def fetch_data(self) -> pd.DataFrame: - """获取数据""" + """获取数据(支持指数-ETF双轨数据)""" from config.settings import DEFAULT_BENCHMARK_CODE # 从配置中读取基准代码,或使用默认值 benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE) + + # 获取代码配置(包含 name, etf, market) + code_config = self.config.get("code_list", {}) - # 使用上下文管理器管理 SSH 隧道(如果是 YFinance 数据源) + # 使用上下文管理器管理 SSH 隧道 with self.data_source: - etf_data, benchmark_data, valid_codes = self.data_source.fetch_all( - self.config["code_list"], + index_data, etf_data, etf_nav_data, benchmark_data, valid_codes = self.data_source.fetch_all( + code_config, benchmark_code, self.config["start_date"], self.config["end_date"], ) - self.etf_data = etf_data + # 存储数据和配置 + self.index_data = index_data # 指数数据(用于因子计算) + self.etf_data = etf_data # ETF价格数据(用于收益计算) + self.etf_nav_data = etf_nav_data # ETF净值数据(用于溢价率计算) self.benchmark_data = benchmark_data self.valid_codes = valid_codes + self.code_config = code_config # 代码配置(用于判断市场类型) - # 计算因子 + # 计算因子(传入两套数据:指数数据用于因子,ETF数据用于收益) factor_data, valid_codes = compute_factors( - etf_data, + index_data, valid_codes, n=self.config["n_days"], factor_type=self.config["factor_type"], + etf_data=etf_data, # 传入ETF数据用于收益计算 + code_config=code_config, # 传入配置以判断加密货币 ) self.data = factor_data