refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
3
archive/framework_v2/tests/__init__.py
Normal file
3
archive/framework_v2/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
框架 V2 测试
|
||||
"""
|
||||
292
archive/framework_v2/tests/test_alignment.py
Normal file
292
archive/framework_v2/tests/test_alignment.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
数据对齐测试
|
||||
|
||||
验证跨市场对齐器的正确性
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def test_factor_alignment():
|
||||
"""测试因子对齐"""
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
|
||||
print("=" * 60)
|
||||
print(" 测试 1: 因子对齐")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建模拟数据
|
||||
# 美股日历(比 A 股少几天)
|
||||
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
us_dates = us_dates.delete([2, 5]) # 删除 2 天(模拟美股假日)
|
||||
|
||||
# A 股日历(完整)
|
||||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
|
||||
# 因子值(美股日历)
|
||||
factor_values = [0.1, 0.15, 0.18, 0.16, 0.20, 0.22, 0.25, 0.23]
|
||||
factor_series = pd.Series(factor_values, index=us_dates)
|
||||
|
||||
print(f"\n源日历(美股): {len(us_dates)} 天")
|
||||
print(f"目标日历(A股): {len(cn_dates)} 天")
|
||||
print(f"因子值(源日历):")
|
||||
print(factor_series)
|
||||
|
||||
# 对齐
|
||||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||||
aligned = aligner.align_factor(factor_series, us_dates, code='^GSPC')
|
||||
|
||||
print(f"\n对齐后因子值:")
|
||||
print(aligned)
|
||||
|
||||
# 验证
|
||||
assert len(aligned) == len(cn_dates), "对齐后长度应该等于目标日历"
|
||||
assert 'value' in aligned.columns, "应该有 'value' 列"
|
||||
assert 'is_filled' in aligned.columns, "应该有 'is_filled' 列"
|
||||
|
||||
# 检查填充值
|
||||
filled_count = aligned['is_filled'].sum()
|
||||
print(f"\n填充值数量: {filled_count}")
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
return True
|
||||
|
||||
|
||||
def test_returns_alignment():
|
||||
"""测试收益率对齐"""
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 2: 收益率对齐")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建模拟价格数据
|
||||
# 美股日历
|
||||
us_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
us_dates = us_dates.delete([2, 5])
|
||||
|
||||
# A 股日历
|
||||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
|
||||
# 价格(美股日历)
|
||||
prices = [100, 101, 102, 101, 103, 104, 105, 104]
|
||||
close_series = pd.Series(prices, index=us_dates, dtype=float)
|
||||
|
||||
print(f"\n源价格(美股日历):")
|
||||
print(close_series)
|
||||
|
||||
# 对齐收益率
|
||||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||||
returns = aligner.align_returns(close_series, code='^GSPC')
|
||||
|
||||
print(f"\n对齐后收益率(A股日历):")
|
||||
print(returns)
|
||||
|
||||
# 验证
|
||||
assert len(returns) == len(cn_dates), "收益率长度应该等于目标日历"
|
||||
assert returns.index.equals(cn_dates), "收益率索引应该等于目标日历"
|
||||
assert not returns.isna().any(), "收益率不应该有 NaN"
|
||||
assert returns.iloc[0] == 0.0, "首日收益率应该为 0"
|
||||
|
||||
# 检查休市日收益率(应该为 0)
|
||||
# 美股休市日,价格 ffill,收益率 = 0
|
||||
print(f"\n收益率统计:")
|
||||
print(f" 最小值: {returns.min():.4f}")
|
||||
print(f" 最大值: {returns.max():.4f}")
|
||||
print(f" 均值: {returns.mean():.4f}")
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
return True
|
||||
|
||||
|
||||
def test_multi_asset_alignment():
|
||||
"""测试多标的对齐"""
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 3: 多标的收益率对齐")
|
||||
print("=" * 60)
|
||||
|
||||
# A 股日历
|
||||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
|
||||
# 多个标的(不同日历)
|
||||
close_dict = {}
|
||||
|
||||
# 标的 1: 完整数据
|
||||
prices1 = 100 + np.cumsum(np.random.randn(10))
|
||||
close_dict['^GSPC'] = pd.Series(prices1, index=cn_dates, dtype=float)
|
||||
|
||||
# 标的 2: 少 2 天
|
||||
us_dates2 = cn_dates.delete([2, 5])
|
||||
prices2 = 100 + np.cumsum(np.random.randn(8))
|
||||
close_dict['^IXIC'] = pd.Series(prices2, index=us_dates2, dtype=float)
|
||||
|
||||
# 标的 3: 完整数据
|
||||
prices3 = 100 + np.cumsum(np.random.randn(10))
|
||||
close_dict['931862.CSI'] = pd.Series(prices3, index=cn_dates, dtype=float)
|
||||
|
||||
print(f"\n标的数量: {len(close_dict)}")
|
||||
for code, close in close_dict.items():
|
||||
print(f" {code}: {len(close)} 天")
|
||||
|
||||
# 对齐多标的
|
||||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||||
returns_df = aligner.align_multi_asset(close_dict)
|
||||
|
||||
print(f"\n对齐后收益率 DataFrame:")
|
||||
print(returns_df.head())
|
||||
|
||||
# 验证
|
||||
assert len(returns_df) == len(cn_dates), "收益率 DataFrame 长度应该等于目标日历"
|
||||
assert returns_df.index.equals(cn_dates), "索引应该等于目标日历"
|
||||
assert not returns_df.isna().any().any(), "收益率 DataFrame 不应该有 NaN"
|
||||
assert len(returns_df.columns) == len(close_dict), "列数应该等于标的数"
|
||||
|
||||
print(f"\n✓ 测试通过")
|
||||
return True
|
||||
|
||||
|
||||
def test_signal_returns_alignment():
|
||||
"""测试信号与收益率对齐"""
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 4: 信号与收益率对齐验证")
|
||||
print("=" * 60)
|
||||
|
||||
# A 股日历
|
||||
cn_dates = pd.date_range('2024-01-01', periods=10, freq='B')
|
||||
|
||||
# 信号(多 2 天)
|
||||
signal_dates = cn_dates.union(pd.date_range('2024-01-15', periods=2, freq='B'))
|
||||
signals = pd.DataFrame({
|
||||
'signal': ['^GSPC,^IXIC'] * len(signal_dates)
|
||||
}, index=signal_dates)
|
||||
|
||||
# 收益率(正常)
|
||||
returns_df = pd.DataFrame(
|
||||
np.random.randn(len(cn_dates), 2) * 0.02,
|
||||
index=cn_dates,
|
||||
columns=['^GSPC', '^IXIC']
|
||||
)
|
||||
|
||||
print(f"\n信号: {len(signals)} 天")
|
||||
print(f"收益: {len(returns_df)} 天")
|
||||
|
||||
# 验证对齐
|
||||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||||
aligned_signals, aligned_returns = aligner.validate_alignment(signals, returns_df)
|
||||
|
||||
print(f"\n对齐后:")
|
||||
print(f" 信号: {len(aligned_signals)} 天")
|
||||
print(f" 收益: {len(aligned_returns)} 天")
|
||||
|
||||
# 验证
|
||||
assert len(aligned_signals) == len(aligned_returns), "对齐后长度应该一致"
|
||||
assert aligned_signals.index.equals(aligned_returns.index), "索引应该一致"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
return True
|
||||
|
||||
|
||||
def test_ffill_trap():
|
||||
"""测试 ffill 陷阱(错误 vs 正确)"""
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 5: ffill 陷阱对比")
|
||||
print("=" * 60)
|
||||
|
||||
# 美股日历
|
||||
us_dates = pd.date_range('2024-01-01', periods=5, freq='B')
|
||||
us_dates = us_dates.delete([2]) # 第 3 天休市
|
||||
|
||||
# A 股日历
|
||||
cn_dates = pd.date_range('2024-01-01', periods=5, freq='B')
|
||||
|
||||
# 价格
|
||||
prices = pd.Series([100, 101, 102, 103], index=us_dates, dtype=float)
|
||||
|
||||
print(f"\n原始价格(美股日历):")
|
||||
print(prices)
|
||||
|
||||
# ❌ 错误做法:先计算收益率,再 ffill
|
||||
print("\n❌ 错误做法:先 pct_change,再 reindex")
|
||||
returns_wrong = prices.pct_change()
|
||||
print("步骤 1 - 收益率:")
|
||||
print(returns_wrong)
|
||||
|
||||
returns_wrong_aligned = returns_wrong.reindex(cn_dates, method='ffill')
|
||||
print("\n步骤 2 - reindex + ffill:")
|
||||
print(returns_wrong_aligned)
|
||||
print("⚠ 问题:第 3 天复制了第 2 天的 +1% 收益率!")
|
||||
|
||||
# ✅ 正确做法:先 ffill 价格,再计算收益率
|
||||
print("\n✅ 正确做法:先 reindex 价格,再 pct_change")
|
||||
prices_aligned = prices.reindex(cn_dates, method='ffill')
|
||||
print("步骤 1 - 价格 reindex:")
|
||||
print(prices_aligned)
|
||||
|
||||
returns_correct = prices_aligned.pct_change(fill_method=None)
|
||||
returns_correct.iloc[0] = 0.0
|
||||
print("\n步骤 2 - pct_change:")
|
||||
print(returns_correct)
|
||||
print("✓ 第 3 天收益率 = 0%(价格不变)")
|
||||
|
||||
# 验证
|
||||
aligner = CrossMarketAligner(target_calendar=cn_dates)
|
||||
returns = aligner.align_returns(prices, code='TEST')
|
||||
|
||||
assert returns.iloc[2] == 0.0, "休市日收益率应该为 0"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("\n" + "=" * 60)
|
||||
print(" 跨市场数据对齐器测试")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("因子对齐", test_factor_alignment),
|
||||
("收益率对齐", test_returns_alignment),
|
||||
("多标的对齐", test_multi_asset_alignment),
|
||||
("信号与收益对齐", test_signal_returns_alignment),
|
||||
("ffill 陷阱", test_ffill_trap),
|
||||
]
|
||||
|
||||
results = []
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
success = test_func()
|
||||
results.append((name, success))
|
||||
except Exception as e:
|
||||
print(f"\n✗ {name} 失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
results.append((name, False))
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试总结")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, success in results if success)
|
||||
total = len(results)
|
||||
|
||||
for name, success in results:
|
||||
status = "✓ 通过" if success else "✗ 失败"
|
||||
print(f" {status} - {name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 通过")
|
||||
|
||||
sys.exit(0 if passed == total else 1)
|
||||
285
archive/framework_v2/tests/test_config.py
Normal file
285
archive/framework_v2/tests/test_config.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
测试配置加载和验证
|
||||
|
||||
验证:
|
||||
1. 配置文件加载
|
||||
2. Pydantic Schema 验证
|
||||
3. 环境变量替换
|
||||
4. 错误处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.config import load_config, ConfigLoader
|
||||
from framework_v2.config.schemas import RotationStrategyConfig, MarketType
|
||||
|
||||
|
||||
def test_load_config():
|
||||
"""测试 1: 加载配置文件"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 1: 加载配置文件")
|
||||
print("=" * 70)
|
||||
|
||||
# 设置环境变量(模拟)
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
os.environ['TUSHARE_TOKEN'] = 'test_token_123'
|
||||
|
||||
# 加载配置
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n✓ 配置加载成功")
|
||||
print(f" 版本: {config.metadata.version}")
|
||||
print(f" 策略: {config.metadata.strategy}")
|
||||
print(f" 资产池: {len(config.asset_pools.equity)} 股票, "
|
||||
f"{len(config.asset_pools.commodity)} 商品, "
|
||||
f"{len(config.asset_pools.fixed_income)} 债券")
|
||||
|
||||
# 验证基本字段
|
||||
assert config.metadata.version == "1.0.0"
|
||||
assert config.factor.n_days == 25
|
||||
assert config.rotation.select_num == 3
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_asset_pools():
|
||||
"""测试 2: 资产池配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 2: 资产池配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
# 验证股票资产
|
||||
print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):")
|
||||
for code, asset in config.asset_pools.equity.items():
|
||||
print(f" {code}: {asset.name} ({asset.market.value})")
|
||||
if asset.etf:
|
||||
print(f" ETF: {asset.etf}")
|
||||
|
||||
# 验证商品资产
|
||||
print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):")
|
||||
for code, asset in config.asset_pools.commodity.items():
|
||||
print(f" {code}: {asset.name}")
|
||||
|
||||
# 验证固定收益
|
||||
print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):")
|
||||
for code, asset in config.asset_pools.fixed_income.items():
|
||||
print(f" {code}: {asset.name} (ETF: {asset.etf})")
|
||||
|
||||
# 验证市场类型
|
||||
assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY
|
||||
assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY
|
||||
assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_threshold_config():
|
||||
"""测试 3: 阈值配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 3: 阈值配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n阈值模式: {config.rotation.threshold.mode}")
|
||||
print(f" 参考标的: {config.rotation.threshold.dynamic.reference}")
|
||||
print(f" 倍数: {config.rotation.threshold.dynamic.ratio}")
|
||||
print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}")
|
||||
|
||||
assert config.rotation.threshold.mode.value == "dynamic"
|
||||
assert config.rotation.threshold.dynamic.reference == "931862.CSI"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_data_sources():
|
||||
"""测试 4: 数据源配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 4: 数据源配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n数据源 ({len(config.data.sources)} 个):")
|
||||
for i, source in enumerate(config.data.sources, 1):
|
||||
print(f" {i}. {source.type.value}")
|
||||
print(f" 启用: {source.enabled}")
|
||||
print(f" 超时: {source.timeout}s")
|
||||
if source.url:
|
||||
print(f" URL: {source.url}")
|
||||
|
||||
# 验证环境变量替换
|
||||
flask_api_source = config.data.sources[0]
|
||||
assert flask_api_source.url == 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_validation_errors():
|
||||
"""测试 5: 验证错误处理"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 5: 验证错误处理")
|
||||
print("=" * 70)
|
||||
|
||||
# 测试 1: n_days 超出范围
|
||||
print("\n[5.1] 测试 n_days 超出范围...")
|
||||
try:
|
||||
from framework_v2.config.schemas import FactorConfig
|
||||
|
||||
# n_days = 1000(超出 5-250 范围)
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"factor": {"n_days": 1000}, # 错误:超出范围
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "test"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||||
|
||||
# 测试 2: 缺少必需字段
|
||||
print("\n[5.2] 测试缺少必需字段...")
|
||||
try:
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
# 缺少 benchmark
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "test"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||||
|
||||
# 测试 3: 环境变量未设置
|
||||
print("\n[5.3] 测试环境变量未设置...")
|
||||
try:
|
||||
# 删除环境变量
|
||||
old_value = os.environ.pop('FLASK_API_URL', None)
|
||||
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except ValueError as e:
|
||||
print(f" ✓ 正确捕获环境变量错误: {e}")
|
||||
finally:
|
||||
# 恢复环境变量
|
||||
if old_value:
|
||||
os.environ['FLASK_API_URL'] = old_value
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_env_substitution():
|
||||
"""测试 6: 环境变量替换"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 6: 环境变量替换")
|
||||
print("=" * 70)
|
||||
|
||||
loader = ConfigLoader()
|
||||
|
||||
# 测试 1: 基本替换
|
||||
print("\n[6.1] 基本替换...")
|
||||
os.environ['TEST_VAR'] = 'test_value'
|
||||
|
||||
config = {
|
||||
"url": "${TEST_VAR}"
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["url"] == "test_value"
|
||||
print(f" ✓ ${{TEST_VAR}} → {result['url']}")
|
||||
|
||||
# 测试 2: 默认值
|
||||
print("\n[6.2] 默认值...")
|
||||
config = {
|
||||
"url": "${NON_EXISTENT_VAR:default_value}"
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["url"] == "default_value"
|
||||
print(f" ${{NON_EXISTENT_VAR:default_value}} → {result['url']}")
|
||||
|
||||
# 测试 3: 嵌套结构
|
||||
print("\n[6.3] 嵌套结构...")
|
||||
os.environ['API_URL'] = 'https://api.example.com'
|
||||
|
||||
config = {
|
||||
"data": {
|
||||
"sources": [
|
||||
{"url": "${API_URL}"}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["data"]["sources"][0]["url"] == "https://api.example.com"
|
||||
print(f" ✓ 嵌套替换成功")
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print(" 配置加载和验证测试")
|
||||
print("=" * 70)
|
||||
|
||||
tests = [
|
||||
("加载配置文件", test_load_config),
|
||||
("资产池配置", test_asset_pools),
|
||||
("阈值配置", test_threshold_config),
|
||||
("数据源配置", test_data_sources),
|
||||
("验证错误处理", test_validation_errors),
|
||||
("环境变量替换", test_env_substitution),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试总结")
|
||||
print("=" * 70)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
if failed > 0:
|
||||
sys.exit(1)
|
||||
468
archive/framework_v2/tests/test_end_to_end.py
Normal file
468
archive/framework_v2/tests/test_end_to_end.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""
|
||||
端到端集成测试:数据获取 → 因子计算 → 数据对齐 → 信号生成
|
||||
|
||||
测试场景:
|
||||
1. 获取纳指(美股)和创业板(A股)数据
|
||||
2. 计算动量因子
|
||||
3. 对齐到 A 股交易日历
|
||||
4. 生成 Top-N 信号
|
||||
5. 验证完整流程
|
||||
|
||||
目标:
|
||||
- 验证 FlaskAPIFetcher 数据获取
|
||||
- 验证 MomentumFactor 因子计算
|
||||
- 验证 CrossMarketAligner 数据对齐
|
||||
- 验证完整流程无数据泄漏
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.shared.data import FlaskAPIFetcher, CrossMarketAligner
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor
|
||||
|
||||
|
||||
def test_stage1_data_fetch():
|
||||
"""
|
||||
阶段 1: 数据获取
|
||||
|
||||
获取纳指(^IXIC)和创业板(399006.SZ)数据
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 阶段 1: 数据获取")
|
||||
print("=" * 70)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取纳指数据(美股)
|
||||
print("\n[1.1] 获取纳斯达克指数数据(美股)...")
|
||||
us_data = fetcher.fetch_indices(
|
||||
codes=["^IXIC"],
|
||||
start="2023-01-01",
|
||||
end="2024-12-31"
|
||||
)
|
||||
|
||||
assert "^IXIC" in us_data, "纳指数据获取失败"
|
||||
df_nasdaq = us_data["^IXIC"]
|
||||
|
||||
print(f"\n纳指数据:")
|
||||
print(f" 数据量: {len(df_nasdaq)} 条")
|
||||
print(f" 日期范围: {df_nasdaq.index[0]} ~ {df_nasdaq.index[-1]}")
|
||||
print(f" 列: {list(df_nasdaq.columns)}")
|
||||
print(f" 前 3 行:")
|
||||
print(df_nasdaq.head(3).to_string())
|
||||
|
||||
# 获取创业板数据(A股)
|
||||
print("\n[1.2] 获取创业板指数数据(A股)...")
|
||||
cn_data = fetcher.fetch_indices(
|
||||
codes=["399006.SZ"],
|
||||
start="2023-01-01",
|
||||
end="2024-12-31"
|
||||
)
|
||||
|
||||
assert "399006.SZ" in cn_data, "创业板数据获取失败"
|
||||
df_gem = cn_data["399006.SZ"]
|
||||
|
||||
print(f"\n创业板数据:")
|
||||
print(f" 数据量: {len(df_gem)} 条")
|
||||
print(f" 日期范围: {df_gem.index[0]} ~ {df_gem.index[-1]}")
|
||||
print(f" 列: {list(df_gem.columns)}")
|
||||
print(f" 前 3 行:")
|
||||
print(df_gem.head(3).to_string())
|
||||
|
||||
# 对比日历差异
|
||||
print(f"\n[1.3] 交易日历对比:")
|
||||
nasdaq_dates = set(df_nasdaq.index)
|
||||
gem_dates = set(df_gem.index)
|
||||
|
||||
common_dates = nasdaq_dates & gem_dates
|
||||
only_nasdaq = nasdaq_dates - gem_dates
|
||||
only_gem = gem_dates - nasdaq_dates
|
||||
|
||||
print(f" 纳指交易日: {len(nasdaq_dates)} 天")
|
||||
print(f" 创业板交易日: {len(gem_dates)} 天")
|
||||
print(f" 共同交易日: {len(common_dates)} 天")
|
||||
print(f" 仅纳指交易: {len(only_nasdaq)} 天")
|
||||
print(f" 仅创业板交易: {len(only_gem)} 天")
|
||||
|
||||
if len(only_nasdaq) > 0:
|
||||
print(f" 纳指独有日期示例: {sorted(list(only_nasdaq))[:3]}")
|
||||
if len(only_gem) > 0:
|
||||
print(f" 创业板独有日期示例: {sorted(list(only_gem))[:3]}")
|
||||
|
||||
print("\n✓ 阶段 1 通过")
|
||||
|
||||
return {
|
||||
"^IXIC": df_nasdaq,
|
||||
"399006.SZ": df_gem
|
||||
}
|
||||
|
||||
|
||||
def test_stage2_factor_calculation(data_dict: Dict[str, pd.DataFrame]):
|
||||
"""
|
||||
阶段 2: 因子计算
|
||||
|
||||
计算动量因子(在原始日历上)
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 阶段 2: 因子计算(原始日历)")
|
||||
print("=" * 70)
|
||||
|
||||
factor_calc = MomentumFactor(n_days=20)
|
||||
|
||||
factors = {}
|
||||
|
||||
for code, df in data_dict.items():
|
||||
print(f"\n[2.1] 计算 {code} 动量因子...")
|
||||
|
||||
# compute 方法接受 DataFrame
|
||||
factor_series = factor_calc.compute(df)
|
||||
|
||||
# 转换为 DataFrame 格式
|
||||
factor_result = pd.DataFrame({
|
||||
'value': factor_series,
|
||||
'is_filled': False
|
||||
})
|
||||
|
||||
factors[code] = factor_result
|
||||
|
||||
print(f" 因子值数量: {len(factor_result)}")
|
||||
print(f" 日期范围: {factor_result.index[0]} ~ {factor_result.index[-1]}")
|
||||
print(f" 前 3 行:")
|
||||
print(factor_result.head(3).to_string())
|
||||
|
||||
# 统计 NaN
|
||||
nan_count = factor_result['value'].isna().sum()
|
||||
print(f" NaN 数量: {nan_count} ({nan_count/len(factor_result):.1%})")
|
||||
|
||||
# 验证因子值合理
|
||||
valid_factors = factor_result['value'].dropna()
|
||||
if len(valid_factors) > 0:
|
||||
print(f" 因子值范围: {valid_factors.min():.4f} ~ {valid_factors.max():.4f}")
|
||||
|
||||
print("\n✓ 阶段 2 通过")
|
||||
|
||||
return factors
|
||||
|
||||
|
||||
def test_stage3_data_alignment(
|
||||
factors: Dict[str, pd.DataFrame],
|
||||
data_dict: Dict[str, pd.DataFrame]
|
||||
):
|
||||
"""
|
||||
阶段 3: 数据对齐
|
||||
|
||||
将因子和收益率对齐到 A 股交易日历
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 阶段 3: 数据对齐(到 A 股日历)")
|
||||
print("=" * 70)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取 A 股交易日历(通过 API)
|
||||
print("\n[3.1] 获取 A 股交易日历(通过 API)...")
|
||||
|
||||
# 裁剪到数据日期范围
|
||||
data_start = min(df.index[0] for df in data_dict.values())
|
||||
data_end = max(df.index[-1] for df in data_dict.values())
|
||||
|
||||
# 使用 API 获取准确日历
|
||||
a_share_calendar = fetcher.get_trading_calendar(
|
||||
market='A',
|
||||
start=data_start.strftime('%Y-%m-%d'),
|
||||
end=data_end.strftime('%Y-%m-%d')
|
||||
)
|
||||
|
||||
print(f" A 股交易日: {len(a_share_calendar)} 天")
|
||||
print(f" 日期范围: {a_share_calendar[0]} ~ {a_share_calendar[-1]}")
|
||||
|
||||
# 创建对齐器
|
||||
aligner = CrossMarketAligner(target_calendar=a_share_calendar)
|
||||
|
||||
# 对齐因子
|
||||
print("\n[3.2] 对齐因子到 A 股日历...")
|
||||
aligned_factors = {}
|
||||
|
||||
for code, factor_df in factors.items():
|
||||
print(f"\n 对齐 {code} 因子...")
|
||||
|
||||
# 获取原始日历
|
||||
original_calendar = factor_df.index
|
||||
|
||||
# 对齐因子
|
||||
aligned = aligner.align_factor(
|
||||
factor_series=factor_df['value'],
|
||||
source_calendar=original_calendar,
|
||||
code=code
|
||||
)
|
||||
|
||||
aligned_factors[code] = aligned
|
||||
|
||||
# 统计
|
||||
filled_count = aligned['is_filled'].sum()
|
||||
print(f" 对齐后天数: {len(aligned)}")
|
||||
print(f" 填充天数: {filled_count} ({filled_count/len(aligned):.1%})")
|
||||
print(f" NaN 数量: {aligned['value'].isna().sum()}")
|
||||
|
||||
# 对齐收益率
|
||||
print("\n[3.3] 对齐收益率到 A 股日历...")
|
||||
aligned_returns = {}
|
||||
|
||||
for code, df in data_dict.items():
|
||||
print(f"\n 对齐 {code} 收益率...")
|
||||
|
||||
returns = aligner.align_returns(
|
||||
close_series=df['close'],
|
||||
code=code
|
||||
)
|
||||
|
||||
aligned_returns[code] = returns
|
||||
|
||||
# 统计
|
||||
print(f" 对齐后天数: {len(returns)}")
|
||||
print(f" 收益率范围: {returns.min():.4%} ~ {returns.max():.4%}")
|
||||
print(f" NaN 数量: {returns.isna().sum()}")
|
||||
print(f" 零收益率天数: {(returns == 0).sum()} (休市日)")
|
||||
|
||||
# 验证对齐结果
|
||||
print("\n[3.4] 验证对齐结果...")
|
||||
|
||||
# 1. 所有 DataFrame 应该有相同的索引
|
||||
indices = [df.index for df in aligned_factors.values()]
|
||||
indices.extend([s.index for s in aligned_returns.values()])
|
||||
|
||||
for i, idx1 in enumerate(indices):
|
||||
for j, idx2 in enumerate(indices):
|
||||
if i != j:
|
||||
assert idx1.equals(idx2), f"索引 {i} 和 {j} 不一致"
|
||||
|
||||
print(f" ✓ 所有数据对齐到同一日历: {len(indices[0])} 天")
|
||||
print(f" ✓ 日期范围: {indices[0][0]} ~ {indices[0][-1]}")
|
||||
|
||||
# 2. 验证收益率无 NaN
|
||||
for code, returns in aligned_returns.items():
|
||||
assert returns.isna().sum() == 0, f"{code} 收益率包含 NaN"
|
||||
print(f" ✓ 收益率无 NaN")
|
||||
|
||||
# 3. 验证休市日收益率 = 0
|
||||
for code, returns in aligned_returns.items():
|
||||
zero_days = (returns == 0).sum()
|
||||
print(f" {code} 休市日收益率 = 0: {zero_days} 天")
|
||||
|
||||
print("\n✓ 阶段 3 通过")
|
||||
|
||||
return aligned_factors, aligned_returns
|
||||
|
||||
|
||||
def test_stage4_signal_generation(
|
||||
aligned_factors: Dict[str, pd.DataFrame],
|
||||
aligned_returns: Dict[str, pd.Series]
|
||||
):
|
||||
"""
|
||||
阶段 4: 信号生成
|
||||
|
||||
根据对齐后的因子生成 Top-N 信号
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 阶段 4: 信号生成")
|
||||
print("=" * 70)
|
||||
|
||||
# 合并因子值
|
||||
print("\n[4.1] 合并因子值...")
|
||||
|
||||
factor_values = pd.DataFrame()
|
||||
for code, factor_df in aligned_factors.items():
|
||||
factor_values[code] = factor_df['value']
|
||||
|
||||
print(f" 合并后形状: {factor_values.shape}")
|
||||
print(f" 列: {list(factor_values.columns)}")
|
||||
print(f" 前 3 行:")
|
||||
print(factor_values.head(3).to_string())
|
||||
|
||||
# 简单信号:选择因子值最高的标的
|
||||
print("\n[4.2] 生成信号(Top-1)...")
|
||||
|
||||
# 跳过全为 NaN 的行
|
||||
valid_rows = factor_values.dropna(how='all').index
|
||||
factor_valid = factor_values.loc[valid_rows]
|
||||
|
||||
signals = pd.DataFrame()
|
||||
signals['best'] = factor_valid.idxmax(axis=1)
|
||||
signals['best_value'] = factor_valid.max(axis=1)
|
||||
|
||||
print(f" 信号数量: {len(signals)}")
|
||||
print(f" 前 10 个信号:")
|
||||
print(signals.head(10).to_string())
|
||||
|
||||
# 统计选择分布
|
||||
print(f"\n[4.3] 标的选择分布:")
|
||||
distribution = signals['best'].value_counts()
|
||||
for code, count in distribution.items():
|
||||
pct = count / len(signals)
|
||||
print(f" {code}: {count} 天 ({pct:.1%})")
|
||||
|
||||
# 验证信号与收益率对齐
|
||||
print("\n[4.4] 验证信号与收益率对齐...")
|
||||
|
||||
returns_df = pd.DataFrame(aligned_returns)
|
||||
|
||||
# 裁剪到共同日期
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
signals_aligned = signals.loc[common_dates]
|
||||
returns_aligned = returns_df.loc[common_dates]
|
||||
|
||||
print(f" 信号日期: {len(signals)} → {len(signals_aligned)}")
|
||||
print(f" 收益日期: {len(returns_df)} → {len(returns_aligned)}")
|
||||
print(f" 共同日期: {len(common_dates)}")
|
||||
|
||||
assert signals_aligned.index.equals(returns_aligned.index), "信号与收益日期不一致"
|
||||
print(f" ✓ 信号与收益率日期一致")
|
||||
|
||||
print("\n✓ 阶段 4 通过")
|
||||
|
||||
return signals_aligned, returns_aligned
|
||||
|
||||
|
||||
def test_stage5_strategy_returns(signals: pd.DataFrame, returns: pd.DataFrame):
|
||||
"""
|
||||
阶段 5: 计算策略收益
|
||||
|
||||
根据信号计算策略净值曲线
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 阶段 5: 计算策略收益")
|
||||
print("=" * 70)
|
||||
|
||||
print("\n[5.1] 计算策略日收益...")
|
||||
|
||||
strategy_returns = pd.Series(index=returns.index, dtype=float)
|
||||
|
||||
for date in returns.index:
|
||||
if date in signals.index:
|
||||
best_code = signals.loc[date, 'best']
|
||||
strategy_returns[date] = returns.loc[date, best_code]
|
||||
else:
|
||||
strategy_returns[date] = 0.0
|
||||
|
||||
# 填充 NaN
|
||||
strategy_returns = strategy_returns.fillna(0.0)
|
||||
|
||||
print(f" 策略收益天数: {len(strategy_returns)}")
|
||||
print(f" 收益范围: {strategy_returns.min():.4%} ~ {strategy_returns.max():.4%}")
|
||||
|
||||
print("\n[5.2] 计算累计收益...")
|
||||
|
||||
cumulative_returns = (1 + strategy_returns).cumprod() - 1
|
||||
|
||||
print(f" 最终累计收益: {cumulative_returns.iloc[-1]:.2%}")
|
||||
print(f" 最大累计收益: {cumulative_returns.max():.2%}")
|
||||
print(f" 最小累计收益: {cumulative_returns.min():.2%}")
|
||||
|
||||
print("\n[5.3] 计算年化收益和最大回撤...")
|
||||
|
||||
# 年化收益
|
||||
total_days = len(strategy_returns)
|
||||
annual_return = (1 + cumulative_returns.iloc[-1]) ** (252 / total_days) - 1
|
||||
print(f" 年化收益: {annual_return:.2%}")
|
||||
|
||||
# 最大回撤
|
||||
rolling_max = cumulative_returns.cummax()
|
||||
drawdown = (cumulative_returns - rolling_max) / (1 + rolling_max)
|
||||
max_drawdown = drawdown.min()
|
||||
print(f" 最大回撤: {max_drawdown:.2%}")
|
||||
|
||||
print("\n[5.4] 策略收益 vs 基准对比...")
|
||||
|
||||
# 基准:等权持有
|
||||
benchmark_returns = returns.mean(axis=1)
|
||||
benchmark_cumulative = (1 + benchmark_returns).cumprod() - 1
|
||||
|
||||
print(f" 策略累计收益: {cumulative_returns.iloc[-1]:.2%}")
|
||||
print(f" 基准累计收益: {benchmark_cumulative.iloc[-1]:.2%}")
|
||||
print(f" 超额收益: {cumulative_returns.iloc[-1] - benchmark_cumulative.iloc[-1]:.2%}")
|
||||
|
||||
print("\n✓ 阶段 5 通过")
|
||||
|
||||
return strategy_returns, cumulative_returns
|
||||
|
||||
|
||||
def run_full_pipeline():
|
||||
"""
|
||||
运行完整流程
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 端到端集成测试:数据获取 → 因子计算 → 数据对齐 → 信号生成")
|
||||
print("=" * 70)
|
||||
print("\n测试标的:")
|
||||
print(" - 纳斯达克指数 (^IXIC) - 美股")
|
||||
print(" - 创业板指数 (399006.SZ) - A 股")
|
||||
print("\n时间范围: 2023-01-01 ~ 2024-12-31")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
try:
|
||||
# 阶段 1: 数据获取
|
||||
data_dict = test_stage1_data_fetch()
|
||||
|
||||
# 阶段 2: 因子计算
|
||||
factors = test_stage2_factor_calculation(data_dict)
|
||||
|
||||
# 阶段 3: 数据对齐
|
||||
aligned_factors, aligned_returns = test_stage3_data_alignment(
|
||||
factors, data_dict
|
||||
)
|
||||
|
||||
# 阶段 4: 信号生成
|
||||
signals, returns = test_stage4_signal_generation(
|
||||
aligned_factors, aligned_returns
|
||||
)
|
||||
|
||||
# 阶段 5: 策略收益
|
||||
strategy_returns, cumulative_returns = test_stage5_strategy_returns(
|
||||
signals, returns
|
||||
)
|
||||
|
||||
# 总结
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试总结")
|
||||
print("=" * 70)
|
||||
print("\n✅ 所有阶段通过!")
|
||||
print("\n流程验证:")
|
||||
print(" ✓ 数据获取: FlaskAPIFetcher 成功获取线上数据")
|
||||
print(" ✓ 因子计算: MomentumFactor 在原始日历计算")
|
||||
print(" ✓ 数据对齐: CrossMarketAligner 对齐到 A 股日历")
|
||||
print(" ✓ 信号生成: Top-N 选择逻辑正确")
|
||||
print(" ✓ 收益计算: 策略净值曲线生成成功")
|
||||
print("\n关键验证:")
|
||||
print(" ✓ 跨市场日历差异已处理")
|
||||
print(" ✓ 休市日收益率 = 0% (无 ffill 陷阱)")
|
||||
print(" ✓ 收益率无 NaN")
|
||||
print(" ✓ 信号与收益日期一致")
|
||||
print("\n" + "=" * 70 + "\n")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_full_pipeline()
|
||||
|
||||
if success:
|
||||
print("🎉 端到端测试通过!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("❌ 端到端测试失败!")
|
||||
sys.exit(1)
|
||||
198
archive/framework_v2/tests/test_flask_api_fetcher.py
Normal file
198
archive/framework_v2/tests/test_flask_api_fetcher.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
测试 FlaskAPIFetcher
|
||||
|
||||
验证:
|
||||
1. 获取指数数据
|
||||
2. 获取 ETF 数据
|
||||
3. 获取交易日历
|
||||
4. 健康检查
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""测试 1: 健康检查"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 1: 健康检查")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
health = fetcher.get_health()
|
||||
|
||||
print(f"\n健康状态: {health}")
|
||||
|
||||
assert health.get('available'), "API 服务不可用"
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_fetch_indices():
|
||||
"""测试 2: 获取指数数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 2: 获取指数数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取沪深 300 + 中证 500
|
||||
codes = ["000300.SH", "000905.SH"]
|
||||
data = fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert len(data) == 2, f"应该返回 2 只指数,实际 {len(data)}"
|
||||
|
||||
for code, df in data.items():
|
||||
print(f"\n{code}:")
|
||||
print(f" 数据量: {len(df)} 条")
|
||||
print(f" 列: {list(df.columns)}")
|
||||
print(f" 日期范围: {df.index[0]} ~ {df.index[-1]}")
|
||||
|
||||
assert len(df) > 0, f"{code} 数据为空"
|
||||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||||
assert 'volume' in df.columns, f"{code} 缺少 volume 列"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_fetch_etf():
|
||||
"""测试 3: 获取 ETF 数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 3: 获取 ETF 数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# 获取沪深 300 ETF
|
||||
codes = ["510300.SH"]
|
||||
data = fetcher.fetch_etf(
|
||||
codes=codes,
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert len(data) == 1, f"应该返回 1 只 ETF,实际 {len(data)}"
|
||||
|
||||
code = "510300.SH"
|
||||
df = data[code]
|
||||
|
||||
print(f"\n{code}:")
|
||||
print(f" 价格数据: {len(df)} 条")
|
||||
print(f" 列: {list(df.columns)}")
|
||||
|
||||
# 验证附加信息
|
||||
nav = df.attrs.get('nav')
|
||||
if nav is not None:
|
||||
print(f" 净值数据: {len(nav)} 条")
|
||||
|
||||
premium = df.attrs.get('latest_premium')
|
||||
if premium is not None:
|
||||
print(f" 最新溢价率: {premium:.2f}%")
|
||||
|
||||
assert len(df) > 0, f"{code} 数据为空"
|
||||
assert 'close' in df.columns, f"{code} 缺少 close 列"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_trading_calendar():
|
||||
"""测试 4: 获取交易日历"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 4: 获取交易日历")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
# A股日历
|
||||
calendar_a = fetcher.get_trading_calendar(market='A')
|
||||
print(f"\nA股交易日历:")
|
||||
print(f" 总天数: {len(calendar_a)}")
|
||||
print(f" 日期范围: {calendar_a[0]} ~ {calendar_a[-1]}")
|
||||
print(f" 前 5 天: {calendar_a[:5].tolist()}")
|
||||
|
||||
assert len(calendar_a) > 0, "A股日历为空"
|
||||
|
||||
# 美股日历
|
||||
calendar_us = fetcher.get_trading_calendar(market='US')
|
||||
print(f"\n美股交易日历:")
|
||||
print(f" 总天数: {len(calendar_us)}")
|
||||
|
||||
assert len(calendar_us) > 0, "美股日历为空"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_benchmark():
|
||||
"""测试 5: 获取基准数据"""
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试 5: 获取基准数据")
|
||||
print("=" * 60)
|
||||
|
||||
fetcher = FlaskAPIFetcher()
|
||||
|
||||
benchmark = fetcher.get_benchmark(
|
||||
code="000300.SH",
|
||||
start="2024-01-01",
|
||||
end="2024-03-31"
|
||||
)
|
||||
|
||||
print(f"\n沪深 300 基准:")
|
||||
print(f" 数据量: {len(benchmark)} 条")
|
||||
print(f" 日期范围: {benchmark.index[0]} ~ {benchmark.index[-1]}")
|
||||
print(f" 价格范围: {benchmark.min():.2f} ~ {benchmark.max():.2f}")
|
||||
|
||||
assert len(benchmark) > 0, "基准数据为空"
|
||||
assert isinstance(benchmark, pd.Series), "基准数据应该是 Series"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pandas as pd
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" FlaskAPIFetcher 测试")
|
||||
print("=" * 60)
|
||||
|
||||
tests = [
|
||||
("健康检查", test_health_check),
|
||||
("指数数据", test_fetch_indices),
|
||||
("ETF 数据", test_fetch_etf),
|
||||
("交易日历", test_trading_calendar),
|
||||
("基准数据", test_benchmark),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" 测试总结")
|
||||
print("=" * 60)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 60 + "\n")
|
||||
281
archive/framework_v2/tests/test_flat_asset_pool.py
Normal file
281
archive/framework_v2/tests/test_flat_asset_pool.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
测试扁平化资产池配置
|
||||
|
||||
验证:
|
||||
1. 扁平化配置加载
|
||||
2. 按市场分组
|
||||
3. 信号/交易标的获取
|
||||
4. 跨市场映射
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.config import load_config
|
||||
from framework_v2.config.schemas import GroupConfig
|
||||
|
||||
|
||||
def test_flat_config_load():
|
||||
"""测试 1: 加载扁平化配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 1: 加载扁平化配置")
|
||||
print("=" * 70)
|
||||
|
||||
# 设置环境变量
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
os.environ['TUSHARE_TOKEN'] = 'test_token'
|
||||
|
||||
# 加载配置
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
print(f"\n✓ 配置加载成功")
|
||||
print(f" 版本: {config.metadata.version}")
|
||||
print(f" 策略: {config.metadata.strategy}")
|
||||
print(f" 总标的数: {config.asset_pools.count()}")
|
||||
|
||||
# 验证基本字段
|
||||
assert config.metadata.version == "2.0.0"
|
||||
assert config.asset_pools.count() == 13 # 12 个标的
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_market_grouping():
|
||||
"""测试 2: 按市场分组"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 2: 按市场分组")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取所有市场
|
||||
markets = config.asset_pools.groups
|
||||
print(f"\n市场类型 ({len(markets)} 个):")
|
||||
for market in markets:
|
||||
count = config.asset_pools.count(market)
|
||||
print(f" {market}: {count} 只")
|
||||
|
||||
# 按市场分组
|
||||
by_group = config.asset_pools.by_group
|
||||
print(f"\n市场分组:")
|
||||
for market, assets in by_group.items():
|
||||
print(f"\n {market} ({len(assets)} 只):")
|
||||
for code, asset in assets.items():
|
||||
print(f" {code}: {asset.name}")
|
||||
|
||||
# 验证市场数量
|
||||
assert len(markets) == 7 # US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME # US_EQUITY, CN_EQUITY, JP_EQUITY, EU_EQUITY, HK_EQUITY, COMMODITY, FIXED_INCOME
|
||||
assert config.asset_pools.count('US_TECH') == 2
|
||||
assert config.asset_pools.count('CN_GROWTH') == 3
|
||||
assert config.asset_pools.count('COMMODITY') == 3
|
||||
assert config.asset_pools.count('FIXED_INCOME') == 1
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_signal_trade_codes():
|
||||
"""测试 3: 信号和交易标的"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 3: 信号和交易标的")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取所有信号标的
|
||||
signal_codes = config.asset_pools.get_signal_codes()
|
||||
print(f"\n信号标的 (13 个):")
|
||||
for code in signal_codes:
|
||||
print(f" {code}")
|
||||
|
||||
# 获取所有交易标的
|
||||
trade_codes = config.asset_pools.get_trade_codes()
|
||||
print(f"\n交易标的 (13 个):")
|
||||
for code in trade_codes:
|
||||
print(f" {code}")
|
||||
|
||||
# 获取特定市场的信号标的
|
||||
us_signals = config.asset_pools.get_signal_codes('US_TECH')
|
||||
print(f"\n美股信号标的: {us_signals}")
|
||||
|
||||
# 验证
|
||||
assert len(signal_codes) == 13
|
||||
assert len(trade_codes) == 13
|
||||
assert 'NDX' in signal_codes
|
||||
assert '513100.SH' in trade_codes
|
||||
assert len(us_signals) == 2
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_signal_to_trade_mapping():
|
||||
"""测试 4: 信号→交易映射"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 4: 信号→交易映射")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取映射
|
||||
mapping = config.asset_pools.get_signal_to_trade_mapping()
|
||||
|
||||
print(f"\n信号→交易映射:")
|
||||
for signal, trade in mapping.items():
|
||||
asset = config.asset_pools.assets.get(signal)
|
||||
cross_market = "✗" if asset.signal_source == asset.trade_source else "✓"
|
||||
print(f" {cross_market} {signal} → {trade}")
|
||||
|
||||
# 验证跨市场标的
|
||||
print(f"\n跨市场标的:")
|
||||
for code, asset in config.asset_pools.assets.items():
|
||||
if asset.is_cross_market:
|
||||
print(f" {code}: {asset.signal_source} → {asset.trade_source}")
|
||||
|
||||
# 验证映射
|
||||
assert mapping['NDX'] == '513100.SH'
|
||||
assert mapping['399006.SZ'] == '159915.SZ'
|
||||
assert mapping['GC=F'] == '518880.SH'
|
||||
assert mapping['931862.CSI'] == '931862.CSI' # 非跨市场
|
||||
|
||||
# 验证跨市场属性
|
||||
assert config.asset_pools.assets['NDX'].is_cross_market == True
|
||||
assert config.asset_pools.assets['931862.CSI'].is_cross_market == False
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_market_specific_mapping():
|
||||
"""测试 5: 特定市场映射"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 5: 特定市场映射")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取美股映射
|
||||
us_mapping = config.asset_pools.get_signal_to_trade_mapping('US_TECH')
|
||||
print(f"\n美股映射:")
|
||||
for signal, trade in us_mapping.items():
|
||||
print(f" {signal} → {trade}")
|
||||
|
||||
# 获取商品映射
|
||||
commodity_mapping = config.asset_pools.get_signal_to_trade_mapping('COMMODITY')
|
||||
print(f"\n商品映射:")
|
||||
for signal, trade in commodity_mapping.items():
|
||||
print(f" {signal} → {trade}")
|
||||
|
||||
# 验证
|
||||
assert len(us_mapping) == 2
|
||||
assert len(commodity_mapping) == 3
|
||||
assert us_mapping['NDX'] == '513100.SH'
|
||||
assert commodity_mapping['GC=F'] == '518880.SH'
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_diversification_config():
|
||||
"""测试 6: 分散化配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 6: 分散化配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
print(f"\n轮动配置:")
|
||||
print(f" 选股数量: {config.rotation.select_num}")
|
||||
print(f" 分散化: {config.rotation.diversified}")
|
||||
print(f" 分散化分组: {config.rotation.diversification_groups}")
|
||||
|
||||
# 验证默认配置(全局模式)
|
||||
assert config.rotation.select_num == 5
|
||||
assert config.rotation.diversified == False
|
||||
assert config.rotation.diversification_groups is None
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_asset_config_details():
|
||||
"""测试 7: 标的配置详情"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 7: 标的配置详情")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 检查纳指配置
|
||||
ndx = config.asset_pools.assets['NDX']
|
||||
print(f"\n纳指100 配置:")
|
||||
print(f" 名称: {ndx.name}")
|
||||
print(f" 市场: {ndx.group}")
|
||||
print(f" 信号来源: {ndx.signal_source}")
|
||||
print(f" 交易来源: {ndx.trade_source}")
|
||||
print(f" 跨市场: {ndx.is_cross_market}")
|
||||
print(f" 描述: {ndx.description}")
|
||||
|
||||
# 检查短债配置
|
||||
bond = config.asset_pools.assets['931862.CSI']
|
||||
print(f"\n短债指数 配置:")
|
||||
print(f" 名称: {bond.name}")
|
||||
print(f" 市场: {bond.group}")
|
||||
print(f" 信号来源: {bond.signal_source}")
|
||||
print(f" 交易来源: {bond.trade_source}")
|
||||
print(f" 跨市场: {bond.is_cross_market}")
|
||||
|
||||
# 验证
|
||||
assert ndx.name == "纳指100"
|
||||
assert ndx.group == 'US_TECH'
|
||||
assert ndx.signal_source == "NDX"
|
||||
assert ndx.trade_source == "513100.SH"
|
||||
assert ndx.is_cross_market == True
|
||||
|
||||
assert bond.signal_source == bond.trade_source
|
||||
assert bond.is_cross_market == False
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print(" 扁平化资产池配置测试")
|
||||
print("=" * 70)
|
||||
|
||||
tests = [
|
||||
("加载扁平化配置", test_flat_config_load),
|
||||
("按市场分组", test_market_grouping),
|
||||
("信号和交易标的", test_signal_trade_codes),
|
||||
("信号→交易映射", test_signal_to_trade_mapping),
|
||||
("特定市场映射", test_market_specific_mapping),
|
||||
("分散化配置", test_diversification_config),
|
||||
("标的配置详情", test_asset_config_details),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试总结")
|
||||
print("=" * 70)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
if failed > 0:
|
||||
sys.exit(1)
|
||||
116
archive/framework_v2/tests/test_momentum_parity.py
Normal file
116
archive/framework_v2/tests/test_momentum_parity.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
因子对比验证测试
|
||||
|
||||
验证新框架的 MomentumFactor 与现有实现输出一致
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def test_momentum_factor_parity():
|
||||
"""验证新因子与旧因子输出一致"""
|
||||
|
||||
print("=" * 60)
|
||||
print(" MomentumFactor 对比测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载测试数据
|
||||
print("\n1. 加载测试数据...")
|
||||
test_data_path = project_root / 'data' / 'index_history_data'
|
||||
|
||||
# 使用纳指100数据测试
|
||||
import glob
|
||||
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
|
||||
if ndx_files:
|
||||
test_file = ndx_files[0]
|
||||
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
|
||||
print(f" ✓ 加载 {test_file}")
|
||||
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
|
||||
print(f" 数据长度: {len(data)} 条")
|
||||
else:
|
||||
print(" ⚠ 未找到测试数据,使用模拟数据")
|
||||
# 生成模拟数据
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range('2020-01-01', periods=500, freq='B')
|
||||
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
|
||||
data = pd.DataFrame({
|
||||
'close': prices,
|
||||
'open': prices * 0.99,
|
||||
'high': prices * 1.01,
|
||||
'low': prices * 0.98,
|
||||
'volume': np.random.randint(1000000, 10000000, 500)
|
||||
}, index=dates)
|
||||
|
||||
# 2. 计算旧因子
|
||||
print("\n2. 计算旧因子(strategies/shared/factors/momentum.py)...")
|
||||
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
|
||||
|
||||
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
old_result = old_factor.compute(data)
|
||||
print(f" ✓ 旧因子计算完成")
|
||||
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
|
||||
print(f" NaN 数量: {old_result.isna().sum()}")
|
||||
|
||||
# 3. 计算新因子
|
||||
print("\n3. 计算新因子(framework_v2/shared/factors/momentum.py)...")
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
|
||||
|
||||
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
new_result = new_factor.compute(data)
|
||||
print(f" ✓ 新因子计算完成")
|
||||
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
|
||||
print(f" NaN 数量: {new_result.isna().sum()}")
|
||||
|
||||
# 4. 对比结果
|
||||
print("\n4. 对比结果...")
|
||||
|
||||
# 检查索引是否一致
|
||||
if not old_result.index.equals(new_result.index):
|
||||
print(" ✗ 索引不一致")
|
||||
return False
|
||||
print(" ✓ 索引一致")
|
||||
|
||||
# 检查数值差异
|
||||
diff = (old_result - new_result).abs()
|
||||
max_diff = diff.max()
|
||||
mean_diff = diff.mean()
|
||||
|
||||
print(f" 最大差异: {max_diff:.6e}")
|
||||
print(f" 平均差异: {mean_diff:.6e}")
|
||||
|
||||
# 允许浮点数精度误差(1e-10)
|
||||
tolerance = 1e-10
|
||||
if max_diff < tolerance:
|
||||
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✓ 测试通过:新旧因子输出完全一致!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ 差异超出容差范围")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✗ 测试失败:新旧因子输出不一致")
|
||||
print("=" * 60)
|
||||
|
||||
# 打印前10个差异点
|
||||
diff_nonzero = diff[diff > tolerance]
|
||||
if len(diff_nonzero) > 0:
|
||||
print(f"\n 前10个差异点:")
|
||||
for date, val in diff_nonzero.head(10).items():
|
||||
old_val = old_result.loc[date]
|
||||
new_val = new_result.loc[date]
|
||||
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
success = test_momentum_factor_parity()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user