refactor: 将回测逻辑整合到策略类,简化执行入口

重构 RotationStrategy:
- 添加 from_yaml() 从配置创建实例
- 添加 get_data() 获取数据
- 添加 compute_factors() 计算因子
- 添加 generate_signals() 生成信号
- 添加 run_backtest() 完整回测流程

简化 run_rotation.py:
- 从 264 行简化为 9 行
- 只做策略调用入口

执行方式:
  python run_rotation.py --config config/strategies/rotation.yaml
  python run_rotation.py --save-path results/my_rotation

代码方式:
  strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
  result = strategy.run_backtest()
This commit is contained in:
2026-05-11 23:50:40 +08:00
parent 6231401b71
commit 893a75a27f
2 changed files with 218 additions and 307 deletions

View File

@@ -1,255 +1,22 @@
#!/usr/bin/env python3
"""
ETF轮动策略回测入口(新框架)
ETF轮动策略回测入口
用法:
python run_rotation.py
python run_rotation.py --config config/strategies/rotation.yaml
python run_rotation.py --save-path results/my_rotation
"""
import sys
import time
import yaml
import argparse
from pathlib import Path
import time
from datetime import datetime
# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
def load_config(config_path: str) -> dict:
"""加载配置文件"""
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def get_data_from_archive(code_list: list, config: dict) -> dict:
"""
从归档的HybridDataSource获取数据
暂时复用旧数据源,后续迁移到新框架
"""
print("\n" + "=" * 60)
print("获取数据...")
print("=" * 60)
# 使用归档的HybridDataSource
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
ssh_config = config.get('ssh_tunnel', {})
if ssh_config.get('enabled'):
ssh_config = {
'host': ssh_config.get('host'),
'port': ssh_config.get('port', 22),
'username': ssh_config.get('username', 'root'),
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
'local_port': ssh_config.get('local_port', 1080)
}
else:
ssh_config = None
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=config.get('use_cache', True)
)
start_date = config.get('start_date', '2019-01-01')
end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
# 获取指数数据
print(f" 回测区间: {start_date} ~ {end_date}")
print(f" 候选标的: {len(code_list)}")
index_data = {}
etf_data = {}
# 获取数据
all_data = data_source.fetch_batch(code_list, start_date, end_date)
# 分离指数和ETF数据
code_list_config = config.get('code_list', {})
for code, df in all_data.items():
if df is not None and not df.empty:
index_data[code] = df
# 获取对应的ETF数据
if code in code_list_config:
etf_code = code_list_config[code].get('etf')
if etf_code:
etf_df = data_source.fetch(etf_code, start_date, end_date)
if etf_df is not None and not etf_df.empty:
etf_data[etf_code] = etf_df
print(f" 指数数据: {len(index_data)}")
print(f" ETF数据: {len(etf_data)}")
return {
'index_data': index_data,
'etf_data': etf_data,
'valid_codes': list(index_data.keys())
}
def run_backtest(config: dict, data: dict) -> dict:
"""
使用新框架运行回测
因子 → 信号 → 执行
"""
print("\n" + "=" * 60)
print("计算因子...")
print("=" * 60)
from framework import FactorRegistry, FactorCombiner, BacktestExecutor
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
# 清空注册表
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
# 初始化因子
n_days = config.get('n_days', 25)
factor = FactorRegistry.get('momentum', n_days=n_days, crash_filter=True)
combiner = FactorCombiner([factor])
print(f" 因子类型: momentum (weighted)")
print(f" 窗口天数: {n_days}")
print(f" 崩盘过滤: True")
# 计算因子值
index_data = data['index_data']
valid_codes = data['valid_codes']
factor_values = {}
for code in valid_codes:
df = index_data[code]
if len(df) >= n_days:
values = factor.compute(df)
factor_values[code] = values
print(f" 计算完成: {len(factor_values)}")
# 生成信号
print("\n" + "=" * 60)
print("生成信号...")
print("=" * 60)
select_num = config.get('select_num', 3)
rebalance_days = config.get('rebalance_days', 1)
rebalance_threshold = config.get('rebalance_threshold', 0.0)
# 构建分组映射(分散化选股)
code_list_config = config.get('code_list', {})
group_mapping = {}
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
selector = TopNSelector(
select_num=select_num,
group_mapping=group_mapping,
min_score=0.0,
rebalance_days=rebalance_days,
rebalance_threshold=rebalance_threshold
)
print(f" 选股数量: {select_num}")
print(f" 分组选股: {len(set(group_mapping.values()))} 个大类")
print(f" 调仓周期: {rebalance_days}")
print(f" 调仓阈值: {rebalance_threshold:.2%}")
# 合并因子数据为DataFrame
factor_df = pd.DataFrame(factor_values)
# 生成信号
signals = selector.generate(factor_df)
print(f" 信号日期: {len(signals)}")
# 计算日收益率数据
print("\n" + "=" * 60)
print("执行回测...")
print("=" * 60)
# 准备收益率数据
returns_data = {}
for code in valid_codes:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
returns_df.index = index_data[valid_codes[0]].index
trade_cost = config.get('trade_cost', 0.001)
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=trade_cost,
select_num=select_num
)
print(f" 初始资金: 100,000")
print(f" 交易成本: {trade_cost:.2%}")
portfolio = executor.execute(signals, returns_df)
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
# 计算绩效
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print(f"\n回测结果:")
print(f" 最终净值: {final_nav:.4f}")
print(f" 累计收益: {total_return:.2f}%")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
def generate_report(backtest_result: dict, config: dict, data: dict, save_path: str):
"""生成报告"""
print("\n" + "=" * 60)
print("生成报告...")
print("=" * 60)
import pandas as pd
result = backtest_result.get('result')
if result is None:
print(" 无回测结果,跳过报告生成")
return
# 保存净值曲线
nav_df = result[['策略净值']].copy()
nav_df.to_csv(f"{save_path}_nav.csv")
# 保存信号记录
signals = backtest_result.get('signals')
if signals is not None:
signals.to_csv(f"{save_path}_signals.csv")
# 简单统计
total_return = backtest_result.get('total_return', 0)
print(f" 报告保存至: {save_path}_*.csv")
print(f" 累计收益: {total_return:.2f}%")
# TODO: 使用 visualization/report_generator 生成完整HTML报告
from strategies.rotation.strategy import RotationStrategy
def main():
parser = argparse.ArgumentParser(description="ETF轮动策略回测(新框架)")
parser = argparse.ArgumentParser(description="ETF轮动策略回测")
parser.add_argument(
"--config",
type=str,
@@ -266,39 +33,17 @@ def main():
start_time = time.time()
print("=" * 60)
print(" ETF轮动策略 回测系统(新框架)")
print("=" * 60)
# 加载配置
config = load_config(args.config)
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
# 获取代码列表
code_list_config = config.get('code_list', {})
code_list = list(code_list_config.keys())
print(f"\n配置文件: {args.config}")
print(f"候选标的: {len(code_list)}")
# 获取数据
data = get_data_from_archive(code_list, config)
# 从配置创建策略
strategy = RotationStrategy.from_yaml(args.config)
# 运行回测
backtest_result = run_backtest(config, data)
# 生成报告
generate_report(backtest_result, config, data, args.save_path)
result = strategy.run_backtest(save_path=args.save_path)
elapsed = time.time() - start_time
print(f"\n总耗时: {elapsed:.1f}")
return backtest_result
return result
if __name__ == "__main__":
import pandas as pd # 确保pd在全局可用
main()

View File

@@ -1,80 +1,246 @@
"""
轮动策略定制实现
轮动策略完整实现
使用framework通用能力 + 定制组件
整合数据获取、因子计算、信号生成、回测执行
"""
import pandas as pd
import yaml
from datetime import datetime
from pathlib import Path
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.execution import BacktestExecutor
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
from framework.config import ConfigLoader
# 导入定制组件
from strategies.shared.factors.momentum import MomentumFactor
from strategies.shared.signals.selectors import TopNSelector
from strategies.shared.risk.controls import premium_filter_callback, holding_time_stoploss_callback
class RotationStrategy(StrategyBase):
"""
ETF轮动策略定制实现)
ETF轮动策略完整实现)
基于动量因子 + Top N选股 + 溢价过滤
基于动量因子 + Top N选股 + 分散化
使用方式:
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
result = strategy.run_backtest()
"""
name = "rotation"
select_num = 3
stoploss = -0.05
n_days = 25
rebalance_days = 1
rebalance_threshold = 0.0
trade_cost = 0.001
def init_factors(self) -> FactorCombiner:
"""初始化动量因子"""
# 清空注册表(避免重复注册)
def __init__(self, config: dict = None):
"""初始化策略"""
# 应用配置
if config:
self._apply_config(config)
self.config = config
else:
self.config = {}
# 初始化因子
FactorRegistry.clear()
# 注册定制因子
FactorRegistry.register(MomentumFactor)
self._factor = FactorRegistry.get(
'momentum',
n_days=self.n_days,
crash_filter=True
)
return FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
])
def init_signal_generator(self) -> SignalGenerator:
"""初始化Top N选股器定制"""
return TopNSelector(
# 构建分组映射(分散化选股)
self._group_mapping = self._build_group_mapping()
# 初始化信号生成器
self._selector = TopNSelector(
select_num=self.select_num,
group_mapping=self._group_mapping,
min_score=0.0,
group_by='market' # 定制:按大类分组
rebalance_days=self.rebalance_days,
rebalance_threshold=self.rebalance_threshold
)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前:溢价过滤(定制)"""
premium = kwargs.get('premium', 0)
@classmethod
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
"""从YAML配置创建策略实例"""
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 定制阈值10%
if premium > 0.10:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
# 设置结束日期
if not config.get('end_date'):
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
return True
return cls(config)
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损:根据持仓时间调整(定制)"""
# 定制规则5天/10天阈值
if position.holding_days >= 10:
return -0.03
elif position.holding_days >= 5:
return -0.05
return -0.10
def _apply_config(self, config: dict) -> None:
"""应用配置参数"""
self.select_num = config.get('select_num', self.select_num)
self.n_days = config.get('n_days', self.n_days)
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
self.trade_cost = config.get('trade_cost', self.trade_cost)
self.start_date = config.get('start_date', '2019-01-01')
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
def custom_exit(self, position: Position) -> bool:
"""自定义出场条件(定制"""
# 定制规则:亏损超过阈值强制出场
if position.profit_ratio < -0.10:
print(f"亏损超阈值,强制出场: {position.code}")
return True
return False
def _build_group_mapping(self) -> dict:
"""构建分组映射(分散化选股"""
group_mapping = {}
code_list_config = self.config.get('code_list', {})
for code, cfg in code_list_config.items():
if isinstance(cfg, dict):
group_mapping[code] = cfg.get('market', 'default')
return group_mapping
def get_data(self) -> dict:
"""获取数据(复用归档的数据源)"""
code_list_config = self.config.get('code_list', {})
code_list = list(code_list_config.keys())
if not code_list:
raise ValueError("配置中未找到 code_list")
# 使用归档的HybridDataSource
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
ssh_config = self.config.get('ssh_tunnel', {})
if ssh_config.get('enabled'):
ssh_config = {
'host': ssh_config.get('host'),
'port': ssh_config.get('port', 22),
'username': ssh_config.get('username', 'root'),
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
'local_port': ssh_config.get('local_port', 1080)
}
else:
ssh_config = None
data_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 获取数据
index_data = {}
all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date)
for code, df in all_data.items():
if df is not None and not df.empty:
index_data[code] = df
return {
'index_data': index_data,
'valid_codes': list(index_data.keys())
}
def compute_factors(self, data: dict) -> pd.DataFrame:
"""计算因子值"""
index_data = data['index_data']
valid_codes = data['valid_codes']
factor_values = {}
for code in valid_codes:
df = index_data[code]
if len(df) >= self.n_days:
values = self._factor.compute(df)
factor_values[code] = values
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""生成信号"""
return self._selector.generate(factor_df)
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
"""
完整回测流程
Args:
data: 可选,如不提供则自动获取
save_path: 报告保存路径
Returns:
回测结果字典
"""
print("\n" + "=" * 60)
print(" ETF轮动策略 回测系统")
print("=" * 60)
# 1. 获取数据
if data is None:
data = self.get_data()
valid_codes = data['valid_codes']
index_data = data['index_data']
print(f"\n候选标的: {len(valid_codes)}")
print(f"回测区间: {self.start_date} ~ {self.end_date}")
# 2. 计算因子
print("\n计算因子...")
factor_df = self.compute_factors(data)
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)}")
# 3. 生成信号
print("\n生成信号...")
signals = self.generate_signals(factor_df)
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)}")
# 4. 执行回测
print("\n执行回测...")
returns_data = {}
first_code = valid_codes[0]
for code in valid_codes:
df = index_data[code]
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
returns_df = pd.DataFrame(returns_data)
returns_df.index = index_data[first_code].index
executor = BacktestExecutor(
initial_capital=100000,
trade_cost=self.trade_cost,
select_num=self.select_num
)
portfolio = executor.execute(signals, returns_df)
# 5. 输出结果
if hasattr(portfolio, 'backtest_result'):
result = portfolio.backtest_result
final_nav = result['策略净值'].iloc[-1]
total_return = (final_nav - 1) * 100
print("\n回测结果:")
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
# 保存报告
if save_path:
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
signals.to_csv(f"{save_path}_signals.csv")
print(f" 报告保存: {save_path}_*.csv")
return {
'signals': signals,
'result': result,
'portfolio': portfolio,
'total_return': total_return
}
return {'signals': signals, 'result': None}
# 保留抽象方法实现
def init_factors(self) -> FactorCombiner:
return FactorCombiner([self._factor])
def init_signal_generator(self) -> SignalGenerator:
return self._selector