refactor(archive): move unused modules to archive/

Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
2026-06-03 23:41:46 +08:00
parent d700bc1dfd
commit c905230a40
98 changed files with 0 additions and 714 deletions

View File

@@ -0,0 +1,3 @@
"""
框架 V2 测试
"""

View File

@@ -0,0 +1,292 @@
"""
数据对齐测试
验证跨市场对齐器的正确性
"""
import sys
import pandas as pd
import numpy as np
from pathlib import Path
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
def test_factor_alignment():
"""测试因子对齐"""
from framework_v2.shared.data.alignment import CrossMarketAligner
print("=" * 60)
print(" 测试 1: 因子对齐")
print("=" * 60)
# 创建模拟数据
# 美股日历(比 A 股少几天)
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
us_dates = us_dates.delete([2, 5]) # 删除 2 天(模拟美股假日)
# A 股日历(完整)
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
# 因子值(美股日历)
factor_values = [0.1, 0.15, 0.18, 0.16, 0.20, 0.22, 0.25, 0.23]
factor_series = pd.Series(factor_values, index=us_dates)
print(f"\n源日历(美股): {len(us_dates)}")
print(f"目标日历A股: {len(cn_dates)}")
print(f"因子值(源日历):")
print(factor_series)
# 对齐
aligner = CrossMarketAligner(target_calendar=cn_dates)
aligned = aligner.align_factor(factor_series, us_dates, code='^GSPC')
print(f"\n对齐后因子值:")
print(aligned)
# 验证
assert len(aligned) == len(cn_dates), "对齐后长度应该等于目标日历"
assert 'value' in aligned.columns, "应该有 'value'"
assert 'is_filled' in aligned.columns, "应该有 'is_filled'"
# 检查填充值
filled_count = aligned['is_filled'].sum()
print(f"\n填充值数量: {filled_count}")
print("\n✓ 测试通过")
return True
def test_returns_alignment():
"""测试收益率对齐"""
from framework_v2.shared.data.alignment import CrossMarketAligner
print("\n" + "=" * 60)
print(" 测试 2: 收益率对齐")
print("=" * 60)
# 创建模拟价格数据
# 美股日历
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
us_dates = us_dates.delete([2, 5])
# A 股日历
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
# 价格(美股日历)
prices = [100, 101, 102, 101, 103, 104, 105, 104]
close_series = pd.Series(prices, index=us_dates, dtype=float)
print(f"\n源价格(美股日历):")
print(close_series)
# 对齐收益率
aligner = CrossMarketAligner(target_calendar=cn_dates)
returns = aligner.align_returns(close_series, code='^GSPC')
print(f"\n对齐后收益率A股日历:")
print(returns)
# 验证
assert len(returns) == len(cn_dates), "收益率长度应该等于目标日历"
assert returns.index.equals(cn_dates), "收益率索引应该等于目标日历"
assert not returns.isna().any(), "收益率不应该有 NaN"
assert returns.iloc[0] == 0.0, "首日收益率应该为 0"
# 检查休市日收益率(应该为 0
# 美股休市日,价格 ffill收益率 = 0
print(f"\n收益率统计:")
print(f" 最小值: {returns.min():.4f}")
print(f" 最大值: {returns.max():.4f}")
print(f" 均值: {returns.mean():.4f}")
print("\n✓ 测试通过")
return True
def test_multi_asset_alignment():
"""测试多标的对齐"""
from framework_v2.shared.data.alignment import CrossMarketAligner
print("\n" + "=" * 60)
print(" 测试 3: 多标的收益率对齐")
print("=" * 60)
# A 股日历
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
# 多个标的(不同日历)
close_dict = {}
# 标的 1: 完整数据
prices1 = 100 + np.cumsum(np.random.randn(10))
close_dict['^GSPC'] = pd.Series(prices1, index=cn_dates, dtype=float)
# 标的 2: 少 2 天
us_dates2 = cn_dates.delete([2, 5])
prices2 = 100 + np.cumsum(np.random.randn(8))
close_dict['^IXIC'] = pd.Series(prices2, index=us_dates2, dtype=float)
# 标的 3: 完整数据
prices3 = 100 + np.cumsum(np.random.randn(10))
close_dict['931862.CSI'] = pd.Series(prices3, index=cn_dates, dtype=float)
print(f"\n标的数量: {len(close_dict)}")
for code, close in close_dict.items():
print(f" {code}: {len(close)}")
# 对齐多标的
aligner = CrossMarketAligner(target_calendar=cn_dates)
returns_df = aligner.align_multi_asset(close_dict)
print(f"\n对齐后收益率 DataFrame:")
print(returns_df.head())
# 验证
assert len(returns_df) == len(cn_dates), "收益率 DataFrame 长度应该等于目标日历"
assert returns_df.index.equals(cn_dates), "索引应该等于目标日历"
assert not returns_df.isna().any().any(), "收益率 DataFrame 不应该有 NaN"
assert len(returns_df.columns) == len(close_dict), "列数应该等于标的数"
print(f"\n✓ 测试通过")
return True
def test_signal_returns_alignment():
"""测试信号与收益率对齐"""
from framework_v2.shared.data.alignment import CrossMarketAligner
print("\n" + "=" * 60)
print(" 测试 4: 信号与收益率对齐验证")
print("=" * 60)
# A 股日历
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
# 信号(多 2 天)
signal_dates = cn_dates.union(pd.date_range('2024-01-15', periods=2, freq='B'))
signals = pd.DataFrame({
'signal': ['^GSPC,^IXIC'] * len(signal_dates)
}, index=signal_dates)
# 收益率(正常)
returns_df = pd.DataFrame(
np.random.randn(len(cn_dates), 2) * 0.02,
index=cn_dates,
columns=['^GSPC', '^IXIC']
)
print(f"\n信号: {len(signals)}")
print(f"收益: {len(returns_df)}")
# 验证对齐
aligner = CrossMarketAligner(target_calendar=cn_dates)
aligned_signals, aligned_returns = aligner.validate_alignment(signals, returns_df)
print(f"\n对齐后:")
print(f" 信号: {len(aligned_signals)}")
print(f" 收益: {len(aligned_returns)}")
# 验证
assert len(aligned_signals) == len(aligned_returns), "对齐后长度应该一致"
assert aligned_signals.index.equals(aligned_returns.index), "索引应该一致"
print("\n✓ 测试通过")
return True
def test_ffill_trap():
"""测试 ffill 陷阱(错误 vs 正确)"""
from framework_v2.shared.data.alignment import CrossMarketAligner
print("\n" + "=" * 60)
print(" 测试 5: ffill 陷阱对比")
print("=" * 60)
# 美股日历
us_dates = pd.date_range('2024-01-01', periods=5, freq='B')
us_dates = us_dates.delete([2]) # 第 3 天休市
# A 股日历
cn_dates = pd.date_range('2024-01-01', periods=5, freq='B')
# 价格
prices = pd.Series([100, 101, 102, 103], index=us_dates, dtype=float)
print(f"\n原始价格(美股日历):")
print(prices)
# ❌ 错误做法:先计算收益率,再 ffill
print("\n❌ 错误做法:先 pct_change再 reindex")
returns_wrong = prices.pct_change()
print("步骤 1 - 收益率:")
print(returns_wrong)
returns_wrong_aligned = returns_wrong.reindex(cn_dates, method='ffill')
print("\n步骤 2 - reindex + ffill:")
print(returns_wrong_aligned)
print("⚠ 问题:第 3 天复制了第 2 天的 +1% 收益率!")
# ✅ 正确做法:先 ffill 价格,再计算收益率
print("\n✅ 正确做法:先 reindex 价格,再 pct_change")
prices_aligned = prices.reindex(cn_dates, method='ffill')
print("步骤 1 - 价格 reindex:")
print(prices_aligned)
returns_correct = prices_aligned.pct_change(fill_method=None)
returns_correct.iloc[0] = 0.0
print("\n步骤 2 - pct_change:")
print(returns_correct)
print("✓ 第 3 天收益率 = 0%(价格不变)")
# 验证
aligner = CrossMarketAligner(target_calendar=cn_dates)
returns = aligner.align_returns(prices, code='TEST')
assert returns.iloc[2] == 0.0, "休市日收益率应该为 0"
print("\n✓ 测试通过")
return True
if __name__ == '__main__':
print("\n" + "=" * 60)
print(" 跨市场数据对齐器测试")
print("=" * 60)
tests = [
("因子对齐", test_factor_alignment),
("收益率对齐", test_returns_alignment),
("多标的对齐", test_multi_asset_alignment),
("信号与收益对齐", test_signal_returns_alignment),
("ffill 陷阱", test_ffill_trap),
]
results = []
for name, test_func in tests:
try:
success = test_func()
results.append((name, success))
except Exception as e:
print(f"\n{name} 失败: {e}")
import traceback
traceback.print_exc()
results.append((name, False))
# 总结
print("\n" + "=" * 60)
print(" 测试总结")
print("=" * 60)
passed = sum(1 for _, success in results if success)
total = len(results)
for name, success in results:
status = "✓ 通过" if success else "✗ 失败"
print(f" {status} - {name}")
print(f"\n总计: {passed}/{total} 通过")
sys.exit(0 if passed == total else 1)

View File

@@ -0,0 +1,285 @@
"""
测试配置加载和验证
验证:
1. 配置文件加载
2. Pydantic Schema 验证
3. 环境变量替换
4. 错误处理
"""
import sys
from pathlib import Path
import os
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.config import load_config, ConfigLoader
from framework_v2.config.schemas import RotationStrategyConfig, MarketType
def test_load_config():
"""测试 1: 加载配置文件"""
print("\n" + "=" * 70)
print(" 测试 1: 加载配置文件")
print("=" * 70)
# 设置环境变量(模拟)
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
os.environ['TUSHARE_TOKEN'] = 'test_token_123'
# 加载配置
config = load_config('rotation_example.yaml')
print(f"\n✓ 配置加载成功")
print(f" 版本: {config.metadata.version}")
print(f" 策略: {config.metadata.strategy}")
print(f" 资产池: {len(config.asset_pools.equity)} 股票, "
f"{len(config.asset_pools.commodity)} 商品, "
f"{len(config.asset_pools.fixed_income)} 债券")
# 验证基本字段
assert config.metadata.version == "1.0.0"
assert config.factor.n_days == 25
assert config.rotation.select_num == 3
print("\n✓ 测试通过")
def test_asset_pools():
"""测试 2: 资产池配置"""
print("\n" + "=" * 70)
print(" 测试 2: 资产池配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
# 验证股票资产
print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):")
for code, asset in config.asset_pools.equity.items():
print(f" {code}: {asset.name} ({asset.market.value})")
if asset.etf:
print(f" ETF: {asset.etf}")
# 验证商品资产
print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):")
for code, asset in config.asset_pools.commodity.items():
print(f" {code}: {asset.name}")
# 验证固定收益
print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):")
for code, asset in config.asset_pools.fixed_income.items():
print(f" {code}: {asset.name} (ETF: {asset.etf})")
# 验证市场类型
assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY
assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY
assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY
print("\n✓ 测试通过")
def test_threshold_config():
"""测试 3: 阈值配置"""
print("\n" + "=" * 70)
print(" 测试 3: 阈值配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
print(f"\n阈值模式: {config.rotation.threshold.mode}")
print(f" 参考标的: {config.rotation.threshold.dynamic.reference}")
print(f" 倍数: {config.rotation.threshold.dynamic.ratio}")
print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}")
assert config.rotation.threshold.mode.value == "dynamic"
assert config.rotation.threshold.dynamic.reference == "931862.CSI"
print("\n✓ 测试通过")
def test_data_sources():
"""测试 4: 数据源配置"""
print("\n" + "=" * 70)
print(" 测试 4: 数据源配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
print(f"\n数据源 ({len(config.data.sources)} 个):")
for i, source in enumerate(config.data.sources, 1):
print(f" {i}. {source.type.value}")
print(f" 启用: {source.enabled}")
print(f" 超时: {source.timeout}s")
if source.url:
print(f" URL: {source.url}")
# 验证环境变量替换
flask_api_source = config.data.sources[0]
assert flask_api_source.url == 'https://k3s.tokenpluse.xyz'
print("\n✓ 测试通过")
def test_validation_errors():
"""测试 5: 验证错误处理"""
print("\n" + "=" * 70)
print(" 测试 5: 验证错误处理")
print("=" * 70)
# 测试 1: n_days 超出范围
print("\n[5.1] 测试 n_days 超出范围...")
try:
from framework_v2.config.schemas import FactorConfig
# n_days = 1000超出 5-250 范围)
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
"benchmark": {"code": "000300.SH", "name": "沪深300"},
"backtest": {"start_date": "2020-01-01"},
"factor": {"n_days": 1000}, # 错误:超出范围
"data": {
"sources": [{"type": "flask_api", "url": "test"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except Exception as e:
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
# 测试 2: 缺少必需字段
print("\n[5.2] 测试缺少必需字段...")
try:
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
# 缺少 benchmark
"backtest": {"start_date": "2020-01-01"},
"data": {
"sources": [{"type": "flask_api", "url": "test"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except Exception as e:
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
# 测试 3: 环境变量未设置
print("\n[5.3] 测试环境变量未设置...")
try:
# 删除环境变量
old_value = os.environ.pop('FLASK_API_URL', None)
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
"benchmark": {"code": "000300.SH", "name": "沪深300"},
"backtest": {"start_date": "2020-01-01"},
"data": {
"sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except ValueError as e:
print(f" ✓ 正确捕获环境变量错误: {e}")
finally:
# 恢复环境变量
if old_value:
os.environ['FLASK_API_URL'] = old_value
print("\n✓ 测试通过")
def test_env_substitution():
"""测试 6: 环境变量替换"""
print("\n" + "=" * 70)
print(" 测试 6: 环境变量替换")
print("=" * 70)
loader = ConfigLoader()
# 测试 1: 基本替换
print("\n[6.1] 基本替换...")
os.environ['TEST_VAR'] = 'test_value'
config = {
"url": "${TEST_VAR}"
}
result = loader._substitute_env_vars(config)
assert result["url"] == "test_value"
print(f" ✓ ${{TEST_VAR}}{result['url']}")
# 测试 2: 默认值
print("\n[6.2] 默认值...")
config = {
"url": "${NON_EXISTENT_VAR:default_value}"
}
result = loader._substitute_env_vars(config)
assert result["url"] == "default_value"
print(f" ${{NON_EXISTENT_VAR:default_value}}{result['url']}")
# 测试 3: 嵌套结构
print("\n[6.3] 嵌套结构...")
os.environ['API_URL'] = 'https://api.example.com'
config = {
"data": {
"sources": [
{"url": "${API_URL}"}
]
}
}
result = loader._substitute_env_vars(config)
assert result["data"]["sources"][0]["url"] == "https://api.example.com"
print(f" ✓ 嵌套替换成功")
print("\n✓ 测试通过")
if __name__ == "__main__":
print("\n" + "=" * 70)
print(" 配置加载和验证测试")
print("=" * 70)
tests = [
("加载配置文件", test_load_config),
("资产池配置", test_asset_pools),
("阈值配置", test_threshold_config),
("数据源配置", test_data_sources),
("验证错误处理", test_validation_errors),
("环境变量替换", test_env_substitution),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 70)
print(" 测试总结")
print("=" * 70)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 70 + "\n")
if failed > 0:
sys.exit(1)

View File

@@ -0,0 +1,468 @@
"""
端到端集成测试:数据获取 → 因子计算 → 数据对齐 → 信号生成
测试场景:
1. 获取纳指美股和创业板A股数据
2. 计算动量因子
3. 对齐到 A 股交易日历
4. 生成 Top-N 信号
5. 验证完整流程
目标:
- 验证 FlaskAPIFetcher 数据获取
- 验证 MomentumFactor 因子计算
- 验证 CrossMarketAligner 数据对齐
- 验证完整流程无数据泄漏
"""
import sys
from pathlib import Path
import pandas as pd
import numpy as np
from typing import Dict
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.shared.data import FlaskAPIFetcher, CrossMarketAligner
from framework_v2.shared.factors.momentum import MomentumFactor
def test_stage1_data_fetch():
"""
阶段 1: 数据获取
获取纳指(^IXIC和创业板399006.SZ数据
"""
print("\n" + "=" * 70)
print(" 阶段 1: 数据获取")
print("=" * 70)
fetcher = FlaskAPIFetcher()
# 获取纳指数据(美股)
print("\n[1.1] 获取纳斯达克指数数据(美股)...")
us_data = fetcher.fetch_indices(
codes=["^IXIC"],
start="2023-01-01",
end="2024-12-31"
)
assert "^IXIC" in us_data, "纳指数据获取失败"
df_nasdaq = us_data["^IXIC"]
print(f"\n纳指数据:")
print(f" 数据量: {len(df_nasdaq)}")
print(f" 日期范围: {df_nasdaq.index[0]} ~ {df_nasdaq.index[-1]}")
print(f" 列: {list(df_nasdaq.columns)}")
print(f" 前 3 行:")
print(df_nasdaq.head(3).to_string())
# 获取创业板数据A股
print("\n[1.2] 获取创业板指数数据A股...")
cn_data = fetcher.fetch_indices(
codes=["399006.SZ"],
start="2023-01-01",
end="2024-12-31"
)
assert "399006.SZ" in cn_data, "创业板数据获取失败"
df_gem = cn_data["399006.SZ"]
print(f"\n创业板数据:")
print(f" 数据量: {len(df_gem)}")
print(f" 日期范围: {df_gem.index[0]} ~ {df_gem.index[-1]}")
print(f" 列: {list(df_gem.columns)}")
print(f" 前 3 行:")
print(df_gem.head(3).to_string())
# 对比日历差异
print(f"\n[1.3] 交易日历对比:")
nasdaq_dates = set(df_nasdaq.index)
gem_dates = set(df_gem.index)
common_dates = nasdaq_dates & gem_dates
only_nasdaq = nasdaq_dates - gem_dates
only_gem = gem_dates - nasdaq_dates
print(f" 纳指交易日: {len(nasdaq_dates)}")
print(f" 创业板交易日: {len(gem_dates)}")
print(f" 共同交易日: {len(common_dates)}")
print(f" 仅纳指交易: {len(only_nasdaq)}")
print(f" 仅创业板交易: {len(only_gem)}")
if len(only_nasdaq) > 0:
print(f" 纳指独有日期示例: {sorted(list(only_nasdaq))[:3]}")
if len(only_gem) > 0:
print(f" 创业板独有日期示例: {sorted(list(only_gem))[:3]}")
print("\n✓ 阶段 1 通过")
return {
"^IXIC": df_nasdaq,
"399006.SZ": df_gem
}
def test_stage2_factor_calculation(data_dict: Dict[str, pd.DataFrame]):
"""
阶段 2: 因子计算
计算动量因子(在原始日历上)
"""
print("\n" + "=" * 70)
print(" 阶段 2: 因子计算(原始日历)")
print("=" * 70)
factor_calc = MomentumFactor(n_days=20)
factors = {}
for code, df in data_dict.items():
print(f"\n[2.1] 计算 {code} 动量因子...")
# compute 方法接受 DataFrame
factor_series = factor_calc.compute(df)
# 转换为 DataFrame 格式
factor_result = pd.DataFrame({
'value': factor_series,
'is_filled': False
})
factors[code] = factor_result
print(f" 因子值数量: {len(factor_result)}")
print(f" 日期范围: {factor_result.index[0]} ~ {factor_result.index[-1]}")
print(f" 前 3 行:")
print(factor_result.head(3).to_string())
# 统计 NaN
nan_count = factor_result['value'].isna().sum()
print(f" NaN 数量: {nan_count} ({nan_count/len(factor_result):.1%})")
# 验证因子值合理
valid_factors = factor_result['value'].dropna()
if len(valid_factors) > 0:
print(f" 因子值范围: {valid_factors.min():.4f} ~ {valid_factors.max():.4f}")
print("\n✓ 阶段 2 通过")
return factors
def test_stage3_data_alignment(
factors: Dict[str, pd.DataFrame],
data_dict: Dict[str, pd.DataFrame]
):
"""
阶段 3: 数据对齐
将因子和收益率对齐到 A 股交易日历
"""
print("\n" + "=" * 70)
print(" 阶段 3: 数据对齐(到 A 股日历)")
print("=" * 70)
fetcher = FlaskAPIFetcher()
# 获取 A 股交易日历(通过 API
print("\n[3.1] 获取 A 股交易日历(通过 API...")
# 裁剪到数据日期范围
data_start = min(df.index[0] for df in data_dict.values())
data_end = max(df.index[-1] for df in data_dict.values())
# 使用 API 获取准确日历
a_share_calendar = fetcher.get_trading_calendar(
market='A',
start=data_start.strftime('%Y-%m-%d'),
end=data_end.strftime('%Y-%m-%d')
)
print(f" A 股交易日: {len(a_share_calendar)}")
print(f" 日期范围: {a_share_calendar[0]} ~ {a_share_calendar[-1]}")
# 创建对齐器
aligner = CrossMarketAligner(target_calendar=a_share_calendar)
# 对齐因子
print("\n[3.2] 对齐因子到 A 股日历...")
aligned_factors = {}
for code, factor_df in factors.items():
print(f"\n 对齐 {code} 因子...")
# 获取原始日历
original_calendar = factor_df.index
# 对齐因子
aligned = aligner.align_factor(
factor_series=factor_df['value'],
source_calendar=original_calendar,
code=code
)
aligned_factors[code] = aligned
# 统计
filled_count = aligned['is_filled'].sum()
print(f" 对齐后天数: {len(aligned)}")
print(f" 填充天数: {filled_count} ({filled_count/len(aligned):.1%})")
print(f" NaN 数量: {aligned['value'].isna().sum()}")
# 对齐收益率
print("\n[3.3] 对齐收益率到 A 股日历...")
aligned_returns = {}
for code, df in data_dict.items():
print(f"\n 对齐 {code} 收益率...")
returns = aligner.align_returns(
close_series=df['close'],
code=code
)
aligned_returns[code] = returns
# 统计
print(f" 对齐后天数: {len(returns)}")
print(f" 收益率范围: {returns.min():.4%} ~ {returns.max():.4%}")
print(f" NaN 数量: {returns.isna().sum()}")
print(f" 零收益率天数: {(returns == 0).sum()} (休市日)")
# 验证对齐结果
print("\n[3.4] 验证对齐结果...")
# 1. 所有 DataFrame 应该有相同的索引
indices = [df.index for df in aligned_factors.values()]
indices.extend([s.index for s in aligned_returns.values()])
for i, idx1 in enumerate(indices):
for j, idx2 in enumerate(indices):
if i != j:
assert idx1.equals(idx2), f"索引 {i}{j} 不一致"
print(f" ✓ 所有数据对齐到同一日历: {len(indices[0])}")
print(f" ✓ 日期范围: {indices[0][0]} ~ {indices[0][-1]}")
# 2. 验证收益率无 NaN
for code, returns in aligned_returns.items():
assert returns.isna().sum() == 0, f"{code} 收益率包含 NaN"
print(f" ✓ 收益率无 NaN")
# 3. 验证休市日收益率 = 0
for code, returns in aligned_returns.items():
zero_days = (returns == 0).sum()
print(f" {code} 休市日收益率 = 0: {zero_days}")
print("\n✓ 阶段 3 通过")
return aligned_factors, aligned_returns
def test_stage4_signal_generation(
aligned_factors: Dict[str, pd.DataFrame],
aligned_returns: Dict[str, pd.Series]
):
"""
阶段 4: 信号生成
根据对齐后的因子生成 Top-N 信号
"""
print("\n" + "=" * 70)
print(" 阶段 4: 信号生成")
print("=" * 70)
# 合并因子值
print("\n[4.1] 合并因子值...")
factor_values = pd.DataFrame()
for code, factor_df in aligned_factors.items():
factor_values[code] = factor_df['value']
print(f" 合并后形状: {factor_values.shape}")
print(f" 列: {list(factor_values.columns)}")
print(f" 前 3 行:")
print(factor_values.head(3).to_string())
# 简单信号:选择因子值最高的标的
print("\n[4.2] 生成信号Top-1...")
# 跳过全为 NaN 的行
valid_rows = factor_values.dropna(how='all').index
factor_valid = factor_values.loc[valid_rows]
signals = pd.DataFrame()
signals['best'] = factor_valid.idxmax(axis=1)
signals['best_value'] = factor_valid.max(axis=1)
print(f" 信号数量: {len(signals)}")
print(f" 前 10 个信号:")
print(signals.head(10).to_string())
# 统计选择分布
print(f"\n[4.3] 标的选择分布:")
distribution = signals['best'].value_counts()
for code, count in distribution.items():
pct = count / len(signals)
print(f" {code}: {count} 天 ({pct:.1%})")
# 验证信号与收益率对齐
print("\n[4.4] 验证信号与收益率对齐...")
returns_df = pd.DataFrame(aligned_returns)
# 裁剪到共同日期
common_dates = signals.index.intersection(returns_df.index)
signals_aligned = signals.loc[common_dates]
returns_aligned = returns_df.loc[common_dates]
print(f" 信号日期: {len(signals)}{len(signals_aligned)}")
print(f" 收益日期: {len(returns_df)}{len(returns_aligned)}")
print(f" 共同日期: {len(common_dates)}")
assert signals_aligned.index.equals(returns_aligned.index), "信号与收益日期不一致"
print(f" ✓ 信号与收益率日期一致")
print("\n✓ 阶段 4 通过")
return signals_aligned, returns_aligned
def test_stage5_strategy_returns(signals: pd.DataFrame, returns: pd.DataFrame):
"""
阶段 5: 计算策略收益
根据信号计算策略净值曲线
"""
print("\n" + "=" * 70)
print(" 阶段 5: 计算策略收益")
print("=" * 70)
print("\n[5.1] 计算策略日收益...")
strategy_returns = pd.Series(index=returns.index, dtype=float)
for date in returns.index:
if date in signals.index:
best_code = signals.loc[date, 'best']
strategy_returns[date] = returns.loc[date, best_code]
else:
strategy_returns[date] = 0.0
# 填充 NaN
strategy_returns = strategy_returns.fillna(0.0)
print(f" 策略收益天数: {len(strategy_returns)}")
print(f" 收益范围: {strategy_returns.min():.4%} ~ {strategy_returns.max():.4%}")
print("\n[5.2] 计算累计收益...")
cumulative_returns = (1 + strategy_returns).cumprod() - 1
print(f" 最终累计收益: {cumulative_returns.iloc[-1]:.2%}")
print(f" 最大累计收益: {cumulative_returns.max():.2%}")
print(f" 最小累计收益: {cumulative_returns.min():.2%}")
print("\n[5.3] 计算年化收益和最大回撤...")
# 年化收益
total_days = len(strategy_returns)
annual_return = (1 + cumulative_returns.iloc[-1]) ** (252 / total_days) - 1
print(f" 年化收益: {annual_return:.2%}")
# 最大回撤
rolling_max = cumulative_returns.cummax()
drawdown = (cumulative_returns - rolling_max) / (1 + rolling_max)
max_drawdown = drawdown.min()
print(f" 最大回撤: {max_drawdown:.2%}")
print("\n[5.4] 策略收益 vs 基准对比...")
# 基准:等权持有
benchmark_returns = returns.mean(axis=1)
benchmark_cumulative = (1 + benchmark_returns).cumprod() - 1
print(f" 策略累计收益: {cumulative_returns.iloc[-1]:.2%}")
print(f" 基准累计收益: {benchmark_cumulative.iloc[-1]:.2%}")
print(f" 超额收益: {cumulative_returns.iloc[-1] - benchmark_cumulative.iloc[-1]:.2%}")
print("\n✓ 阶段 5 通过")
return strategy_returns, cumulative_returns
def run_full_pipeline():
"""
运行完整流程
"""
print("\n" + "=" * 70)
print(" 端到端集成测试:数据获取 → 因子计算 → 数据对齐 → 信号生成")
print("=" * 70)
print("\n测试标的:")
print(" - 纳斯达克指数 (^IXIC) - 美股")
print(" - 创业板指数 (399006.SZ) - A 股")
print("\n时间范围: 2023-01-01 ~ 2024-12-31")
print("\n" + "=" * 70)
try:
# 阶段 1: 数据获取
data_dict = test_stage1_data_fetch()
# 阶段 2: 因子计算
factors = test_stage2_factor_calculation(data_dict)
# 阶段 3: 数据对齐
aligned_factors, aligned_returns = test_stage3_data_alignment(
factors, data_dict
)
# 阶段 4: 信号生成
signals, returns = test_stage4_signal_generation(
aligned_factors, aligned_returns
)
# 阶段 5: 策略收益
strategy_returns, cumulative_returns = test_stage5_strategy_returns(
signals, returns
)
# 总结
print("\n" + "=" * 70)
print(" 测试总结")
print("=" * 70)
print("\n✅ 所有阶段通过!")
print("\n流程验证:")
print(" ✓ 数据获取: FlaskAPIFetcher 成功获取线上数据")
print(" ✓ 因子计算: MomentumFactor 在原始日历计算")
print(" ✓ 数据对齐: CrossMarketAligner 对齐到 A 股日历")
print(" ✓ 信号生成: Top-N 选择逻辑正确")
print(" ✓ 收益计算: 策略净值曲线生成成功")
print("\n关键验证:")
print(" ✓ 跨市场日历差异已处理")
print(" ✓ 休市日收益率 = 0% (无 ffill 陷阱)")
print(" ✓ 收益率无 NaN")
print(" ✓ 信号与收益日期一致")
print("\n" + "=" * 70 + "\n")
return True
except Exception as e:
print(f"\n✗ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = run_full_pipeline()
if success:
print("🎉 端到端测试通过!")
sys.exit(0)
else:
print("❌ 端到端测试失败!")
sys.exit(1)

View File

@@ -0,0 +1,198 @@
"""
测试 FlaskAPIFetcher
验证:
1. 获取指数数据
2. 获取 ETF 数据
3. 获取交易日历
4. 健康检查
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.shared.data import FlaskAPIFetcher
def test_health_check():
"""测试 1: 健康检查"""
print("\n" + "=" * 60)
print(" 测试 1: 健康检查")
print("=" * 60)
fetcher = FlaskAPIFetcher()
health = fetcher.get_health()
print(f"\n健康状态: {health}")
assert health.get('available'), "API 服务不可用"
print("\n✓ 测试通过")
def test_fetch_indices():
"""测试 2: 获取指数数据"""
print("\n" + "=" * 60)
print(" 测试 2: 获取指数数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 + 中证 500
codes = ["000300.SH", "000905.SH"]
data = fetcher.fetch_indices(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
for code, df in data.items():
print(f"\n{code}:")
print(f" 数据量: {len(df)}")
print(f" 列: {list(df.columns)}")
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
print("\n✓ 测试通过")
def test_fetch_etf():
"""测试 3: 获取 ETF 数据"""
print("\n" + "=" * 60)
print(" 测试 3: 获取 ETF 数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# 获取沪深 300 ETF
codes = ["510300.SH"]
data = fetcher.fetch_etf(
codes=codes,
start="2024-01-01",
end="2024-03-31"
)
# 验证
assert len(data) == 1, f"应该返回 1 只 ETF实际 {len(data)}"
code = "510300.SH"
df = data[code]
print(f"\n{code}:")
print(f" 价格数据: {len(df)}")
print(f" 列: {list(df.columns)}")
# 验证附加信息
nav = df.attrs.get('nav')
if nav is not None:
print(f" 净值数据: {len(nav)}")
premium = df.attrs.get('latest_premium')
if premium is not None:
print(f" 最新溢价率: {premium:.2f}%")
assert len(df) > 0, f"{code} 数据为空"
assert 'close' in df.columns, f"{code} 缺少 close 列"
print("\n✓ 测试通过")
def test_trading_calendar():
"""测试 4: 获取交易日历"""
print("\n" + "=" * 60)
print(" 测试 4: 获取交易日历")
print("=" * 60)
fetcher = FlaskAPIFetcher()
# A股日历
calendar_a = fetcher.get_trading_calendar(market='A')
print(f"\nA股交易日历:")
print(f" 总天数: {len(calendar_a)}")
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
print(f" 前 5 天: {calendar_a[:5].tolist()}")
assert len(calendar_a) > 0, "A股日历为空"
# 美股日历
calendar_us = fetcher.get_trading_calendar(market='US')
print(f"\n美股交易日历:")
print(f" 总天数: {len(calendar_us)}")
assert len(calendar_us) > 0, "美股日历为空"
print("\n✓ 测试通过")
def test_benchmark():
"""测试 5: 获取基准数据"""
print("\n" + "=" * 60)
print(" 测试 5: 获取基准数据")
print("=" * 60)
fetcher = FlaskAPIFetcher()
benchmark = fetcher.get_benchmark(
code="000300.SH",
start="2024-01-01",
end="2024-03-31"
)
print(f"\n沪深 300 基准:")
print(f" 数据量: {len(benchmark)}")
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
assert len(benchmark) > 0, "基准数据为空"
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
print("\n✓ 测试通过")
if __name__ == "__main__":
import pandas as pd
print("\n" + "=" * 60)
print(" FlaskAPIFetcher 测试")
print("=" * 60)
tests = [
("健康检查", test_health_check),
("指数数据", test_fetch_indices),
("ETF 数据", test_fetch_etf),
("交易日历", test_trading_calendar),
("基准数据", test_benchmark),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 60)
print(" 测试总结")
print("=" * 60)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 60 + "\n")

View File

@@ -0,0 +1,281 @@
"""
测试扁平化资产池配置
验证:
1. 扁平化配置加载
2. 按市场分组
3. 信号/交易标的获取
4. 跨市场映射
"""
import sys
from pathlib import Path
import os
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.config import load_config
from framework_v2.config.schemas import GroupConfig
def test_flat_config_load():
"""测试 1: 加载扁平化配置"""
print("\n" + "=" * 70)
print(" 测试 1: 加载扁平化配置")
print("=" * 70)
# 设置环境变量
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
os.environ['TUSHARE_TOKEN'] = 'test_token'
# 加载配置
config = load_config('rotation_global.yaml')
print(f"\n✓ 配置加载成功")
print(f" 版本: {config.metadata.version}")
print(f" 策略: {config.metadata.strategy}")
print(f" 总标的数: {config.asset_pools.count()}")
# 验证基本字段
assert config.metadata.version == "2.0.0"
assert config.asset_pools.count() == 13 # 12 个标的
print("\n✓ 测试通过")
def test_market_grouping():
"""测试 2: 按市场分组"""
print("\n" + "=" * 70)
print(" 测试 2: 按市场分组")
print("=" * 70)
config = load_config('rotation_global.yaml')
# 获取所有市场
markets = config.asset_pools.groups
print(f"\n市场类型 ({len(markets)} 个):")
for market in markets:
count = config.asset_pools.count(market)
print(f" {market}: {count}")
# 按市场分组
by_group = config.asset_pools.by_group
print(f"\n市场分组:")
for market, assets in by_group.items():
print(f"\n {market} ({len(assets)} 只):")
for code, asset in assets.items():
print(f" {code}: {asset.name}")
# 验证市场数量
assert len(markets) == 7 # US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME # US_EQUITY, CN_EQUITY, JP_EQUITY, EU_EQUITY, HK_EQUITY, COMMODITY, FIXED_INCOME
assert config.asset_pools.count('US_TECH') == 2
assert config.asset_pools.count('CN_GROWTH') == 3
assert config.asset_pools.count('COMMODITY') == 3
assert config.asset_pools.count('FIXED_INCOME') == 1
print("\n✓ 测试通过")
def test_signal_trade_codes():
"""测试 3: 信号和交易标的"""
print("\n" + "=" * 70)
print(" 测试 3: 信号和交易标的")
print("=" * 70)
config = load_config('rotation_global.yaml')
# 获取所有信号标的
signal_codes = config.asset_pools.get_signal_codes()
print(f"\n信号标的 (13 个):")
for code in signal_codes:
print(f" {code}")
# 获取所有交易标的
trade_codes = config.asset_pools.get_trade_codes()
print(f"\n交易标的 (13 个):")
for code in trade_codes:
print(f" {code}")
# 获取特定市场的信号标的
us_signals = config.asset_pools.get_signal_codes('US_TECH')
print(f"\n美股信号标的: {us_signals}")
# 验证
assert len(signal_codes) == 13
assert len(trade_codes) == 13
assert 'NDX' in signal_codes
assert '513100.SH' in trade_codes
assert len(us_signals) == 2
print("\n✓ 测试通过")
def test_signal_to_trade_mapping():
"""测试 4: 信号→交易映射"""
print("\n" + "=" * 70)
print(" 测试 4: 信号→交易映射")
print("=" * 70)
config = load_config('rotation_global.yaml')
# 获取映射
mapping = config.asset_pools.get_signal_to_trade_mapping()
print(f"\n信号→交易映射:")
for signal, trade in mapping.items():
asset = config.asset_pools.assets.get(signal)
cross_market = "" if asset.signal_source == asset.trade_source else ""
print(f" {cross_market} {signal}{trade}")
# 验证跨市场标的
print(f"\n跨市场标的:")
for code, asset in config.asset_pools.assets.items():
if asset.is_cross_market:
print(f" {code}: {asset.signal_source}{asset.trade_source}")
# 验证映射
assert mapping['NDX'] == '513100.SH'
assert mapping['399006.SZ'] == '159915.SZ'
assert mapping['GC=F'] == '518880.SH'
assert mapping['931862.CSI'] == '931862.CSI' # 非跨市场
# 验证跨市场属性
assert config.asset_pools.assets['NDX'].is_cross_market == True
assert config.asset_pools.assets['931862.CSI'].is_cross_market == False
print("\n✓ 测试通过")
def test_market_specific_mapping():
"""测试 5: 特定市场映射"""
print("\n" + "=" * 70)
print(" 测试 5: 特定市场映射")
print("=" * 70)
config = load_config('rotation_global.yaml')
# 获取美股映射
us_mapping = config.asset_pools.get_signal_to_trade_mapping('US_TECH')
print(f"\n美股映射:")
for signal, trade in us_mapping.items():
print(f" {signal}{trade}")
# 获取商品映射
commodity_mapping = config.asset_pools.get_signal_to_trade_mapping('COMMODITY')
print(f"\n商品映射:")
for signal, trade in commodity_mapping.items():
print(f" {signal}{trade}")
# 验证
assert len(us_mapping) == 2
assert len(commodity_mapping) == 3
assert us_mapping['NDX'] == '513100.SH'
assert commodity_mapping['GC=F'] == '518880.SH'
print("\n✓ 测试通过")
def test_diversification_config():
"""测试 6: 分散化配置"""
print("\n" + "=" * 70)
print(" 测试 6: 分散化配置")
print("=" * 70)
config = load_config('rotation_global.yaml')
print(f"\n轮动配置:")
print(f" 选股数量: {config.rotation.select_num}")
print(f" 分散化: {config.rotation.diversified}")
print(f" 分散化分组: {config.rotation.diversification_groups}")
# 验证默认配置(全局模式)
assert config.rotation.select_num == 5
assert config.rotation.diversified == False
assert config.rotation.diversification_groups is None
print("\n✓ 测试通过")
def test_asset_config_details():
"""测试 7: 标的配置详情"""
print("\n" + "=" * 70)
print(" 测试 7: 标的配置详情")
print("=" * 70)
config = load_config('rotation_global.yaml')
# 检查纳指配置
ndx = config.asset_pools.assets['NDX']
print(f"\n纳指100 配置:")
print(f" 名称: {ndx.name}")
print(f" 市场: {ndx.group}")
print(f" 信号来源: {ndx.signal_source}")
print(f" 交易来源: {ndx.trade_source}")
print(f" 跨市场: {ndx.is_cross_market}")
print(f" 描述: {ndx.description}")
# 检查短债配置
bond = config.asset_pools.assets['931862.CSI']
print(f"\n短债指数 配置:")
print(f" 名称: {bond.name}")
print(f" 市场: {bond.group}")
print(f" 信号来源: {bond.signal_source}")
print(f" 交易来源: {bond.trade_source}")
print(f" 跨市场: {bond.is_cross_market}")
# 验证
assert ndx.name == "纳指100"
assert ndx.group == 'US_TECH'
assert ndx.signal_source == "NDX"
assert ndx.trade_source == "513100.SH"
assert ndx.is_cross_market == True
assert bond.signal_source == bond.trade_source
assert bond.is_cross_market == False
print("\n✓ 测试通过")
if __name__ == "__main__":
print("\n" + "=" * 70)
print(" 扁平化资产池配置测试")
print("=" * 70)
tests = [
("加载扁平化配置", test_flat_config_load),
("按市场分组", test_market_grouping),
("信号和交易标的", test_signal_trade_codes),
("信号→交易映射", test_signal_to_trade_mapping),
("特定市场映射", test_market_specific_mapping),
("分散化配置", test_diversification_config),
("标的配置详情", test_asset_config_details),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 70)
print(" 测试总结")
print("=" * 70)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 70 + "\n")
if failed > 0:
sys.exit(1)

View File

@@ -0,0 +1,116 @@
"""
因子对比验证测试
验证新框架的 MomentumFactor 与现有实现输出一致
"""
import sys
import pandas as pd
import numpy as np
from pathlib import Path
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
def test_momentum_factor_parity():
"""验证新因子与旧因子输出一致"""
print("=" * 60)
print(" MomentumFactor 对比测试")
print("=" * 60)
# 1. 加载测试数据
print("\n1. 加载测试数据...")
test_data_path = project_root / 'data' / 'index_history_data'
# 使用纳指100数据测试
import glob
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
if ndx_files:
test_file = ndx_files[0]
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
print(f" ✓ 加载 {test_file}")
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
print(f" 数据长度: {len(data)}")
else:
print(" ⚠ 未找到测试数据,使用模拟数据")
# 生成模拟数据
np.random.seed(42)
dates = pd.date_range('2020-01-01', periods=500, freq='B')
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
data = pd.DataFrame({
'close': prices,
'open': prices * 0.99,
'high': prices * 1.01,
'low': prices * 0.98,
'volume': np.random.randint(1000000, 10000000, 500)
}, index=dates)
# 2. 计算旧因子
print("\n2. 计算旧因子strategies/shared/factors/momentum.py...")
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
old_result = old_factor.compute(data)
print(f" ✓ 旧因子计算完成")
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
print(f" NaN 数量: {old_result.isna().sum()}")
# 3. 计算新因子
print("\n3. 计算新因子framework_v2/shared/factors/momentum.py...")
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
new_result = new_factor.compute(data)
print(f" ✓ 新因子计算完成")
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
print(f" NaN 数量: {new_result.isna().sum()}")
# 4. 对比结果
print("\n4. 对比结果...")
# 检查索引是否一致
if not old_result.index.equals(new_result.index):
print(" ✗ 索引不一致")
return False
print(" ✓ 索引一致")
# 检查数值差异
diff = (old_result - new_result).abs()
max_diff = diff.max()
mean_diff = diff.mean()
print(f" 最大差异: {max_diff:.6e}")
print(f" 平均差异: {mean_diff:.6e}")
# 允许浮点数精度误差1e-10
tolerance = 1e-10
if max_diff < tolerance:
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
print("\n" + "=" * 60)
print(" ✓ 测试通过:新旧因子输出完全一致!")
print("=" * 60)
return True
else:
print(f" ✗ 差异超出容差范围")
print("\n" + "=" * 60)
print(" ✗ 测试失败:新旧因子输出不一致")
print("=" * 60)
# 打印前10个差异点
diff_nonzero = diff[diff > tolerance]
if len(diff_nonzero) > 0:
print(f"\n 前10个差异点:")
for date, val in diff_nonzero.head(10).items():
old_val = old_result.loc[date]
new_val = new_result.loc[date]
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
return False
if __name__ == '__main__':
success = test_momentum_factor_parity()
sys.exit(0 if success else 1)