clean(rotation): add simple rotation strategy and remove unused files
New: - rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines) - rotation/config_loader.py: standalone config loader - rotation/config_simple.yaml: 11 assets, 7 groups - rotation/README_SIMPLE.md: usage guide - scripts/get_trading_calendar.py: trading calendar fetcher Removed: - rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py) - rotation/results/ output files (gitignored) - scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts) - scripts/README_TRADING_CALENDAR.md Backtest result (2020-01-10 ~ 2026-06-01): - Total return: 1237.6%, Annual: 52.66% - Max drawdown: -11.71%, Sharpe: 2.50
This commit is contained in:
@@ -72,7 +72,18 @@ def run_backtest():
|
||||
print("=" * 70)
|
||||
|
||||
strategy = GlobalRotationStrategy(config)
|
||||
result = strategy.run()
|
||||
|
||||
# 运行回测并导出 detail
|
||||
output_dir = project_root / "framework_v2" / "results"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
detail_path = output_dir / "backtest_detail_v2.json"
|
||||
print(f"\n导出 detail JSON: {detail_path}")
|
||||
|
||||
result = strategy.run(
|
||||
export_detail=True,
|
||||
detail_path=str(detail_path)
|
||||
)
|
||||
|
||||
# 打印结果
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
@@ -391,31 +391,55 @@ class GlobalRotationStrategy(StrategyBase):
|
||||
aligner = CrossMarketAligner(target_calendar=trading_calendar)
|
||||
|
||||
# 提取交易标的的收盘价,并对齐到 A 股日历
|
||||
print(" [对齐] 对齐 ETF 价格到 A 股日历...")
|
||||
close_dict = {}
|
||||
print(" [对齐] 构建可实现价格序列(模拟真实交易)...")
|
||||
executable_close_dict = {}
|
||||
|
||||
for signal_code, trade_code in signal_to_trade.items():
|
||||
if trade_code in data:
|
||||
# 提取收盘价
|
||||
close_series = data[trade_code]['close']
|
||||
# 使用 signal_code 作为键(与 positions 列名一致)
|
||||
close_dict[signal_code] = close_series
|
||||
# 提取开盘价和收盘价
|
||||
etf_df = data[trade_code]
|
||||
open_series = etf_df['open'].reindex(trading_calendar, method='ffill')
|
||||
close_series = etf_df['close'].reindex(trading_calendar, method='ffill')
|
||||
|
||||
# 默认使用收盘价
|
||||
exec_close = close_series.copy()
|
||||
|
||||
# 检测调仓日,调整价格以反映真实交易
|
||||
for i in range(1, len(trading_calendar)):
|
||||
date = trading_calendar[i]
|
||||
prev_date = trading_calendar[i-1]
|
||||
|
||||
# 获取仓位变化
|
||||
prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0
|
||||
curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0
|
||||
|
||||
# 买入日:修改前一天价格为当日开盘价
|
||||
# 这样收益率 = (close[t] - open[t]) / open[t] = 日内收益
|
||||
if pd.isna(prev_pos) or prev_pos == 0:
|
||||
if pd.notna(curr_pos) and curr_pos > 0:
|
||||
exec_close.loc[prev_date] = open_series.loc[date]
|
||||
|
||||
# 卖出日:不需要修改(因为 positions[t]=0,不会计算收益)
|
||||
|
||||
executable_close_dict[signal_code] = exec_close
|
||||
else:
|
||||
print(f" 警告: {trade_code} 数据不存在,跳过")
|
||||
|
||||
# 使用 CrossMarketAligner 对齐多标的收益率
|
||||
# 内部逻辑:先 ffill 价格到 A 股日历,再计算收益率
|
||||
print(" [对齐] 计算收益率(先对齐价格,再计算)...")
|
||||
returns_df = aligner.align_multi_asset(close_dict)
|
||||
print(" [对齐] 计算收益率(使用可实现价格)...")
|
||||
returns_df = aligner.align_multi_asset(executable_close_dict)
|
||||
print(f" [对齐] 收益率数据: {len(returns_df)} 天, {len(returns_df.columns)} 个标的")
|
||||
|
||||
# 对齐 positions 到 A 股日历
|
||||
# 注意:必须先 reindex 再 ffill,因为 reindex(method='ffill') 不会填充已有的 NaN
|
||||
positions = positions.reindex(trading_calendar)
|
||||
positions = positions.ffill()
|
||||
# 卖出日不向前填充(保持 0)
|
||||
positions = positions.ffill().fillna(0)
|
||||
|
||||
# 计算策略收益(仓位加权,T+1 执行)
|
||||
positions_delayed = positions.shift(1).fillna(0)
|
||||
strategy_returns = (positions_delayed * returns_df).sum(axis=1)
|
||||
# 计算策略收益(仓位加权,无需延迟)
|
||||
# 因为 positions[t] 已表示 t 日的实际持仓,且价格已调整为可实现价格
|
||||
strategy_returns = (positions * returns_df).sum(axis=1)
|
||||
|
||||
# 扣除交易成本
|
||||
strategy_returns, rebalance_count = self._apply_trade_cost(
|
||||
@@ -649,6 +673,33 @@ class GlobalRotationStrategy(StrategyBase):
|
||||
index_return_dict = {}
|
||||
etf_return_dict = {}
|
||||
|
||||
# 构建 ETF 可实现价格序列(与回测一致)
|
||||
executable_etf_close = {}
|
||||
for signal_code, trade_code in signal_to_trade.items():
|
||||
if trade_code in self._data:
|
||||
etf_df = self._data[trade_code]
|
||||
open_series = etf_df['open'].reindex(trading_calendar, method='ffill')
|
||||
close_series = etf_df['close'].reindex(trading_calendar, method='ffill')
|
||||
|
||||
# 默认使用 close
|
||||
exec_close = close_series.copy()
|
||||
|
||||
# 检测调仓日,调整价格
|
||||
for i in range(1, len(trading_calendar)):
|
||||
date = trading_calendar[i]
|
||||
prev_date = trading_calendar[i-1]
|
||||
|
||||
# 获取仓位变化
|
||||
prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0
|
||||
curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0
|
||||
|
||||
# 买入日:修改前一天价格为 open
|
||||
if pd.isna(prev_pos) or prev_pos == 0:
|
||||
if pd.notna(curr_pos) and curr_pos > 0:
|
||||
exec_close.loc[prev_date] = open_series.loc[date]
|
||||
|
||||
executable_etf_close[signal_code] = exec_close
|
||||
|
||||
for signal_code, trade_code in signal_to_trade.items():
|
||||
# 指数收益率
|
||||
if signal_code in index_close_dict:
|
||||
@@ -656,10 +707,10 @@ class GlobalRotationStrategy(StrategyBase):
|
||||
idx_return = idx_close.pct_change(fill_method=None).fillna(0)
|
||||
index_return_dict[signal_code] = idx_return
|
||||
|
||||
# ETF 收益率
|
||||
if signal_code in etf_close_dict:
|
||||
etf_close = etf_close_dict[signal_code].reindex(trading_calendar, method='ffill')
|
||||
etf_return = etf_close.pct_change(fill_method=None).fillna(0)
|
||||
# ETF 收益率(使用可实现价格)
|
||||
if signal_code in executable_etf_close:
|
||||
etf_exec = executable_etf_close[signal_code]
|
||||
etf_return = etf_exec.pct_change(fill_method=None).fillna(0)
|
||||
etf_return_dict[signal_code] = etf_return
|
||||
|
||||
# 对齐因子
|
||||
|
||||
64
rotation/README_SIMPLE.md
Normal file
64
rotation/README_SIMPLE.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# 精简版轮动策略(日迭代)
|
||||
|
||||
## 概述
|
||||
|
||||
从 A 股交易日历出发,逐日迭代模拟实盘信号生成流程。
|
||||
|
||||
**与 V2 的核心差异:**
|
||||
|
||||
| 特性 | V2 向量化版本 | 精简版(日迭代) |
|
||||
|------|--------------|------------------|
|
||||
| 数据获取 | 一次性获取所有历史数据 | 预加载+CSV缓存 |
|
||||
| 因子计算 | 全量 rolling 计算 | 每日单独计算 |
|
||||
| 日历对齐 | 先算因子,后对齐 | 从 A 股日历出发 |
|
||||
| 收益计算 | 向量化 | 逐日迭代 |
|
||||
|
||||
## 快速使用
|
||||
|
||||
```bash
|
||||
cd /path/to/etf
|
||||
FLASK_API_URL=https://k3s.tokenpluse.xyz python rotation/simple_rotation.py
|
||||
```
|
||||
|
||||
或作为模块导入:
|
||||
|
||||
```python
|
||||
from rotation.simple_rotation import SimpleRotationStrategy
|
||||
|
||||
strategy = SimpleRotationStrategy('rotation/config_simple.yaml')
|
||||
result = strategy.run()
|
||||
strategy.export_results()
|
||||
```
|
||||
|
||||
## 策略逻辑
|
||||
|
||||
### 动量计算
|
||||
- 加权线性回归:`score = annualized_return * R^2`
|
||||
- 窗口:25 天(可配置)
|
||||
- 崩盘过滤:连续 3 天跌 > 5% 则清零
|
||||
|
||||
### 信号生成(与 V2 完全一致)
|
||||
1. 每个 group 内选 Top 1(非 BOND 组需超过短债动量)
|
||||
2. 从各组 Top 1 中按动量排序选 Top 3
|
||||
3. 不足用 BOND 填充
|
||||
|
||||
### T+1 执行收益
|
||||
- **持有日**: close-to-close
|
||||
- **卖出日**: close-to-open(开盘已卖出)
|
||||
- **买入日**: open-to-close(日内收益)
|
||||
- 调仓日扣除 0.1% 交易成本
|
||||
|
||||
## 输出文件
|
||||
|
||||
```
|
||||
results/
|
||||
├── simple_rotation_nav.csv # 净值曲线
|
||||
├── simple_rotation_signals.csv # 每日信号
|
||||
├── simple_rotation_detail.json # 完整详情
|
||||
└── simple_rotation_metrics.json # 绩效指标
|
||||
```
|
||||
|
||||
## 数据缓存
|
||||
|
||||
首次运行会从 Flask API 下载数据并缓存到 `data/simple_rotation_cache/`。
|
||||
后续运行直接读取缓存,速度显著提升。
|
||||
457
rotation/config_loader.py
Normal file
457
rotation/config_loader.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
配置加载器 - rotation 目录专用(独立实现)
|
||||
|
||||
不依赖 framework_v2,完全独立的配置加载和验证机制。
|
||||
|
||||
使用示例:
|
||||
from config_loader import load_rotation_config
|
||||
|
||||
# 加载配置(自动处理环境变量)
|
||||
config = load_rotation_config('config_simple.yaml')
|
||||
|
||||
# 访问配置
|
||||
print(f"策略版本: {config.metadata.version}")
|
||||
print(f"动量窗口: {config.factor.n_days}")
|
||||
print(f"资产数量: {config.asset_pools.count()}")
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 枚举类型
|
||||
# ============================================================
|
||||
|
||||
class FactorType(str, Enum):
|
||||
"""因子类型枚举"""
|
||||
MOMENTUM = "momentum"
|
||||
SLOPE_R2 = "slope_r2"
|
||||
WEIGHTED_MOMENTUM = "weighted_momentum"
|
||||
|
||||
|
||||
class PremiumMode(str, Enum):
|
||||
"""溢价控制模式"""
|
||||
FILTER = "filter" # 完全排除
|
||||
PENALIZE = "penalize" # 降权
|
||||
|
||||
|
||||
class ThresholdMode(str, Enum):
|
||||
"""阈值模式"""
|
||||
FIXED = "fixed" # 固定阈值
|
||||
DYNAMIC = "dynamic" # 动态阈值
|
||||
|
||||
|
||||
class DataSourceType(str, Enum):
|
||||
"""数据源类型"""
|
||||
FLASK_API = "flask_api"
|
||||
TUSHARE = "tushare"
|
||||
YFINANCE = "yfinance"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Schema 定义(独立实现)
|
||||
# ============================================================
|
||||
|
||||
class PremiumConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=True, description="是否启用")
|
||||
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
|
||||
|
||||
|
||||
class AssetConfig(BaseModel):
|
||||
"""标的配置"""
|
||||
name: str = Field(..., description="标的名称")
|
||||
group: str = Field(..., description="策略分组")
|
||||
|
||||
signal_source: str = Field(..., description="信号来源代码")
|
||||
trade_source: str = Field(..., description="交易来源代码")
|
||||
|
||||
description: Optional[str] = Field(None, description="标的描述")
|
||||
etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)")
|
||||
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
|
||||
|
||||
@property
|
||||
def is_cross_market(self) -> bool:
|
||||
"""是否跨市场"""
|
||||
return self.signal_source != self.trade_source
|
||||
|
||||
|
||||
class AssetPool(BaseModel):
|
||||
"""资产池(扁平化设计)"""
|
||||
assets: Dict[str, AssetConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="所有标的"
|
||||
)
|
||||
|
||||
@property
|
||||
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
|
||||
"""按策略分组"""
|
||||
groups = {}
|
||||
for code, asset in self.assets.items():
|
||||
group = asset.group
|
||||
if group not in groups:
|
||||
groups[group] = {}
|
||||
groups[group][code] = asset
|
||||
return groups
|
||||
|
||||
@property
|
||||
def groups(self) -> list:
|
||||
"""获取所有策略分组"""
|
||||
return list(self.by_group.keys())
|
||||
|
||||
def get_signal_codes(self, group: str = None) -> list:
|
||||
"""获取信号标的"""
|
||||
if group:
|
||||
return [
|
||||
asset.signal_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.signal_source for asset in self.assets.values()]
|
||||
|
||||
def get_trade_codes(self, group: str = None) -> list:
|
||||
"""获取交易标的"""
|
||||
if group:
|
||||
return [
|
||||
asset.trade_source
|
||||
for asset in self.assets.values()
|
||||
if asset.group == group
|
||||
]
|
||||
return [asset.trade_source for asset in self.assets.values()]
|
||||
|
||||
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
|
||||
"""获取信号→交易映射"""
|
||||
mapping = {}
|
||||
for asset in self.assets.values():
|
||||
if group and asset.group != group:
|
||||
continue
|
||||
mapping[asset.signal_source] = asset.trade_source
|
||||
return mapping
|
||||
|
||||
def count(self, group: str = None) -> int:
|
||||
"""获取标的数量"""
|
||||
if group:
|
||||
return len([a for a in self.assets.values() if a.group == group])
|
||||
return len(self.assets)
|
||||
|
||||
|
||||
class FactorConfig(BaseModel):
|
||||
"""因子配置"""
|
||||
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM)
|
||||
n_days: int = Field(default=25, ge=5, le=250)
|
||||
|
||||
|
||||
class DynamicThresholdConfig(BaseModel):
|
||||
"""动态阈值配置"""
|
||||
reference: str = Field(..., description="阈值参考标的")
|
||||
ratio: float = Field(default=1.0, ge=0, description="阈值倍数")
|
||||
fallback_enabled: bool = Field(default=True)
|
||||
fallback_value: float = Field(default=0.0)
|
||||
|
||||
|
||||
class ThresholdConfig(BaseModel):
|
||||
"""阈值配置"""
|
||||
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC)
|
||||
fixed_value: float = Field(default=0.0)
|
||||
dynamic: Optional[DynamicThresholdConfig] = Field(None)
|
||||
|
||||
|
||||
class RotationConfig(BaseModel):
|
||||
"""轮动配置"""
|
||||
select_num: int = Field(default=3, ge=1, le=10)
|
||||
diversified: bool = Field(default=True)
|
||||
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
||||
|
||||
|
||||
class RebalanceConfig(BaseModel):
|
||||
"""调仓配置"""
|
||||
min_hold_days: int = Field(default=1, ge=1)
|
||||
score_threshold: float = Field(default=0.0)
|
||||
trade_cost: float = Field(default=0.001, ge=0, le=0.01)
|
||||
|
||||
|
||||
class PremiumControlConfig(BaseModel):
|
||||
"""溢价控制配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
default_threshold: float = Field(default=0.10)
|
||||
mode: PremiumMode = Field(default=PremiumMode.FILTER)
|
||||
penalty_factor: float = Field(default=0.5)
|
||||
market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None)
|
||||
|
||||
|
||||
class DataSourceConfig(BaseModel):
|
||||
"""数据源配置"""
|
||||
type: DataSourceType = Field(...)
|
||||
enabled: bool = Field(default=True)
|
||||
timeout: int = Field(default=120, ge=10, le=600)
|
||||
url: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class DataConfig(BaseModel):
|
||||
"""数据配置"""
|
||||
sources: List[DataSourceConfig] = Field(...)
|
||||
|
||||
@field_validator('sources')
|
||||
@classmethod
|
||||
def check_at_least_one(cls, v):
|
||||
"""至少配置一个数据源"""
|
||||
if not v:
|
||||
raise ValueError("必须配置至少一个数据源")
|
||||
return v
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""基准配置"""
|
||||
code: str = Field(...)
|
||||
name: str = Field(...)
|
||||
|
||||
|
||||
class BacktestConfig(BaseModel):
|
||||
"""回测配置"""
|
||||
start_date: str = Field(...)
|
||||
end_date: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class MetadataConfig(BaseModel):
|
||||
"""配置元数据"""
|
||||
version: str = Field(default="1.0.0")
|
||||
strategy: str = Field(default="rotation")
|
||||
description: str = Field(default="")
|
||||
last_updated: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class RotationStrategyConfig(BaseModel):
|
||||
"""ETF轮动策略完整配置"""
|
||||
metadata: MetadataConfig = Field(default_factory=MetadataConfig)
|
||||
asset_pools: AssetPool = Field(...)
|
||||
benchmark: BenchmarkConfig = Field(...)
|
||||
backtest: BacktestConfig = Field(...)
|
||||
factor: FactorConfig = Field(default_factory=FactorConfig)
|
||||
rotation: RotationConfig = Field(default_factory=RotationConfig)
|
||||
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
|
||||
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig)
|
||||
data: DataConfig = Field(...)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 配置加载器(独立实现)
|
||||
# ============================================================
|
||||
|
||||
class ConfigLoader:
|
||||
"""配置加载器(独立实现)"""
|
||||
|
||||
def __init__(self, config_dir: str = None):
|
||||
"""初始化"""
|
||||
if config_dir is None:
|
||||
config_dir = Path(__file__).parent
|
||||
|
||||
self.config_dir = Path(config_dir)
|
||||
|
||||
def load(self, config_file: str) -> RotationStrategyConfig:
|
||||
"""加载配置文件"""
|
||||
# 1. 解析路径
|
||||
config_path = self._resolve_path(config_file)
|
||||
|
||||
# 2. 读取 YAML
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# 3. 环境变量替换
|
||||
config_dict = self._substitute_env_vars(config_dict)
|
||||
|
||||
# 4. Pydantic 验证
|
||||
config = RotationStrategyConfig(**config_dict)
|
||||
|
||||
return config
|
||||
|
||||
def _resolve_path(self, config_file: str) -> Path:
|
||||
"""解析配置文件路径"""
|
||||
path = Path(config_file)
|
||||
|
||||
if path.is_absolute():
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"配置文件不存在: {path}")
|
||||
return path
|
||||
|
||||
# 相对路径:先在配置目录查找
|
||||
config_path = self.config_dir / path
|
||||
if config_path.exists():
|
||||
return config_path
|
||||
|
||||
# 然后在当前工作目录查找
|
||||
cwd_path = Path.cwd() / path
|
||||
if cwd_path.exists():
|
||||
return cwd_path
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"配置文件未找到: {path}\n"
|
||||
f"搜索路径:\n"
|
||||
f" - {config_path}\n"
|
||||
f" - {cwd_path}"
|
||||
)
|
||||
|
||||
def _substitute_env_vars(self, config: Any) -> Any:
|
||||
"""替换环境变量"""
|
||||
if isinstance(config, dict):
|
||||
return {
|
||||
key: self._substitute_env_vars(value)
|
||||
for key, value in config.items()
|
||||
}
|
||||
elif isinstance(config, list):
|
||||
return [
|
||||
self._substitute_env_vars(item)
|
||||
for item in config
|
||||
]
|
||||
elif isinstance(config, str):
|
||||
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
|
||||
pattern = r'\$\{([^}]+)\}'
|
||||
|
||||
def replace_match(match):
|
||||
var_expr = match.group(1)
|
||||
|
||||
if ':' in var_expr:
|
||||
var_name, default = var_expr.split(':', 1)
|
||||
else:
|
||||
var_name = var_expr
|
||||
default = None
|
||||
|
||||
value = os.getenv(var_name)
|
||||
|
||||
if value is None:
|
||||
if default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"环境变量未设置: {var_name}\n"
|
||||
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
return re.sub(pattern, replace_match, config)
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 快捷函数
|
||||
# ============================================================
|
||||
|
||||
def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig:
|
||||
"""
|
||||
加载 rotation 配置
|
||||
|
||||
Args:
|
||||
config_file: 配置文件名
|
||||
|
||||
Returns:
|
||||
RotationStrategyConfig: 验证后的配置对象
|
||||
|
||||
环境变量:
|
||||
必须设置 FLASK_API_URL 或提供默认值
|
||||
"""
|
||||
loader = ConfigLoader()
|
||||
return loader.load(config_file)
|
||||
|
||||
|
||||
def get_config_info(config: RotationStrategyConfig) -> dict:
|
||||
"""获取配置摘要"""
|
||||
return {
|
||||
'version': config.metadata.version,
|
||||
'strategy': config.metadata.strategy,
|
||||
'description': config.metadata.description,
|
||||
'asset_count': config.asset_pools.count(),
|
||||
'groups': config.asset_pools.groups,
|
||||
'factor_type': config.factor.type.value,
|
||||
'n_days': config.factor.n_days,
|
||||
'select_num': config.rotation.select_num,
|
||||
'start_date': config.backtest.start_date,
|
||||
'end_date': config.backtest.end_date or '至今',
|
||||
'threshold_mode': config.rotation.threshold.mode.value,
|
||||
}
|
||||
|
||||
|
||||
def print_config_summary(config: RotationStrategyConfig) -> None:
|
||||
"""打印配置摘要"""
|
||||
info = get_config_info(config)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(" ETF 轮动策略配置摘要")
|
||||
print("=" * 50)
|
||||
|
||||
print(f"\n版本: {info['version']}")
|
||||
print(f"策略: {info['strategy']}")
|
||||
print(f"描述: {info['description']}")
|
||||
|
||||
print(f"\n资产池:")
|
||||
print(f" 总数: {info['asset_count']}")
|
||||
print(f" 分组: {', '.join(info['groups'])}")
|
||||
|
||||
print(f"\n因子配置:")
|
||||
print(f" 类型: {info['factor_type']}")
|
||||
print(f" 窗口: {info['n_days']} 天")
|
||||
|
||||
print(f"\n轮动配置:")
|
||||
print(f" 选股数: {info['select_num']}")
|
||||
print(f" 阈值模式: {info['threshold_mode']}")
|
||||
|
||||
print(f"\n回测配置:")
|
||||
print(f" 起始: {info['start_date']}")
|
||||
print(f" 结束: {info['end_date']}")
|
||||
|
||||
print(f"\n各分组资产:")
|
||||
for group, assets in config.asset_pools.by_group.items():
|
||||
print(f"\n [{group}] ({len(assets)} 个):")
|
||||
for code, asset in assets.items():
|
||||
cross_market = "✓ 跨市场" if asset.is_cross_market else ""
|
||||
print(f" - {asset.name} ({code})")
|
||||
print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 测试
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""测试配置加载器"""
|
||||
# 设置环境变量(如果没有设置)
|
||||
if 'FLASK_API_URL' not in os.environ:
|
||||
print("提示: FLASK_API_URL 未设置,使用默认值")
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
print("\n[测试] 加载 config_simple.yaml...")
|
||||
config = load_rotation_config()
|
||||
|
||||
# 打印摘要
|
||||
print_config_summary(config)
|
||||
|
||||
# 测试基本功能
|
||||
print("\n[验证] 基本功能测试:")
|
||||
print(f" ✓ 配置对象类型: {type(config).__name__}")
|
||||
print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}")
|
||||
print(f" ✓ 资产数量: {config.asset_pools.count()}")
|
||||
print(f" ✓ 分组数量: {len(config.asset_pools.groups)}")
|
||||
|
||||
# 测试数据源
|
||||
print("\n[验证] 数据源配置:")
|
||||
for i, source in enumerate(config.data.sources, 1):
|
||||
print(f" {i}. {source.type.value}")
|
||||
print(f" URL: {source.url}")
|
||||
print(f" 启用: {source.enabled}")
|
||||
|
||||
print("\n✓ 所有测试通过!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 加载失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
186
rotation/config_simple.yaml
Normal file
186
rotation/config_simple.yaml
Normal file
@@ -0,0 +1,186 @@
|
||||
# ETF轮动策略配置(V2 框架)
|
||||
#
|
||||
# 配置版本: 2.0.0
|
||||
# 最后更新: 2024-04-16
|
||||
# 策略名称: rotation
|
||||
# 描述: 全球资产大类轮动策略 - 复现 V1 结果
|
||||
|
||||
# ============================================================
|
||||
# 元数据
|
||||
# ============================================================
|
||||
metadata:
|
||||
version: "2.0.0"
|
||||
strategy: "rotation"
|
||||
description: "全球资产大类轮动策略 V2 - 复现 V1 结果"
|
||||
last_updated: "2024-04-16"
|
||||
|
||||
# ============================================================
|
||||
# 资产池配置(扁平化设计:严格对齐 V1 config.yaml)
|
||||
# ============================================================
|
||||
asset_pools:
|
||||
assets:
|
||||
# 中国A股指数
|
||||
"399006.SZ":
|
||||
name: "创业板指"
|
||||
group: "A"
|
||||
signal_source: "399006.SZ"
|
||||
trade_source: "159915.SZ"
|
||||
description: "创业板指数"
|
||||
|
||||
"H30269.CSI":
|
||||
name: "中证红利低波"
|
||||
group: "A"
|
||||
signal_source: "H30269.CSI"
|
||||
trade_source: "512890.SH"
|
||||
description: "红利低波指数"
|
||||
|
||||
# 全球市场
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
group: "US"
|
||||
signal_source: "NDX"
|
||||
trade_source: "513100.SH"
|
||||
description: "纳斯达克100指数"
|
||||
|
||||
"N225":
|
||||
name: "日经225"
|
||||
group: "JP"
|
||||
signal_source: "N225"
|
||||
trade_source: "513520.SH"
|
||||
description: "日经225指数"
|
||||
|
||||
"GDAXI":
|
||||
name: "德国DAX"
|
||||
group: "EU"
|
||||
signal_source: "GDAXI"
|
||||
trade_source: "513030.SH"
|
||||
description: "德国DAX指数"
|
||||
|
||||
"HSI":
|
||||
name: "恒生指数"
|
||||
group: "HK"
|
||||
signal_source: "HSI"
|
||||
trade_source: "159920.SZ"
|
||||
description: "恒生指数"
|
||||
|
||||
"HSTECH.HK":
|
||||
name: "恒生科技"
|
||||
group: "HK"
|
||||
signal_source: "HSTECH.HK"
|
||||
trade_source: "513130.SH"
|
||||
description: "恒生科技指数"
|
||||
|
||||
# 商品(使用 COMEX/WTI 期货替代上期所主力合约,数据更长)
|
||||
"GC=F":
|
||||
name: "黄金"
|
||||
group: "COMMODITY"
|
||||
signal_source: "GC=F"
|
||||
trade_source: "518880.SH"
|
||||
description: "COMEX黄金期货(2000年至今)"
|
||||
|
||||
"CL=F":
|
||||
name: "原油"
|
||||
group: "COMMODITY"
|
||||
signal_source: "CL=F"
|
||||
trade_source: "160723.SZ"
|
||||
description: "WTI原油期货(2000年至今)"
|
||||
|
||||
"HG=F":
|
||||
name: "有色金属"
|
||||
group: "COMMODITY"
|
||||
signal_source: "HG=F"
|
||||
trade_source: "159980.SZ"
|
||||
description: "COMEX铜期货(2000年至今)"
|
||||
|
||||
# 防御类资产:短债指数
|
||||
# 931862.CSI = 中证0-9个月国债指数(短债指数)
|
||||
# 数据范围:2007-12-31开始,约19年数据
|
||||
# 久期:极短(<1年),波动极小,熊市防御效果最佳
|
||||
# 收益归因:标的收益约17%,决策收益约83%
|
||||
# 注意:无对应ETF可交易,直接使用指数数据计算动量和收益
|
||||
"931862.CSI":
|
||||
name: "短债指数"
|
||||
group: "BOND"
|
||||
signal_source: "931862.CSI"
|
||||
trade_source: "931862.CSI"
|
||||
description: "中证0-9个月国债指数,久期<1年,防御配置"
|
||||
|
||||
# ============================================================
|
||||
# 基准配置
|
||||
# ============================================================
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300"
|
||||
|
||||
# ============================================================
|
||||
# 回测配置
|
||||
# ============================================================
|
||||
backtest:
|
||||
start_date: "2020-01-10" # 与 V1 保持一致(第一个完整交易日)
|
||||
# end_date: "2026-05-22" # 与 V1 保持一致
|
||||
|
||||
# ============================================================
|
||||
# 因子配置
|
||||
# ============================================================
|
||||
factor:
|
||||
type: "weighted_momentum" # 加权动量
|
||||
n_days: 25 # 25 天窗口
|
||||
|
||||
# ============================================================
|
||||
# 轮动配置
|
||||
# ============================================================
|
||||
rotation:
|
||||
select_num: 3 # 选择 Top-3
|
||||
diversified: true # 强制分散化:每个大类只选 Top 1
|
||||
|
||||
# 阈值配置(V3 动态阈值)
|
||||
threshold:
|
||||
mode: "dynamic" # 动态阈值模式
|
||||
fixed_value: 0.0 # 固定阈值(mode=fixed时使用)
|
||||
|
||||
# 动态阈值配置(使用短债动量作为阈值)
|
||||
dynamic:
|
||||
reference: "931862.CSI" # 阈值参考标的(短债指数)
|
||||
ratio: 1.0 # 阈值 = 短债动量 × ratio
|
||||
fallback_enabled: true # 参考不可用时是否回退
|
||||
fallback_value: 0.0 # 回退值
|
||||
|
||||
# ============================================================
|
||||
# 调仓配置
|
||||
# ============================================================
|
||||
rebalance:
|
||||
min_hold_days: 1
|
||||
score_threshold: 0.0
|
||||
trade_cost: 0.001 # 0.1% 交易成本
|
||||
|
||||
# ============================================================
|
||||
# 溢价控制配置
|
||||
# ============================================================
|
||||
premium_control:
|
||||
enabled: false # 启用溢价控制
|
||||
default_threshold: 0.10 # 默认溢价阈值 10%
|
||||
mode: "filter" # filter(完全排除) 或 penalize(降权)
|
||||
penalty_factor: 0.5 # 降权模式下的惩罚系数
|
||||
|
||||
# 按市场覆盖配置
|
||||
market_overrides:
|
||||
A: # A股 ETF
|
||||
enabled: false # 不启用(溢价通常 < 0.5%)
|
||||
HK: # 港股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
US: # 美股 ETF
|
||||
enabled: true
|
||||
threshold: 0.10 # 阈值 10%
|
||||
COMMODITY: # 商品 ETF
|
||||
enabled: false # 不启用
|
||||
|
||||
# ============================================================
|
||||
# 数据配置
|
||||
# ============================================================
|
||||
data:
|
||||
sources:
|
||||
- type: "flask_api"
|
||||
enabled: true
|
||||
url: "${FLASK_API_URL}"
|
||||
timeout: 120
|
||||
584
rotation/simple_rotation.py
Normal file
584
rotation/simple_rotation.py
Normal file
@@ -0,0 +1,584 @@
|
||||
"""
|
||||
Simple Rotation Strategy (Daily Iteration)
|
||||
|
||||
From A-share trading calendar, iterate daily:
|
||||
1. Get signal source last N days -> compute momentum
|
||||
2. Get bond momentum (dynamic threshold)
|
||||
3. Group selection -> generate holdings
|
||||
4. Compare with yesterday -> compute T+1 return
|
||||
5. Update NAV
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import json
|
||||
import time
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pure functions: momentum
|
||||
# ============================================================
|
||||
|
||||
def weighted_momentum_score(prices: np.ndarray) -> float:
|
||||
"""Weighted linear regression momentum score = annualized_return * R^2"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_return = math.exp(slope * 250) - 1
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
return annualized_return * r2
|
||||
|
||||
|
||||
def is_crash(prices: np.ndarray) -> bool:
|
||||
"""Crash filter: 3 consecutive days drop > 5%"""
|
||||
if len(prices) < 4:
|
||||
return False
|
||||
p = prices[-4:]
|
||||
r1 = p[3] / p[2]
|
||||
r2 = p[2] / p[1]
|
||||
r3 = p[1] / p[0]
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
|
||||
return con1 or con2
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data cache
|
||||
# ============================================================
|
||||
|
||||
class DataCache:
|
||||
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
|
||||
|
||||
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.api_path = '/api/v1/ohlcv'
|
||||
self.timeout = timeout
|
||||
if cache_dir is None:
|
||||
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _cache_path(self, code: str, adj: str) -> Path:
|
||||
prefix = 'index' if adj == 'raw' else 'etf'
|
||||
safe_code = code.replace('=', '_').replace('^', '_')
|
||||
return self.cache_dir / f"{prefix}_{safe_code}.csv"
|
||||
|
||||
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
||||
"""Preload full history and cache to CSV"""
|
||||
cache_path = self._cache_path(code, adj)
|
||||
if cache_path.exists():
|
||||
try:
|
||||
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
|
||||
if len(df) > 0:
|
||||
cs = df.index.min().strftime('%Y-%m-%d')
|
||||
ce = df.index.max().strftime('%Y-%m-%d')
|
||||
if cs <= start_date and ce >= end_date:
|
||||
return df
|
||||
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
|
||||
if new_start <= end_date:
|
||||
new_df = self._fetch_api(code, new_start, end_date, adj)
|
||||
if new_df is not None and len(new_df) > 0:
|
||||
df = pd.concat([df, new_df])
|
||||
df = df[~df.index.duplicated(keep='last')]
|
||||
df.to_csv(cache_path)
|
||||
return df
|
||||
except Exception:
|
||||
pass
|
||||
df = self._fetch_api(code, start_date, end_date, adj)
|
||||
if df is not None and len(df) > 0:
|
||||
df.to_csv(cache_path)
|
||||
return df
|
||||
|
||||
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
|
||||
"""Fetch from Flask API"""
|
||||
url = f"{self.base_url}{self.api_path}"
|
||||
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
|
||||
for attempt in range(3):
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=self.timeout)
|
||||
if resp.status_code != 200:
|
||||
if attempt < 2:
|
||||
time.sleep(1)
|
||||
continue
|
||||
print(f" x {code}: HTTP {resp.status_code}")
|
||||
return None
|
||||
data = resp.json()
|
||||
if 'error' in data:
|
||||
print(f" x {code}: {data['error']}")
|
||||
return None
|
||||
records = data.get('data', [])
|
||||
if not records:
|
||||
return None
|
||||
df = pd.DataFrame(records)
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date').sort_index()
|
||||
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
|
||||
df = df[keep]
|
||||
print(f" + {code}: {len(df)} rows ({adj})")
|
||||
return df
|
||||
except requests.exceptions.Timeout:
|
||||
if attempt < 2:
|
||||
continue
|
||||
print(f" x {code}: timeout")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" x {code}: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
|
||||
"""Fetch trading calendar from API"""
|
||||
url = f"{self.base_url}/api/v1/trading-calendar"
|
||||
params = {'market': market, 'start': start_date, 'end': end_date}
|
||||
for attempt in range(3):
|
||||
try:
|
||||
resp = requests.get(url, params=params, timeout=self.timeout)
|
||||
if resp.status_code != 200:
|
||||
if attempt < 2:
|
||||
continue
|
||||
return None
|
||||
data = resp.json()
|
||||
if 'error' in data:
|
||||
print(f" x calendar: {data['error']}")
|
||||
return None
|
||||
dates = data.get('trading_dates', [])
|
||||
if not dates:
|
||||
return pd.DatetimeIndex([])
|
||||
result = pd.DatetimeIndex(dates)
|
||||
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
|
||||
return result
|
||||
except Exception as e:
|
||||
if attempt < 2:
|
||||
continue
|
||||
print(f" x calendar: {e}")
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Core strategy class
|
||||
# ============================================================
|
||||
|
||||
class SimpleRotationStrategy:
|
||||
"""
|
||||
Simple rotation strategy (daily iteration)
|
||||
|
||||
Flow:
|
||||
1. Load config
|
||||
2. Preload all historical data (index raw + ETF hfq)
|
||||
3. Fetch A-share trading calendar
|
||||
4. For each trading day:
|
||||
- Compute momentum (last 25 days)
|
||||
- Group selection -> generate holdings
|
||||
- Compare with yesterday -> compute T+1 return
|
||||
- Update NAV
|
||||
5. Export results
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str = None):
|
||||
if config_path is None:
|
||||
config_path = Path(__file__).parent / 'config_simple.yaml'
|
||||
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
|
||||
|
||||
# Strategy params
|
||||
self.n_days = self.config.factor.n_days
|
||||
self.select_num = self.config.rotation.select_num
|
||||
self.trade_cost = self.config.rebalance.trade_cost
|
||||
|
||||
# Dynamic threshold
|
||||
threshold = self.config.rotation.threshold
|
||||
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
|
||||
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
|
||||
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
|
||||
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
|
||||
|
||||
# Signal codes and mappings
|
||||
self.signal_codes = []
|
||||
self.signal_to_trade = {}
|
||||
self.code_to_group = {}
|
||||
for code, asset in self.config.asset_pools.assets.items():
|
||||
self.signal_codes.append(asset.signal_source)
|
||||
self.signal_to_trade[asset.signal_source] = asset.trade_source
|
||||
self.code_to_group[asset.signal_source] = asset.group
|
||||
|
||||
# Data source
|
||||
data_source = self.config.data.sources[0]
|
||||
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
|
||||
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
|
||||
|
||||
# Preloaded data
|
||||
self.index_data: Dict[str, pd.DataFrame] = {}
|
||||
self.etf_data: Dict[str, pd.DataFrame] = {}
|
||||
|
||||
# Results
|
||||
self.daily_records: List[dict] = []
|
||||
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
def _preload_data(self):
|
||||
"""Preload all historical data"""
|
||||
start_date = self.config.backtest.start_date
|
||||
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
||||
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
|
||||
|
||||
print("\n[1/4] Preloading signal sources (index raw)...")
|
||||
for code in self.signal_codes:
|
||||
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
|
||||
if df is not None:
|
||||
self.index_data[code] = df
|
||||
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
|
||||
|
||||
print("\n[2/4] Preloading trade sources (ETF hfq)...")
|
||||
trade_codes = set(self.signal_to_trade.values())
|
||||
for code in trade_codes:
|
||||
is_bond = any(
|
||||
a.trade_source == code and a.group == 'BOND'
|
||||
for a in self.config.asset_pools.assets.values()
|
||||
)
|
||||
adj = 'raw' if is_bond else 'hfq'
|
||||
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
|
||||
if df is not None:
|
||||
self.etf_data[code] = df
|
||||
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK")
|
||||
|
||||
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||
"""Compute momentum for a single code on a given date"""
|
||||
if signal_code not in self.index_data:
|
||||
return None
|
||||
df = self.index_data[signal_code]
|
||||
mask = df.index <= date
|
||||
recent = df.loc[mask]
|
||||
if len(recent) < self.n_days:
|
||||
return None
|
||||
prices = recent['close'].values[-self.n_days:]
|
||||
if len(prices) >= 4 and is_crash(prices):
|
||||
return 0.0
|
||||
return weighted_momentum_score(prices)
|
||||
|
||||
def _generate_signals(self, date: pd.Timestamp) -> List[str]:
|
||||
"""
|
||||
Generate rotation signals (group competition + dynamic threshold + BOND fill)
|
||||
|
||||
Logic (identical to V2):
|
||||
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
|
||||
2. From all group winners: sort by momentum, select Top select_num
|
||||
3. Fill remaining slots with BOND
|
||||
"""
|
||||
factors: Dict[str, float] = {}
|
||||
for code in self.signal_codes:
|
||||
score = self._compute_momentum(code, date)
|
||||
if score is not None:
|
||||
factors[code] = score
|
||||
if not factors:
|
||||
return []
|
||||
|
||||
bond_momentum = None
|
||||
if self.use_dynamic_threshold and self.bond_code:
|
||||
bond_momentum = self._compute_momentum(self.bond_code, date)
|
||||
if bond_momentum is None:
|
||||
bond_momentum = self.fallback_value
|
||||
|
||||
groups = self.config.asset_pools.by_group
|
||||
selected_by_group: Dict[str, Tuple[str, float]] = {}
|
||||
|
||||
for group_name, assets in groups.items():
|
||||
group_codes = [a.signal_source for a in assets.values()]
|
||||
group_factors = {c: factors[c] for c in group_codes if c in factors}
|
||||
if not group_factors:
|
||||
continue
|
||||
if group_name != 'BOND' and bond_momentum is not None:
|
||||
thresh = bond_momentum * self.bond_ratio
|
||||
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
|
||||
if not group_factors:
|
||||
continue
|
||||
top_code = max(group_factors, key=group_factors.get)
|
||||
selected_by_group[group_name] = (top_code, group_factors[top_code])
|
||||
|
||||
if not selected_by_group:
|
||||
return []
|
||||
|
||||
candidates = list(selected_by_group.values())
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
final_holdings = [code for code, _ in candidates[:self.select_num]]
|
||||
|
||||
if len(final_holdings) < self.select_num and self.bond_code:
|
||||
if self.bond_code not in final_holdings:
|
||||
n_slots = self.select_num - len(final_holdings)
|
||||
final_holdings.extend([self.bond_code] * n_slots)
|
||||
|
||||
return sorted(final_holdings)
|
||||
|
||||
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
||||
"""Get ETF prices on a given date: {open, close, prev_close}
|
||||
Note: Index data may lack 'open' column; use close as fallback.
|
||||
"""
|
||||
if trade_code not in self.etf_data:
|
||||
return None
|
||||
df = self.etf_data[trade_code]
|
||||
mask = df.index <= date
|
||||
recent = df.loc[mask]
|
||||
if len(recent) < 2:
|
||||
return None
|
||||
today_row = recent.iloc[-1]
|
||||
prev_row = recent.iloc[-2]
|
||||
today_date = recent.index[-1]
|
||||
if (date - today_date).days > 5:
|
||||
return None
|
||||
close = float(today_row['close'])
|
||||
prev_close = float(prev_row['close'])
|
||||
# Handle missing open (common for index data like 931862.CSI)
|
||||
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
|
||||
if pd.isna(open_price) or open_price == 0:
|
||||
open_price = close
|
||||
if pd.isna(close) or pd.isna(prev_close):
|
||||
return None
|
||||
return {
|
||||
'open': open_price,
|
||||
'close': close,
|
||||
'prev_close': prev_close,
|
||||
}
|
||||
|
||||
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
||||
"""
|
||||
Compute daily return (T+1 execution):
|
||||
- Hold: close-to-close
|
||||
- Sell: close-to-open (sold at open)
|
||||
- Buy: open-to-close (intraday)
|
||||
"""
|
||||
if not old_holdings:
|
||||
if not new_holdings:
|
||||
return 0.0
|
||||
weight = 1.0 / len(new_holdings)
|
||||
ret = 0.0
|
||||
for code in new_holdings:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
if p and p['open'] > 0:
|
||||
ret += weight * (p['close'] - p['open']) / p['open']
|
||||
if is_rebalance:
|
||||
ret -= self.trade_cost
|
||||
return ret
|
||||
|
||||
old_set = set(old_holdings)
|
||||
new_set = set(new_holdings)
|
||||
weight = 1.0 / len(old_holdings)
|
||||
daily_return = 0.0
|
||||
|
||||
for code in old_holdings:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
if p is None or p['prev_close'] == 0:
|
||||
continue
|
||||
if code in new_set:
|
||||
r = (p['close'] - p['prev_close']) / p['prev_close']
|
||||
else:
|
||||
r = (p['open'] - p['prev_close']) / p['prev_close']
|
||||
if not math.isnan(r):
|
||||
daily_return += weight * r
|
||||
|
||||
for code in new_holdings:
|
||||
if code not in old_set:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
if p and p['open'] > 0 and not math.isnan(p['close']):
|
||||
r = (p['close'] - p['open']) / p['open']
|
||||
if not math.isnan(r):
|
||||
daily_return += weight * r
|
||||
|
||||
if is_rebalance:
|
||||
daily_return -= self.trade_cost
|
||||
return daily_return
|
||||
|
||||
def run(self) -> dict:
|
||||
"""Main backtest loop"""
|
||||
print("=" * 60)
|
||||
print(" Simple Rotation Strategy (Daily Iteration)")
|
||||
print("=" * 60)
|
||||
|
||||
self._preload_data()
|
||||
|
||||
print("\n[3/4] Fetching A-share trading calendar...")
|
||||
start_date = self.config.backtest.start_date
|
||||
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
||||
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
|
||||
|
||||
if self.trading_calendar is None or len(self.trading_calendar) == 0:
|
||||
print(" x Calendar fetch failed")
|
||||
return {}
|
||||
|
||||
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
|
||||
current_holdings: List[str] = []
|
||||
nav = 1.0
|
||||
rebalance_count = 0
|
||||
|
||||
for i, date in enumerate(self.trading_calendar):
|
||||
new_holdings = self._generate_signals(date)
|
||||
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
|
||||
|
||||
daily_return = self._calculate_daily_return(
|
||||
current_holdings, new_holdings, date, is_rebalance
|
||||
)
|
||||
nav *= (1 + daily_return)
|
||||
|
||||
if is_rebalance:
|
||||
rebalance_count += 1
|
||||
|
||||
self.daily_records.append({
|
||||
'date': date.strftime('%Y-%m-%d'),
|
||||
'nav': round(nav, 6),
|
||||
'daily_return': round(daily_return, 6),
|
||||
'is_rebalance': is_rebalance,
|
||||
'holdings': sorted(new_holdings),
|
||||
'added': sorted(set(new_holdings) - set(current_holdings)),
|
||||
'removed': sorted(set(current_holdings) - set(new_holdings)),
|
||||
})
|
||||
|
||||
current_holdings = new_holdings
|
||||
|
||||
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
|
||||
print(f" [{i+1}/{len(self.trading_calendar)}] "
|
||||
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
|
||||
|
||||
metrics = self._compute_metrics(rebalance_count)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" Backtest Complete")
|
||||
print("=" * 60)
|
||||
if self.daily_records:
|
||||
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
|
||||
print(f" Trading days: {len(self.daily_records)}")
|
||||
print(f" Total return: {metrics['total_return']:.2%}")
|
||||
print(f" Annual return: {metrics['annual_return']:.2%}")
|
||||
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
|
||||
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
|
||||
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
|
||||
print(f" Win rate: {metrics['win_rate']:.2%}")
|
||||
print(f" Rebalances: {rebalance_count}")
|
||||
print("=" * 60)
|
||||
|
||||
return {
|
||||
'metrics': metrics,
|
||||
'daily_records': self.daily_records,
|
||||
}
|
||||
|
||||
def _compute_metrics(self, rebalance_count: int) -> dict:
|
||||
"""Compute performance metrics"""
|
||||
if not self.daily_records:
|
||||
return {}
|
||||
df = pd.DataFrame(self.daily_records)
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
nav = df['nav']
|
||||
returns = df['daily_return']
|
||||
|
||||
total_return = nav.iloc[-1] / nav.iloc[0] - 1
|
||||
n_days = len(df)
|
||||
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||||
|
||||
peak = nav.cummax()
|
||||
drawdown = (nav - peak) / peak
|
||||
max_drawdown = drawdown.min()
|
||||
|
||||
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
|
||||
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
|
||||
|
||||
non_zero = returns[returns != 0]
|
||||
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
|
||||
|
||||
return {
|
||||
'total_return': total_return,
|
||||
'annual_return': annual_return,
|
||||
'max_drawdown': max_drawdown,
|
||||
'sharpe_ratio': sharpe,
|
||||
'calmar_ratio': calmar,
|
||||
'win_rate': win_rate,
|
||||
'n_days': n_days,
|
||||
'rebalance_count': rebalance_count,
|
||||
}
|
||||
|
||||
def export_results(self, output_dir: str = None):
|
||||
"""Export backtest results to CSV and JSON"""
|
||||
if not self.daily_records:
|
||||
print(" x No results to export")
|
||||
return
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent / 'results'
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df = pd.DataFrame(self.daily_records)
|
||||
|
||||
# NAV curve
|
||||
nav_path = output_dir / 'simple_rotation_nav.csv'
|
||||
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
|
||||
print(f" + NAV: {nav_path}")
|
||||
|
||||
# Signals
|
||||
sig_path = output_dir / 'simple_rotation_signals.csv'
|
||||
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
|
||||
print(f" + Signals: {sig_path}")
|
||||
|
||||
# Detail JSON
|
||||
detail_path = output_dir / 'simple_rotation_detail.json'
|
||||
detail = {
|
||||
'meta': {
|
||||
'mode': 'Simple: Daily Iteration',
|
||||
'start_date': self.config.backtest.start_date,
|
||||
'end_date': self.config.backtest.end_date or 'now',
|
||||
'n_days': self.n_days,
|
||||
'select_num': self.select_num,
|
||||
'trade_cost': self.trade_cost,
|
||||
'bond_threshold': {
|
||||
'enabled': self.use_dynamic_threshold,
|
||||
'bond_code': self.bond_code,
|
||||
'ratio': self.bond_ratio,
|
||||
},
|
||||
},
|
||||
'days': self.daily_records,
|
||||
}
|
||||
with open(detail_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(detail, f, ensure_ascii=False, indent=2)
|
||||
print(f" + Detail: {detail_path}")
|
||||
|
||||
# Metrics JSON
|
||||
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
|
||||
metrics_path = output_dir / 'simple_rotation_metrics.json'
|
||||
with open(metrics_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metrics, f, ensure_ascii=False, indent=2)
|
||||
print(f" + Metrics: {metrics_path}")
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Entry point
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
if 'FLASK_API_URL' not in os.environ:
|
||||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||||
|
||||
strategy = SimpleRotationStrategy()
|
||||
result = strategy.run()
|
||||
|
||||
if result:
|
||||
strategy.export_results()
|
||||
@@ -121,13 +121,13 @@ def run_strategy(config_path: str = "strategies/rotation/config.yaml") -> dict:
|
||||
|
||||
logger.info(f"执行命令: {' '.join(cmd)}")
|
||||
|
||||
# 执行策略
|
||||
# 执行策略(增加超时到15分钟,因为需要获取多市场数据)
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=project_root,
|
||||
timeout=300 # 5分钟超时
|
||||
timeout=900 # 15分钟超时(原5分钟不够,数据获取串行需要更长时间)
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
|
||||
348
scripts/get_trading_calendar.py
Normal file
348
scripts/get_trading_calendar.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
获取 A 股交易日历脚本
|
||||
|
||||
使用 Flask API 交易日历服务获取 A 股交易日历
|
||||
支持多市场、多年份的交易日查询
|
||||
|
||||
用法:
|
||||
python scripts/get_trading_calendar.py
|
||||
python scripts/get_trading_calendar.py --year 2024
|
||||
python scripts/get_trading_calendar.py --start 2024-01-01 --end 2024-12-31
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# 加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# 导入 Flask API 数据源
|
||||
from datasource.flask_api_source import FlaskAPIDataSource
|
||||
|
||||
|
||||
def get_calendar_for_year(source: FlaskAPIDataSource, year: int, market: str = 'A'):
|
||||
"""
|
||||
获取指定年份的交易日历
|
||||
|
||||
Args:
|
||||
source: Flask API 数据源实例
|
||||
year: 年份(如 2024)
|
||||
market: 市场代码('A', 'US', 'HK')
|
||||
|
||||
Returns:
|
||||
pd.DatetimeIndex: 交易日序列
|
||||
"""
|
||||
start_date = f"{year}-01-01"
|
||||
end_date = f"{year}-12-31"
|
||||
|
||||
print(f"\n获取 {year} 年 {market} 市场交易日历...")
|
||||
|
||||
trading_dates = source.get_trading_calendar(
|
||||
market=market,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
if trading_dates is None or len(trading_dates) == 0:
|
||||
print(f"✗ {year} 年 {market} 市场无交易日数据")
|
||||
return None
|
||||
|
||||
return trading_dates
|
||||
|
||||
|
||||
def analyze_calendar(trading_dates: pd.DatetimeIndex, year: int):
|
||||
"""
|
||||
分析交易日历统计信息
|
||||
|
||||
Args:
|
||||
trading_dates: 交易日序列
|
||||
year: 年份
|
||||
"""
|
||||
if trading_dates is None or len(trading_dates) == 0:
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"{year} 年 A 股交易日历分析")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
# 基本统计
|
||||
total_days = len(trading_dates)
|
||||
print(f"\n基本统计:")
|
||||
print(f" 总交易日: {total_days} 天")
|
||||
print(f" 起始日期: {trading_dates.min().strftime('%Y-%m-%d')}")
|
||||
print(f" 结束日期: {trading_dates.max().strftime('%Y-%m-%d')}")
|
||||
|
||||
# 按月份统计
|
||||
print(f"\n按月份统计:")
|
||||
monthly_counts = {}
|
||||
for date in trading_dates:
|
||||
month = date.month
|
||||
monthly_counts[month] = monthly_counts.get(month, 0) + 1
|
||||
|
||||
for month in range(1, 13):
|
||||
count = monthly_counts.get(month, 0)
|
||||
month_name = datetime(2024, month, 1).strftime('%B')
|
||||
print(f" {month:02d}月 ({month_name}): {count} 天")
|
||||
|
||||
# 按季度统计
|
||||
print(f"\n按季度统计:")
|
||||
quarterly_counts = {1: 0, 2: 0, 3: 0, 4: 0}
|
||||
for date in trading_dates:
|
||||
quarter = (date.month - 1) // 3 + 1
|
||||
quarterly_counts[quarter] += 1
|
||||
|
||||
for quarter, count in quarterly_counts.items():
|
||||
print(f" Q{quarter}: {count} 天")
|
||||
|
||||
# 特殊日期统计
|
||||
print(f"\n特殊日期:")
|
||||
first_date = trading_dates.min()
|
||||
last_date = trading_dates.max()
|
||||
print(f" 首个交易日: {first_date.strftime('%Y-%m-%d')} ({first_date.strftime('%A')})")
|
||||
print(f" 最后交易日: {last_date.strftime('%Y-%m-%d')} ({last_date.strftime('%A')})")
|
||||
|
||||
# 查找节假日后的首个交易日(通过间隔判断)
|
||||
gaps = []
|
||||
for i in range(1, len(trading_dates)):
|
||||
prev_date = trading_dates[i-1]
|
||||
curr_date = trading_dates[i]
|
||||
gap_days = (curr_date - prev_date).days
|
||||
if gap_days > 3: # 超过3天视为可能节假日
|
||||
gaps.append({
|
||||
'prev': prev_date,
|
||||
'curr': curr_date,
|
||||
'gap': gap_days
|
||||
})
|
||||
|
||||
if gaps:
|
||||
print(f"\n可能的节假日(间隔 > 3天):")
|
||||
for gap_info in gaps[:5]: # 只显示前5个
|
||||
print(f" {gap_info['prev'].strftime('%Y-%m-%d')} → {gap_info['curr'].strftime('%Y-%m-%d')} "
|
||||
f"(间隔 {gap_info['gap']} 天)")
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
|
||||
|
||||
def compare_markets(source: FlaskAPIDataSource, year: int):
|
||||
"""
|
||||
比较不同市场的交易日历
|
||||
|
||||
Args:
|
||||
source: Flask API 数据源实例
|
||||
year: 年份
|
||||
"""
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"{year} 年不同市场交易日历对比")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
markets = {
|
||||
'A': 'A股(上交所/深交所)',
|
||||
'US': '美股(NYSE)',
|
||||
'HK': '港股(HKEX)'
|
||||
}
|
||||
|
||||
results = {}
|
||||
for market_code, market_name in markets.items():
|
||||
print(f"\n获取 {market_name} 交易日历...")
|
||||
trading_dates = get_calendar_for_year(source, year, market_code)
|
||||
|
||||
if trading_dates is not None and len(trading_dates) > 0:
|
||||
results[market_code] = {
|
||||
'name': market_name,
|
||||
'dates': trading_dates,
|
||||
'count': len(trading_dates)
|
||||
}
|
||||
|
||||
# 对比统计
|
||||
print(f"\n交易日对比:")
|
||||
print(f"{'市场':<20} {'交易日数':<10} {'起始日期':<12} {'结束日期':<12}")
|
||||
print("-" * 60)
|
||||
|
||||
for market_code, data in results.items():
|
||||
print(f"{data['name']:<20} {data['count']:<10} "
|
||||
f"{data['dates'].min().strftime('%Y-%m-%d'):<12} "
|
||||
f"{data['dates'].max().strftime('%Y-%m-%d'):<12}")
|
||||
|
||||
# 计算差异
|
||||
if len(results) >= 2:
|
||||
print(f"\n交易日差异:")
|
||||
market_codes = list(results.keys())
|
||||
for i in range(len(market_codes)):
|
||||
for j in range(i+1, len(market_codes)):
|
||||
m1 = market_codes[i]
|
||||
m2 = market_codes[j]
|
||||
diff = results[m1]['count'] - results[m2]['count']
|
||||
print(f" {results[m1]['name']} vs {results[m2]['name']}: "
|
||||
f"相差 {abs(diff)} 天 ({'+' if diff > 0 else ''}{diff})")
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
|
||||
|
||||
def show_recent_dates(trading_dates: pd.DatetimeIndex, n: int = 10):
|
||||
"""
|
||||
显示最近的交易日
|
||||
|
||||
Args:
|
||||
trading_dates: 交易日序列
|
||||
n: 显示数量
|
||||
"""
|
||||
if trading_dates is None or len(trading_dates) == 0:
|
||||
return
|
||||
|
||||
print(f"\n最近 {n} 个交易日:")
|
||||
recent_dates = trading_dates[-n:] if len(trading_dates) >= n else trading_dates
|
||||
|
||||
for date in recent_dates:
|
||||
weekday = date.strftime('%A')
|
||||
print(f" {date.strftime('%Y-%m-%d')} ({weekday})")
|
||||
|
||||
|
||||
def export_calendar(trading_dates: pd.DatetimeIndex, output_path: str, year: int):
|
||||
"""
|
||||
导出交易日历到 CSV
|
||||
|
||||
Args:
|
||||
trading_dates: 交易日序列
|
||||
output_path: 输出路径
|
||||
year: 年份
|
||||
"""
|
||||
if trading_dates is None or len(trading_dates) == 0:
|
||||
return
|
||||
|
||||
# 创建 DataFrame
|
||||
df = pd.DataFrame({
|
||||
'date': trading_dates,
|
||||
'year': trading_dates.year,
|
||||
'month': trading_dates.month,
|
||||
'quarter': (trading_dates.month - 1) // 3 + 1,
|
||||
'weekday': [d.strftime('%A') for d in trading_dates]
|
||||
})
|
||||
|
||||
# 导出到 CSV
|
||||
filename = f"{output_path}/trading_calendar_A_{year}.csv"
|
||||
df.to_csv(filename, index=False)
|
||||
print(f"\n✓ 交易日历已导出到: {filename}")
|
||||
print(f" 文件包含 {len(df)} 条记录")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(description='获取 A 股交易日历')
|
||||
|
||||
parser.add_argument(
|
||||
'--year',
|
||||
type=int,
|
||||
default=datetime.now().year,
|
||||
help='年份(默认当前年份)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--start',
|
||||
type=str,
|
||||
help='起始日期 YYYY-MM-DD'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--end',
|
||||
type=str,
|
||||
help='结束日期 YYYY-MM-DD'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--market',
|
||||
type=str,
|
||||
default='A',
|
||||
choices=['A', 'US', 'HK'],
|
||||
help='市场代码(A=A股, US=美股, HK=港股)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--compare',
|
||||
action='store_true',
|
||||
help='对比不同市场交易日历'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--export',
|
||||
action='store_true',
|
||||
help='导出交易日历到 CSV'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
type=str,
|
||||
default='data',
|
||||
help='导出目录(默认 data)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 初始化 Flask API 数据源
|
||||
print("\n初始化 Flask API 数据源...")
|
||||
source = FlaskAPIDataSource()
|
||||
|
||||
# 检查服务健康状态
|
||||
health = source.get_health()
|
||||
if health.get('status') != 'healthy':
|
||||
print(f"✗ Flask API 服务不可用: {health}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"✓ Flask API 服务可用 ({source.base_url})")
|
||||
|
||||
# 获取交易日历信息
|
||||
calendar_info = source.get_calendar_info()
|
||||
if 'error' not in calendar_info:
|
||||
print(f"\n交易日历服务信息:")
|
||||
print(f" 支持市场: {', '.join(calendar_info.get('markets', []))}")
|
||||
print(f" 数据源: {calendar_info.get('source', 'pandas_market_calendars')}")
|
||||
|
||||
# 执行不同功能
|
||||
if args.compare:
|
||||
# 对比不同市场
|
||||
compare_markets(source, args.year)
|
||||
|
||||
elif args.start and args.end:
|
||||
# 自定义日期范围
|
||||
print(f"\n获取 {args.market} 市场交易日历 ({args.start} ~ {args.end})...")
|
||||
trading_dates = source.get_trading_calendar(
|
||||
market=args.market,
|
||||
start_date=args.start,
|
||||
end_date=args.end
|
||||
)
|
||||
|
||||
if trading_dates is not None:
|
||||
print(f"✓ 获取到 {len(trading_dates)} 个交易日")
|
||||
show_recent_dates(trading_dates)
|
||||
|
||||
if args.export:
|
||||
export_calendar(trading_dates, args.output, args.year)
|
||||
|
||||
else:
|
||||
# 获取指定年份交易日历
|
||||
trading_dates = get_calendar_for_year(source, args.year, args.market)
|
||||
|
||||
if trading_dates is not None:
|
||||
# 分析统计
|
||||
analyze_calendar(trading_dates, args.year)
|
||||
|
||||
# 显示最近交易日
|
||||
show_recent_dates(trading_dates)
|
||||
|
||||
# 导出
|
||||
if args.export:
|
||||
export_calendar(trading_dates, args.output, args.year)
|
||||
|
||||
print("\n✓ 完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user