""" 测试配置加载和验证 验证: 1. 配置文件加载 2. Pydantic Schema 验证 3. 环境变量替换 4. 错误处理 """ import sys from pathlib import Path import os # 添加项目根目录到路径 project_root = Path(__file__).parent.parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) from framework_v2.config import load_config, ConfigLoader from framework_v2.config.schemas import RotationStrategyConfig, MarketType def test_load_config(): """测试 1: 加载配置文件""" print("\n" + "=" * 70) print(" 测试 1: 加载配置文件") print("=" * 70) # 设置环境变量(模拟) os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' os.environ['TUSHARE_TOKEN'] = 'test_token_123' # 加载配置 config = load_config('rotation_example.yaml') print(f"\n✓ 配置加载成功") print(f" 版本: {config.metadata.version}") print(f" 策略: {config.metadata.strategy}") print(f" 资产池: {len(config.asset_pools.equity)} 股票, " f"{len(config.asset_pools.commodity)} 商品, " f"{len(config.asset_pools.fixed_income)} 债券") # 验证基本字段 assert config.metadata.version == "1.0.0" assert config.factor.n_days == 25 assert config.rotation.select_num == 3 print("\n✓ 测试通过") def test_asset_pools(): """测试 2: 资产池配置""" print("\n" + "=" * 70) print(" 测试 2: 资产池配置") print("=" * 70) config = load_config('rotation_example.yaml') # 验证股票资产 print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):") for code, asset in config.asset_pools.equity.items(): print(f" {code}: {asset.name} ({asset.market.value})") if asset.etf: print(f" ETF: {asset.etf}") # 验证商品资产 print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):") for code, asset in config.asset_pools.commodity.items(): print(f" {code}: {asset.name}") # 验证固定收益 print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):") for code, asset in config.asset_pools.fixed_income.items(): print(f" {code}: {asset.name} (ETF: {asset.etf})") # 验证市场类型 assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY print("\n✓ 测试通过") def test_threshold_config(): """测试 3: 阈值配置""" print("\n" + "=" * 70) print(" 测试 3: 阈值配置") print("=" * 70) config = load_config('rotation_example.yaml') print(f"\n阈值模式: {config.rotation.threshold.mode}") print(f" 参考标的: {config.rotation.threshold.dynamic.reference}") print(f" 倍数: {config.rotation.threshold.dynamic.ratio}") print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}") assert config.rotation.threshold.mode.value == "dynamic" assert config.rotation.threshold.dynamic.reference == "931862.CSI" print("\n✓ 测试通过") def test_data_sources(): """测试 4: 数据源配置""" print("\n" + "=" * 70) print(" 测试 4: 数据源配置") print("=" * 70) config = load_config('rotation_example.yaml') print(f"\n数据源 ({len(config.data.sources)} 个):") for i, source in enumerate(config.data.sources, 1): print(f" {i}. {source.type.value}") print(f" 启用: {source.enabled}") print(f" 超时: {source.timeout}s") if source.url: print(f" URL: {source.url}") # 验证环境变量替换 flask_api_source = config.data.sources[0] assert flask_api_source.url == 'https://k3s.tokenpluse.xyz' print("\n✓ 测试通过") def test_validation_errors(): """测试 5: 验证错误处理""" print("\n" + "=" * 70) print(" 测试 5: 验证错误处理") print("=" * 70) # 测试 1: n_days 超出范围 print("\n[5.1] 测试 n_days 超出范围...") try: from framework_v2.config.schemas import FactorConfig # n_days = 1000(超出 5-250 范围) invalid_config = { "asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}}, "benchmark": {"code": "000300.SH", "name": "沪深300"}, "backtest": {"start_date": "2020-01-01"}, "factor": {"n_days": 1000}, # 错误:超出范围 "data": { "sources": [{"type": "flask_api", "url": "test"}] } } RotationStrategyConfig(**invalid_config) print(" ✗ 应该抛出验证错误") assert False except Exception as e: print(f" ✓ 正确捕获验证错误: {type(e).__name__}") # 测试 2: 缺少必需字段 print("\n[5.2] 测试缺少必需字段...") try: invalid_config = { "asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}}, # 缺少 benchmark "backtest": {"start_date": "2020-01-01"}, "data": { "sources": [{"type": "flask_api", "url": "test"}] } } RotationStrategyConfig(**invalid_config) print(" ✗ 应该抛出验证错误") assert False except Exception as e: print(f" ✓ 正确捕获验证错误: {type(e).__name__}") # 测试 3: 环境变量未设置 print("\n[5.3] 测试环境变量未设置...") try: # 删除环境变量 old_value = os.environ.pop('FLASK_API_URL', None) invalid_config = { "asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}}, "benchmark": {"code": "000300.SH", "name": "沪深300"}, "backtest": {"start_date": "2020-01-01"}, "data": { "sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}] } } RotationStrategyConfig(**invalid_config) print(" ✗ 应该抛出验证错误") assert False except ValueError as e: print(f" ✓ 正确捕获环境变量错误: {e}") finally: # 恢复环境变量 if old_value: os.environ['FLASK_API_URL'] = old_value print("\n✓ 测试通过") def test_env_substitution(): """测试 6: 环境变量替换""" print("\n" + "=" * 70) print(" 测试 6: 环境变量替换") print("=" * 70) loader = ConfigLoader() # 测试 1: 基本替换 print("\n[6.1] 基本替换...") os.environ['TEST_VAR'] = 'test_value' config = { "url": "${TEST_VAR}" } result = loader._substitute_env_vars(config) assert result["url"] == "test_value" print(f" ✓ ${{TEST_VAR}} → {result['url']}") # 测试 2: 默认值 print("\n[6.2] 默认值...") config = { "url": "${NON_EXISTENT_VAR:default_value}" } result = loader._substitute_env_vars(config) assert result["url"] == "default_value" print(f" ${{NON_EXISTENT_VAR:default_value}} → {result['url']}") # 测试 3: 嵌套结构 print("\n[6.3] 嵌套结构...") os.environ['API_URL'] = 'https://api.example.com' config = { "data": { "sources": [ {"url": "${API_URL}"} ] } } result = loader._substitute_env_vars(config) assert result["data"]["sources"][0]["url"] == "https://api.example.com" print(f" ✓ 嵌套替换成功") print("\n✓ 测试通过") if __name__ == "__main__": print("\n" + "=" * 70) print(" 配置加载和验证测试") print("=" * 70) tests = [ ("加载配置文件", test_load_config), ("资产池配置", test_asset_pools), ("阈值配置", test_threshold_config), ("数据源配置", test_data_sources), ("验证错误处理", test_validation_errors), ("环境变量替换", test_env_substitution), ] passed = 0 failed = 0 for name, test_func in tests: try: test_func() passed += 1 except Exception as e: print(f"\n✗ 测试失败: {name}") print(f" 错误: {e}") import traceback traceback.print_exc() failed += 1 print("\n" + "=" * 70) print(" 测试总结") print("=" * 70) print(f" ✓ 通过 - {passed}") if failed > 0: print(f" ✗ 失败 - {failed}") print(f"\n总计: {passed}/{passed + failed} 通过") print("=" * 70 + "\n") if failed > 0: sys.exit(1)