diff --git a/archive/legacy_core/core/datasource/hybrid_source.py b/archive/legacy_core/core/datasource/hybrid_source.py index 6399ce9..1fa18a9 100644 --- a/archive/legacy_core/core/datasource/hybrid_source.py +++ b/archive/legacy_core/core/datasource/hybrid_source.py @@ -34,8 +34,8 @@ class SSHTunnelManager: # 处理 key_path:如果是相对路径,转换为绝对路径 key_path = config.get("key_path", "") if key_path and not os.path.isabs(key_path): - # 相对于项目根目录 - project_root = Path(__file__).parent.parent.parent + # 相对于项目根目录(需要跳5层:datasource->core->legacy_core->archive->etf) + project_root = Path(__file__).parent.parent.parent.parent.parent key_path = str(project_root / key_path) self.key_path = key_path print(f"SSH 私钥路径: {self.key_path}") diff --git a/archive/legacy_core/core/datasource/yfinance_source.py b/archive/legacy_core/core/datasource/yfinance_source.py index ef1d8ce..4f464e7 100644 --- a/archive/legacy_core/core/datasource/yfinance_source.py +++ b/archive/legacy_core/core/datasource/yfinance_source.py @@ -32,8 +32,8 @@ class SSHTunnelManager: # 处理 key_path:如果是相对路径,转换为绝对路径 key_path = config.get("key_path", "") if key_path and not os.path.isabs(key_path): - # 相对于项目根目录 - project_root = Path(__file__).parent.parent.parent + # 相对于项目根目录(需要跳5层:datasource->core->legacy_core->archive->etf) + project_root = Path(__file__).parent.parent.parent.parent.parent key_path = str(project_root / key_path) self.key_path = key_path diff --git a/datasource/hybrid_source.py b/datasource/hybrid_source.py index be411a7..ceab9d2 100644 --- a/datasource/hybrid_source.py +++ b/datasource/hybrid_source.py @@ -113,7 +113,8 @@ class HybridDataSource: Optional[pd.DataFrame], # etf_nav_data: ETF净值 Optional[pd.DataFrame], # benchmark_data: 基准数据 List[str], # valid_codes: 有效代码列表 - Dict[str, pd.DataFrame] # index_ohlcv_data: 原始OHLCV数据 + Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据 + Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射 ]: """ 批量获取数据 @@ -125,7 +126,7 @@ class HybridDataSource: end_date: 结束日期 Returns: - (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data) + (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map) """ if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') @@ -247,7 +248,7 @@ class HybridDataSource: benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") - return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data + return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes def __enter__(self): self._start_tunnel() diff --git a/datasource/tushare_source.py b/datasource/tushare_source.py index 0ba9ce4..c0e7ce5 100644 --- a/datasource/tushare_source.py +++ b/datasource/tushare_source.py @@ -226,8 +226,10 @@ class TushareSource: return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI") def is_futures(self, code: str) -> bool: - """判断是否为期货""" - return ".SHF" in code or ".NYM" in code or ".DCE" in code or ".CZC" in code + """判断是否为中国期货(仅支持上期所、大商所、郑商所)""" + # 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所) + # NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance + return ".SHF" in code or ".DCE" in code or ".CZC" in code def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index a475e4e..7b27941 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -125,7 +125,7 @@ class RotationStrategy(StrategyBase): ) # 调用 fetch_all - index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ + index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \ data_source.fetch_all( code_config=code_list_config, benchmark_code=benchmark_code, @@ -139,7 +139,8 @@ class RotationStrategy(StrategyBase): 'etf_data': etf_data, 'etf_nav_data': etf_nav_data, 'benchmark_data': benchmark_data, - 'valid_codes': valid_codes + 'valid_codes': valid_codes, + 'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 } def compute_factors(self, data: dict) -> pd.DataFrame: @@ -148,13 +149,42 @@ class RotationStrategy(StrategyBase): valid_codes = data['valid_codes'] factor_values = {} + final_valid_codes = [] + for code in valid_codes: df = index_data[code] - if len(df) >= self.n_days: - values = self._factor.compute(df) - factor_values[code] = values + # 只使用 close 列计算因子(匹配原引擎逻辑:部分指数只有收盘价) + if 'close' in df.columns: + close_series = df['close'].dropna() + else: + close_series = df.dropna() + + # 原引擎剔除逻辑:close 数据需要至少 n_days + 1 条 + if len(close_series) < self.n_days + 1: + print(f" ⚠ 剔除 {code}: 数据不足 ({len(close_series)} < {self.n_days + 1})") + continue + + # 只传入 close 列给因子计算器 + close_df = pd.DataFrame({'close': close_series}) + values = self._factor.compute(close_df) + factor_values[code] = values + final_valid_codes.append(code) - return pd.DataFrame(factor_values) + factor_df = pd.DataFrame(factor_values) + + # 过滤缺失率过高的标的 + total_rows = len(factor_df) + for code in final_valid_codes: + if code in factor_df.columns: + null_pct = factor_df[code].isnull().sum() / total_rows + if null_pct > 0.5: + print(f" ⚠ 剔除 {code}: 缺失率 {null_pct:.1%} 过高") + factor_df = factor_df.drop(columns=[code]) + + # 更新有效代码列表 + data['valid_codes'] = [c for c in final_valid_codes if c in factor_df.columns] + + return factor_df def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: """生成信号""" @@ -198,25 +228,35 @@ class RotationStrategy(StrategyBase): # 4. 执行回测 print("\n执行回测...") - # 使用对齐后的指数收盘价数据获取日期基准 - index_close = data.get('index_close') + # 获取ETF数据和代码映射 + etf_data = data.get('etf_data') + etf_code_map = data.get('etf_code_map', {}) # {指数代码: ETF代码} - # 计算日收益率(使用对齐后的收盘价数据) - if index_close is not None and not index_close.empty: - returns_df = index_close.pct_change() - returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns] - else: - # 回退到原始数据 + # 计算日收益率(使用ETF价格数据,匹配原引擎逻辑) + if etf_data is not None and not etf_data.empty: + # 使用ETF价格计算收益,列名保持指数代码格式 returns_data = {} - for code in valid_codes: - if code in index_data: - df = index_data[code] - returns_data[f'日收益率_{code}'] = df['close'].pct_change() + for idx_code in valid_codes: + etf_code = etf_code_map.get(idx_code, idx_code) + if etf_code in etf_data.columns: + returns_data[f'日收益率_{idx_code}'] = etf_data[etf_code].pct_change() returns_df = pd.DataFrame(returns_data) - - if valid_codes: - first_code = valid_codes[0] - returns_df.index = index_data[first_code].index + else: + # 回退到指数收盘价数据 + index_close = data.get('index_close') + if index_close is not None and not index_close.empty: + returns_df = index_close.pct_change() + returns_df.columns = [f'日收益率_{col}' for col in returns_df.columns] + else: + returns_data = {} + for code in valid_codes: + if code in index_data: + df = index_data[code] + returns_data[f'日收益率_{code}'] = df['close'].pct_change() + returns_df = pd.DataFrame(returns_data) + if valid_codes: + first_code = valid_codes[0] + returns_df.index = index_data[first_code].index # 确保信号和收益率数据日期对齐 common_dates = signals.index.intersection(returns_df.index) diff --git a/strategies/shared/signals/selectors.py b/strategies/shared/signals/selectors.py index c74dd90..2e81484 100644 --- a/strategies/shared/signals/selectors.py +++ b/strategies/shared/signals/selectors.py @@ -164,10 +164,6 @@ class TopNSelector(SignalGenerator): factor_cols: List[str] ) -> bool: """检查是否应该调仓(得分阈值检查)""" - if self.rebalance_threshold <= 0: - # 无阈值,直接调仓 - return target != current_held - # 提取当前持仓和目标持仓的代码 old_codes = [c for c in current_held.split(',') if c] new_codes = [c for c in target.split(',') if c] @@ -176,13 +172,14 @@ class TopNSelector(SignalGenerator): return True if set(new_codes) == set(old_codes): - return False + return False # 组合完全相同,不调仓 # 计算新旧组合的总得分 old_total = sum(float(row.get(col, 0)) for col in factor_cols if col in old_codes) new_total = sum(float(row.get(col, 0)) for col in factor_cols if col in new_codes) # 新组合得分需超过当前组合一定比例才调仓 + # 即使 threshold=0,也要确保 new_total >= old_total if old_total > 0: return (new_total / old_total - 1) >= self.rebalance_threshold