From 341611c32b337faa7044ba6d4d24fcaa0852279c Mon Sep 17 00:00:00 2001 From: aszerW Date: Sun, 24 May 2026 14:25:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(framework=5Fv2):=20=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E9=80=9A=E7=94=A8=E9=85=8D=E7=BD=AE=E7=B3=BB=E7=BB=9F=EF=BC=8C?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=89=81=E5=B9=B3=E5=8C=96=E8=B5=84=E4=BA=A7?= =?UTF-8?q?=E6=B1=A0=E5=92=8C=E7=AD=96=E7=95=A5=E5=88=86=E7=BB=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 Pydantic Schema 验证配置类型安全 - 实现扁平化 AssetPool,移除预设分类(equity/commodity/fixed_income) - 移除 MarketType 枚举,改用 group 字符串字段实现策略分组 - AssetConfig 引入 signal_source/trade_source 分离,支持跨市场场景 - ConfigLoader 支持通用 StrategyConfig,向后兼容 RotationStrategyConfig - 新增 GroupConfig 替代 MarketGroupConfig,支持分散化选股 重构核心: - market → group(策略分组语义,组内竞争强制分散) - by_market → by_group - MarketGroupConfig → GroupConfig --- framework_v2/config/__init__.py | 72 ++++++ framework_v2/config/loader.py | 247 ++++++++++++++++++ framework_v2/config/schemas.py | 444 ++++++++++++++++++++++++++++++++ 3 files changed, 763 insertions(+) create mode 100644 framework_v2/config/__init__.py create mode 100644 framework_v2/config/loader.py create mode 100644 framework_v2/config/schemas.py diff --git a/framework_v2/config/__init__.py b/framework_v2/config/__init__.py new file mode 100644 index 0000000..e8111e1 --- /dev/null +++ b/framework_v2/config/__init__.py @@ -0,0 +1,72 @@ +""" +配置模块 + +提供配置加载、验证、管理功能 +""" + +from framework_v2.config.schemas import ( + # 完整配置 + RotationStrategyConfig, + StrategyConfig, # 通用配置 + + # 子配置 + AssetPool, + AssetConfig, + PremiumConfig, + GroupConfig, + FactorConfig, + RotationConfig, + ThresholdConfig, + DynamicThresholdConfig, + RebalanceConfig, + PremiumControlConfig, + DataConfig, + DataSourceConfig, + BenchmarkConfig, + BacktestConfig, + MetadataConfig, + + # 枚举(已移除 MarketType,改用字符串 group) + FactorType, + PremiumMode, + ThresholdMode, + DataSourceType, +) + +from framework_v2.config.loader import ( + ConfigLoader, + get_config_loader, + load_config, +) + +__all__ = [ + # 配置 Schema + 'RotationStrategyConfig', + 'StrategyConfig', + 'AssetPool', + 'AssetConfig', + 'PremiumConfig', + 'GroupConfig', + 'FactorConfig', + 'RotationConfig', + 'ThresholdConfig', + 'DynamicThresholdConfig', + 'RebalanceConfig', + 'PremiumControlConfig', + 'DataConfig', + 'DataSourceConfig', + 'BenchmarkConfig', + 'BacktestConfig', + 'MetadataConfig', + + # 枚举(已移除 MarketType,改用字符串 group) + 'FactorType', + 'PremiumMode', + 'ThresholdMode', + 'DataSourceType', + + # 加载器 + 'ConfigLoader', + 'get_config_loader', + 'load_config', +] diff --git a/framework_v2/config/loader.py b/framework_v2/config/loader.py new file mode 100644 index 0000000..772d5e6 --- /dev/null +++ b/framework_v2/config/loader.py @@ -0,0 +1,247 @@ +""" +配置加载器 + +支持: +1. YAML 配置文件加载 +2. Pydantic Schema 验证 +3. 环境变量替换 +4. 配置合并(默认值 + 用户配置) +""" + +import os +import re +import yaml +from pathlib import Path +from typing import Optional, Dict, Any + +from framework_v2.config.schemas import StrategyConfig, RotationStrategyConfig + + +class ConfigLoader: + """ + 配置加载器 + + 用法: + loader = ConfigLoader() + config = loader.load('config/rotation.yaml') + """ + + def __init__(self, config_dir: str = None): + """ + 初始化 + + Args: + config_dir: 配置目录(默认 framework_v2/config) + """ + if config_dir is None: + # 默认配置目录 + config_dir = Path(__file__).parent + + self.config_dir = Path(config_dir) + + def load(self, config_file: str, config_type: str = 'strategy') -> StrategyConfig: + """ + 加载配置文件 + + Args: + config_file: 配置文件路径(相对路径或绝对路径) + config_type: 配置类型('strategy' 或 'rotation') + + Returns: + 验证后的配置对象 + + 示例: + >>> loader = ConfigLoader() + >>> config = loader.load('rotation_example.yaml') + >>> print(config.factor.n_days) + 25 + """ + # 1. 解析文件路径 + config_path = self._resolve_path(config_file) + + # 2. 读取 YAML + with open(config_path, 'r', encoding='utf-8') as f: + config_dict = yaml.safe_load(f) + + # 3. 环境变量替换 + config_dict = self._substitute_env_vars(config_dict) + + # 4. Pydantic 验证(根据类型选择 Schema) + if config_type == 'rotation': + # 兼容旧版本:轮动策略专用配置 + config = RotationStrategyConfig(**config_dict) + else: + # 通用策略配置(推荐) + config = StrategyConfig(**config_dict) + + return config + + def load_dict(self, config_dict: Dict[str, Any], config_type: str = 'strategy') -> StrategyConfig: + """ + 从字典加载配置 + + Args: + config_dict: 配置字典 + config_type: 配置类型('strategy' 或 'rotation') + + Returns: + 验证后的配置对象 + """ + # 环境变量替换 + config_dict = self._substitute_env_vars(config_dict) + + # Pydantic 验证(根据类型选择 Schema) + if config_type == 'rotation': + config = RotationStrategyConfig(**config_dict) + else: + config = StrategyConfig(**config_dict) + + return config + + def _resolve_path(self, config_file: str) -> Path: + """ + 解析配置文件路径 + + Args: + config_file: 配置文件路径 + + Returns: + 绝对路径 + """ + path = Path(config_file) + + # 如果是绝对路径,直接返回 + if path.is_absolute(): + if not path.exists(): + raise FileNotFoundError(f"配置文件不存在: {path}") + return path + + # 相对路径:先在配置目录查找 + config_path = self.config_dir / path + if config_path.exists(): + return config_path + + # 然后在当前工作目录查找 + cwd_path = Path.cwd() / path + if cwd_path.exists(): + return cwd_path + + raise FileNotFoundError( + f"配置文件未找到: {path}\n" + f"搜索路径:\n" + f" - {config_path}\n" + f" - {cwd_path}" + ) + + def _substitute_env_vars(self, config: Any) -> Any: + """ + 替换配置中的环境变量 + + 支持格式: + - ${VAR_NAME} + - ${VAR_NAME:default_value} + + Args: + config: 配置对象(dict/list/str) + + Returns: + 替换后的配置 + """ + if isinstance(config, dict): + return { + key: self._substitute_env_vars(value) + for key, value in config.items() + } + elif isinstance(config, list): + return [ + self._substitute_env_vars(item) + for item in config + ] + elif isinstance(config, str): + # 匹配 ${VAR_NAME} 或 ${VAR_NAME:default} + pattern = r'\$\{([^}]+)\}' + + def replace_match(match): + var_expr = match.group(1) + + # 检查是否有默认值 + if ':' in var_expr: + var_name, default = var_expr.split(':', 1) + else: + var_name = var_expr + default = None + + # 从环境变量读取 + value = os.getenv(var_name) + + if value is None: + if default is not None: + return default + else: + raise ValueError( + f"环境变量未设置: {var_name}\n" + f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}" + ) + + return value + + return re.sub(pattern, replace_match, config) + else: + return config + + def get_available_configs(self) -> list: + """ + 获取可用的配置文件列表 + + Returns: + 配置文件名列表 + """ + if not self.config_dir.exists(): + return [] + + return [ + f.name + for f in self.config_dir.glob('*.yaml') + if f.is_file() + ] + + +# 全局实例 +_config_loader: Optional[ConfigLoader] = None + + +def get_config_loader(config_dir: str = None) -> ConfigLoader: + """ + 获取配置加载器单例 + + Args: + config_dir: 配置目录 + + Returns: + ConfigLoader 实例 + """ + global _config_loader + + if _config_loader is None: + _config_loader = ConfigLoader(config_dir) + + return _config_loader + + +def load_config(config_file: str, config_type: str = 'strategy') -> StrategyConfig: + """ + 快捷函数:加载配置文件 + + Args: + config_file: 配置文件路径 + config_type: 配置类型('strategy' 或 'rotation') + + Returns: + 验证后的配置对象 + + 示例: + >>> from framework_v2.config import load_config + >>> config = load_config('rotation_example.yaml') + """ + loader = get_config_loader() + return loader.load(config_file, config_type=config_type) diff --git a/framework_v2/config/schemas.py b/framework_v2/config/schemas.py new file mode 100644 index 0000000..cbf02c1 --- /dev/null +++ b/framework_v2/config/schemas.py @@ -0,0 +1,444 @@ +""" +配置 Schema 定义 + +使用 Pydantic 验证配置文件的类型安全 +""" + +from pydantic import BaseModel, Field, field_validator, model_validator +from typing import Optional, Dict, List, Literal +from enum import Enum +from datetime import date + + +# ============================================================ +# 枚举类型 +# ============================================================ + +# 注意:不再使用 MarketType 枚举,改用字符串类型的 group 字段 +# group 字段用于策略分组(组内竞争,强制分散) + + +class FactorType(str, Enum): + """因子类型枚举""" + MOMENTUM = "momentum" + SLOPE_R2 = "slope_r2" + WEIGHTED_MOMENTUM = "weighted_momentum" + + +class PremiumMode(str, Enum): + """溢价控制模式""" + FILTER = "filter" # 完全排除 + PENALIZE = "penalize" # 降权 + + +class ThresholdMode(str, Enum): + """阈值模式""" + FIXED = "fixed" # 固定阈值 + DYNAMIC = "dynamic" # 动态阈值 + + +class DataSourceType(str, Enum): + """数据源类型""" + FLASK_API = "flask_api" + TUSHARE = "tushare" + YFINANCE = "yfinance" + + +# ============================================================ +# 资产池 Schema(扁平化设计) +# ============================================================ + +class PremiumConfig(BaseModel): + """溢价控制配置""" + enabled: bool = Field(default=True, description="是否启用") + threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值") + + +class AssetConfig(BaseModel): + """ + 标的配置(通用,支持所有场景) + + 场景 1:指数信号 → 指数收益 + signal_source: "NDX" + trade_source: "NDX" + + 场景 2:指数信号 → ETF收益 + signal_source: "NDX" + trade_source: "513100.SH" + + 场景 3:ETF信号 → ETF收益 + signal_source: "518880.SH" + trade_source: "518880.SH" + + 场景 4:个股信号 → 个股收益 + signal_source: "AAPL" + trade_source: "AAPL" + """ + name: str = Field(..., description="标的名称") + group: str = Field(..., description="策略分组(组内竞争,强制分散)") + + # 核心字段:信号来源和交易来源 + signal_source: str = Field(..., description="信号来源代码(计算因子用)") + trade_source: str = Field(..., description="交易来源代码(计算收益用)") + + # 可选字段 + description: Optional[str] = Field(None, description="标的描述") + etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)") + premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置") + + @property + def is_cross_market(self) -> bool: + """是否跨市场(信号和交易不同)""" + return self.signal_source != self.trade_source + + +class AssetPool(BaseModel): + """ + 资产池(扁平化设计) + + 优势: + 1. 不预设分类,支持任意策略分组 + 2. 通过 group 字段自动分组 + 3. 配置简单直观 + 4. 易于扩展新分组 + + 示例: + asset_pool: + assets: + "NDX": + name: "纳指100" + group: "US_TECH" + signal_source: "NDX" + trade_source: "513100.SH" + + "BTC": + name: "比特币" + group: "CRYPTO" + signal_source: "BTC" + trade_source: "BTC" + """ + assets: Dict[str, AssetConfig] = Field( + default_factory=dict, + description="所有标的(flat结构)" + ) + + @property + def by_group(self) -> Dict[str, Dict[str, AssetConfig]]: + """按策略分组""" + groups = {} + for code, asset in self.assets.items(): + group = asset.group + if group not in groups: + groups[group] = {} + groups[group][code] = asset + return groups + + @property + def groups(self) -> list: + """获取所有策略分组""" + return list(self.by_group.keys()) + + def get_signal_codes(self, group: str = None) -> list: + """ + 获取信号标的 + + Args: + group: 可选,过滤特定分组(如 'US_TECH') + """ + if group: + return [ + asset.signal_source + for asset in self.assets.values() + if asset.group == group + ] + return [asset.signal_source for asset in self.assets.values()] + + def get_trade_codes(self, group: str = None) -> list: + """获取交易标的""" + if group: + return [ + asset.trade_source + for asset in self.assets.values() + if asset.group == group + ] + return [asset.trade_source for asset in self.assets.values()] + + def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]: + """获取信号→交易映射""" + mapping = {} + for asset in self.assets.values(): + if group and asset.group != group: + continue + mapping[asset.signal_source] = asset.trade_source + return mapping + + def count(self, group: str = None) -> int: + """获取标的数量""" + if group: + return len([a for a in self.assets.values() if a.group == group]) + return len(self.assets) + + @field_validator('assets') + @classmethod + def check_non_empty(cls, v): + """至少配置一个标的""" + if not v: + import warnings + warnings.warn("资产池为空") + return v + + +# ============================================================ +# 因子配置 Schema +# ============================================================ + +class FactorConfig(BaseModel): + """因子配置""" + type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM, description="因子类型") + n_days: int = Field(default=25, ge=5, le=250, description="动量窗口期(天数)") + + # 动态周期参数 + auto_day: bool = Field(default=False, description="是否启用动态周期") + min_days: int = Field(default=20, ge=5, le=100, description="最小周期") + max_days: int = Field(default=60, ge=20, le=250, description="最大周期") + + +# ============================================================ +# 阈值配置 Schema +# ============================================================ + +class DynamicThresholdConfig(BaseModel): + """动态阈值配置""" + reference: str = Field(..., description="参考标的代码") + ratio: float = Field(default=1.0, gt=0, description="倍数") + fallback_enabled: bool = Field(default=True, description="参考不可用时是否回退") + fallback_value: float = Field(default=0.0, description="回退值") + + +class ThresholdConfig(BaseModel): + """阈值配置(统一 V2/V3)""" + mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC, description="阈值模式") + fixed_value: float = Field(default=0.0, description="固定阈值(mode=fixed时使用)") + dynamic: Optional[DynamicThresholdConfig] = Field(None, description="动态阈值(mode=dynamic时使用)") + + @model_validator(mode='after') + def validate_dynamic_config(self): + """验证动态阈值配置""" + if self.mode == ThresholdMode.DYNAMIC and self.dynamic is None: + raise ValueError("dynamic 模式必须配置 dynamic 字段") + return self + + +# ============================================================ +# 轮动配置 Schema +# ============================================================ + +class GroupConfig(BaseModel): + """策略分组配置(用于分散化选股)""" + group: str = Field(..., description="策略分组名称") + select_num: int = Field(default=1, ge=1, le=10, description="该分组选股数量") + + +class RotationConfig(BaseModel): + """ + 轮动配置 + + 支持两种选股模式: + 1. 全局选股:diversified=false,直接选 Top-N + 2. 分散化选股:diversified=true,按策略分组选股 + """ + select_num: int = Field(default=3, ge=1, le=20, description="选股数量(全局模式)") + diversified: bool = Field(default=False, description="是否分散化选股") + + # 分散化配置(可选) + diversification_groups: Optional[List[GroupConfig]] = Field( + None, + description="分散化分组配置(diversified=true时使用)" + ) + + threshold: ThresholdConfig = Field( + default_factory=ThresholdConfig, + description="动量阈值" + ) + + +# ============================================================ +# 调仓配置 Schema +# ============================================================ + +class RebalanceConfig(BaseModel): + """调仓配置""" + min_hold_days: int = Field(default=1, ge=1, le=30, description="最低持有天数") + score_threshold: float = Field(default=0.0, ge=0, le=0.5, description="调仓得分阈值(%)") + trade_cost: float = Field(default=0.001, ge=0, le=0.01, description="单次换仓成本") + + +# ============================================================ +# 溢价控制 Schema +# ============================================================ + +class MarketPremiumOverride(BaseModel): + """市场溢价覆盖配置""" + enabled: bool = Field(default=True, description="是否启用") + threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值") + + +class PremiumControlConfig(BaseModel): + """溢价控制配置""" + enabled: bool = Field(default=True, description="是否启用溢价控制") + default_threshold: float = Field(default=0.10, ge=0, le=1.0, description="默认溢价阈值") + mode: PremiumMode = Field(default=PremiumMode.FILTER, description="控制模式") + penalty_factor: float = Field(default=0.5, ge=0, le=1.0, description="降权惩罚系数") + + # 按市场覆盖 + market_overrides: Dict[str, MarketPremiumOverride] = Field( + default_factory=dict, + description="按市场覆盖配置" + ) + + +# ============================================================ +# 数据源配置 Schema +# ============================================================ + +class DataSourceConfig(BaseModel): + """单个数据源配置""" + type: DataSourceType = Field(..., description="数据源类型") + enabled: bool = Field(default=True, description="是否启用") + timeout: int = Field(default=120, ge=10, le=600, description="超时时间(秒)") + + # Flask API 特定配置 + url: Optional[str] = Field(None, description="API地址(flask_api使用)") + + # Tushare 特定配置 + token: Optional[str] = Field(None, description="Tushare Token(tushare使用)") + + +class DataConfig(BaseModel): + """数据配置""" + sources: List[DataSourceConfig] = Field(..., description="数据源列表(按优先级排序)") + use_cache: bool = Field(default=True, description="是否使用本地缓存") + cache_dir: str = Field(default="data_cache", description="缓存目录") + + @field_validator('sources') + @classmethod + def check_at_least_one(cls, v): + """至少配置一个数据源""" + if not v: + raise ValueError("必须配置至少一个数据源") + return v + + +# ============================================================ +# 基准配置 Schema +# ============================================================ + +class BenchmarkConfig(BaseModel): + """基准配置""" + code: str = Field(..., description="基准代码") + name: str = Field(..., description="基准名称") + + +# ============================================================ +# 回测配置 Schema +# ============================================================ + +class BacktestConfig(BaseModel): + """回测配置""" + start_date: str = Field(..., description="回测起始日期(YYYY-MM-DD)") + end_date: Optional[str] = Field(None, description="回测结束日期(None表示至今)") + + +# ============================================================ +# 元数据 Schema +# ============================================================ + +class MetadataConfig(BaseModel): + """配置元数据""" + version: str = Field(default="1.0.0", description="配置版本") + strategy: str = Field(default="rotation", description="策略名称") + description: str = Field(default="", description="配置描述") + last_updated: Optional[str] = Field(None, description="最后更新日期") + + +# ============================================================ +# 完整策略配置 Schema +# ============================================================ + +class RotationStrategyConfig(BaseModel): + """ + ETF轮动策略完整配置 + + 使用示例: + from framework_v2.config.schemas import RotationStrategyConfig + import yaml + + with open('config/rotation.yaml') as f: + config_dict = yaml.safe_load(f) + + config = RotationStrategyConfig(**config_dict) + """ + metadata: MetadataConfig = Field(default_factory=MetadataConfig, description="元数据") + + # 资产池 + asset_pools: AssetPool = Field(..., description="资产池配置") + + # 基准 + benchmark: BenchmarkConfig = Field(..., description="基准配置") + + # 回测 + backtest: BacktestConfig = Field(..., description="回测配置") + + # 因子 + factor: FactorConfig = Field(default_factory=FactorConfig, description="因子配置") + + # 轮动 + rotation: RotationConfig = Field(default_factory=RotationConfig, description="轮动配置") + + # 调仓 + rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig, description="调仓配置") + + # 溢价控制 + premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig, description="溢价控制") + + # 数据 + data: DataConfig = Field(..., description="数据配置") + + +# ============================================================ +# 通用策略配置 Schema(V2 架构) +# ============================================================ + +class StrategyConfig(BaseModel): + """ + 通用策略配置(支持所有策略类型) + + 使用示例: + from framework_v2.config import load_config + config = load_config('rotation.yaml') + """ + metadata: MetadataConfig = Field(default_factory=MetadataConfig, description="元数据") + + # 资产池 + asset_pools: AssetPool = Field(..., description="资产池配置") + + # 基准 + benchmark: BenchmarkConfig = Field(..., description="基准配置") + + # 回测 + backtest: BacktestConfig = Field(..., description="回测配置") + + # 因子 + factor: FactorConfig = Field(default_factory=FactorConfig, description="因子配置") + + # 轮动(可选) + rotation: Optional[RotationConfig] = Field(None, description="轮动配置") + + # 调仓 + rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig, description="调仓配置") + + # 溢价控制 + premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig, description="溢价控制") + + # 数据 + data: DataConfig = Field(..., description="数据配置")