Files
etf/framework_v2/tests/test_config.py
aszerW 0954458114 test(framework_v2): 添加配置系统测试和策略示例
配置文件:
- rotation_global.yaml: 扁平化资产池配置示例,演示 group 策略分组
  * 13 个标的覆盖 7 个策略分组(US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME)
  * signal_source/trade_source 分离配置(跨市场场景)
  * 分散化选股配置示例(注释状态)
  * 默认使用 Flask API 数据源

测试用例:
- test_flat_asset_pool.py: 7/7 测试通过
  * 扁平配置加载验证
  * 策略分组功能测试(by_group, groups, count)
  * 信号/交易标的获取(get_signal_codes, get_trade_codes)
  * 信号→交易映射(get_signal_to_trade_mapping)
  * 分散化配置验证
  * 标的配置详情验证

- test_config.py: 配置加载器测试
- test_simple_rotation.py: 简单轮动策略端到端测试
2026-05-24 14:26:09 +08:00

286 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试配置加载和验证
验证:
1. 配置文件加载
2. Pydantic Schema 验证
3. 环境变量替换
4. 错误处理
"""
import sys
from pathlib import Path
import os
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
from framework_v2.config import load_config, ConfigLoader
from framework_v2.config.schemas import RotationStrategyConfig, MarketType
def test_load_config():
"""测试 1: 加载配置文件"""
print("\n" + "=" * 70)
print(" 测试 1: 加载配置文件")
print("=" * 70)
# 设置环境变量(模拟)
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
os.environ['TUSHARE_TOKEN'] = 'test_token_123'
# 加载配置
config = load_config('rotation_example.yaml')
print(f"\n✓ 配置加载成功")
print(f" 版本: {config.metadata.version}")
print(f" 策略: {config.metadata.strategy}")
print(f" 资产池: {len(config.asset_pools.equity)} 股票, "
f"{len(config.asset_pools.commodity)} 商品, "
f"{len(config.asset_pools.fixed_income)} 债券")
# 验证基本字段
assert config.metadata.version == "1.0.0"
assert config.factor.n_days == 25
assert config.rotation.select_num == 3
print("\n✓ 测试通过")
def test_asset_pools():
"""测试 2: 资产池配置"""
print("\n" + "=" * 70)
print(" 测试 2: 资产池配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
# 验证股票资产
print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):")
for code, asset in config.asset_pools.equity.items():
print(f" {code}: {asset.name} ({asset.market.value})")
if asset.etf:
print(f" ETF: {asset.etf}")
# 验证商品资产
print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):")
for code, asset in config.asset_pools.commodity.items():
print(f" {code}: {asset.name}")
# 验证固定收益
print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):")
for code, asset in config.asset_pools.fixed_income.items():
print(f" {code}: {asset.name} (ETF: {asset.etf})")
# 验证市场类型
assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY
assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY
assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY
print("\n✓ 测试通过")
def test_threshold_config():
"""测试 3: 阈值配置"""
print("\n" + "=" * 70)
print(" 测试 3: 阈值配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
print(f"\n阈值模式: {config.rotation.threshold.mode}")
print(f" 参考标的: {config.rotation.threshold.dynamic.reference}")
print(f" 倍数: {config.rotation.threshold.dynamic.ratio}")
print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}")
assert config.rotation.threshold.mode.value == "dynamic"
assert config.rotation.threshold.dynamic.reference == "931862.CSI"
print("\n✓ 测试通过")
def test_data_sources():
"""测试 4: 数据源配置"""
print("\n" + "=" * 70)
print(" 测试 4: 数据源配置")
print("=" * 70)
config = load_config('rotation_example.yaml')
print(f"\n数据源 ({len(config.data.sources)} 个):")
for i, source in enumerate(config.data.sources, 1):
print(f" {i}. {source.type.value}")
print(f" 启用: {source.enabled}")
print(f" 超时: {source.timeout}s")
if source.url:
print(f" URL: {source.url}")
# 验证环境变量替换
flask_api_source = config.data.sources[0]
assert flask_api_source.url == 'https://k3s.tokenpluse.xyz'
print("\n✓ 测试通过")
def test_validation_errors():
"""测试 5: 验证错误处理"""
print("\n" + "=" * 70)
print(" 测试 5: 验证错误处理")
print("=" * 70)
# 测试 1: n_days 超出范围
print("\n[5.1] 测试 n_days 超出范围...")
try:
from framework_v2.config.schemas import FactorConfig
# n_days = 1000超出 5-250 范围)
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
"benchmark": {"code": "000300.SH", "name": "沪深300"},
"backtest": {"start_date": "2020-01-01"},
"factor": {"n_days": 1000}, # 错误:超出范围
"data": {
"sources": [{"type": "flask_api", "url": "test"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except Exception as e:
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
# 测试 2: 缺少必需字段
print("\n[5.2] 测试缺少必需字段...")
try:
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
# 缺少 benchmark
"backtest": {"start_date": "2020-01-01"},
"data": {
"sources": [{"type": "flask_api", "url": "test"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except Exception as e:
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
# 测试 3: 环境变量未设置
print("\n[5.3] 测试环境变量未设置...")
try:
# 删除环境变量
old_value = os.environ.pop('FLASK_API_URL', None)
invalid_config = {
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
"benchmark": {"code": "000300.SH", "name": "沪深300"},
"backtest": {"start_date": "2020-01-01"},
"data": {
"sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}]
}
}
RotationStrategyConfig(**invalid_config)
print(" ✗ 应该抛出验证错误")
assert False
except ValueError as e:
print(f" ✓ 正确捕获环境变量错误: {e}")
finally:
# 恢复环境变量
if old_value:
os.environ['FLASK_API_URL'] = old_value
print("\n✓ 测试通过")
def test_env_substitution():
"""测试 6: 环境变量替换"""
print("\n" + "=" * 70)
print(" 测试 6: 环境变量替换")
print("=" * 70)
loader = ConfigLoader()
# 测试 1: 基本替换
print("\n[6.1] 基本替换...")
os.environ['TEST_VAR'] = 'test_value'
config = {
"url": "${TEST_VAR}"
}
result = loader._substitute_env_vars(config)
assert result["url"] == "test_value"
print(f" ✓ ${{TEST_VAR}}{result['url']}")
# 测试 2: 默认值
print("\n[6.2] 默认值...")
config = {
"url": "${NON_EXISTENT_VAR:default_value}"
}
result = loader._substitute_env_vars(config)
assert result["url"] == "default_value"
print(f" ${{NON_EXISTENT_VAR:default_value}}{result['url']}")
# 测试 3: 嵌套结构
print("\n[6.3] 嵌套结构...")
os.environ['API_URL'] = 'https://api.example.com'
config = {
"data": {
"sources": [
{"url": "${API_URL}"}
]
}
}
result = loader._substitute_env_vars(config)
assert result["data"]["sources"][0]["url"] == "https://api.example.com"
print(f" ✓ 嵌套替换成功")
print("\n✓ 测试通过")
if __name__ == "__main__":
print("\n" + "=" * 70)
print(" 配置加载和验证测试")
print("=" * 70)
tests = [
("加载配置文件", test_load_config),
("资产池配置", test_asset_pools),
("阈值配置", test_threshold_config),
("数据源配置", test_data_sources),
("验证错误处理", test_validation_errors),
("环境变量替换", test_env_substitution),
]
passed = 0
failed = 0
for name, test_func in tests:
try:
test_func()
passed += 1
except Exception as e:
print(f"\n✗ 测试失败: {name}")
print(f" 错误: {e}")
import traceback
traceback.print_exc()
failed += 1
print("\n" + "=" * 70)
print(" 测试总结")
print("=" * 70)
print(f" ✓ 通过 - {passed}")
if failed > 0:
print(f" ✗ 失败 - {failed}")
print(f"\n总计: {passed}/{passed + failed} 通过")
print("=" * 70 + "\n")
if failed > 0:
sys.exit(1)