refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
9
archive/framework_v2/shared/__init__.py
Normal file
9
archive/framework_v2/shared/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
通用实现层(2+ 策略复用的组件)
|
||||
|
||||
包含:
|
||||
├── factors/ # 通用因子
|
||||
├── signals/ # 通用信号生成器
|
||||
├── execution/ # 通用执行器
|
||||
└── data/ # 通用数据处理
|
||||
"""
|
||||
21
archive/framework_v2/shared/data/__init__.py
Normal file
21
archive/framework_v2/shared/data/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
通用数据处理
|
||||
"""
|
||||
|
||||
from framework_v2.shared.data.alignment import CrossMarketAligner
|
||||
from framework_v2.shared.data.schemas import (
|
||||
OHLCVInputSchema,
|
||||
AlignedFactorSchema,
|
||||
AlignedReturnsSchema,
|
||||
AlignmentValidationResult,
|
||||
)
|
||||
from framework_v2.shared.data.flask_api_fetcher import FlaskAPIFetcher
|
||||
|
||||
__all__ = [
|
||||
'CrossMarketAligner',
|
||||
'OHLCVInputSchema',
|
||||
'AlignedFactorSchema',
|
||||
'AlignedReturnsSchema',
|
||||
'AlignmentValidationResult',
|
||||
'FlaskAPIFetcher',
|
||||
]
|
||||
334
archive/framework_v2/shared/data/alignment.py
Normal file
334
archive/framework_v2/shared/data/alignment.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
跨市场数据对齐器
|
||||
|
||||
核心原则:
|
||||
1. 因子在原始交易日历计算,再对齐到目标日历(A股)
|
||||
2. 价格先对齐到目标日历,再计算收益率
|
||||
3. 显式标记 ffill 填充的值
|
||||
4. 严格验证对齐结果(Pydantic Schema + 内置验证)
|
||||
|
||||
解决的问题:
|
||||
- 跨市场交易日历不同(美股/港股/A股假日不同)
|
||||
- ffill 陷阱(收益率 vs 价格)
|
||||
- NaN 传播
|
||||
- 日期不一致
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
# 导入 Schema 验证
|
||||
from framework_v2.shared.data.schemas import (
|
||||
OHLCVInputSchema,
|
||||
AlignedFactorSchema,
|
||||
AlignedReturnsSchema,
|
||||
MultiAssetReturnsSchema,
|
||||
AlignmentValidationResult,
|
||||
validate_ohlcv_before_align,
|
||||
validate_factor_after_align,
|
||||
validate_returns_after_align
|
||||
)
|
||||
|
||||
|
||||
class CrossMarketAligner:
|
||||
"""
|
||||
跨市场数据对齐器
|
||||
|
||||
使用示例:
|
||||
>>> aligner = CrossMarketAligner(target_calendar=a_share_dates)
|
||||
>>>
|
||||
>>> # 对齐因子值
|
||||
>>> aligned = aligner.align_factor(factor_series, source_calendar=us_dates)
|
||||
>>>
|
||||
>>> # 对齐收益率
|
||||
>>> returns = aligner.align_returns(close_series, code='^GSPC')
|
||||
>>>
|
||||
>>> # 对齐多标的
|
||||
>>> returns_df = aligner.align_multi_asset(close_dict)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_calendar: pd.Index,
|
||||
max_nan_ratio: float = 0.1,
|
||||
max_single_day_return: float = 0.5
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
target_calendar: 目标交易日历(A股)
|
||||
max_nan_ratio: 最大允许 NaN 比例(默认 10%)
|
||||
max_single_day_return: 最大单日收益率(默认 50%,用于检测异常)
|
||||
"""
|
||||
self.target_calendar = target_calendar
|
||||
self.max_nan_ratio = max_nan_ratio
|
||||
self.max_single_day_return = max_single_day_return
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
'aligned_factors': 0,
|
||||
'aligned_returns': 0,
|
||||
'warnings': []
|
||||
}
|
||||
|
||||
@validate_factor_after_align # ← Pydantic Schema 验证
|
||||
def align_factor(
|
||||
self,
|
||||
factor_series: pd.Series,
|
||||
source_calendar: pd.Index,
|
||||
code: str = ''
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
对齐因子值到目标日历
|
||||
|
||||
规则:
|
||||
- 因子在 source_calendar 计算
|
||||
- 对齐到 target_calendar(ffill)
|
||||
- 标记哪些是填充值(is_filled 列)
|
||||
|
||||
Args:
|
||||
factor_series: 因子值序列(source_calendar 索引)
|
||||
source_calendar: 原始交易日历
|
||||
code: 标的代码(用于日志)
|
||||
|
||||
Returns:
|
||||
DataFrame with columns:
|
||||
- value: 对齐后的因子值
|
||||
- is_filled: 是否为 ffill 填充值
|
||||
"""
|
||||
# 1. reindex + ffill
|
||||
aligned = factor_series.reindex(self.target_calendar, method='ffill')
|
||||
|
||||
# 2. 标记填充值(不在 source_calendar 中的日期)
|
||||
is_filled = ~aligned.index.isin(source_calendar)
|
||||
|
||||
# 3. 验证
|
||||
self._validate_factor_alignment(aligned, is_filled, code)
|
||||
|
||||
# 4. 统计
|
||||
self._stats['aligned_factors'] += 1
|
||||
|
||||
return pd.DataFrame({
|
||||
'value': aligned,
|
||||
'is_filled': is_filled
|
||||
}, index=self.target_calendar)
|
||||
|
||||
@validate_returns_after_align # ← Pydantic Schema 验证
|
||||
def align_returns(
|
||||
self,
|
||||
close_series: pd.Series,
|
||||
code: str
|
||||
) -> pd.Series:
|
||||
"""
|
||||
对齐收益率到目标日历
|
||||
|
||||
规则:
|
||||
- 价格先 ffill 到 target_calendar
|
||||
- 再计算 pct_change
|
||||
- 休市日收益率 = 0%(价格不变)
|
||||
|
||||
重要:
|
||||
❌ 错误:先计算收益率,再 ffill(会复制非零收益率)
|
||||
✅ 正确:先 ffill 价格,再计算收益率(休市日收益率 = 0%)
|
||||
|
||||
Args:
|
||||
close_series: 收盘价序列(原始日历)
|
||||
code: 标的代码(用于日志和错误信息)
|
||||
|
||||
Returns:
|
||||
收益率序列(target_calendar 索引)
|
||||
"""
|
||||
# 1. 价格对齐到目标日历
|
||||
close_aligned = close_series.reindex(
|
||||
self.target_calendar,
|
||||
method='ffill'
|
||||
)
|
||||
|
||||
# 2. 计算收益率(关键:fill_method=None,不填充 NaN)
|
||||
returns = close_aligned.pct_change(fill_method=None)
|
||||
|
||||
# 3. 填充首日 NaN(首日无前一日,收益率 = 0)
|
||||
if len(returns) > 0:
|
||||
returns.iloc[0] = 0.0
|
||||
|
||||
# 4. 填充剩余 NaN(如果价格全 NaN,收益率也全 NaN)
|
||||
nan_ratio = returns.isna().sum() / len(returns)
|
||||
if nan_ratio > 0:
|
||||
# 用 0 填充(表示"无数据,收益率为 0")
|
||||
returns = returns.fillna(0.0)
|
||||
warnings.warn(
|
||||
f"{code}: 收益率 NaN 比例 {nan_ratio:.1%},已填充为 0"
|
||||
)
|
||||
|
||||
# 5. 验证
|
||||
self._validate_returns(returns, code)
|
||||
|
||||
# 6. 统计
|
||||
self._stats['aligned_returns'] += 1
|
||||
|
||||
return returns
|
||||
|
||||
def align_multi_asset(
|
||||
self,
|
||||
close_dict: Dict[str, pd.Series]
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
对齐多标的收益率
|
||||
|
||||
Args:
|
||||
close_dict: {标的代码: 收盘价序列}
|
||||
|
||||
Returns:
|
||||
收益率 DataFrame(所有标的同索引 = target_calendar)
|
||||
"""
|
||||
returns_dict = {}
|
||||
|
||||
for code, close_series in close_dict.items():
|
||||
try:
|
||||
returns_dict[code] = self.align_returns(close_series, code)
|
||||
except Exception as e:
|
||||
warnings.warn(f"{code}: 收益率对齐失败 - {e}")
|
||||
# 填充全 0
|
||||
returns_dict[code] = pd.Series(
|
||||
0.0,
|
||||
index=self.target_calendar,
|
||||
name=code
|
||||
)
|
||||
|
||||
# 合并为 DataFrame
|
||||
returns_df = pd.DataFrame(returns_dict, index=self.target_calendar)
|
||||
|
||||
# 最终验证:不能有 NaN
|
||||
if returns_df.isna().any().any():
|
||||
nan_cols = returns_df.columns[returns_df.isna().any()]
|
||||
raise ValueError(
|
||||
f"多标的收益率对齐后仍包含 NaN\n"
|
||||
f"NaN 列: {list(nan_cols)}\n"
|
||||
f"这不应该发生,请检查 align_returns 逻辑"
|
||||
)
|
||||
|
||||
return returns_df
|
||||
|
||||
def validate_alignment(
|
||||
self,
|
||||
signals: pd.DataFrame,
|
||||
returns_df: pd.DataFrame
|
||||
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
"""
|
||||
验证信号与收益率对齐,并返回对齐后的结果
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
returns_df: 收益率 DataFrame
|
||||
|
||||
Returns:
|
||||
(aligned_signals, aligned_returns)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果对齐后日期太少
|
||||
"""
|
||||
# 1. 找共同日期
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
|
||||
# 2. 检查丢失的日期
|
||||
lost_signals = len(signals) - len(common_dates)
|
||||
lost_returns = len(returns_df) - len(common_dates)
|
||||
|
||||
if lost_signals > 0 or lost_returns > 0:
|
||||
warnings.warn(
|
||||
f"信号与收益率对齐丢失日期\n"
|
||||
f"信号: {len(signals)} → {len(common_dates)} (丢失 {lost_signals})\n"
|
||||
f"收益: {len(returns_df)} → {len(common_dates)} (丢失 {lost_returns})"
|
||||
)
|
||||
|
||||
# 3. 检查对齐后日期是否太少
|
||||
if len(common_dates) < 10:
|
||||
raise ValueError(
|
||||
f"对齐后日期太少: {len(common_dates)} 天\n"
|
||||
f"信号和收益率可能使用了不同的日历"
|
||||
)
|
||||
|
||||
# 4. 裁剪到共同日期
|
||||
aligned_signals = signals.loc[common_dates]
|
||||
aligned_returns = returns_df.loc[common_dates]
|
||||
|
||||
# 5. 使用 Pydantic Schema 验证结果
|
||||
validation_result = AlignmentValidationResult(
|
||||
signals_aligned=True,
|
||||
returns_aligned=True,
|
||||
common_dates_count=len(common_dates),
|
||||
lost_signals=lost_signals,
|
||||
lost_returns=lost_returns
|
||||
)
|
||||
|
||||
# 6. 如果验证失败,会抛出异常
|
||||
# (Pydantic 自动验证 field_validator)
|
||||
|
||||
return aligned_signals, aligned_returns
|
||||
|
||||
def _validate_factor_alignment(
|
||||
self,
|
||||
aligned: pd.Series,
|
||||
is_filled: pd.Series,
|
||||
code: str
|
||||
):
|
||||
"""验证因子对齐结果"""
|
||||
# 1. 检查 NaN 比例
|
||||
nan_ratio = aligned.isna().sum() / len(aligned)
|
||||
if nan_ratio > self.max_nan_ratio:
|
||||
warnings.warn(
|
||||
f"{code}: 因子 NaN 比例过高 ({nan_ratio:.1%} > {self.max_nan_ratio:.1%})"
|
||||
)
|
||||
|
||||
# 2. 检查填充比例
|
||||
fill_ratio = is_filled.sum() / len(is_filled)
|
||||
if fill_ratio > 0.3:
|
||||
warnings.warn(
|
||||
f"{code}: 因子填充比例过高 ({fill_ratio:.1%})\n"
|
||||
f"可能源日历与目标日历差异太大"
|
||||
)
|
||||
|
||||
def _validate_returns(
|
||||
self,
|
||||
returns: pd.Series,
|
||||
code: str
|
||||
):
|
||||
"""验证收益率数据"""
|
||||
# 1. 检查 NaN 比例
|
||||
nan_ratio = returns.isna().sum() / len(returns)
|
||||
if nan_ratio > self.max_nan_ratio:
|
||||
raise ValueError(
|
||||
f"{code}: 收益率 NaN 比例过高 ({nan_ratio:.1%} > {self.max_nan_ratio:.1%})"
|
||||
)
|
||||
|
||||
# 2. 检查异常值
|
||||
max_return = returns.abs().max()
|
||||
if max_return > self.max_single_day_return:
|
||||
warnings.warn(
|
||||
f"{code}: 发现异常收益率 ({max_return:.1%} > {self.max_single_day_return:.1%})\n"
|
||||
f"可能数据有问题"
|
||||
)
|
||||
|
||||
# 3. 检查索引是否匹配目标日历
|
||||
if not returns.index.equals(self.target_calendar):
|
||||
raise ValueError(
|
||||
f"{code}: 收益率索引与目标日历不匹配\n"
|
||||
f"收益率长度: {len(returns)}\n"
|
||||
f"目标日历长度: {len(self.target_calendar)}"
|
||||
)
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""获取对齐统计信息"""
|
||||
return self._stats.copy()
|
||||
|
||||
def reset_stats(self):
|
||||
"""重置统计信息"""
|
||||
self._stats = {
|
||||
'aligned_factors': 0,
|
||||
'aligned_returns': 0,
|
||||
'warnings': []
|
||||
}
|
||||
269
archive/framework_v2/shared/data/flask_api_fetcher.py
Normal file
269
archive/framework_v2/shared/data/flask_api_fetcher.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Flask API 数据获取器(framework_v2 实现)
|
||||
|
||||
继承 DataFetcher 抽象基类,使用 FlaskAPIDataSource 获取线上数据
|
||||
支持指数、ETF 数据获取
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.core.data import DataFetcher
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
|
||||
|
||||
class FlaskAPIFetcher(DataFetcher):
|
||||
"""
|
||||
Flask API 数据获取器
|
||||
|
||||
通过 HTTP API 获取线上数据(指数、ETF)
|
||||
无需本地 SSH 隧道配置
|
||||
|
||||
用法:
|
||||
fetcher = FlaskAPIFetcher(base_url="https://k3s.tokenpluse.xyz")
|
||||
data = fetcher.fetch_indices(["000300.SH"], "2024-01-01", "2024-12-31")
|
||||
"""
|
||||
|
||||
name = "flask_api"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = None,
|
||||
timeout: int = 120,
|
||||
retries: int = 3
|
||||
):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
base_url: API 服务地址(默认从环境变量读取)
|
||||
timeout: 请求超时时间(秒)
|
||||
retries: 重试次数
|
||||
"""
|
||||
super().__init__(base_url=base_url, timeout=timeout, retries=retries)
|
||||
|
||||
# 创建底层数据源
|
||||
self._source = FlaskAPIDataSource(
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
retries=retries
|
||||
)
|
||||
|
||||
def fetch_indices(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str,
|
||||
adj: str = 'raw'
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取指数 OHLCV 数据
|
||||
|
||||
Args:
|
||||
codes: 指数代码列表(如 ["000300.SH", "000905.SH"])
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
adj: 复权类型,默认 'raw'(指数通常用原始价格)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典,DataFrame 包含 OHLCV 列
|
||||
|
||||
示例:
|
||||
>>> fetcher = FlaskAPIFetcher()
|
||||
>>> data = fetcher.fetch_indices(
|
||||
... ["000300.SH", "000905.SH"],
|
||||
... "2024-01-01",
|
||||
... "2024-12-31"
|
||||
... )
|
||||
>>> print(data["000300.SH"].head())
|
||||
"""
|
||||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只指数数据(adj='{adj}')...")
|
||||
|
||||
results = {}
|
||||
for i, code in enumerate(codes, 1):
|
||||
print(f" [{i}/{len(codes)}] {code}...")
|
||||
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj=adj # 使用传入的 adj 参数
|
||||
)
|
||||
|
||||
if df is not None:
|
||||
results[code] = df
|
||||
print(f" ✓ {len(df)} 条数据")
|
||||
else:
|
||||
print(f" ✗ 获取失败")
|
||||
|
||||
success = len(results)
|
||||
print(f"\n[FlaskAPI] 指数数据获取完成: {success}/{len(codes)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
def fetch_etf(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str,
|
||||
adj: str = 'hfq'
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取 ETF 数据(价格 + 净值)
|
||||
|
||||
Args:
|
||||
codes: ETF 代码列表(如 ["510300.SH", "159919.SZ"])
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
adj: 复权类型,默认 'hfq'(ETF 收益计算推荐后复权)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
DataFrame 包含 OHLCV 列
|
||||
df.attrs['nav'] 包含净值数据
|
||||
df.attrs['premium_series'] 包含溢价率序列
|
||||
|
||||
示例:
|
||||
>>> fetcher = FlaskAPIFetcher()
|
||||
>>> # 默认使用 hfq(后复权)
|
||||
>>> data = fetcher.fetch_etf(
|
||||
... ["510300.SH", "159919.SZ"],
|
||||
... "2024-01-01",
|
||||
... "2024-12-31"
|
||||
... )
|
||||
>>> # 或者显式指定 raw(原始价格,用于计算溢价率)
|
||||
>>> data_raw = fetcher.fetch_etf(
|
||||
... ["510300.SH"],
|
||||
... "2024-01-01",
|
||||
... "2024-12-31",
|
||||
... adj='raw'
|
||||
... )
|
||||
>>> # 访问净值
|
||||
>>> nav = data["510300.SH"].attrs.get('nav')
|
||||
"""
|
||||
print(f"\n[FlaskAPI] 获取 {len(codes)} 只 ETF 数据(adj='{adj}')...")
|
||||
|
||||
results = {}
|
||||
for i, code in enumerate(codes, 1):
|
||||
print(f" [{i}/{len(codes)}] {code}...")
|
||||
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj=adj, # 使用传入的 adj 参数
|
||||
asset_type='china_etf' # 强制指定 ETF 类型
|
||||
)
|
||||
|
||||
if df is not None:
|
||||
results[code] = df
|
||||
|
||||
# 显示附加信息
|
||||
nav_count = len(df.attrs.get('nav', pd.DataFrame()))
|
||||
premium = df.attrs.get('latest_premium', 'N/A')
|
||||
|
||||
print(f" ✓ {len(df)} 条价格, {nav_count} 条净值, 溢价率: {premium}%")
|
||||
else:
|
||||
print(f" ✗ 获取失败")
|
||||
|
||||
success = len(results)
|
||||
print(f"\n[FlaskAPI] ETF 数据获取完成: {success}/{len(codes)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
def get_trading_calendar(
|
||||
self,
|
||||
market: str = 'A',
|
||||
start: str = None,
|
||||
end: str = None
|
||||
) -> pd.Index:
|
||||
"""
|
||||
获取交易日历(通过 API)
|
||||
|
||||
Args:
|
||||
market: 市场代码
|
||||
- 'A' 或 'china': A股(上交所/深交所)
|
||||
- 'US' 或 'us': 美股(NYSE)
|
||||
- 'HK' 或 'hk': 港股(HKEX)
|
||||
start: 开始日期 YYYY-MM-DD(默认 2020-01-01)
|
||||
end: 结束日期 YYYY-MM-DD(默认 2025-12-31)
|
||||
|
||||
Returns:
|
||||
交易日历 DatetimeIndex
|
||||
|
||||
示例:
|
||||
>>> fetcher = FlaskAPIFetcher()
|
||||
>>> # 获取 A 股 2024 年交易日历
|
||||
>>> calendar = fetcher.get_trading_calendar('A', '2024-01-01', '2024-12-31')
|
||||
>>> # 获取美股交易日历
|
||||
>>> calendar = fetcher.get_trading_calendar('US', '2024-01-01', '2024-12-31')
|
||||
"""
|
||||
# 默认日期范围
|
||||
if start is None:
|
||||
start = '2020-01-01'
|
||||
if end is None:
|
||||
end = '2025-12-31'
|
||||
|
||||
# 调用 API 获取准确日历
|
||||
calendar = self._source.get_trading_calendar(
|
||||
market=market,
|
||||
start_date=start,
|
||||
end_date=end
|
||||
)
|
||||
|
||||
if calendar is None:
|
||||
# API 失败,抛出异常(不应静默降级)
|
||||
raise ValueError(
|
||||
f"交易日历获取失败: market={market}, {start} ~ {end}。"
|
||||
f"请检查 API 服务是否可用。"
|
||||
)
|
||||
|
||||
return calendar
|
||||
|
||||
def get_benchmark(
|
||||
self,
|
||||
code: str = "000300.SH",
|
||||
start: str = "2020-01-01",
|
||||
end: str = "2025-12-31"
|
||||
) -> pd.Series:
|
||||
"""
|
||||
获取基准数据
|
||||
|
||||
Args:
|
||||
code: 基准代码(默认沪深 300)
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
基准收盘价 Series
|
||||
"""
|
||||
df = self._source.fetch(
|
||||
code=code,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
adj='raw'
|
||||
)
|
||||
|
||||
if df is None:
|
||||
raise ValueError(f"基准数据获取失败: {code}")
|
||||
|
||||
return df['close']
|
||||
|
||||
def get_health(self) -> Dict:
|
||||
"""
|
||||
检查 API 服务健康状态
|
||||
|
||||
Returns:
|
||||
健康状态字典
|
||||
"""
|
||||
return self._source.get_health()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FlaskAPIFetcher(base_url={self._source.base_url})"
|
||||
258
archive/framework_v2/shared/data/schemas.py
Normal file
258
archive/framework_v2/shared/data/schemas.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
数据对齐 Schema 定义
|
||||
|
||||
与 CrossMarketAligner 配合使用,提供结构验证
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional, List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 输入验证 Schema
|
||||
# ============================================================
|
||||
|
||||
class OHLCVInputSchema(BaseModel):
|
||||
"""
|
||||
OHLCV 输入数据验证
|
||||
|
||||
用于对齐前验证原始数据
|
||||
"""
|
||||
# 必需字段
|
||||
close: float = Field(..., description="收盘价(必需)", gt=0)
|
||||
|
||||
# 可选字段
|
||||
open: Optional[float] = Field(None, description="开盘价", gt=0)
|
||||
high: Optional[float] = Field(None, description="最高价", gt=0)
|
||||
low: Optional[float] = Field(None, description="最低价", gt=0)
|
||||
volume: Optional[float] = Field(None, description="成交量", ge=0)
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # 忽略额外字段
|
||||
|
||||
@field_validator('close', 'open', 'high', 'low')
|
||||
@classmethod
|
||||
def check_positive(cls, v):
|
||||
"""价格必须为正数"""
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError(f"价格必须为正数,当前值: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class FactorInputSchema(BaseModel):
|
||||
"""
|
||||
因子输入数据验证
|
||||
|
||||
用于验证因子值在合理范围内
|
||||
"""
|
||||
value: float = Field(..., description="因子值")
|
||||
is_filled: bool = Field(False, description="是否为填充值")
|
||||
|
||||
@field_validator('value')
|
||||
@classmethod
|
||||
def check_reasonable(cls, v):
|
||||
"""因子值应在合理范围内(-10 ~ 10)"""
|
||||
if abs(v) > 10:
|
||||
import warnings
|
||||
warnings.warn(f"因子值异常: {v}")
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 输出验证 Schema
|
||||
# ============================================================
|
||||
|
||||
class AlignedFactorSchema(BaseModel):
|
||||
"""
|
||||
对齐后的因子数据验证
|
||||
|
||||
用于验证 align_factor() 的输出
|
||||
"""
|
||||
value: float = Field(..., description="对齐后的因子值")
|
||||
is_filled: bool = Field(..., description="是否为填充值")
|
||||
|
||||
class Config:
|
||||
# 允许 NaN(早期数据不足)
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AlignedReturnsSchema(BaseModel):
|
||||
"""
|
||||
对齐后的收益率数据验证
|
||||
|
||||
用于验证 align_returns() 的输出
|
||||
"""
|
||||
returns: float = Field(..., description="收益率")
|
||||
|
||||
@field_validator('returns')
|
||||
@classmethod
|
||||
def check_returns_range(cls, v):
|
||||
"""收益率应在合理范围内(-50% ~ 50%)"""
|
||||
if abs(v) > 0.5:
|
||||
import warnings
|
||||
warnings.warn(f"收益率异常: {v:.2%}")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 批量验证 Schema
|
||||
# ============================================================
|
||||
|
||||
class MultiAssetReturnsSchema(BaseModel):
|
||||
"""
|
||||
多标的收益率数据验证
|
||||
|
||||
用于验证 align_multi_asset() 的输出
|
||||
"""
|
||||
data: dict = Field(..., description="{标的代码: 收益率 Series}")
|
||||
|
||||
@field_validator('data')
|
||||
@classmethod
|
||||
def check_no_nan(cls, v):
|
||||
"""收益率 DataFrame 不能有 NaN"""
|
||||
df = pd.DataFrame(v)
|
||||
if df.isna().any().any():
|
||||
nan_cols = df.columns[df.isna().any()]
|
||||
raise ValueError(f"收益率包含 NaN 列: {list(nan_cols)}")
|
||||
return v
|
||||
|
||||
|
||||
class AlignmentValidationResult(BaseModel):
|
||||
"""
|
||||
对齐验证结果
|
||||
|
||||
用于 validate_alignment() 的输出
|
||||
"""
|
||||
signals_aligned: bool = Field(..., description="信号是否已对齐")
|
||||
returns_aligned: bool = Field(..., description="收益率是否已对齐")
|
||||
common_dates_count: int = Field(..., description="共同日期数量")
|
||||
lost_signals: int = Field(0, description="丢失的信号数")
|
||||
lost_returns: int = Field(0, description="丢失的收益数")
|
||||
|
||||
@field_validator('common_dates_count')
|
||||
@classmethod
|
||||
def check_min_dates(cls, v):
|
||||
"""共同日期至少 10 天"""
|
||||
if v < 10:
|
||||
raise ValueError(f"共同日期太少: {v} 天")
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 验证装饰器(与 Aligner 配合)
|
||||
# ============================================================
|
||||
|
||||
def validate_ohlcv_before_align(func):
|
||||
"""
|
||||
验证 OHLCV 数据在对齐前符合要求
|
||||
|
||||
使用示例:
|
||||
class CrossMarketAligner:
|
||||
@validate_ohlcv_before_align
|
||||
def align_factor(self, factor_series, source_calendar, code):
|
||||
...
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# 提取 close_series(第二个参数)
|
||||
if len(args) >= 1:
|
||||
close_series = args[0]
|
||||
else:
|
||||
close_series = kwargs.get('close_series')
|
||||
|
||||
if isinstance(close_series, pd.Series):
|
||||
# 验证 close 列
|
||||
if not pd.api.types.is_numeric_dtype(close_series):
|
||||
raise TypeError(
|
||||
f"close_series 必须是数值类型,当前是 {close_series.dtype}"
|
||||
)
|
||||
|
||||
if close_series.isna().all():
|
||||
raise ValueError("close_series 全为 NaN")
|
||||
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def validate_factor_after_align(func):
|
||||
"""
|
||||
验证因子对齐后符合要求
|
||||
|
||||
使用示例:
|
||||
class CrossMarketAligner:
|
||||
@validate_factor_after_align
|
||||
def align_factor(self, factor_series, source_calendar, code):
|
||||
...
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
result = func(self, *args, **kwargs)
|
||||
|
||||
# 验证返回类型
|
||||
if not isinstance(result, pd.DataFrame):
|
||||
raise TypeError(
|
||||
f"align_factor 必须返回 DataFrame,当前返回 {type(result)}"
|
||||
)
|
||||
|
||||
# 验证列
|
||||
required_cols = ['value', 'is_filled']
|
||||
missing_cols = [col for col in required_cols if col not in result.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"对齐后 DataFrame 缺少列: {missing_cols}")
|
||||
|
||||
# 验证 value 列类型
|
||||
if not pd.api.types.is_numeric_dtype(result['value']):
|
||||
raise TypeError(f"value 列必须是数值类型")
|
||||
|
||||
# 验证 is_filled 列类型
|
||||
if not pd.api.types.is_bool_dtype(result['is_filled']):
|
||||
raise TypeError(f"is_filled 列必须是布尔类型")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
|
||||
def validate_returns_after_align(func):
|
||||
"""
|
||||
验证收益率对齐后符合要求
|
||||
|
||||
使用示例:
|
||||
class CrossMarketAligner:
|
||||
@validate_returns_after_align
|
||||
def align_returns(self, close_series, code):
|
||||
...
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
result = func(self, *args, **kwargs)
|
||||
|
||||
# 验证返回类型
|
||||
if not isinstance(result, pd.Series):
|
||||
raise TypeError(
|
||||
f"align_returns 必须返回 Series,当前返回 {type(result)}"
|
||||
)
|
||||
|
||||
# 验证无 NaN
|
||||
if result.isna().any():
|
||||
nan_count = result.isna().sum()
|
||||
raise ValueError(f"收益率包含 {nan_count} 个 NaN")
|
||||
|
||||
# 验证收益率范围
|
||||
max_return = result.abs().max()
|
||||
if max_return > 0.5:
|
||||
import warnings
|
||||
warnings.warn(f"发现异常收益率: {max_return:.2%}")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
17
archive/framework_v2/shared/factors/__init__.py
Normal file
17
archive/framework_v2/shared/factors/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
通用因子实现
|
||||
"""
|
||||
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor
|
||||
|
||||
# TALibFactorBase 需要安装 talib,可选导入
|
||||
try:
|
||||
from framework_v2.shared.factors.talib_base import TALibFactorBase
|
||||
__all__ = [
|
||||
'TALibFactorBase',
|
||||
'MomentumFactor',
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
'MomentumFactor',
|
||||
]
|
||||
104
archive/framework_v2/shared/factors/momentum.py
Normal file
104
archive/framework_v2/shared/factors/momentum.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
动量因子(通用版本)
|
||||
|
||||
使用加权线性回归:得分 = 年化收益率 × R²
|
||||
|
||||
与现有 MomentumFactor 对比验证:
|
||||
- 输入相同 → 输出应该相同
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import math
|
||||
from framework_v2.core import FactorBase
|
||||
|
||||
|
||||
class MomentumFactor(FactorBase):
|
||||
"""
|
||||
动量因子
|
||||
|
||||
计算加权线性回归动量得分:
|
||||
得分 = 年化收益率 × R²
|
||||
|
||||
参数:
|
||||
- n_days: 动量窗口(默认25)
|
||||
- weighted: 是否加权(默认True)
|
||||
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||
"""
|
||||
|
||||
name = "momentum"
|
||||
category = "momentum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_days: int = 25,
|
||||
weighted: bool = True,
|
||||
crash_filter: bool = True
|
||||
):
|
||||
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||
self.n_days = n_days
|
||||
self.weighted = weighted
|
||||
self.crash_filter = crash_filter
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算动量因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.weighted:
|
||||
factor_values = prices.rolling(self.n_days).apply(
|
||||
lambda x: self._weighted_momentum_score(x.values),
|
||||
raw=False
|
||||
)
|
||||
else:
|
||||
factor_values = prices.pct_change(self.n_days)
|
||||
|
||||
if self.crash_filter:
|
||||
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||
|
||||
return factor_values
|
||||
|
||||
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||
"""计算加权动量得分(完全复制现有逻辑)"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
# 价格下界 clip,防止 log(0) 或 log(负数)
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
|
||||
# 异常值检测
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_returns = math.exp(slope * 250) - 1
|
||||
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
|
||||
return annualized_returns * r2
|
||||
|
||||
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
|
||||
"""崩盘过滤:连续3天跌>5%清零(完全复制现有逻辑)"""
|
||||
result = factor_values.copy()
|
||||
|
||||
for i in range(3, len(prices)):
|
||||
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||
|
||||
if con1 or con2:
|
||||
result.iloc[i] = 0.0
|
||||
|
||||
return result
|
||||
55
archive/framework_v2/shared/factors/talib_base.py
Normal file
55
archive/framework_v2/shared/factors/talib_base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
ta-lib 因子基类(通用)
|
||||
|
||||
所有 ta-lib 因子继承此类,只需指定函数和参数
|
||||
"""
|
||||
|
||||
import talib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from framework_v2.core import FactorBase
|
||||
|
||||
|
||||
class TALibFactorBase(FactorBase):
|
||||
"""
|
||||
ta-lib 因子基类
|
||||
|
||||
子类只需实现:
|
||||
- name: 因子名称
|
||||
- _talib_func: 返回 ta-lib 函数
|
||||
"""
|
||||
|
||||
category = "technical"
|
||||
|
||||
def __init__(self, period: int = 14, **params):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
period: 周期参数
|
||||
**params: 其他参数
|
||||
"""
|
||||
super().__init__(period=period, **params)
|
||||
self.period = period
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据
|
||||
|
||||
Returns:
|
||||
因子值序列
|
||||
"""
|
||||
close = data['close'].values.astype(float)
|
||||
|
||||
# 调用子类指定的 ta-lib 函数
|
||||
result = self._talib_func(close, timeperiod=self.period)
|
||||
|
||||
return pd.Series(result, index=data.index, name=self.name)
|
||||
|
||||
@property
|
||||
def _talib_func(self):
|
||||
"""子类必须实现,返回 ta-lib 函数"""
|
||||
raise NotImplementedError("Subclasses must implement _talib_func")
|
||||
Reference in New Issue
Block a user