diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index cfd7939..930c8e2 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -9,6 +9,10 @@ import yaml from datetime import datetime from pathlib import Path +# 加载环境变量 +from dotenv import load_dotenv +load_dotenv() + from framework.factors import FactorRegistry, FactorCombiner from framework.signals import SignalGenerator from framework.execution import BacktestExecutor @@ -104,9 +108,10 @@ class RotationStrategy(StrategyBase): def get_data(self) -> dict: """获取数据(复用归档的数据源)""" code_list_config = self.config.get('code_list', {}) - code_list = list(code_list_config.keys()) + benchmark_config = self.config.get('benchmark', {}) + benchmark_code = benchmark_config.get('code', '000300.SH') - if not code_list: + if not code_list_config: raise ValueError("配置中未找到 code_list") # 使用归档的HybridDataSource @@ -129,17 +134,22 @@ class RotationStrategy(StrategyBase): use_cache=self.config.get('use_cache', True) ) - # 获取数据 - index_data = {} - all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date) - - for code, df in all_data.items(): - if df is not None and not df.empty: - index_data[code] = df + # 调用 fetch_all(返回元组) + index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ + data_source.fetch_all( + code_config=code_list_config, + benchmark_code=benchmark_code, + start_date=self.start_date, + end_date=self.end_date + ) return { - 'index_data': index_data, - 'valid_codes': list(index_data.keys()) + 'index_data': index_ohlcv_data, # 原始OHLCV数据 + 'index_close': index_data, # 对齐后的收盘价(宽格式) + 'etf_data': etf_data, + 'etf_nav_data': etf_nav_data, + 'benchmark_data': benchmark_data, + 'valid_codes': valid_codes } def compute_factors(self, data: dict) -> pd.DataFrame: