""" 配置加载器 - rotation 目录专用(独立实现) 不依赖 framework_v2,完全独立的配置加载和验证机制。 使用示例: from config_loader import load_rotation_config # 加载配置(自动处理环境变量) config = load_rotation_config('config_simple.yaml') # 访问配置 print(f"策略版本: {config.metadata.version}") print(f"动量窗口: {config.factor.n_days}") print(f"资产数量: {config.asset_pools.count()}") """ import os import re import yaml from pathlib import Path from typing import Optional, Dict, List, Any from pydantic import BaseModel, Field, field_validator from enum import Enum # ============================================================ # 枚举类型 # ============================================================ class FactorType(str, Enum): """因子类型枚举""" MOMENTUM = "momentum" SLOPE_R2 = "slope_r2" WEIGHTED_MOMENTUM = "weighted_momentum" class PremiumMode(str, Enum): """溢价控制模式""" FILTER = "filter" # 完全排除 PENALIZE = "penalize" # 降权 class ThresholdMode(str, Enum): """阈值模式""" FIXED = "fixed" # 固定阈值 DYNAMIC = "dynamic" # 动态阈值 class DataSourceType(str, Enum): """数据源类型""" FLASK_API = "flask_api" TUSHARE = "tushare" YFINANCE = "yfinance" # ============================================================ # Schema 定义(独立实现) # ============================================================ class PremiumConfig(BaseModel): """溢价控制配置""" enabled: bool = Field(default=True, description="是否启用") threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值") class AssetConfig(BaseModel): """标的配置""" name: str = Field(..., description="标的名称") group: str = Field(..., description="策略分组") signal_source: str = Field(..., description="信号来源代码") trade_source: str = Field(..., description="交易来源代码") description: Optional[str] = Field(None, description="标的描述") etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)") premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置") @property def is_cross_market(self) -> bool: """是否跨市场""" return self.signal_source != self.trade_source class AssetPool(BaseModel): """资产池(扁平化设计)""" assets: Dict[str, AssetConfig] = Field( default_factory=dict, description="所有标的" ) @property def by_group(self) -> Dict[str, Dict[str, AssetConfig]]: """按策略分组""" groups = {} for code, asset in self.assets.items(): group = asset.group if group not in groups: groups[group] = {} groups[group][code] = asset return groups @property def groups(self) -> list: """获取所有策略分组""" return list(self.by_group.keys()) def get_signal_codes(self, group: str = None) -> list: """获取信号标的""" if group: return [ asset.signal_source for asset in self.assets.values() if asset.group == group ] return [asset.signal_source for asset in self.assets.values()] def get_trade_codes(self, group: str = None) -> list: """获取交易标的""" if group: return [ asset.trade_source for asset in self.assets.values() if asset.group == group ] return [asset.trade_source for asset in self.assets.values()] def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]: """获取信号→交易映射""" mapping = {} for asset in self.assets.values(): if group and asset.group != group: continue mapping[asset.signal_source] = asset.trade_source return mapping def count(self, group: str = None) -> int: """获取标的数量""" if group: return len([a for a in self.assets.values() if a.group == group]) return len(self.assets) class FactorConfig(BaseModel): """因子配置""" type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM) n_days: int = Field(default=25, ge=5, le=250) class DynamicThresholdConfig(BaseModel): """动态阈值配置""" reference: str = Field(..., description="阈值参考标的") ratio: float = Field(default=1.0, ge=0, description="阈值倍数") fallback_enabled: bool = Field(default=True) fallback_value: float = Field(default=0.0) class ThresholdConfig(BaseModel): """阈值配置""" mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC) fixed_value: float = Field(default=0.0) dynamic: Optional[DynamicThresholdConfig] = Field(None) class RotationConfig(BaseModel): """轮动配置""" select_num: int = Field(default=3, ge=1, le=10) diversified: bool = Field(default=True) threshold: ThresholdConfig = Field(default_factory=ThresholdConfig) class RebalanceConfig(BaseModel): """调仓配置""" min_hold_days: int = Field(default=1, ge=1) score_threshold: float = Field(default=0.0) trade_cost: float = Field(default=0.001, ge=0, le=0.01) class PremiumControlConfig(BaseModel): """溢价控制配置""" enabled: bool = Field(default=False) default_threshold: float = Field(default=0.10) mode: PremiumMode = Field(default=PremiumMode.FILTER) penalty_factor: float = Field(default=0.5) market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None) class DataSourceConfig(BaseModel): """数据源配置""" type: DataSourceType = Field(...) enabled: bool = Field(default=True) timeout: int = Field(default=120, ge=10, le=600) url: Optional[str] = Field(None) class DataConfig(BaseModel): """数据配置""" sources: List[DataSourceConfig] = Field(...) @field_validator('sources') @classmethod def check_at_least_one(cls, v): """至少配置一个数据源""" if not v: raise ValueError("必须配置至少一个数据源") return v class BenchmarkConfig(BaseModel): """基准配置""" code: str = Field(...) name: str = Field(...) class BacktestConfig(BaseModel): """回测配置""" start_date: str = Field(...) end_date: Optional[str] = Field(None) class MetadataConfig(BaseModel): """配置元数据""" version: str = Field(default="1.0.0") strategy: str = Field(default="rotation") description: str = Field(default="") last_updated: Optional[str] = Field(None) class RotationStrategyConfig(BaseModel): """ETF轮动策略完整配置""" metadata: MetadataConfig = Field(default_factory=MetadataConfig) asset_pools: AssetPool = Field(...) benchmark: BenchmarkConfig = Field(...) backtest: BacktestConfig = Field(...) factor: FactorConfig = Field(default_factory=FactorConfig) rotation: RotationConfig = Field(default_factory=RotationConfig) rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig) premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig) data: DataConfig = Field(...) # ============================================================ # 配置加载器(独立实现) # ============================================================ class ConfigLoader: """配置加载器(独立实现)""" def __init__(self, config_dir: str = None): """初始化""" if config_dir is None: config_dir = Path(__file__).parent self.config_dir = Path(config_dir) def load(self, config_file: str) -> RotationStrategyConfig: """加载配置文件""" # 1. 解析路径 config_path = self._resolve_path(config_file) # 2. 读取 YAML with open(config_path, 'r', encoding='utf-8') as f: config_dict = yaml.safe_load(f) # 3. 环境变量替换 config_dict = self._substitute_env_vars(config_dict) # 4. Pydantic 验证 config = RotationStrategyConfig(**config_dict) return config def _resolve_path(self, config_file: str) -> Path: """解析配置文件路径""" path = Path(config_file) if path.is_absolute(): if not path.exists(): raise FileNotFoundError(f"配置文件不存在: {path}") return path # 相对路径:先在配置目录查找 config_path = self.config_dir / path if config_path.exists(): return config_path # 然后在当前工作目录查找 cwd_path = Path.cwd() / path if cwd_path.exists(): return cwd_path raise FileNotFoundError( f"配置文件未找到: {path}\n" f"搜索路径:\n" f" - {config_path}\n" f" - {cwd_path}" ) def _substitute_env_vars(self, config: Any) -> Any: """替换环境变量""" if isinstance(config, dict): return { key: self._substitute_env_vars(value) for key, value in config.items() } elif isinstance(config, list): return [ self._substitute_env_vars(item) for item in config ] elif isinstance(config, str): # 匹配 ${VAR_NAME} 或 ${VAR_NAME:default} pattern = r'\$\{([^}]+)\}' def replace_match(match): var_expr = match.group(1) if ':' in var_expr: var_name, default = var_expr.split(':', 1) else: var_name = var_expr default = None value = os.getenv(var_name) if value is None: if default is not None: return default else: raise ValueError( f"环境变量未设置: {var_name}\n" f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}" ) return value return re.sub(pattern, replace_match, config) else: return config # ============================================================ # 快捷函数 # ============================================================ def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig: """ 加载 rotation 配置 Args: config_file: 配置文件名 Returns: RotationStrategyConfig: 验证后的配置对象 环境变量: 必须设置 FLASK_API_URL 或提供默认值 """ loader = ConfigLoader() return loader.load(config_file) def get_config_info(config: RotationStrategyConfig) -> dict: """获取配置摘要""" return { 'version': config.metadata.version, 'strategy': config.metadata.strategy, 'description': config.metadata.description, 'asset_count': config.asset_pools.count(), 'groups': config.asset_pools.groups, 'factor_type': config.factor.type.value, 'n_days': config.factor.n_days, 'select_num': config.rotation.select_num, 'start_date': config.backtest.start_date, 'end_date': config.backtest.end_date or '至今', 'threshold_mode': config.rotation.threshold.mode.value, } def print_config_summary(config: RotationStrategyConfig) -> None: """打印配置摘要""" info = get_config_info(config) print("\n" + "=" * 50) print(" ETF 轮动策略配置摘要") print("=" * 50) print(f"\n版本: {info['version']}") print(f"策略: {info['strategy']}") print(f"描述: {info['description']}") print(f"\n资产池:") print(f" 总数: {info['asset_count']}") print(f" 分组: {', '.join(info['groups'])}") print(f"\n因子配置:") print(f" 类型: {info['factor_type']}") print(f" 窗口: {info['n_days']} 天") print(f"\n轮动配置:") print(f" 选股数: {info['select_num']}") print(f" 阈值模式: {info['threshold_mode']}") print(f"\n回测配置:") print(f" 起始: {info['start_date']}") print(f" 结束: {info['end_date']}") print(f"\n各分组资产:") for group, assets in config.asset_pools.by_group.items(): print(f"\n [{group}] ({len(assets)} 个):") for code, asset in assets.items(): cross_market = "✓ 跨市场" if asset.is_cross_market else "" print(f" - {asset.name} ({code})") print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}") print("\n" + "=" * 50) # ============================================================ # 测试 # ============================================================ if __name__ == "__main__": """测试配置加载器""" # 设置环境变量(如果没有设置) if 'FLASK_API_URL' not in os.environ: print("提示: FLASK_API_URL 未设置,使用默认值") os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' try: # 加载配置 print("\n[测试] 加载 config_simple.yaml...") config = load_rotation_config() # 打印摘要 print_config_summary(config) # 测试基本功能 print("\n[验证] 基本功能测试:") print(f" ✓ 配置对象类型: {type(config).__name__}") print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}") print(f" ✓ 资产数量: {config.asset_pools.count()}") print(f" ✓ 分组数量: {len(config.asset_pools.groups)}") # 测试数据源 print("\n[验证] 数据源配置:") for i, source in enumerate(config.data.sources, 1): print(f" {i}. {source.type.value}") print(f" URL: {source.url}") print(f" 启用: {source.enabled}") print("\n✓ 所有测试通过!") except Exception as e: print(f"\n✗ 加载失败: {e}") import traceback traceback.print_exc()