From b89e975aedb119f03902bc61385cb7cd564bd1b0 Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 25 May 2026 01:33:23 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=88=A0=E9=99=A4=20SimpleRotation?= =?UTF-8?q?Strategy=20=E7=AE=80=E5=8C=96=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除 simple.py(已被 GlobalRotationStrategy 替代) - 删除 backtest_simple_rotation.py 回测脚本 - 删除 test_simple_rotation.py 测试脚本 - 更新 __init__.py 移除 SimpleRotationStrategy 导出 - 现在只保留 GlobalRotationStrategy 正式版 --- .../scripts/backtest_simple_rotation.py | 98 ------ framework_v2/strategies/rotation/__init__.py | 3 +- framework_v2/strategies/rotation/simple.py | 304 ------------------ framework_v2/tests/test_simple_rotation.py | 128 -------- 4 files changed, 1 insertion(+), 532 deletions(-) delete mode 100644 framework_v2/scripts/backtest_simple_rotation.py delete mode 100644 framework_v2/strategies/rotation/simple.py delete mode 100644 framework_v2/tests/test_simple_rotation.py diff --git a/framework_v2/scripts/backtest_simple_rotation.py b/framework_v2/scripts/backtest_simple_rotation.py deleted file mode 100644 index 3bc708a..0000000 --- a/framework_v2/scripts/backtest_simple_rotation.py +++ /dev/null @@ -1,98 +0,0 @@ -""" -简单轮动策略回测脚本 - -测试场景:指数信号 → ETF收益 -- 使用指数计算动量信号 -- 使用 ETF 计算收益 -""" - -import sys -from pathlib import Path - -# 添加项目根目录到 Python 路径 -project_root = Path(__file__).parent.parent.parent -sys.path.insert(0, str(project_root)) - -from framework_v2.config import load_config -from framework_v2.strategies.rotation.simple import SimpleRotationStrategy - - -def run_backtest(): - """运行回测""" - print("=" * 70) - print(" ETF轮动策略回测(V2 框架)") - print(" 场景:指数信号 → ETF收益,复现 V1 结果") - print("=" * 70) - - # 加载配置 - config_file = project_root / "framework_v2" / "strategies" / "rotation" / "config_simple.yaml" - print(f"\n配置文件: {config_file}") - - config = load_config(str(config_file)) - - # 打印配置摘要 - print("\n" + "=" * 70) - print(" 配置摘要") - print("=" * 70) - print(f"策略名称: {config.metadata.strategy}") - print(f"回测区间: {config.backtest.start_date} ~ {config.backtest.end_date or '至今'}") - print(f"因子类型: {config.factor.type.value}") - print(f"动量窗口: {config.factor.n_days} 天") - print(f"选股数量: {config.rotation.select_num}") - - # 打印资产池 - print(f"\n资产池 ({config.asset_pools.count()} 个标的):") - for code, asset in config.asset_pools.assets.items(): - print(f" {code}: {asset.name}") - print(f" 分组: {asset.group}") - print(f" 信号: {asset.signal_source}") - print(f" 交易: {asset.trade_source}") - print(f" 跨市场: {'是' if asset.is_cross_market else '否'}") - - # 创建策略 - print("\n" + "=" * 70) - print(" 运行回测...") - print("=" * 70) - - strategy = SimpleRotationStrategy(config) - result = strategy.run() - - # 打印结果 - print("\n" + "=" * 70) - print(" 回测结果") - print("=" * 70) - - metrics = result['metrics'] - print(f"总收益: {metrics['total_return']:.2%}") - print(f"年化收益: {metrics['annual_return']:.2%}") - print(f"最大回撤: {metrics['max_drawdown']:.2%}") - print(f"夏普比率: {metrics['sharpe_ratio']:.2f}") - print(f"交易天数: {metrics['n_days']}") - - # 打印净值曲线 - equity_curve = result['equity_curve'] - print(f"\n净值曲线:") - print(f" 起始净值: {equity_curve.iloc[0]:.4f}") - print(f" 结束净值: {equity_curve.iloc[-1]:.4f}") - print(f" 数据点数: {len(equity_curve)}") - - # 保存结果 - output_dir = project_root / "framework_v2" / "results" - output_dir.mkdir(exist_ok=True) - - # 保存净值曲线 - equity_curve.to_csv(output_dir / "simple_rotation_equity.csv") - print(f"\n净值曲线已保存: {output_dir / 'simple_rotation_equity.csv'}") - - # 保存持仓记录 - positions = result['positions'] - positions.to_csv(output_dir / "simple_rotation_positions.csv") - print(f"持仓记录已保存: {output_dir / 'simple_rotation_positions.csv'}") - - print("\n" + "=" * 70) - print(" 回测完成!") - print("=" * 70) - - -if __name__ == "__main__": - run_backtest() diff --git a/framework_v2/strategies/rotation/__init__.py b/framework_v2/strategies/rotation/__init__.py index cf1dab2..a0534c2 100644 --- a/framework_v2/strategies/rotation/__init__.py +++ b/framework_v2/strategies/rotation/__init__.py @@ -2,7 +2,6 @@ 轮动策略模块 """ -from framework_v2.strategies.rotation.simple import SimpleRotationStrategy from framework_v2.strategies.rotation.rotation import GlobalRotationStrategy -__all__ = ['SimpleRotationStrategy', 'GlobalRotationStrategy'] +__all__ = ['GlobalRotationStrategy'] diff --git a/framework_v2/strategies/rotation/simple.py b/framework_v2/strategies/rotation/simple.py deleted file mode 100644 index b74f1a4..0000000 --- a/framework_v2/strategies/rotation/simple.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -简单轮动策略 - -基于动量因子的 ETF 轮动策略 -- 计算各标的动量得分 -- 选择 Top-N 标的 -- 等权分配仓位 -""" - -import pandas as pd -import numpy as np -from typing import Dict - -from framework_v2.core.strategy import StrategyBase -from framework_v2.config.schemas import StrategyConfig -from framework_v2.shared.factors import MomentumFactor - - -class SimpleRotationStrategy(StrategyBase): - """ - 简单轮动策略 - - 策略逻辑: - 1. 计算各标的动量得分(加权线性回归) - 2. 选择得分最高的 Top-N 标的 - 3. 等权分配仓位 - - 示例: - from framework_v2.config import load_config - from framework_v2.strategies.rotation.simple import SimpleRotationStrategy - - config = load_config('rotation_simple.yaml') - strategy = SimpleRotationStrategy(config) - result = strategy.run() - """ - - def __init__(self, config: StrategyConfig): - """ - 初始化策略 - - Args: - config: 策略配置 - """ - super().__init__(config) - - # 初始化动量因子 - self.momentum = MomentumFactor( - n_days=config.factor.n_days, - weighted=(config.factor.type.value == 'weighted_momentum') - ) - - # 策略参数 - self.select_num = config.rotation.select_num if config.rotation else 3 - self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0 - - def get_codes(self) -> list: - """ - 获取标的列表(信号标的 + 交易标的) - - 返回所有需要的数据标的: - - signal_source: 用于计算因子和信号 - - trade_source: 用于计算收益 - """ - codes = set() - - # 添加所有信号标的 - codes.update(self.config.asset_pools.get_signal_codes()) - - # 添加所有交易标的 - codes.update(self.config.asset_pools.get_trade_codes()) - - return list(codes) - - def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]: - """ - 计算动量因子(只使用信号标的的数据) - - Args: - data: 数据字典 {code: DataFrame}(包含 signal_source 和 trade_source) - - Returns: - 因子字典 {signal_source: Series} - """ - factors = {} - - # 只使用信号标的计算因子 - signal_codes = self.config.asset_pools.get_signal_codes() - - for code in signal_codes: - if code not in data: - print(f" 警告: {code} 数据不存在,跳过") - continue - - try: - df = data[code] - # 计算动量得分(使用信号标的的数据) - factor_values = self.momentum.compute(df) - factors[code] = factor_values - except Exception as e: - print(f" 警告: {code} 因子计算失败 - {e}") - continue - - return factors - - def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame: - """ - 生成轮动信号 - - 逻辑: - 1. 每个交易日选择动量得分最高的 Top-N 标的 - 2. 过滤得分低于阈值的标的 - - Args: - factors: 因子字典 {code: Series} - - Returns: - 信号 DataFrame(index=日期, columns=标的, values=1或0) - """ - if not factors: - return pd.DataFrame() - - # 对齐所有因子的日期 - factor_df = pd.DataFrame(factors) - - # 生成信号 - signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0) - - for date in factor_df.index: - # 获取当日因子值 - scores = factor_df.loc[date].dropna() - - if scores.empty: - continue - - # 过滤低分标的 - if self.min_score > 0: - scores = scores[scores >= self.min_score] - - # 选择 Top-N - if len(scores) > self.select_num: - top_codes = scores.nlargest(self.select_num).index - else: - top_codes = scores.index - - # 标记信号 - signals.loc[date, top_codes] = 1 - - return signals.astype(int) - - def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame: - """ - 仓位管理(等权分配) - - Args: - signals: 信号 DataFrame - - Returns: - 仓位 DataFrame(包含 'weight' 列) - """ - positions = signals.astype(float).copy() - - # 计算每个日期的权重 - for date in positions.index: - signal_row = positions.loc[date] - n_selected = signal_row.sum() - - if n_selected > 0: - # 等权分配 - positions.loc[date] = signal_row / n_selected - else: - # 空仓 - positions.loc[date] = 0 - - return positions - - def _get_trading_calendar(self) -> pd.DatetimeIndex: - """ - 获取 A 股交易日历 - - Returns: - A 股交易日历 DatetimeIndex - """ - from datetime import date - - # 获取回测区间 - start = self.config.backtest.start_date - end = self.config.backtest.end_date - if end is None: - end = date.today().strftime('%Y-%m-%d') - - # 创建临时数据获取器来获取交易日历 - if self._data_fetcher is None: - self._data_fetcher = self._create_data_fetcher() - - try: - # 调用 get_trading_calendar 方法 - calendar = self._data_fetcher.get_trading_calendar( - market='A', - start=start, - end=end - ) - print(f" [日历] A 股交易日: {len(calendar)} 天 ({calendar[0]} ~ {calendar[-1]})") - return calendar - except Exception as e: - print(f" [警告] 无法获取 A 股交易日历,使用所有日期: {e}") - # 降级方案:使用 pandas 生成工作日 - from pandas.tseries.offsets import BDay - start_dt = pd.Timestamp(start) - end_dt = pd.Timestamp(end) - return pd.date_range(start=start_dt, end=end_dt, freq='B') # 工作日 - - def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]: - """ - 执行回测 - - 核心逻辑: - 1. 使用 signal_source 计算信号(positions 的 columns 是 signal_source) - 2. 使用 trade_source 计算收益(通过 signal→trade 映射) - 3. T+1 执行:今天的信号明天生效 - 4. 过滤非交易日:只保留 A 股交易日 - - Args: - positions: 仓位 DataFrame(columns=signal_source) - data: 数据字典 {code: DataFrame}(包含 signal_source 和 trade_source) - - Returns: - 回测结果字典 - """ - # 获取信号→交易映射 - signal_to_trade = self.config.asset_pools.get_signal_to_trade_mapping() - - # 提取交易标的的收盘价 - close_prices = {} - for signal_code, trade_code in signal_to_trade.items(): - if trade_code in data: - # 使用交易标的的数据计算收益 - close_prices[signal_code] = data[trade_code]['close'] - else: - print(f" 警告: {trade_code} 数据不存在,跳过") - - close_df = pd.DataFrame(close_prices) - - # 计算收益率 - returns = close_df.pct_change() - - # 获取 A 股交易日历并过滤 - print("\n [过滤] 获取 A 股交易日历...") - trading_calendar = self._get_trading_calendar() - - # 过滤到 A 股交易日 - original_days = len(returns) - returns = returns[returns.index.isin(trading_calendar)] - positions = positions[positions.index.isin(trading_calendar)] - filtered_days = len(returns) - print(f" [过滤] 原始数据: {original_days} 天 -> A 股交易日: {filtered_days} 天 (过滤 {original_days - filtered_days} 天)") - - # 计算策略收益(仓位加权) - # 注意:T+1 执行,今天的信号明天生效 - positions_delayed = positions.shift(1).fillna(0) - strategy_returns = (positions_delayed * returns).sum(axis=1) - - # 计算净值曲线 - equity_curve = (1 + strategy_returns).cumprod() - - # 检查是否有数据 - if len(equity_curve) == 0: - return { - 'equity_curve': equity_curve, - 'strategy_returns': strategy_returns, - 'positions': positions, - 'metrics': { - 'total_return': 0, - 'annual_return': 0, - 'max_drawdown': 0, - 'sharpe_ratio': 0, - 'n_days': 0, - } - } - - # 计算绩效指标 - total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1 - n_days = len(strategy_returns) - annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0 - - # 最大回撤 - cumulative_max = equity_curve.cummax() - drawdown = (equity_curve - cumulative_max) / cumulative_max - max_drawdown = drawdown.min() - - # 夏普比率 - sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0 - - return { - 'equity_curve': equity_curve, - 'strategy_returns': strategy_returns, - 'positions': positions, - 'metrics': { - 'total_return': total_return, - 'annual_return': annual_return, - 'max_drawdown': max_drawdown, - 'sharpe_ratio': sharpe, - 'n_days': n_days, - } - } diff --git a/framework_v2/tests/test_simple_rotation.py b/framework_v2/tests/test_simple_rotation.py deleted file mode 100644 index 17de139..0000000 --- a/framework_v2/tests/test_simple_rotation.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -测试简单轮动策略 - -验证完整流程: -1. 配置加载 -2. 策略初始化 -3. 数据获取 -4. 因子计算 -5. 信号生成 -6. 回测执行 -""" - -import sys -from pathlib import Path -import os - -# 添加项目根目录到路径 -project_root = Path(__file__).parent.parent -if str(project_root) not in sys.path: - sys.path.insert(0, str(project_root)) - -from framework_v2.config import load_config -from framework_v2.strategies.rotation.simple import SimpleRotationStrategy - - -def test_simple_rotation(): - """测试简单轮动策略完整流程""" - - print("\n" + "=" * 70) - print(" 简单轮动策略端到端测试") - print("=" * 70) - - # 设置环境变量 - os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' - - # 1. 加载配置 - print("\n[1/6] 加载配置...") - config_path = Path(__file__).parent.parent / 'strategies' / 'rotation' / 'config_simple.yaml' - config = load_config(str(config_path)) - print(f" ✓ 配置加载成功") - print(f" 策略: {config.metadata.strategy}") - print(f" 标的: {list(config.asset_pools.equity.keys())}") - print(f" 回测: {config.backtest.start_date} ~ {config.backtest.end_date}") - - # 2. 初始化策略 - print("\n[2/6] 初始化策略...") - strategy = SimpleRotationStrategy(config) - print(f" ✓ 策略初始化成功") - print(f" 名称: {strategy.name}") - print(f" 动量窗口: {config.factor.n_days} 天") - print(f" 选股数量: {strategy.select_num}") - - # 3. 获取数据 - print("\n[3/6] 获取数据...") - codes = strategy.get_codes() - print(f" 标的列表: {codes}") - - data = strategy.get_data() - print(f" ✓ 获取 {len(data)} 个标的") - for code, df in data.items(): - print(f" {code}: {len(df)} 天 ({df.index[0].date()} ~ {df.index[-1].date()})") - - # 4. 计算因子 - print("\n[4/6] 计算因子...") - factors = strategy.compute_factors(data) - print(f" ✓ 计算 {len(factors)} 个因子") - for code, factor in factors.items(): - print(f" {code}: {len(factor)} 值, 范围 [{factor.min():.4f}, {factor.max():.4f}]") - - # 5. 生成信号 - print("\n[5/6] 生成信号...") - signals = strategy.generate_signals(factors) - n_signals = signals.sum().sum() - print(f" ✓ 生成 {signals.shape[0]} 个交易日信号") - print(f" 总信号数: {n_signals}") - print(f" 平均每日持仓: {signals.mean().mean():.2%}") - - # 6. 仓位管理 - print("\n[6/6] 仓位管理...") - positions = strategy.manage_positions(signals) - print(f" ✓ 仓位分配完成") - print(f" 权重和: {positions.sum(axis=1).mean():.2%}") - - # 7. 执行回测 - print("\n执行回测...") - result = strategy._execute_backtest(positions, data) - - # 打印结果 - print("\n" + "=" * 70) - print(" 回测结果") - print("=" * 70) - - metrics = result['metrics'] - print(f"\n 总收益率: {metrics['total_return']:.2%}") - print(f" 年化收益: {metrics['annual_return']:.2%}") - print(f" 最大回撤: {metrics['max_drawdown']:.2%}") - print(f" 夏普比率: {metrics['sharpe_ratio']:.2f}") - print(f" 交易天数: {metrics['n_days']}") - - # 验证结果 - print("\n" + "=" * 70) - print(" 验证") - print("=" * 70) - - assert metrics['total_return'] != 0, "总收益率不应为 0" - print(" ✓ 总收益率有效") - - assert len(result['equity_curve']) > 0, "净值曲线不应为空" - print(" ✓ 净值曲线有效") - - assert positions.sum(axis=1).max() <= 1.01, "权重和不应超过 100%" - print(" ✓ 仓位权重有效") - - print("\n" + "=" * 70) - print(" ✓ 所有测试通过") - print("=" * 70 + "\n") - - return result - - -if __name__ == "__main__": - try: - result = test_simple_rotation() - except Exception as e: - print(f"\n✗ 测试失败: {e}") - import traceback - traceback.print_exc() - sys.exit(1)