Files
etf/run_rotation.py
aszerW f70aa1d3d1 feat: 创建新框架执行入口 run_rotation.py
使用新框架的因子、信号、执行器:
- FactorRegistry + MomentumFactor(因子层)
- TopNSelector(信号层,支持分散化选股)
- BacktestExecutor(执行层,完整回测)

暂时复用归档的HybridDataSource获取数据

执行方式:
  python run_rotation.py
  python run_rotation.py --config config/strategies/rotation.yaml
2026-05-11 23:39:51 +08:00

304 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
ETF轮动策略回测入口新框架
用法:
python run_rotation.py
python run_rotation.py --config config/strategies/rotation.yaml
"""
import sys
import time
import yaml
import argparse
from pathlib import Path
from datetime import datetime
# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
def load_config(config_path: str) -> dict:
"""加载配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def get_data_from_archive(code_list: list, config: dict) -> dict:
"""
从归档的HybridDataSource获取数据
暂时复用旧数据源,后续迁移到新框架
"""
print("\n" + "=" * 60)
print("获取数据...")
print("=" * 60)
# 使用归档的HybridDataSource
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
ssh_config = config.get('ssh_tunnel', {})
if ssh_config.get('enabled'):
ssh_config = {
'host': ssh_config.get('host'),
'port': ssh_config.get('port', 22),
'username': ssh_config.get('username', 'root'),
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
'local_port': ssh_config.get('local_port', 1080)
}
else:
ssh_config = None
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=config.get('use_cache', True)
)
start_date = config.get('start_date', '2019-01-01')
end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
# 获取指数数据
print(f" 回测区间: {start_date} ~ {end_date}")
print(f" 候选标的: {len(code_list)}")
index_data = {}
etf_data = {}
# 获取数据
all_data = data_source.fetch_batch(code_list, start_date, end_date)
# 分离指数和ETF数据
code_list_config = config.get('code_list', {})
for code, df in all_data.items():
if df is not None and not df.empty:
index_data[code] = df
# 获取对应的ETF数据
if code in code_list_config:
etf_code = code_list_config[code].get('etf')
if etf_code:
etf_df = data_source.fetch(etf_code, start_date, end_date)
if etf_df is not None and not etf_df.empty:
etf_data[etf_code] = etf_df
print(f" 指数数据: {len(index_data)}")
print(f" ETF数据: {len(etf_data)}")
return {
'index_data': index_data,
'etf_data': etf_data,
'valid_codes': list(index_data.keys())
}
def run_backtest(config: dict, data: dict) -> dict:
"""
使用新框架运行回测
因子 → 信号 → 执行
"""
print("\n" + "=" * 60)
print("计算因子...")
print("=" * 60)
from framework import FactorRegistry, FactorCombiner, BacktestExecutor
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
# 清空注册表
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
# 初始化因子
n_days = config.get('n_days', 25)
factor = FactorRegistry.get('momentum', n_days=n_days, crash_filter=True)
combiner = FactorCombiner([factor])
print(f" 因子类型: momentum (weighted)")
print(f" 窗口天数: {n_days}")
print(f" 崩盘过滤: True")
# 计算因子值
index_data = data['index_data']
valid_codes = data['valid_codes']
factor_values = {}
for code in valid_codes:
df = index_data[code]
if len(df) >= n_days:
values = factor.compute(df)
factor_values[code] = values
print(f" 计算完成: {len(factor_values)}")
# 生成信号
print("\n" + "=" * 60)
print("生成信号...")
print("=" * 60)
select_num = config.get('select_num', 3)
rebalance_days = config.get('rebalance_days', 1)
rebalance_threshold = config.get('rebalance_threshold', 0.0)
# 构建分组映射(分散化选股)
code_list_config = config.get('code_list', {})
group_mapping = {}
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
selector = TopNSelector(
select_num=select_num,
group_mapping=group_mapping,
min_score=0.0,
rebalance_days=rebalance_days,
rebalance_threshold=rebalance_threshold
)
print(f" 选股数量: {select_num}")
print(f" 分组选股: {len(set(group_mapping.values()))} 个大类")
print(f" 调仓周期: {rebalance_days}")
print(f" 调仓阈值: {rebalance_threshold:.2%}")
# 合并因子数据为DataFrame
factor_df = pd.DataFrame(factor_values)
# 生成信号
signals = selector.generate(factor_df)
print(f" 信号日期: {len(signals)}")
# 计算日收益率数据
print("\n" + "=" * 60)
print("执行回测...")
print("=" * 60)
# 准备收益率数据
returns_data = {}
for code in valid_codes:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
returns_df.index = index_data[valid_codes[0]].index
trade_cost = config.get('trade_cost', 0.001)
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=trade_cost,
select_num=select_num
)
print(f" 初始资金: 100,000")
print(f" 交易成本: {trade_cost:.2%}")
portfolio = executor.execute(signals, returns_df)
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
# 计算绩效
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print(f"\n回测结果:")
print(f" 最终净值: {final_nav:.4f}")
print(f" 累计收益: {total_return:.2f}%")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
def generate_report(backtest_result: dict, config: dict, data: dict, save_path: str):
"""生成报告"""
print("\n" + "=" * 60)
print("生成报告...")
print("=" * 60)
import pandas as pd
result = backtest_result.get('result')
if result is None:
print(" 无回测结果,跳过报告生成")
return
# 保存净值曲线
nav_df = result[['策略净值']].copy()
nav_df.to_csv(f"{save_path}_nav.csv")
# 保存信号记录
signals = backtest_result.get('signals')
if signals is not None:
signals.to_csv(f"{save_path}_signals.csv")
# 简单统计
total_return = backtest_result.get('total_return', 0)
print(f" 报告保存至: {save_path}_*.csv")
print(f" 累计收益: {total_return:.2f}%")
# TODO: 使用 visualization/report_generator 生成完整HTML报告
def main():
parser = argparse.ArgumentParser(description="ETF轮动策略回测新框架")
parser.add_argument(
"--config",
type=str,
default="config/strategies/rotation.yaml",
help="配置文件路径",
)
parser.add_argument(
"--save-path",
type=str,
default="results/rotation",
help="报告保存路径前缀",
)
args = parser.parse_args()
start_time = time.time()
print("=" * 60)
print(" ETF轮动策略 回测系统(新框架)")
print("=" * 60)
# 加载配置
config = load_config(args.config)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 获取代码列表
code_list_config = config.get('code_list', {})
code_list = list(code_list_config.keys())
print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(code_list)}")
# 获取数据
data = get_data_from_archive(code_list, config)
# 运行回测
backtest_result = run_backtest(config, data)
# 生成报告
generate_report(backtest_result, config, data, args.save_path)
elapsed = time.time() - start_time
print(f"\n总耗时: {elapsed:.1f}")
return backtest_result
if __name__ == "__main__":
import pandas as pd # 确保pd在全局可用
main()