From e6b2c8cfb7f50c2c3438e23436ef35d52723c29c Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 11 May 2026 23:56:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E9=80=82=E9=85=8D=E5=BD=92=E6=A1=A3?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=BA=90=E6=8E=A5=E5=8F=A3=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0dotenv=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 fetch_all() 替代 fetch_batch() - 添加 from dotenv import load_dotenv 加载环境变量 - 返回完整数据结构(index_data, etf_data, nav_data, benchmark) 回测验证成功: - 累计收益: 164.47% - 最终净值: 2.6447 - 信号日期: 1780 天 --- strategies/rotation/strategy.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index cfd7939..930c8e2 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -9,6 +9,10 @@ import yaml from datetime import datetime from pathlib import Path +# 加载环境变量 +from dotenv import load_dotenv +load_dotenv() + from framework.factors import FactorRegistry, FactorCombiner from framework.signals import SignalGenerator from framework.execution import BacktestExecutor @@ -104,9 +108,10 @@ class RotationStrategy(StrategyBase): def get_data(self) -> dict: """获取数据(复用归档的数据源)""" code_list_config = self.config.get('code_list', {}) - code_list = list(code_list_config.keys()) + benchmark_config = self.config.get('benchmark', {}) + benchmark_code = benchmark_config.get('code', '000300.SH') - if not code_list: + if not code_list_config: raise ValueError("配置中未找到 code_list") # 使用归档的HybridDataSource @@ -129,17 +134,22 @@ class RotationStrategy(StrategyBase): use_cache=self.config.get('use_cache', True) ) - # 获取数据 - index_data = {} - all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date) - - for code, df in all_data.items(): - if df is not None and not df.empty: - index_data[code] = df + # 调用 fetch_all(返回元组) + index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ + data_source.fetch_all( + code_config=code_list_config, + benchmark_code=benchmark_code, + start_date=self.start_date, + end_date=self.end_date + ) return { - 'index_data': index_data, - 'valid_codes': list(index_data.keys()) + 'index_data': index_ohlcv_data, # 原始OHLCV数据 + 'index_close': index_data, # 对齐后的收盘价(宽格式) + 'etf_data': etf_data, + 'etf_nav_data': etf_nav_data, + 'benchmark_data': benchmark_data, + 'valid_codes': valid_codes } def compute_factors(self, data: dict) -> pd.DataFrame: