## 核心功能 - CrossMarketAligner: 跨市场数据对齐(解决 ffill 陷阱) - Pydantic Schema: 数据结构验证(OHLCVInputSchema, AlignedFactorSchema 等) - 验证装饰器: @validate_factor_after_align, @validate_returns_after_align ## 解决的问题 - 跨市场交易日历不同(美股/港股/A股) - ffill 收益率陷阱(休市日复制非零收益率) - NaN 传播问题 - 日期不一致问题 ## 测试验证 - 5/5 测试通过(因子对齐、收益率对齐、多标的对齐、信号验证、ffill陷阱) - 休市日收益率 = 0%(正确) - 无 NaN 传播 ## 架构设计 - shared/data/alignment.py - 对齐器实现 - shared/data/schemas.py - Pydantic Schema 定义 - tests/test_alignment.py - 完整测试套件
259 lines
7.6 KiB
Python
259 lines
7.6 KiB
Python
"""
|
||
数据对齐 Schema 定义
|
||
|
||
与 CrossMarketAligner 配合使用,提供结构验证
|
||
"""
|
||
|
||
from pydantic import BaseModel, Field, field_validator
|
||
from typing import Optional, List
|
||
import pandas as pd
|
||
import numpy as np
|
||
|
||
|
||
# ============================================================
|
||
# 输入验证 Schema
|
||
# ============================================================
|
||
|
||
class OHLCVInputSchema(BaseModel):
|
||
"""
|
||
OHLCV 输入数据验证
|
||
|
||
用于对齐前验证原始数据
|
||
"""
|
||
# 必需字段
|
||
close: float = Field(..., description="收盘价(必需)", gt=0)
|
||
|
||
# 可选字段
|
||
open: Optional[float] = Field(None, description="开盘价", gt=0)
|
||
high: Optional[float] = Field(None, description="最高价", gt=0)
|
||
low: Optional[float] = Field(None, description="最低价", gt=0)
|
||
volume: Optional[float] = Field(None, description="成交量", ge=0)
|
||
|
||
class Config:
|
||
extra = "ignore" # 忽略额外字段
|
||
|
||
@field_validator('close', 'open', 'high', 'low')
|
||
@classmethod
|
||
def check_positive(cls, v):
|
||
"""价格必须为正数"""
|
||
if v is not None and v <= 0:
|
||
raise ValueError(f"价格必须为正数,当前值: {v}")
|
||
return v
|
||
|
||
|
||
class FactorInputSchema(BaseModel):
|
||
"""
|
||
因子输入数据验证
|
||
|
||
用于验证因子值在合理范围内
|
||
"""
|
||
value: float = Field(..., description="因子值")
|
||
is_filled: bool = Field(False, description="是否为填充值")
|
||
|
||
@field_validator('value')
|
||
@classmethod
|
||
def check_reasonable(cls, v):
|
||
"""因子值应在合理范围内(-10 ~ 10)"""
|
||
if abs(v) > 10:
|
||
import warnings
|
||
warnings.warn(f"因子值异常: {v}")
|
||
return v
|
||
|
||
|
||
# ============================================================
|
||
# 输出验证 Schema
|
||
# ============================================================
|
||
|
||
class AlignedFactorSchema(BaseModel):
|
||
"""
|
||
对齐后的因子数据验证
|
||
|
||
用于验证 align_factor() 的输出
|
||
"""
|
||
value: float = Field(..., description="对齐后的因子值")
|
||
is_filled: bool = Field(..., description="是否为填充值")
|
||
|
||
class Config:
|
||
# 允许 NaN(早期数据不足)
|
||
arbitrary_types_allowed = True
|
||
|
||
|
||
class AlignedReturnsSchema(BaseModel):
|
||
"""
|
||
对齐后的收益率数据验证
|
||
|
||
用于验证 align_returns() 的输出
|
||
"""
|
||
returns: float = Field(..., description="收益率")
|
||
|
||
@field_validator('returns')
|
||
@classmethod
|
||
def check_returns_range(cls, v):
|
||
"""收益率应在合理范围内(-50% ~ 50%)"""
|
||
if abs(v) > 0.5:
|
||
import warnings
|
||
warnings.warn(f"收益率异常: {v:.2%}")
|
||
return v
|
||
|
||
class Config:
|
||
arbitrary_types_allowed = True
|
||
|
||
|
||
# ============================================================
|
||
# 批量验证 Schema
|
||
# ============================================================
|
||
|
||
class MultiAssetReturnsSchema(BaseModel):
|
||
"""
|
||
多标的收益率数据验证
|
||
|
||
用于验证 align_multi_asset() 的输出
|
||
"""
|
||
data: dict = Field(..., description="{标的代码: 收益率 Series}")
|
||
|
||
@field_validator('data')
|
||
@classmethod
|
||
def check_no_nan(cls, v):
|
||
"""收益率 DataFrame 不能有 NaN"""
|
||
df = pd.DataFrame(v)
|
||
if df.isna().any().any():
|
||
nan_cols = df.columns[df.isna().any()]
|
||
raise ValueError(f"收益率包含 NaN 列: {list(nan_cols)}")
|
||
return v
|
||
|
||
|
||
class AlignmentValidationResult(BaseModel):
|
||
"""
|
||
对齐验证结果
|
||
|
||
用于 validate_alignment() 的输出
|
||
"""
|
||
signals_aligned: bool = Field(..., description="信号是否已对齐")
|
||
returns_aligned: bool = Field(..., description="收益率是否已对齐")
|
||
common_dates_count: int = Field(..., description="共同日期数量")
|
||
lost_signals: int = Field(0, description="丢失的信号数")
|
||
lost_returns: int = Field(0, description="丢失的收益数")
|
||
|
||
@field_validator('common_dates_count')
|
||
@classmethod
|
||
def check_min_dates(cls, v):
|
||
"""共同日期至少 10 天"""
|
||
if v < 10:
|
||
raise ValueError(f"共同日期太少: {v} 天")
|
||
return v
|
||
|
||
|
||
# ============================================================
|
||
# 验证装饰器(与 Aligner 配合)
|
||
# ============================================================
|
||
|
||
def validate_ohlcv_before_align(func):
|
||
"""
|
||
验证 OHLCV 数据在对齐前符合要求
|
||
|
||
使用示例:
|
||
class CrossMarketAligner:
|
||
@validate_ohlcv_before_align
|
||
def align_factor(self, factor_series, source_calendar, code):
|
||
...
|
||
"""
|
||
from functools import wraps
|
||
|
||
@wraps(func)
|
||
def wrapper(self, *args, **kwargs):
|
||
# 提取 close_series(第二个参数)
|
||
if len(args) >= 1:
|
||
close_series = args[0]
|
||
else:
|
||
close_series = kwargs.get('close_series')
|
||
|
||
if isinstance(close_series, pd.Series):
|
||
# 验证 close 列
|
||
if not pd.api.types.is_numeric_dtype(close_series):
|
||
raise TypeError(
|
||
f"close_series 必须是数值类型,当前是 {close_series.dtype}"
|
||
)
|
||
|
||
if close_series.isna().all():
|
||
raise ValueError("close_series 全为 NaN")
|
||
|
||
return func(self, *args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
def validate_factor_after_align(func):
|
||
"""
|
||
验证因子对齐后符合要求
|
||
|
||
使用示例:
|
||
class CrossMarketAligner:
|
||
@validate_factor_after_align
|
||
def align_factor(self, factor_series, source_calendar, code):
|
||
...
|
||
"""
|
||
from functools import wraps
|
||
|
||
@wraps(func)
|
||
def wrapper(self, *args, **kwargs):
|
||
result = func(self, *args, **kwargs)
|
||
|
||
# 验证返回类型
|
||
if not isinstance(result, pd.DataFrame):
|
||
raise TypeError(
|
||
f"align_factor 必须返回 DataFrame,当前返回 {type(result)}"
|
||
)
|
||
|
||
# 验证列
|
||
required_cols = ['value', 'is_filled']
|
||
missing_cols = [col for col in required_cols if col not in result.columns]
|
||
if missing_cols:
|
||
raise ValueError(f"对齐后 DataFrame 缺少列: {missing_cols}")
|
||
|
||
# 验证 value 列类型
|
||
if not pd.api.types.is_numeric_dtype(result['value']):
|
||
raise TypeError(f"value 列必须是数值类型")
|
||
|
||
# 验证 is_filled 列类型
|
||
if not pd.api.types.is_bool_dtype(result['is_filled']):
|
||
raise TypeError(f"is_filled 列必须是布尔类型")
|
||
|
||
return result
|
||
return wrapper
|
||
|
||
|
||
def validate_returns_after_align(func):
|
||
"""
|
||
验证收益率对齐后符合要求
|
||
|
||
使用示例:
|
||
class CrossMarketAligner:
|
||
@validate_returns_after_align
|
||
def align_returns(self, close_series, code):
|
||
...
|
||
"""
|
||
from functools import wraps
|
||
|
||
@wraps(func)
|
||
def wrapper(self, *args, **kwargs):
|
||
result = func(self, *args, **kwargs)
|
||
|
||
# 验证返回类型
|
||
if not isinstance(result, pd.Series):
|
||
raise TypeError(
|
||
f"align_returns 必须返回 Series,当前返回 {type(result)}"
|
||
)
|
||
|
||
# 验证无 NaN
|
||
if result.isna().any():
|
||
nan_count = result.isna().sum()
|
||
raise ValueError(f"收益率包含 {nan_count} 个 NaN")
|
||
|
||
# 验证收益率范围
|
||
max_return = result.abs().max()
|
||
if max_return > 0.5:
|
||
import warnings
|
||
warnings.warn(f"发现异常收益率: {max_return:.2%}")
|
||
|
||
return result
|
||
return wrapper
|