配置文件: - rotation_global.yaml: 扁平化资产池配置示例,演示 group 策略分组 * 13 个标的覆盖 7 个策略分组(US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME) * signal_source/trade_source 分离配置(跨市场场景) * 分散化选股配置示例(注释状态) * 默认使用 Flask API 数据源 测试用例: - test_flat_asset_pool.py: 7/7 测试通过 * 扁平配置加载验证 * 策略分组功能测试(by_group, groups, count) * 信号/交易标的获取(get_signal_codes, get_trade_codes) * 信号→交易映射(get_signal_to_trade_mapping) * 分散化配置验证 * 标的配置详情验证 - test_config.py: 配置加载器测试 - test_simple_rotation.py: 简单轮动策略端到端测试
282 lines
8.5 KiB
Python
282 lines
8.5 KiB
Python
"""
|
|
测试扁平化资产池配置
|
|
|
|
验证:
|
|
1. 扁平化配置加载
|
|
2. 按市场分组
|
|
3. 信号/交易标的获取
|
|
4. 跨市场映射
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
import os
|
|
|
|
# 添加项目根目录到路径
|
|
project_root = Path(__file__).parent.parent
|
|
if str(project_root) not in sys.path:
|
|
sys.path.insert(0, str(project_root))
|
|
|
|
from framework_v2.config import load_config
|
|
from framework_v2.config.schemas import GroupConfig
|
|
|
|
|
|
def test_flat_config_load():
|
|
"""测试 1: 加载扁平化配置"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 1: 加载扁平化配置")
|
|
print("=" * 70)
|
|
|
|
# 设置环境变量
|
|
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
|
os.environ['TUSHARE_TOKEN'] = 'test_token'
|
|
|
|
# 加载配置
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
print(f"\n✓ 配置加载成功")
|
|
print(f" 版本: {config.metadata.version}")
|
|
print(f" 策略: {config.metadata.strategy}")
|
|
print(f" 总标的数: {config.asset_pools.count()}")
|
|
|
|
# 验证基本字段
|
|
assert config.metadata.version == "2.0.0"
|
|
assert config.asset_pools.count() == 13 # 12 个标的
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_market_grouping():
|
|
"""测试 2: 按市场分组"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 2: 按市场分组")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
# 获取所有市场
|
|
markets = config.asset_pools.groups
|
|
print(f"\n市场类型 ({len(markets)} 个):")
|
|
for market in markets:
|
|
count = config.asset_pools.count(market)
|
|
print(f" {market}: {count} 只")
|
|
|
|
# 按市场分组
|
|
by_group = config.asset_pools.by_group
|
|
print(f"\n市场分组:")
|
|
for market, assets in by_group.items():
|
|
print(f"\n {market} ({len(assets)} 只):")
|
|
for code, asset in assets.items():
|
|
print(f" {code}: {asset.name}")
|
|
|
|
# 验证市场数量
|
|
assert len(markets) == 7 # US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME # US_EQUITY, CN_EQUITY, JP_EQUITY, EU_EQUITY, HK_EQUITY, COMMODITY, FIXED_INCOME
|
|
assert config.asset_pools.count('US_TECH') == 2
|
|
assert config.asset_pools.count('CN_GROWTH') == 3
|
|
assert config.asset_pools.count('COMMODITY') == 3
|
|
assert config.asset_pools.count('FIXED_INCOME') == 1
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_signal_trade_codes():
|
|
"""测试 3: 信号和交易标的"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 3: 信号和交易标的")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
# 获取所有信号标的
|
|
signal_codes = config.asset_pools.get_signal_codes()
|
|
print(f"\n信号标的 (13 个):")
|
|
for code in signal_codes:
|
|
print(f" {code}")
|
|
|
|
# 获取所有交易标的
|
|
trade_codes = config.asset_pools.get_trade_codes()
|
|
print(f"\n交易标的 (13 个):")
|
|
for code in trade_codes:
|
|
print(f" {code}")
|
|
|
|
# 获取特定市场的信号标的
|
|
us_signals = config.asset_pools.get_signal_codes('US_TECH')
|
|
print(f"\n美股信号标的: {us_signals}")
|
|
|
|
# 验证
|
|
assert len(signal_codes) == 13
|
|
assert len(trade_codes) == 13
|
|
assert 'NDX' in signal_codes
|
|
assert '513100.SH' in trade_codes
|
|
assert len(us_signals) == 2
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_signal_to_trade_mapping():
|
|
"""测试 4: 信号→交易映射"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 4: 信号→交易映射")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
# 获取映射
|
|
mapping = config.asset_pools.get_signal_to_trade_mapping()
|
|
|
|
print(f"\n信号→交易映射:")
|
|
for signal, trade in mapping.items():
|
|
asset = config.asset_pools.assets.get(signal)
|
|
cross_market = "✗" if asset.signal_source == asset.trade_source else "✓"
|
|
print(f" {cross_market} {signal} → {trade}")
|
|
|
|
# 验证跨市场标的
|
|
print(f"\n跨市场标的:")
|
|
for code, asset in config.asset_pools.assets.items():
|
|
if asset.is_cross_market:
|
|
print(f" {code}: {asset.signal_source} → {asset.trade_source}")
|
|
|
|
# 验证映射
|
|
assert mapping['NDX'] == '513100.SH'
|
|
assert mapping['399006.SZ'] == '159915.SZ'
|
|
assert mapping['GC=F'] == '518880.SH'
|
|
assert mapping['931862.CSI'] == '931862.CSI' # 非跨市场
|
|
|
|
# 验证跨市场属性
|
|
assert config.asset_pools.assets['NDX'].is_cross_market == True
|
|
assert config.asset_pools.assets['931862.CSI'].is_cross_market == False
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_market_specific_mapping():
|
|
"""测试 5: 特定市场映射"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 5: 特定市场映射")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
# 获取美股映射
|
|
us_mapping = config.asset_pools.get_signal_to_trade_mapping('US_TECH')
|
|
print(f"\n美股映射:")
|
|
for signal, trade in us_mapping.items():
|
|
print(f" {signal} → {trade}")
|
|
|
|
# 获取商品映射
|
|
commodity_mapping = config.asset_pools.get_signal_to_trade_mapping('COMMODITY')
|
|
print(f"\n商品映射:")
|
|
for signal, trade in commodity_mapping.items():
|
|
print(f" {signal} → {trade}")
|
|
|
|
# 验证
|
|
assert len(us_mapping) == 2
|
|
assert len(commodity_mapping) == 3
|
|
assert us_mapping['NDX'] == '513100.SH'
|
|
assert commodity_mapping['GC=F'] == '518880.SH'
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_diversification_config():
|
|
"""测试 6: 分散化配置"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 6: 分散化配置")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
print(f"\n轮动配置:")
|
|
print(f" 选股数量: {config.rotation.select_num}")
|
|
print(f" 分散化: {config.rotation.diversified}")
|
|
print(f" 分散化分组: {config.rotation.diversification_groups}")
|
|
|
|
# 验证默认配置(全局模式)
|
|
assert config.rotation.select_num == 5
|
|
assert config.rotation.diversified == False
|
|
assert config.rotation.diversification_groups is None
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
def test_asset_config_details():
|
|
"""测试 7: 标的配置详情"""
|
|
print("\n" + "=" * 70)
|
|
print(" 测试 7: 标的配置详情")
|
|
print("=" * 70)
|
|
|
|
config = load_config('rotation_global.yaml')
|
|
|
|
# 检查纳指配置
|
|
ndx = config.asset_pools.assets['NDX']
|
|
print(f"\n纳指100 配置:")
|
|
print(f" 名称: {ndx.name}")
|
|
print(f" 市场: {ndx.group}")
|
|
print(f" 信号来源: {ndx.signal_source}")
|
|
print(f" 交易来源: {ndx.trade_source}")
|
|
print(f" 跨市场: {ndx.is_cross_market}")
|
|
print(f" 描述: {ndx.description}")
|
|
|
|
# 检查短债配置
|
|
bond = config.asset_pools.assets['931862.CSI']
|
|
print(f"\n短债指数 配置:")
|
|
print(f" 名称: {bond.name}")
|
|
print(f" 市场: {bond.group}")
|
|
print(f" 信号来源: {bond.signal_source}")
|
|
print(f" 交易来源: {bond.trade_source}")
|
|
print(f" 跨市场: {bond.is_cross_market}")
|
|
|
|
# 验证
|
|
assert ndx.name == "纳指100"
|
|
assert ndx.group == 'US_TECH'
|
|
assert ndx.signal_source == "NDX"
|
|
assert ndx.trade_source == "513100.SH"
|
|
assert ndx.is_cross_market == True
|
|
|
|
assert bond.signal_source == bond.trade_source
|
|
assert bond.is_cross_market == False
|
|
|
|
print("\n✓ 测试通过")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("\n" + "=" * 70)
|
|
print(" 扁平化资产池配置测试")
|
|
print("=" * 70)
|
|
|
|
tests = [
|
|
("加载扁平化配置", test_flat_config_load),
|
|
("按市场分组", test_market_grouping),
|
|
("信号和交易标的", test_signal_trade_codes),
|
|
("信号→交易映射", test_signal_to_trade_mapping),
|
|
("特定市场映射", test_market_specific_mapping),
|
|
("分散化配置", test_diversification_config),
|
|
("标的配置详情", test_asset_config_details),
|
|
]
|
|
|
|
passed = 0
|
|
failed = 0
|
|
|
|
for name, test_func in tests:
|
|
try:
|
|
test_func()
|
|
passed += 1
|
|
except Exception as e:
|
|
print(f"\n✗ 测试失败: {name}")
|
|
print(f" 错误: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
failed += 1
|
|
|
|
print("\n" + "=" * 70)
|
|
print(" 测试总结")
|
|
print("=" * 70)
|
|
print(f" ✓ 通过 - {passed}")
|
|
if failed > 0:
|
|
print(f" ✗ 失败 - {failed}")
|
|
print(f"\n总计: {passed}/{passed + failed} 通过")
|
|
print("=" * 70 + "\n")
|
|
|
|
if failed > 0:
|
|
sys.exit(1)
|