From 893a75a27f64b6f2ce7e2936dff3ed00758a2b76 Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 11 May 2026 23:50:40 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B0=86=E5=9B=9E=E6=B5=8B?= =?UTF-8?q?=E9=80=BB=E8=BE=91=E6=95=B4=E5=90=88=E5=88=B0=E7=AD=96=E7=95=A5?= =?UTF-8?q?=E7=B1=BB=EF=BC=8C=E7=AE=80=E5=8C=96=E6=89=A7=E8=A1=8C=E5=85=A5?= =?UTF-8?q?=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 重构 RotationStrategy: - 添加 from_yaml() 从配置创建实例 - 添加 get_data() 获取数据 - 添加 compute_factors() 计算因子 - 添加 generate_signals() 生成信号 - 添加 run_backtest() 完整回测流程 简化 run_rotation.py: - 从 264 行简化为 9 行 - 只做策略调用入口 执行方式: python run_rotation.py --config config/strategies/rotation.yaml python run_rotation.py --save-path results/my_rotation 代码方式: strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml') result = strategy.run_backtest() --- run_rotation.py | 273 ++------------------------------ strategies/rotation/strategy.py | 252 ++++++++++++++++++++++++----- 2 files changed, 218 insertions(+), 307 deletions(-) diff --git a/run_rotation.py b/run_rotation.py index 08e5a25..6d47948 100644 --- a/run_rotation.py +++ b/run_rotation.py @@ -1,255 +1,22 @@ #!/usr/bin/env python3 """ -ETF轮动策略回测入口(新框架) +ETF轮动策略回测入口 用法: python run_rotation.py python run_rotation.py --config config/strategies/rotation.yaml + python run_rotation.py --save-path results/my_rotation """ -import sys -import time -import yaml import argparse -from pathlib import Path +import time from datetime import datetime -# 添加项目根目录到路径 -project_root = Path(__file__).parent -sys.path.insert(0, str(project_root)) - - -def load_config(config_path: str) -> dict: - """加载配置文件""" - with open(config_path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) - - -def get_data_from_archive(code_list: list, config: dict) -> dict: - """ - 从归档的HybridDataSource获取数据 - - 暂时复用旧数据源,后续迁移到新框架 - """ - print("\n" + "=" * 60) - print("获取数据...") - print("=" * 60) - - # 使用归档的HybridDataSource - from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource - - ssh_config = config.get('ssh_tunnel', {}) - if ssh_config.get('enabled'): - ssh_config = { - 'host': ssh_config.get('host'), - 'port': ssh_config.get('port', 22), - 'username': ssh_config.get('username', 'root'), - 'key_path': ssh_config.get('key_path', 'hk_ecs.pem'), - 'local_port': ssh_config.get('local_port', 1080) - } - else: - ssh_config = None - - data_source = HybridDataSource( - ssh_config=ssh_config, - use_cache=config.get('use_cache', True) - ) - - start_date = config.get('start_date', '2019-01-01') - end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d')) - - # 获取指数数据 - print(f" 回测区间: {start_date} ~ {end_date}") - print(f" 候选标的: {len(code_list)} 只") - - index_data = {} - etf_data = {} - - # 获取数据 - all_data = data_source.fetch_batch(code_list, start_date, end_date) - - # 分离指数和ETF数据 - code_list_config = config.get('code_list', {}) - for code, df in all_data.items(): - if df is not None and not df.empty: - index_data[code] = df - - # 获取对应的ETF数据 - if code in code_list_config: - etf_code = code_list_config[code].get('etf') - if etf_code: - etf_df = data_source.fetch(etf_code, start_date, end_date) - if etf_df is not None and not etf_df.empty: - etf_data[etf_code] = etf_df - - print(f" 指数数据: {len(index_data)} 只") - print(f" ETF数据: {len(etf_data)} 只") - - return { - 'index_data': index_data, - 'etf_data': etf_data, - 'valid_codes': list(index_data.keys()) - } - - -def run_backtest(config: dict, data: dict) -> dict: - """ - 使用新框架运行回测 - - 因子 → 信号 → 执行 - """ - print("\n" + "=" * 60) - print("计算因子...") - print("=" * 60) - - from framework import FactorRegistry, FactorCombiner, BacktestExecutor - from strategies.shared.factors.momentum import MomentumFactor - from strategies.shared.signals.selectors import TopNSelector - - # 清空注册表 - FactorRegistry.clear() - FactorRegistry.register(MomentumFactor) - - # 初始化因子 - n_days = config.get('n_days', 25) - factor = FactorRegistry.get('momentum', n_days=n_days, crash_filter=True) - combiner = FactorCombiner([factor]) - - print(f" 因子类型: momentum (weighted)") - print(f" 窗口天数: {n_days}") - print(f" 崩盘过滤: True") - - # 计算因子值 - index_data = data['index_data'] - valid_codes = data['valid_codes'] - - factor_values = {} - for code in valid_codes: - df = index_data[code] - if len(df) >= n_days: - values = factor.compute(df) - factor_values[code] = values - - print(f" 计算完成: {len(factor_values)} 只") - - # 生成信号 - print("\n" + "=" * 60) - print("生成信号...") - print("=" * 60) - - select_num = config.get('select_num', 3) - rebalance_days = config.get('rebalance_days', 1) - rebalance_threshold = config.get('rebalance_threshold', 0.0) - - # 构建分组映射(分散化选股) - code_list_config = config.get('code_list', {}) - group_mapping = {} - for code, cfg in code_list_config.items(): - if isinstance(cfg, dict): - group_mapping[code] = cfg.get('market', 'default') - - selector = TopNSelector( - select_num=select_num, - group_mapping=group_mapping, - min_score=0.0, - rebalance_days=rebalance_days, - rebalance_threshold=rebalance_threshold - ) - - print(f" 选股数量: {select_num}") - print(f" 分组选股: {len(set(group_mapping.values()))} 个大类") - print(f" 调仓周期: {rebalance_days} 天") - print(f" 调仓阈值: {rebalance_threshold:.2%}") - - # 合并因子数据为DataFrame - factor_df = pd.DataFrame(factor_values) - - # 生成信号 - signals = selector.generate(factor_df) - - print(f" 信号日期: {len(signals)} 天") - - # 计算日收益率数据 - print("\n" + "=" * 60) - print("执行回测...") - print("=" * 60) - - # 准备收益率数据 - returns_data = {} - for code in valid_codes: - df = index_data[code] - returns_data[f'日收益率_{code}'] = df['close'].pct_change() - - returns_df = pd.DataFrame(returns_data) - returns_df.index = index_data[valid_codes[0]].index - - trade_cost = config.get('trade_cost', 0.001) - - executor = BacktestExecutor( - initial_capital=100000, - trade_cost=trade_cost, - select_num=select_num - ) - - print(f" 初始资金: 100,000") - print(f" 交易成本: {trade_cost:.2%}") - - portfolio = executor.execute(signals, returns_df) - - if hasattr(portfolio, 'backtest_result'): - result = portfolio.backtest_result - - # 计算绩效 - final_nav = result['策略净值'].iloc[-1] - total_return = (final_nav - 1) * 100 - - print(f"\n回测结果:") - print(f" 最终净值: {final_nav:.4f}") - print(f" 累计收益: {total_return:.2f}%") - - return { - 'signals': signals, - 'result': result, - 'portfolio': portfolio, - 'total_return': total_return - } - - return {'signals': signals, 'result': None} - - -def generate_report(backtest_result: dict, config: dict, data: dict, save_path: str): - """生成报告""" - print("\n" + "=" * 60) - print("生成报告...") - print("=" * 60) - - import pandas as pd - - result = backtest_result.get('result') - if result is None: - print(" 无回测结果,跳过报告生成") - return - - # 保存净值曲线 - nav_df = result[['策略净值']].copy() - nav_df.to_csv(f"{save_path}_nav.csv") - - # 保存信号记录 - signals = backtest_result.get('signals') - if signals is not None: - signals.to_csv(f"{save_path}_signals.csv") - - # 简单统计 - total_return = backtest_result.get('total_return', 0) - - print(f" 报告保存至: {save_path}_*.csv") - print(f" 累计收益: {total_return:.2f}%") - - # TODO: 使用 visualization/report_generator 生成完整HTML报告 +from strategies.rotation.strategy import RotationStrategy def main(): - parser = argparse.ArgumentParser(description="ETF轮动策略回测(新框架)") + parser = argparse.ArgumentParser(description="ETF轮动策略回测") parser.add_argument( "--config", type=str, @@ -266,39 +33,17 @@ def main(): start_time = time.time() - print("=" * 60) - print(" ETF轮动策略 回测系统(新框架)") - print("=" * 60) - - # 加载配置 - config = load_config(args.config) - - # 设置结束日期 - if not config.get('end_date'): - config['end_date'] = datetime.now().strftime('%Y-%m-%d') - - # 获取代码列表 - code_list_config = config.get('code_list', {}) - code_list = list(code_list_config.keys()) - - print(f"\n配置文件: {args.config}") - print(f"候选标的: {len(code_list)} 只") - - # 获取数据 - data = get_data_from_archive(code_list, config) + # 从配置创建策略 + strategy = RotationStrategy.from_yaml(args.config) # 运行回测 - backtest_result = run_backtest(config, data) - - # 生成报告 - generate_report(backtest_result, config, data, args.save_path) + result = strategy.run_backtest(save_path=args.save_path) elapsed = time.time() - start_time print(f"\n总耗时: {elapsed:.1f}秒") - return backtest_result + return result if __name__ == "__main__": - import pandas as pd # 确保pd在全局可用 main() \ No newline at end of file diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index 13f7463..cfd7939 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -1,80 +1,246 @@ """ -轮动策略定制实现 +轮动策略完整实现 -使用framework通用能力 + 定制组件 +整合数据获取、因子计算、信号生成、回测执行 """ import pandas as pd import yaml from datetime import datetime +from pathlib import Path -from framework.factors import FactorBase, FactorRegistry, FactorCombiner +from framework.factors import FactorRegistry, FactorCombiner from framework.signals import SignalGenerator +from framework.execution import BacktestExecutor from framework.risk import CallbackHook, Position from framework.strategy import StrategyBase -from framework.config import ConfigLoader # 导入定制组件 from strategies.shared.factors.momentum import MomentumFactor from strategies.shared.signals.selectors import TopNSelector -from strategies.shared.risk.controls import premium_filter_callback, holding_time_stoploss_callback class RotationStrategy(StrategyBase): """ - ETF轮动策略(定制实现) + ETF轮动策略(完整实现) - 基于动量因子 + Top N选股 + 溢价过滤 + 基于动量因子 + Top N选股 + 分散化 + + 使用方式: + from strategies.rotation.strategy import RotationStrategy + strategy = RotationStrategy.from_yaml('config/strategies/rotation.yaml') + result = strategy.run_backtest() """ name = "rotation" select_num = 3 stoploss = -0.05 + n_days = 25 + rebalance_days = 1 + rebalance_threshold = 0.0 + trade_cost = 0.001 - def init_factors(self) -> FactorCombiner: - """初始化动量因子""" - # 清空注册表(避免重复注册) + def __init__(self, config: dict = None): + """初始化策略""" + # 应用配置 + if config: + self._apply_config(config) + self.config = config + else: + self.config = {} + + # 初始化因子 FactorRegistry.clear() - - # 注册定制因子 FactorRegistry.register(MomentumFactor) + self._factor = FactorRegistry.get( + 'momentum', + n_days=self.n_days, + crash_filter=True + ) - return FactorCombiner([ - FactorRegistry.get('momentum', n_days=25, crash_filter=True) - ]) - - def init_signal_generator(self) -> SignalGenerator: - """初始化Top N选股器(定制)""" - return TopNSelector( + # 构建分组映射(分散化选股) + self._group_mapping = self._build_group_mapping() + + # 初始化信号生成器 + self._selector = TopNSelector( select_num=self.select_num, + group_mapping=self._group_mapping, min_score=0.0, - group_by='market' # 定制:按大类分组 + rebalance_days=self.rebalance_days, + rebalance_threshold=self.rebalance_threshold ) - def before_entry(self, code: str, price: float, **kwargs) -> bool: - """入场前:溢价过滤(定制)""" - premium = kwargs.get('premium', 0) + @classmethod + def from_yaml(cls, config_path: str) -> 'RotationStrategy': + """从YAML配置创建策略实例""" + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) - # 定制阈值:10% - if premium > 0.10: - print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") - return False + # 设置结束日期 + if not config.get('end_date'): + config['end_date'] = datetime.now().strftime('%Y-%m-%d') - return True + return cls(config) - def dynamic_stoploss(self, position: Position) -> float: - """动态止损:根据持仓时间调整(定制)""" - # 定制规则:5天/10天阈值 - if position.holding_days >= 10: - return -0.03 - elif position.holding_days >= 5: - return -0.05 - return -0.10 + def _apply_config(self, config: dict) -> None: + """应用配置参数""" + self.select_num = config.get('select_num', self.select_num) + self.n_days = config.get('n_days', self.n_days) + self.rebalance_days = config.get('rebalance_days', self.rebalance_days) + self.rebalance_threshold = config.get('rebalance_threshold', self.rebalance_threshold) + self.trade_cost = config.get('trade_cost', self.trade_cost) + self.start_date = config.get('start_date', '2019-01-01') + self.end_date = config.get('end_date', datetime.now().strftime('%Y-%m-%d')) - def custom_exit(self, position: Position) -> bool: - """自定义出场条件(定制)""" - # 定制规则:亏损超过阈值强制出场 - if position.profit_ratio < -0.10: - print(f"亏损超阈值,强制出场: {position.code}") - return True - return False \ No newline at end of file + def _build_group_mapping(self) -> dict: + """构建分组映射(分散化选股)""" + group_mapping = {} + code_list_config = self.config.get('code_list', {}) + for code, cfg in code_list_config.items(): + if isinstance(cfg, dict): + group_mapping[code] = cfg.get('market', 'default') + return group_mapping + + def get_data(self) -> dict: + """获取数据(复用归档的数据源)""" + code_list_config = self.config.get('code_list', {}) + code_list = list(code_list_config.keys()) + + if not code_list: + raise ValueError("配置中未找到 code_list") + + # 使用归档的HybridDataSource + from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource + + ssh_config = self.config.get('ssh_tunnel', {}) + if ssh_config.get('enabled'): + ssh_config = { + 'host': ssh_config.get('host'), + 'port': ssh_config.get('port', 22), + 'username': ssh_config.get('username', 'root'), + 'key_path': ssh_config.get('key_path', 'hk_ecs.pem'), + 'local_port': ssh_config.get('local_port', 1080) + } + else: + ssh_config = None + + data_source = HybridDataSource( + ssh_config=ssh_config, + use_cache=self.config.get('use_cache', True) + ) + + # 获取数据 + index_data = {} + all_data = data_source.fetch_batch(code_list, self.start_date, self.end_date) + + for code, df in all_data.items(): + if df is not None and not df.empty: + index_data[code] = df + + return { + 'index_data': index_data, + 'valid_codes': list(index_data.keys()) + } + + def compute_factors(self, data: dict) -> pd.DataFrame: + """计算因子值""" + index_data = data['index_data'] + valid_codes = data['valid_codes'] + + factor_values = {} + for code in valid_codes: + df = index_data[code] + if len(df) >= self.n_days: + values = self._factor.compute(df) + factor_values[code] = values + + return pd.DataFrame(factor_values) + + def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: + """生成信号""" + return self._selector.generate(factor_df) + + def run_backtest(self, data: dict = None, save_path: str = None) -> dict: + """ + 完整回测流程 + + Args: + data: 可选,如不提供则自动获取 + save_path: 报告保存路径 + + Returns: + 回测结果字典 + """ + print("\n" + "=" * 60) + print(" ETF轮动策略 回测系统") + print("=" * 60) + + # 1. 获取数据 + if data is None: + data = self.get_data() + + valid_codes = data['valid_codes'] + index_data = data['index_data'] + + print(f"\n候选标的: {len(valid_codes)} 只") + print(f"回测区间: {self.start_date} ~ {self.end_date}") + + # 2. 计算因子 + print("\n计算因子...") + factor_df = self.compute_factors(data) + print(f" 因子类型: momentum (weighted)\n 窗口天数: {self.n_days}\n 计算完成: {len(factor_df.columns)} 只") + + # 3. 生成信号 + print("\n生成信号...") + signals = self.generate_signals(factor_df) + print(f" 选股数量: {self.select_num}\n 分组选股: {len(set(self._group_mapping.values()))} 个大类\n 信号日期: {len(signals)} 天") + + # 4. 执行回测 + print("\n执行回测...") + returns_data = {} + first_code = valid_codes[0] + for code in valid_codes: + df = index_data[code] + returns_data[f'日收益率_{code}'] = df['close'].pct_change() + + returns_df = pd.DataFrame(returns_data) + returns_df.index = index_data[first_code].index + + executor = BacktestExecutor( + initial_capital=100000, + trade_cost=self.trade_cost, + select_num=self.select_num + ) + + portfolio = executor.execute(signals, returns_df) + + # 5. 输出结果 + if hasattr(portfolio, 'backtest_result'): + result = portfolio.backtest_result + final_nav = result['策略净值'].iloc[-1] + total_return = (final_nav - 1) * 100 + + print("\n回测结果:") + print(f" 最终净值: {final_nav:.4f}\n 累计收益: {total_return:.2f}%") + + # 保存报告 + if save_path: + result[['策略净值']].to_csv(f"{save_path}_nav.csv") + signals.to_csv(f"{save_path}_signals.csv") + print(f" 报告保存: {save_path}_*.csv") + + return { + 'signals': signals, + 'result': result, + 'portfolio': portfolio, + 'total_return': total_return + } + + return {'signals': signals, 'result': None} + + # 保留抽象方法实现 + def init_factors(self) -> FactorCombiner: + return FactorCombiner([self._factor]) + + def init_signal_generator(self) -> SignalGenerator: + return self._selector \ No newline at end of file