## 架构设计 - 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试) - 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher - 零侵入:与现有框架并行开发,不修改生产代码 ## 已完成 ✓ 核心接口层(5个ABC类) ✓ 通用因子层(MomentumFactor完全复制现有逻辑) ✓ 对比验证测试(新旧因子输出差异=0,测试通过) ## 验证结果 - 最大差异: 0.000000e+00 - 平均差异: 0.000000e+00 - 容差: < 1e-10 ## 下一步 - 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController) - 阶段4: 执行层迁移(BacktestRunner) - 阶段5: 数据层迁移(DataFetcher实现) - 阶段6: 完整策略对比验证 ## 设计原则 - 按需抽象,不预先设计 - 职责分离,避免框架膨胀 - 测试驱动,每个组件必须有对比测试 - 渐进式迁移,验证通过再替换
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
"""
|
||
因子对比验证测试
|
||
|
||
验证新框架的 MomentumFactor 与现有实现输出一致
|
||
"""
|
||
|
||
import sys
|
||
import pandas as pd
|
||
import numpy as np
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录
|
||
project_root = Path(__file__).parent.parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
|
||
def test_momentum_factor_parity():
|
||
"""验证新因子与旧因子输出一致"""
|
||
|
||
print("=" * 60)
|
||
print(" MomentumFactor 对比测试")
|
||
print("=" * 60)
|
||
|
||
# 1. 加载测试数据
|
||
print("\n1. 加载测试数据...")
|
||
test_data_path = project_root / 'data' / 'index_history_data'
|
||
|
||
# 使用纳指100数据测试
|
||
import glob
|
||
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
|
||
if ndx_files:
|
||
test_file = ndx_files[0]
|
||
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
|
||
print(f" ✓ 加载 {test_file}")
|
||
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
|
||
print(f" 数据长度: {len(data)} 条")
|
||
else:
|
||
print(" ⚠ 未找到测试数据,使用模拟数据")
|
||
# 生成模拟数据
|
||
np.random.seed(42)
|
||
dates = pd.date_range('2020-01-01', periods=500, freq='B')
|
||
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
|
||
data = pd.DataFrame({
|
||
'close': prices,
|
||
'open': prices * 0.99,
|
||
'high': prices * 1.01,
|
||
'low': prices * 0.98,
|
||
'volume': np.random.randint(1000000, 10000000, 500)
|
||
}, index=dates)
|
||
|
||
# 2. 计算旧因子
|
||
print("\n2. 计算旧因子(strategies/shared/factors/momentum.py)...")
|
||
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
|
||
|
||
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
|
||
old_result = old_factor.compute(data)
|
||
print(f" ✓ 旧因子计算完成")
|
||
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
|
||
print(f" NaN 数量: {old_result.isna().sum()}")
|
||
|
||
# 3. 计算新因子
|
||
print("\n3. 计算新因子(framework_v2/shared/factors/momentum.py)...")
|
||
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
|
||
|
||
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
|
||
new_result = new_factor.compute(data)
|
||
print(f" ✓ 新因子计算完成")
|
||
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
|
||
print(f" NaN 数量: {new_result.isna().sum()}")
|
||
|
||
# 4. 对比结果
|
||
print("\n4. 对比结果...")
|
||
|
||
# 检查索引是否一致
|
||
if not old_result.index.equals(new_result.index):
|
||
print(" ✗ 索引不一致")
|
||
return False
|
||
print(" ✓ 索引一致")
|
||
|
||
# 检查数值差异
|
||
diff = (old_result - new_result).abs()
|
||
max_diff = diff.max()
|
||
mean_diff = diff.mean()
|
||
|
||
print(f" 最大差异: {max_diff:.6e}")
|
||
print(f" 平均差异: {mean_diff:.6e}")
|
||
|
||
# 允许浮点数精度误差(1e-10)
|
||
tolerance = 1e-10
|
||
if max_diff < tolerance:
|
||
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
|
||
print("\n" + "=" * 60)
|
||
print(" ✓ 测试通过:新旧因子输出完全一致!")
|
||
print("=" * 60)
|
||
return True
|
||
else:
|
||
print(f" ✗ 差异超出容差范围")
|
||
print("\n" + "=" * 60)
|
||
print(" ✗ 测试失败:新旧因子输出不一致")
|
||
print("=" * 60)
|
||
|
||
# 打印前10个差异点
|
||
diff_nonzero = diff[diff > tolerance]
|
||
if len(diff_nonzero) > 0:
|
||
print(f"\n 前10个差异点:")
|
||
for date, val in diff_nonzero.head(10).items():
|
||
old_val = old_result.loc[date]
|
||
new_val = new_result.loc[date]
|
||
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
|
||
|
||
return False
|
||
|
||
|
||
if __name__ == '__main__':
|
||
success = test_momentum_factor_parity()
|
||
sys.exit(0 if success else 1)
|