""" 数据对齐 Schema 定义 与 CrossMarketAligner 配合使用,提供结构验证 """ from pydantic import BaseModel, Field, field_validator from typing import Optional, List import pandas as pd import numpy as np # ============================================================ # 输入验证 Schema # ============================================================ class OHLCVInputSchema(BaseModel): """ OHLCV 输入数据验证 用于对齐前验证原始数据 """ # 必需字段 close: float = Field(..., description="收盘价(必需)", gt=0) # 可选字段 open: Optional[float] = Field(None, description="开盘价", gt=0) high: Optional[float] = Field(None, description="最高价", gt=0) low: Optional[float] = Field(None, description="最低价", gt=0) volume: Optional[float] = Field(None, description="成交量", ge=0) class Config: extra = "ignore" # 忽略额外字段 @field_validator('close', 'open', 'high', 'low') @classmethod def check_positive(cls, v): """价格必须为正数""" if v is not None and v <= 0: raise ValueError(f"价格必须为正数,当前值: {v}") return v class FactorInputSchema(BaseModel): """ 因子输入数据验证 用于验证因子值在合理范围内 """ value: float = Field(..., description="因子值") is_filled: bool = Field(False, description="是否为填充值") @field_validator('value') @classmethod def check_reasonable(cls, v): """因子值应在合理范围内(-10 ~ 10)""" if abs(v) > 10: import warnings warnings.warn(f"因子值异常: {v}") return v # ============================================================ # 输出验证 Schema # ============================================================ class AlignedFactorSchema(BaseModel): """ 对齐后的因子数据验证 用于验证 align_factor() 的输出 """ value: float = Field(..., description="对齐后的因子值") is_filled: bool = Field(..., description="是否为填充值") class Config: # 允许 NaN(早期数据不足) arbitrary_types_allowed = True class AlignedReturnsSchema(BaseModel): """ 对齐后的收益率数据验证 用于验证 align_returns() 的输出 """ returns: float = Field(..., description="收益率") @field_validator('returns') @classmethod def check_returns_range(cls, v): """收益率应在合理范围内(-50% ~ 50%)""" if abs(v) > 0.5: import warnings warnings.warn(f"收益率异常: {v:.2%}") return v class Config: arbitrary_types_allowed = True # ============================================================ # 批量验证 Schema # ============================================================ class MultiAssetReturnsSchema(BaseModel): """ 多标的收益率数据验证 用于验证 align_multi_asset() 的输出 """ data: dict = Field(..., description="{标的代码: 收益率 Series}") @field_validator('data') @classmethod def check_no_nan(cls, v): """收益率 DataFrame 不能有 NaN""" df = pd.DataFrame(v) if df.isna().any().any(): nan_cols = df.columns[df.isna().any()] raise ValueError(f"收益率包含 NaN 列: {list(nan_cols)}") return v class AlignmentValidationResult(BaseModel): """ 对齐验证结果 用于 validate_alignment() 的输出 """ signals_aligned: bool = Field(..., description="信号是否已对齐") returns_aligned: bool = Field(..., description="收益率是否已对齐") common_dates_count: int = Field(..., description="共同日期数量") lost_signals: int = Field(0, description="丢失的信号数") lost_returns: int = Field(0, description="丢失的收益数") @field_validator('common_dates_count') @classmethod def check_min_dates(cls, v): """共同日期至少 10 天""" if v < 10: raise ValueError(f"共同日期太少: {v} 天") return v # ============================================================ # 验证装饰器(与 Aligner 配合) # ============================================================ def validate_ohlcv_before_align(func): """ 验证 OHLCV 数据在对齐前符合要求 使用示例: class CrossMarketAligner: @validate_ohlcv_before_align def align_factor(self, factor_series, source_calendar, code): ... """ from functools import wraps @wraps(func) def wrapper(self, *args, **kwargs): # 提取 close_series(第二个参数) if len(args) >= 1: close_series = args[0] else: close_series = kwargs.get('close_series') if isinstance(close_series, pd.Series): # 验证 close 列 if not pd.api.types.is_numeric_dtype(close_series): raise TypeError( f"close_series 必须是数值类型,当前是 {close_series.dtype}" ) if close_series.isna().all(): raise ValueError("close_series 全为 NaN") return func(self, *args, **kwargs) return wrapper def validate_factor_after_align(func): """ 验证因子对齐后符合要求 使用示例: class CrossMarketAligner: @validate_factor_after_align def align_factor(self, factor_series, source_calendar, code): ... """ from functools import wraps @wraps(func) def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) # 验证返回类型 if not isinstance(result, pd.DataFrame): raise TypeError( f"align_factor 必须返回 DataFrame,当前返回 {type(result)}" ) # 验证列 required_cols = ['value', 'is_filled'] missing_cols = [col for col in required_cols if col not in result.columns] if missing_cols: raise ValueError(f"对齐后 DataFrame 缺少列: {missing_cols}") # 验证 value 列类型 if not pd.api.types.is_numeric_dtype(result['value']): raise TypeError(f"value 列必须是数值类型") # 验证 is_filled 列类型 if not pd.api.types.is_bool_dtype(result['is_filled']): raise TypeError(f"is_filled 列必须是布尔类型") return result return wrapper def validate_returns_after_align(func): """ 验证收益率对齐后符合要求 使用示例: class CrossMarketAligner: @validate_returns_after_align def align_returns(self, close_series, code): ... """ from functools import wraps @wraps(func) def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs) # 验证返回类型 if not isinstance(result, pd.Series): raise TypeError( f"align_returns 必须返回 Series,当前返回 {type(result)}" ) # 验证无 NaN if result.isna().any(): nan_count = result.isna().sum() raise ValueError(f"收益率包含 {nan_count} 个 NaN") # 验证收益率范围 max_return = result.abs().max() if max_return > 0.5: import warnings warnings.warn(f"发现异常收益率: {max_return:.2%}") return result return wrapper