Files
etf/rotation/config_loader.py
aszerW 921f84cb6a feat: 新增 standardized_slope (t-statistic) 因子并实验验证
- simple_rotation.py: 新增 standardized_slope_score 函数 (slope/SE)
- config_loader.py: FactorType 枚举新增 STANDARDIZED_SLOPE
- 对比实验结果: standardized_slope 年化 13.73% vs slope_r2 19.84%
- 结论: t-statistic 过度惩罚高波动资产的有效趋势信号,不适合本场景
- 文档更新: 动量因子对比调研报告新增 3.3 节详细分析
2026-06-06 16:40:01 +08:00

459 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
配置加载器 - rotation 目录专用(独立实现)
不依赖 framework_v2完全独立的配置加载和验证机制。
使用示例:
from config_loader import load_rotation_config
# 加载配置(自动处理环境变量)
config = load_rotation_config('config_simple.yaml')
# 访问配置
print(f"策略版本: {config.metadata.version}")
print(f"动量窗口: {config.factor.n_days}")
print(f"资产数量: {config.asset_pools.count()}")
"""
import os
import re
import yaml
from pathlib import Path
from typing import Optional, Dict, List, Any
from pydantic import BaseModel, Field, field_validator
from enum import Enum
# ============================================================
# 枚举类型
# ============================================================
class FactorType(str, Enum):
"""因子类型枚举"""
MOMENTUM = "momentum"
SLOPE_R2 = "slope_r2"
WEIGHTED_MOMENTUM = "weighted_momentum"
VOL_ADJUSTED_MOMENTUM = "vol_adjusted_momentum"
STANDARDIZED_SLOPE = "standardized_slope"
class PremiumMode(str, Enum):
"""溢价控制模式"""
FILTER = "filter" # 完全排除
PENALIZE = "penalize" # 降权
class ThresholdMode(str, Enum):
"""阈值模式"""
FIXED = "fixed" # 固定阈值
DYNAMIC = "dynamic" # 动态阈值
class DataSourceType(str, Enum):
"""数据源类型"""
FLASK_API = "flask_api"
TUSHARE = "tushare"
YFINANCE = "yfinance"
# ============================================================
# Schema 定义(独立实现)
# ============================================================
class PremiumConfig(BaseModel):
"""溢价控制配置"""
enabled: bool = Field(default=True, description="是否启用")
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
class AssetConfig(BaseModel):
"""标的配置"""
name: str = Field(..., description="标的名称")
group: str = Field(..., description="策略分组")
signal_source: str = Field(..., description="信号来源代码")
trade_source: str = Field(..., description="交易来源代码")
description: Optional[str] = Field(None, description="标的描述")
etf: Optional[str] = Field(None, description="ETF代码兼容旧配置")
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
@property
def is_cross_market(self) -> bool:
"""是否跨市场"""
return self.signal_source != self.trade_source
class AssetPool(BaseModel):
"""资产池(扁平化设计)"""
assets: Dict[str, AssetConfig] = Field(
default_factory=dict,
description="所有标的"
)
@property
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
"""按策略分组"""
groups = {}
for code, asset in self.assets.items():
group = asset.group
if group not in groups:
groups[group] = {}
groups[group][code] = asset
return groups
@property
def groups(self) -> list:
"""获取所有策略分组"""
return list(self.by_group.keys())
def get_signal_codes(self, group: str = None) -> list:
"""获取信号标的"""
if group:
return [
asset.signal_source
for asset in self.assets.values()
if asset.group == group
]
return [asset.signal_source for asset in self.assets.values()]
def get_trade_codes(self, group: str = None) -> list:
"""获取交易标的"""
if group:
return [
asset.trade_source
for asset in self.assets.values()
if asset.group == group
]
return [asset.trade_source for asset in self.assets.values()]
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
"""获取信号→交易映射"""
mapping = {}
for asset in self.assets.values():
if group and asset.group != group:
continue
mapping[asset.signal_source] = asset.trade_source
return mapping
def count(self, group: str = None) -> int:
"""获取标的数量"""
if group:
return len([a for a in self.assets.values() if a.group == group])
return len(self.assets)
class FactorConfig(BaseModel):
"""因子配置"""
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM)
n_days: int = Field(default=25, ge=5, le=250)
class DynamicThresholdConfig(BaseModel):
"""动态阈值配置"""
reference: str = Field(..., description="阈值参考标的")
ratio: float = Field(default=1.0, ge=0, description="阈值倍数")
fallback_enabled: bool = Field(default=True)
fallback_value: float = Field(default=0.0)
class ThresholdConfig(BaseModel):
"""阈值配置"""
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC)
fixed_value: float = Field(default=0.0)
dynamic: Optional[DynamicThresholdConfig] = Field(None)
class RotationConfig(BaseModel):
"""轮动配置"""
select_num: int = Field(default=3, ge=1, le=10)
diversified: bool = Field(default=True)
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
class RebalanceConfig(BaseModel):
"""调仓配置"""
min_hold_days: int = Field(default=1, ge=1)
score_threshold: float = Field(default=0.0)
trade_cost: float = Field(default=0.001, ge=0, le=0.01)
class PremiumControlConfig(BaseModel):
"""溢价控制配置"""
enabled: bool = Field(default=False)
default_threshold: float = Field(default=0.10)
mode: PremiumMode = Field(default=PremiumMode.FILTER)
penalty_factor: float = Field(default=0.5)
market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None)
class DataSourceConfig(BaseModel):
"""数据源配置"""
type: DataSourceType = Field(...)
enabled: bool = Field(default=True)
timeout: int = Field(default=120, ge=10, le=600)
url: Optional[str] = Field(None)
class DataConfig(BaseModel):
"""数据配置"""
sources: List[DataSourceConfig] = Field(...)
@field_validator('sources')
@classmethod
def check_at_least_one(cls, v):
"""至少配置一个数据源"""
if not v:
raise ValueError("必须配置至少一个数据源")
return v
class BenchmarkConfig(BaseModel):
"""基准配置"""
code: str = Field(...)
name: str = Field(...)
class BacktestConfig(BaseModel):
"""回测配置"""
start_date: str = Field(...)
end_date: Optional[str] = Field(None)
class MetadataConfig(BaseModel):
"""配置元数据"""
version: str = Field(default="1.0.0")
strategy: str = Field(default="rotation")
description: str = Field(default="")
last_updated: Optional[str] = Field(None)
class RotationStrategyConfig(BaseModel):
"""ETF轮动策略完整配置"""
metadata: MetadataConfig = Field(default_factory=MetadataConfig)
asset_pools: AssetPool = Field(...)
benchmark: BenchmarkConfig = Field(...)
backtest: BacktestConfig = Field(...)
factor: FactorConfig = Field(default_factory=FactorConfig)
rotation: RotationConfig = Field(default_factory=RotationConfig)
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig)
data: DataConfig = Field(...)
# ============================================================
# 配置加载器(独立实现)
# ============================================================
class ConfigLoader:
"""配置加载器(独立实现)"""
def __init__(self, config_dir: str = None):
"""初始化"""
if config_dir is None:
config_dir = Path(__file__).parent
self.config_dir = Path(config_dir)
def load(self, config_file: str) -> RotationStrategyConfig:
"""加载配置文件"""
# 1. 解析路径
config_path = self._resolve_path(config_file)
# 2. 读取 YAML
with open(config_path, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
# 3. 环境变量替换
config_dict = self._substitute_env_vars(config_dict)
# 4. Pydantic 验证
config = RotationStrategyConfig(**config_dict)
return config
def _resolve_path(self, config_file: str) -> Path:
"""解析配置文件路径"""
path = Path(config_file)
if path.is_absolute():
if not path.exists():
raise FileNotFoundError(f"配置文件不存在: {path}")
return path
# 相对路径:先在配置目录查找
config_path = self.config_dir / path
if config_path.exists():
return config_path
# 然后在当前工作目录查找
cwd_path = Path.cwd() / path
if cwd_path.exists():
return cwd_path
raise FileNotFoundError(
f"配置文件未找到: {path}\n"
f"搜索路径:\n"
f" - {config_path}\n"
f" - {cwd_path}"
)
def _substitute_env_vars(self, config: Any) -> Any:
"""替换环境变量"""
if isinstance(config, dict):
return {
key: self._substitute_env_vars(value)
for key, value in config.items()
}
elif isinstance(config, list):
return [
self._substitute_env_vars(item)
for item in config
]
elif isinstance(config, str):
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
pattern = r'\$\{([^}]+)\}'
def replace_match(match):
var_expr = match.group(1)
if ':' in var_expr:
var_name, default = var_expr.split(':', 1)
else:
var_name = var_expr
default = None
value = os.getenv(var_name)
if value is None:
if default is not None:
return default
else:
raise ValueError(
f"环境变量未设置: {var_name}\n"
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
)
return value
return re.sub(pattern, replace_match, config)
else:
return config
# ============================================================
# 快捷函数
# ============================================================
def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig:
"""
加载 rotation 配置
Args:
config_file: 配置文件名
Returns:
RotationStrategyConfig: 验证后的配置对象
环境变量:
必须设置 FLASK_API_URL 或提供默认值
"""
loader = ConfigLoader()
return loader.load(config_file)
def get_config_info(config: RotationStrategyConfig) -> dict:
"""获取配置摘要"""
return {
'version': config.metadata.version,
'strategy': config.metadata.strategy,
'description': config.metadata.description,
'asset_count': config.asset_pools.count(),
'groups': config.asset_pools.groups,
'factor_type': config.factor.type.value,
'n_days': config.factor.n_days,
'select_num': config.rotation.select_num,
'start_date': config.backtest.start_date,
'end_date': config.backtest.end_date or '至今',
'threshold_mode': config.rotation.threshold.mode.value,
}
def print_config_summary(config: RotationStrategyConfig) -> None:
"""打印配置摘要"""
info = get_config_info(config)
print("\n" + "=" * 50)
print(" ETF 轮动策略配置摘要")
print("=" * 50)
print(f"\n版本: {info['version']}")
print(f"策略: {info['strategy']}")
print(f"描述: {info['description']}")
print(f"\n资产池:")
print(f" 总数: {info['asset_count']}")
print(f" 分组: {', '.join(info['groups'])}")
print(f"\n因子配置:")
print(f" 类型: {info['factor_type']}")
print(f" 窗口: {info['n_days']}")
print(f"\n轮动配置:")
print(f" 选股数: {info['select_num']}")
print(f" 阈值模式: {info['threshold_mode']}")
print(f"\n回测配置:")
print(f" 起始: {info['start_date']}")
print(f" 结束: {info['end_date']}")
print(f"\n各分组资产:")
for group, assets in config.asset_pools.by_group.items():
print(f"\n [{group}] ({len(assets)} 个):")
for code, asset in assets.items():
cross_market = "✓ 跨市场" if asset.is_cross_market else ""
print(f" - {asset.name} ({code})")
print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}")
print("\n" + "=" * 50)
# ============================================================
# 测试
# ============================================================
if __name__ == "__main__":
"""测试配置加载器"""
# 设置环境变量(如果没有设置)
if 'FLASK_API_URL' not in os.environ:
print("提示: FLASK_API_URL 未设置,使用默认值")
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
try:
# 加载配置
print("\n[测试] 加载 config_simple.yaml...")
config = load_rotation_config()
# 打印摘要
print_config_summary(config)
# 测试基本功能
print("\n[验证] 基本功能测试:")
print(f" ✓ 配置对象类型: {type(config).__name__}")
print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}")
print(f" ✓ 资产数量: {config.asset_pools.count()}")
print(f" ✓ 分组数量: {len(config.asset_pools.groups)}")
# 测试数据源
print("\n[验证] 数据源配置:")
for i, source in enumerate(config.data.sources, 1):
print(f" {i}. {source.type.value}")
print(f" URL: {source.url}")
print(f" 启用: {source.enabled}")
print("\n✓ 所有测试通过!")
except Exception as e:
print(f"\n✗ 加载失败: {e}")
import traceback
traceback.print_exc()