配置文件: - rotation_global.yaml: 扁平化资产池配置示例,演示 group 策略分组 * 13 个标的覆盖 7 个策略分组(US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME) * signal_source/trade_source 分离配置(跨市场场景) * 分散化选股配置示例(注释状态) * 默认使用 Flask API 数据源 测试用例: - test_flat_asset_pool.py: 7/7 测试通过 * 扁平配置加载验证 * 策略分组功能测试(by_group, groups, count) * 信号/交易标的获取(get_signal_codes, get_trade_codes) * 信号→交易映射(get_signal_to_trade_mapping) * 分散化配置验证 * 标的配置详情验证 - test_config.py: 配置加载器测试 - test_simple_rotation.py: 简单轮动策略端到端测试
286 lines
8.7 KiB
Python
286 lines
8.7 KiB
Python
"""
|
||
测试配置加载和验证
|
||
|
||
验证:
|
||
1. 配置文件加载
|
||
2. Pydantic Schema 验证
|
||
3. 环境变量替换
|
||
4. 错误处理
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
import os
|
||
|
||
# 添加项目根目录到路径
|
||
project_root = Path(__file__).parent.parent.parent
|
||
if str(project_root) not in sys.path:
|
||
sys.path.insert(0, str(project_root))
|
||
|
||
from framework_v2.config import load_config, ConfigLoader
|
||
from framework_v2.config.schemas import RotationStrategyConfig, MarketType
|
||
|
||
|
||
def test_load_config():
|
||
"""测试 1: 加载配置文件"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 1: 加载配置文件")
|
||
print("=" * 70)
|
||
|
||
# 设置环境变量(模拟)
|
||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||
os.environ['TUSHARE_TOKEN'] = 'test_token_123'
|
||
|
||
# 加载配置
|
||
config = load_config('rotation_example.yaml')
|
||
|
||
print(f"\n✓ 配置加载成功")
|
||
print(f" 版本: {config.metadata.version}")
|
||
print(f" 策略: {config.metadata.strategy}")
|
||
print(f" 资产池: {len(config.asset_pools.equity)} 股票, "
|
||
f"{len(config.asset_pools.commodity)} 商品, "
|
||
f"{len(config.asset_pools.fixed_income)} 债券")
|
||
|
||
# 验证基本字段
|
||
assert config.metadata.version == "1.0.0"
|
||
assert config.factor.n_days == 25
|
||
assert config.rotation.select_num == 3
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_asset_pools():
|
||
"""测试 2: 资产池配置"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 2: 资产池配置")
|
||
print("=" * 70)
|
||
|
||
config = load_config('rotation_example.yaml')
|
||
|
||
# 验证股票资产
|
||
print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):")
|
||
for code, asset in config.asset_pools.equity.items():
|
||
print(f" {code}: {asset.name} ({asset.market.value})")
|
||
if asset.etf:
|
||
print(f" ETF: {asset.etf}")
|
||
|
||
# 验证商品资产
|
||
print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):")
|
||
for code, asset in config.asset_pools.commodity.items():
|
||
print(f" {code}: {asset.name}")
|
||
|
||
# 验证固定收益
|
||
print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):")
|
||
for code, asset in config.asset_pools.fixed_income.items():
|
||
print(f" {code}: {asset.name} (ETF: {asset.etf})")
|
||
|
||
# 验证市场类型
|
||
assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY
|
||
assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY
|
||
assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_threshold_config():
|
||
"""测试 3: 阈值配置"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 3: 阈值配置")
|
||
print("=" * 70)
|
||
|
||
config = load_config('rotation_example.yaml')
|
||
|
||
print(f"\n阈值模式: {config.rotation.threshold.mode}")
|
||
print(f" 参考标的: {config.rotation.threshold.dynamic.reference}")
|
||
print(f" 倍数: {config.rotation.threshold.dynamic.ratio}")
|
||
print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}")
|
||
|
||
assert config.rotation.threshold.mode.value == "dynamic"
|
||
assert config.rotation.threshold.dynamic.reference == "931862.CSI"
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_data_sources():
|
||
"""测试 4: 数据源配置"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 4: 数据源配置")
|
||
print("=" * 70)
|
||
|
||
config = load_config('rotation_example.yaml')
|
||
|
||
print(f"\n数据源 ({len(config.data.sources)} 个):")
|
||
for i, source in enumerate(config.data.sources, 1):
|
||
print(f" {i}. {source.type.value}")
|
||
print(f" 启用: {source.enabled}")
|
||
print(f" 超时: {source.timeout}s")
|
||
if source.url:
|
||
print(f" URL: {source.url}")
|
||
|
||
# 验证环境变量替换
|
||
flask_api_source = config.data.sources[0]
|
||
assert flask_api_source.url == 'https://k3s.tokenpluse.xyz'
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_validation_errors():
|
||
"""测试 5: 验证错误处理"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 5: 验证错误处理")
|
||
print("=" * 70)
|
||
|
||
# 测试 1: n_days 超出范围
|
||
print("\n[5.1] 测试 n_days 超出范围...")
|
||
try:
|
||
from framework_v2.config.schemas import FactorConfig
|
||
|
||
# n_days = 1000(超出 5-250 范围)
|
||
invalid_config = {
|
||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||
"backtest": {"start_date": "2020-01-01"},
|
||
"factor": {"n_days": 1000}, # 错误:超出范围
|
||
"data": {
|
||
"sources": [{"type": "flask_api", "url": "test"}]
|
||
}
|
||
}
|
||
|
||
RotationStrategyConfig(**invalid_config)
|
||
print(" ✗ 应该抛出验证错误")
|
||
assert False
|
||
except Exception as e:
|
||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||
|
||
# 测试 2: 缺少必需字段
|
||
print("\n[5.2] 测试缺少必需字段...")
|
||
try:
|
||
invalid_config = {
|
||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||
# 缺少 benchmark
|
||
"backtest": {"start_date": "2020-01-01"},
|
||
"data": {
|
||
"sources": [{"type": "flask_api", "url": "test"}]
|
||
}
|
||
}
|
||
|
||
RotationStrategyConfig(**invalid_config)
|
||
print(" ✗ 应该抛出验证错误")
|
||
assert False
|
||
except Exception as e:
|
||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||
|
||
# 测试 3: 环境变量未设置
|
||
print("\n[5.3] 测试环境变量未设置...")
|
||
try:
|
||
# 删除环境变量
|
||
old_value = os.environ.pop('FLASK_API_URL', None)
|
||
|
||
invalid_config = {
|
||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||
"backtest": {"start_date": "2020-01-01"},
|
||
"data": {
|
||
"sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}]
|
||
}
|
||
}
|
||
|
||
RotationStrategyConfig(**invalid_config)
|
||
print(" ✗ 应该抛出验证错误")
|
||
assert False
|
||
except ValueError as e:
|
||
print(f" ✓ 正确捕获环境变量错误: {e}")
|
||
finally:
|
||
# 恢复环境变量
|
||
if old_value:
|
||
os.environ['FLASK_API_URL'] = old_value
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
def test_env_substitution():
|
||
"""测试 6: 环境变量替换"""
|
||
print("\n" + "=" * 70)
|
||
print(" 测试 6: 环境变量替换")
|
||
print("=" * 70)
|
||
|
||
loader = ConfigLoader()
|
||
|
||
# 测试 1: 基本替换
|
||
print("\n[6.1] 基本替换...")
|
||
os.environ['TEST_VAR'] = 'test_value'
|
||
|
||
config = {
|
||
"url": "${TEST_VAR}"
|
||
}
|
||
result = loader._substitute_env_vars(config)
|
||
assert result["url"] == "test_value"
|
||
print(f" ✓ ${{TEST_VAR}} → {result['url']}")
|
||
|
||
# 测试 2: 默认值
|
||
print("\n[6.2] 默认值...")
|
||
config = {
|
||
"url": "${NON_EXISTENT_VAR:default_value}"
|
||
}
|
||
result = loader._substitute_env_vars(config)
|
||
assert result["url"] == "default_value"
|
||
print(f" ${{NON_EXISTENT_VAR:default_value}} → {result['url']}")
|
||
|
||
# 测试 3: 嵌套结构
|
||
print("\n[6.3] 嵌套结构...")
|
||
os.environ['API_URL'] = 'https://api.example.com'
|
||
|
||
config = {
|
||
"data": {
|
||
"sources": [
|
||
{"url": "${API_URL}"}
|
||
]
|
||
}
|
||
}
|
||
result = loader._substitute_env_vars(config)
|
||
assert result["data"]["sources"][0]["url"] == "https://api.example.com"
|
||
print(f" ✓ 嵌套替换成功")
|
||
|
||
print("\n✓ 测试通过")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("\n" + "=" * 70)
|
||
print(" 配置加载和验证测试")
|
||
print("=" * 70)
|
||
|
||
tests = [
|
||
("加载配置文件", test_load_config),
|
||
("资产池配置", test_asset_pools),
|
||
("阈值配置", test_threshold_config),
|
||
("数据源配置", test_data_sources),
|
||
("验证错误处理", test_validation_errors),
|
||
("环境变量替换", test_env_substitution),
|
||
]
|
||
|
||
passed = 0
|
||
failed = 0
|
||
|
||
for name, test_func in tests:
|
||
try:
|
||
test_func()
|
||
passed += 1
|
||
except Exception as e:
|
||
print(f"\n✗ 测试失败: {name}")
|
||
print(f" 错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
failed += 1
|
||
|
||
print("\n" + "=" * 70)
|
||
print(" 测试总结")
|
||
print("=" * 70)
|
||
print(f" ✓ 通过 - {passed}")
|
||
if failed > 0:
|
||
print(f" ✗ 失败 - {failed}")
|
||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||
print("=" * 70 + "\n")
|
||
|
||
if failed > 0:
|
||
sys.exit(1)
|