refactor(rotation): 统一与配置文件代码映射和基准指数使用方式
- 将默认代码映射字典和基准指数改为可被策略配置覆盖的形式 - 修改配置文件rotation.yaml中候选池配置从列表变为代码与名称的字典映射 - 在运行脚本中加载配置时支持字典格式的code_list和benchmark,兼容旧格式列表 - 更新回测策略引擎通过配置动态获取基准指数代码 - 打印输出和函数调用中统一使用从配置加载的代码映射和基准名称数据
This commit is contained in:
@@ -55,8 +55,8 @@ def get_db_config() -> dict:
|
||||
}
|
||||
|
||||
|
||||
# ==================== 代码映射 ====================
|
||||
CODE_NAME_MAP = {
|
||||
# ==================== 代码映射(默认,可被策略配置覆盖)====================
|
||||
DEFAULT_CODE_NAME_MAP = {
|
||||
# 宽基
|
||||
"000300.SH": "沪深300",
|
||||
"000905.SH": "中证500",
|
||||
@@ -95,6 +95,6 @@ CODE_NAME_MAP = {
|
||||
"399702.SZ": "中证国债指数",
|
||||
}
|
||||
|
||||
# 基准指数
|
||||
BENCHMARK_CODE = "000300.SH"
|
||||
BENCHMARK_NAME = "沪深300指数"
|
||||
# 基准指数(默认,可被策略配置覆盖)
|
||||
DEFAULT_BENCHMARK_CODE = "000300.SH"
|
||||
DEFAULT_BENCHMARK_NAME = "沪深300指数"
|
||||
|
||||
@@ -1,33 +1,39 @@
|
||||
# ETF轮动策略配置
|
||||
|
||||
# ==================== 候选池配置 ====================
|
||||
# A股全行业指数代码列表(Tushare格式:XXXXXX.SH / XXXXXX.SZ)
|
||||
# A股全行业指数配置(Tushare格式:XXXXXX.SH / XXXXXX.SZ)
|
||||
# 格式: {代码: 名称}
|
||||
code_list:
|
||||
# 宽基指数
|
||||
- "000300.SH" # 沪深300(大盘蓝筹)
|
||||
- "000905.SH" # 中证500(中盘成长)
|
||||
- "000852.SH" # 中证1000(小盘)
|
||||
- "399006.SZ" # 创业板指(创业板龙头)
|
||||
- "000015.SH" # 上证红利(高股息价值)
|
||||
"000300.SH": "沪深300"
|
||||
"000905.SH": "中证500"
|
||||
"000852.SH": "中证1000"
|
||||
"399006.SZ": "创业板指"
|
||||
"000015.SH": "上证红利"
|
||||
# 金融
|
||||
- "399986.SZ" # 中证银行
|
||||
"399986.SZ": "中证银行"
|
||||
# 消费
|
||||
- "399997.SZ" # 中证白酒
|
||||
"399997.SZ": "中证白酒"
|
||||
# 医药健康
|
||||
- "399989.SZ" # 中证医疗
|
||||
"399989.SZ": "中证医疗"
|
||||
# 科技信息
|
||||
- "000935.SH" # 中证信息技术
|
||||
"000935.SH": "中证信息"
|
||||
# 新能源
|
||||
- "399976.SZ" # 新能源汽车
|
||||
"399976.SZ": "新能源车"
|
||||
# 周期资源
|
||||
- "399395.SZ" # 国证有色金属
|
||||
- "399998.SZ" # 中证煤炭
|
||||
- "399813.SZ" # 细分化工
|
||||
- "000937.SH" # 中证能源
|
||||
"399395.SZ": "国证有色"
|
||||
"399998.SZ": "中证煤炭"
|
||||
"399813.SZ": "细分化工"
|
||||
"000937.SH": "中证能源"
|
||||
# 其他行业
|
||||
- "399967.SZ" # 中证军工
|
||||
- "000949.SH" # 中证农业
|
||||
- "399702.SZ" # 中证国债指数
|
||||
"399967.SZ": "中证军工"
|
||||
"000949.SH": "中证农业"
|
||||
"399702.SZ": "国债指数"
|
||||
|
||||
# 基准指数配置
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300指数"
|
||||
|
||||
# ==================== 回测参数 ====================
|
||||
start_date: "2018-01-01"
|
||||
|
||||
@@ -20,7 +20,7 @@ sys.path.insert(0, str(project_root))
|
||||
from strategies.rotation.engine import RotationStrategy
|
||||
from strategies.rotation.portfolio import track_positions, save_trades
|
||||
from strategies.rotation.report import generate_performance_report
|
||||
from config.settings import CODE_NAME_MAP, BENCHMARK_NAME
|
||||
from config.settings import DEFAULT_CODE_NAME_MAP, DEFAULT_BENCHMARK_NAME
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
@@ -59,8 +59,22 @@ def main():
|
||||
from datetime import datetime
|
||||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
# 从配置中读取 code_list 和 code_name_map
|
||||
# code_list 现在是一个字典 {代码: 名称}
|
||||
code_list_config = config.get('code_list', {})
|
||||
if isinstance(code_list_config, dict):
|
||||
code_list = list(code_list_config.keys())
|
||||
code_name_map = code_list_config
|
||||
else:
|
||||
# 兼容旧格式(列表)
|
||||
code_list = code_list_config
|
||||
code_name_map = DEFAULT_CODE_NAME_MAP
|
||||
|
||||
benchmark_config = config.get('benchmark', {})
|
||||
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
|
||||
|
||||
print(f"\n配置文件: {args.config}")
|
||||
print(f"候选标的: {len(config['code_list'])} 只")
|
||||
print(f"候选标的: {len(code_list)} 只")
|
||||
print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
|
||||
print(f"因子类型: {config['factor_type']}")
|
||||
print(f"窗口天数: {config['n_days']}")
|
||||
@@ -68,6 +82,9 @@ def main():
|
||||
print(f"调仓周期: {config['rebalance_days']} 天")
|
||||
print(f"交易成本: {config['trade_cost']:.2%}")
|
||||
|
||||
# 更新 config 中的 code_list 为列表格式
|
||||
config['code_list'] = code_list
|
||||
|
||||
# 创建策略实例
|
||||
strategy = RotationStrategy(config)
|
||||
|
||||
@@ -85,7 +102,7 @@ def main():
|
||||
|
||||
trades_df, summary_df = track_positions(
|
||||
backtest_result,
|
||||
code_name_map=CODE_NAME_MAP,
|
||||
code_name_map=code_name_map,
|
||||
select_num=config["select_num"],
|
||||
)
|
||||
save_trades(trades_df, summary_df, save_path=args.save_path)
|
||||
@@ -98,8 +115,8 @@ def main():
|
||||
metrics = generate_performance_report(
|
||||
backtest_result,
|
||||
strategy.valid_codes,
|
||||
code_name_map=CODE_NAME_MAP,
|
||||
benchmark_name=BENCHMARK_NAME,
|
||||
code_name_map=code_name_map,
|
||||
benchmark_name=benchmark_name,
|
||||
save_path=args.save_path,
|
||||
select_num=config["select_num"],
|
||||
)
|
||||
|
||||
@@ -25,11 +25,14 @@ class RotationStrategy(BacktestStrategy):
|
||||
|
||||
def fetch_data(self) -> pd.DataFrame:
|
||||
"""获取数据"""
|
||||
from config.settings import BENCHMARK_CODE
|
||||
from config.settings import DEFAULT_BENCHMARK_CODE
|
||||
|
||||
# 从配置中读取基准代码,或使用默认值
|
||||
benchmark_code = self.config.get("benchmark", {}).get("code", DEFAULT_BENCHMARK_CODE)
|
||||
|
||||
etf_data, benchmark_data, valid_codes = self.data_source.fetch_all(
|
||||
self.config["code_list"],
|
||||
BENCHMARK_CODE,
|
||||
benchmark_code,
|
||||
self.config["start_date"],
|
||||
self.config["end_date"],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user