## 核心功能 - CrossMarketAligner: 跨市场数据对齐(解决 ffill 陷阱) - Pydantic Schema: 数据结构验证(OHLCVInputSchema, AlignedFactorSchema 等) - 验证装饰器: @validate_factor_after_align, @validate_returns_after_align ## 解决的问题 - 跨市场交易日历不同(美股/港股/A股) - ffill 收益率陷阱(休市日复制非零收益率) - NaN 传播问题 - 日期不一致问题 ## 测试验证 - 5/5 测试通过(因子对齐、收益率对齐、多标的对齐、信号验证、ffill陷阱) - 休市日收益率 = 0%(正确) - 无 NaN 传播 ## 架构设计 - shared/data/alignment.py - 对齐器实现 - shared/data/schemas.py - Pydantic Schema 定义 - tests/test_alignment.py - 完整测试套件
293 lines
9.2 KiB
Python
293 lines
9.2 KiB
Python
"""
|
||
数据对齐测试
|
||
|
||
验证跨市场对齐器的正确性
|
||
"""
|
||
|
||
import sys
|
||
import pandas as pd
|
||
import numpy as np
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录
|
||
project_root = Path(__file__).parent.parent.parent
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
|
||
def test_factor_alignment():
|
||
"""测试因子对齐"""
|
||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||
|
||
print("=" * 60)
|
||
print(" 测试 1: 因子对齐")
|
||
print("=" * 60)
|
||
|
||
# 创建模拟数据
|
||
# 美股日历(比 A 股少几天)
|
||
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
us_dates = us_dates.delete([2, 5]) # 删除 2 天(模拟美股假日)
|
||
|
||
# A 股日历(完整)
|
||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
|
||
# 因子值(美股日历)
|
||
factor_values = [0.1, 0.15, 0.18, 0.16, 0.20, 0.22, 0.25, 0.23]
|
||
factor_series = pd.Series(factor_values, index=us_dates)
|
||
|
||
print(f"\n源日历(美股): {len(us_dates)} 天")
|
||
print(f"目标日历(A股): {len(cn_dates)} 天")
|
||
print(f"因子值(源日历):")
|
||
print(factor_series)
|
||
|
||
# 对齐
|
||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||
aligned = aligner.align_factor(factor_series, us_dates, code='^GSPC')
|
||
|
||
print(f"\n对齐后因子值:")
|
||
print(aligned)
|
||
|
||
# 验证
|
||
assert len(aligned) == len(cn_dates), "对齐后长度应该等于目标日历"
|
||
assert 'value' in aligned.columns, "应该有 'value' 列"
|
||
assert 'is_filled' in aligned.columns, "应该有 'is_filled' 列"
|
||
|
||
# 检查填充值
|
||
filled_count = aligned['is_filled'].sum()
|
||
print(f"\n填充值数量: {filled_count}")
|
||
|
||
print("\n✓ 测试通过")
|
||
return True
|
||
|
||
|
||
def test_returns_alignment():
|
||
"""测试收益率对齐"""
|
||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 2: 收益率对齐")
|
||
print("=" * 60)
|
||
|
||
# 创建模拟价格数据
|
||
# 美股日历
|
||
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
us_dates = us_dates.delete([2, 5])
|
||
|
||
# A 股日历
|
||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
|
||
# 价格(美股日历)
|
||
prices = [100, 101, 102, 101, 103, 104, 105, 104]
|
||
close_series = pd.Series(prices, index=us_dates, dtype=float)
|
||
|
||
print(f"\n源价格(美股日历):")
|
||
print(close_series)
|
||
|
||
# 对齐收益率
|
||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||
returns = aligner.align_returns(close_series, code='^GSPC')
|
||
|
||
print(f"\n对齐后收益率(A股日历):")
|
||
print(returns)
|
||
|
||
# 验证
|
||
assert len(returns) == len(cn_dates), "收益率长度应该等于目标日历"
|
||
assert returns.index.equals(cn_dates), "收益率索引应该等于目标日历"
|
||
assert not returns.isna().any(), "收益率不应该有 NaN"
|
||
assert returns.iloc[0] == 0.0, "首日收益率应该为 0"
|
||
|
||
# 检查休市日收益率(应该为 0)
|
||
# 美股休市日,价格 ffill,收益率 = 0
|
||
print(f"\n收益率统计:")
|
||
print(f" 最小值: {returns.min():.4f}")
|
||
print(f" 最大值: {returns.max():.4f}")
|
||
print(f" 均值: {returns.mean():.4f}")
|
||
|
||
print("\n✓ 测试通过")
|
||
return True
|
||
|
||
|
||
def test_multi_asset_alignment():
|
||
"""测试多标的对齐"""
|
||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 3: 多标的收益率对齐")
|
||
print("=" * 60)
|
||
|
||
# A 股日历
|
||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
|
||
# 多个标的(不同日历)
|
||
close_dict = {}
|
||
|
||
# 标的 1: 完整数据
|
||
prices1 = 100 + np.cumsum(np.random.randn(10))
|
||
close_dict['^GSPC'] = pd.Series(prices1, index=cn_dates, dtype=float)
|
||
|
||
# 标的 2: 少 2 天
|
||
us_dates2 = cn_dates.delete([2, 5])
|
||
prices2 = 100 + np.cumsum(np.random.randn(8))
|
||
close_dict['^IXIC'] = pd.Series(prices2, index=us_dates2, dtype=float)
|
||
|
||
# 标的 3: 完整数据
|
||
prices3 = 100 + np.cumsum(np.random.randn(10))
|
||
close_dict['931862.CSI'] = pd.Series(prices3, index=cn_dates, dtype=float)
|
||
|
||
print(f"\n标的数量: {len(close_dict)}")
|
||
for code, close in close_dict.items():
|
||
print(f" {code}: {len(close)} 天")
|
||
|
||
# 对齐多标的
|
||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||
returns_df = aligner.align_multi_asset(close_dict)
|
||
|
||
print(f"\n对齐后收益率 DataFrame:")
|
||
print(returns_df.head())
|
||
|
||
# 验证
|
||
assert len(returns_df) == len(cn_dates), "收益率 DataFrame 长度应该等于目标日历"
|
||
assert returns_df.index.equals(cn_dates), "索引应该等于目标日历"
|
||
assert not returns_df.isna().any().any(), "收益率 DataFrame 不应该有 NaN"
|
||
assert len(returns_df.columns) == len(close_dict), "列数应该等于标的数"
|
||
|
||
print(f"\n✓ 测试通过")
|
||
return True
|
||
|
||
|
||
def test_signal_returns_alignment():
|
||
"""测试信号与收益率对齐"""
|
||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 4: 信号与收益率对齐验证")
|
||
print("=" * 60)
|
||
|
||
# A 股日历
|
||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||
|
||
# 信号(多 2 天)
|
||
signal_dates = cn_dates.union(pd.date_range('2024-01-15', periods=2, freq='B'))
|
||
signals = pd.DataFrame({
|
||
'signal': ['^GSPC,^IXIC'] * len(signal_dates)
|
||
}, index=signal_dates)
|
||
|
||
# 收益率(正常)
|
||
returns_df = pd.DataFrame(
|
||
np.random.randn(len(cn_dates), 2) * 0.02,
|
||
index=cn_dates,
|
||
columns=['^GSPC', '^IXIC']
|
||
)
|
||
|
||
print(f"\n信号: {len(signals)} 天")
|
||
print(f"收益: {len(returns_df)} 天")
|
||
|
||
# 验证对齐
|
||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||
aligned_signals, aligned_returns = aligner.validate_alignment(signals, returns_df)
|
||
|
||
print(f"\n对齐后:")
|
||
print(f" 信号: {len(aligned_signals)} 天")
|
||
print(f" 收益: {len(aligned_returns)} 天")
|
||
|
||
# 验证
|
||
assert len(aligned_signals) == len(aligned_returns), "对齐后长度应该一致"
|
||
assert aligned_signals.index.equals(aligned_returns.index), "索引应该一致"
|
||
|
||
print("\n✓ 测试通过")
|
||
return True
|
||
|
||
|
||
def test_ffill_trap():
|
||
"""测试 ffill 陷阱(错误 vs 正确)"""
|
||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" 测试 5: ffill 陷阱对比")
|
||
print("=" * 60)
|
||
|
||
# 美股日历
|
||
us_dates = pd.date_range('2024-01-01', periods=5, freq='B')
|
||
us_dates = us_dates.delete([2]) # 第 3 天休市
|
||
|
||
# A 股日历
|
||
cn_dates = pd.date_range('2024-01-01', periods=5, freq='B')
|
||
|
||
# 价格
|
||
prices = pd.Series([100, 101, 102, 103], index=us_dates, dtype=float)
|
||
|
||
print(f"\n原始价格(美股日历):")
|
||
print(prices)
|
||
|
||
# ❌ 错误做法:先计算收益率,再 ffill
|
||
print("\n❌ 错误做法:先 pct_change,再 reindex")
|
||
returns_wrong = prices.pct_change()
|
||
print("步骤 1 - 收益率:")
|
||
print(returns_wrong)
|
||
|
||
returns_wrong_aligned = returns_wrong.reindex(cn_dates, method='ffill')
|
||
print("\n步骤 2 - reindex + ffill:")
|
||
print(returns_wrong_aligned)
|
||
print("⚠ 问题:第 3 天复制了第 2 天的 +1% 收益率!")
|
||
|
||
# ✅ 正确做法:先 ffill 价格,再计算收益率
|
||
print("\n✅ 正确做法:先 reindex 价格,再 pct_change")
|
||
prices_aligned = prices.reindex(cn_dates, method='ffill')
|
||
print("步骤 1 - 价格 reindex:")
|
||
print(prices_aligned)
|
||
|
||
returns_correct = prices_aligned.pct_change(fill_method=None)
|
||
returns_correct.iloc[0] = 0.0
|
||
print("\n步骤 2 - pct_change:")
|
||
print(returns_correct)
|
||
print("✓ 第 3 天收益率 = 0%(价格不变)")
|
||
|
||
# 验证
|
||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||
returns = aligner.align_returns(prices, code='TEST')
|
||
|
||
assert returns.iloc[2] == 0.0, "休市日收益率应该为 0"
|
||
|
||
print("\n✓ 测试通过")
|
||
return True
|
||
|
||
|
||
if __name__ == '__main__':
|
||
print("\n" + "=" * 60)
|
||
print(" 跨市场数据对齐器测试")
|
||
print("=" * 60)
|
||
|
||
tests = [
|
||
("因子对齐", test_factor_alignment),
|
||
("收益率对齐", test_returns_alignment),
|
||
("多标的对齐", test_multi_asset_alignment),
|
||
("信号与收益对齐", test_signal_returns_alignment),
|
||
("ffill 陷阱", test_ffill_trap),
|
||
]
|
||
|
||
results = []
|
||
for name, test_func in tests:
|
||
try:
|
||
success = test_func()
|
||
results.append((name, success))
|
||
except Exception as e:
|
||
print(f"\n✗ {name} 失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
results.append((name, False))
|
||
|
||
# 总结
|
||
print("\n" + "=" * 60)
|
||
print(" 测试总结")
|
||
print("=" * 60)
|
||
|
||
passed = sum(1 for _, success in results if success)
|
||
total = len(results)
|
||
|
||
for name, success in results:
|
||
status = "✓ 通过" if success else "✗ 失败"
|
||
print(f" {status} - {name}")
|
||
|
||
print(f"\n总计: {passed}/{total} 通过")
|
||
|
||
sys.exit(0 if passed == total else 1)
|