feat(framework_v2): 创建框架V2骨架 - 三层架构+因子验证通过
## 架构设计 - 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试) - 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher - 零侵入:与现有框架并行开发,不修改生产代码 ## 已完成 ✓ 核心接口层(5个ABC类) ✓ 通用因子层(MomentumFactor完全复制现有逻辑) ✓ 对比验证测试(新旧因子输出差异=0,测试通过) ## 验证结果 - 最大差异: 0.000000e+00 - 平均差异: 0.000000e+00 - 容差: < 1e-10 ## 下一步 - 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController) - 阶段4: 执行层迁移(BacktestRunner) - 阶段5: 数据层迁移(DataFetcher实现) - 阶段6: 完整策略对比验证 ## 设计原则 - 按需抽象,不预先设计 - 职责分离,避免框架膨胀 - 测试驱动,每个组件必须有对比测试 - 渐进式迁移,验证通过再替换
This commit is contained in:
116
framework_v2/tests/test_momentum_parity.py
Normal file
116
framework_v2/tests/test_momentum_parity.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
因子对比验证测试
|
||||
|
||||
验证新框架的 MomentumFactor 与现有实现输出一致
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def test_momentum_factor_parity():
|
||||
"""验证新因子与旧因子输出一致"""
|
||||
|
||||
print("=" * 60)
|
||||
print(" MomentumFactor 对比测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载测试数据
|
||||
print("\n1. 加载测试数据...")
|
||||
test_data_path = project_root / 'data' / 'index_history_data'
|
||||
|
||||
# 使用纳指100数据测试
|
||||
import glob
|
||||
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
|
||||
if ndx_files:
|
||||
test_file = ndx_files[0]
|
||||
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
|
||||
print(f" ✓ 加载 {test_file}")
|
||||
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
|
||||
print(f" 数据长度: {len(data)} 条")
|
||||
else:
|
||||
print(" ⚠ 未找到测试数据,使用模拟数据")
|
||||
# 生成模拟数据
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range('2020-01-01', periods=500, freq='B')
|
||||
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
|
||||
data = pd.DataFrame({
|
||||
'close': prices,
|
||||
'open': prices * 0.99,
|
||||
'high': prices * 1.01,
|
||||
'low': prices * 0.98,
|
||||
'volume': np.random.randint(1000000, 10000000, 500)
|
||||
}, index=dates)
|
||||
|
||||
# 2. 计算旧因子
|
||||
print("\n2. 计算旧因子(strategies/shared/factors/momentum.py)...")
|
||||
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
|
||||
|
||||
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
old_result = old_factor.compute(data)
|
||||
print(f" ✓ 旧因子计算完成")
|
||||
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
|
||||
print(f" NaN 数量: {old_result.isna().sum()}")
|
||||
|
||||
# 3. 计算新因子
|
||||
print("\n3. 计算新因子(framework_v2/shared/factors/momentum.py)...")
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
|
||||
|
||||
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
new_result = new_factor.compute(data)
|
||||
print(f" ✓ 新因子计算完成")
|
||||
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
|
||||
print(f" NaN 数量: {new_result.isna().sum()}")
|
||||
|
||||
# 4. 对比结果
|
||||
print("\n4. 对比结果...")
|
||||
|
||||
# 检查索引是否一致
|
||||
if not old_result.index.equals(new_result.index):
|
||||
print(" ✗ 索引不一致")
|
||||
return False
|
||||
print(" ✓ 索引一致")
|
||||
|
||||
# 检查数值差异
|
||||
diff = (old_result - new_result).abs()
|
||||
max_diff = diff.max()
|
||||
mean_diff = diff.mean()
|
||||
|
||||
print(f" 最大差异: {max_diff:.6e}")
|
||||
print(f" 平均差异: {mean_diff:.6e}")
|
||||
|
||||
# 允许浮点数精度误差(1e-10)
|
||||
tolerance = 1e-10
|
||||
if max_diff < tolerance:
|
||||
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✓ 测试通过:新旧因子输出完全一致!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ 差异超出容差范围")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✗ 测试失败:新旧因子输出不一致")
|
||||
print("=" * 60)
|
||||
|
||||
# 打印前10个差异点
|
||||
diff_nonzero = diff[diff > tolerance]
|
||||
if len(diff_nonzero) > 0:
|
||||
print(f"\n 前10个差异点:")
|
||||
for date, val in diff_nonzero.head(10).items():
|
||||
old_val = old_result.loc[date]
|
||||
new_val = new_result.loc[date]
|
||||
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
success = test_momentum_factor_parity()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user