- config_loader.py: WeightType 枚举新增 KELLY - simple_rotation.py: compute_position_weights 新增 kelly 分支 - 公式: w_i = max(score_i, 0) / sum(max(score_j, 0)) - 负分自动排除 (Kelly: 不下注负期望) - 全负分时 fallback 到等权 - _generate_signals 传递 scores 给 kelly 模式 - config_simple.yaml: weight 改为 kelly - 新增策略总结文档: kelly_weight.md 回测对比 (2020-2026): - equal: 年化 19.88%, 夏普 1.13, 回撤 -14.65% - rank: 年化 22.90%, 夏普 1.12, 回撤 -16.27% - kelly: 年化 30.13%, 夏普 1.15, 回撤 -20.44%
467 lines
15 KiB
Python
467 lines
15 KiB
Python
"""
|
||
配置加载器 - rotation 目录专用(独立实现)
|
||
|
||
不依赖 framework_v2,完全独立的配置加载和验证机制。
|
||
|
||
使用示例:
|
||
from config_loader import load_rotation_config
|
||
|
||
# 加载配置(自动处理环境变量)
|
||
config = load_rotation_config('config_simple.yaml')
|
||
|
||
# 访问配置
|
||
print(f"策略版本: {config.metadata.version}")
|
||
print(f"动量窗口: {config.factor.n_days}")
|
||
print(f"资产数量: {config.asset_pools.count()}")
|
||
"""
|
||
|
||
import os
|
||
import re
|
||
import yaml
|
||
from pathlib import Path
|
||
from typing import Optional, Dict, List, Any
|
||
from pydantic import BaseModel, Field, field_validator
|
||
from enum import Enum
|
||
|
||
|
||
# ============================================================
|
||
# 枚举类型
|
||
# ============================================================
|
||
|
||
class FactorType(str, Enum):
|
||
"""因子类型枚举"""
|
||
MOMENTUM = "momentum"
|
||
SLOPE_R2 = "slope_r2"
|
||
WEIGHTED_MOMENTUM = "weighted_momentum"
|
||
VOL_ADJUSTED_MOMENTUM = "vol_adjusted_momentum"
|
||
STANDARDIZED_SLOPE = "standardized_slope"
|
||
|
||
|
||
class PremiumMode(str, Enum):
|
||
"""溢价控制模式"""
|
||
FILTER = "filter" # 完全排除
|
||
PENALIZE = "penalize" # 降权
|
||
|
||
|
||
class ThresholdMode(str, Enum):
|
||
"""阈值模式"""
|
||
FIXED = "fixed" # 固定阈值
|
||
DYNAMIC = "dynamic" # 动态阈值
|
||
|
||
|
||
class WeightType(str, Enum):
|
||
"""仓位加权模式"""
|
||
EQUAL = "equal" # 等权
|
||
RANK = "rank" # 按排名加权 (slot i gets (N-i)/triangular(N))
|
||
KELLY = "kelly" # Kelly准则近似 (score-proportional weighting)
|
||
|
||
|
||
class DataSourceType(str, Enum):
|
||
"""数据源类型"""
|
||
FLASK_API = "flask_api"
|
||
TUSHARE = "tushare"
|
||
YFINANCE = "yfinance"
|
||
|
||
|
||
# ============================================================
|
||
# Schema 定义(独立实现)
|
||
# ============================================================
|
||
|
||
class PremiumConfig(BaseModel):
|
||
"""溢价控制配置"""
|
||
enabled: bool = Field(default=True, description="是否启用")
|
||
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
|
||
|
||
|
||
class AssetConfig(BaseModel):
|
||
"""标的配置"""
|
||
name: str = Field(..., description="标的名称")
|
||
group: str = Field(..., description="策略分组")
|
||
|
||
signal_source: str = Field(..., description="信号来源代码")
|
||
trade_source: str = Field(..., description="交易来源代码")
|
||
|
||
description: Optional[str] = Field(None, description="标的描述")
|
||
etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)")
|
||
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
|
||
|
||
@property
|
||
def is_cross_market(self) -> bool:
|
||
"""是否跨市场"""
|
||
return self.signal_source != self.trade_source
|
||
|
||
|
||
class AssetPool(BaseModel):
|
||
"""资产池(扁平化设计)"""
|
||
assets: Dict[str, AssetConfig] = Field(
|
||
default_factory=dict,
|
||
description="所有标的"
|
||
)
|
||
|
||
@property
|
||
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
|
||
"""按策略分组"""
|
||
groups = {}
|
||
for code, asset in self.assets.items():
|
||
group = asset.group
|
||
if group not in groups:
|
||
groups[group] = {}
|
||
groups[group][code] = asset
|
||
return groups
|
||
|
||
@property
|
||
def groups(self) -> list:
|
||
"""获取所有策略分组"""
|
||
return list(self.by_group.keys())
|
||
|
||
def get_signal_codes(self, group: str = None) -> list:
|
||
"""获取信号标的"""
|
||
if group:
|
||
return [
|
||
asset.signal_source
|
||
for asset in self.assets.values()
|
||
if asset.group == group
|
||
]
|
||
return [asset.signal_source for asset in self.assets.values()]
|
||
|
||
def get_trade_codes(self, group: str = None) -> list:
|
||
"""获取交易标的"""
|
||
if group:
|
||
return [
|
||
asset.trade_source
|
||
for asset in self.assets.values()
|
||
if asset.group == group
|
||
]
|
||
return [asset.trade_source for asset in self.assets.values()]
|
||
|
||
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
|
||
"""获取信号→交易映射"""
|
||
mapping = {}
|
||
for asset in self.assets.values():
|
||
if group and asset.group != group:
|
||
continue
|
||
mapping[asset.signal_source] = asset.trade_source
|
||
return mapping
|
||
|
||
def count(self, group: str = None) -> int:
|
||
"""获取标的数量"""
|
||
if group:
|
||
return len([a for a in self.assets.values() if a.group == group])
|
||
return len(self.assets)
|
||
|
||
|
||
class FactorConfig(BaseModel):
|
||
"""因子配置"""
|
||
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM)
|
||
n_days: int = Field(default=25, ge=5, le=250)
|
||
|
||
|
||
class DynamicThresholdConfig(BaseModel):
|
||
"""动态阈值配置"""
|
||
reference: str = Field(..., description="阈值参考标的")
|
||
ratio: float = Field(default=1.0, ge=0, description="阈值倍数")
|
||
fallback_enabled: bool = Field(default=True)
|
||
fallback_value: float = Field(default=0.0)
|
||
|
||
|
||
class ThresholdConfig(BaseModel):
|
||
"""阈值配置"""
|
||
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC)
|
||
fixed_value: float = Field(default=0.0)
|
||
dynamic: Optional[DynamicThresholdConfig] = Field(None)
|
||
|
||
|
||
class RotationConfig(BaseModel):
|
||
"""轮动配置"""
|
||
select_num: int = Field(default=3, ge=1, le=10)
|
||
diversified: bool = Field(default=True)
|
||
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
||
weight: WeightType = Field(default=WeightType.EQUAL)
|
||
|
||
|
||
class RebalanceConfig(BaseModel):
|
||
"""调仓配置"""
|
||
min_hold_days: int = Field(default=1, ge=1)
|
||
score_threshold: float = Field(default=0.0)
|
||
trade_cost: float = Field(default=0.001, ge=0, le=0.01)
|
||
|
||
|
||
class PremiumControlConfig(BaseModel):
|
||
"""溢价控制配置"""
|
||
enabled: bool = Field(default=False)
|
||
default_threshold: float = Field(default=0.10)
|
||
mode: PremiumMode = Field(default=PremiumMode.FILTER)
|
||
penalty_factor: float = Field(default=0.5)
|
||
market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None)
|
||
|
||
|
||
class DataSourceConfig(BaseModel):
|
||
"""数据源配置"""
|
||
type: DataSourceType = Field(...)
|
||
enabled: bool = Field(default=True)
|
||
timeout: int = Field(default=120, ge=10, le=600)
|
||
url: Optional[str] = Field(None)
|
||
|
||
|
||
class DataConfig(BaseModel):
|
||
"""数据配置"""
|
||
sources: List[DataSourceConfig] = Field(...)
|
||
|
||
@field_validator('sources')
|
||
@classmethod
|
||
def check_at_least_one(cls, v):
|
||
"""至少配置一个数据源"""
|
||
if not v:
|
||
raise ValueError("必须配置至少一个数据源")
|
||
return v
|
||
|
||
|
||
class BenchmarkConfig(BaseModel):
|
||
"""基准配置"""
|
||
code: str = Field(...)
|
||
name: str = Field(...)
|
||
|
||
|
||
class BacktestConfig(BaseModel):
|
||
"""回测配置"""
|
||
start_date: str = Field(...)
|
||
end_date: Optional[str] = Field(None)
|
||
|
||
|
||
class MetadataConfig(BaseModel):
|
||
"""配置元数据"""
|
||
version: str = Field(default="1.0.0")
|
||
strategy: str = Field(default="rotation")
|
||
description: str = Field(default="")
|
||
last_updated: Optional[str] = Field(None)
|
||
|
||
|
||
class RotationStrategyConfig(BaseModel):
|
||
"""ETF轮动策略完整配置"""
|
||
metadata: MetadataConfig = Field(default_factory=MetadataConfig)
|
||
asset_pools: AssetPool = Field(...)
|
||
benchmark: BenchmarkConfig = Field(...)
|
||
backtest: BacktestConfig = Field(...)
|
||
factor: FactorConfig = Field(default_factory=FactorConfig)
|
||
rotation: RotationConfig = Field(default_factory=RotationConfig)
|
||
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
|
||
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig)
|
||
data: DataConfig = Field(...)
|
||
|
||
|
||
# ============================================================
|
||
# 配置加载器(独立实现)
|
||
# ============================================================
|
||
|
||
class ConfigLoader:
|
||
"""配置加载器(独立实现)"""
|
||
|
||
def __init__(self, config_dir: str = None):
|
||
"""初始化"""
|
||
if config_dir is None:
|
||
config_dir = Path(__file__).parent
|
||
|
||
self.config_dir = Path(config_dir)
|
||
|
||
def load(self, config_file: str) -> RotationStrategyConfig:
|
||
"""加载配置文件"""
|
||
# 1. 解析路径
|
||
config_path = self._resolve_path(config_file)
|
||
|
||
# 2. 读取 YAML
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config_dict = yaml.safe_load(f)
|
||
|
||
# 3. 环境变量替换
|
||
config_dict = self._substitute_env_vars(config_dict)
|
||
|
||
# 4. Pydantic 验证
|
||
config = RotationStrategyConfig(**config_dict)
|
||
|
||
return config
|
||
|
||
def _resolve_path(self, config_file: str) -> Path:
|
||
"""解析配置文件路径"""
|
||
path = Path(config_file)
|
||
|
||
if path.is_absolute():
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"配置文件不存在: {path}")
|
||
return path
|
||
|
||
# 相对路径:先在配置目录查找
|
||
config_path = self.config_dir / path
|
||
if config_path.exists():
|
||
return config_path
|
||
|
||
# 然后在当前工作目录查找
|
||
cwd_path = Path.cwd() / path
|
||
if cwd_path.exists():
|
||
return cwd_path
|
||
|
||
raise FileNotFoundError(
|
||
f"配置文件未找到: {path}\n"
|
||
f"搜索路径:\n"
|
||
f" - {config_path}\n"
|
||
f" - {cwd_path}"
|
||
)
|
||
|
||
def _substitute_env_vars(self, config: Any) -> Any:
|
||
"""替换环境变量"""
|
||
if isinstance(config, dict):
|
||
return {
|
||
key: self._substitute_env_vars(value)
|
||
for key, value in config.items()
|
||
}
|
||
elif isinstance(config, list):
|
||
return [
|
||
self._substitute_env_vars(item)
|
||
for item in config
|
||
]
|
||
elif isinstance(config, str):
|
||
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
|
||
pattern = r'\$\{([^}]+)\}'
|
||
|
||
def replace_match(match):
|
||
var_expr = match.group(1)
|
||
|
||
if ':' in var_expr:
|
||
var_name, default = var_expr.split(':', 1)
|
||
else:
|
||
var_name = var_expr
|
||
default = None
|
||
|
||
value = os.getenv(var_name)
|
||
|
||
if value is None:
|
||
if default is not None:
|
||
return default
|
||
else:
|
||
raise ValueError(
|
||
f"环境变量未设置: {var_name}\n"
|
||
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
|
||
)
|
||
|
||
return value
|
||
|
||
return re.sub(pattern, replace_match, config)
|
||
else:
|
||
return config
|
||
|
||
|
||
# ============================================================
|
||
# 快捷函数
|
||
# ============================================================
|
||
|
||
def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig:
|
||
"""
|
||
加载 rotation 配置
|
||
|
||
Args:
|
||
config_file: 配置文件名
|
||
|
||
Returns:
|
||
RotationStrategyConfig: 验证后的配置对象
|
||
|
||
环境变量:
|
||
必须设置 FLASK_API_URL 或提供默认值
|
||
"""
|
||
loader = ConfigLoader()
|
||
return loader.load(config_file)
|
||
|
||
|
||
def get_config_info(config: RotationStrategyConfig) -> dict:
|
||
"""获取配置摘要"""
|
||
return {
|
||
'version': config.metadata.version,
|
||
'strategy': config.metadata.strategy,
|
||
'description': config.metadata.description,
|
||
'asset_count': config.asset_pools.count(),
|
||
'groups': config.asset_pools.groups,
|
||
'factor_type': config.factor.type.value,
|
||
'n_days': config.factor.n_days,
|
||
'select_num': config.rotation.select_num,
|
||
'start_date': config.backtest.start_date,
|
||
'end_date': config.backtest.end_date or '至今',
|
||
'threshold_mode': config.rotation.threshold.mode.value,
|
||
}
|
||
|
||
|
||
def print_config_summary(config: RotationStrategyConfig) -> None:
|
||
"""打印配置摘要"""
|
||
info = get_config_info(config)
|
||
|
||
print("\n" + "=" * 50)
|
||
print(" ETF 轮动策略配置摘要")
|
||
print("=" * 50)
|
||
|
||
print(f"\n版本: {info['version']}")
|
||
print(f"策略: {info['strategy']}")
|
||
print(f"描述: {info['description']}")
|
||
|
||
print(f"\n资产池:")
|
||
print(f" 总数: {info['asset_count']}")
|
||
print(f" 分组: {', '.join(info['groups'])}")
|
||
|
||
print(f"\n因子配置:")
|
||
print(f" 类型: {info['factor_type']}")
|
||
print(f" 窗口: {info['n_days']} 天")
|
||
|
||
print(f"\n轮动配置:")
|
||
print(f" 选股数: {info['select_num']}")
|
||
print(f" 阈值模式: {info['threshold_mode']}")
|
||
|
||
print(f"\n回测配置:")
|
||
print(f" 起始: {info['start_date']}")
|
||
print(f" 结束: {info['end_date']}")
|
||
|
||
print(f"\n各分组资产:")
|
||
for group, assets in config.asset_pools.by_group.items():
|
||
print(f"\n [{group}] ({len(assets)} 个):")
|
||
for code, asset in assets.items():
|
||
cross_market = "✓ 跨市场" if asset.is_cross_market else ""
|
||
print(f" - {asset.name} ({code})")
|
||
print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}")
|
||
|
||
print("\n" + "=" * 50)
|
||
|
||
|
||
# ============================================================
|
||
# 测试
|
||
# ============================================================
|
||
|
||
if __name__ == "__main__":
|
||
"""测试配置加载器"""
|
||
# 设置环境变量(如果没有设置)
|
||
if 'FLASK_API_URL' not in os.environ:
|
||
print("提示: FLASK_API_URL 未设置,使用默认值")
|
||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||
|
||
try:
|
||
# 加载配置
|
||
print("\n[测试] 加载 config_simple.yaml...")
|
||
config = load_rotation_config()
|
||
|
||
# 打印摘要
|
||
print_config_summary(config)
|
||
|
||
# 测试基本功能
|
||
print("\n[验证] 基本功能测试:")
|
||
print(f" ✓ 配置对象类型: {type(config).__name__}")
|
||
print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}")
|
||
print(f" ✓ 资产数量: {config.asset_pools.count()}")
|
||
print(f" ✓ 分组数量: {len(config.asset_pools.groups)}")
|
||
|
||
# 测试数据源
|
||
print("\n[验证] 数据源配置:")
|
||
for i, source in enumerate(config.data.sources, 1):
|
||
print(f" {i}. {source.type.value}")
|
||
print(f" URL: {source.url}")
|
||
print(f" 启用: {source.enabled}")
|
||
|
||
print("\n✓ 所有测试通过!")
|
||
|
||
except Exception as e:
|
||
print(f"\n✗ 加载失败: {e}")
|
||
import traceback
|
||
traceback.print_exc() |