diff --git a/config/settings.py b/config/settings.py index 3659044..5039d33 100644 --- a/config/settings.py +++ b/config/settings.py @@ -55,8 +55,8 @@ def get_db_config() -> dict: } -# ==================== 代码映射 ==================== -CODE_NAME_MAP = { +# ==================== 代码映射(默认,可被策略配置覆盖)==================== +DEFAULT_CODE_NAME_MAP = { # 宽基 "000300.SH": "沪深300", "000905.SH": "中证500", @@ -95,6 +95,6 @@ CODE_NAME_MAP = { "399702.SZ": "中证国债指数", } -# 基准指数 -BENCHMARK_CODE = "000300.SH" -BENCHMARK_NAME = "沪深300指数" +# 基准指数(默认,可被策略配置覆盖) +DEFAULT_BENCHMARK_CODE = "000300.SH" +DEFAULT_BENCHMARK_NAME = "沪深300指数" diff --git a/config/strategies/rotation.yaml b/config/strategies/rotation.yaml index fa99bd8..ab271e1 100644 --- a/config/strategies/rotation.yaml +++ b/config/strategies/rotation.yaml @@ -1,33 +1,39 @@ # ETF轮动策略配置 # ==================== 候选池配置 ==================== -# A股全行业指数代码列表(Tushare格式:XXXXXX.SH / XXXXXX.SZ) +# A股全行业指数配置(Tushare格式:XXXXXX.SH / XXXXXX.SZ) +# 格式: {代码: 名称} code_list: # 宽基指数 - - "000300.SH" # 沪深300(大盘蓝筹) - - "000905.SH" # 中证500(中盘成长) - - "000852.SH" # 中证1000(小盘) - - "399006.SZ" # 创业板指(创业板龙头) - - "000015.SH" # 上证红利(高股息价值) + "000300.SH": "沪深300" + "000905.SH": "中证500" + "000852.SH": "中证1000" + "399006.SZ": "创业板指" + "000015.SH": "上证红利" # 金融 - - "399986.SZ" # 中证银行 + "399986.SZ": "中证银行" # 消费 - - "399997.SZ" # 中证白酒 + "399997.SZ": "中证白酒" # 医药健康 - - "399989.SZ" # 中证医疗 + "399989.SZ": "中证医疗" # 科技信息 - - "000935.SH" # 中证信息技术 + "000935.SH": "中证信息" # 新能源 - - "399976.SZ" # 新能源汽车 + "399976.SZ": "新能源车" # 周期资源 - - "399395.SZ" # 国证有色金属 - - "399998.SZ" # 中证煤炭 - - "399813.SZ" # 细分化工 - - "000937.SH" # 中证能源 + "399395.SZ": "国证有色" + "399998.SZ": "中证煤炭" + "399813.SZ": "细分化工" + "000937.SH": "中证能源" # 其他行业 - - "399967.SZ" # 中证军工 - - "000949.SH" # 中证农业 - - "399702.SZ" # 中证国债指数 + "399967.SZ": "中证军工" + "000949.SH": "中证农业" + "399702.SZ": "国债指数" + +# 基准指数配置 +benchmark: + code: "000300.SH" + name: "沪深300指数" # ==================== 回测参数 ==================== start_date: "2018-01-01" diff --git a/scripts/run_rotation.py b/scripts/run_rotation.py index 4e14e78..2d1d41e 100755 --- a/scripts/run_rotation.py +++ b/scripts/run_rotation.py @@ -20,7 +20,7 @@ sys.path.insert(0, str(project_root)) from strategies.rotation.engine import RotationStrategy from strategies.rotation.portfolio import track_positions, save_trades from strategies.rotation.report import generate_performance_report -from config.settings import CODE_NAME_MAP, BENCHMARK_NAME +from config.settings import DEFAULT_CODE_NAME_MAP, DEFAULT_BENCHMARK_NAME def load_config(config_path: str) -> dict: @@ -59,8 +59,22 @@ def main(): from datetime import datetime config['end_date'] = datetime.now().strftime('%Y-%m-%d') + # 从配置中读取 code_list 和 code_name_map + # code_list 现在是一个字典 {代码: 名称} + code_list_config = config.get('code_list', {}) + if isinstance(code_list_config, dict): + code_list = list(code_list_config.keys()) + code_name_map = code_list_config + else: + # 兼容旧格式(列表) + code_list = code_list_config + code_name_map = DEFAULT_CODE_NAME_MAP + + benchmark_config = config.get('benchmark', {}) + benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME) + print(f"\n配置文件: {args.config}") - print(f"候选标的: {len(config['code_list'])} 只") + print(f"候选标的: {len(code_list)} 只") print(f"回测区间: {config['start_date']} ~ {config['end_date']}") print(f"因子类型: {config['factor_type']}") print(f"窗口天数: {config['n_days']}") @@ -68,6 +82,9 @@ def main(): print(f"调仓周期: {config['rebalance_days']} 天") print(f"交易成本: {config['trade_cost']:.2%}") + # 更新 config 中的 code_list 为列表格式 + config['code_list'] = code_list + # 创建策略实例 strategy = RotationStrategy(config) @@ -85,7 +102,7 @@ def main(): trades_df, summary_df = track_positions( backtest_result, - code_name_map=CODE_NAME_MAP, + code_name_map=code_name_map, select_num=config["select_num"], ) save_trades(trades_df, summary_df, save_path=args.save_path) @@ -98,8 +115,8 @@ def main(): metrics = generate_performance_report( backtest_result, strategy.valid_codes, - code_name_map=CODE_NAME_MAP, - benchmark_name=BENCHMARK_NAME, + code_name_map=code_name_map, + benchmark_name=benchmark_name, save_path=args.save_path, select_num=config["select_num"], ) diff --git a/strategies/rotation/engine.py b/strategies/rotation/engine.py index 4c64221..8d8a970 100644 --- a/strategies/rotation/engine.py +++ b/strategies/rotation/engine.py @@ -25,11 +25,14 @@ class RotationStrategy(BacktestStrategy): def fetch_data(self) -> pd.DataFrame: """获取数据""" - from config.settings import BENCHMARK_CODE + from config.settings import DEFAULT_BENCHMARK_CODE + + # 从配置中读取基准代码,或使用默认值 + benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE) etf_data, benchmark_data, valid_codes = self.data_source.fetch_all( self.config["code_list"], - BENCHMARK_CODE, + benchmark_code, self.config["start_date"], self.config["end_date"], )