Files
etf/framework_v2/tests/test_momentum_parity.py
aszerW 908b28473f feat(framework_v2): 创建框架V2骨架 - 三层架构+因子验证通过
## 架构设计
- 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试)
- 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher
- 零侵入:与现有框架并行开发,不修改生产代码

## 已完成
✓ 核心接口层(5个ABC类)
✓ 通用因子层(MomentumFactor完全复制现有逻辑)
✓ 对比验证测试(新旧因子输出差异=0,测试通过)

## 验证结果
- 最大差异: 0.000000e+00
- 平均差异: 0.000000e+00
- 容差: < 1e-10

## 下一步
- 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController)
- 阶段4: 执行层迁移(BacktestRunner)
- 阶段5: 数据层迁移(DataFetcher实现)
- 阶段6: 完整策略对比验证

## 设计原则
- 按需抽象,不预先设计
- 职责分离,避免框架膨胀
- 测试驱动,每个组件必须有对比测试
- 渐进式迁移,验证通过再替换
2026-05-24 09:12:29 +08:00

117 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
因子对比验证测试
验证新框架的 MomentumFactor 与现有实现输出一致
"""
import sys
import pandas as pd
import numpy as np
from pathlib import Path
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
def test_momentum_factor_parity():
"""验证新因子与旧因子输出一致"""
print("=" * 60)
print(" MomentumFactor 对比测试")
print("=" * 60)
# 1. 加载测试数据
print("\n1. 加载测试数据...")
test_data_path = project_root / 'data' / 'index_history_data'
# 使用纳指100数据测试
import glob
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
if ndx_files:
test_file = ndx_files[0]
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
print(f" ✓ 加载 {test_file}")
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
print(f" 数据长度: {len(data)}")
else:
print(" ⚠ 未找到测试数据,使用模拟数据")
# 生成模拟数据
np.random.seed(42)
dates = pd.date_range('2020-01-01', periods=500, freq='B')
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
data = pd.DataFrame({
'close': prices,
'open': prices * 0.99,
'high': prices * 1.01,
'low': prices * 0.98,
'volume': np.random.randint(1000000, 10000000, 500)
}, index=dates)
# 2. 计算旧因子
print("\n2. 计算旧因子strategies/shared/factors/momentum.py...")
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
old_result = old_factor.compute(data)
print(f" ✓ 旧因子计算完成")
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
print(f" NaN 数量: {old_result.isna().sum()}")
# 3. 计算新因子
print("\n3. 计算新因子framework_v2/shared/factors/momentum.py...")
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
new_result = new_factor.compute(data)
print(f" ✓ 新因子计算完成")
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
print(f" NaN 数量: {new_result.isna().sum()}")
# 4. 对比结果
print("\n4. 对比结果...")
# 检查索引是否一致
if not old_result.index.equals(new_result.index):
print(" ✗ 索引不一致")
return False
print(" ✓ 索引一致")
# 检查数值差异
diff = (old_result - new_result).abs()
max_diff = diff.max()
mean_diff = diff.mean()
print(f" 最大差异: {max_diff:.6e}")
print(f" 平均差异: {mean_diff:.6e}")
# 允许浮点数精度误差1e-10
tolerance = 1e-10
if max_diff < tolerance:
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
print("\n" + "=" * 60)
print(" ✓ 测试通过:新旧因子输出完全一致!")
print("=" * 60)
return True
else:
print(f" ✗ 差异超出容差范围")
print("\n" + "=" * 60)
print(" ✗ 测试失败:新旧因子输出不一致")
print("=" * 60)
# 打印前10个差异点
diff_nonzero = diff[diff > tolerance]
if len(diff_nonzero) > 0:
print(f"\n 前10个差异点:")
for date, val in diff_nonzero.head(10).items():
old_val = old_result.loc[date]
new_val = new_result.loc[date]
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
return False
if __name__ == '__main__':
success = test_momentum_factor_parity()
sys.exit(0 if success else 1)