""" 测试扁平化资产池配置 验证: 1. 扁平化配置加载 2. 按市场分组 3. 信号/交易标的获取 4. 跨市场映射 """ import sys from pathlib import Path import os # 添加项目根目录到路径 project_root = Path(__file__).parent.parent if str(project_root) not in sys.path: sys.path.insert(0, str(project_root)) from framework_v2.config import load_config from framework_v2.config.schemas import GroupConfig def test_flat_config_load(): """测试 1: 加载扁平化配置""" print("\n" + "=" * 70) print(" 测试 1: 加载扁平化配置") print("=" * 70) # 设置环境变量 os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' os.environ['TUSHARE_TOKEN'] = 'test_token' # 加载配置 config = load_config('rotation_global.yaml') print(f"\n✓ 配置加载成功") print(f" 版本: {config.metadata.version}") print(f" 策略: {config.metadata.strategy}") print(f" 总标的数: {config.asset_pools.count()}") # 验证基本字段 assert config.metadata.version == "2.0.0" assert config.asset_pools.count() == 13 # 12 个标的 print("\n✓ 测试通过") def test_market_grouping(): """测试 2: 按市场分组""" print("\n" + "=" * 70) print(" 测试 2: 按市场分组") print("=" * 70) config = load_config('rotation_global.yaml') # 获取所有市场 markets = config.asset_pools.groups print(f"\n市场类型 ({len(markets)} 个):") for market in markets: count = config.asset_pools.count(market) print(f" {market}: {count} 只") # 按市场分组 by_group = config.asset_pools.by_group print(f"\n市场分组:") for market, assets in by_group.items(): print(f"\n {market} ({len(assets)} 只):") for code, asset in assets.items(): print(f" {code}: {asset.name}") # 验证市场数量 assert len(markets) == 7 # US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME # US_EQUITY, CN_EQUITY, JP_EQUITY, EU_EQUITY, HK_EQUITY, COMMODITY, FIXED_INCOME assert config.asset_pools.count('US_TECH') == 2 assert config.asset_pools.count('CN_GROWTH') == 3 assert config.asset_pools.count('COMMODITY') == 3 assert config.asset_pools.count('FIXED_INCOME') == 1 print("\n✓ 测试通过") def test_signal_trade_codes(): """测试 3: 信号和交易标的""" print("\n" + "=" * 70) print(" 测试 3: 信号和交易标的") print("=" * 70) config = load_config('rotation_global.yaml') # 获取所有信号标的 signal_codes = config.asset_pools.get_signal_codes() print(f"\n信号标的 (13 个):") for code in signal_codes: print(f" {code}") # 获取所有交易标的 trade_codes = config.asset_pools.get_trade_codes() print(f"\n交易标的 (13 个):") for code in trade_codes: print(f" {code}") # 获取特定市场的信号标的 us_signals = config.asset_pools.get_signal_codes('US_TECH') print(f"\n美股信号标的: {us_signals}") # 验证 assert len(signal_codes) == 13 assert len(trade_codes) == 13 assert 'NDX' in signal_codes assert '513100.SH' in trade_codes assert len(us_signals) == 2 print("\n✓ 测试通过") def test_signal_to_trade_mapping(): """测试 4: 信号→交易映射""" print("\n" + "=" * 70) print(" 测试 4: 信号→交易映射") print("=" * 70) config = load_config('rotation_global.yaml') # 获取映射 mapping = config.asset_pools.get_signal_to_trade_mapping() print(f"\n信号→交易映射:") for signal, trade in mapping.items(): asset = config.asset_pools.assets.get(signal) cross_market = "✗" if asset.signal_source == asset.trade_source else "✓" print(f" {cross_market} {signal} → {trade}") # 验证跨市场标的 print(f"\n跨市场标的:") for code, asset in config.asset_pools.assets.items(): if asset.is_cross_market: print(f" {code}: {asset.signal_source} → {asset.trade_source}") # 验证映射 assert mapping['NDX'] == '513100.SH' assert mapping['399006.SZ'] == '159915.SZ' assert mapping['GC=F'] == '518880.SH' assert mapping['931862.CSI'] == '931862.CSI' # 非跨市场 # 验证跨市场属性 assert config.asset_pools.assets['NDX'].is_cross_market == True assert config.asset_pools.assets['931862.CSI'].is_cross_market == False print("\n✓ 测试通过") def test_market_specific_mapping(): """测试 5: 特定市场映射""" print("\n" + "=" * 70) print(" 测试 5: 特定市场映射") print("=" * 70) config = load_config('rotation_global.yaml') # 获取美股映射 us_mapping = config.asset_pools.get_signal_to_trade_mapping('US_TECH') print(f"\n美股映射:") for signal, trade in us_mapping.items(): print(f" {signal} → {trade}") # 获取商品映射 commodity_mapping = config.asset_pools.get_signal_to_trade_mapping('COMMODITY') print(f"\n商品映射:") for signal, trade in commodity_mapping.items(): print(f" {signal} → {trade}") # 验证 assert len(us_mapping) == 2 assert len(commodity_mapping) == 3 assert us_mapping['NDX'] == '513100.SH' assert commodity_mapping['GC=F'] == '518880.SH' print("\n✓ 测试通过") def test_diversification_config(): """测试 6: 分散化配置""" print("\n" + "=" * 70) print(" 测试 6: 分散化配置") print("=" * 70) config = load_config('rotation_global.yaml') print(f"\n轮动配置:") print(f" 选股数量: {config.rotation.select_num}") print(f" 分散化: {config.rotation.diversified}") print(f" 分散化分组: {config.rotation.diversification_groups}") # 验证默认配置(全局模式) assert config.rotation.select_num == 5 assert config.rotation.diversified == False assert config.rotation.diversification_groups is None print("\n✓ 测试通过") def test_asset_config_details(): """测试 7: 标的配置详情""" print("\n" + "=" * 70) print(" 测试 7: 标的配置详情") print("=" * 70) config = load_config('rotation_global.yaml') # 检查纳指配置 ndx = config.asset_pools.assets['NDX'] print(f"\n纳指100 配置:") print(f" 名称: {ndx.name}") print(f" 市场: {ndx.group}") print(f" 信号来源: {ndx.signal_source}") print(f" 交易来源: {ndx.trade_source}") print(f" 跨市场: {ndx.is_cross_market}") print(f" 描述: {ndx.description}") # 检查短债配置 bond = config.asset_pools.assets['931862.CSI'] print(f"\n短债指数 配置:") print(f" 名称: {bond.name}") print(f" 市场: {bond.group}") print(f" 信号来源: {bond.signal_source}") print(f" 交易来源: {bond.trade_source}") print(f" 跨市场: {bond.is_cross_market}") # 验证 assert ndx.name == "纳指100" assert ndx.group == 'US_TECH' assert ndx.signal_source == "NDX" assert ndx.trade_source == "513100.SH" assert ndx.is_cross_market == True assert bond.signal_source == bond.trade_source assert bond.is_cross_market == False print("\n✓ 测试通过") if __name__ == "__main__": print("\n" + "=" * 70) print(" 扁平化资产池配置测试") print("=" * 70) tests = [ ("加载扁平化配置", test_flat_config_load), ("按市场分组", test_market_grouping), ("信号和交易标的", test_signal_trade_codes), ("信号→交易映射", test_signal_to_trade_mapping), ("特定市场映射", test_market_specific_mapping), ("分散化配置", test_diversification_config), ("标的配置详情", test_asset_config_details), ] passed = 0 failed = 0 for name, test_func in tests: try: test_func() passed += 1 except Exception as e: print(f"\n✗ 测试失败: {name}") print(f" 错误: {e}") import traceback traceback.print_exc() failed += 1 print("\n" + "=" * 70) print(" 测试总结") print("=" * 70) print(f" ✓ 通过 - {passed}") if failed > 0: print(f" ✗ 失败 - {failed}") print(f"\n总计: {passed}/{passed + failed} 通过") print("=" * 70 + "\n") if failed > 0: sys.exit(1)