Compare commits
2 Commits
72df18a28b
...
908b28473f
| Author | SHA1 | Date | |
|---|---|---|---|
| 908b28473f | |||
| 226a27361f |
@@ -14,6 +14,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .models import OHLCVResponse, validate_ohlcv_response
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -132,14 +134,20 @@ class FlaskAPIDataSource:
|
||||
print(f"✗ API返回错误: {data['error']}")
|
||||
return None
|
||||
|
||||
# 解析数据
|
||||
records = data.get('data', [])
|
||||
if not records:
|
||||
# ✅ 使用 Pydantic 模型验证响应(类型安全)
|
||||
try:
|
||||
validated = validate_ohlcv_response(data)
|
||||
except Exception as e:
|
||||
print(f"✗ {code}: 响应数据验证失败 - {e}")
|
||||
return None
|
||||
|
||||
# 检查数据是否为空
|
||||
if not validated.data:
|
||||
print(f"⚠ {code}: 无数据返回")
|
||||
return None
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(records)
|
||||
df = pd.DataFrame(validated.data)
|
||||
|
||||
# 处理日期列
|
||||
if 'date' in df.columns:
|
||||
@@ -153,37 +161,37 @@ class FlaskAPIDataSource:
|
||||
df = df[standard_cols]
|
||||
|
||||
# 使用 API 返回的实际数据范围(而非请求参数)
|
||||
actual_start = data.get('date_range', {}).get('start', start_date)
|
||||
actual_end = data.get('date_range', {}).get('end', end_date)
|
||||
actual_count = data.get('count', len(df))
|
||||
actual_start = validated.date_range.start if validated.date_range else start_date
|
||||
actual_end = validated.date_range.end if validated.date_range else end_date
|
||||
actual_count = validated.count
|
||||
|
||||
# 缓存 info 信息(如果有)
|
||||
if 'info' in data:
|
||||
df.attrs['info'] = data['info']
|
||||
if validated.info:
|
||||
df.attrs['info'] = validated.info
|
||||
|
||||
# ETF 数据自动附加净值和溢价率信息
|
||||
if data.get('asset_type') == 'china_etf':
|
||||
if validated.asset_type == 'china_etf':
|
||||
# 净值数据
|
||||
nav_section = data.get('nav', {})
|
||||
if nav_section.get('data'):
|
||||
nav_df = pd.DataFrame(nav_section['data'])
|
||||
if validated.nav and validated.nav.data:
|
||||
nav_df = pd.DataFrame(validated.nav.data)
|
||||
if 'date' in nav_df.columns:
|
||||
nav_df['date'] = pd.to_datetime(nav_df['date'])
|
||||
nav_df = nav_df.set_index('date')
|
||||
df.attrs['nav'] = nav_df
|
||||
|
||||
# 溢价率序列
|
||||
if 'premium_series' in data:
|
||||
df.attrs['premium_series'] = data['premium_series']
|
||||
if validated.premium_series:
|
||||
premium_dict = {item.date: item.premium for item in validated.premium_series}
|
||||
df.attrs['premium_series'] = premium_dict
|
||||
|
||||
# 最新溢价率
|
||||
if 'latest_premium' in data:
|
||||
df.attrs['latest_premium'] = data['latest_premium']
|
||||
df.attrs['premium_date'] = data.get('premium_date')
|
||||
if validated.latest_premium is not None:
|
||||
df.attrs['latest_premium'] = validated.latest_premium
|
||||
df.attrs['premium_date'] = validated.premium_date
|
||||
|
||||
# 溢价率统计
|
||||
if 'premium_stats' in data:
|
||||
df.attrs['premium_stats'] = data['premium_stats']
|
||||
if validated.premium_stats:
|
||||
df.attrs['premium_stats'] = validated.premium_stats.model_dump()
|
||||
|
||||
print(f"✓ {code}: {actual_count} 条数据 ({actual_start} ~ {actual_end})")
|
||||
return df
|
||||
|
||||
@@ -44,6 +44,7 @@ import pandas as pd
|
||||
|
||||
from datasource.universal_fetcher import UniversalDataFetcher
|
||||
from datasource.asset_type_detector import AssetTypeDetector, AssetType
|
||||
from datasource.models import dataframe_to_ohlcv_response, OHLCVResponse, ErrorResponse
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -730,91 +731,71 @@ def get_ohlcv():
|
||||
result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj, final_type)
|
||||
|
||||
if result is None:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": final_type.value,
|
||||
"adj": adj,
|
||||
"detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示
|
||||
"error": "No data available",
|
||||
"start": start,
|
||||
"end": end,
|
||||
}), 404
|
||||
error_response = ErrorResponse(
|
||||
error="No data available",
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
detected_type=detected_type.value if asset_type_param else None,
|
||||
)
|
||||
return error_response.model_dump(mode='json'), 404
|
||||
|
||||
if "error" in result:
|
||||
return jsonify({
|
||||
"code": code,
|
||||
"asset_type": final_type.value,
|
||||
"adj": adj,
|
||||
"detected_type": detected_type.value if asset_type_param else None,
|
||||
"error": result["error"],
|
||||
}), 500
|
||||
error_response = ErrorResponse(
|
||||
error=result["error"],
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
detected_type=detected_type.value if asset_type_param else None,
|
||||
)
|
||||
return error_response.model_dump(mode='json'), 500
|
||||
|
||||
result['cached'] = is_cached
|
||||
result['asset_type'] = final_type.value # 使用最终类型
|
||||
result['adj'] = adj # 返回使用的 adj 参数
|
||||
# ✅ 使用 Pydantic 模型构建响应(类型安全)
|
||||
# 从 result 中提取数据
|
||||
df_data = result.get('data', [])
|
||||
attrs = result.get('attrs', {})
|
||||
|
||||
# API 层职责:决定如何使用 attrs 中的业务数据
|
||||
if 'attrs' in result:
|
||||
attrs = result['attrs']
|
||||
|
||||
# 根据资产类型决定日期格式精度
|
||||
# 加密货币使用分钟级,其他使用天级
|
||||
date_format = '%Y-%m-%d %H:%M:%S' if final_type == AssetType.CRYPTO else '%Y-%m-%d'
|
||||
|
||||
# 提取净值到顶层(方便调用方使用)
|
||||
if 'nav' in attrs:
|
||||
nav_df = attrs['nav']
|
||||
if isinstance(nav_df, pd.DataFrame):
|
||||
# 将 DataFrame 转换为列表格式(JSON 可序列化)
|
||||
nav_df_copy = nav_df.reset_index().copy()
|
||||
nav_df_copy['date'] = nav_df_copy['date'].dt.strftime(date_format)
|
||||
nav_dict = {
|
||||
'data': nav_df_copy.to_dict(orient='records'),
|
||||
'count': len(nav_df_copy)
|
||||
}
|
||||
result['nav'] = nav_dict
|
||||
else:
|
||||
result['nav'] = nav_df
|
||||
|
||||
# 提取溢价率到顶层(调用业务函数处理格式)
|
||||
if 'premium' in attrs:
|
||||
premium_result = build_premium_result_from_attrs(attrs['premium'])
|
||||
if premium_result:
|
||||
result.update(premium_result)
|
||||
|
||||
# 将 attrs 中的 DataFrame/Series 转换为字典格式(用于 JSON 序列化)
|
||||
attrs_serializable = {}
|
||||
for key, value in attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
df_copy = value.reset_index().copy()
|
||||
if 'date' in df_copy.columns:
|
||||
df_copy['date'] = df_copy['date'].dt.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'data': df_copy.to_dict(orient='records'),
|
||||
'count': len(df_copy)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
# 将 Series 索引转换为字符串
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
result['attrs'] = attrs_serializable
|
||||
# 重建 DataFrame(用于转换函数)
|
||||
if df_data:
|
||||
df = pd.DataFrame(df_data)
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
else:
|
||||
df = pd.DataFrame()
|
||||
|
||||
# 如果用户指定了类型但与自动检测不同,显示提示
|
||||
if asset_type_param and detected_type != final_type:
|
||||
result['type_override'] = {
|
||||
# 提取 nav DataFrame
|
||||
nav_df = attrs.get('nav') if isinstance(attrs.get('nav'), pd.DataFrame) else None
|
||||
|
||||
# 提取 premium Series
|
||||
premium_series = attrs.get('premium') if isinstance(attrs.get('premium'), pd.Series) else None
|
||||
|
||||
# 构建响应模型
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df=df if len(df) > 0 else None,
|
||||
code=code,
|
||||
asset_type=final_type.value,
|
||||
adj=adj,
|
||||
cached=is_cached,
|
||||
nav_df=nav_df,
|
||||
premium_series=premium_series,
|
||||
info=attrs.get('info'),
|
||||
attrs=attrs,
|
||||
columns=result.get('columns'),
|
||||
date_range=result.get('date_range'),
|
||||
requested_range=result.get('requested_range'),
|
||||
available_range=result.get('available_range'),
|
||||
cache_strategy=result.get('cache_strategy'),
|
||||
timeframe=result.get('timeframe'),
|
||||
type_override={
|
||||
"detected": detected_type.value,
|
||||
"specified": final_type.value,
|
||||
"hint": "用户强制覆盖了自动检测结果",
|
||||
}
|
||||
return jsonify(result)
|
||||
} if (asset_type_param and detected_type != final_type) else None,
|
||||
)
|
||||
|
||||
# ✅ 自动序列化为 JSON
|
||||
return response.model_dump(mode='json')
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -328,12 +328,13 @@ def validate_ohlcv_response(data: Dict[str, Any]) -> OHLCVResponse:
|
||||
return OHLCVResponse.model_validate(data)
|
||||
|
||||
|
||||
def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
def dataframe_to_records(df, date_format: str = '%Y-%m-%d') -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将 DataFrame 转换为 OHLCVRecord 兼容的字典列表
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame
|
||||
date_format: 日期格式字符串
|
||||
|
||||
Returns:
|
||||
字典列表(可直接用于 OHLCVResponse.data)
|
||||
@@ -349,7 +350,7 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
if col in df_reset.columns:
|
||||
try:
|
||||
import pandas as pd
|
||||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime('%Y-%m-%d')
|
||||
df_reset[col] = pd.to_datetime(df_reset[col]).dt.strftime(date_format)
|
||||
if col != 'date':
|
||||
df_reset = df_reset.rename(columns={col: 'date'})
|
||||
break
|
||||
@@ -357,3 +358,179 @@ def dataframe_to_records(df) -> List[Dict[str, Any]]:
|
||||
pass
|
||||
|
||||
return df_reset.to_dict(orient='records')
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DataFrame → Pydantic Model 转换函数
|
||||
# ============================================================
|
||||
|
||||
def dataframe_to_ohlcv_response(
|
||||
df: Any, # pd.DataFrame
|
||||
code: str,
|
||||
asset_type: str,
|
||||
adj: str = 'raw',
|
||||
cached: bool = False,
|
||||
nav_df: Optional[Any] = None, # Optional[pd.DataFrame]
|
||||
premium_series: Optional[Any] = None, # Optional[pd.Series]
|
||||
info: Optional[Dict[str, Any]] = None,
|
||||
attrs: Optional[Dict[str, Any]] = None,
|
||||
date_format: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> 'OHLCVResponse':
|
||||
"""
|
||||
将 DataFrame 转换为 OHLCVResponse 模型
|
||||
|
||||
用途:
|
||||
- Flask API: 统一响应结构
|
||||
- 本地调用: 获得类型安全的响应对象
|
||||
|
||||
Args:
|
||||
df: 主数据 DataFrame(OHLCV)
|
||||
code: 标的代码
|
||||
asset_type: 资产类型
|
||||
adj: 复权类型
|
||||
cached: 是否命中缓存
|
||||
nav_df: ETF 净值 DataFrame(可选)
|
||||
premium_series: 溢价率 Series(可选)
|
||||
info: 标的信息(可选)
|
||||
attrs: 完整元数据(可选)
|
||||
date_format: 日期格式(可选,默认根据 asset_type 自动选择)
|
||||
**kwargs: 其他字段(columns, date_range, timeframe 等)
|
||||
|
||||
Returns:
|
||||
OHLCVResponse 模型实例
|
||||
|
||||
使用示例:
|
||||
# Flask API
|
||||
df = fetcher.fetch("META", start, end)
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df, code="META", asset_type="us_stock", adj="raw"
|
||||
)
|
||||
return response.model_dump(mode='json')
|
||||
|
||||
# 本地调用
|
||||
df = fetcher.fetch("513100.SH", start, end)
|
||||
nav_df = df.attrs.get('nav')
|
||||
premium = df.attrs.get('premium')
|
||||
|
||||
response = dataframe_to_ohlcv_response(
|
||||
df,
|
||||
code="513100.SH",
|
||||
asset_type="china_etf",
|
||||
nav_df=nav_df,
|
||||
premium_series=premium
|
||||
)
|
||||
print(response.nav.count) # IDE 有自动补全
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
# 自动选择日期格式
|
||||
if date_format is None:
|
||||
date_format = '%Y-%m-%d %H:%M:%S' if asset_type == 'crypto' else '%Y-%m-%d'
|
||||
|
||||
# 转换主数据
|
||||
data = dataframe_to_records(df, date_format) if df is not None else []
|
||||
|
||||
# 构建响应数据
|
||||
response_data = {
|
||||
"code": code,
|
||||
"asset_type": asset_type,
|
||||
"adj": adj,
|
||||
"count": len(data),
|
||||
"data": data,
|
||||
"cached": cached,
|
||||
}
|
||||
|
||||
# 添加 info(优先使用传入的,其次从 df.attrs 获取)
|
||||
if info is not None:
|
||||
response_data['info'] = info
|
||||
elif hasattr(df, 'attrs') and df.attrs and 'info' in df.attrs:
|
||||
response_data['info'] = df.attrs['info']
|
||||
|
||||
# 添加 nav(如果有)
|
||||
if nav_df is not None and isinstance(nav_df, pd.DataFrame):
|
||||
nav_records = dataframe_to_records(nav_df, date_format)
|
||||
response_data['nav'] = {
|
||||
"data": nav_records,
|
||||
"count": len(nav_records)
|
||||
}
|
||||
|
||||
# 添加 premium(如果有)
|
||||
if premium_series is not None and isinstance(premium_series, pd.Series) and len(premium_series) > 0:
|
||||
# 最新溢价率
|
||||
latest_premium = float(premium_series.iloc[-1])
|
||||
response_data['latest_premium'] = round(latest_premium, 6)
|
||||
response_data['premium_date'] = premium_series.index[-1].strftime(date_format)
|
||||
|
||||
# 溢价率序列
|
||||
premium_list = [
|
||||
{"date": date.strftime(date_format), "premium": round(float(premium), 6)}
|
||||
for date, premium in premium_series.items()
|
||||
]
|
||||
response_data['premium_series'] = premium_list
|
||||
|
||||
# 溢价率统计
|
||||
response_data['premium_stats'] = {
|
||||
"mean": round(float(premium_series.mean()), 6),
|
||||
"std": round(float(premium_series.std()), 6),
|
||||
"min": round(float(premium_series.min()), 6),
|
||||
"max": round(float(premium_series.max()), 6),
|
||||
"median": round(float(premium_series.median()), 6),
|
||||
}
|
||||
|
||||
# 添加 attrs(如果有)
|
||||
if attrs is not None:
|
||||
# 过滤内部缓存元数据
|
||||
public_attrs = {k: v for k, v in attrs.items() if not k.startswith('_cache_')}
|
||||
|
||||
# 转换 DataFrame/Series 为可序列化格式
|
||||
attrs_serializable = {}
|
||||
for key, value in public_attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
attrs_serializable[key] = {
|
||||
'data': dataframe_to_records(value, date_format),
|
||||
'count': len(value)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
if attrs_serializable:
|
||||
response_data['attrs'] = attrs_serializable
|
||||
elif hasattr(df, 'attrs') and df.attrs:
|
||||
# 从 df.attrs 提取
|
||||
public_attrs = {k: v for k, v in df.attrs.items() if not k.startswith('_cache_')}
|
||||
if public_attrs:
|
||||
attrs_serializable = {}
|
||||
for key, value in public_attrs.items():
|
||||
if isinstance(value, pd.DataFrame):
|
||||
attrs_serializable[key] = {
|
||||
'data': dataframe_to_records(value, date_format),
|
||||
'count': len(value)
|
||||
}
|
||||
elif isinstance(value, pd.Series):
|
||||
series_copy = value.copy()
|
||||
series_copy.index = series_copy.index.strftime(date_format)
|
||||
attrs_serializable[key] = {
|
||||
'type': 'series',
|
||||
'data': series_copy.to_dict(),
|
||||
'name': value.name
|
||||
}
|
||||
else:
|
||||
attrs_serializable[key] = value
|
||||
|
||||
if attrs_serializable:
|
||||
response_data['attrs'] = attrs_serializable
|
||||
|
||||
# 添加其他辅助信息
|
||||
response_data.update(kwargs)
|
||||
|
||||
# 验证并返回模型
|
||||
return OHLCVResponse.model_validate(response_data)
|
||||
|
||||
180
framework_v2/README.md
Normal file
180
framework_v2/README.md
Normal file
@@ -0,0 +1,180 @@
|
||||
# 框架 V2 - 重构版本
|
||||
|
||||
## 📋 设计理念
|
||||
|
||||
### 三层架构
|
||||
|
||||
```
|
||||
framework_v2/
|
||||
├── core/ # 纯抽象接口(零实现)
|
||||
├── shared/ # 通用实现(2+策略复用)
|
||||
└── tests/ # 框架测试
|
||||
```
|
||||
|
||||
### 设计原则
|
||||
|
||||
1. **按需抽象**:不预先设计,只抽象已验证的通用逻辑
|
||||
2. **职责分离**:数据获取、因子计算、信号生成、回测执行各司其职
|
||||
3. **向后兼容**:与现有策略并行运行,验证一致后再替换
|
||||
4. **测试驱动**:每个组件必须有对比验证测试
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ 目录结构
|
||||
|
||||
```
|
||||
framework_v2/
|
||||
├── __init__.py
|
||||
├── README.md
|
||||
│
|
||||
├── core/ # 核心抽象接口
|
||||
│ ├── __init__.py
|
||||
│ ├── strategy.py # StrategyBase (ABC)
|
||||
│ ├── factor.py # FactorBase (ABC)
|
||||
│ ├── signal.py # SignalGenerator (ABC)
|
||||
│ ├── executor.py # Executor (ABC)
|
||||
│ └── data.py # DataFetcher (ABC)
|
||||
│
|
||||
├── shared/ # 通用实现
|
||||
│ ├── __init__.py
|
||||
│ └── factors/
|
||||
│ ├── __init__.py
|
||||
│ ├── talib_base.py # TALibFactorBase (需要 talib)
|
||||
│ └── momentum.py # 动量因子(已验证✓)
|
||||
│
|
||||
└── tests/ # 测试
|
||||
├── __init__.py
|
||||
└── test_momentum_parity.py # 因子对比测试(通过✓)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 已完成
|
||||
|
||||
### 阶段1: 核心接口层 ✓
|
||||
|
||||
- [x] StrategyBase - 策略抽象基类
|
||||
- [x] FactorBase - 因子抽象基类
|
||||
- [x] SignalGenerator - 信号生成器抽象基类
|
||||
- [x] Executor - 执行器抽象基类
|
||||
- [x] DataFetcher - 数据获取器抽象基类
|
||||
|
||||
### 阶段2: 通用因子层 ✓
|
||||
|
||||
- [x] MomentumFactor - 动量因子(完全复制现有逻辑)
|
||||
- [x] 对比验证测试(通过✓,差异 = 0)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 验证结果
|
||||
|
||||
### MomentumFactor 对比测试
|
||||
|
||||
```
|
||||
============================================================
|
||||
MomentumFactor 对比测试
|
||||
============================================================
|
||||
|
||||
1. 加载测试数据...
|
||||
⚠ 未找到测试数据,使用模拟数据
|
||||
|
||||
2. 计算旧因子(strategies/shared/factors/momentum.py)...
|
||||
✓ 旧因子计算完成
|
||||
结果范围: -0.8515 ~ 8.5805
|
||||
NaN 数量: 22
|
||||
|
||||
3. 计算新因子(framework_v2/shared/factors/momentum.py)...
|
||||
✓ 新因子计算完成
|
||||
结果范围: -0.8515 ~ 8.5805
|
||||
NaN 数量: 22
|
||||
|
||||
4. 对比结果...
|
||||
✓ 索引一致
|
||||
最大差异: 0.000000e+00
|
||||
平均差异: 0.000000e+00
|
||||
✓ 差异在容差范围内 (< 1e-10)
|
||||
|
||||
============================================================
|
||||
✓ 测试通过:新旧因子输出完全一致!
|
||||
============================================================
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📝 下一步计划
|
||||
|
||||
### 阶段3: 信号层迁移
|
||||
|
||||
- [ ] TopNSelector - Top N 选股器
|
||||
- [ ] DynamicThreshold - 动态阈值(V3逻辑)
|
||||
- [ ] RebalanceController - 调仓控制器
|
||||
- [ ] 信号对比验证测试
|
||||
|
||||
### 阶段4: 执行层迁移
|
||||
|
||||
- [ ] BacktestRunner - 回测执行器
|
||||
- [ ] 收益计算对比测试
|
||||
|
||||
### 阶段5: 数据层迁移
|
||||
|
||||
- [ ] RotationDataFetcher - 轮动策略数据获取器
|
||||
- [ ] CrossMarketAligner - 跨市场对齐器
|
||||
|
||||
### 阶段6: 策略组装
|
||||
|
||||
- [ ] RotationStrategyV2 - 新框架轮动策略
|
||||
- [ ] 完整策略对比测试
|
||||
|
||||
---
|
||||
|
||||
## 🔧 使用方法
|
||||
|
||||
### 运行测试
|
||||
|
||||
```bash
|
||||
# 运行因子对比测试
|
||||
python framework_v2/tests/test_momentum_parity.py
|
||||
|
||||
# 运行所有测试
|
||||
python -m pytest framework_v2/tests/
|
||||
```
|
||||
|
||||
### 使用新因子
|
||||
|
||||
```python
|
||||
from framework_v2.shared.factors import MomentumFactor
|
||||
|
||||
# 创建因子
|
||||
factor = MomentumFactor(n_days=25, weighted=True, crash_filter=True)
|
||||
|
||||
# 计算因子值
|
||||
import pandas as pd
|
||||
data = pd.DataFrame({'close': [...]}, index=[...])
|
||||
factor_values = factor.compute(data)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 与旧框架对比
|
||||
|
||||
| 维度 | 旧框架 (framework/) | 新框架 (framework_v2/) |
|
||||
|------|---------------------|------------------------|
|
||||
| **架构** | 抽象+实现混杂 | 三层分离(core/shared/tests) |
|
||||
| **因子** | 独立实现 | TALibFactorBase + 定制继承 |
|
||||
| **信号** | 包含所有逻辑 | 拆分为 Signal + Threshold + Rebalance |
|
||||
| **数据** | 耦合在策略中 | DataFetcher 抽象 |
|
||||
| **测试** | 部分覆盖 | 每个组件必须有对比测试 |
|
||||
| **状态** | 生产环境 ✓ | 开发中 🚧 |
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ 注意事项
|
||||
|
||||
1. **talib 依赖**:TALibFactorBase 需要安装 `ta-lib`,但未安装不影响 MomentumFactor 使用
|
||||
2. **并行开发**:新框架与旧框架并行,不修改现有代码
|
||||
3. **验证优先**:每个模块迁移后立即验证,确保结果一致
|
||||
|
||||
---
|
||||
|
||||
*创建日期: 2026-05-06*
|
||||
*版本: 2.0.0*
|
||||
15
framework_v2/__init__.py
Normal file
15
framework_v2/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
框架 V2 - 重构版本
|
||||
|
||||
三层架构:
|
||||
├── core/ # 纯抽象接口(零实现)
|
||||
├── shared/ # 通用实现(2+策略复用)
|
||||
└── tests/ # 框架测试
|
||||
|
||||
设计原则:
|
||||
├── 按需抽象,不预先设计
|
||||
├── 只放通用逻辑,定制逻辑在 strategies/
|
||||
└── 每个组件必须有测试
|
||||
"""
|
||||
|
||||
__version__ = "2.0.0"
|
||||
19
framework_v2/core/__init__.py
Normal file
19
framework_v2/core/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
核心抽象接口层(纯ABC,零实现)
|
||||
|
||||
只定义策略框架的标准接口,不包含任何业务逻辑
|
||||
"""
|
||||
|
||||
from framework_v2.core.strategy import StrategyBase
|
||||
from framework_v2.core.factor import FactorBase
|
||||
from framework_v2.core.signal import SignalGenerator
|
||||
from framework_v2.core.executor import Executor
|
||||
from framework_v2.core.data import DataFetcher
|
||||
|
||||
__all__ = [
|
||||
'StrategyBase',
|
||||
'FactorBase',
|
||||
'SignalGenerator',
|
||||
'Executor',
|
||||
'DataFetcher',
|
||||
]
|
||||
97
framework_v2/core/data.py
Normal file
97
framework_v2/core/data.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
数据获取器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DataFetcher(ABC):
|
||||
"""
|
||||
数据获取器抽象基类
|
||||
|
||||
所有数据获取器必须实现必要方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化数据获取器参数
|
||||
|
||||
Args:
|
||||
**params: 数据源参数(如 api_url, ssh_config 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def fetch_indices(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取指数 OHLCV 数据
|
||||
|
||||
Args:
|
||||
codes: 指数代码列表
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_etf(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取 ETF 数据(价格 + 净值)
|
||||
|
||||
Args:
|
||||
codes: ETF 代码列表
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_trading_calendar(self, market: str = 'A') -> pd.Index:
|
||||
"""
|
||||
获取交易日历
|
||||
|
||||
Args:
|
||||
market: 市场代码('A', 'US', 'HK' 等)
|
||||
|
||||
Returns:
|
||||
交易日历 Index
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_benchmark(self, code: str, start: str, end: str) -> pd.Series:
|
||||
"""
|
||||
获取基准数据(可选)
|
||||
|
||||
Args:
|
||||
code: 基准代码
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
基准收盘价 Series
|
||||
"""
|
||||
raise NotImplementedError("Optional method")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
46
framework_v2/core/executor.py
Normal file
46
framework_v2/core/executor.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
执行器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""
|
||||
执行器抽象基类
|
||||
|
||||
所有执行器必须实现 execute 方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化执行器参数
|
||||
|
||||
Args:
|
||||
**params: 执行参数(如 initial_capital, trade_cost 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> dict:
|
||||
"""
|
||||
执行信号
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
data: 收益率数据 DataFrame
|
||||
|
||||
Returns:
|
||||
回测结果字典,包含:
|
||||
- result: 回测 DataFrame(含净值、收益率)
|
||||
- portfolio: 组合对象(可选)
|
||||
- metrics: 绩效指标(可选)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
59
framework_v2/core/factor.py
Normal file
59
framework_v2/core/factor.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
因子抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class FactorBase(ABC):
|
||||
"""
|
||||
因子抽象基类
|
||||
|
||||
所有因子必须实现 compute 方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
category: str = "unknown"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化因子参数
|
||||
|
||||
Args:
|
||||
**params: 因子参数(如 n_days, weighted 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据,必须包含 'close' 列
|
||||
|
||||
Returns:
|
||||
因子值序列(与 data 同索引)
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证数据是否满足计算要求
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据
|
||||
|
||||
Returns:
|
||||
True 如果数据有效
|
||||
"""
|
||||
if 'close' not in data.columns:
|
||||
return False
|
||||
|
||||
min_periods = self._params.get('min_periods', 20)
|
||||
return len(data) >= min_periods
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
57
framework_v2/core/signal.py
Normal file
57
framework_v2/core/signal.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SignalGenerator(ABC):
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
|
||||
所有信号生成器必须实现 generate 方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化信号生成器参数
|
||||
|
||||
Args:
|
||||
**params: 信号参数(如 select_num, rebalance_days 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成交易信号
|
||||
|
||||
Args:
|
||||
factor_data: 因子数据 DataFrame
|
||||
|
||||
Returns:
|
||||
信号 DataFrame,必须包含 'signal' 列
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证因子数据是否有效
|
||||
|
||||
Args:
|
||||
factor_data: 因子数据
|
||||
|
||||
Returns:
|
||||
True 如果数据有效
|
||||
"""
|
||||
if factor_data.empty:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
151
framework_v2/core/strategy.py
Normal file
151
framework_v2/core/strategy.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
所有策略必须继承此类并实现必要方法
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置
|
||||
2. 获取数据
|
||||
3. 计算因子
|
||||
4. 生成信号
|
||||
5. 执行回测
|
||||
|
||||
子类必须实现:
|
||||
- init_factors(): 初始化因子
|
||||
- init_signal_generator(): 初始化信号生成器
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置字典
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._factor = None
|
||||
self._signal_generator = None
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> Any:
|
||||
"""
|
||||
初始化因子组件
|
||||
|
||||
Returns:
|
||||
因子实例(继承 FactorBase)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> Any:
|
||||
"""
|
||||
初始化信号生成器
|
||||
|
||||
Returns:
|
||||
信号生成器实例(继承 SignalGenerator)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取数据(可选覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典,包含:
|
||||
- index_data: 指数数据
|
||||
- etf_data: ETF数据
|
||||
- benchmark_data: 基准数据
|
||||
- valid_codes: 有效标的列表
|
||||
- trading_calendar: 交易日历
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_data()")
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
计算因子(可选覆盖)
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
因子 DataFrame(日期 × 标的)
|
||||
"""
|
||||
if self._factor is None:
|
||||
self._factor = self.init_factors()
|
||||
|
||||
# 默认实现:遍历标的计算因子
|
||||
factor_values = {}
|
||||
for code in data.get('valid_codes', []):
|
||||
if code in data.get('index_data', {}):
|
||||
factor_values[code] = self._factor.compute(data['index_data'][code])
|
||||
|
||||
return pd.DataFrame(factor_values)
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成信号
|
||||
|
||||
Args:
|
||||
factor_df: 因子 DataFrame
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(包含 'signal' 列)
|
||||
"""
|
||||
if self._signal_generator is None:
|
||||
self._signal_generator = self.init_signal_generator()
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
data = self.get_data()
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
|
||||
# 4. 执行回测(子类实现)
|
||||
return self._execute_backtest(signals, data)
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
执行回测(子类可覆盖)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
回测结果
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
9
framework_v2/shared/__init__.py
Normal file
9
framework_v2/shared/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
通用实现层(2+ 策略复用的组件)
|
||||
|
||||
包含:
|
||||
├── factors/ # 通用因子
|
||||
├── signals/ # 通用信号生成器
|
||||
├── execution/ # 通用执行器
|
||||
└── data/ # 通用数据处理
|
||||
"""
|
||||
17
framework_v2/shared/factors/__init__.py
Normal file
17
framework_v2/shared/factors/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
通用因子实现
|
||||
"""
|
||||
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor
|
||||
|
||||
# TALibFactorBase 需要安装 talib,可选导入
|
||||
try:
|
||||
from framework_v2.shared.factors.talib_base import TALibFactorBase
|
||||
__all__ = [
|
||||
'TALibFactorBase',
|
||||
'MomentumFactor',
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = [
|
||||
'MomentumFactor',
|
||||
]
|
||||
104
framework_v2/shared/factors/momentum.py
Normal file
104
framework_v2/shared/factors/momentum.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
动量因子(通用版本)
|
||||
|
||||
使用加权线性回归:得分 = 年化收益率 × R²
|
||||
|
||||
与现有 MomentumFactor 对比验证:
|
||||
- 输入相同 → 输出应该相同
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import math
|
||||
from framework_v2.core import FactorBase
|
||||
|
||||
|
||||
class MomentumFactor(FactorBase):
|
||||
"""
|
||||
动量因子
|
||||
|
||||
计算加权线性回归动量得分:
|
||||
得分 = 年化收益率 × R²
|
||||
|
||||
参数:
|
||||
- n_days: 动量窗口(默认25)
|
||||
- weighted: 是否加权(默认True)
|
||||
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||
"""
|
||||
|
||||
name = "momentum"
|
||||
category = "momentum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_days: int = 25,
|
||||
weighted: bool = True,
|
||||
crash_filter: bool = True
|
||||
):
|
||||
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||
self.n_days = n_days
|
||||
self.weighted = weighted
|
||||
self.crash_filter = crash_filter
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算动量因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.weighted:
|
||||
factor_values = prices.rolling(self.n_days).apply(
|
||||
lambda x: self._weighted_momentum_score(x.values),
|
||||
raw=False
|
||||
)
|
||||
else:
|
||||
factor_values = prices.pct_change(self.n_days)
|
||||
|
||||
if self.crash_filter:
|
||||
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||
|
||||
return factor_values
|
||||
|
||||
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||
"""计算加权动量得分(完全复制现有逻辑)"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
# 价格下界 clip,防止 log(0) 或 log(负数)
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
|
||||
# 异常值检测
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_returns = math.exp(slope * 250) - 1
|
||||
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
|
||||
return annualized_returns * r2
|
||||
|
||||
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
|
||||
"""崩盘过滤:连续3天跌>5%清零(完全复制现有逻辑)"""
|
||||
result = factor_values.copy()
|
||||
|
||||
for i in range(3, len(prices)):
|
||||
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||
|
||||
if con1 or con2:
|
||||
result.iloc[i] = 0.0
|
||||
|
||||
return result
|
||||
55
framework_v2/shared/factors/talib_base.py
Normal file
55
framework_v2/shared/factors/talib_base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
ta-lib 因子基类(通用)
|
||||
|
||||
所有 ta-lib 因子继承此类,只需指定函数和参数
|
||||
"""
|
||||
|
||||
import talib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from framework_v2.core import FactorBase
|
||||
|
||||
|
||||
class TALibFactorBase(FactorBase):
|
||||
"""
|
||||
ta-lib 因子基类
|
||||
|
||||
子类只需实现:
|
||||
- name: 因子名称
|
||||
- _talib_func: 返回 ta-lib 函数
|
||||
"""
|
||||
|
||||
category = "technical"
|
||||
|
||||
def __init__(self, period: int = 14, **params):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
period: 周期参数
|
||||
**params: 其他参数
|
||||
"""
|
||||
super().__init__(period=period, **params)
|
||||
self.period = period
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据
|
||||
|
||||
Returns:
|
||||
因子值序列
|
||||
"""
|
||||
close = data['close'].values.astype(float)
|
||||
|
||||
# 调用子类指定的 ta-lib 函数
|
||||
result = self._talib_func(close, timeperiod=self.period)
|
||||
|
||||
return pd.Series(result, index=data.index, name=self.name)
|
||||
|
||||
@property
|
||||
def _talib_func(self):
|
||||
"""子类必须实现,返回 ta-lib 函数"""
|
||||
raise NotImplementedError("Subclasses must implement _talib_func")
|
||||
3
framework_v2/tests/__init__.py
Normal file
3
framework_v2/tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
框架 V2 测试
|
||||
"""
|
||||
116
framework_v2/tests/test_momentum_parity.py
Normal file
116
framework_v2/tests/test_momentum_parity.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
因子对比验证测试
|
||||
|
||||
验证新框架的 MomentumFactor 与现有实现输出一致
|
||||
"""
|
||||
|
||||
import sys
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def test_momentum_factor_parity():
|
||||
"""验证新因子与旧因子输出一致"""
|
||||
|
||||
print("=" * 60)
|
||||
print(" MomentumFactor 对比测试")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 加载测试数据
|
||||
print("\n1. 加载测试数据...")
|
||||
test_data_path = project_root / 'data' / 'index_history_data'
|
||||
|
||||
# 使用纳指100数据测试
|
||||
import glob
|
||||
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
|
||||
if ndx_files:
|
||||
test_file = ndx_files[0]
|
||||
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
|
||||
print(f" ✓ 加载 {test_file}")
|
||||
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
|
||||
print(f" 数据长度: {len(data)} 条")
|
||||
else:
|
||||
print(" ⚠ 未找到测试数据,使用模拟数据")
|
||||
# 生成模拟数据
|
||||
np.random.seed(42)
|
||||
dates = pd.date_range('2020-01-01', periods=500, freq='B')
|
||||
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
|
||||
data = pd.DataFrame({
|
||||
'close': prices,
|
||||
'open': prices * 0.99,
|
||||
'high': prices * 1.01,
|
||||
'low': prices * 0.98,
|
||||
'volume': np.random.randint(1000000, 10000000, 500)
|
||||
}, index=dates)
|
||||
|
||||
# 2. 计算旧因子
|
||||
print("\n2. 计算旧因子(strategies/shared/factors/momentum.py)...")
|
||||
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
|
||||
|
||||
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
old_result = old_factor.compute(data)
|
||||
print(f" ✓ 旧因子计算完成")
|
||||
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
|
||||
print(f" NaN 数量: {old_result.isna().sum()}")
|
||||
|
||||
# 3. 计算新因子
|
||||
print("\n3. 计算新因子(framework_v2/shared/factors/momentum.py)...")
|
||||
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
|
||||
|
||||
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
|
||||
new_result = new_factor.compute(data)
|
||||
print(f" ✓ 新因子计算完成")
|
||||
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
|
||||
print(f" NaN 数量: {new_result.isna().sum()}")
|
||||
|
||||
# 4. 对比结果
|
||||
print("\n4. 对比结果...")
|
||||
|
||||
# 检查索引是否一致
|
||||
if not old_result.index.equals(new_result.index):
|
||||
print(" ✗ 索引不一致")
|
||||
return False
|
||||
print(" ✓ 索引一致")
|
||||
|
||||
# 检查数值差异
|
||||
diff = (old_result - new_result).abs()
|
||||
max_diff = diff.max()
|
||||
mean_diff = diff.mean()
|
||||
|
||||
print(f" 最大差异: {max_diff:.6e}")
|
||||
print(f" 平均差异: {mean_diff:.6e}")
|
||||
|
||||
# 允许浮点数精度误差(1e-10)
|
||||
tolerance = 1e-10
|
||||
if max_diff < tolerance:
|
||||
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✓ 测试通过:新旧因子输出完全一致!")
|
||||
print("=" * 60)
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ 差异超出容差范围")
|
||||
print("\n" + "=" * 60)
|
||||
print(" ✗ 测试失败:新旧因子输出不一致")
|
||||
print("=" * 60)
|
||||
|
||||
# 打印前10个差异点
|
||||
diff_nonzero = diff[diff > tolerance]
|
||||
if len(diff_nonzero) > 0:
|
||||
print(f"\n 前10个差异点:")
|
||||
for date, val in diff_nonzero.head(10).items():
|
||||
old_val = old_result.loc[date]
|
||||
new_val = new_result.loc[date]
|
||||
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
success = test_momentum_factor_parity()
|
||||
sys.exit(0 if success else 1)
|
||||
Reference in New Issue
Block a user