Files
etf/scripts/run_rotation.py
aszerW 062f500369 refactor(rotation): 统一与配置文件代码映射和基准指数使用方式
- 将默认代码映射字典和基准指数改为可被策略配置覆盖的形式
- 修改配置文件rotation.yaml中候选池配置从列表变为代码与名称的字典映射
- 在运行脚本中加载配置时支持字典格式的code_list和benchmark,兼容旧格式列表
- 更新回测策略引擎通过配置动态获取基准指数代码
- 打印输出和函数调用中统一使用从配置加载的代码映射和基准名称数据
2026-03-19 00:33:06 +08:00

132 lines
3.6 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
ETF轮动策略回测入口
用法:
python scripts/run_rotation.py
python scripts/run_rotation.py --config config/strategies/rotation.yaml
"""
import sys
import time
import yaml
import argparse
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from strategies.rotation.engine import RotationStrategy
from strategies.rotation.portfolio import track_positions, save_trades
from strategies.rotation.report import generate_performance_report
from config.settings import DEFAULT_CODE_NAME_MAP, DEFAULT_BENCHMARK_NAME
def load_config(config_path: str) -> dict:
"""加载配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def main():
parser = argparse.ArgumentParser(description="ETF轮动策略回测")
parser.add_argument(
"--config",
type=str,
default="config/strategies/rotation.yaml",
help="配置文件路径",
)
parser.add_argument(
"--save-path",
type=str,
default="results/report",
help="报告保存路径前缀",
)
args = parser.parse_args()
start_time = time.time()
print("=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 加载配置
config = load_config(args.config)
# 如果未设置 end_date默认使用最新日期
if not config.get('end_date'):
from datetime import datetime
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 从配置中读取 code_list 和 code_name_map
# code_list 现在是一个字典 {代码: 名称}
code_list_config = config.get('code_list', {})
if isinstance(code_list_config, dict):
code_list = list(code_list_config.keys())
code_name_map = code_list_config
else:
# 兼容旧格式(列表)
code_list = code_list_config
code_name_map = DEFAULT_CODE_NAME_MAP
benchmark_config = config.get('benchmark', {})
benchmark_name = benchmark_config.get('name', DEFAULT_BENCHMARK_NAME)
print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(code_list)}")
print(f"回测区间: {config['start_date']} ~ {config['end_date']}")
print(f"因子类型: {config['factor_type']}")
print(f"窗口天数: {config['n_days']}")
print(f"选中数量: {config['select_num']}")
print(f"调仓周期: {config['rebalance_days']}")
print(f"交易成本: {config['trade_cost']:.2%}")
# 更新 config 中的 code_list 为列表格式
config['code_list'] = code_list
# 创建策略实例
strategy = RotationStrategy(config)
# 运行回测
print("\n" + "=" * 60)
print("开始回测...")
print("=" * 60)
backtest_result = strategy.run()
# 持仓跟踪
print("\n" + "=" * 60)
print("持仓跟踪...")
print("=" * 60)
trades_df, summary_df = track_positions(
backtest_result,
code_name_map=code_name_map,
select_num=config["select_num"],
)
save_trades(trades_df, summary_df, save_path=args.save_path)
# 生成绩效报告
print("\n" + "=" * 60)
print("生成绩效报告...")
print("=" * 60)
metrics = generate_performance_report(
backtest_result,
strategy.valid_codes,
code_name_map=code_name_map,
benchmark_name=benchmark_name,
save_path=args.save_path,
select_num=config["select_num"],
)
elapsed = time.time() - start_time
print(f"\n总耗时: {elapsed:.1f}")
return metrics
if __name__ == "__main__":
main()