Compare commits
5 Commits
99d3584d05
...
43ce8056f1
| Author | SHA1 | Date | |
|---|---|---|---|
| 43ce8056f1 | |||
| 5212b004dc | |||
| 0954458114 | |||
| de988b919b | |||
| 341611c32b |
197
docs/experiments/982fbe2后代码变更总结.md
Normal file
197
docs/experiments/982fbe2后代码变更总结.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# 982fbe2 后代码变更总结
|
||||
|
||||
**基准 commit**: `982fbe2` — fix: 修复跨市场收益率计算Bug
|
||||
**变更时间**: 2026-05-22 ~ 2026-05-23
|
||||
**状态**: 未提交(working tree 修改)
|
||||
|
||||
---
|
||||
|
||||
## 变更文件总览
|
||||
|
||||
| 文件 | 操作 | 改动量 | 类别 |
|
||||
|------|------|--------|------|
|
||||
| `datasource/tushare_source.py` | 修改 | +99 行 | 数据层新增方法 |
|
||||
| `strategies/rotation/strategy.py` | 修改 | +25/-15 行 | 核心Bug修复 |
|
||||
| `strategies/rotation/config.yaml` | 修改 | +8/-10 行 | 标的配置修正 |
|
||||
| `strategies/shared/factors/momentum.py` | 修改 | +5/-1 行 | 因子鲁棒性 |
|
||||
| `scripts/export_backtest_detail.py` | 新增 | ~480 行 | 回测数据导出工具 |
|
||||
| `results/backtest_viewer.html` | 新增 | ~650 行 | HTML回放器 |
|
||||
|
||||
---
|
||||
|
||||
## 一、datasource/tushare_source.py
|
||||
|
||||
### 新增方法 1: `fetch_etf_adj()`
|
||||
|
||||
获取 ETF 后复权价格数据。通过 `fund_daily` + `fund_adj` 手动计算后复权价格,消除份额折算(拆分)对收益率计算的影响。
|
||||
|
||||
**关键实现**:
|
||||
- `fund_adj` 单次限 2000 条,按 5 年分段请求再拼接
|
||||
- 输出 `close_hfq = close × adj_factor`
|
||||
- 包含 `open`, `close`, `adj_factor`, `close_hfq` 等列
|
||||
|
||||
### 新增方法 2: `fetch_trade_cal()`
|
||||
|
||||
获取 A 股(上交所 SSE)官方交易日历。
|
||||
|
||||
```python
|
||||
def fetch_trade_cal(self, start_date: str, end_date: str) -> pd.DatetimeIndex:
|
||||
pro.trade_cal(exchange='SSE', start_date=..., end_date=..., is_open='1')
|
||||
```
|
||||
|
||||
**修复的问题**:之前策略使用 `index_close.index`(联合日历)包含海外独有交易日和周日,每年多约 70 天,导致 CAGR 被年数摊薄、回撤虚增。详见 `联合日历Bug详解.md`。
|
||||
|
||||
---
|
||||
|
||||
## 二、strategies/rotation/strategy.py
|
||||
|
||||
### 修复 1: 使用 SSE 官方交易日历
|
||||
|
||||
**问题**:`compute_factors()` 使用 `index_close.index`(所有标的日期的并集),包含非 A 股交易日。
|
||||
|
||||
**修复**:
|
||||
- `_get_data_from_flask_api()` 和 `_get_data_from_local()` 中调用 `fetch_trade_cal()` 获取 SSE 日历,写入 `data['a_share_dates']`
|
||||
- `compute_factors()` 优先使用 `data['a_share_dates']` 而非 `index_close.index`
|
||||
- `run_backtest()` 中用 A 股日历对齐 signals:`signals.reindex(a_share_dates, method='ffill')`
|
||||
|
||||
```python
|
||||
# Before
|
||||
a_share_dates = index_close.index # 联合日历(含海外交易日)
|
||||
|
||||
# After
|
||||
a_share_dates = data.get('a_share_dates') # SSE 官方日历
|
||||
```
|
||||
|
||||
### 修复 2: 信号对齐到 A 股日历
|
||||
|
||||
`run_backtest()` 中增加信号 reindex 逻辑,确保回测只在 A 股交易日执行:
|
||||
|
||||
```python
|
||||
if a_share_dates is not None:
|
||||
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、strategies/rotation/config.yaml
|
||||
|
||||
### 修改: 标的替换
|
||||
|
||||
| 变更 | 原配置 | 新配置 | 原因 |
|
||||
|------|--------|--------|------|
|
||||
| 有色金属信号源 | `HG=F`(COMEX 铜期货) | `399395.SZ`(国证有色金属指数) | HG=F 与 159980.SZ 日相关性仅 0.38,属于标的错配;399395.SZ 相关性 0.56 |
|
||||
| 原油 | `CL=F` + `160723.SZ` | 删除 | CL=F 与 160723.SZ 相关性仅 0.12,160723 是 LOF 有 QDII 额度限制 |
|
||||
|
||||
**影响**:标的从 12 只缩减为 10 只,消除了两对信号-执行严重错配的标的。
|
||||
|
||||
---
|
||||
|
||||
## 四、strategies/shared/factors/momentum.py
|
||||
|
||||
### 修复: 加权动量计算鲁棒性
|
||||
|
||||
```python
|
||||
# 新增:价格下界 clip,防止 log(0) 或 log(负数)
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
# 新增:异常值检测
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
```
|
||||
|
||||
**修复的问题**:某些历史数据中出现 0 值或极端异常值,`np.log(0)` 产生 `-inf` 导致后续回归计算崩溃。
|
||||
|
||||
---
|
||||
|
||||
## 五、scripts/export_backtest_detail.py(新增)
|
||||
|
||||
模式 B(指数信号 + ETF 收益)的逐日回测明细导出脚本。
|
||||
|
||||
**输出**:`results/backtest_detail.json`(~5.2 MB,1542 天 × 10 标的)
|
||||
|
||||
**每日每标的记录字段**:
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| `index_close` | 指数收盘价 |
|
||||
| `momentum` | 加权动量得分 |
|
||||
| `rank` | 动量排名(BOND 不参与) |
|
||||
| `threshold` | V3 动态阈值 |
|
||||
| `above_threshold` | 是否过阈值 |
|
||||
| `etf_close` / `etf_open` | ETF 原始价格 |
|
||||
| `etf_close_hfq` | ETF 后复权收盘价 |
|
||||
| `premium` | ETF 溢价率 |
|
||||
| `etf_return_ctc` | ETF close-to-close 日收益率 |
|
||||
| `index_return` | 指数日收益率 |
|
||||
| `is_held` | 是否在持仓 |
|
||||
| `entry_date` / `entry_price_etf` / `entry_price_idx` | 入场信息 |
|
||||
| `holding_days` | 持有天数 |
|
||||
| `cum_return_etf` / `cum_return_idx` | 累计收益(ETF/指数) |
|
||||
|
||||
**关键实现**:
|
||||
- 因子计算:先在原始日历计算,再 ffill 对齐到 A 股日历(与 strategy.py 一致)
|
||||
- 溢价率:通过 `fetch_etf_nav()` 获取基金净值,`(ETF收盘价 - NAV) / NAV`
|
||||
- NAV 去重:`nav_raw[~nav_raw.index.duplicated(keep='last')]`
|
||||
- 净值扩展:信号前的日期填充 nav=1.0
|
||||
|
||||
---
|
||||
|
||||
## 六、results/backtest_viewer.html(新增)
|
||||
|
||||
单文件 HTML 回放器,用于人工逐日验证回测数据准确性。
|
||||
|
||||
**功能**:
|
||||
- 文件选择器加载 JSON
|
||||
- Canvas 净值曲线(点击/拖拽定位日期)
|
||||
- 策略统计面板(CAGR、最大回撤、夏普、Calmar、胜率、换仓次数、平均持仓天数)
|
||||
- 全标的排名表(按动量排序,含市场分组、溢价率、动态阈值分隔行)
|
||||
- 持仓卡片(ETF + 指数双累计收益)
|
||||
- 调仓明细栏
|
||||
- 键盘快捷键:← → 翻页,Space 播放/暂停
|
||||
- 跳转前/后调仓日按钮
|
||||
|
||||
---
|
||||
|
||||
## 修复前后指标对比
|
||||
|
||||
**回测区间**:2020-01-02 ~ 2026-05-19,使用 `own_strategy_etf_vs_index.py` 四模式对比
|
||||
|
||||
| 指标 | 修复前 Mode A | 修复后 Mode A | 变化 |
|
||||
|------|-------------|-------------|------|
|
||||
| CAGR | 15.63% | 11.80% | -3.83% |
|
||||
| 最大回撤 | — | -29.49% | — |
|
||||
| 夏普 | — | 0.818 | — |
|
||||
|
||||
| 指标 | 修复前 Mode B | 修复后 Mode B | 变化 |
|
||||
|------|-------------|-------------|------|
|
||||
| CAGR | 26.13% | 28.07% | +1.94% |
|
||||
| 最大回撤 | — | -13.34% | — |
|
||||
| 夏普 | — | 1.685 | — |
|
||||
|
||||
| 模式 | CAGR | 最大回撤 | 夏普 | Calmar |
|
||||
|------|------|----------|------|--------|
|
||||
| A (指数→指数) | 11.80% | -29.49% | 0.818 | 0.400 |
|
||||
| B (指数→ETF) | 28.07% | -13.34% | 1.685 | 2.104 |
|
||||
| C (ETF→ETF) | 21.27% | -13.27% | 1.304 | 1.603 |
|
||||
| D (ETF修正) | 17.93% | -13.36% | 1.130 | 1.342 |
|
||||
|
||||
**差异分解**:
|
||||
- 收益差异 B-A: +16.27%/年(ETF 收益 vs 指数收益)
|
||||
- 信号差异 C-B: -6.80%/年(ETF 信号 vs 指数信号)
|
||||
- 隔夜偏差 C-D: +3.34%/年(close 买入虚高 vs open 买入现实)
|
||||
|
||||
---
|
||||
|
||||
## 依赖关系
|
||||
|
||||
```
|
||||
config.yaml (标的配置)
|
||||
↓
|
||||
tushare_source.py (fetch_trade_cal / fetch_etf_adj)
|
||||
↓
|
||||
strategy.py (compute_factors → generate_signals → run_backtest)
|
||||
↓
|
||||
export_backtest_detail.py (导出 JSON)
|
||||
↓
|
||||
backtest_viewer.html (加载 JSON 可视化)
|
||||
```
|
||||
72
framework_v2/config/__init__.py
Normal file
72
framework_v2/config/__init__.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
配置模块
|
||||
|
||||
提供配置加载、验证、管理功能
|
||||
"""
|
||||
|
||||
from framework_v2.config.schemas import (
|
||||
# 完整配置
|
||||
RotationStrategyConfig,
|
||||
StrategyConfig, # 通用配置
|
||||
|
||||
# 子配置
|
||||
AssetPool,
|
||||
AssetConfig,
|
||||
PremiumConfig,
|
||||
GroupConfig,
|
||||
FactorConfig,
|
||||
RotationConfig,
|
||||
ThresholdConfig,
|
||||
DynamicThresholdConfig,
|
||||
RebalanceConfig,
|
||||
PremiumControlConfig,
|
||||
DataConfig,
|
||||
DataSourceConfig,
|
||||
BenchmarkConfig,
|
||||
BacktestConfig,
|
||||
MetadataConfig,
|
||||
|
||||
# 枚举(已移除 MarketType,改用字符串 group)
|
||||
FactorType,
|
||||
PremiumMode,
|
||||
ThresholdMode,
|
||||
DataSourceType,
|
||||
)
|
||||
|
||||
from framework_v2.config.loader import (
|
||||
ConfigLoader,
|
||||
get_config_loader,
|
||||
load_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 配置 Schema
|
||||
'RotationStrategyConfig',
|
||||
'StrategyConfig',
|
||||
'AssetPool',
|
||||
'AssetConfig',
|
||||
'PremiumConfig',
|
||||
'GroupConfig',
|
||||
'FactorConfig',
|
||||
'RotationConfig',
|
||||
'ThresholdConfig',
|
||||
'DynamicThresholdConfig',
|
||||
'RebalanceConfig',
|
||||
'PremiumControlConfig',
|
||||
'DataConfig',
|
||||
'DataSourceConfig',
|
||||
'BenchmarkConfig',
|
||||
'BacktestConfig',
|
||||
'MetadataConfig',
|
||||
|
||||
# 枚举(已移除 MarketType,改用字符串 group)
|
||||
'FactorType',
|
||||
'PremiumMode',
|
||||
'ThresholdMode',
|
||||
'DataSourceType',
|
||||
|
||||
# 加载器
|
||||
'ConfigLoader',
|
||||
'get_config_loader',
|
||||
'load_config',
|
||||
]
|
||||
247
framework_v2/config/loader.py
Normal file
247
framework_v2/config/loader.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
支持:
|
||||
1. YAML 配置文件加载
|
||||
2. Pydantic Schema 验证
|
||||
3. 环境变量替换
|
||||
4. 配置合并(默认值 + 用户配置)
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from framework_v2.config.schemas import StrategyConfig, RotationStrategyConfig
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
用法:
|
||||
loader = ConfigLoader()
|
||||
config = loader.load('config/rotation.yaml')
|
||||
"""
|
||||
|
||||
def __init__(self, config_dir: str = None):
|
||||
"""
|
||||
初始化
|
||||
|
||||
Args:
|
||||
config_dir: 配置目录(默认 framework_v2/config)
|
||||
"""
|
||||
if config_dir is None:
|
||||
# 默认配置目录
|
||||
config_dir = Path(__file__).parent
|
||||
|
||||
self.config_dir = Path(config_dir)
|
||||
|
||||
def load(self, config_file: str, config_type: str = 'strategy') -> StrategyConfig:
|
||||
"""
|
||||
加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径(相对路径或绝对路径)
|
||||
config_type: 配置类型('strategy' 或 'rotation')
|
||||
|
||||
Returns:
|
||||
验证后的配置对象
|
||||
|
||||
示例:
|
||||
>>> loader = ConfigLoader()
|
||||
>>> config = loader.load('rotation_example.yaml')
|
||||
>>> print(config.factor.n_days)
|
||||
25
|
||||
"""
|
||||
# 1. 解析文件路径
|
||||
config_path = self._resolve_path(config_file)
|
||||
|
||||
# 2. 读取 YAML
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# 3. 环境变量替换
|
||||
config_dict = self._substitute_env_vars(config_dict)
|
||||
|
||||
# 4. Pydantic 验证(根据类型选择 Schema)
|
||||
if config_type == 'rotation':
|
||||
# 兼容旧版本:轮动策略专用配置
|
||||
config = RotationStrategyConfig(**config_dict)
|
||||
else:
|
||||
# 通用策略配置(推荐)
|
||||
config = StrategyConfig(**config_dict)
|
||||
|
||||
return config
|
||||
|
||||
def load_dict(self, config_dict: Dict[str, Any], config_type: str = 'strategy') -> StrategyConfig:
|
||||
"""
|
||||
从字典加载配置
|
||||
|
||||
Args:
|
||||
config_dict: 配置字典
|
||||
config_type: 配置类型('strategy' 或 'rotation')
|
||||
|
||||
Returns:
|
||||
验证后的配置对象
|
||||
"""
|
||||
# 环境变量替换
|
||||
config_dict = self._substitute_env_vars(config_dict)
|
||||
|
||||
# Pydantic 验证(根据类型选择 Schema)
|
||||
if config_type == 'rotation':
|
||||
config = RotationStrategyConfig(**config_dict)
|
||||
else:
|
||||
config = StrategyConfig(**config_dict)
|
||||
|
||||
return config
|
||||
|
||||
def _resolve_path(self, config_file: str) -> Path:
|
||||
"""
|
||||
解析配置文件路径
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
绝对路径
|
||||
"""
|
||||
path = Path(config_file)
|
||||
|
||||
# 如果是绝对路径,直接返回
|
||||
if path.is_absolute():
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {path}")
|
||||
return path
|
||||
|
||||
# 相对路径:先在配置目录查找
|
||||
config_path = self.config_dir / path
|
||||
if config_path.exists():
|
||||
return config_path
|
||||
|
||||
# 然后在当前工作目录查找
|
||||
cwd_path = Path.cwd() / path
|
||||
if cwd_path.exists():
|
||||
return cwd_path
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"配置文件未找到: {path}\n"
|
||||
f"搜索路径:\n"
|
||||
f" - {config_path}\n"
|
||||
f" - {cwd_path}"
|
||||
)
|
||||
|
||||
def _substitute_env_vars(self, config: Any) -> Any:
|
||||
"""
|
||||
替换配置中的环境变量
|
||||
|
||||
支持格式:
|
||||
- ${VAR_NAME}
|
||||
- ${VAR_NAME:default_value}
|
||||
|
||||
Args:
|
||||
config: 配置对象(dict/list/str)
|
||||
|
||||
Returns:
|
||||
替换后的配置
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
return {
|
||||
key: self._substitute_env_vars(value)
|
||||
for key, value in config.items()
|
||||
}
|
||||
elif isinstance(config, list):
|
||||
return [
|
||||
self._substitute_env_vars(item)
|
||||
for item in config
|
||||
]
|
||||
elif isinstance(config, str):
|
||||
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
|
||||
pattern = r'\$\{([^}]+)\}'
|
||||
|
||||
def replace_match(match):
|
||||
var_expr = match.group(1)
|
||||
|
||||
# 检查是否有默认值
|
||||
if ':' in var_expr:
|
||||
var_name, default = var_expr.split(':', 1)
|
||||
else:
|
||||
var_name = var_expr
|
||||
default = None
|
||||
|
||||
# 从环境变量读取
|
||||
value = os.getenv(var_name)
|
||||
|
||||
if value is None:
|
||||
if default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"环境变量未设置: {var_name}\n"
|
||||
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
return re.sub(pattern, replace_match, config)
|
||||
else:
|
||||
return config
|
||||
|
||||
def get_available_configs(self) -> list:
|
||||
"""
|
||||
获取可用的配置文件列表
|
||||
|
||||
Returns:
|
||||
配置文件名列表
|
||||
"""
|
||||
if not self.config_dir.exists():
|
||||
return []
|
||||
|
||||
return [
|
||||
f.name
|
||||
for f in self.config_dir.glob('*.yaml')
|
||||
if f.is_file()
|
||||
]
|
||||
|
||||
|
||||
# 全局实例
|
||||
_config_loader: Optional[ConfigLoader] = None
|
||||
|
||||
|
||||
def get_config_loader(config_dir: str = None) -> ConfigLoader:
|
||||
"""
|
||||
获取配置加载器单例
|
||||
|
||||
Args:
|
||||
config_dir: 配置目录
|
||||
|
||||
Returns:
|
||||
ConfigLoader 实例
|
||||
"""
|
||||
global _config_loader
|
||||
|
||||
if _config_loader is None:
|
||||
_config_loader = ConfigLoader(config_dir)
|
||||
|
||||
return _config_loader
|
||||
|
||||
|
||||
def load_config(config_file: str, config_type: str = 'strategy') -> StrategyConfig:
|
||||
"""
|
||||
快捷函数:加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
config_type: 配置类型('strategy' 或 'rotation')
|
||||
|
||||
Returns:
|
||||
验证后的配置对象
|
||||
|
||||
示例:
|
||||
>>> from framework_v2.config import load_config
|
||||
>>> config = load_config('rotation_example.yaml')
|
||||
"""
|
||||
loader = get_config_loader()
|
||||
return loader.load(config_file, config_type=config_type)
|
||||
268
framework_v2/config/rotation_global.yaml
Normal file
268
framework_v2/config/rotation_global.yaml
Normal file
@@ -0,0 +1,268 @@
|
||||
# 跨市场轮动策略配置(扁平化设计)
|
||||
#
|
||||
# 配置版本: 2.0.0
|
||||
# 最后更新: 2024-04-16
|
||||
# 策略名称: rotation_global
|
||||
# 描述: 全球资产大类轮动 - 扁平化资产池设计
|
||||
|
||||
# ============================================================
|
||||
# 元数据
|
||||
# ============================================================
|
||||
metadata:
|
||||
version: "2.0.0"
|
||||
strategy: "rotation_global"
|
||||
description: "全球资产大类轮动策略 V2 - 扁平化资产池"
|
||||
last_updated: "2024-04-16"
|
||||
|
||||
# ============================================================
|
||||
# 资产池配置(扁平化设计)
|
||||
# ============================================================
|
||||
asset_pools:
|
||||
assets:
|
||||
# ============================================================
|
||||
# 美股指数(通过 A 股 ETF 交易)
|
||||
# ============================================================
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
group: "US_TECH"
|
||||
signal_source: "NDX" # 纳指信号
|
||||
trade_source: "513100.SH" # A股ETF交易
|
||||
description: "纳斯达克100指数,科技股代表"
|
||||
|
||||
"SPX":
|
||||
name: "标普500"
|
||||
group: "US_TECH"
|
||||
signal_source: "SPX"
|
||||
trade_source: "513500.SH"
|
||||
description: "标普500指数,美股大盘"
|
||||
|
||||
# ============================================================
|
||||
# A股指数(直接交易 ETF)
|
||||
# ============================================================
|
||||
"399006.SZ":
|
||||
name: "创业板指"
|
||||
group: "CN_GROWTH"
|
||||
signal_source: "399006.SZ"
|
||||
trade_source: "159915.SZ"
|
||||
description: "创业板指数,成长股代表"
|
||||
|
||||
"000300.SH":
|
||||
name: "沪深300"
|
||||
group: "CN_GROWTH"
|
||||
signal_source: "000300.SH"
|
||||
trade_source: "510300.SH"
|
||||
description: "沪深300指数,大盘蓝筹"
|
||||
|
||||
"H30269.CSI":
|
||||
name: "中证红利低波"
|
||||
group: "CN_GROWTH"
|
||||
signal_source: "H30269.CSI"
|
||||
trade_source: "512890.SH"
|
||||
description: "红利低波指数,价值股代表"
|
||||
|
||||
# ============================================================
|
||||
# 日本股市(通过 A 股 ETF 交易)
|
||||
# ============================================================
|
||||
"N225":
|
||||
name: "日经225"
|
||||
group: "JP_BROAD"
|
||||
signal_source: "N225"
|
||||
trade_source: "513520.SH"
|
||||
description: "日经225指数,日本股市"
|
||||
|
||||
# ============================================================
|
||||
# 欧洲股市(通过 A 股 ETF 交易)
|
||||
# ============================================================
|
||||
"GDAXI":
|
||||
name: "德国DAX"
|
||||
group: "EU_BROAD"
|
||||
signal_source: "GDAXI"
|
||||
trade_source: "513030.SH"
|
||||
description: "德国DAX指数,欧洲股市"
|
||||
|
||||
# ============================================================
|
||||
# 港股(通过 A 股 ETF 交易)
|
||||
# ============================================================
|
||||
"HSI":
|
||||
name: "恒生指数"
|
||||
group: "HK_TECH"
|
||||
signal_source: "HSI"
|
||||
trade_source: "159920.SZ"
|
||||
description: "恒生指数,香港股市"
|
||||
|
||||
"HSTECH.HK":
|
||||
name: "恒生科技"
|
||||
group: "HK_TECH"
|
||||
signal_source: "HSTECH.HK"
|
||||
trade_source: "513130.SH"
|
||||
description: "恒生科技指数,港股科技"
|
||||
|
||||
# ============================================================
|
||||
# 商品(国际期货信号 → A股ETF交易)
|
||||
# ============================================================
|
||||
"GC=F":
|
||||
name: "黄金"
|
||||
group: "COMMODITY"
|
||||
signal_source: "GC=F" # COMEX黄金期货
|
||||
trade_source: "518880.SH" # A股黄金ETF
|
||||
description: "COMEX黄金期货,避险资产"
|
||||
|
||||
"CL=F":
|
||||
name: "原油"
|
||||
group: "COMMODITY"
|
||||
signal_source: "CL=F" # WTI原油期货
|
||||
trade_source: "160723.SZ" # A股原油基金
|
||||
description: "WTI原油期货,能源商品"
|
||||
|
||||
"HG=F":
|
||||
name: "有色金属"
|
||||
group: "COMMODITY"
|
||||
signal_source: "HG=F" # COMEX铜期货
|
||||
trade_source: "159980.SZ" # A股有色ETF
|
||||
description: "COMEX铜期货,工业金属"
|
||||
|
||||
# ============================================================
|
||||
# 固定收益(直接交易指数)
|
||||
# ============================================================
|
||||
"931862.CSI":
|
||||
name: "短债指数"
|
||||
group: "FIXED_INCOME"
|
||||
signal_source: "931862.CSI"
|
||||
trade_source: "931862.CSI" # 直接交易指数(无ETF)
|
||||
description: "中证0-9个月国债指数,久期<1年,防御配置"
|
||||
|
||||
# ============================================================
|
||||
# 加密货币(未来扩展示例)
|
||||
# ============================================================
|
||||
# "BTC":
|
||||
# name: "比特币"
|
||||
# group: "CRYPTO"
|
||||
# signal_source: "BTC"
|
||||
# trade_source: "BTC"
|
||||
# description: "比特币,数字黄金"
|
||||
|
||||
# ============================================================
|
||||
# 外汇(未来扩展示例)
|
||||
# ============================================================
|
||||
# "EURUSD":
|
||||
# name: "欧元/美元"
|
||||
# group: "FOREX"
|
||||
# signal_source: "EURUSD"
|
||||
# trade_source: "EURUSD"
|
||||
# description: "欧元/美元汇率"
|
||||
|
||||
# ============================================================
|
||||
# 基准配置
|
||||
# ============================================================
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300"
|
||||
|
||||
# ============================================================
|
||||
# 回测配置
|
||||
# ============================================================
|
||||
backtest:
|
||||
start_date: "2020-01-01"
|
||||
# end_date: null # null 表示至今
|
||||
|
||||
# ============================================================
|
||||
# 因子配置
|
||||
# ============================================================
|
||||
factor:
|
||||
type: "weighted_momentum" # 因子类型: momentum / slope_r2 / weighted_momentum
|
||||
n_days: 25 # 动量窗口期(5-250天)
|
||||
|
||||
# 动态周期参数(可选)
|
||||
auto_day: false # 是否启用动态周期
|
||||
min_days: 20 # 最小周期
|
||||
max_days: 60 # 最大周期
|
||||
|
||||
# ============================================================
|
||||
# 轮动配置
|
||||
# ============================================================
|
||||
rotation:
|
||||
# ============================================================
|
||||
# 模式 1:全局选股(默认)
|
||||
# ============================================================
|
||||
select_num: 5 # 全局选 Top-5
|
||||
diversified: false # 不分散化
|
||||
|
||||
# ============================================================
|
||||
# 模式 2:分散化选股(取消注释启用)
|
||||
# ============================================================
|
||||
# diversified: true # 启用分散化
|
||||
# diversification_groups: # 按市场分组选股
|
||||
# - group: "US_TECH"
|
||||
# select_num: 1 # 美股选 1 只
|
||||
# - group: "CN_GROWTH"
|
||||
# select_num: 1 # A股选 1 只
|
||||
# - group: "JP_BROAD"
|
||||
# select_num: 1 # 日本选 1 只
|
||||
# - group: "EU_BROAD"
|
||||
# select_num: 1 # 欧洲选 1 只
|
||||
# - group: "HK_TECH"
|
||||
# select_num: 1 # 港股选 1 只
|
||||
# - group: "COMMODITY"
|
||||
# select_num: 1 # 商品选 1 只
|
||||
# - group: "FIXED_INCOME"
|
||||
# select_num: 1 # 债券选 1 只
|
||||
|
||||
# 阈值配置(统一 V2/V3)
|
||||
threshold:
|
||||
mode: "dynamic" # 阈值模式: fixed / dynamic
|
||||
fixed_value: 0.0 # 固定阈值(mode=fixed时使用)
|
||||
|
||||
# 动态阈值配置(mode=dynamic时使用)
|
||||
dynamic:
|
||||
reference: "931862.CSI" # 参考标的(短债指数)
|
||||
ratio: 1.0 # 阈值 = 短债动量 × ratio
|
||||
fallback_enabled: true # 参考不可用时是否回退
|
||||
fallback_value: 0.0 # 回退值
|
||||
|
||||
# ============================================================
|
||||
# 调仓配置
|
||||
# ============================================================
|
||||
rebalance:
|
||||
min_hold_days: 1 # 最低持有天数(1-30)
|
||||
score_threshold: 0.0 # 调仓得分阈值(0-0.5,表示%)
|
||||
trade_cost: 0.001 # 单次换仓成本(0-0.01,即 0.1%)
|
||||
|
||||
# ============================================================
|
||||
# 溢价控制配置
|
||||
# ============================================================
|
||||
premium_control:
|
||||
enabled: true # 是否启用溢价控制
|
||||
default_threshold: 0.10 # 默认溢价阈值(10%)
|
||||
mode: "filter" # 控制模式: filter(排除)/ penalize(降权)
|
||||
penalty_factor: 0.5 # 降权惩罚系数
|
||||
|
||||
# 按市场覆盖配置
|
||||
market_overrides:
|
||||
CN_EQUITY: # A股 ETF
|
||||
enabled: false # 不启用(溢价通常 < 0.5%)
|
||||
HK_EQUITY: # 港股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
US_EQUITY: # 美股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
JP_EQUITY: # 日本 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
EU_EQUITY: # 欧洲 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
COMMODITY: # 商品 ETF
|
||||
enabled: false # 不启用
|
||||
|
||||
# ============================================================
|
||||
# 数据配置
|
||||
# ============================================================
|
||||
data:
|
||||
# 数据源列表(按优先级排序)
|
||||
sources:
|
||||
# 主数据源:Flask API
|
||||
- type: "flask_api"
|
||||
enabled: true
|
||||
url: "${FLASK_API_URL}" # 从环境变量读取
|
||||
timeout: 120
|
||||
444
framework_v2/config/schemas.py
Normal file
444
framework_v2/config/schemas.py
Normal file
@@ -0,0 +1,444 @@
|
||||
"""
|
||||
配置 Schema 定义
|
||||
|
||||
使用 Pydantic 验证配置文件的类型安全
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from typing import Optional, Dict, List, Literal
|
||||
from enum import Enum
|
||||
from datetime import date
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 枚举类型
|
||||
# ============================================================
|
||||
|
||||
# 注意:不再使用 MarketType 枚举,改用字符串类型的 group 字段
|
||||
# group 字段用于策略分组(组内竞争,强制分散)
|
||||
|
||||
|
||||
class FactorType(str, Enum):
|
||||
"""因子类型枚举"""
|
||||
MOMENTUM = "momentum"
|
||||
SLOPE_R2 = "slope_r2"
|
||||
WEIGHTED_MOMENTUM = "weighted_momentum"
|
||||
|
||||
|
||||
class PremiumMode(str, Enum):
|
||||
"""溢价控制模式"""
|
||||
FILTER = "filter" # 完全排除
|
||||
PENALIZE = "penalize" # 降权
|
||||
|
||||
|
||||
class ThresholdMode(str, Enum):
|
||||
"""阈值模式"""
|
||||
FIXED = "fixed" # 固定阈值
|
||||
DYNAMIC = "dynamic" # 动态阈值
|
||||
|
||||
|
||||
class DataSourceType(str, Enum):
|
||||
"""数据源类型"""
|
||||
FLASK_API = "flask_api"
|
||||
TUSHARE = "tushare"
|
||||
YFINANCE = "yfinance"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 资产池 Schema(扁平化设计)
|
||||
# ============================================================
|
||||
|
||||
class PremiumConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
|
||||
|
||||
|
||||
class AssetConfig(BaseModel):
|
||||
"""
|
||||
标的配置(通用,支持所有场景)
|
||||
|
||||
场景 1:指数信号 → 指数收益
|
||||
signal_source: "NDX"
|
||||
trade_source: "NDX"
|
||||
|
||||
场景 2:指数信号 → ETF收益
|
||||
signal_source: "NDX"
|
||||
trade_source: "513100.SH"
|
||||
|
||||
场景 3:ETF信号 → ETF收益
|
||||
signal_source: "518880.SH"
|
||||
trade_source: "518880.SH"
|
||||
|
||||
场景 4:个股信号 → 个股收益
|
||||
signal_source: "AAPL"
|
||||
trade_source: "AAPL"
|
||||
"""
|
||||
name: str = Field(..., description="标的名称")
|
||||
group: str = Field(..., description="策略分组(组内竞争,强制分散)")
|
||||
|
||||
# 核心字段:信号来源和交易来源
|
||||
signal_source: str = Field(..., description="信号来源代码(计算因子用)")
|
||||
trade_source: str = Field(..., description="交易来源代码(计算收益用)")
|
||||
|
||||
# 可选字段
|
||||
description: Optional[str] = Field(None, description="标的描述")
|
||||
etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)")
|
||||
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
|
||||
|
||||
@property
|
||||
def is_cross_market(self) -> bool:
|
||||
"""是否跨市场(信号和交易不同)"""
|
||||
return self.signal_source != self.trade_source
|
||||
|
||||
|
||||
class AssetPool(BaseModel):
|
||||
"""
|
||||
资产池(扁平化设计)
|
||||
|
||||
优势:
|
||||
1. 不预设分类,支持任意策略分组
|
||||
2. 通过 group 字段自动分组
|
||||
3. 配置简单直观
|
||||
4. 易于扩展新分组
|
||||
|
||||
示例:
|
||||
asset_pool:
|
||||
assets:
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
group: "US_TECH"
|
||||
signal_source: "NDX"
|
||||
trade_source: "513100.SH"
|
||||
|
||||
"BTC":
|
||||
name: "比特币"
|
||||
group: "CRYPTO"
|
||||
signal_source: "BTC"
|
||||
trade_source: "BTC"
|
||||
"""
|
||||
assets: Dict[str, AssetConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="所有标的(flat结构)"
|
||||
)
|
||||
|
||||
@property
|
||||
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
|
||||
"""按策略分组"""
|
||||
groups = {}
|
||||
for code, asset in self.assets.items():
|
||||
group = asset.group
|
||||
if group not in groups:
|
||||
groups[group] = {}
|
||||
groups[group][code] = asset
|
||||
return groups
|
||||
|
||||
@property
|
||||
def groups(self) -> list:
|
||||
"""获取所有策略分组"""
|
||||
return list(self.by_group.keys())
|
||||
|
||||
def get_signal_codes(self, group: str = None) -> list:
|
||||
"""
|
||||
获取信号标的
|
||||
|
||||
Args:
|
||||
group: 可选,过滤特定分组(如 'US_TECH')
|
||||
"""
|
||||
if group:
|
||||
return [
|
||||
asset.signal_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.signal_source for asset in self.assets.values()]
|
||||
|
||||
def get_trade_codes(self, group: str = None) -> list:
|
||||
"""获取交易标的"""
|
||||
if group:
|
||||
return [
|
||||
asset.trade_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.trade_source for asset in self.assets.values()]
|
||||
|
||||
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
|
||||
"""获取信号→交易映射"""
|
||||
mapping = {}
|
||||
for asset in self.assets.values():
|
||||
if group and asset.group != group:
|
||||
continue
|
||||
mapping[asset.signal_source] = asset.trade_source
|
||||
return mapping
|
||||
|
||||
def count(self, group: str = None) -> int:
|
||||
"""获取标的数量"""
|
||||
if group:
|
||||
return len([a for a in self.assets.values() if a.group == group])
|
||||
return len(self.assets)
|
||||
|
||||
@field_validator('assets')
|
||||
@classmethod
|
||||
def check_non_empty(cls, v):
|
||||
"""至少配置一个标的"""
|
||||
if not v:
|
||||
import warnings
|
||||
warnings.warn("资产池为空")
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 因子配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class FactorConfig(BaseModel):
|
||||
"""因子配置"""
|
||||
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM, description="因子类型")
|
||||
n_days: int = Field(default=25, ge=5, le=250, description="动量窗口期(天数)")
|
||||
|
||||
# 动态周期参数
|
||||
auto_day: bool = Field(default=False, description="是否启用动态周期")
|
||||
min_days: int = Field(default=20, ge=5, le=100, description="最小周期")
|
||||
max_days: int = Field(default=60, ge=20, le=250, description="最大周期")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 阈值配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class DynamicThresholdConfig(BaseModel):
|
||||
"""动态阈值配置"""
|
||||
reference: str = Field(..., description="参考标的代码")
|
||||
ratio: float = Field(default=1.0, gt=0, description="倍数")
|
||||
fallback_enabled: bool = Field(default=True, description="参考不可用时是否回退")
|
||||
fallback_value: float = Field(default=0.0, description="回退值")
|
||||
|
||||
|
||||
class ThresholdConfig(BaseModel):
|
||||
"""阈值配置(统一 V2/V3)"""
|
||||
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC, description="阈值模式")
|
||||
fixed_value: float = Field(default=0.0, description="固定阈值(mode=fixed时使用)")
|
||||
dynamic: Optional[DynamicThresholdConfig] = Field(None, description="动态阈值(mode=dynamic时使用)")
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_dynamic_config(self):
|
||||
"""验证动态阈值配置"""
|
||||
if self.mode == ThresholdMode.DYNAMIC and self.dynamic is None:
|
||||
raise ValueError("dynamic 模式必须配置 dynamic 字段")
|
||||
return self
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 轮动配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class GroupConfig(BaseModel):
|
||||
"""策略分组配置(用于分散化选股)"""
|
||||
group: str = Field(..., description="策略分组名称")
|
||||
select_num: int = Field(default=1, ge=1, le=10, description="该分组选股数量")
|
||||
|
||||
|
||||
class RotationConfig(BaseModel):
|
||||
"""
|
||||
轮动配置
|
||||
|
||||
支持两种选股模式:
|
||||
1. 全局选股:diversified=false,直接选 Top-N
|
||||
2. 分散化选股:diversified=true,按策略分组选股
|
||||
"""
|
||||
select_num: int = Field(default=3, ge=1, le=20, description="选股数量(全局模式)")
|
||||
diversified: bool = Field(default=False, description="是否分散化选股")
|
||||
|
||||
# 分散化配置(可选)
|
||||
diversification_groups: Optional[List[GroupConfig]] = Field(
|
||||
None,
|
||||
description="分散化分组配置(diversified=true时使用)"
|
||||
)
|
||||
|
||||
threshold: ThresholdConfig = Field(
|
||||
default_factory=ThresholdConfig,
|
||||
description="动量阈值"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 调仓配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class RebalanceConfig(BaseModel):
|
||||
"""调仓配置"""
|
||||
min_hold_days: int = Field(default=1, ge=1, le=30, description="最低持有天数")
|
||||
score_threshold: float = Field(default=0.0, ge=0, le=0.5, description="调仓得分阈值(%)")
|
||||
trade_cost: float = Field(default=0.001, ge=0, le=0.01, description="单次换仓成本")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 溢价控制 Schema
|
||||
# ============================================================
|
||||
|
||||
class MarketPremiumOverride(BaseModel):
|
||||
"""市场溢价覆盖配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
|
||||
|
||||
|
||||
class PremiumControlConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用溢价控制")
|
||||
default_threshold: float = Field(default=0.10, ge=0, le=1.0, description="默认溢价阈值")
|
||||
mode: PremiumMode = Field(default=PremiumMode.FILTER, description="控制模式")
|
||||
penalty_factor: float = Field(default=0.5, ge=0, le=1.0, description="降权惩罚系数")
|
||||
|
||||
# 按市场覆盖
|
||||
market_overrides: Dict[str, MarketPremiumOverride] = Field(
|
||||
default_factory=dict,
|
||||
description="按市场覆盖配置"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 数据源配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
"""单个数据源配置"""
|
||||
type: DataSourceType = Field(..., description="数据源类型")
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
timeout: int = Field(default=120, ge=10, le=600, description="超时时间(秒)")
|
||||
|
||||
# Flask API 特定配置
|
||||
url: Optional[str] = Field(None, description="API地址(flask_api使用)")
|
||||
|
||||
# Tushare 特定配置
|
||||
token: Optional[str] = Field(None, description="Tushare Token(tushare使用)")
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
"""数据配置"""
|
||||
sources: List[DataSourceConfig] = Field(..., description="数据源列表(按优先级排序)")
|
||||
use_cache: bool = Field(default=True, description="是否使用本地缓存")
|
||||
cache_dir: str = Field(default="data_cache", description="缓存目录")
|
||||
|
||||
@field_validator('sources')
|
||||
@classmethod
|
||||
def check_at_least_one(cls, v):
|
||||
"""至少配置一个数据源"""
|
||||
if not v:
|
||||
raise ValueError("必须配置至少一个数据源")
|
||||
return v
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 基准配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""基准配置"""
|
||||
code: str = Field(..., description="基准代码")
|
||||
name: str = Field(..., description="基准名称")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 回测配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class BacktestConfig(BaseModel):
|
||||
"""回测配置"""
|
||||
start_date: str = Field(..., description="回测起始日期(YYYY-MM-DD)")
|
||||
end_date: Optional[str] = Field(None, description="回测结束日期(None表示至今)")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 元数据 Schema
|
||||
# ============================================================
|
||||
|
||||
class MetadataConfig(BaseModel):
|
||||
"""配置元数据"""
|
||||
version: str = Field(default="1.0.0", description="配置版本")
|
||||
strategy: str = Field(default="rotation", description="策略名称")
|
||||
description: str = Field(default="", description="配置描述")
|
||||
last_updated: Optional[str] = Field(None, description="最后更新日期")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 完整策略配置 Schema
|
||||
# ============================================================
|
||||
|
||||
class RotationStrategyConfig(BaseModel):
|
||||
"""
|
||||
ETF轮动策略完整配置
|
||||
|
||||
使用示例:
|
||||
from framework_v2.config.schemas import RotationStrategyConfig
|
||||
import yaml
|
||||
|
||||
with open('config/rotation.yaml') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
config = RotationStrategyConfig(**config_dict)
|
||||
"""
|
||||
metadata: MetadataConfig = Field(default_factory=MetadataConfig, description="元数据")
|
||||
|
||||
# 资产池
|
||||
asset_pools: AssetPool = Field(..., description="资产池配置")
|
||||
|
||||
# 基准
|
||||
benchmark: BenchmarkConfig = Field(..., description="基准配置")
|
||||
|
||||
# 回测
|
||||
backtest: BacktestConfig = Field(..., description="回测配置")
|
||||
|
||||
# 因子
|
||||
factor: FactorConfig = Field(default_factory=FactorConfig, description="因子配置")
|
||||
|
||||
# 轮动
|
||||
rotation: RotationConfig = Field(default_factory=RotationConfig, description="轮动配置")
|
||||
|
||||
# 调仓
|
||||
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig, description="调仓配置")
|
||||
|
||||
# 溢价控制
|
||||
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig, description="溢价控制")
|
||||
|
||||
# 数据
|
||||
data: DataConfig = Field(..., description="数据配置")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 通用策略配置 Schema(V2 架构)
|
||||
# ============================================================
|
||||
|
||||
class StrategyConfig(BaseModel):
|
||||
"""
|
||||
通用策略配置(支持所有策略类型)
|
||||
|
||||
使用示例:
|
||||
from framework_v2.config import load_config
|
||||
config = load_config('rotation.yaml')
|
||||
"""
|
||||
metadata: MetadataConfig = Field(default_factory=MetadataConfig, description="元数据")
|
||||
|
||||
# 资产池
|
||||
asset_pools: AssetPool = Field(..., description="资产池配置")
|
||||
|
||||
# 基准
|
||||
benchmark: BenchmarkConfig = Field(..., description="基准配置")
|
||||
|
||||
# 回测
|
||||
backtest: BacktestConfig = Field(..., description="回测配置")
|
||||
|
||||
# 因子
|
||||
factor: FactorConfig = Field(default_factory=FactorConfig, description="因子配置")
|
||||
|
||||
# 轮动(可选)
|
||||
rotation: Optional[RotationConfig] = Field(None, description="轮动配置")
|
||||
|
||||
# 调仓
|
||||
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig, description="调仓配置")
|
||||
|
||||
# 溢价控制
|
||||
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig, description="溢价控制")
|
||||
|
||||
# 数据
|
||||
data: DataConfig = Field(..., description="数据配置")
|
||||
@@ -8,72 +8,115 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
策略抽象基类(V2 增强版)
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置
|
||||
2. 获取数据
|
||||
3. 计算因子
|
||||
4. 生成信号
|
||||
5. 执行回测
|
||||
1. 初始化配置(使用 Pydantic Schema)
|
||||
2. 获取数据(支持多数据源)
|
||||
3. 计算因子(使用框架因子库)
|
||||
4. 生成信号(策略特定逻辑)
|
||||
5. 执行回测(框架通用执行器)
|
||||
|
||||
子类必须实现:
|
||||
- init_factors(): 初始化因子
|
||||
- init_signal_generator(): 初始化信号生成器
|
||||
- get_codes(): 获取标的列表
|
||||
- compute_factors(): 计算因子
|
||||
- generate_signals(): 生成信号
|
||||
- manage_positions(): 仓位管理
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置字典
|
||||
config: 策略配置(Pydantic Schema)
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._factor = None
|
||||
self._signal_generator = None
|
||||
self.config = config
|
||||
self.name = config.metadata.strategy
|
||||
|
||||
# 组件将在子类中初始化
|
||||
self._data_fetcher = None
|
||||
self._executor = None
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> Any:
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
初始化因子组件
|
||||
获取标的列表(策略必须实现)
|
||||
|
||||
Returns:
|
||||
因子实例(继承 FactorBase)
|
||||
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> Any:
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
初始化信号生成器
|
||||
计算因子(策略必须实现)
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
信号生成器实例(继承 SignalGenerator)
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, Any]:
|
||||
@abstractmethod
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
获取数据(可选覆盖)
|
||||
生成信号(策略必须实现)
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
数据字典,包含:
|
||||
- index_data: 指数数据
|
||||
- etf_data: ETF数据
|
||||
- benchmark_data: 基准数据
|
||||
- valid_codes: 有效标的列表
|
||||
- trading_calendar: 交易日历
|
||||
信号 DataFrame(包含 'signal' 列,1=买入,0=空仓)
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_data()")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(策略必须实现)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列,权重和为 1)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取数据(框架实现,策略可覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典 {code: DataFrame}
|
||||
"""
|
||||
if self._data_fetcher is None:
|
||||
self._data_fetcher = self._create_data_fetcher()
|
||||
|
||||
codes = self.get_codes()
|
||||
|
||||
# 批量获取数据(fetch_indices 返回 {code: DataFrame})
|
||||
try:
|
||||
data = self._data_fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start=self.config.backtest.start_date,
|
||||
end=self.config.backtest.end_date
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f" 错误: 数据获取失败 - {e}")
|
||||
return {}
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -111,28 +154,46 @@ class StrategyBase(ABC):
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程
|
||||
运行完整回测流程(框架标准流程)
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
回测结果字典,包含:
|
||||
- equity_curve: 净值曲线
|
||||
- trades: 交易记录
|
||||
- metrics: 绩效指标
|
||||
"""
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
print("[1/5] 获取数据...")
|
||||
data = self.get_data()
|
||||
print(f" 获取 {len(data)} 个标的")
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
print("[2/5] 计算因子...")
|
||||
factors = self.compute_factors(data)
|
||||
print(f" 计算 {len(factors)} 个因子")
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
print("[3/5] 生成信号...")
|
||||
signals = self.generate_signals(factors)
|
||||
print(f" 生成 {signals.shape[0]} 个信号")
|
||||
|
||||
# 4. 执行回测(子类实现)
|
||||
return self._execute_backtest(signals, data)
|
||||
# 4. 仓位管理
|
||||
print("[4/5] 仓位管理...")
|
||||
positions = self.manage_positions(signals)
|
||||
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
|
||||
|
||||
# 5. 执行回测
|
||||
print("[5/5] 执行回测...")
|
||||
result = self._execute_backtest(positions, data)
|
||||
print(f" 回测完成")
|
||||
|
||||
return result
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -147,5 +208,21 @@ class StrategyBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
def _create_data_fetcher(self):
|
||||
"""
|
||||
创建数据获取器(框架实现)
|
||||
|
||||
Returns:
|
||||
DataFetcher 实例
|
||||
"""
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
# 使用配置中的第一个启用的数据源
|
||||
for source_config in self.config.data.sources:
|
||||
if source_config.enabled and source_config.type.value == 'flask_api':
|
||||
return FlaskAPIFetcher(
|
||||
base_url=source_config.url,
|
||||
timeout=source_config.timeout
|
||||
)
|
||||
|
||||
raise ValueError("未找到可用的数据源配置")
|
||||
|
||||
92
framework_v2/strategies/rotation/config_simple.yaml
Normal file
92
framework_v2/strategies/rotation/config_simple.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
# 简单轮动策略配置
|
||||
#
|
||||
# 配置版本: 1.0.0
|
||||
# 最后更新: 2024-04-16
|
||||
# 策略名称: simple_rotation
|
||||
# 描述: 基于动量因子的简单 ETF 轮动策略
|
||||
|
||||
# ============================================================
|
||||
# 元数据
|
||||
# ============================================================
|
||||
metadata:
|
||||
version: "1.0.0"
|
||||
strategy: "simple_rotation"
|
||||
description: "简单轮动策略 - 等权分配 + Top-N 选择"
|
||||
last_updated: "2024-04-16"
|
||||
|
||||
# ============================================================
|
||||
# 资产池配置(简化版:只选 3 个标的)
|
||||
# ============================================================
|
||||
asset_pools:
|
||||
equity:
|
||||
"399006.SZ":
|
||||
name: "创业板指"
|
||||
etf: "159915.SZ"
|
||||
market: "CN_EQUITY"
|
||||
description: "创业板指数"
|
||||
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
etf: "513100.SH"
|
||||
market: "US_EQUITY"
|
||||
description: "纳斯达克100指数"
|
||||
|
||||
commodity: {}
|
||||
fixed_income: {}
|
||||
|
||||
# ============================================================
|
||||
# 基准配置
|
||||
# ============================================================
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300"
|
||||
|
||||
# ============================================================
|
||||
# 回测配置
|
||||
# ============================================================
|
||||
backtest:
|
||||
start_date: "2023-01-01"
|
||||
end_date: "2024-12-31"
|
||||
|
||||
# ============================================================
|
||||
# 因子配置
|
||||
# ============================================================
|
||||
factor:
|
||||
type: "weighted_momentum" # 加权动量
|
||||
n_days: 25 # 25 天窗口
|
||||
|
||||
# ============================================================
|
||||
# 轮动配置
|
||||
# ============================================================
|
||||
rotation:
|
||||
select_num: 2 # 选择 Top-2
|
||||
threshold:
|
||||
mode: "fixed"
|
||||
fixed_value: 0.0 # 无阈值过滤
|
||||
|
||||
# ============================================================
|
||||
# 调仓配置
|
||||
# ============================================================
|
||||
rebalance:
|
||||
min_hold_days: 1
|
||||
score_threshold: 0.0
|
||||
trade_cost: 0.001 # 0.1% 交易成本
|
||||
|
||||
# ============================================================
|
||||
# 溢价控制(禁用)
|
||||
# ============================================================
|
||||
premium_control:
|
||||
enabled: false
|
||||
|
||||
# ============================================================
|
||||
# 数据配置
|
||||
# ============================================================
|
||||
data:
|
||||
sources:
|
||||
- type: "flask_api"
|
||||
enabled: true
|
||||
url: "${FLASK_API_URL}"
|
||||
timeout: 120
|
||||
|
||||
use_cache: true
|
||||
cache_dir: "data_cache"
|
||||
241
framework_v2/strategies/rotation/simple.py
Normal file
241
framework_v2/strategies/rotation/simple.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
简单轮动策略
|
||||
|
||||
基于动量因子的 ETF 轮动策略
|
||||
- 计算各标的动量得分
|
||||
- 选择 Top-N 标的
|
||||
- 等权分配仓位
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
from framework_v2.core.strategy import StrategyBase
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
from framework_v2.shared.factors import MomentumFactor
|
||||
|
||||
|
||||
class SimpleRotationStrategy(StrategyBase):
|
||||
"""
|
||||
简单轮动策略
|
||||
|
||||
策略逻辑:
|
||||
1. 计算各标的动量得分(加权线性回归)
|
||||
2. 选择得分最高的 Top-N 标的
|
||||
3. 等权分配仓位
|
||||
|
||||
示例:
|
||||
from framework_v2.config import load_config
|
||||
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
|
||||
|
||||
config = load_config('rotation_simple.yaml')
|
||||
strategy = SimpleRotationStrategy(config)
|
||||
result = strategy.run()
|
||||
"""
|
||||
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# 初始化动量因子
|
||||
self.momentum = MomentumFactor(
|
||||
n_days=config.factor.n_days,
|
||||
weighted=(config.factor.type.value == 'weighted_momentum')
|
||||
)
|
||||
|
||||
# 策略参数
|
||||
self.select_num = config.rotation.select_num if config.rotation else 3
|
||||
self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0
|
||||
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
获取标的列表
|
||||
|
||||
从配置的资产池中获取所有标的
|
||||
"""
|
||||
codes = []
|
||||
|
||||
# 股票资产
|
||||
if self.config.asset_pools.equity:
|
||||
codes.extend(self.config.asset_pools.equity.keys())
|
||||
|
||||
# 商品资产
|
||||
if self.config.asset_pools.commodity:
|
||||
codes.extend(self.config.asset_pools.commodity.keys())
|
||||
|
||||
# 固定收益资产
|
||||
if self.config.asset_pools.fixed_income:
|
||||
codes.extend(self.config.asset_pools.fixed_income.keys())
|
||||
|
||||
return codes
|
||||
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
计算动量因子
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
factors = {}
|
||||
|
||||
for code, df in data.items():
|
||||
try:
|
||||
# 计算动量得分
|
||||
factor_values = self.momentum.compute(df)
|
||||
factors[code] = factor_values
|
||||
except Exception as e:
|
||||
print(f" 警告: {code} 因子计算失败 - {e}")
|
||||
continue
|
||||
|
||||
return factors
|
||||
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
生成轮动信号
|
||||
|
||||
逻辑:
|
||||
1. 每个交易日选择动量得分最高的 Top-N 标的
|
||||
2. 过滤得分低于阈值的标的
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(index=日期, columns=标的, values=1或0)
|
||||
"""
|
||||
if not factors:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 对齐所有因子的日期
|
||||
factor_df = pd.DataFrame(factors)
|
||||
|
||||
# 生成信号
|
||||
signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0)
|
||||
|
||||
for date in factor_df.index:
|
||||
# 获取当日因子值
|
||||
scores = factor_df.loc[date].dropna()
|
||||
|
||||
if scores.empty:
|
||||
continue
|
||||
|
||||
# 过滤低分标的
|
||||
if self.min_score > 0:
|
||||
scores = scores[scores >= self.min_score]
|
||||
|
||||
# 选择 Top-N
|
||||
if len(scores) > self.select_num:
|
||||
top_codes = scores.nlargest(self.select_num).index
|
||||
else:
|
||||
top_codes = scores.index
|
||||
|
||||
# 标记信号
|
||||
signals.loc[date, top_codes] = 1
|
||||
|
||||
return signals.astype(int)
|
||||
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(等权分配)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列)
|
||||
"""
|
||||
positions = signals.astype(float).copy()
|
||||
|
||||
# 计算每个日期的权重
|
||||
for date in positions.index:
|
||||
signal_row = positions.loc[date]
|
||||
n_selected = signal_row.sum()
|
||||
|
||||
if n_selected > 0:
|
||||
# 等权分配
|
||||
positions.loc[date] = signal_row / n_selected
|
||||
else:
|
||||
# 空仓
|
||||
positions.loc[date] = 0
|
||||
|
||||
return positions
|
||||
|
||||
def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]:
|
||||
"""
|
||||
执行回测
|
||||
|
||||
Args:
|
||||
positions: 仓位 DataFrame
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
# 提取收盘价
|
||||
close_prices = {}
|
||||
for code, df in data.items():
|
||||
if 'close' in df.columns:
|
||||
close_prices[code] = df['close']
|
||||
|
||||
close_df = pd.DataFrame(close_prices)
|
||||
|
||||
# 计算收益率
|
||||
returns = close_df.pct_change()
|
||||
|
||||
# 计算策略收益(仓位加权)
|
||||
# 注意:T+1 执行,今天的信号明天生效
|
||||
positions_delayed = positions.shift(1).fillna(0)
|
||||
strategy_returns = (positions_delayed * returns).sum(axis=1)
|
||||
|
||||
# 计算净值曲线
|
||||
equity_curve = (1 + strategy_returns).cumprod()
|
||||
|
||||
# 检查是否有数据
|
||||
if len(equity_curve) == 0:
|
||||
return {
|
||||
'equity_curve': equity_curve,
|
||||
'strategy_returns': strategy_returns,
|
||||
'positions': positions,
|
||||
'metrics': {
|
||||
'total_return': 0,
|
||||
'annual_return': 0,
|
||||
'max_drawdown': 0,
|
||||
'sharpe_ratio': 0,
|
||||
'n_days': 0,
|
||||
}
|
||||
}
|
||||
|
||||
# 计算绩效指标
|
||||
total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1
|
||||
n_days = len(strategy_returns)
|
||||
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||||
|
||||
# 最大回撤
|
||||
cumulative_max = equity_curve.cummax()
|
||||
drawdown = (equity_curve - cumulative_max) / cumulative_max
|
||||
max_drawdown = drawdown.min()
|
||||
|
||||
# 夏普比率
|
||||
sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0
|
||||
|
||||
return {
|
||||
'equity_curve': equity_curve,
|
||||
'strategy_returns': strategy_returns,
|
||||
'positions': positions,
|
||||
'metrics': {
|
||||
'total_return': total_return,
|
||||
'annual_return': annual_return,
|
||||
'max_drawdown': max_drawdown,
|
||||
'sharpe_ratio': sharpe,
|
||||
'n_days': n_days,
|
||||
}
|
||||
}
|
||||
285
framework_v2/tests/test_config.py
Normal file
285
framework_v2/tests/test_config.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
测试配置加载和验证
|
||||
|
||||
验证:
|
||||
1. 配置文件加载
|
||||
2. Pydantic Schema 验证
|
||||
3. 环境变量替换
|
||||
4. 错误处理
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.config import load_config, ConfigLoader
|
||||
from framework_v2.config.schemas import RotationStrategyConfig, MarketType
|
||||
|
||||
|
||||
def test_load_config():
|
||||
"""测试 1: 加载配置文件"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 1: 加载配置文件")
|
||||
print("=" * 70)
|
||||
|
||||
# 设置环境变量(模拟)
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
os.environ['TUSHARE_TOKEN'] = 'test_token_123'
|
||||
|
||||
# 加载配置
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n✓ 配置加载成功")
|
||||
print(f" 版本: {config.metadata.version}")
|
||||
print(f" 策略: {config.metadata.strategy}")
|
||||
print(f" 资产池: {len(config.asset_pools.equity)} 股票, "
|
||||
f"{len(config.asset_pools.commodity)} 商品, "
|
||||
f"{len(config.asset_pools.fixed_income)} 债券")
|
||||
|
||||
# 验证基本字段
|
||||
assert config.metadata.version == "1.0.0"
|
||||
assert config.factor.n_days == 25
|
||||
assert config.rotation.select_num == 3
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_asset_pools():
|
||||
"""测试 2: 资产池配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 2: 资产池配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
# 验证股票资产
|
||||
print(f"\n股票资产 ({len(config.asset_pools.equity)} 只):")
|
||||
for code, asset in config.asset_pools.equity.items():
|
||||
print(f" {code}: {asset.name} ({asset.market.value})")
|
||||
if asset.etf:
|
||||
print(f" ETF: {asset.etf}")
|
||||
|
||||
# 验证商品资产
|
||||
print(f"\n商品资产 ({len(config.asset_pools.commodity)} 只):")
|
||||
for code, asset in config.asset_pools.commodity.items():
|
||||
print(f" {code}: {asset.name}")
|
||||
|
||||
# 验证固定收益
|
||||
print(f"\n固定收益 ({len(config.asset_pools.fixed_income)} 只):")
|
||||
for code, asset in config.asset_pools.fixed_income.items():
|
||||
print(f" {code}: {asset.name} (ETF: {asset.etf})")
|
||||
|
||||
# 验证市场类型
|
||||
assert config.asset_pools.equity["399006.SZ"].market == MarketType.CN_EQUITY
|
||||
assert config.asset_pools.equity["NDX"].market == MarketType.US_EQUITY
|
||||
assert config.asset_pools.commodity["GC=F"].market == MarketType.COMMODITY
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_threshold_config():
|
||||
"""测试 3: 阈值配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 3: 阈值配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n阈值模式: {config.rotation.threshold.mode}")
|
||||
print(f" 参考标的: {config.rotation.threshold.dynamic.reference}")
|
||||
print(f" 倍数: {config.rotation.threshold.dynamic.ratio}")
|
||||
print(f" 回退启用: {config.rotation.threshold.dynamic.fallback_enabled}")
|
||||
|
||||
assert config.rotation.threshold.mode.value == "dynamic"
|
||||
assert config.rotation.threshold.dynamic.reference == "931862.CSI"
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_data_sources():
|
||||
"""测试 4: 数据源配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 4: 数据源配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_example.yaml')
|
||||
|
||||
print(f"\n数据源 ({len(config.data.sources)} 个):")
|
||||
for i, source in enumerate(config.data.sources, 1):
|
||||
print(f" {i}. {source.type.value}")
|
||||
print(f" 启用: {source.enabled}")
|
||||
print(f" 超时: {source.timeout}s")
|
||||
if source.url:
|
||||
print(f" URL: {source.url}")
|
||||
|
||||
# 验证环境变量替换
|
||||
flask_api_source = config.data.sources[0]
|
||||
assert flask_api_source.url == 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_validation_errors():
|
||||
"""测试 5: 验证错误处理"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 5: 验证错误处理")
|
||||
print("=" * 70)
|
||||
|
||||
# 测试 1: n_days 超出范围
|
||||
print("\n[5.1] 测试 n_days 超出范围...")
|
||||
try:
|
||||
from framework_v2.config.schemas import FactorConfig
|
||||
|
||||
# n_days = 1000(超出 5-250 范围)
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"factor": {"n_days": 1000}, # 错误:超出范围
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "test"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||||
|
||||
# 测试 2: 缺少必需字段
|
||||
print("\n[5.2] 测试缺少必需字段...")
|
||||
try:
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
# 缺少 benchmark
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "test"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except Exception as e:
|
||||
print(f" ✓ 正确捕获验证错误: {type(e).__name__}")
|
||||
|
||||
# 测试 3: 环境变量未设置
|
||||
print("\n[5.3] 测试环境变量未设置...")
|
||||
try:
|
||||
# 删除环境变量
|
||||
old_value = os.environ.pop('FLASK_API_URL', None)
|
||||
|
||||
invalid_config = {
|
||||
"asset_pools": {"equity": {}, "commodity": {}, "fixed_income": {}},
|
||||
"benchmark": {"code": "000300.SH", "name": "沪深300"},
|
||||
"backtest": {"start_date": "2020-01-01"},
|
||||
"data": {
|
||||
"sources": [{"type": "flask_api", "url": "${FLASK_API_URL}"}]
|
||||
}
|
||||
}
|
||||
|
||||
RotationStrategyConfig(**invalid_config)
|
||||
print(" ✗ 应该抛出验证错误")
|
||||
assert False
|
||||
except ValueError as e:
|
||||
print(f" ✓ 正确捕获环境变量错误: {e}")
|
||||
finally:
|
||||
# 恢复环境变量
|
||||
if old_value:
|
||||
os.environ['FLASK_API_URL'] = old_value
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_env_substitution():
|
||||
"""测试 6: 环境变量替换"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 6: 环境变量替换")
|
||||
print("=" * 70)
|
||||
|
||||
loader = ConfigLoader()
|
||||
|
||||
# 测试 1: 基本替换
|
||||
print("\n[6.1] 基本替换...")
|
||||
os.environ['TEST_VAR'] = 'test_value'
|
||||
|
||||
config = {
|
||||
"url": "${TEST_VAR}"
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["url"] == "test_value"
|
||||
print(f" ✓ ${{TEST_VAR}} → {result['url']}")
|
||||
|
||||
# 测试 2: 默认值
|
||||
print("\n[6.2] 默认值...")
|
||||
config = {
|
||||
"url": "${NON_EXISTENT_VAR:default_value}"
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["url"] == "default_value"
|
||||
print(f" ${{NON_EXISTENT_VAR:default_value}} → {result['url']}")
|
||||
|
||||
# 测试 3: 嵌套结构
|
||||
print("\n[6.3] 嵌套结构...")
|
||||
os.environ['API_URL'] = 'https://api.example.com'
|
||||
|
||||
config = {
|
||||
"data": {
|
||||
"sources": [
|
||||
{"url": "${API_URL}"}
|
||||
]
|
||||
}
|
||||
}
|
||||
result = loader._substitute_env_vars(config)
|
||||
assert result["data"]["sources"][0]["url"] == "https://api.example.com"
|
||||
print(f" ✓ 嵌套替换成功")
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print(" 配置加载和验证测试")
|
||||
print("=" * 70)
|
||||
|
||||
tests = [
|
||||
("加载配置文件", test_load_config),
|
||||
("资产池配置", test_asset_pools),
|
||||
("阈值配置", test_threshold_config),
|
||||
("数据源配置", test_data_sources),
|
||||
("验证错误处理", test_validation_errors),
|
||||
("环境变量替换", test_env_substitution),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试总结")
|
||||
print("=" * 70)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
if failed > 0:
|
||||
sys.exit(1)
|
||||
281
framework_v2/tests/test_flat_asset_pool.py
Normal file
281
framework_v2/tests/test_flat_asset_pool.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
测试扁平化资产池配置
|
||||
|
||||
验证:
|
||||
1. 扁平化配置加载
|
||||
2. 按市场分组
|
||||
3. 信号/交易标的获取
|
||||
4. 跨市场映射
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.config import load_config
|
||||
from framework_v2.config.schemas import GroupConfig
|
||||
|
||||
|
||||
def test_flat_config_load():
|
||||
"""测试 1: 加载扁平化配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 1: 加载扁平化配置")
|
||||
print("=" * 70)
|
||||
|
||||
# 设置环境变量
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
os.environ['TUSHARE_TOKEN'] = 'test_token'
|
||||
|
||||
# 加载配置
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
print(f"\n✓ 配置加载成功")
|
||||
print(f" 版本: {config.metadata.version}")
|
||||
print(f" 策略: {config.metadata.strategy}")
|
||||
print(f" 总标的数: {config.asset_pools.count()}")
|
||||
|
||||
# 验证基本字段
|
||||
assert config.metadata.version == "2.0.0"
|
||||
assert config.asset_pools.count() == 13 # 12 个标的
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_market_grouping():
|
||||
"""测试 2: 按市场分组"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 2: 按市场分组")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取所有市场
|
||||
markets = config.asset_pools.groups
|
||||
print(f"\n市场类型 ({len(markets)} 个):")
|
||||
for market in markets:
|
||||
count = config.asset_pools.count(market)
|
||||
print(f" {market}: {count} 只")
|
||||
|
||||
# 按市场分组
|
||||
by_group = config.asset_pools.by_group
|
||||
print(f"\n市场分组:")
|
||||
for market, assets in by_group.items():
|
||||
print(f"\n {market} ({len(assets)} 只):")
|
||||
for code, asset in assets.items():
|
||||
print(f" {code}: {asset.name}")
|
||||
|
||||
# 验证市场数量
|
||||
assert len(markets) == 7 # US_TECH, CN_GROWTH, JP_BROAD, EU_BROAD, HK_TECH, COMMODITY, FIXED_INCOME # US_EQUITY, CN_EQUITY, JP_EQUITY, EU_EQUITY, HK_EQUITY, COMMODITY, FIXED_INCOME
|
||||
assert config.asset_pools.count('US_TECH') == 2
|
||||
assert config.asset_pools.count('CN_GROWTH') == 3
|
||||
assert config.asset_pools.count('COMMODITY') == 3
|
||||
assert config.asset_pools.count('FIXED_INCOME') == 1
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_signal_trade_codes():
|
||||
"""测试 3: 信号和交易标的"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 3: 信号和交易标的")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取所有信号标的
|
||||
signal_codes = config.asset_pools.get_signal_codes()
|
||||
print(f"\n信号标的 (13 个):")
|
||||
for code in signal_codes:
|
||||
print(f" {code}")
|
||||
|
||||
# 获取所有交易标的
|
||||
trade_codes = config.asset_pools.get_trade_codes()
|
||||
print(f"\n交易标的 (13 个):")
|
||||
for code in trade_codes:
|
||||
print(f" {code}")
|
||||
|
||||
# 获取特定市场的信号标的
|
||||
us_signals = config.asset_pools.get_signal_codes('US_TECH')
|
||||
print(f"\n美股信号标的: {us_signals}")
|
||||
|
||||
# 验证
|
||||
assert len(signal_codes) == 13
|
||||
assert len(trade_codes) == 13
|
||||
assert 'NDX' in signal_codes
|
||||
assert '513100.SH' in trade_codes
|
||||
assert len(us_signals) == 2
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_signal_to_trade_mapping():
|
||||
"""测试 4: 信号→交易映射"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 4: 信号→交易映射")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取映射
|
||||
mapping = config.asset_pools.get_signal_to_trade_mapping()
|
||||
|
||||
print(f"\n信号→交易映射:")
|
||||
for signal, trade in mapping.items():
|
||||
asset = config.asset_pools.assets.get(signal)
|
||||
cross_market = "✗" if asset.signal_source == asset.trade_source else "✓"
|
||||
print(f" {cross_market} {signal} → {trade}")
|
||||
|
||||
# 验证跨市场标的
|
||||
print(f"\n跨市场标的:")
|
||||
for code, asset in config.asset_pools.assets.items():
|
||||
if asset.is_cross_market:
|
||||
print(f" {code}: {asset.signal_source} → {asset.trade_source}")
|
||||
|
||||
# 验证映射
|
||||
assert mapping['NDX'] == '513100.SH'
|
||||
assert mapping['399006.SZ'] == '159915.SZ'
|
||||
assert mapping['GC=F'] == '518880.SH'
|
||||
assert mapping['931862.CSI'] == '931862.CSI' # 非跨市场
|
||||
|
||||
# 验证跨市场属性
|
||||
assert config.asset_pools.assets['NDX'].is_cross_market == True
|
||||
assert config.asset_pools.assets['931862.CSI'].is_cross_market == False
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_market_specific_mapping():
|
||||
"""测试 5: 特定市场映射"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 5: 特定市场映射")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 获取美股映射
|
||||
us_mapping = config.asset_pools.get_signal_to_trade_mapping('US_TECH')
|
||||
print(f"\n美股映射:")
|
||||
for signal, trade in us_mapping.items():
|
||||
print(f" {signal} → {trade}")
|
||||
|
||||
# 获取商品映射
|
||||
commodity_mapping = config.asset_pools.get_signal_to_trade_mapping('COMMODITY')
|
||||
print(f"\n商品映射:")
|
||||
for signal, trade in commodity_mapping.items():
|
||||
print(f" {signal} → {trade}")
|
||||
|
||||
# 验证
|
||||
assert len(us_mapping) == 2
|
||||
assert len(commodity_mapping) == 3
|
||||
assert us_mapping['NDX'] == '513100.SH'
|
||||
assert commodity_mapping['GC=F'] == '518880.SH'
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_diversification_config():
|
||||
"""测试 6: 分散化配置"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 6: 分散化配置")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
print(f"\n轮动配置:")
|
||||
print(f" 选股数量: {config.rotation.select_num}")
|
||||
print(f" 分散化: {config.rotation.diversified}")
|
||||
print(f" 分散化分组: {config.rotation.diversification_groups}")
|
||||
|
||||
# 验证默认配置(全局模式)
|
||||
assert config.rotation.select_num == 5
|
||||
assert config.rotation.diversified == False
|
||||
assert config.rotation.diversification_groups is None
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
def test_asset_config_details():
|
||||
"""测试 7: 标的配置详情"""
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试 7: 标的配置详情")
|
||||
print("=" * 70)
|
||||
|
||||
config = load_config('rotation_global.yaml')
|
||||
|
||||
# 检查纳指配置
|
||||
ndx = config.asset_pools.assets['NDX']
|
||||
print(f"\n纳指100 配置:")
|
||||
print(f" 名称: {ndx.name}")
|
||||
print(f" 市场: {ndx.group}")
|
||||
print(f" 信号来源: {ndx.signal_source}")
|
||||
print(f" 交易来源: {ndx.trade_source}")
|
||||
print(f" 跨市场: {ndx.is_cross_market}")
|
||||
print(f" 描述: {ndx.description}")
|
||||
|
||||
# 检查短债配置
|
||||
bond = config.asset_pools.assets['931862.CSI']
|
||||
print(f"\n短债指数 配置:")
|
||||
print(f" 名称: {bond.name}")
|
||||
print(f" 市场: {bond.group}")
|
||||
print(f" 信号来源: {bond.signal_source}")
|
||||
print(f" 交易来源: {bond.trade_source}")
|
||||
print(f" 跨市场: {bond.is_cross_market}")
|
||||
|
||||
# 验证
|
||||
assert ndx.name == "纳指100"
|
||||
assert ndx.group == 'US_TECH'
|
||||
assert ndx.signal_source == "NDX"
|
||||
assert ndx.trade_source == "513100.SH"
|
||||
assert ndx.is_cross_market == True
|
||||
|
||||
assert bond.signal_source == bond.trade_source
|
||||
assert bond.is_cross_market == False
|
||||
|
||||
print("\n✓ 测试通过")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "=" * 70)
|
||||
print(" 扁平化资产池配置测试")
|
||||
print("=" * 70)
|
||||
|
||||
tests = [
|
||||
("加载扁平化配置", test_flat_config_load),
|
||||
("按市场分组", test_market_grouping),
|
||||
("信号和交易标的", test_signal_trade_codes),
|
||||
("信号→交易映射", test_signal_to_trade_mapping),
|
||||
("特定市场映射", test_market_specific_mapping),
|
||||
("分散化配置", test_diversification_config),
|
||||
("标的配置详情", test_asset_config_details),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
test_func()
|
||||
passed += 1
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {name}")
|
||||
print(f" 错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 测试总结")
|
||||
print("=" * 70)
|
||||
print(f" ✓ 通过 - {passed}")
|
||||
if failed > 0:
|
||||
print(f" ✗ 失败 - {failed}")
|
||||
print(f"\n总计: {passed}/{passed + failed} 通过")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
if failed > 0:
|
||||
sys.exit(1)
|
||||
128
framework_v2/tests/test_simple_rotation.py
Normal file
128
framework_v2/tests/test_simple_rotation.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
测试简单轮动策略
|
||||
|
||||
验证完整流程:
|
||||
1. 配置加载
|
||||
2. 策略初始化
|
||||
3. 数据获取
|
||||
4. 因子计算
|
||||
5. 信号生成
|
||||
6. 回测执行
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from framework_v2.config import load_config
|
||||
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
|
||||
|
||||
|
||||
def test_simple_rotation():
|
||||
"""测试简单轮动策略完整流程"""
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" 简单轮动策略端到端测试")
|
||||
print("=" * 70)
|
||||
|
||||
# 设置环境变量
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
# 1. 加载配置
|
||||
print("\n[1/6] 加载配置...")
|
||||
config_path = Path(__file__).parent.parent / 'strategies' / 'rotation' / 'config_simple.yaml'
|
||||
config = load_config(str(config_path))
|
||||
print(f" ✓ 配置加载成功")
|
||||
print(f" 策略: {config.metadata.strategy}")
|
||||
print(f" 标的: {list(config.asset_pools.equity.keys())}")
|
||||
print(f" 回测: {config.backtest.start_date} ~ {config.backtest.end_date}")
|
||||
|
||||
# 2. 初始化策略
|
||||
print("\n[2/6] 初始化策略...")
|
||||
strategy = SimpleRotationStrategy(config)
|
||||
print(f" ✓ 策略初始化成功")
|
||||
print(f" 名称: {strategy.name}")
|
||||
print(f" 动量窗口: {config.factor.n_days} 天")
|
||||
print(f" 选股数量: {strategy.select_num}")
|
||||
|
||||
# 3. 获取数据
|
||||
print("\n[3/6] 获取数据...")
|
||||
codes = strategy.get_codes()
|
||||
print(f" 标的列表: {codes}")
|
||||
|
||||
data = strategy.get_data()
|
||||
print(f" ✓ 获取 {len(data)} 个标的")
|
||||
for code, df in data.items():
|
||||
print(f" {code}: {len(df)} 天 ({df.index[0].date()} ~ {df.index[-1].date()})")
|
||||
|
||||
# 4. 计算因子
|
||||
print("\n[4/6] 计算因子...")
|
||||
factors = strategy.compute_factors(data)
|
||||
print(f" ✓ 计算 {len(factors)} 个因子")
|
||||
for code, factor in factors.items():
|
||||
print(f" {code}: {len(factor)} 值, 范围 [{factor.min():.4f}, {factor.max():.4f}]")
|
||||
|
||||
# 5. 生成信号
|
||||
print("\n[5/6] 生成信号...")
|
||||
signals = strategy.generate_signals(factors)
|
||||
n_signals = signals.sum().sum()
|
||||
print(f" ✓ 生成 {signals.shape[0]} 个交易日信号")
|
||||
print(f" 总信号数: {n_signals}")
|
||||
print(f" 平均每日持仓: {signals.mean().mean():.2%}")
|
||||
|
||||
# 6. 仓位管理
|
||||
print("\n[6/6] 仓位管理...")
|
||||
positions = strategy.manage_positions(signals)
|
||||
print(f" ✓ 仓位分配完成")
|
||||
print(f" 权重和: {positions.sum(axis=1).mean():.2%}")
|
||||
|
||||
# 7. 执行回测
|
||||
print("\n执行回测...")
|
||||
result = strategy._execute_backtest(positions, data)
|
||||
|
||||
# 打印结果
|
||||
print("\n" + "=" * 70)
|
||||
print(" 回测结果")
|
||||
print("=" * 70)
|
||||
|
||||
metrics = result['metrics']
|
||||
print(f"\n 总收益率: {metrics['total_return']:.2%}")
|
||||
print(f" 年化收益: {metrics['annual_return']:.2%}")
|
||||
print(f" 最大回撤: {metrics['max_drawdown']:.2%}")
|
||||
print(f" 夏普比率: {metrics['sharpe_ratio']:.2f}")
|
||||
print(f" 交易天数: {metrics['n_days']}")
|
||||
|
||||
# 验证结果
|
||||
print("\n" + "=" * 70)
|
||||
print(" 验证")
|
||||
print("=" * 70)
|
||||
|
||||
assert metrics['total_return'] != 0, "总收益率不应为 0"
|
||||
print(" ✓ 总收益率有效")
|
||||
|
||||
assert len(result['equity_curve']) > 0, "净值曲线不应为空"
|
||||
print(" ✓ 净值曲线有效")
|
||||
|
||||
assert positions.sum(axis=1).max() <= 1.01, "权重和不应超过 100%"
|
||||
print(" ✓ 仓位权重有效")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" ✓ 所有测试通过")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
result = test_simple_rotation()
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
477
scripts/export_backtest_detail.py
Normal file
477
scripts/export_backtest_detail.py
Normal file
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
导出回测逐日明细到 JSON,供 HTML 回放器加载。
|
||||
|
||||
模式 B:指数信号 + ETF 收益(2020-01-01 ~ 2026-05-19)
|
||||
|
||||
用法:
|
||||
python scripts/export_backtest_detail.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import yaml
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from datasource.tushare_source import TushareSource
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
from strategies.shared.factors.momentum import MomentumFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from framework.execution import BacktestExecutor
|
||||
|
||||
# ==================== 加载配置 ====================
|
||||
config_path = project_root / 'strategies' / 'rotation' / 'config.yaml'
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
CODE_LIST = config['code_list']
|
||||
SELECT_NUM = config['select_num']
|
||||
N_DAYS = config['n_days']
|
||||
TRADE_COST = config['trade_cost']
|
||||
BOND_THRESHOLD = config.get('bond_threshold', {})
|
||||
BOND_CODE = BOND_THRESHOLD.get('bond_code', '931862.CSI')
|
||||
BOND_RATIO = BOND_THRESHOLD.get('ratio', 1.0)
|
||||
|
||||
|
||||
def fetch_all_data(start_date='2018-01-01', end_date='2026-05-19'):
|
||||
ts = TushareSource()
|
||||
api = FlaskAPIDataSource() # 默认使用 k3s.tokenpluse.xyz
|
||||
|
||||
index_data = {}
|
||||
etf_data = {}
|
||||
etf_code_map = {}
|
||||
|
||||
# 统一使用 Flask API 获取所有指数数据(与 strategy.py 保持一致)
|
||||
print("[指数数据] - 通过 Flask API (k3s服务) 获取")
|
||||
index_codes = list(CODE_LIST.keys())
|
||||
index_ohlcv_data = api.fetch_batch(index_codes, start_date, end_date)
|
||||
|
||||
for code, df in index_ohlcv_data.items():
|
||||
if df is not None and 'close' in df.columns and len(df) > 0:
|
||||
index_data[code] = df
|
||||
name = CODE_LIST.get(code, {}).get('name', code)
|
||||
print(f" {code} ({name})... {len(df)}天")
|
||||
else:
|
||||
name = CODE_LIST.get(code, {}).get('name', code)
|
||||
print(f" {code} ({name})... 失败")
|
||||
|
||||
print("\n[ETF数据]")
|
||||
etf_nav_data = {}
|
||||
for code, cfg in CODE_LIST.items():
|
||||
etf_code = cfg.get('etf')
|
||||
if etf_code is None:
|
||||
continue
|
||||
etf_code_map[code] = etf_code
|
||||
name = cfg['name']
|
||||
print(f" {etf_code} ({name})...", end=' ')
|
||||
|
||||
df = ts.fetch_etf_adj(etf_code, start_date, end_date)
|
||||
if df is not None and 'close_hfq' in df.columns and len(df) > 0:
|
||||
adj_ratio = df['close_hfq'] / df['close']
|
||||
df['open_hfq'] = df['open'] * adj_ratio
|
||||
etf_data[code] = df
|
||||
print(f"{len(df)}天", end='')
|
||||
else:
|
||||
print("失败")
|
||||
continue
|
||||
|
||||
# 获取ETF净值(用于计算溢价率)
|
||||
nav_df = ts.fetch_etf_nav(etf_code, start_date, end_date)
|
||||
if nav_df is not None and 'nav' in nav_df.columns and len(nav_df) > 0:
|
||||
etf_nav_data[code] = nav_df['nav']
|
||||
print(f" nav={len(nav_df)}天")
|
||||
else:
|
||||
print(" nav=无")
|
||||
|
||||
return index_data, etf_data, etf_code_map, etf_nav_data
|
||||
|
||||
|
||||
def compute_factors(price_data, n_days, trade_dates):
|
||||
"""先在原始交易日历上计算因子,再 ffill 对齐到 A 股日历(与 strategy.py 一致)"""
|
||||
factor = MomentumFactor(n_days=n_days, weighted=True, crash_filter=True)
|
||||
factor_values = {}
|
||||
for code, df in price_data.items():
|
||||
if 'close' not in df.columns:
|
||||
continue
|
||||
close_series = df['close'].dropna()
|
||||
if len(close_series) == 0:
|
||||
continue
|
||||
values = factor.compute(pd.DataFrame({'close': close_series}))
|
||||
factor_values[code] = values.reindex(trade_dates, method='ffill')
|
||||
return pd.DataFrame(factor_values)
|
||||
|
||||
|
||||
def generate_signals(factor_df, group_mapping):
|
||||
selector = TopNSelector(
|
||||
select_num=SELECT_NUM,
|
||||
group_mapping=group_mapping,
|
||||
min_score=0.0,
|
||||
rebalance_days=1,
|
||||
rebalance_threshold=0.0,
|
||||
bond_threshold_config=BOND_THRESHOLD
|
||||
)
|
||||
return selector.generate(factor_df)
|
||||
|
||||
|
||||
def safe_val(v, decimals=4):
|
||||
if v is None or (isinstance(v, float) and (math.isnan(v) or math.isinf(v))):
|
||||
return None
|
||||
if isinstance(v, (np.floating, float)):
|
||||
return round(float(v), decimals)
|
||||
if isinstance(v, (np.integer, int)):
|
||||
return int(v)
|
||||
return v
|
||||
|
||||
|
||||
def main():
|
||||
from datetime import datetime
|
||||
backtest_start = '2020-01-01'
|
||||
backtest_end = datetime.now().strftime('%Y-%m-%d') # 动态获取当前日期
|
||||
|
||||
print("=" * 60)
|
||||
print(" 导出回测逐日明细 (模式B: 指数信号 + ETF收益)")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 获取数据
|
||||
print("\n[1] 获取数据...")
|
||||
index_data, etf_data, etf_code_map, etf_nav_data = fetch_all_data()
|
||||
|
||||
# 2. A股交易日历
|
||||
print("\n[2] 获取A股交易日历...")
|
||||
ts = TushareSource()
|
||||
a_share_dates = ts.fetch_trade_cal(backtest_start, backtest_end)
|
||||
print(f" {len(a_share_dates)} 天")
|
||||
|
||||
# 3. 分组映射
|
||||
group_mapping = {}
|
||||
for code, cfg in CODE_LIST.items():
|
||||
if isinstance(cfg, dict):
|
||||
group_mapping[code] = cfg.get('market', 'default')
|
||||
|
||||
valid_codes = [c for c in CODE_LIST if c in index_data]
|
||||
|
||||
# 4. 计算因子(指数信号)
|
||||
print("\n[3] 计算指数动量因子...")
|
||||
idx_price_data = {}
|
||||
for code in valid_codes:
|
||||
if code in index_data and 'close' in index_data[code].columns:
|
||||
idx_price_data[code] = index_data[code]
|
||||
factor_df = compute_factors(idx_price_data, N_DAYS, a_share_dates)
|
||||
print(f" {len(factor_df.columns)} 只, {len(factor_df)} 天")
|
||||
|
||||
# 5. 生成信号
|
||||
print("\n[4] 生成信号...")
|
||||
signals = generate_signals(factor_df, group_mapping)
|
||||
print(f" {len(signals)} 天")
|
||||
|
||||
# 6. 准备ETF收益率(模式B)
|
||||
print("\n[5] 准备ETF收益率...")
|
||||
etf_close_hfq_aligned = {}
|
||||
etf_close_aligned = {}
|
||||
etf_open_aligned = {}
|
||||
etf_close_hfq_raw = {}
|
||||
index_close_aligned = {}
|
||||
returns_etf = {}
|
||||
returns_idx = {}
|
||||
|
||||
for code in valid_codes:
|
||||
# 指数收盘价和收益率
|
||||
if code in index_data and 'close' in index_data[code].columns:
|
||||
ic = index_data[code]['close'].dropna()
|
||||
ic_a = ic.reindex(a_share_dates, method='ffill')
|
||||
index_close_aligned[code] = ic_a
|
||||
returns_idx[code] = ic_a.pct_change(fill_method=None)
|
||||
|
||||
# ETF价格和收益率
|
||||
etf_code = etf_code_map.get(code)
|
||||
if etf_code and code in etf_data:
|
||||
df = etf_data[code]
|
||||
chfq = df['close_hfq'].dropna()
|
||||
chfq_a = chfq.reindex(a_share_dates, method='ffill')
|
||||
etf_close_hfq_aligned[code] = chfq_a
|
||||
etf_close_hfq_raw[code] = chfq
|
||||
returns_etf[f'日收益率_{code}'] = chfq_a.pct_change(fill_method=None)
|
||||
|
||||
ec = df['close'].reindex(a_share_dates, method='ffill')
|
||||
etf_close_aligned[code] = ec
|
||||
eo = df['open'].reindex(a_share_dates, method='ffill')
|
||||
etf_open_aligned[code] = eo
|
||||
elif code in index_data and 'close' in index_data[code].columns:
|
||||
ic = index_data[code]['close'].dropna()
|
||||
ic_a = ic.reindex(a_share_dates, method='ffill')
|
||||
returns_etf[f'日收益率_{code}'] = ic_a.pct_change(fill_method=None)
|
||||
|
||||
returns_etf_df = pd.DataFrame(returns_etf)
|
||||
|
||||
# 6.5 溢价率:(ETF收盘价 - 单位净值) / 单位净值
|
||||
etf_premium_aligned = {}
|
||||
for code in valid_codes:
|
||||
if code in etf_nav_data and code in etf_close_aligned:
|
||||
nav_raw = etf_nav_data[code]
|
||||
nav_raw = nav_raw[~nav_raw.index.duplicated(keep='last')]
|
||||
nav = nav_raw.reindex(a_share_dates, method='ffill')
|
||||
close = etf_close_aligned[code]
|
||||
premium = (close - nav) / nav
|
||||
etf_premium_aligned[code] = premium
|
||||
|
||||
# 7. 执行回测获取净值
|
||||
print("\n[6] 执行回测...")
|
||||
common_dates = signals.index.intersection(returns_etf_df.index)
|
||||
signals_aligned = signals.loc[common_dates]
|
||||
returns_aligned = returns_etf_df.loc[common_dates]
|
||||
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100000,
|
||||
trade_cost=TRADE_COST,
|
||||
select_num=SELECT_NUM
|
||||
)
|
||||
portfolio = executor.execute(signals_aligned, returns_aligned)
|
||||
result = portfolio.backtest_result
|
||||
nav_series_raw = result['策略净值']
|
||||
daily_ret_raw = result['策略日收益率']
|
||||
|
||||
# 扩展到所有common_dates,信号前的日期 nav=1.0, return=0.0
|
||||
nav_series = nav_series_raw.reindex(common_dates)
|
||||
daily_ret_series = daily_ret_raw.reindex(common_dates, fill_value=0.0)
|
||||
first_valid = nav_series.first_valid_index()
|
||||
if first_valid is not None:
|
||||
nav_series.loc[:first_valid] = nav_series.loc[:first_valid].fillna(1.0)
|
||||
nav_series = nav_series.ffill()
|
||||
|
||||
print(f" 终值: {nav_series.iloc[-1]:.4f}")
|
||||
|
||||
# 8. 构建逐日明细
|
||||
print("\n[7] 构建逐日明细...")
|
||||
|
||||
# 持仓跟踪状态
|
||||
holdings_state = {} # {code: {'entry_date': str, 'entry_price': float}}
|
||||
prev_holdings = set()
|
||||
|
||||
days_list = []
|
||||
signal_col = 'signal'
|
||||
|
||||
for i, date in enumerate(common_dates):
|
||||
sig_val = signals_aligned.loc[date, signal_col] if signal_col in signals_aligned.columns else ''
|
||||
current_holdings = set(str(sig_val).split(',')) if pd.notna(sig_val) and sig_val else set()
|
||||
current_holdings.discard('')
|
||||
|
||||
# 调仓检测
|
||||
added = list(current_holdings - prev_holdings)
|
||||
removed = list(prev_holdings - current_holdings)
|
||||
is_rebalance = len(added) > 0 or len(removed) > 0
|
||||
|
||||
# 更新持仓状态
|
||||
for code in removed:
|
||||
holdings_state.pop(code, None)
|
||||
for code in added:
|
||||
entry_price_etf = None
|
||||
entry_price_idx = None
|
||||
if code in etf_close_hfq_aligned:
|
||||
ep = etf_close_hfq_aligned[code].get(date)
|
||||
if pd.notna(ep):
|
||||
entry_price_etf = float(ep)
|
||||
if code in index_close_aligned:
|
||||
ep = index_close_aligned[code].get(date)
|
||||
if pd.notna(ep):
|
||||
entry_price_idx = float(ep)
|
||||
holdings_state[code] = {
|
||||
'entry_date': date.strftime('%Y-%m-%d'),
|
||||
'entry_price_etf': entry_price_etf,
|
||||
'entry_price_idx': entry_price_idx,
|
||||
}
|
||||
|
||||
# 动态阈值
|
||||
factor_scores = {}
|
||||
for code in valid_codes:
|
||||
if code in factor_df.columns:
|
||||
v = factor_df.loc[date, code] if date in factor_df.index else np.nan
|
||||
if pd.notna(v):
|
||||
factor_scores[code] = float(v)
|
||||
|
||||
bond_score = factor_scores.get(BOND_CODE)
|
||||
if BOND_THRESHOLD.get('enabled') and bond_score is not None and bond_score >= 0:
|
||||
threshold = bond_score * BOND_RATIO
|
||||
else:
|
||||
threshold = 0.0
|
||||
|
||||
# 排名(按动量降序,排除BOND)
|
||||
non_bond_scores = {k: v for k, v in factor_scores.items()
|
||||
if group_mapping.get(k) != 'BOND'}
|
||||
sorted_codes = sorted(non_bond_scores.keys(),
|
||||
key=lambda c: non_bond_scores[c], reverse=True)
|
||||
rank_map = {c: r + 1 for r, c in enumerate(sorted_codes)}
|
||||
# BOND不参与排名
|
||||
if BOND_CODE in factor_scores:
|
||||
rank_map[BOND_CODE] = None
|
||||
|
||||
# 每标的详情
|
||||
assets = {}
|
||||
for code in valid_codes:
|
||||
asset = {}
|
||||
|
||||
# 指数收盘价
|
||||
if code in index_close_aligned:
|
||||
v = index_close_aligned[code].get(date)
|
||||
asset['index_close'] = safe_val(v, 2)
|
||||
else:
|
||||
asset['index_close'] = None
|
||||
|
||||
# 动量
|
||||
mom = factor_scores.get(code)
|
||||
asset['momentum'] = safe_val(mom, 4)
|
||||
|
||||
# 排名
|
||||
asset['rank'] = rank_map.get(code)
|
||||
|
||||
# 阈值
|
||||
asset['threshold'] = safe_val(threshold, 4)
|
||||
asset['above_threshold'] = mom >= threshold if mom is not None else False
|
||||
|
||||
# ETF价格
|
||||
if code in etf_close_aligned:
|
||||
asset['etf_close'] = safe_val(etf_close_aligned[code].get(date), 3)
|
||||
else:
|
||||
asset['etf_close'] = None
|
||||
|
||||
if code in etf_open_aligned:
|
||||
asset['etf_open'] = safe_val(etf_open_aligned[code].get(date), 3)
|
||||
else:
|
||||
asset['etf_open'] = None
|
||||
|
||||
if code in etf_close_hfq_aligned:
|
||||
asset['etf_close_hfq'] = safe_val(etf_close_hfq_aligned[code].get(date), 4)
|
||||
else:
|
||||
asset['etf_close_hfq'] = None
|
||||
|
||||
# 溢价率
|
||||
if code in etf_premium_aligned:
|
||||
asset['premium'] = safe_val(etf_premium_aligned[code].get(date), 4)
|
||||
else:
|
||||
asset['premium'] = None
|
||||
|
||||
# ETF日收益率
|
||||
ret_col = f'日收益率_{code}'
|
||||
if ret_col in returns_etf_df.columns:
|
||||
asset['etf_return_ctc'] = safe_val(returns_etf_df.loc[date, ret_col], 6)
|
||||
else:
|
||||
asset['etf_return_ctc'] = None
|
||||
|
||||
# 指数日收益率
|
||||
if code in returns_idx:
|
||||
asset['index_return'] = safe_val(returns_idx[code].get(date), 6)
|
||||
else:
|
||||
asset['index_return'] = None
|
||||
|
||||
# 持仓状态
|
||||
is_held = code in current_holdings
|
||||
asset['is_held'] = is_held
|
||||
|
||||
if is_held and code in holdings_state:
|
||||
hs = holdings_state[code]
|
||||
asset['entry_date'] = hs['entry_date']
|
||||
asset['entry_price_etf'] = safe_val(hs['entry_price_etf'], 4)
|
||||
asset['entry_price_idx'] = safe_val(hs['entry_price_idx'], 4)
|
||||
|
||||
entry_dt = pd.Timestamp(hs['entry_date'])
|
||||
trading_days_held = len(common_dates[(common_dates >= entry_dt) & (common_dates <= date)])
|
||||
asset['holding_days'] = trading_days_held
|
||||
|
||||
# ETF累计收益
|
||||
if hs['entry_price_etf'] and hs['entry_price_etf'] > 0:
|
||||
cur = etf_close_hfq_aligned[code].get(date) if code in etf_close_hfq_aligned else None
|
||||
if cur and pd.notna(cur):
|
||||
asset['cum_return_etf'] = safe_val(float(cur) / hs['entry_price_etf'] - 1, 4)
|
||||
else:
|
||||
asset['cum_return_etf'] = None
|
||||
else:
|
||||
asset['cum_return_etf'] = None
|
||||
|
||||
# 指数累计收益
|
||||
if hs['entry_price_idx'] and hs['entry_price_idx'] > 0:
|
||||
cur = index_close_aligned[code].get(date) if code in index_close_aligned else None
|
||||
if cur and pd.notna(cur):
|
||||
asset['cum_return_idx'] = safe_val(float(cur) / hs['entry_price_idx'] - 1, 4)
|
||||
else:
|
||||
asset['cum_return_idx'] = None
|
||||
else:
|
||||
asset['cum_return_idx'] = None
|
||||
else:
|
||||
asset['entry_date'] = None
|
||||
asset['entry_price_etf'] = None
|
||||
asset['entry_price_idx'] = None
|
||||
asset['holding_days'] = 0
|
||||
asset['cum_return_etf'] = None
|
||||
asset['cum_return_idx'] = None
|
||||
|
||||
assets[code] = asset
|
||||
|
||||
# 构建当天记录
|
||||
nav_val = nav_series.loc[date] if date in nav_series.index else None
|
||||
ret_val = daily_ret_series.loc[date] if date in daily_ret_series.index else None
|
||||
|
||||
day_record = {
|
||||
'date': date.strftime('%Y-%m-%d'),
|
||||
'nav': safe_val(nav_val, 4),
|
||||
'daily_return': safe_val(ret_val, 6),
|
||||
'is_rebalance': is_rebalance,
|
||||
'holdings': sorted(list(current_holdings)),
|
||||
'added': sorted(added),
|
||||
'removed': sorted(removed),
|
||||
'assets': assets
|
||||
}
|
||||
days_list.append(day_record)
|
||||
prev_holdings = current_holdings
|
||||
|
||||
# 9. 构建元数据
|
||||
codes_meta = {}
|
||||
for code, cfg in CODE_LIST.items():
|
||||
codes_meta[code] = {
|
||||
'name': cfg['name'],
|
||||
'etf': cfg.get('etf'),
|
||||
'market': cfg.get('market')
|
||||
}
|
||||
|
||||
output = {
|
||||
'meta': {
|
||||
'mode': 'B: 指数信号 + ETF收益',
|
||||
'start_date': common_dates[0].strftime('%Y-%m-%d'),
|
||||
'end_date': common_dates[-1].strftime('%Y-%m-%d'),
|
||||
'total_days': len(common_dates),
|
||||
'select_num': SELECT_NUM,
|
||||
'n_days': N_DAYS,
|
||||
'trade_cost': TRADE_COST,
|
||||
'bond_threshold': {
|
||||
'enabled': BOND_THRESHOLD.get('enabled', False),
|
||||
'bond_code': BOND_CODE,
|
||||
'ratio': BOND_RATIO
|
||||
},
|
||||
'codes': codes_meta
|
||||
},
|
||||
'days': days_list
|
||||
}
|
||||
|
||||
# 10. 输出
|
||||
output_path = project_root / 'results' / 'backtest_detail.json'
|
||||
print(f"\n[8] 写入 {output_path}...")
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(output, f, ensure_ascii=False)
|
||||
|
||||
file_size_mb = output_path.stat().st_size / 1024 / 1024
|
||||
print(f" 大小: {file_size_mb:.1f} MB")
|
||||
print(f" 天数: {len(days_list)}")
|
||||
print(f" 标的: {len(valid_codes)}")
|
||||
print(" 完成!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -62,7 +62,14 @@ class MomentumFactor(FactorBase):
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
# 价格下界 clip,防止 log(0) 或 log(负数)
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
|
||||
# 异常值检测
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
|
||||
164
tests/test_trading_calendar.py
Normal file
164
tests/test_trading_calendar.py
Normal file
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试交易日历 API
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
# Flask 服务地址
|
||||
FLASK_API_URL = "http://localhost:80"
|
||||
|
||||
def test_calendar_api():
|
||||
"""测试交易日历 API"""
|
||||
print("\n" + "="*80)
|
||||
print("📅 交易日历 API 测试")
|
||||
print("="*80)
|
||||
|
||||
# 测试 1: A 股
|
||||
print("\n[1] 测试 A 股交易日历...")
|
||||
url = f"{FLASK_API_URL}/api/v1/trading-calendar"
|
||||
params = {"market": "A", "start": "2024-01-01", "end": "2024-01-31"}
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f" ✅ 成功: {data['count']} 个交易日")
|
||||
print(f" 市场: {data['market']}")
|
||||
print(f" 交易所: {data['exchange']}")
|
||||
print(f" 日期范围: {data['start']} ~ {data['end']}")
|
||||
print(f" 前5个交易日: {data['trading_dates'][:5]}")
|
||||
else:
|
||||
print(f" ❌ 失败: {response.status_code}")
|
||||
print(f" 响应: {response.json()}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 异常: {e}")
|
||||
|
||||
# 测试 2: 美股
|
||||
print("\n[2] 测试美股交易日历...")
|
||||
params = {"market": "US", "start": "2024-01-01", "end": "2024-01-31"}
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f" ✅ 成功: {data['count']} 个交易日")
|
||||
print(f" 市场: {data['market']}")
|
||||
print(f" 交易所: {data['exchange']}")
|
||||
print(f" 前5个交易日: {data['trading_dates'][:5]}")
|
||||
else:
|
||||
print(f" ❌ 失败: {response.status_code}")
|
||||
print(f" 响应: {response.json()}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 异常: {e}")
|
||||
|
||||
# 测试 3: 港股
|
||||
print("\n[3] 测试港股交易日历...")
|
||||
params = {"market": "HK", "start": "2024-01-01", "end": "2024-01-31"}
|
||||
|
||||
try:
|
||||
response = requests.get(url, params=params, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f" ✅ 成功: {data['count']} 个交易日")
|
||||
print(f" 市场: {data['market']}")
|
||||
print(f" 交易所: {data['exchange']}")
|
||||
print(f" 前5个交易日: {data['trading_dates'][:5]}")
|
||||
else:
|
||||
print(f" ❌ 失败: {response.status_code}")
|
||||
print(f" 响应: {response.json()}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 异常: {e}")
|
||||
|
||||
# 测试 4: 日历信息
|
||||
print("\n[4] 测试日历信息...")
|
||||
url_info = f"{FLASK_API_URL}/api/v1/calendar/info"
|
||||
|
||||
try:
|
||||
response = requests.get(url_info, timeout=10)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
print(f" ✅ 成功")
|
||||
print(f" 支持的市场:")
|
||||
for market, info in data.get('supported_markets', {}).items():
|
||||
print(f" {market}: {info['name']} ({info['method']})")
|
||||
print(f" pandas_market_calendars: {'✅ 已安装' if data.get('pandas_market_calendars_installed') else '❌ 未安装'}")
|
||||
else:
|
||||
print(f" ❌ 失败: {response.status_code}")
|
||||
except Exception as e:
|
||||
print(f" ❌ 异常: {e}")
|
||||
|
||||
def test_local_fetcher():
|
||||
"""测试本地 UniversalDataFetcher"""
|
||||
print("\n" + "="*80)
|
||||
print("🧪 本地 UniversalDataFetcher 测试")
|
||||
print("="*80)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
try:
|
||||
from datasource.universal_fetcher import UniversalDataFetcher
|
||||
|
||||
fetcher = UniversalDataFetcher()
|
||||
|
||||
# 测试 A 股
|
||||
print("\n[1] A 股交易日历 (2024年)...")
|
||||
cal_a = fetcher.get_trading_calendar('A', '2024-01-01', '2024-12-31')
|
||||
print(f" ✅ {len(cal_a)} 个交易日")
|
||||
print(f" 前5天: {list(cal_a[:5])}")
|
||||
|
||||
# 测试美股
|
||||
print("\n[2] 美股交易日历 (2024年)...")
|
||||
cal_us = fetcher.get_trading_calendar('US', '2024-01-01', '2024-12-31')
|
||||
print(f" ✅ {len(cal_us)} 个交易日")
|
||||
print(f" 前5天: {list(cal_us[:5])}")
|
||||
|
||||
# 测试港股
|
||||
print("\n[3] 港股交易日历 (2024年)...")
|
||||
cal_hk = fetcher.get_trading_calendar('HK', '2024-01-01', '2024-12-31')
|
||||
print(f" ✅ {len(cal_hk)} 个交易日")
|
||||
print(f" 前5天: {list(cal_hk[:5])}")
|
||||
|
||||
# 日历信息
|
||||
print("\n[4] 日历支持信息...")
|
||||
info = fetcher.get_calendar_info()
|
||||
print(f" ✅ 支持 {len(info['supported_markets'])} 个市场")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ 失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
def main():
|
||||
print("\n" + "="*80)
|
||||
print("📅 交易日历功能测试")
|
||||
print("="*80)
|
||||
|
||||
# 测试 1: 本地 fetcher
|
||||
test_local_fetcher()
|
||||
|
||||
# 测试 2: Flask API(如果服务在运行)
|
||||
print("\n" + "="*80)
|
||||
print("🌐 测试 Flask API 端点")
|
||||
print("="*80)
|
||||
print(f"\nAPI 地址: {FLASK_API_URL}")
|
||||
print("注意: 需要 Flask 服务正在运行")
|
||||
|
||||
try:
|
||||
response = requests.get(f"{FLASK_API_URL}/health", timeout=3)
|
||||
if response.status_code == 200:
|
||||
print("✅ Flask 服务可访问")
|
||||
test_calendar_api()
|
||||
else:
|
||||
print(f"⚠️ Flask 服务返回 {response.status_code},跳过 API 测试")
|
||||
except:
|
||||
print("⚠️ Flask 服务未运行,跳过 API 测试")
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("✅ 测试完成")
|
||||
print("="*80)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
tests/verify_fix_result.py
Normal file
143
tests/verify_fix_result.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
验证修复后的回测结果是否与文档一致
|
||||
|
||||
文档预期结果 (Mode A - 指数信号+指数收益):
|
||||
CAGR: 11.80%, 最大回撤: -29.49%, 夏普: 0.818, Calmar: 0.400
|
||||
|
||||
文档预期结果 (Mode B - 指数信号+ETF收益):
|
||||
CAGR: 28.07%, 最大回撤: -13.34%, 夏普: 1.685, Calmar: 2.104
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
def calculate_metrics(nav: pd.Series) -> dict:
|
||||
"""计算绩效指标"""
|
||||
start_date = nav.index[0]
|
||||
end_date = nav.index[-1]
|
||||
days = (end_date - start_date).days
|
||||
years = days / 365
|
||||
|
||||
total_return = nav.iloc[-1] - 1
|
||||
cagr = (nav.iloc[-1] / nav.iloc[0]) ** (1/years) - 1
|
||||
|
||||
daily_ret = nav.pct_change().dropna()
|
||||
sharpe = daily_ret.mean() / daily_ret.std() * np.sqrt(252) if daily_ret.std() > 0 else 0
|
||||
|
||||
peak = nav.cummax()
|
||||
drawdown = (nav - peak) / peak
|
||||
max_dd = drawdown.min()
|
||||
|
||||
calmar = cagr / abs(max_dd) if max_dd != 0 else 0
|
||||
win_rate = (daily_ret > 0).sum() / len(daily_ret)
|
||||
|
||||
return {
|
||||
'start_date': start_date.strftime('%Y-%m-%d'),
|
||||
'end_date': end_date.strftime('%Y-%m-%d'),
|
||||
'years': years,
|
||||
'days': len(nav),
|
||||
'total_return': total_return,
|
||||
'cagr': cagr,
|
||||
'max_dd': max_dd,
|
||||
'sharpe': sharpe,
|
||||
'calmar': calmar,
|
||||
'win_rate': win_rate
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
# 加载配置
|
||||
config_path = project_root / 'strategies/rotation/config.yaml'
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 设置回测区间(文档中的测试区间)
|
||||
config['start_date'] = '2020-01-02'
|
||||
config['end_date'] = '2026-05-19'
|
||||
|
||||
print('='*70)
|
||||
print('修复后回测结果验证')
|
||||
print('='*70)
|
||||
print(f'回测区间: {config["start_date"]} ~ {config["end_date"]}')
|
||||
|
||||
# 初始化策略
|
||||
strategy = RotationStrategy(config)
|
||||
|
||||
# 获取数据并执行回测
|
||||
print('\n获取数据...')
|
||||
data = strategy.get_data(use_flask_api=False)
|
||||
|
||||
print('\n执行回测...')
|
||||
result = strategy.run_backtest(data=data)
|
||||
|
||||
if result.get('result') is None:
|
||||
print('❌ 回测未生成结果')
|
||||
return
|
||||
|
||||
# 计算指标
|
||||
nav = result['result']['策略净值']
|
||||
metrics = calculate_metrics(nav)
|
||||
|
||||
# 输出结果
|
||||
print('\n' + '='*70)
|
||||
print('修复后回测结果')
|
||||
print('='*70)
|
||||
print(f"回测区间: {metrics['start_date']} ~ {metrics['end_date']}")
|
||||
print(f"回测年数: {metrics['years']:.2f} 年")
|
||||
print(f"交易天数: {metrics['days']} 天")
|
||||
print('-'*70)
|
||||
print(f"CAGR: {metrics['cagr']:.2%}")
|
||||
print(f"最大回撤: {metrics['max_dd']:.2%}")
|
||||
print(f"夏普比率: {metrics['sharpe']:.3f}")
|
||||
print(f"Calmar比率: {metrics['calmar']:.3f}")
|
||||
print(f"日胜率: {metrics['win_rate']:.2%}")
|
||||
print(f"累计收益: {metrics['total_return']:.2%}")
|
||||
print(f"调仓次数: {len(result.get('rebalance_events', []))} 次")
|
||||
print('='*70)
|
||||
|
||||
# 文档预期结果对比
|
||||
print('\n' + '='*70)
|
||||
print('文档预期结果对比')
|
||||
print('='*70)
|
||||
print("\nMode A (指数信号 → 指数收益):")
|
||||
print(" 预期: CAGR 11.80%, MaxDD -29.49%, Sharpe 0.818, Calmar 0.400")
|
||||
|
||||
print("\nMode B (指数信号 → ETF收益):")
|
||||
print(" 预期: CAGR 28.07%, MaxDD -13.34%, Sharpe 1.685, Calmar 2.104")
|
||||
|
||||
# 判断当前模式
|
||||
print('\n' + '-'*70)
|
||||
cagr_diff_a = abs(metrics['cagr'] - 0.1180)
|
||||
cagr_diff_b = abs(metrics['cagr'] - 0.2807)
|
||||
|
||||
if cagr_diff_a < 0.03:
|
||||
print(f"✓ 当前结果接近 Mode A (CAGR差异: {cagr_diff_a:.2%})")
|
||||
print(" 说明: 当前回测使用指数收盘价计算收益")
|
||||
elif cagr_diff_b < 0.03:
|
||||
print(f"✓ 当前结果接近 Mode B (CAGR差异: {cagr_diff_b:.2%})")
|
||||
print(" 说明: 当前回测使用ETF价格计算收益")
|
||||
else:
|
||||
print(f"⚠ 当前结果与文档预期有差异")
|
||||
print(f" Mode A CAGR差异: {cagr_diff_a:.2%}")
|
||||
print(f" Mode B CAGR差异: {cagr_diff_b:.2%}")
|
||||
|
||||
print('='*70)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
208
tests/verify_mode_b.py
Normal file
208
tests/verify_mode_b.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
验证 Mode B: 指数信号 → ETF收益
|
||||
|
||||
文档预期结果:
|
||||
CAGR: 28.07%, 最大回撤: -13.34%, 夏普: 1.685, Calmar: 2.104
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
def calculate_metrics(nav: pd.Series) -> dict:
|
||||
"""计算绩效指标"""
|
||||
start_date = nav.index[0]
|
||||
end_date = nav.index[-1]
|
||||
days = (end_date - start_date).days
|
||||
years = days / 365
|
||||
|
||||
total_return = nav.iloc[-1] - 1
|
||||
cagr = (nav.iloc[-1] / nav.iloc[0]) ** (1/years) - 1
|
||||
|
||||
daily_ret = nav.pct_change().dropna()
|
||||
sharpe = daily_ret.mean() / daily_ret.std() * np.sqrt(252) if daily_ret.std() > 0 else 0
|
||||
|
||||
peak = nav.cummax()
|
||||
drawdown = (nav - peak) / peak
|
||||
max_dd = drawdown.min()
|
||||
|
||||
calmar = cagr / abs(max_dd) if max_dd != 0 else 0
|
||||
win_rate = (daily_ret > 0).sum() / len(daily_ret)
|
||||
|
||||
return {
|
||||
'start_date': start_date.strftime('%Y-%m-%d'),
|
||||
'end_date': end_date.strftime('%Y-%m-%d'),
|
||||
'years': years,
|
||||
'days': len(nav),
|
||||
'total_return': total_return,
|
||||
'cagr': cagr,
|
||||
'max_dd': max_dd,
|
||||
'sharpe': sharpe,
|
||||
'calmar': calmar,
|
||||
'win_rate': win_rate
|
||||
}
|
||||
|
||||
|
||||
def run_mode_b_backtest(data: dict, signals: pd.DataFrame, valid_codes: list,
|
||||
etf_code_map: dict, a_share_dates: pd.DatetimeIndex,
|
||||
trade_cost: float, select_num: int) -> dict:
|
||||
"""
|
||||
Mode B: 使用ETF价格计算收益
|
||||
|
||||
Args:
|
||||
data: 包含 etf_data 的数据字典
|
||||
signals: 指数生成的信号
|
||||
valid_codes: 指数代码列表
|
||||
etf_code_map: {指数代码: ETF代码} 映射
|
||||
a_share_dates: A股交易日历
|
||||
trade_cost: 交易成本
|
||||
select_num: 选股数量
|
||||
"""
|
||||
from framework.execution import BacktestExecutor
|
||||
|
||||
etf_data = data.get('etf_data')
|
||||
if etf_data is None:
|
||||
print("❌ ETF数据不可用")
|
||||
return {'result': None}
|
||||
|
||||
# 将信号对齐到 A 股日历
|
||||
if a_share_dates is not signals.index:
|
||||
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
|
||||
|
||||
# 使用ETF收盘价计算收益率
|
||||
returns_data = {}
|
||||
for code in valid_codes:
|
||||
etf_code = etf_code_map.get(code)
|
||||
if etf_code and etf_code in etf_data.columns:
|
||||
etf_close = etf_data[etf_code].dropna()
|
||||
# 对齐到A股日历
|
||||
etf_aligned = etf_close.reindex(a_share_dates, method='ffill')
|
||||
returns_aligned = etf_aligned.pct_change(fill_method=None)
|
||||
# 使用指数代码作为列名(与信号匹配)
|
||||
returns_data[f'日收益率_{code}'] = returns_aligned
|
||||
else:
|
||||
# 没有ETF映射的标的,回退使用指数数据
|
||||
index_data = data.get('index_data', {})
|
||||
if code in index_data and 'close' in index_data[code].columns:
|
||||
close_series = index_data[code]['close'].dropna()
|
||||
close_aligned = close_series.reindex(a_share_dates, method='ffill')
|
||||
returns_data[f'日收益率_{code}'] = close_aligned.pct_change(fill_method=None)
|
||||
|
||||
returns_df = pd.DataFrame(returns_data)
|
||||
|
||||
# 对齐日期
|
||||
common_dates = signals.index.intersection(returns_df.index)
|
||||
signals = signals.loc[common_dates]
|
||||
returns_df = returns_df.loc[common_dates]
|
||||
|
||||
print(f" Mode B 对齐后日期: {len(common_dates)} 天")
|
||||
print(f" 使用ETF计算收益: {len([c for c in valid_codes if etf_code_map.get(c)])} 只")
|
||||
|
||||
executor = BacktestExecutor(
|
||||
initial_capital=100000,
|
||||
trade_cost=trade_cost,
|
||||
select_num=select_num
|
||||
)
|
||||
|
||||
portfolio = executor.execute(signals, returns_df)
|
||||
|
||||
if hasattr(portfolio, 'backtest_result'):
|
||||
return {'result': portfolio.backtest_result, 'portfolio': portfolio}
|
||||
|
||||
return {'result': None}
|
||||
|
||||
|
||||
def main():
|
||||
# 加载配置
|
||||
config_path = project_root / 'strategies/rotation/config.yaml'
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# 设置回测区间
|
||||
config['start_date'] = '2020-01-02'
|
||||
config['end_date'] = '2026-05-19'
|
||||
|
||||
print('='*70)
|
||||
print('Mode B 验证: 指数信号 → ETF收益')
|
||||
print('='*70)
|
||||
|
||||
# 初始化策略
|
||||
strategy = RotationStrategy(config)
|
||||
|
||||
# 获取数据
|
||||
print('\n获取数据...')
|
||||
data = strategy.get_data(use_flask_api=False)
|
||||
|
||||
# 计算因子(使用指数数据)
|
||||
print('\n计算因子(指数信号)...')
|
||||
factor_df = strategy.compute_factors(data)
|
||||
|
||||
# 生成信号
|
||||
print('\n生成信号...')
|
||||
signals = strategy.generate_signals(factor_df)
|
||||
|
||||
# 执行 Mode B 回测
|
||||
print('\n执行 Mode B 回测(ETF收益)...')
|
||||
result_b = run_mode_b_backtest(
|
||||
data=data,
|
||||
signals=signals,
|
||||
valid_codes=data['valid_codes'],
|
||||
etf_code_map=data['etf_code_map'],
|
||||
a_share_dates=data.get('a_share_dates'),
|
||||
trade_cost=config.get('trade_cost', 0.001),
|
||||
select_num=config.get('select_num', 3)
|
||||
)
|
||||
|
||||
if result_b.get('result') is None:
|
||||
print('❌ Mode B 回测未生成结果')
|
||||
return
|
||||
|
||||
# 计算指标
|
||||
nav_b = result_b['result']['策略净值']
|
||||
metrics_b = calculate_metrics(nav_b)
|
||||
|
||||
# 输出结果
|
||||
print('\n' + '='*70)
|
||||
print('Mode B 回测结果')
|
||||
print('='*70)
|
||||
print(f"回测区间: {metrics_b['start_date']} ~ {metrics_b['end_date']}")
|
||||
print(f"回测年数: {metrics_b['years']:.2f} 年")
|
||||
print(f"交易天数: {metrics_b['days']} 天")
|
||||
print('-'*70)
|
||||
print(f"CAGR: {metrics_b['cagr']:.2%}")
|
||||
print(f"最大回撤: {metrics_b['max_dd']:.2%}")
|
||||
print(f"夏普比率: {metrics_b['sharpe']:.3f}")
|
||||
print(f"Calmar比率: {metrics_b['calmar']:.3f}")
|
||||
print(f"日胜率: {metrics_b['win_rate']:.2%}")
|
||||
print(f"累计收益: {metrics_b['total_return']:.2%}")
|
||||
print('='*70)
|
||||
|
||||
# 文档预期对比
|
||||
print('\n文档预期 (Mode B):')
|
||||
print(' CAGR: 28.07%, MaxDD -13.34%, Sharpe 1.685, Calmar 2.104')
|
||||
|
||||
cagr_diff = abs(metrics_b['cagr'] - 0.2807)
|
||||
print(f'\nCAGR差异: {cagr_diff:.2%}')
|
||||
|
||||
if cagr_diff < 0.05:
|
||||
print('✓ 结果与文档预期基本一致')
|
||||
else:
|
||||
print('⚠ 结果与文档预期有差异')
|
||||
|
||||
return metrics_b
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user