refactor(rotation): 统一与配置文件代码映射和基准指数使用方式

- 将默认代码映射字典和基准指数改为可被策略配置覆盖的形式
- 修改配置文件rotation.yaml中候选池配置从列表变为代码与名称的字典映射
- 在运行脚本中加载配置时支持字典格式的code_list和benchmark,兼容旧格式列表
- 更新回测策略引擎通过配置动态获取基准指数代码
- 打印输出和函数调用中统一使用从配置加载的代码映射和基准名称数据
This commit is contained in:
2026-03-19 00:33:06 +08:00
parent 9b154a1a25
commit 062f500369
4 changed files with 56 additions and 30 deletions

View File

@@ -55,8 +55,8 @@ def get_db_config() -> dict:
}
# ==================== 代码映射 ====================
CODE_NAME_MAP = {
# ==================== 代码映射(默认,可被策略配置覆盖)====================
DEFAULT_CODE_NAME_MAP = {
# 宽基
"000300.SH": "沪深300",
"000905.SH": "中证500",
@@ -95,6 +95,6 @@ CODE_NAME_MAP = {
"399702.SZ": "中证国债指数",
}
# 基准指数
BENCHMARK_CODE = "000300.SH"
BENCHMARK_NAME = "沪深300指数"
# 基准指数(默认,可被策略配置覆盖)
DEFAULT_BENCHMARK_CODE = "000300.SH"
DEFAULT_BENCHMARK_NAME = "沪深300指数"

View File

@@ -1,33 +1,39 @@
# ETF轮动策略配置
# ==================== 候选池配置 ====================
# A股全行业指数代码列表Tushare格式XXXXXX.SH / XXXXXX.SZ
# A股全行业指数配置Tushare格式XXXXXX.SH / XXXXXX.SZ
# 格式: {代码: 名称}
code_list:
# 宽基指数
- "000300.SH" # 沪深300大盘蓝筹
- "000905.SH" # 中证500中盘成长
- "000852.SH" # 中证1000(小盘)
- "399006.SZ" # 创业板指(创业板龙头)
- "000015.SH" # 上证红利(高股息价值)
"000300.SH": "沪深300"
"000905.SH": "中证500"
"000852.SH": "中证1000"
"399006.SZ": "创业板指"
"000015.SH": "上证红利"
# 金融
- "399986.SZ" # 中证银行
"399986.SZ": "中证银行"
# 消费
- "399997.SZ" # 中证白酒
"399997.SZ": "中证白酒"
# 医药健康
- "399989.SZ" # 中证医疗
"399989.SZ": "中证医疗"
# 科技信息
- "000935.SH" # 中证信息技术
"000935.SH": "中证信息"
# 新能源
- "399976.SZ" # 新能源
"399976.SZ": "新能源车"
# 周期资源
- "399395.SZ" # 国证有色金属
- "399998.SZ" # 中证煤炭
- "399813.SZ" # 细分化工
- "000937.SH" # 中证能源
"399395.SZ": "国证有色"
"399998.SZ": "中证煤炭"
"399813.SZ": "细分化工"
"000937.SH": "中证能源"
# 其他行业
- "399967.SZ" # 中证军工
- "000949.SH" # 中证农业
- "399702.SZ" # 中证国债指数
"399967.SZ": "中证军工"
"000949.SH": "中证农业"
"399702.SZ": "国债指数"
# 基准指数配置
benchmark:
code: "000300.SH"
name: "沪深300指数"
# ==================== 回测参数 ====================
start_date: "2018-01-01"

View File

@@ -20,7 +20,7 @@ sys.path.insert(0, str(project_root))
from strategies.rotation.engine import RotationStrategy
from strategies.rotation.portfolio import track_positions, save_trades
from strategies.rotation.report import generate_performance_report
from config.settings import CODE_NAME_MAP, BENCHMARK_NAME
from config.settings import DEFAULT_CODE_NAME_MAP, DEFAULT_BENCHMARK_NAME
def load_config(config_path: str) -> dict:
@@ -59,8 +59,22 @@ def main():
from datetime import datetime
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 从配置中读取 code_list 和 code_name_map
# code_list 现在是一个字典 {代码: 名称}
code_list_config = config.get('code_list', {})
if isinstance(code_list_config, dict):
code_list = list(code_list_config.keys())
code_name_map = code_list_config
else:
# 兼容旧格式(列表)
code_list = code_list_config
code_name_map = DEFAULT_CODE_NAME_MAP
benchmark_config = config.get('benchmark', {})
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(config['code_list'])}")
print(f"候选标的: {len(code_list)}")
print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
print(f"因子类型: {config['factor_type']}")
print(f"窗口天数: {config['n_days']}")
@@ -68,6 +82,9 @@ def main():
print(f"调仓周期: {config['rebalance_days']}")
print(f"交易成本: {config['trade_cost']:.2%}")
# 更新 config 中的 code_list 为列表格式
config['code_list'] = code_list
# 创建策略实例
strategy = RotationStrategy(config)
@@ -85,7 +102,7 @@ def main():
trades_df, summary_df = track_positions(
backtest_result,
code_name_map=CODE_NAME_MAP,
code_name_map=code_name_map,
select_num=config["select_num"],
)
save_trades(trades_df, summary_df, save_path=args.save_path)
@@ -98,8 +115,8 @@ def main():
metrics = generate_performance_report(
backtest_result,
strategy.valid_codes,
code_name_map=CODE_NAME_MAP,
benchmark_name=BENCHMARK_NAME,
code_name_map=code_name_map,
benchmark_name=benchmark_name,
save_path=args.save_path,
select_num=config["select_num"],
)

View File

@@ -25,11 +25,14 @@ class RotationStrategy(BacktestStrategy):
def fetch_data(self) -> pd.DataFrame:
"""获取数据"""
from config.settings import BENCHMARK_CODE
from config.settings import DEFAULT_BENCHMARK_CODE
# 从配置中读取基准代码,或使用默认值
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all(
self.config["code_list"],
BENCHMARK_CODE,
benchmark_code,
self.config["start_date"],
self.config["end_date"],
)