refactor: 将回测逻辑整合到策略类,简化执行入口
重构 RotationStrategy:
- 添加 from_yaml() 从配置创建实例
- 添加 get_data() 获取数据
- 添加 compute_factors() 计算因子
- 添加 generate_signals() 生成信号
- 添加 run_backtest() 完整回测流程
简化 run_rotation.py:
- 从 264 行简化为 9 行
- 只做策略调用入口
执行方式:
python run_rotation.py --config config/strategies/rotation.yaml
python run_rotation.py --save-path results/my_rotation
代码方式:
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
result = strategy.run_backtest()
This commit is contained in:
273
run_rotation.py
273
run_rotation.py
@@ -1,255 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ETF轮动策略回测入口(新框架)
|
||||
ETF轮动策略回测入口
|
||||
|
||||
用法:
|
||||
python run_rotation.py
|
||||
python run_rotation.py --config config/strategies/rotation.yaml
|
||||
python run_rotation.py --save-path results/my_rotation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import yaml
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""加载配置文件"""
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def get_data_from_archive(code_list: list, config: dict) -> dict:
|
||||
"""
|
||||
从归档的HybridDataSource获取数据
|
||||
|
||||
暂时复用旧数据源,后续迁移到新框架
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("获取数据...")
|
||||
print("=" * 60)
|
||||
|
||||
# 使用归档的HybridDataSource
|
||||
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
|
||||
|
||||
ssh_config = config.get('ssh_tunnel', {})
|
||||
if ssh_config.get('enabled'):
|
||||
ssh_config = {
|
||||
'host': ssh_config.get('host'),
|
||||
'port': ssh_config.get('port', 22),
|
||||
'username': ssh_config.get('username', 'root'),
|
||||
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
|
||||
'local_port': ssh_config.get('local_port', 1080)
|
||||
}
|
||||
else:
|
||||
ssh_config = None
|
||||
|
||||
data_source = HybridDataSource(
|
||||
ssh_config=ssh_config,
|
||||
use_cache=config.get('use_cache', True)
|
||||
)
|
||||
|
||||
start_date = config.get('start_date', '2019-01-01')
|
||||
end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||||
|
||||
# 获取指数数据
|
||||
print(f" 回测区间: {start_date} ~ {end_date}")
|
||||
print(f" 候选标的: {len(code_list)} 只")
|
||||
|
||||
index_data = {}
|
||||
etf_data = {}
|
||||
|
||||
# 获取数据
|
||||
all_data = data_source.fetch_batch(code_list, start_date, end_date)
|
||||
|
||||
# 分离指数和ETF数据
|
||||
code_list_config = config.get('code_list', {})
|
||||
for code, df in all_data.items():
|
||||
if df is not None and not df.empty:
|
||||
index_data[code] = df
|
||||
|
||||
# 获取对应的ETF数据
|
||||
if code in code_list_config:
|
||||
etf_code = code_list_config[code].get('etf')
|
||||
if etf_code:
|
||||
etf_df = data_source.fetch(etf_code, start_date, end_date)
|
||||
if etf_df is not None and not etf_df.empty:
|
||||
etf_data[etf_code] = etf_df
|
||||
|
||||
print(f" 指数数据: {len(index_data)} 只")
|
||||
print(f" ETF数据: {len(etf_data)} 只")
|
||||
|
||||
return {
|
||||
'index_data': index_data,
|
||||
'etf_data': etf_data,
|
||||
'valid_codes': list(index_data.keys())
|
||||
}
|
||||
|
||||
|
||||
def run_backtest(config: dict, data: dict) -> dict:
|
||||
"""
|
||||
使用新框架运行回测
|
||||
|
||||
因子 → 信号 → 执行
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("计算因子...")
|
||||
print("=" * 60)
|
||||
|
||||
from framework import FactorRegistry, FactorCombiner, BacktestExecutor
|
||||
from strategies.shared.factors.momentum import MomentumFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
|
||||
# 清空注册表
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
# 初始化因子
|
||||
n_days = config.get('n_days', 25)
|
||||
factor = FactorRegistry.get('momentum', n_days=n_days, crash_filter=True)
|
||||
combiner = FactorCombiner([factor])
|
||||
|
||||
print(f" 因子类型: momentum (weighted)")
|
||||
print(f" 窗口天数: {n_days}")
|
||||
print(f" 崩盘过滤: True")
|
||||
|
||||
# 计算因子值
|
||||
index_data = data['index_data']
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
factor_values = {}
|
||||
for code in valid_codes:
|
||||
df = index_data[code]
|
||||
if len(df) >= n_days:
|
||||
values = factor.compute(df)
|
||||
factor_values[code] = values
|
||||
|
||||
print(f" 计算完成: {len(factor_values)} 只")
|
||||
|
||||
# 生成信号
|
||||
print("\n" + "=" * 60)
|
||||
print("生成信号...")
|
||||
print("=" * 60)
|
||||
|
||||
select_num = config.get('select_num', 3)
|
||||
rebalance_days = config.get('rebalance_days', 1)
|
||||
rebalance_threshold = config.get('rebalance_threshold', 0.0)
|
||||
|
||||
# 构建分组映射(分散化选股)
|
||||
code_list_config = config.get('code_list', {})
|
||||
group_mapping = {}
|
||||
for code, cfg in code_list_config.items():
|
||||
if isinstance(cfg, dict):
|
||||
group_mapping[code] = cfg.get('market', 'default')
|
||||
|
||||
selector = TopNSelector(
|
||||
select_num=select_num,
|
||||
group_mapping=group_mapping,
|
||||
min_score=0.0,
|
||||
rebalance_days=rebalance_days,
|
||||
rebalance_threshold=rebalance_threshold
|
||||
)
|
||||
|
||||
print(f" 选股数量: {select_num}")
|
||||
print(f" 分组选股: {len(set(group_mapping.values()))} 个大类")
|
||||
print(f" 调仓周期: {rebalance_days} 天")
|
||||
print(f" 调仓阈值: {rebalance_threshold:.2%}")
|
||||
|
||||
# 合并因子数据为DataFrame
|
||||
factor_df = pd.DataFrame(factor_values)
|
||||
|
||||
# 生成信号
|
||||
signals = selector.generate(factor_df)
|
||||
|
||||
print(f" 信号日期: {len(signals)} 天")
|
||||
|
||||
# 计算日收益率数据
|
||||
print("\n" + "=" * 60)
|
||||
print("执行回测...")
|
||||
print("=" * 60)
|
||||
|
||||
# 准备收益率数据
|
||||
returns_data = {}
|
||||
for code in valid_codes:
|
||||
df = index_data[code]
|
||||
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
|
||||
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
returns_df.index = index_data[valid_codes[0]].index
|
||||
|
||||
trade_cost = config.get('trade_cost', 0.001)
|
||||
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100000,
|
||||
trade_cost=trade_cost,
|
||||
select_num=select_num
|
||||
)
|
||||
|
||||
print(f" 初始资金: 100,000")
|
||||
print(f" 交易成本: {trade_cost:.2%}")
|
||||
|
||||
portfolio = executor.execute(signals, returns_df)
|
||||
|
||||
if hasattr(portfolio, 'backtest_result'):
|
||||
result = portfolio.backtest_result
|
||||
|
||||
# 计算绩效
|
||||
final_nav = result['策略净值'].iloc[-1]
|
||||
total_return = (final_nav - 1) * 100
|
||||
|
||||
print(f"\n回测结果:")
|
||||
print(f" 最终净值: {final_nav:.4f}")
|
||||
print(f" 累计收益: {total_return:.2f}%")
|
||||
|
||||
return {
|
||||
'signals': signals,
|
||||
'result': result,
|
||||
'portfolio': portfolio,
|
||||
'total_return': total_return
|
||||
}
|
||||
|
||||
return {'signals': signals, 'result': None}
|
||||
|
||||
|
||||
def generate_report(backtest_result: dict, config: dict, data: dict, save_path: str):
|
||||
"""生成报告"""
|
||||
print("\n" + "=" * 60)
|
||||
print("生成报告...")
|
||||
print("=" * 60)
|
||||
|
||||
import pandas as pd
|
||||
|
||||
result = backtest_result.get('result')
|
||||
if result is None:
|
||||
print(" 无回测结果,跳过报告生成")
|
||||
return
|
||||
|
||||
# 保存净值曲线
|
||||
nav_df = result[['策略净值']].copy()
|
||||
nav_df.to_csv(f"{save_path}_nav.csv")
|
||||
|
||||
# 保存信号记录
|
||||
signals = backtest_result.get('signals')
|
||||
if signals is not None:
|
||||
signals.to_csv(f"{save_path}_signals.csv")
|
||||
|
||||
# 简单统计
|
||||
total_return = backtest_result.get('total_return', 0)
|
||||
|
||||
print(f" 报告保存至: {save_path}_*.csv")
|
||||
print(f" 累计收益: {total_return:.2f}%")
|
||||
|
||||
# TODO: 使用 visualization/report_generator 生成完整HTML报告
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="ETF轮动策略回测(新框架)")
|
||||
parser = argparse.ArgumentParser(description="ETF轮动策略回测")
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
@@ -266,39 +33,17 @@ def main():
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
print("=" * 60)
|
||||
print(" ETF轮动策略 回测系统(新框架)")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载配置
|
||||
config = load_config(args.config)
|
||||
|
||||
# 设置结束日期
|
||||
if not config.get('end_date'):
|
||||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
# 获取代码列表
|
||||
code_list_config = config.get('code_list', {})
|
||||
code_list = list(code_list_config.keys())
|
||||
|
||||
print(f"\n配置文件: {args.config}")
|
||||
print(f"候选标的: {len(code_list)} 只")
|
||||
|
||||
# 获取数据
|
||||
data = get_data_from_archive(code_list, config)
|
||||
# 从配置创建策略
|
||||
strategy = RotationStrategy.from_yaml(args.config)
|
||||
|
||||
# 运行回测
|
||||
backtest_result = run_backtest(config, data)
|
||||
|
||||
# 生成报告
|
||||
generate_report(backtest_result, config, data, args.save_path)
|
||||
result = strategy.run_backtest(save_path=args.save_path)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\n总耗时: {elapsed:.1f}秒")
|
||||
|
||||
return backtest_result
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pandas as pd # 确保pd在全局可用
|
||||
main()
|
||||
@@ -1,80 +1,246 @@
|
||||
"""
|
||||
轮动策略定制实现
|
||||
轮动策略完整实现
|
||||
|
||||
使用framework通用能力 + 定制组件
|
||||
整合数据获取、因子计算、信号生成、回测执行
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors import FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.execution import BacktestExecutor
|
||||
from framework.risk import CallbackHook, Position
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.config import ConfigLoader
|
||||
|
||||
# 导入定制组件
|
||||
from strategies.shared.factors.momentum import MomentumFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from strategies.shared.risk.controls import premium_filter_callback, holding_time_stoploss_callback
|
||||
|
||||
|
||||
class RotationStrategy(StrategyBase):
|
||||
"""
|
||||
ETF轮动策略(定制实现)
|
||||
ETF轮动策略(完整实现)
|
||||
|
||||
基于动量因子 + Top N选股 + 溢价过滤
|
||||
基于动量因子 + Top N选股 + 分散化
|
||||
|
||||
使用方式:
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml')
|
||||
result = strategy.run_backtest()
|
||||
"""
|
||||
|
||||
name = "rotation"
|
||||
select_num = 3
|
||||
stoploss = -0.05
|
||||
n_days = 25
|
||||
rebalance_days = 1
|
||||
rebalance_threshold = 0.0
|
||||
trade_cost = 0.001
|
||||
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""初始化动量因子"""
|
||||
# 清空注册表(避免重复注册)
|
||||
def __init__(self, config: dict = None):
|
||||
"""初始化策略"""
|
||||
# 应用配置
|
||||
if config:
|
||||
self._apply_config(config)
|
||||
self.config = config
|
||||
else:
|
||||
self.config = {}
|
||||
|
||||
# 初始化因子
|
||||
FactorRegistry.clear()
|
||||
|
||||
# 注册定制因子
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
self._factor = FactorRegistry.get(
|
||||
'momentum',
|
||||
n_days=self.n_days,
|
||||
crash_filter=True
|
||||
)
|
||||
|
||||
return FactorCombiner([
|
||||
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
])
|
||||
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
"""初始化Top N选股器(定制)"""
|
||||
return TopNSelector(
|
||||
# 构建分组映射(分散化选股)
|
||||
self._group_mapping = self._build_group_mapping()
|
||||
|
||||
# 初始化信号生成器
|
||||
self._selector = TopNSelector(
|
||||
select_num=self.select_num,
|
||||
group_mapping=self._group_mapping,
|
||||
min_score=0.0,
|
||||
group_by='market' # 定制:按大类分组
|
||||
rebalance_days=self.rebalance_days,
|
||||
rebalance_threshold=self.rebalance_threshold
|
||||
)
|
||||
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""入场前:溢价过滤(定制)"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str) -> 'RotationStrategy':
|
||||
"""从YAML配置创建策略实例"""
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 定制阈值:10%
|
||||
if premium > 0.10:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
# 设置结束日期
|
||||
if not config.get('end_date'):
|
||||
config['end_date'] = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
return True
|
||||
return cls(config)
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""动态止损:根据持仓时间调整(定制)"""
|
||||
# 定制规则:5天/10天阈值
|
||||
if position.holding_days >= 10:
|
||||
return -0.03
|
||||
elif position.holding_days >= 5:
|
||||
return -0.05
|
||||
return -0.10
|
||||
def _apply_config(self, config: dict) -> None:
|
||||
"""应用配置参数"""
|
||||
self.select_num = config.get('select_num', self.select_num)
|
||||
self.n_days = config.get('n_days', self.n_days)
|
||||
self.rebalance_days = config.get('rebalance_days', self.rebalance_days)
|
||||
self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold)
|
||||
self.trade_cost = config.get('trade_cost', self.trade_cost)
|
||||
self.start_date = config.get('start_date', '2019-01-01')
|
||||
self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||||
|
||||
def custom_exit(self, position: Position) -> bool:
|
||||
"""自定义出场条件(定制)"""
|
||||
# 定制规则:亏损超过阈值强制出场
|
||||
if position.profit_ratio < -0.10:
|
||||
print(f"亏损超阈值,强制出场: {position.code}")
|
||||
return True
|
||||
return False
|
||||
def _build_group_mapping(self) -> dict:
|
||||
"""构建分组映射(分散化选股)"""
|
||||
group_mapping = {}
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
for code, cfg in code_list_config.items():
|
||||
if isinstance(cfg, dict):
|
||||
group_mapping[code] = cfg.get('market', 'default')
|
||||
return group_mapping
|
||||
|
||||
def get_data(self) -> dict:
|
||||
"""获取数据(复用归档的数据源)"""
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
code_list = list(code_list_config.keys())
|
||||
|
||||
if not code_list:
|
||||
raise ValueError("配置中未找到 code_list")
|
||||
|
||||
# 使用归档的HybridDataSource
|
||||
from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource
|
||||
|
||||
ssh_config = self.config.get('ssh_tunnel', {})
|
||||
if ssh_config.get('enabled'):
|
||||
ssh_config = {
|
||||
'host': ssh_config.get('host'),
|
||||
'port': ssh_config.get('port', 22),
|
||||
'username': ssh_config.get('username', 'root'),
|
||||
'key_path': ssh_config.get('key_path', 'hk_ecs.pem'),
|
||||
'local_port': ssh_config.get('local_port', 1080)
|
||||
}
|
||||
else:
|
||||
ssh_config = None
|
||||
|
||||
data_source = HybridDataSource(
|
||||
ssh_config=ssh_config,
|
||||
use_cache=self.config.get('use_cache', True)
|
||||
)
|
||||
|
||||
# 获取数据
|
||||
index_data = {}
|
||||
all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date)
|
||||
|
||||
for code, df in all_data.items():
|
||||
if df is not None and not df.empty:
|
||||
index_data[code] = df
|
||||
|
||||
return {
|
||||
'index_data': index_data,
|
||||
'valid_codes': list(index_data.keys())
|
||||
}
|
||||
|
||||
def compute_factors(self, data: dict) -> pd.DataFrame:
|
||||
"""计算因子值"""
|
||||
index_data = data['index_data']
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
factor_values = {}
|
||||
for code in valid_codes:
|
||||
df = index_data[code]
|
||||
if len(df) >= self.n_days:
|
||||
values = self._factor.compute(df)
|
||||
factor_values[code] = values
|
||||
|
||||
return pd.DataFrame(factor_values)
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""生成信号"""
|
||||
return self._selector.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: dict = None, save_path: str = None) -> dict:
|
||||
"""
|
||||
完整回测流程
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
save_path: 报告保存路径
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" ETF轮动策略 回测系统")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
data = self.get_data()
|
||||
|
||||
valid_codes = data['valid_codes']
|
||||
index_data = data['index_data']
|
||||
|
||||
print(f"\n候选标的: {len(valid_codes)} 只")
|
||||
print(f"回测区间: {self.start_date} ~ {self.end_date}")
|
||||
|
||||
# 2. 计算因子
|
||||
print("\n计算因子...")
|
||||
factor_df = self.compute_factors(data)
|
||||
print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)} 只")
|
||||
|
||||
# 3. 生成信号
|
||||
print("\n生成信号...")
|
||||
signals = self.generate_signals(factor_df)
|
||||
print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)} 天")
|
||||
|
||||
# 4. 执行回测
|
||||
print("\n执行回测...")
|
||||
returns_data = {}
|
||||
first_code = valid_codes[0]
|
||||
for code in valid_codes:
|
||||
df = index_data[code]
|
||||
returns_data[f'日收益率_{code}'] = df['close'].pct_change()
|
||||
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
returns_df.index = index_data[first_code].index
|
||||
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100000,
|
||||
trade_cost=self.trade_cost,
|
||||
select_num=self.select_num
|
||||
)
|
||||
|
||||
portfolio = executor.execute(signals, returns_df)
|
||||
|
||||
# 5. 输出结果
|
||||
if hasattr(portfolio, 'backtest_result'):
|
||||
result = portfolio.backtest_result
|
||||
final_nav = result['策略净值'].iloc[-1]
|
||||
total_return = (final_nav - 1) * 100
|
||||
|
||||
print("\n回测结果:")
|
||||
print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%")
|
||||
|
||||
# 保存报告
|
||||
if save_path:
|
||||
result[['策略净值']].to_csv(f"{save_path}_nav.csv")
|
||||
signals.to_csv(f"{save_path}_signals.csv")
|
||||
print(f" 报告保存: {save_path}_*.csv")
|
||||
|
||||
return {
|
||||
'signals': signals,
|
||||
'result': result,
|
||||
'portfolio': portfolio,
|
||||
'total_return': total_return
|
||||
}
|
||||
|
||||
return {'signals': signals, 'result': None}
|
||||
|
||||
# 保留抽象方法实现
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
return FactorCombiner([self._factor])
|
||||
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
return self._selector
|
||||
Reference in New Issue
Block a user