Files
etf/framework_v2/shared/data/schemas.py
aszerW a16681bda9 feat(framework_v2): 添加跨市场数据对齐器 + Pydantic Schema 验证
## 核心功能
- CrossMarketAligner: 跨市场数据对齐(解决 ffill 陷阱)
- Pydantic Schema: 数据结构验证(OHLCVInputSchema, AlignedFactorSchema 等)
- 验证装饰器: @validate_factor_after_align, @validate_returns_after_align

## 解决的问题
- 跨市场交易日历不同(美股/港股/A股)
- ffill 收益率陷阱(休市日复制非零收益率)
- NaN 传播问题
- 日期不一致问题

## 测试验证
- 5/5 测试通过(因子对齐、收益率对齐、多标的对齐、信号验证、ffill陷阱)
- 休市日收益率 = 0%(正确)
- 无 NaN 传播

## 架构设计
- shared/data/alignment.py - 对齐器实现
- shared/data/schemas.py - Pydantic Schema 定义
- tests/test_alignment.py - 完整测试套件
2026-05-24 10:28:35 +08:00

259 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据对齐 Schema 定义
与 CrossMarketAligner 配合使用,提供结构验证
"""
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List
import pandas as pd
import numpy as np
# ============================================================
# 输入验证 Schema
# ============================================================
class OHLCVInputSchema(BaseModel):
"""
OHLCV 输入数据验证
用于对齐前验证原始数据
"""
# 必需字段
close: float = Field(..., description="收盘价(必需)", gt=0)
# 可选字段
open: Optional[float] = Field(None, description="开盘价", gt=0)
high: Optional[float] = Field(None, description="最高价", gt=0)
low: Optional[float] = Field(None, description="最低价", gt=0)
volume: Optional[float] = Field(None, description="成交量", ge=0)
class Config:
extra = "ignore" # 忽略额外字段
@field_validator('close', 'open', 'high', 'low')
@classmethod
def check_positive(cls, v):
"""价格必须为正数"""
if v is not None and v <= 0:
raise ValueError(f"价格必须为正数,当前值: {v}")
return v
class FactorInputSchema(BaseModel):
"""
因子输入数据验证
用于验证因子值在合理范围内
"""
value: float = Field(..., description="因子值")
is_filled: bool = Field(False, description="是否为填充值")
@field_validator('value')
@classmethod
def check_reasonable(cls, v):
"""因子值应在合理范围内(-10 ~ 10"""
if abs(v) > 10:
import warnings
warnings.warn(f"因子值异常: {v}")
return v
# ============================================================
# 输出验证 Schema
# ============================================================
class AlignedFactorSchema(BaseModel):
"""
对齐后的因子数据验证
用于验证 align_factor() 的输出
"""
value: float = Field(..., description="对齐后的因子值")
is_filled: bool = Field(..., description="是否为填充值")
class Config:
# 允许 NaN早期数据不足
arbitrary_types_allowed = True
class AlignedReturnsSchema(BaseModel):
"""
对齐后的收益率数据验证
用于验证 align_returns() 的输出
"""
returns: float = Field(..., description="收益率")
@field_validator('returns')
@classmethod
def check_returns_range(cls, v):
"""收益率应在合理范围内(-50% ~ 50%"""
if abs(v) > 0.5:
import warnings
warnings.warn(f"收益率异常: {v:.2%}")
return v
class Config:
arbitrary_types_allowed = True
# ============================================================
# 批量验证 Schema
# ============================================================
class MultiAssetReturnsSchema(BaseModel):
"""
多标的收益率数据验证
用于验证 align_multi_asset() 的输出
"""
data: dict = Field(..., description="{标的代码: 收益率 Series}")
@field_validator('data')
@classmethod
def check_no_nan(cls, v):
"""收益率 DataFrame 不能有 NaN"""
df = pd.DataFrame(v)
if df.isna().any().any():
nan_cols = df.columns[df.isna().any()]
raise ValueError(f"收益率包含 NaN 列: {list(nan_cols)}")
return v
class AlignmentValidationResult(BaseModel):
"""
对齐验证结果
用于 validate_alignment() 的输出
"""
signals_aligned: bool = Field(..., description="信号是否已对齐")
returns_aligned: bool = Field(..., description="收益率是否已对齐")
common_dates_count: int = Field(..., description="共同日期数量")
lost_signals: int = Field(0, description="丢失的信号数")
lost_returns: int = Field(0, description="丢失的收益数")
@field_validator('common_dates_count')
@classmethod
def check_min_dates(cls, v):
"""共同日期至少 10 天"""
if v < 10:
raise ValueError(f"共同日期太少: {v}")
return v
# ============================================================
# 验证装饰器(与 Aligner 配合)
# ============================================================
def validate_ohlcv_before_align(func):
"""
验证 OHLCV 数据在对齐前符合要求
使用示例:
class CrossMarketAligner:
@validate_ohlcv_before_align
def align_factor(self, factor_series, source_calendar, code):
...
"""
from functools import wraps
@wraps(func)
def wrapper(self, *args, **kwargs):
# 提取 close_series第二个参数
if len(args) >= 1:
close_series = args[0]
else:
close_series = kwargs.get('close_series')
if isinstance(close_series, pd.Series):
# 验证 close 列
if not pd.api.types.is_numeric_dtype(close_series):
raise TypeError(
f"close_series 必须是数值类型,当前是 {close_series.dtype}"
)
if close_series.isna().all():
raise ValueError("close_series 全为 NaN")
return func(self, *args, **kwargs)
return wrapper
def validate_factor_after_align(func):
"""
验证因子对齐后符合要求
使用示例:
class CrossMarketAligner:
@validate_factor_after_align
def align_factor(self, factor_series, source_calendar, code):
...
"""
from functools import wraps
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
# 验证返回类型
if not isinstance(result, pd.DataFrame):
raise TypeError(
f"align_factor 必须返回 DataFrame当前返回 {type(result)}"
)
# 验证列
required_cols = ['value', 'is_filled']
missing_cols = [col for col in required_cols if col not in result.columns]
if missing_cols:
raise ValueError(f"对齐后 DataFrame 缺少列: {missing_cols}")
# 验证 value 列类型
if not pd.api.types.is_numeric_dtype(result['value']):
raise TypeError(f"value 列必须是数值类型")
# 验证 is_filled 列类型
if not pd.api.types.is_bool_dtype(result['is_filled']):
raise TypeError(f"is_filled 列必须是布尔类型")
return result
return wrapper
def validate_returns_after_align(func):
"""
验证收益率对齐后符合要求
使用示例:
class CrossMarketAligner:
@validate_returns_after_align
def align_returns(self, close_series, code):
...
"""
from functools import wraps
@wraps(func)
def wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
# 验证返回类型
if not isinstance(result, pd.Series):
raise TypeError(
f"align_returns 必须返回 Series当前返回 {type(result)}"
)
# 验证无 NaN
if result.isna().any():
nan_count = result.isna().sum()
raise ValueError(f"收益率包含 {nan_count} 个 NaN")
# 验证收益率范围
max_return = result.abs().max()
if max_return > 0.5:
import warnings
warnings.warn(f"发现异常收益率: {max_return:.2%}")
return result
return wrapper