diff --git a/framework_v2/shared/data/__init__.py b/framework_v2/shared/data/__init__.py new file mode 100644 index 0000000..4625831 --- /dev/null +++ b/framework_v2/shared/data/__init__.py @@ -0,0 +1,19 @@ +""" +通用数据处理 +""" + +from framework_v2.shared.data.alignment import CrossMarketAligner +from framework_v2.shared.data.schemas import ( + OHLCVInputSchema, + AlignedFactorSchema, + AlignedReturnsSchema, + AlignmentValidationResult, +) + +__all__ = [ + 'CrossMarketAligner', + 'OHLCVInputSchema', + 'AlignedFactorSchema', + 'AlignedReturnsSchema', + 'AlignmentValidationResult', +] diff --git a/framework_v2/shared/data/alignment.py b/framework_v2/shared/data/alignment.py new file mode 100644 index 0000000..aeda3f7 --- /dev/null +++ b/framework_v2/shared/data/alignment.py @@ -0,0 +1,334 @@ +""" +跨市场数据对齐器 + +核心原则: +1. 因子在原始交易日历计算,再对齐到目标日历(A股) +2. 价格先对齐到目标日历,再计算收益率 +3. 显式标记 ffill 填充的值 +4. 严格验证对齐结果(Pydantic Schema + 内置验证) + +解决的问题: +- 跨市场交易日历不同(美股/港股/A股假日不同) +- ffill 陷阱(收益率 vs 价格) +- NaN 传播 +- 日期不一致 +""" + +import pandas as pd +import numpy as np +from typing import Dict, List, Optional, Tuple +import warnings +from functools import wraps + +# 导入 Schema 验证 +from framework_v2.shared.data.schemas import ( + OHLCVInputSchema, + AlignedFactorSchema, + AlignedReturnsSchema, + MultiAssetReturnsSchema, + AlignmentValidationResult, + validate_ohlcv_before_align, + validate_factor_after_align, + validate_returns_after_align +) + + +class CrossMarketAligner: + """ + 跨市场数据对齐器 + + 使用示例: + >>> aligner = CrossMarketAligner(target_calendar=a_share_dates) + >>> + >>> # 对齐因子值 + >>> aligned = aligner.align_factor(factor_series, source_calendar=us_dates) + >>> + >>> # 对齐收益率 + >>> returns = aligner.align_returns(close_series, code='^GSPC') + >>> + >>> # 对齐多标的 + >>> returns_df = aligner.align_multi_asset(close_dict) + """ + + def __init__( + self, + target_calendar: pd.Index, + max_nan_ratio: float = 0.1, + max_single_day_return: float = 0.5 + ): + """ + 初始化 + + Args: + target_calendar: 目标交易日历(A股) + max_nan_ratio: 最大允许 NaN 比例(默认 10%) + max_single_day_return: 最大单日收益率(默认 50%,用于检测异常) + """ + self.target_calendar = target_calendar + self.max_nan_ratio = max_nan_ratio + self.max_single_day_return = max_single_day_return + + # 统计信息 + self._stats = { + 'aligned_factors': 0, + 'aligned_returns': 0, + 'warnings': [] + } + + @validate_factor_after_align # ← Pydantic Schema 验证 + def align_factor( + self, + factor_series: pd.Series, + source_calendar: pd.Index, + code: str = '' + ) -> pd.DataFrame: + """ + 对齐因子值到目标日历 + + 规则: + - 因子在 source_calendar 计算 + - 对齐到 target_calendar(ffill) + - 标记哪些是填充值(is_filled 列) + + Args: + factor_series: 因子值序列(source_calendar 索引) + source_calendar: 原始交易日历 + code: 标的代码(用于日志) + + Returns: + DataFrame with columns: + - value: 对齐后的因子值 + - is_filled: 是否为 ffill 填充值 + """ + # 1. reindex + ffill + aligned = factor_series.reindex(self.target_calendar, method='ffill') + + # 2. 标记填充值(不在 source_calendar 中的日期) + is_filled = ~aligned.index.isin(source_calendar) + + # 3. 验证 + self._validate_factor_alignment(aligned, is_filled, code) + + # 4. 统计 + self._stats['aligned_factors'] += 1 + + return pd.DataFrame({ + 'value': aligned, + 'is_filled': is_filled + }, index=self.target_calendar) + + @validate_returns_after_align # ← Pydantic Schema 验证 + def align_returns( + self, + close_series: pd.Series, + code: str + ) -> pd.Series: + """ + 对齐收益率到目标日历 + + 规则: + - 价格先 ffill 到 target_calendar + - 再计算 pct_change + - 休市日收益率 = 0%(价格不变) + + 重要: + ❌ 错误:先计算收益率,再 ffill(会复制非零收益率) + ✅ 正确:先 ffill 价格,再计算收益率(休市日收益率 = 0%) + + Args: + close_series: 收盘价序列(原始日历) + code: 标的代码(用于日志和错误信息) + + Returns: + 收益率序列(target_calendar 索引) + """ + # 1. 价格对齐到目标日历 + close_aligned = close_series.reindex( + self.target_calendar, + method='ffill' + ) + + # 2. 计算收益率(关键:fill_method=None,不填充 NaN) + returns = close_aligned.pct_change(fill_method=None) + + # 3. 填充首日 NaN(首日无前一日,收益率 = 0) + if len(returns) > 0: + returns.iloc[0] = 0.0 + + # 4. 填充剩余 NaN(如果价格全 NaN,收益率也全 NaN) + nan_ratio = returns.isna().sum() / len(returns) + if nan_ratio > 0: + # 用 0 填充(表示"无数据,收益率为 0") + returns = returns.fillna(0.0) + warnings.warn( + f"{code}: 收益率 NaN 比例 {nan_ratio:.1%},已填充为 0" + ) + + # 5. 验证 + self._validate_returns(returns, code) + + # 6. 统计 + self._stats['aligned_returns'] += 1 + + return returns + + def align_multi_asset( + self, + close_dict: Dict[str, pd.Series] + ) -> pd.DataFrame: + """ + 对齐多标的收益率 + + Args: + close_dict: {标的代码: 收盘价序列} + + Returns: + 收益率 DataFrame(所有标的同索引 = target_calendar) + """ + returns_dict = {} + + for code, close_series in close_dict.items(): + try: + returns_dict[code] = self.align_returns(close_series, code) + except Exception as e: + warnings.warn(f"{code}: 收益率对齐失败 - {e}") + # 填充全 0 + returns_dict[code] = pd.Series( + 0.0, + index=self.target_calendar, + name=code + ) + + # 合并为 DataFrame + returns_df = pd.DataFrame(returns_dict, index=self.target_calendar) + + # 最终验证:不能有 NaN + if returns_df.isna().any().any(): + nan_cols = returns_df.columns[returns_df.isna().any()] + raise ValueError( + f"多标的收益率对齐后仍包含 NaN\n" + f"NaN 列: {list(nan_cols)}\n" + f"这不应该发生,请检查 align_returns 逻辑" + ) + + return returns_df + + def validate_alignment( + self, + signals: pd.DataFrame, + returns_df: pd.DataFrame + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + 验证信号与收益率对齐,并返回对齐后的结果 + + Args: + signals: 信号 DataFrame + returns_df: 收益率 DataFrame + + Returns: + (aligned_signals, aligned_returns) + + Raises: + ValueError: 如果对齐后日期太少 + """ + # 1. 找共同日期 + common_dates = signals.index.intersection(returns_df.index) + + # 2. 检查丢失的日期 + lost_signals = len(signals) - len(common_dates) + lost_returns = len(returns_df) - len(common_dates) + + if lost_signals > 0 or lost_returns > 0: + warnings.warn( + f"信号与收益率对齐丢失日期\n" + f"信号: {len(signals)} → {len(common_dates)} (丢失 {lost_signals})\n" + f"收益: {len(returns_df)} → {len(common_dates)} (丢失 {lost_returns})" + ) + + # 3. 检查对齐后日期是否太少 + if len(common_dates) < 10: + raise ValueError( + f"对齐后日期太少: {len(common_dates)} 天\n" + f"信号和收益率可能使用了不同的日历" + ) + + # 4. 裁剪到共同日期 + aligned_signals = signals.loc[common_dates] + aligned_returns = returns_df.loc[common_dates] + + # 5. 使用 Pydantic Schema 验证结果 + validation_result = AlignmentValidationResult( + signals_aligned=True, + returns_aligned=True, + common_dates_count=len(common_dates), + lost_signals=lost_signals, + lost_returns=lost_returns + ) + + # 6. 如果验证失败,会抛出异常 + # (Pydantic 自动验证 field_validator) + + return aligned_signals, aligned_returns + + def _validate_factor_alignment( + self, + aligned: pd.Series, + is_filled: pd.Series, + code: str + ): + """验证因子对齐结果""" + # 1. 检查 NaN 比例 + nan_ratio = aligned.isna().sum() / len(aligned) + if nan_ratio > self.max_nan_ratio: + warnings.warn( + f"{code}: 因子 NaN 比例过高 ({nan_ratio:.1%} > {self.max_nan_ratio:.1%})" + ) + + # 2. 检查填充比例 + fill_ratio = is_filled.sum() / len(is_filled) + if fill_ratio > 0.3: + warnings.warn( + f"{code}: 因子填充比例过高 ({fill_ratio:.1%})\n" + f"可能源日历与目标日历差异太大" + ) + + def _validate_returns( + self, + returns: pd.Series, + code: str + ): + """验证收益率数据""" + # 1. 检查 NaN 比例 + nan_ratio = returns.isna().sum() / len(returns) + if nan_ratio > self.max_nan_ratio: + raise ValueError( + f"{code}: 收益率 NaN 比例过高 ({nan_ratio:.1%} > {self.max_nan_ratio:.1%})" + ) + + # 2. 检查异常值 + max_return = returns.abs().max() + if max_return > self.max_single_day_return: + warnings.warn( + f"{code}: 发现异常收益率 ({max_return:.1%} > {self.max_single_day_return:.1%})\n" + f"可能数据有问题" + ) + + # 3. 检查索引是否匹配目标日历 + if not returns.index.equals(self.target_calendar): + raise ValueError( + f"{code}: 收益率索引与目标日历不匹配\n" + f"收益率长度: {len(returns)}\n" + f"目标日历长度: {len(self.target_calendar)}" + ) + + def get_stats(self) -> dict: + """获取对齐统计信息""" + return self._stats.copy() + + def reset_stats(self): + """重置统计信息""" + self._stats = { + 'aligned_factors': 0, + 'aligned_returns': 0, + 'warnings': [] + } diff --git a/framework_v2/shared/data/schemas.py b/framework_v2/shared/data/schemas.py new file mode 100644 index 0000000..2a1c4cf --- /dev/null +++ b/framework_v2/shared/data/schemas.py @@ -0,0 +1,258 @@ +""" +数据对齐 Schema 定义 + +与 CrossMarketAligner 配合使用,提供结构验证 +""" + +from pydantic import BaseModel, Field, field_validator +from typing import Optional, List +import pandas as pd +import numpy as np + + +# ============================================================ +# 输入验证 Schema +# ============================================================ + +class OHLCVInputSchema(BaseModel): + """ + OHLCV 输入数据验证 + + 用于对齐前验证原始数据 + """ + # 必需字段 + close: float = Field(..., description="收盘价(必需)", gt=0) + + # 可选字段 + open: Optional[float] = Field(None, description="开盘价", gt=0) + high: Optional[float] = Field(None, description="最高价", gt=0) + low: Optional[float] = Field(None, description="最低价", gt=0) + volume: Optional[float] = Field(None, description="成交量", ge=0) + + class Config: + extra = "ignore" # 忽略额外字段 + + @field_validator('close', 'open', 'high', 'low') + @classmethod + def check_positive(cls, v): + """价格必须为正数""" + if v is not None and v <= 0: + raise ValueError(f"价格必须为正数,当前值: {v}") + return v + + +class FactorInputSchema(BaseModel): + """ + 因子输入数据验证 + + 用于验证因子值在合理范围内 + """ + value: float = Field(..., description="因子值") + is_filled: bool = Field(False, description="是否为填充值") + + @field_validator('value') + @classmethod + def check_reasonable(cls, v): + """因子值应在合理范围内(-10 ~ 10)""" + if abs(v) > 10: + import warnings + warnings.warn(f"因子值异常: {v}") + return v + + +# ============================================================ +# 输出验证 Schema +# ============================================================ + +class AlignedFactorSchema(BaseModel): + """ + 对齐后的因子数据验证 + + 用于验证 align_factor() 的输出 + """ + value: float = Field(..., description="对齐后的因子值") + is_filled: bool = Field(..., description="是否为填充值") + + class Config: + # 允许 NaN(早期数据不足) + arbitrary_types_allowed = True + + +class AlignedReturnsSchema(BaseModel): + """ + 对齐后的收益率数据验证 + + 用于验证 align_returns() 的输出 + """ + returns: float = Field(..., description="收益率") + + @field_validator('returns') + @classmethod + def check_returns_range(cls, v): + """收益率应在合理范围内(-50% ~ 50%)""" + if abs(v) > 0.5: + import warnings + warnings.warn(f"收益率异常: {v:.2%}") + return v + + class Config: + arbitrary_types_allowed = True + + +# ============================================================ +# 批量验证 Schema +# ============================================================ + +class MultiAssetReturnsSchema(BaseModel): + """ + 多标的收益率数据验证 + + 用于验证 align_multi_asset() 的输出 + """ + data: dict = Field(..., description="{标的代码: 收益率 Series}") + + @field_validator('data') + @classmethod + def check_no_nan(cls, v): + """收益率 DataFrame 不能有 NaN""" + df = pd.DataFrame(v) + if df.isna().any().any(): + nan_cols = df.columns[df.isna().any()] + raise ValueError(f"收益率包含 NaN 列: {list(nan_cols)}") + return v + + +class AlignmentValidationResult(BaseModel): + """ + 对齐验证结果 + + 用于 validate_alignment() 的输出 + """ + signals_aligned: bool = Field(..., description="信号是否已对齐") + returns_aligned: bool = Field(..., description="收益率是否已对齐") + common_dates_count: int = Field(..., description="共同日期数量") + lost_signals: int = Field(0, description="丢失的信号数") + lost_returns: int = Field(0, description="丢失的收益数") + + @field_validator('common_dates_count') + @classmethod + def check_min_dates(cls, v): + """共同日期至少 10 天""" + if v < 10: + raise ValueError(f"共同日期太少: {v} 天") + return v + + +# ============================================================ +# 验证装饰器(与 Aligner 配合) +# ============================================================ + +def validate_ohlcv_before_align(func): + """ + 验证 OHLCV 数据在对齐前符合要求 + + 使用示例: + class CrossMarketAligner: + @validate_ohlcv_before_align + def align_factor(self, factor_series, source_calendar, code): + ... + """ + from functools import wraps + + @wraps(func) + def wrapper(self, *args, **kwargs): + # 提取 close_series(第二个参数) + if len(args) >= 1: + close_series = args[0] + else: + close_series = kwargs.get('close_series') + + if isinstance(close_series, pd.Series): + # 验证 close 列 + if not pd.api.types.is_numeric_dtype(close_series): + raise TypeError( + f"close_series 必须是数值类型,当前是 {close_series.dtype}" + ) + + if close_series.isna().all(): + raise ValueError("close_series 全为 NaN") + + return func(self, *args, **kwargs) + return wrapper + + +def validate_factor_after_align(func): + """ + 验证因子对齐后符合要求 + + 使用示例: + class CrossMarketAligner: + @validate_factor_after_align + def align_factor(self, factor_series, source_calendar, code): + ... + """ + from functools import wraps + + @wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + + # 验证返回类型 + if not isinstance(result, pd.DataFrame): + raise TypeError( + f"align_factor 必须返回 DataFrame,当前返回 {type(result)}" + ) + + # 验证列 + required_cols = ['value', 'is_filled'] + missing_cols = [col for col in required_cols if col not in result.columns] + if missing_cols: + raise ValueError(f"对齐后 DataFrame 缺少列: {missing_cols}") + + # 验证 value 列类型 + if not pd.api.types.is_numeric_dtype(result['value']): + raise TypeError(f"value 列必须是数值类型") + + # 验证 is_filled 列类型 + if not pd.api.types.is_bool_dtype(result['is_filled']): + raise TypeError(f"is_filled 列必须是布尔类型") + + return result + return wrapper + + +def validate_returns_after_align(func): + """ + 验证收益率对齐后符合要求 + + 使用示例: + class CrossMarketAligner: + @validate_returns_after_align + def align_returns(self, close_series, code): + ... + """ + from functools import wraps + + @wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + + # 验证返回类型 + if not isinstance(result, pd.Series): + raise TypeError( + f"align_returns 必须返回 Series,当前返回 {type(result)}" + ) + + # 验证无 NaN + if result.isna().any(): + nan_count = result.isna().sum() + raise ValueError(f"收益率包含 {nan_count} 个 NaN") + + # 验证收益率范围 + max_return = result.abs().max() + if max_return > 0.5: + import warnings + warnings.warn(f"发现异常收益率: {max_return:.2%}") + + return result + return wrapper diff --git a/framework_v2/tests/test_alignment.py b/framework_v2/tests/test_alignment.py new file mode 100644 index 0000000..b12114d --- /dev/null +++ b/framework_v2/tests/test_alignment.py @@ -0,0 +1,292 @@ +""" +数据对齐测试 + +验证跨市场对齐器的正确性 +""" + +import sys +import pandas as pd +import numpy as np +from pathlib import Path + +# 添加项目根目录 +project_root = Path(__file__).parent.parent.parent +sys.path.insert(0, str(project_root)) + + +def test_factor_alignment(): + """测试因子对齐""" + from framework_v2.shared.data.alignment import CrossMarketAligner + + print("=" * 60) + print(" 测试 1: 因子对齐") + print("=" * 60) + + # 创建模拟数据 + # 美股日历(比 A 股少几天) + us_dates = pd.date_range('2024-01-01', periods=10, freq='B') + us_dates = us_dates.delete([2, 5]) # 删除 2 天(模拟美股假日) + + # A 股日历(完整) + cn_dates = pd.date_range('2024-01-01', periods=10, freq='B') + + # 因子值(美股日历) + factor_values = [0.1, 0.15, 0.18, 0.16, 0.20, 0.22, 0.25, 0.23] + factor_series = pd.Series(factor_values, index=us_dates) + + print(f"\n源日历(美股): {len(us_dates)} 天") + print(f"目标日历(A股): {len(cn_dates)} 天") + print(f"因子值(源日历):") + print(factor_series) + + # 对齐 + aligner = CrossMarketAligner(target_calendar=cn_dates) + aligned = aligner.align_factor(factor_series, us_dates, code='^GSPC') + + print(f"\n对齐后因子值:") + print(aligned) + + # 验证 + assert len(aligned) == len(cn_dates), "对齐后长度应该等于目标日历" + assert 'value' in aligned.columns, "应该有 'value' 列" + assert 'is_filled' in aligned.columns, "应该有 'is_filled' 列" + + # 检查填充值 + filled_count = aligned['is_filled'].sum() + print(f"\n填充值数量: {filled_count}") + + print("\n✓ 测试通过") + return True + + +def test_returns_alignment(): + """测试收益率对齐""" + from framework_v2.shared.data.alignment import CrossMarketAligner + + print("\n" + "=" * 60) + print(" 测试 2: 收益率对齐") + print("=" * 60) + + # 创建模拟价格数据 + # 美股日历 + us_dates = pd.date_range('2024-01-01', periods=10, freq='B') + us_dates = us_dates.delete([2, 5]) + + # A 股日历 + cn_dates = pd.date_range('2024-01-01', periods=10, freq='B') + + # 价格(美股日历) + prices = [100, 101, 102, 101, 103, 104, 105, 104] + close_series = pd.Series(prices, index=us_dates, dtype=float) + + print(f"\n源价格(美股日历):") + print(close_series) + + # 对齐收益率 + aligner = CrossMarketAligner(target_calendar=cn_dates) + returns = aligner.align_returns(close_series, code='^GSPC') + + print(f"\n对齐后收益率(A股日历):") + print(returns) + + # 验证 + assert len(returns) == len(cn_dates), "收益率长度应该等于目标日历" + assert returns.index.equals(cn_dates), "收益率索引应该等于目标日历" + assert not returns.isna().any(), "收益率不应该有 NaN" + assert returns.iloc[0] == 0.0, "首日收益率应该为 0" + + # 检查休市日收益率(应该为 0) + # 美股休市日,价格 ffill,收益率 = 0 + print(f"\n收益率统计:") + print(f" 最小值: {returns.min():.4f}") + print(f" 最大值: {returns.max():.4f}") + print(f" 均值: {returns.mean():.4f}") + + print("\n✓ 测试通过") + return True + + +def test_multi_asset_alignment(): + """测试多标的对齐""" + from framework_v2.shared.data.alignment import CrossMarketAligner + + print("\n" + "=" * 60) + print(" 测试 3: 多标的收益率对齐") + print("=" * 60) + + # A 股日历 + cn_dates = pd.date_range('2024-01-01', periods=10, freq='B') + + # 多个标的(不同日历) + close_dict = {} + + # 标的 1: 完整数据 + prices1 = 100 + np.cumsum(np.random.randn(10)) + close_dict['^GSPC'] = pd.Series(prices1, index=cn_dates, dtype=float) + + # 标的 2: 少 2 天 + us_dates2 = cn_dates.delete([2, 5]) + prices2 = 100 + np.cumsum(np.random.randn(8)) + close_dict['^IXIC'] = pd.Series(prices2, index=us_dates2, dtype=float) + + # 标的 3: 完整数据 + prices3 = 100 + np.cumsum(np.random.randn(10)) + close_dict['931862.CSI'] = pd.Series(prices3, index=cn_dates, dtype=float) + + print(f"\n标的数量: {len(close_dict)}") + for code, close in close_dict.items(): + print(f" {code}: {len(close)} 天") + + # 对齐多标的 + aligner = CrossMarketAligner(target_calendar=cn_dates) + returns_df = aligner.align_multi_asset(close_dict) + + print(f"\n对齐后收益率 DataFrame:") + print(returns_df.head()) + + # 验证 + assert len(returns_df) == len(cn_dates), "收益率 DataFrame 长度应该等于目标日历" + assert returns_df.index.equals(cn_dates), "索引应该等于目标日历" + assert not returns_df.isna().any().any(), "收益率 DataFrame 不应该有 NaN" + assert len(returns_df.columns) == len(close_dict), "列数应该等于标的数" + + print(f"\n✓ 测试通过") + return True + + +def test_signal_returns_alignment(): + """测试信号与收益率对齐""" + from framework_v2.shared.data.alignment import CrossMarketAligner + + print("\n" + "=" * 60) + print(" 测试 4: 信号与收益率对齐验证") + print("=" * 60) + + # A 股日历 + cn_dates = pd.date_range('2024-01-01', periods=10, freq='B') + + # 信号(多 2 天) + signal_dates = cn_dates.union(pd.date_range('2024-01-15', periods=2, freq='B')) + signals = pd.DataFrame({ + 'signal': ['^GSPC,^IXIC'] * len(signal_dates) + }, index=signal_dates) + + # 收益率(正常) + returns_df = pd.DataFrame( + np.random.randn(len(cn_dates), 2) * 0.02, + index=cn_dates, + columns=['^GSPC', '^IXIC'] + ) + + print(f"\n信号: {len(signals)} 天") + print(f"收益: {len(returns_df)} 天") + + # 验证对齐 + aligner = CrossMarketAligner(target_calendar=cn_dates) + aligned_signals, aligned_returns = aligner.validate_alignment(signals, returns_df) + + print(f"\n对齐后:") + print(f" 信号: {len(aligned_signals)} 天") + print(f" 收益: {len(aligned_returns)} 天") + + # 验证 + assert len(aligned_signals) == len(aligned_returns), "对齐后长度应该一致" + assert aligned_signals.index.equals(aligned_returns.index), "索引应该一致" + + print("\n✓ 测试通过") + return True + + +def test_ffill_trap(): + """测试 ffill 陷阱(错误 vs 正确)""" + from framework_v2.shared.data.alignment import CrossMarketAligner + + print("\n" + "=" * 60) + print(" 测试 5: ffill 陷阱对比") + print("=" * 60) + + # 美股日历 + us_dates = pd.date_range('2024-01-01', periods=5, freq='B') + us_dates = us_dates.delete([2]) # 第 3 天休市 + + # A 股日历 + cn_dates = pd.date_range('2024-01-01', periods=5, freq='B') + + # 价格 + prices = pd.Series([100, 101, 102, 103], index=us_dates, dtype=float) + + print(f"\n原始价格(美股日历):") + print(prices) + + # ❌ 错误做法:先计算收益率,再 ffill + print("\n❌ 错误做法:先 pct_change,再 reindex") + returns_wrong = prices.pct_change() + print("步骤 1 - 收益率:") + print(returns_wrong) + + returns_wrong_aligned = returns_wrong.reindex(cn_dates, method='ffill') + print("\n步骤 2 - reindex + ffill:") + print(returns_wrong_aligned) + print("⚠ 问题:第 3 天复制了第 2 天的 +1% 收益率!") + + # ✅ 正确做法:先 ffill 价格,再计算收益率 + print("\n✅ 正确做法:先 reindex 价格,再 pct_change") + prices_aligned = prices.reindex(cn_dates, method='ffill') + print("步骤 1 - 价格 reindex:") + print(prices_aligned) + + returns_correct = prices_aligned.pct_change(fill_method=None) + returns_correct.iloc[0] = 0.0 + print("\n步骤 2 - pct_change:") + print(returns_correct) + print("✓ 第 3 天收益率 = 0%(价格不变)") + + # 验证 + aligner = CrossMarketAligner(target_calendar=cn_dates) + returns = aligner.align_returns(prices, code='TEST') + + assert returns.iloc[2] == 0.0, "休市日收益率应该为 0" + + print("\n✓ 测试通过") + return True + + +if __name__ == '__main__': + print("\n" + "=" * 60) + print(" 跨市场数据对齐器测试") + print("=" * 60) + + tests = [ + ("因子对齐", test_factor_alignment), + ("收益率对齐", test_returns_alignment), + ("多标的对齐", test_multi_asset_alignment), + ("信号与收益对齐", test_signal_returns_alignment), + ("ffill 陷阱", test_ffill_trap), + ] + + results = [] + for name, test_func in tests: + try: + success = test_func() + results.append((name, success)) + except Exception as e: + print(f"\n✗ {name} 失败: {e}") + import traceback + traceback.print_exc() + results.append((name, False)) + + # 总结 + print("\n" + "=" * 60) + print(" 测试总结") + print("=" * 60) + + passed = sum(1 for _, success in results if success) + total = len(results) + + for name, success in results: + status = "✓ 通过" if success else "✗ 失败" + print(f" {status} - {name}") + + print(f"\n总计: {passed}/{total} 通过") + + sys.exit(0 if passed == total else 1)