diff --git a/framework_v2/README.md b/framework_v2/README.md new file mode 100644 index 0000000..80817e8 --- /dev/null +++ b/framework_v2/README.md @@ -0,0 +1,180 @@ +# 框架 V2 - 重构版本 + +## 📋 设计理念 + +### 三层架构 + +``` +framework_v2/ +├── core/ # 纯抽象接口(零实现) +├── shared/ # 通用实现(2+策略复用) +└── tests/ # 框架测试 +``` + +### 设计原则 + +1. **按需抽象**:不预先设计,只抽象已验证的通用逻辑 +2. **职责分离**:数据获取、因子计算、信号生成、回测执行各司其职 +3. **向后兼容**:与现有策略并行运行,验证一致后再替换 +4. **测试驱动**:每个组件必须有对比验证测试 + +--- + +## 🏗️ 目录结构 + +``` +framework_v2/ +├── __init__.py +├── README.md +│ +├── core/ # 核心抽象接口 +│ ├── __init__.py +│ ├── strategy.py # StrategyBase (ABC) +│ ├── factor.py # FactorBase (ABC) +│ ├── signal.py # SignalGenerator (ABC) +│ ├── executor.py # Executor (ABC) +│ └── data.py # DataFetcher (ABC) +│ +├── shared/ # 通用实现 +│ ├── __init__.py +│ └── factors/ +│ ├── __init__.py +│ ├── talib_base.py # TALibFactorBase (需要 talib) +│ └── momentum.py # 动量因子(已验证✓) +│ +└── tests/ # 测试 + ├── __init__.py + └── test_momentum_parity.py # 因子对比测试(通过✓) +``` + +--- + +## ✅ 已完成 + +### 阶段1: 核心接口层 ✓ + +- [x] StrategyBase - 策略抽象基类 +- [x] FactorBase - 因子抽象基类 +- [x] SignalGenerator - 信号生成器抽象基类 +- [x] Executor - 执行器抽象基类 +- [x] DataFetcher - 数据获取器抽象基类 + +### 阶段2: 通用因子层 ✓ + +- [x] MomentumFactor - 动量因子(完全复制现有逻辑) +- [x] 对比验证测试(通过✓,差异 = 0) + +--- + +## 🎯 验证结果 + +### MomentumFactor 对比测试 + +``` +============================================================ + MomentumFactor 对比测试 +============================================================ + +1. 加载测试数据... + ⚠ 未找到测试数据,使用模拟数据 + +2. 计算旧因子(strategies/shared/factors/momentum.py)... + ✓ 旧因子计算完成 + 结果范围: -0.8515 ~ 8.5805 + NaN 数量: 22 + +3. 计算新因子(framework_v2/shared/factors/momentum.py)... + ✓ 新因子计算完成 + 结果范围: -0.8515 ~ 8.5805 + NaN 数量: 22 + +4. 对比结果... + ✓ 索引一致 + 最大差异: 0.000000e+00 + 平均差异: 0.000000e+00 + ✓ 差异在容差范围内 (< 1e-10) + +============================================================ + ✓ 测试通过:新旧因子输出完全一致! +============================================================ +``` + +--- + +## 📝 下一步计划 + +### 阶段3: 信号层迁移 + +- [ ] TopNSelector - Top N 选股器 +- [ ] DynamicThreshold - 动态阈值(V3逻辑) +- [ ] RebalanceController - 调仓控制器 +- [ ] 信号对比验证测试 + +### 阶段4: 执行层迁移 + +- [ ] BacktestRunner - 回测执行器 +- [ ] 收益计算对比测试 + +### 阶段5: 数据层迁移 + +- [ ] RotationDataFetcher - 轮动策略数据获取器 +- [ ] CrossMarketAligner - 跨市场对齐器 + +### 阶段6: 策略组装 + +- [ ] RotationStrategyV2 - 新框架轮动策略 +- [ ] 完整策略对比测试 + +--- + +## 🔧 使用方法 + +### 运行测试 + +```bash +# 运行因子对比测试 +python framework_v2/tests/test_momentum_parity.py + +# 运行所有测试 +python -m pytest framework_v2/tests/ +``` + +### 使用新因子 + +```python +from framework_v2.shared.factors import MomentumFactor + +# 创建因子 +factor = MomentumFactor(n_days=25, weighted=True, crash_filter=True) + +# 计算因子值 +import pandas as pd +data = pd.DataFrame({'close': [...]}, index=[...]) +factor_values = factor.compute(data) +``` + +--- + +## 📊 与旧框架对比 + +| 维度 | 旧框架 (framework/) | 新框架 (framework_v2/) | +|------|---------------------|------------------------| +| **架构** | 抽象+实现混杂 | 三层分离(core/shared/tests) | +| **因子** | 独立实现 | TALibFactorBase + 定制继承 | +| **信号** | 包含所有逻辑 | 拆分为 Signal + Threshold + Rebalance | +| **数据** | 耦合在策略中 | DataFetcher 抽象 | +| **测试** | 部分覆盖 | 每个组件必须有对比测试 | +| **状态** | 生产环境 ✓ | 开发中 🚧 | + +--- + +## ⚠️ 注意事项 + +1. **talib 依赖**:TALibFactorBase 需要安装 `ta-lib`,但未安装不影响 MomentumFactor 使用 +2. **并行开发**:新框架与旧框架并行,不修改现有代码 +3. **验证优先**:每个模块迁移后立即验证,确保结果一致 + +--- + +*创建日期: 2026-05-06* +*版本: 2.0.0* diff --git a/framework_v2/__init__.py b/framework_v2/__init__.py new file mode 100644 index 0000000..972924b --- /dev/null +++ b/framework_v2/__init__.py @@ -0,0 +1,15 @@ +""" +框架 V2 - 重构版本 + +三层架构: +├── core/ # 纯抽象接口(零实现) +├── shared/ # 通用实现(2+策略复用) +└── tests/ # 框架测试 + +设计原则: +├── 按需抽象,不预先设计 +├── 只放通用逻辑,定制逻辑在 strategies/ +└── 每个组件必须有测试 +""" + +__version__ = "2.0.0" diff --git a/framework_v2/core/__init__.py b/framework_v2/core/__init__.py new file mode 100644 index 0000000..63a773e --- /dev/null +++ b/framework_v2/core/__init__.py @@ -0,0 +1,19 @@ +""" +核心抽象接口层(纯ABC,零实现) + +只定义策略框架的标准接口,不包含任何业务逻辑 +""" + +from framework_v2.core.strategy import StrategyBase +from framework_v2.core.factor import FactorBase +from framework_v2.core.signal import SignalGenerator +from framework_v2.core.executor import Executor +from framework_v2.core.data import DataFetcher + +__all__ = [ + 'StrategyBase', + 'FactorBase', + 'SignalGenerator', + 'Executor', + 'DataFetcher', +] diff --git a/framework_v2/core/data.py b/framework_v2/core/data.py new file mode 100644 index 0000000..355b36b --- /dev/null +++ b/framework_v2/core/data.py @@ -0,0 +1,97 @@ +""" +数据获取器抽象基类 +""" + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional +import pandas as pd + + +class DataFetcher(ABC): + """ + 数据获取器抽象基类 + + 所有数据获取器必须实现必要方法 + """ + + name: str = "base" + + def __init__(self, **params): + """ + 初始化数据获取器参数 + + Args: + **params: 数据源参数(如 api_url, ssh_config 等) + """ + self._params = params + + @abstractmethod + def fetch_indices( + self, + codes: List[str], + start: str, + end: str + ) -> Dict[str, pd.DataFrame]: + """ + 获取指数 OHLCV 数据 + + Args: + codes: 指数代码列表 + start: 开始日期 (YYYY-MM-DD) + end: 结束日期 (YYYY-MM-DD) + + Returns: + {code: DataFrame} 字典 + """ + pass + + @abstractmethod + def fetch_etf( + self, + codes: List[str], + start: str, + end: str + ) -> Dict[str, pd.DataFrame]: + """ + 获取 ETF 数据(价格 + 净值) + + Args: + codes: ETF 代码列表 + start: 开始日期 + end: 结束日期 + + Returns: + {code: DataFrame} 字典 + """ + pass + + @abstractmethod + def get_trading_calendar(self, market: str = 'A') -> pd.Index: + """ + 获取交易日历 + + Args: + market: 市场代码('A', 'US', 'HK' 等) + + Returns: + 交易日历 Index + """ + pass + + def get_benchmark(self, code: str, start: str, end: str) -> pd.Series: + """ + 获取基准数据(可选) + + Args: + code: 基准代码 + start: 开始日期 + end: 结束日期 + + Returns: + 基准收盘价 Series + """ + raise NotImplementedError("Optional method") + + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}(name={self.name})" diff --git a/framework_v2/core/executor.py b/framework_v2/core/executor.py new file mode 100644 index 0000000..9325bbe --- /dev/null +++ b/framework_v2/core/executor.py @@ -0,0 +1,46 @@ +""" +执行器抽象基类 +""" + +from abc import ABC, abstractmethod +import pandas as pd + + +class Executor(ABC): + """ + 执行器抽象基类 + + 所有执行器必须实现 execute 方法 + """ + + mode: str = "base" + + def __init__(self, **params): + """ + 初始化执行器参数 + + Args: + **params: 执行参数(如 initial_capital, trade_cost 等) + """ + self._params = params + + @abstractmethod + def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> dict: + """ + 执行信号 + + Args: + signals: 信号 DataFrame + data: 收益率数据 DataFrame + + Returns: + 回测结果字典,包含: + - result: 回测 DataFrame(含净值、收益率) + - portfolio: 组合对象(可选) + - metrics: 绩效指标(可选) + """ + pass + + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" diff --git a/framework_v2/core/factor.py b/framework_v2/core/factor.py new file mode 100644 index 0000000..ed935df --- /dev/null +++ b/framework_v2/core/factor.py @@ -0,0 +1,59 @@ +""" +因子抽象基类 +""" + +from abc import ABC, abstractmethod +import pandas as pd + + +class FactorBase(ABC): + """ + 因子抽象基类 + + 所有因子必须实现 compute 方法 + """ + + name: str = "base" + category: str = "unknown" + + def __init__(self, **params): + """ + 初始化因子参数 + + Args: + **params: 因子参数(如 n_days, weighted 等) + """ + self._params = params + + @abstractmethod + def compute(self, data: pd.DataFrame) -> pd.Series: + """ + 计算因子值 + + Args: + data: OHLCV 数据,必须包含 'close' 列 + + Returns: + 因子值序列(与 data 同索引) + """ + pass + + def validate_data(self, data: pd.DataFrame) -> bool: + """ + 验证数据是否满足计算要求 + + Args: + data: OHLCV 数据 + + Returns: + True 如果数据有效 + """ + if 'close' not in data.columns: + return False + + min_periods = self._params.get('min_periods', 20) + return len(data) >= min_periods + + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" diff --git a/framework_v2/core/signal.py b/framework_v2/core/signal.py new file mode 100644 index 0000000..b665959 --- /dev/null +++ b/framework_v2/core/signal.py @@ -0,0 +1,57 @@ +""" +信号生成器抽象基类 +""" + +from abc import ABC, abstractmethod +import pandas as pd + + +class SignalGenerator(ABC): + """ + 信号生成器抽象基类 + + 所有信号生成器必须实现 generate 方法 + """ + + mode: str = "base" + + def __init__(self, **params): + """ + 初始化信号生成器参数 + + Args: + **params: 信号参数(如 select_num, rebalance_days 等) + """ + self._params = params + + @abstractmethod + def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: + """ + 生成交易信号 + + Args: + factor_data: 因子数据 DataFrame + + Returns: + 信号 DataFrame,必须包含 'signal' 列 + """ + pass + + def validate_factor_data(self, factor_data: pd.DataFrame) -> bool: + """ + 验证因子数据是否有效 + + Args: + factor_data: 因子数据 + + Returns: + True 如果数据有效 + """ + if factor_data.empty: + return False + + return True + + def __repr__(self) -> str: + params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()]) + return f"{self.__class__.__name__}({params_str})" diff --git a/framework_v2/core/strategy.py b/framework_v2/core/strategy.py new file mode 100644 index 0000000..922100c --- /dev/null +++ b/framework_v2/core/strategy.py @@ -0,0 +1,151 @@ +""" +策略抽象基类 + +所有策略必须继承此类并实现必要方法 +""" + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Any +import pandas as pd + + +class StrategyBase(ABC): + """ + 策略抽象基类 + + 定义策略的标准生命周期: + 1. 初始化配置 + 2. 获取数据 + 3. 计算因子 + 4. 生成信号 + 5. 执行回测 + + 子类必须实现: + - init_factors(): 初始化因子 + - init_signal_generator(): 初始化信号生成器 + """ + + INTERFACE_VERSION = 2 # V2 版本 + + name: str = "base" + timeframe: str = "1d" + + def __init__(self, config: Optional[Dict] = None): + """ + 初始化策略 + + Args: + config: 策略配置字典 + """ + self.config = config or {} + self._factor = None + self._signal_generator = None + + @abstractmethod + def init_factors(self) -> Any: + """ + 初始化因子组件 + + Returns: + 因子实例(继承 FactorBase) + """ + pass + + @abstractmethod + def init_signal_generator(self) -> Any: + """ + 初始化信号生成器 + + Returns: + 信号生成器实例(继承 SignalGenerator) + """ + pass + + def get_data(self) -> Dict[str, Any]: + """ + 获取数据(可选覆盖) + + Returns: + 数据字典,包含: + - index_data: 指数数据 + - etf_data: ETF数据 + - benchmark_data: 基准数据 + - valid_codes: 有效标的列表 + - trading_calendar: 交易日历 + """ + raise NotImplementedError("Subclasses must implement get_data()") + + def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame: + """ + 计算因子(可选覆盖) + + Args: + data: 数据字典 + + Returns: + 因子 DataFrame(日期 × 标的) + """ + if self._factor is None: + self._factor = self.init_factors() + + # 默认实现:遍历标的计算因子 + factor_values = {} + for code in data.get('valid_codes', []): + if code in data.get('index_data', {}): + factor_values[code] = self._factor.compute(data['index_data'][code]) + + return pd.DataFrame(factor_values) + + def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: + """ + 生成信号 + + Args: + factor_df: 因子 DataFrame + + Returns: + 信号 DataFrame(包含 'signal' 列) + """ + if self._signal_generator is None: + self._signal_generator = self.init_signal_generator() + + return self._signal_generator.generate(factor_df) + + def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]: + """ + 运行完整回测流程 + + Args: + data: 可选,如不提供则自动获取 + + Returns: + 回测结果字典 + """ + # 1. 获取数据 + if data is None: + data = self.get_data() + + # 2. 计算因子 + factor_df = self.compute_factors(data) + + # 3. 生成信号 + signals = self.generate_signals(factor_df) + + # 4. 执行回测(子类实现) + return self._execute_backtest(signals, data) + + def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]: + """ + 执行回测(子类可覆盖) + + Args: + signals: 信号 DataFrame + data: 数据字典 + + Returns: + 回测结果 + """ + raise NotImplementedError("Subclasses must implement _execute_backtest()") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.name})" diff --git a/framework_v2/shared/__init__.py b/framework_v2/shared/__init__.py new file mode 100644 index 0000000..8ad04d7 --- /dev/null +++ b/framework_v2/shared/__init__.py @@ -0,0 +1,9 @@ +""" +通用实现层(2+ 策略复用的组件) + +包含: +├── factors/ # 通用因子 +├── signals/ # 通用信号生成器 +├── execution/ # 通用执行器 +└── data/ # 通用数据处理 +""" diff --git a/framework_v2/shared/factors/__init__.py b/framework_v2/shared/factors/__init__.py new file mode 100644 index 0000000..5613d0a --- /dev/null +++ b/framework_v2/shared/factors/__init__.py @@ -0,0 +1,17 @@ +""" +通用因子实现 +""" + +from framework_v2.shared.factors.momentum import MomentumFactor + +# TALibFactorBase 需要安装 talib,可选导入 +try: + from framework_v2.shared.factors.talib_base import TALibFactorBase + __all__ = [ + 'TALibFactorBase', + 'MomentumFactor', + ] +except ImportError: + __all__ = [ + 'MomentumFactor', + ] diff --git a/framework_v2/shared/factors/momentum.py b/framework_v2/shared/factors/momentum.py new file mode 100644 index 0000000..63bafae --- /dev/null +++ b/framework_v2/shared/factors/momentum.py @@ -0,0 +1,104 @@ +""" +动量因子(通用版本) + +使用加权线性回归:得分 = 年化收益率 × R² + +与现有 MomentumFactor 对比验证: +- 输入相同 → 输出应该相同 +""" + +import pandas as pd +import numpy as np +import math +from framework_v2.core import FactorBase + + +class MomentumFactor(FactorBase): + """ + 动量因子 + + 计算加权线性回归动量得分: + 得分 = 年化收益率 × R² + + 参数: + - n_days: 动量窗口(默认25) + - weighted: 是否加权(默认True) + - crash_filter: 是否启用崩盘过滤(默认True) + """ + + name = "momentum" + category = "momentum" + + def __init__( + self, + n_days: int = 25, + weighted: bool = True, + crash_filter: bool = True + ): + super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter) + self.n_days = n_days + self.weighted = weighted + self.crash_filter = crash_filter + + def compute(self, data: pd.DataFrame) -> pd.Series: + """计算动量因子值""" + if 'close' not in data.columns: + raise ValueError("data must contain 'close' column") + + prices = data['close'] + + if self.weighted: + factor_values = prices.rolling(self.n_days).apply( + lambda x: self._weighted_momentum_score(x.values), + raw=False + ) + else: + factor_values = prices.pct_change(self.n_days) + + if self.crash_filter: + factor_values = self._apply_crash_filter(prices, factor_values) + + return factor_values + + def _weighted_momentum_score(self, prices: np.ndarray) -> float: + """计算加权动量得分(完全复制现有逻辑)""" + if len(prices) < 5: + return 0.0 + + # 价格下界 clip,防止 log(0) 或 log(负数) + prices = np.clip(prices, 0.01, None) + y = np.log(prices) + + # 异常值检测 + if np.any(np.isnan(y)) or np.any(np.isinf(y)): + return 0.0 + + x = np.arange(len(y)) + weights = np.linspace(1, 2, len(y)) + + slope, intercept = np.polyfit(x, y, 1, w=weights) + annualized_returns = math.exp(slope * 250) - 1 + + y_pred = slope * x + intercept + ss_res = np.sum(weights * (y - y_pred) ** 2) + ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 + + return annualized_returns * r2 + + def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series: + """崩盘过滤:连续3天跌>5%清零(完全复制现有逻辑)""" + result = factor_values.copy() + + for i in range(3, len(prices)): + r1 = prices.iloc[i] / prices.iloc[i-1] + r2 = prices.iloc[i-1] / prices.iloc[i-2] + r3 = prices.iloc[i-2] / prices.iloc[i-3] + + con1 = min(r1, r2, r3) < 0.95 + con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95) + + if con1 or con2: + result.iloc[i] = 0.0 + + return result diff --git a/framework_v2/shared/factors/talib_base.py b/framework_v2/shared/factors/talib_base.py new file mode 100644 index 0000000..86d9b36 --- /dev/null +++ b/framework_v2/shared/factors/talib_base.py @@ -0,0 +1,55 @@ +""" +ta-lib 因子基类(通用) + +所有 ta-lib 因子继承此类,只需指定函数和参数 +""" + +import talib +import pandas as pd +import numpy as np +from framework_v2.core import FactorBase + + +class TALibFactorBase(FactorBase): + """ + ta-lib 因子基类 + + 子类只需实现: + - name: 因子名称 + - _talib_func: 返回 ta-lib 函数 + """ + + category = "technical" + + def __init__(self, period: int = 14, **params): + """ + 初始化 + + Args: + period: 周期参数 + **params: 其他参数 + """ + super().__init__(period=period, **params) + self.period = period + + def compute(self, data: pd.DataFrame) -> pd.Series: + """ + 计算因子值 + + Args: + data: OHLCV 数据 + + Returns: + 因子值序列 + """ + close = data['close'].values.astype(float) + + # 调用子类指定的 ta-lib 函数 + result = self._talib_func(close, timeperiod=self.period) + + return pd.Series(result, index=data.index, name=self.name) + + @property + def _talib_func(self): + """子类必须实现,返回 ta-lib 函数""" + raise NotImplementedError("Subclasses must implement _talib_func") diff --git a/framework_v2/tests/__init__.py b/framework_v2/tests/__init__.py new file mode 100644 index 0000000..efddb61 --- /dev/null +++ b/framework_v2/tests/__init__.py @@ -0,0 +1,3 @@ +""" +框架 V2 测试 +""" diff --git a/framework_v2/tests/test_momentum_parity.py b/framework_v2/tests/test_momentum_parity.py new file mode 100644 index 0000000..cc34910 --- /dev/null +++ b/framework_v2/tests/test_momentum_parity.py @@ -0,0 +1,116 @@ +""" +因子对比验证测试 + +验证新框架的 MomentumFactor 与现有实现输出一致 +""" + +import sys +import pandas as pd +import numpy as np +from pathlib import Path + +# 添加项目根目录 +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +def test_momentum_factor_parity(): + """验证新因子与旧因子输出一致""" + + print("=" * 60) + print(" MomentumFactor 对比测试") + print("=" * 60) + + # 1. 加载测试数据 + print("\n1. 加载测试数据...") + test_data_path = project_root / 'data' / 'index_history_data' + + # 使用纳指100数据测试 + import glob + ndx_files = glob.glob(str(test_data_path / '*NDX*')) + if ndx_files: + test_file = ndx_files[0] + data = pd.read_csv(test_file, index_col=0, parse_dates=True) + print(f" ✓ 加载 {test_file}") + print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}") + print(f" 数据长度: {len(data)} 条") + else: + print(" ⚠ 未找到测试数据,使用模拟数据") + # 生成模拟数据 + np.random.seed(42) + dates = pd.date_range('2020-01-01', periods=500, freq='B') + prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02) + data = pd.DataFrame({ + 'close': prices, + 'open': prices * 0.99, + 'high': prices * 1.01, + 'low': prices * 0.98, + 'volume': np.random.randint(1000000, 10000000, 500) + }, index=dates) + + # 2. 计算旧因子 + print("\n2. 计算旧因子(strategies/shared/factors/momentum.py)...") + from strategies.shared.factors.momentum import MomentumFactor as OldMomentum + + old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True) + old_result = old_factor.compute(data) + print(f" ✓ 旧因子计算完成") + print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}") + print(f" NaN 数量: {old_result.isna().sum()}") + + # 3. 计算新因子 + print("\n3. 计算新因子(framework_v2/shared/factors/momentum.py)...") + from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum + + new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True) + new_result = new_factor.compute(data) + print(f" ✓ 新因子计算完成") + print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}") + print(f" NaN 数量: {new_result.isna().sum()}") + + # 4. 对比结果 + print("\n4. 对比结果...") + + # 检查索引是否一致 + if not old_result.index.equals(new_result.index): + print(" ✗ 索引不一致") + return False + print(" ✓ 索引一致") + + # 检查数值差异 + diff = (old_result - new_result).abs() + max_diff = diff.max() + mean_diff = diff.mean() + + print(f" 最大差异: {max_diff:.6e}") + print(f" 平均差异: {mean_diff:.6e}") + + # 允许浮点数精度误差(1e-10) + tolerance = 1e-10 + if max_diff < tolerance: + print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})") + print("\n" + "=" * 60) + print(" ✓ 测试通过:新旧因子输出完全一致!") + print("=" * 60) + return True + else: + print(f" ✗ 差异超出容差范围") + print("\n" + "=" * 60) + print(" ✗ 测试失败:新旧因子输出不一致") + print("=" * 60) + + # 打印前10个差异点 + diff_nonzero = diff[diff > tolerance] + if len(diff_nonzero) > 0: + print(f"\n 前10个差异点:") + for date, val in diff_nonzero.head(10).items(): + old_val = old_result.loc[date] + new_val = new_result.loc[date] + print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}") + + return False + + +if __name__ == '__main__': + success = test_momentum_factor_parity() + sys.exit(0 if success else 1)