clean(rotation): add simple rotation strategy and remove unused files

New:
- rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines)
- rotation/config_loader.py: standalone config loader
- rotation/config_simple.yaml: 11 assets, 7 groups
- rotation/README_SIMPLE.md: usage guide
- scripts/get_trading_calendar.py: trading calendar fetcher

Removed:
- rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py)
- rotation/results/ output files (gitignored)
- scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts)
- scripts/README_TRADING_CALENDAR.md

Backtest result (2020-01-10 ~ 2026-06-01):
- Total return: 1237.6%, Annual: 52.66%
- Max drawdown: -11.71%, Sharpe: 2.50
This commit is contained in:
2026-06-01 22:28:26 +08:00
parent 3b0208d7d3
commit 451ffa33d2
8 changed files with 1720 additions and 19 deletions

View File

@@ -72,7 +72,18 @@ def run_backtest():
print("=" * 70)
strategy = GlobalRotationStrategy(config)
result = strategy.run()
# 运行回测并导出 detail
output_dir = project_root / "framework_v2" / "results"
output_dir.mkdir(exist_ok=True)
detail_path = output_dir / "backtest_detail_v2.json"
print(f"\n导出 detail JSON: {detail_path}")
result = strategy.run(
export_detail=True,
detail_path=str(detail_path)
)
# 打印结果
print("\n" + "=" * 70)

View File

@@ -391,31 +391,55 @@ class GlobalRotationStrategy(StrategyBase):
aligner = CrossMarketAligner(target_calendar=trading_calendar)
# 提取交易标的的收盘价,并对齐到 A 股日历
print(" [对齐] 对齐 ETF 价格到 A 股日历...")
close_dict = {}
print(" [对齐] 构建可实现价格序列(模拟真实交易)...")
executable_close_dict = {}
for signal_code, trade_code in signal_to_trade.items():
if trade_code in data:
# 提取收盘价
close_series = data[trade_code]['close']
# 使用 signal_code 作为键(与 positions 列名一致)
close_dict[signal_code] = close_series
# 提取开盘价和收盘价
etf_df = data[trade_code]
open_series = etf_df['open'].reindex(trading_calendar, method='ffill')
close_series = etf_df['close'].reindex(trading_calendar, method='ffill')
# 默认使用收盘价
exec_close = close_series.copy()
# 检测调仓日,调整价格以反映真实交易
for i in range(1, len(trading_calendar)):
date = trading_calendar[i]
prev_date = trading_calendar[i-1]
# 获取仓位变化
prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0
curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0
# 买入日:修改前一天价格为当日开盘价
# 这样收益率 = (close[t] - open[t]) / open[t] = 日内收益
if pd.isna(prev_pos) or prev_pos == 0:
if pd.notna(curr_pos) and curr_pos > 0:
exec_close.loc[prev_date] = open_series.loc[date]
# 卖出日:不需要修改(因为 positions[t]=0不会计算收益
executable_close_dict[signal_code] = exec_close
else:
print(f" 警告: {trade_code} 数据不存在,跳过")
# 使用 CrossMarketAligner 对齐多标的收益率
# 内部逻辑:先 ffill 价格到 A 股日历,再计算收益率
print(" [对齐] 计算收益率(先对齐价格,再计算...")
returns_df = aligner.align_multi_asset(close_dict)
print(" [对齐] 计算收益率(使用可实现价格...")
returns_df = aligner.align_multi_asset(executable_close_dict)
print(f" [对齐] 收益率数据: {len(returns_df)} 天, {len(returns_df.columns)} 个标的")
# 对齐 positions 到 A 股日历
# 注意:必须先 reindex 再 ffill因为 reindex(method='ffill') 不会填充已有的 NaN
positions = positions.reindex(trading_calendar)
positions = positions.ffill()
# 卖出日不向前填充(保持 0
positions = positions.ffill().fillna(0)
# 计算策略收益(仓位加权,T+1 执行
positions_delayed = positions.shift(1).fillna(0)
strategy_returns = (positions_delayed * returns_df).sum(axis=1)
# 计算策略收益(仓位加权,无需延迟
# 因为 positions[t] 已表示 t 日的实际持仓,且价格已调整为可实现价格
strategy_returns = (positions * returns_df).sum(axis=1)
# 扣除交易成本
strategy_returns, rebalance_count = self._apply_trade_cost(
@@ -649,6 +673,33 @@ class GlobalRotationStrategy(StrategyBase):
index_return_dict = {}
etf_return_dict = {}
# 构建 ETF 可实现价格序列(与回测一致)
executable_etf_close = {}
for signal_code, trade_code in signal_to_trade.items():
if trade_code in self._data:
etf_df = self._data[trade_code]
open_series = etf_df['open'].reindex(trading_calendar, method='ffill')
close_series = etf_df['close'].reindex(trading_calendar, method='ffill')
# 默认使用 close
exec_close = close_series.copy()
# 检测调仓日,调整价格
for i in range(1, len(trading_calendar)):
date = trading_calendar[i]
prev_date = trading_calendar[i-1]
# 获取仓位变化
prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0
curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0
# 买入日:修改前一天价格为 open
if pd.isna(prev_pos) or prev_pos == 0:
if pd.notna(curr_pos) and curr_pos > 0:
exec_close.loc[prev_date] = open_series.loc[date]
executable_etf_close[signal_code] = exec_close
for signal_code, trade_code in signal_to_trade.items():
# 指数收益率
if signal_code in index_close_dict:
@@ -656,10 +707,10 @@ class GlobalRotationStrategy(StrategyBase):
idx_return = idx_close.pct_change(fill_method=None).fillna(0)
index_return_dict[signal_code] = idx_return
# ETF 收益率
if signal_code in etf_close_dict:
etf_close = etf_close_dict[signal_code].reindex(trading_calendar, method='ffill')
etf_return = etf_close.pct_change(fill_method=None).fillna(0)
# ETF 收益率(使用可实现价格)
if signal_code in executable_etf_close:
etf_exec = executable_etf_close[signal_code]
etf_return = etf_exec.pct_change(fill_method=None).fillna(0)
etf_return_dict[signal_code] = etf_return
# 对齐因子

64
rotation/README_SIMPLE.md Normal file
View File

@@ -0,0 +1,64 @@
# 精简版轮动策略(日迭代)
## 概述
从 A 股交易日历出发,逐日迭代模拟实盘信号生成流程。
**与 V2 的核心差异:**
| 特性 | V2 向量化版本 | 精简版(日迭代) |
|------|--------------|------------------|
| 数据获取 | 一次性获取所有历史数据 | 预加载+CSV缓存 |
| 因子计算 | 全量 rolling 计算 | 每日单独计算 |
| 日历对齐 | 先算因子,后对齐 | 从 A 股日历出发 |
| 收益计算 | 向量化 | 逐日迭代 |
## 快速使用
```bash
cd /path/to/etf
FLASK_API_URL=https://k3s.tokenpluse.xyz python rotation/simple_rotation.py
```
或作为模块导入:
```python
from rotation.simple_rotation import SimpleRotationStrategy
strategy = SimpleRotationStrategy('rotation/config_simple.yaml')
result = strategy.run()
strategy.export_results()
```
## 策略逻辑
### 动量计算
- 加权线性回归:`score = annualized_return * R^2`
- 窗口25 天(可配置)
- 崩盘过滤:连续 3 天跌 > 5% 则清零
### 信号生成(与 V2 完全一致)
1. 每个 group 内选 Top 1非 BOND 组需超过短债动量)
2. 从各组 Top 1 中按动量排序选 Top 3
3. 不足用 BOND 填充
### T+1 执行收益
- **持有日**: close-to-close
- **卖出日**: close-to-open开盘已卖出
- **买入日**: open-to-close日内收益
- 调仓日扣除 0.1% 交易成本
## 输出文件
```
results/
├── simple_rotation_nav.csv # 净值曲线
├── simple_rotation_signals.csv # 每日信号
├── simple_rotation_detail.json # 完整详情
└── simple_rotation_metrics.json # 绩效指标
```
## 数据缓存
首次运行会从 Flask API 下载数据并缓存到 `data/simple_rotation_cache/`
后续运行直接读取缓存,速度显著提升。

457
rotation/config_loader.py Normal file
View File

@@ -0,0 +1,457 @@
"""
配置加载器 - rotation 目录专用(独立实现)
不依赖 framework_v2完全独立的配置加载和验证机制。
使用示例:
from config_loader import load_rotation_config
# 加载配置(自动处理环境变量)
config = load_rotation_config('config_simple.yaml')
# 访问配置
print(f"策略版本: {config.metadata.version}")
print(f"动量窗口: {config.factor.n_days}")
print(f"资产数量: {config.asset_pools.count()}")
"""
import os
import re
import yaml
from pathlib import Path
from typing import Optional, Dict, List, Any
from pydantic import BaseModel, Field, field_validator
from enum import Enum
# ============================================================
# 枚举类型
# ============================================================
class FactorType(str, Enum):
"""因子类型枚举"""
MOMENTUM = "momentum"
SLOPE_R2 = "slope_r2"
WEIGHTED_MOMENTUM = "weighted_momentum"
class PremiumMode(str, Enum):
"""溢价控制模式"""
FILTER = "filter" # 完全排除
PENALIZE = "penalize" # 降权
class ThresholdMode(str, Enum):
"""阈值模式"""
FIXED = "fixed" # 固定阈值
DYNAMIC = "dynamic" # 动态阈值
class DataSourceType(str, Enum):
"""数据源类型"""
FLASK_API = "flask_api"
TUSHARE = "tushare"
YFINANCE = "yfinance"
# ============================================================
# Schema 定义(独立实现)
# ============================================================
class PremiumConfig(BaseModel):
"""溢价控制配置"""
enabled: bool = Field(default=True, description="是否启用")
threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值")
class AssetConfig(BaseModel):
"""标的配置"""
name: str = Field(..., description="标的名称")
group: str = Field(..., description="策略分组")
signal_source: str = Field(..., description="信号来源代码")
trade_source: str = Field(..., description="交易来源代码")
description: Optional[str] = Field(None, description="标的描述")
etf: Optional[str] = Field(None, description="ETF代码兼容旧配置")
premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置")
@property
def is_cross_market(self) -> bool:
"""是否跨市场"""
return self.signal_source != self.trade_source
class AssetPool(BaseModel):
"""资产池(扁平化设计)"""
assets: Dict[str, AssetConfig] = Field(
default_factory=dict,
description="所有标的"
)
@property
def by_group(self) -> Dict[str, Dict[str, AssetConfig]]:
"""按策略分组"""
groups = {}
for code, asset in self.assets.items():
group = asset.group
if group not in groups:
groups[group] = {}
groups[group][code] = asset
return groups
@property
def groups(self) -> list:
"""获取所有策略分组"""
return list(self.by_group.keys())
def get_signal_codes(self, group: str = None) -> list:
"""获取信号标的"""
if group:
return [
asset.signal_source
for asset in self.assets.values()
if asset.group == group
]
return [asset.signal_source for asset in self.assets.values()]
def get_trade_codes(self, group: str = None) -> list:
"""获取交易标的"""
if group:
return [
asset.trade_source
for asset in self.assets.values()
if asset.group == group
]
return [asset.trade_source for asset in self.assets.values()]
def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]:
"""获取信号→交易映射"""
mapping = {}
for asset in self.assets.values():
if group and asset.group != group:
continue
mapping[asset.signal_source] = asset.trade_source
return mapping
def count(self, group: str = None) -> int:
"""获取标的数量"""
if group:
return len([a for a in self.assets.values() if a.group == group])
return len(self.assets)
class FactorConfig(BaseModel):
"""因子配置"""
type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM)
n_days: int = Field(default=25, ge=5, le=250)
class DynamicThresholdConfig(BaseModel):
"""动态阈值配置"""
reference: str = Field(..., description="阈值参考标的")
ratio: float = Field(default=1.0, ge=0, description="阈值倍数")
fallback_enabled: bool = Field(default=True)
fallback_value: float = Field(default=0.0)
class ThresholdConfig(BaseModel):
"""阈值配置"""
mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC)
fixed_value: float = Field(default=0.0)
dynamic: Optional[DynamicThresholdConfig] = Field(None)
class RotationConfig(BaseModel):
"""轮动配置"""
select_num: int = Field(default=3, ge=1, le=10)
diversified: bool = Field(default=True)
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
class RebalanceConfig(BaseModel):
"""调仓配置"""
min_hold_days: int = Field(default=1, ge=1)
score_threshold: float = Field(default=0.0)
trade_cost: float = Field(default=0.001, ge=0, le=0.01)
class PremiumControlConfig(BaseModel):
"""溢价控制配置"""
enabled: bool = Field(default=False)
default_threshold: float = Field(default=0.10)
mode: PremiumMode = Field(default=PremiumMode.FILTER)
penalty_factor: float = Field(default=0.5)
market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None)
class DataSourceConfig(BaseModel):
"""数据源配置"""
type: DataSourceType = Field(...)
enabled: bool = Field(default=True)
timeout: int = Field(default=120, ge=10, le=600)
url: Optional[str] = Field(None)
class DataConfig(BaseModel):
"""数据配置"""
sources: List[DataSourceConfig] = Field(...)
@field_validator('sources')
@classmethod
def check_at_least_one(cls, v):
"""至少配置一个数据源"""
if not v:
raise ValueError("必须配置至少一个数据源")
return v
class BenchmarkConfig(BaseModel):
"""基准配置"""
code: str = Field(...)
name: str = Field(...)
class BacktestConfig(BaseModel):
"""回测配置"""
start_date: str = Field(...)
end_date: Optional[str] = Field(None)
class MetadataConfig(BaseModel):
"""配置元数据"""
version: str = Field(default="1.0.0")
strategy: str = Field(default="rotation")
description: str = Field(default="")
last_updated: Optional[str] = Field(None)
class RotationStrategyConfig(BaseModel):
"""ETF轮动策略完整配置"""
metadata: MetadataConfig = Field(default_factory=MetadataConfig)
asset_pools: AssetPool = Field(...)
benchmark: BenchmarkConfig = Field(...)
backtest: BacktestConfig = Field(...)
factor: FactorConfig = Field(default_factory=FactorConfig)
rotation: RotationConfig = Field(default_factory=RotationConfig)
rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig)
premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig)
data: DataConfig = Field(...)
# ============================================================
# 配置加载器(独立实现)
# ============================================================
class ConfigLoader:
"""配置加载器(独立实现)"""
def __init__(self, config_dir: str = None):
"""初始化"""
if config_dir is None:
config_dir = Path(__file__).parent
self.config_dir = Path(config_dir)
def load(self, config_file: str) -> RotationStrategyConfig:
"""加载配置文件"""
# 1. 解析路径
config_path = self._resolve_path(config_file)
# 2. 读取 YAML
with open(config_path, 'r', encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
# 3. 环境变量替换
config_dict = self._substitute_env_vars(config_dict)
# 4. Pydantic 验证
config = RotationStrategyConfig(**config_dict)
return config
def _resolve_path(self, config_file: str) -> Path:
"""解析配置文件路径"""
path = Path(config_file)
if path.is_absolute():
if not path.exists():
raise FileNotFoundError(f"配置文件不存在: {path}")
return path
# 相对路径:先在配置目录查找
config_path = self.config_dir / path
if config_path.exists():
return config_path
# 然后在当前工作目录查找
cwd_path = Path.cwd() / path
if cwd_path.exists():
return cwd_path
raise FileNotFoundError(
f"配置文件未找到: {path}\n"
f"搜索路径:\n"
f" - {config_path}\n"
f" - {cwd_path}"
)
def _substitute_env_vars(self, config: Any) -> Any:
"""替换环境变量"""
if isinstance(config, dict):
return {
key: self._substitute_env_vars(value)
for key, value in config.items()
}
elif isinstance(config, list):
return [
self._substitute_env_vars(item)
for item in config
]
elif isinstance(config, str):
# 匹配 ${VAR_NAME} 或 ${VAR_NAME:default}
pattern = r'\$\{([^}]+)\}'
def replace_match(match):
var_expr = match.group(1)
if ':' in var_expr:
var_name, default = var_expr.split(':', 1)
else:
var_name = var_expr
default = None
value = os.getenv(var_name)
if value is None:
if default is not None:
return default
else:
raise ValueError(
f"环境变量未设置: {var_name}\n"
f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}"
)
return value
return re.sub(pattern, replace_match, config)
else:
return config
# ============================================================
# 快捷函数
# ============================================================
def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig:
"""
加载 rotation 配置
Args:
config_file: 配置文件名
Returns:
RotationStrategyConfig: 验证后的配置对象
环境变量:
必须设置 FLASK_API_URL 或提供默认值
"""
loader = ConfigLoader()
return loader.load(config_file)
def get_config_info(config: RotationStrategyConfig) -> dict:
"""获取配置摘要"""
return {
'version': config.metadata.version,
'strategy': config.metadata.strategy,
'description': config.metadata.description,
'asset_count': config.asset_pools.count(),
'groups': config.asset_pools.groups,
'factor_type': config.factor.type.value,
'n_days': config.factor.n_days,
'select_num': config.rotation.select_num,
'start_date': config.backtest.start_date,
'end_date': config.backtest.end_date or '至今',
'threshold_mode': config.rotation.threshold.mode.value,
}
def print_config_summary(config: RotationStrategyConfig) -> None:
"""打印配置摘要"""
info = get_config_info(config)
print("\n" + "=" * 50)
print(" ETF 轮动策略配置摘要")
print("=" * 50)
print(f"\n版本: {info['version']}")
print(f"策略: {info['strategy']}")
print(f"描述: {info['description']}")
print(f"\n资产池:")
print(f" 总数: {info['asset_count']}")
print(f" 分组: {', '.join(info['groups'])}")
print(f"\n因子配置:")
print(f" 类型: {info['factor_type']}")
print(f" 窗口: {info['n_days']}")
print(f"\n轮动配置:")
print(f" 选股数: {info['select_num']}")
print(f" 阈值模式: {info['threshold_mode']}")
print(f"\n回测配置:")
print(f" 起始: {info['start_date']}")
print(f" 结束: {info['end_date']}")
print(f"\n各分组资产:")
for group, assets in config.asset_pools.by_group.items():
print(f"\n [{group}] ({len(assets)} 个):")
for code, asset in assets.items():
cross_market = "✓ 跨市场" if asset.is_cross_market else ""
print(f" - {asset.name} ({code})")
print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}")
print("\n" + "=" * 50)
# ============================================================
# 测试
# ============================================================
if __name__ == "__main__":
"""测试配置加载器"""
# 设置环境变量(如果没有设置)
if 'FLASK_API_URL' not in os.environ:
print("提示: FLASK_API_URL 未设置,使用默认值")
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
try:
# 加载配置
print("\n[测试] 加载 config_simple.yaml...")
config = load_rotation_config()
# 打印摘要
print_config_summary(config)
# 测试基本功能
print("\n[验证] 基本功能测试:")
print(f" ✓ 配置对象类型: {type(config).__name__}")
print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}")
print(f" ✓ 资产数量: {config.asset_pools.count()}")
print(f" ✓ 分组数量: {len(config.asset_pools.groups)}")
# 测试数据源
print("\n[验证] 数据源配置:")
for i, source in enumerate(config.data.sources, 1):
print(f" {i}. {source.type.value}")
print(f" URL: {source.url}")
print(f" 启用: {source.enabled}")
print("\n✓ 所有测试通过!")
except Exception as e:
print(f"\n✗ 加载失败: {e}")
import traceback
traceback.print_exc()

186
rotation/config_simple.yaml Normal file
View File

@@ -0,0 +1,186 @@
# ETF轮动策略配置V2 框架)
#
# 配置版本: 2.0.0
# 最后更新: 2024-04-16
# 策略名称: rotation
# 描述: 全球资产大类轮动策略 - 复现 V1 结果
# ============================================================
# 元数据
# ============================================================
metadata:
version: "2.0.0"
strategy: "rotation"
description: "全球资产大类轮动策略 V2 - 复现 V1 结果"
last_updated: "2024-04-16"
# ============================================================
# 资产池配置(扁平化设计:严格对齐 V1 config.yaml
# ============================================================
asset_pools:
assets:
# 中国A股指数
"399006.SZ":
name: "创业板指"
group: "A"
signal_source: "399006.SZ"
trade_source: "159915.SZ"
description: "创业板指数"
"H30269.CSI":
name: "中证红利低波"
group: "A"
signal_source: "H30269.CSI"
trade_source: "512890.SH"
description: "红利低波指数"
# 全球市场
"NDX":
name: "纳指100"
group: "US"
signal_source: "NDX"
trade_source: "513100.SH"
description: "纳斯达克100指数"
"N225":
name: "日经225"
group: "JP"
signal_source: "N225"
trade_source: "513520.SH"
description: "日经225指数"
"GDAXI":
name: "德国DAX"
group: "EU"
signal_source: "GDAXI"
trade_source: "513030.SH"
description: "德国DAX指数"
"HSI":
name: "恒生指数"
group: "HK"
signal_source: "HSI"
trade_source: "159920.SZ"
description: "恒生指数"
"HSTECH.HK":
name: "恒生科技"
group: "HK"
signal_source: "HSTECH.HK"
trade_source: "513130.SH"
description: "恒生科技指数"
# 商品(使用 COMEX/WTI 期货替代上期所主力合约,数据更长)
"GC=F":
name: "黄金"
group: "COMMODITY"
signal_source: "GC=F"
trade_source: "518880.SH"
description: "COMEX黄金期货2000年至今"
"CL=F":
name: "原油"
group: "COMMODITY"
signal_source: "CL=F"
trade_source: "160723.SZ"
description: "WTI原油期货2000年至今"
"HG=F":
name: "有色金属"
group: "COMMODITY"
signal_source: "HG=F"
trade_source: "159980.SZ"
description: "COMEX铜期货2000年至今"
# 防御类资产:短债指数
# 931862.CSI = 中证0-9个月国债指数短债指数
# 数据范围2007-12-31开始约19年数据
# 久期:极短(<1年波动极小熊市防御效果最佳
# 收益归因标的收益约17%决策收益约83%
# 注意无对应ETF可交易直接使用指数数据计算动量和收益
"931862.CSI":
name: "短债指数"
group: "BOND"
signal_source: "931862.CSI"
trade_source: "931862.CSI"
description: "中证0-9个月国债指数久期<1年防御配置"
# ============================================================
# 基准配置
# ============================================================
benchmark:
code: "000300.SH"
name: "沪深300"
# ============================================================
# 回测配置
# ============================================================
backtest:
start_date: "2020-01-10" # 与 V1 保持一致(第一个完整交易日)
# end_date: "2026-05-22" # 与 V1 保持一致
# ============================================================
# 因子配置
# ============================================================
factor:
type: "weighted_momentum" # 加权动量
n_days: 25 # 25 天窗口
# ============================================================
# 轮动配置
# ============================================================
rotation:
select_num: 3 # 选择 Top-3
diversified: true # 强制分散化:每个大类只选 Top 1
# 阈值配置V3 动态阈值)
threshold:
mode: "dynamic" # 动态阈值模式
fixed_value: 0.0 # 固定阈值mode=fixed时使用
# 动态阈值配置(使用短债动量作为阈值)
dynamic:
reference: "931862.CSI" # 阈值参考标的(短债指数)
ratio: 1.0 # 阈值 = 短债动量 × ratio
fallback_enabled: true # 参考不可用时是否回退
fallback_value: 0.0 # 回退值
# ============================================================
# 调仓配置
# ============================================================
rebalance:
min_hold_days: 1
score_threshold: 0.0
trade_cost: 0.001 # 0.1% 交易成本
# ============================================================
# 溢价控制配置
# ============================================================
premium_control:
enabled: false # 启用溢价控制
default_threshold: 0.10 # 默认溢价阈值 10%
mode: "filter" # filter(完全排除) 或 penalize(降权)
penalty_factor: 0.5 # 降权模式下的惩罚系数
# 按市场覆盖配置
market_overrides:
A: # A股 ETF
enabled: false # 不启用(溢价通常 < 0.5%
HK: # 港股 ETF
enabled: true
threshold: 0.10 # 阈值 10%
US: # 美股 ETF
enabled: true
threshold: 0.10 # 阈值 10%
COMMODITY: # 商品 ETF
enabled: false # 不启用
# ============================================================
# 数据配置
# ============================================================
data:
sources:
- type: "flask_api"
enabled: true
url: "${FLASK_API_URL}"
timeout: 120

584
rotation/simple_rotation.py Normal file
View File

@@ -0,0 +1,584 @@
"""
Simple Rotation Strategy (Daily Iteration)
From A-share trading calendar, iterate daily:
1. Get signal source last N days -> compute momentum
2. Get bond momentum (dynamic threshold)
3. Group selection -> generate holdings
4. Compare with yesterday -> compute T+1 return
5. Update NAV
"""
import os
import sys
import math
import json
import time
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
# ============================================================
# Pure functions: momentum
# ============================================================
def weighted_momentum_score(prices: np.ndarray) -> float:
"""Weighted linear regression momentum score = annualized_return * R^2"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_return * r2
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
# ============================================================
# Data cache
# ============================================================
class DataCache:
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
if cache_dir is None:
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _cache_path(self, code: str, adj: str) -> Path:
prefix = 'index' if adj == 'raw' else 'etf'
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"{prefix}_{safe_code}.csv"
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Preload full history and cache to CSV"""
cache_path = self._cache_path(code, adj)
if cache_path.exists():
try:
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
if len(df) > 0:
cs = df.index.min().strftime('%Y-%m-%d')
ce = df.index.max().strftime('%Y-%m-%d')
if cs <= start_date and ce >= end_date:
return df
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
if new_start <= end_date:
new_df = self._fetch_api(code, new_start, end_date, adj)
if new_df is not None and len(new_df) > 0:
df = pd.concat([df, new_df])
df = df[~df.index.duplicated(keep='last')]
df.to_csv(cache_path)
return df
except Exception:
pass
df = self._fetch_api(code, start_date, end_date, adj)
if df is not None and len(df) > 0:
df.to_csv(cache_path)
return df
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
print(f" x {code}: HTTP {resp.status_code}")
return None
data = resp.json()
if 'error' in data:
print(f" x {code}: {data['error']}")
return None
records = data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
print(f" + {code}: {len(df)} rows ({adj})")
return df
except requests.exceptions.Timeout:
if attempt < 2:
continue
print(f" x {code}: timeout")
return None
except Exception as e:
print(f" x {code}: {e}")
return None
return None
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
"""Fetch trading calendar from API"""
url = f"{self.base_url}/api/v1/trading-calendar"
params = {'market': market, 'start': start_date, 'end': end_date}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
continue
return None
data = resp.json()
if 'error' in data:
print(f" x calendar: {data['error']}")
return None
dates = data.get('trading_dates', [])
if not dates:
return pd.DatetimeIndex([])
result = pd.DatetimeIndex(dates)
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
return result
except Exception as e:
if attempt < 2:
continue
print(f" x calendar: {e}")
return None
return None
# ============================================================
# Core strategy class
# ============================================================
class SimpleRotationStrategy:
"""
Simple rotation strategy (daily iteration)
Flow:
1. Load config
2. Preload all historical data (index raw + ETF hfq)
3. Fetch A-share trading calendar
4. For each trading day:
- Compute momentum (last 25 days)
- Group selection -> generate holdings
- Compare with yesterday -> compute T+1 return
- Update NAV
5. Export results
"""
def __init__(self, config_path: str = None):
if config_path is None:
config_path = Path(__file__).parent / 'config_simple.yaml'
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
# Strategy params
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
# Dynamic threshold
threshold = self.config.rotation.threshold
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
# Signal codes and mappings
self.signal_codes = []
self.signal_to_trade = {}
self.code_to_group = {}
for code, asset in self.config.asset_pools.assets.items():
self.signal_codes.append(asset.signal_source)
self.signal_to_trade[asset.signal_source] = asset.trade_source
self.code_to_group[asset.signal_source] = asset.group
# Data source
data_source = self.config.data.sources[0]
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
# Preloaded data
self.index_data: Dict[str, pd.DataFrame] = {}
self.etf_data: Dict[str, pd.DataFrame] = {}
# Results
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
print("\n[1/4] Preloading signal sources (index raw)...")
for code in self.signal_codes:
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
if df is not None:
self.index_data[code] = df
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
print("\n[2/4] Preloading trade sources (ETF hfq)...")
trade_codes = set(self.signal_to_trade.values())
for code in trade_codes:
is_bond = any(
a.trade_source == code and a.group == 'BOND'
for a in self.config.asset_pools.assets.values()
)
adj = 'raw' if is_bond else 'hfq'
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
if df is not None:
self.etf_data[code] = df
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK")
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum for a single code on a given date"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
return weighted_momentum_score(prices)
def _generate_signals(self, date: pd.Timestamp) -> List[str]:
"""
Generate rotation signals (group competition + dynamic threshold + BOND fill)
Logic (identical to V2):
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
2. From all group winners: sort by momentum, select Top select_num
3. Fill remaining slots with BOND
"""
factors: Dict[str, float] = {}
for code in self.signal_codes:
score = self._compute_momentum(code, date)
if score is not None:
factors[code] = score
if not factors:
return []
bond_momentum = None
if self.use_dynamic_threshold and self.bond_code:
bond_momentum = self._compute_momentum(self.bond_code, date)
if bond_momentum is None:
bond_momentum = self.fallback_value
groups = self.config.asset_pools.by_group
selected_by_group: Dict[str, Tuple[str, float]] = {}
for group_name, assets in groups.items():
group_codes = [a.signal_source for a in assets.values()]
group_factors = {c: factors[c] for c in group_codes if c in factors}
if not group_factors:
continue
if group_name != 'BOND' and bond_momentum is not None:
thresh = bond_momentum * self.bond_ratio
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
if not group_factors:
continue
top_code = max(group_factors, key=group_factors.get)
selected_by_group[group_name] = (top_code, group_factors[top_code])
if not selected_by_group:
return []
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
final_holdings = [code for code, _ in candidates[:self.select_num]]
if len(final_holdings) < self.select_num and self.bond_code:
if self.bond_code not in final_holdings:
n_slots = self.select_num - len(final_holdings)
final_holdings.extend([self.bond_code] * n_slots)
return sorted(final_holdings)
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
Note: Index data may lack 'open' column; use close as fallback.
"""
if trade_code not in self.etf_data:
return None
df = self.etf_data[trade_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < 2:
return None
today_row = recent.iloc[-1]
prev_row = recent.iloc[-2]
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close = float(today_row['close'])
prev_close = float(prev_row['close'])
# Handle missing open (common for index data like 931862.CSI)
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
if pd.isna(open_price) or open_price == 0:
open_price = close
if pd.isna(close) or pd.isna(prev_close):
return None
return {
'open': open_price,
'close': close,
'prev_close': prev_close,
}
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution):
- Hold: close-to-close
- Sell: close-to-open (sold at open)
- Buy: open-to-close (intraday)
"""
if not old_holdings:
if not new_holdings:
return 0.0
weight = 1.0 / len(new_holdings)
ret = 0.0
for code in new_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0:
ret += weight * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
weight = 1.0 / len(old_holdings)
daily_return = 0.0
for code in old_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += weight * r
for code in new_holdings:
if code not in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += weight * r
if is_rebalance:
daily_return -= self.trade_cost
return daily_return
def run(self) -> dict:
"""Main backtest loop"""
print("=" * 60)
print(" Simple Rotation Strategy (Daily Iteration)")
print("=" * 60)
self._preload_data()
print("\n[3/4] Fetching A-share trading calendar...")
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
if self.trading_calendar is None or len(self.trading_calendar) == 0:
print(" x Calendar fetch failed")
return {}
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
current_holdings: List[str] = []
nav = 1.0
rebalance_count = 0
for i, date in enumerate(self.trading_calendar):
new_holdings = self._generate_signals(date)
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
daily_return = self._calculate_daily_return(
current_holdings, new_holdings, date, is_rebalance
)
nav *= (1 + daily_return)
if is_rebalance:
rebalance_count += 1
self.daily_records.append({
'date': date.strftime('%Y-%m-%d'),
'nav': round(nav, 6),
'daily_return': round(daily_return, 6),
'is_rebalance': is_rebalance,
'holdings': sorted(new_holdings),
'added': sorted(set(new_holdings) - set(current_holdings)),
'removed': sorted(set(current_holdings) - set(new_holdings)),
})
current_holdings = new_holdings
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
print(f" [{i+1}/{len(self.trading_calendar)}] "
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
metrics = self._compute_metrics(rebalance_count)
print("\n" + "=" * 60)
print(" Backtest Complete")
print("=" * 60)
if self.daily_records:
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
print(f" Trading days: {len(self.daily_records)}")
print(f" Total return: {metrics['total_return']:.2%}")
print(f" Annual return: {metrics['annual_return']:.2%}")
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
print(f" Win rate: {metrics['win_rate']:.2%}")
print(f" Rebalances: {rebalance_count}")
print("=" * 60)
return {
'metrics': metrics,
'daily_records': self.daily_records,
}
def _compute_metrics(self, rebalance_count: int) -> dict:
"""Compute performance metrics"""
if not self.daily_records:
return {}
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
nav = df['nav']
returns = df['daily_return']
total_return = nav.iloc[-1] / nav.iloc[0] - 1
n_days = len(df)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
peak = nav.cummax()
drawdown = (nav - peak) / peak
max_drawdown = drawdown.min()
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
non_zero = returns[returns != 0]
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'calmar_ratio': calmar,
'win_rate': win_rate,
'n_days': n_days,
'rebalance_count': rebalance_count,
}
def export_results(self, output_dir: str = None):
"""Export backtest results to CSV and JSON"""
if not self.daily_records:
print(" x No results to export")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(self.daily_records)
# NAV curve
nav_path = output_dir / 'simple_rotation_nav.csv'
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
print(f" + NAV: {nav_path}")
# Signals
sig_path = output_dir / 'simple_rotation_signals.csv'
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
print(f" + Signals: {sig_path}")
# Detail JSON
detail_path = output_dir / 'simple_rotation_detail.json'
detail = {
'meta': {
'mode': 'Simple: Daily Iteration',
'start_date': self.config.backtest.start_date,
'end_date': self.config.backtest.end_date or 'now',
'n_days': self.n_days,
'select_num': self.select_num,
'trade_cost': self.trade_cost,
'bond_threshold': {
'enabled': self.use_dynamic_threshold,
'bond_code': self.bond_code,
'ratio': self.bond_ratio,
},
},
'days': self.daily_records,
}
with open(detail_path, 'w', encoding='utf-8') as f:
json.dump(detail, f, ensure_ascii=False, indent=2)
print(f" + Detail: {detail_path}")
# Metrics JSON
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
metrics_path = output_dir / 'simple_rotation_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")
# ============================================================
# Entry point
# ============================================================
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
strategy = SimpleRotationStrategy()
result = strategy.run()
if result:
strategy.export_results()

View File

@@ -121,13 +121,13 @@ def run_strategy(config_path: str = "strategies/rotation/config.yaml") -> dict:
logger.info(f"执行命令: {' '.join(cmd)}")
# 执行策略
# 执行策略增加超时到15分钟因为需要获取多市场数据
result = subprocess.run(
cmd,
capture_output=True,
text=True,
cwd=project_root,
timeout=300 # 5分钟超时
timeout=900 # 15分钟超时原5分钟不够数据获取串行需要更长时间
)
if result.returncode != 0:

View File

@@ -0,0 +1,348 @@
"""
获取 A 股交易日历脚本
使用 Flask API 交易日历服务获取 A 股交易日历
支持多市场、多年份的交易日查询
用法:
python scripts/get_trading_calendar.py
python scripts/get_trading_calendar.py --year 2024
python scripts/get_trading_calendar.py --start 2024-01-01 --end 2024-12-31
"""
import sys
import argparse
from pathlib import Path
from datetime import datetime, timedelta
import pandas as pd
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
# 导入 Flask API 数据源
from datasource.flask_api_source import FlaskAPIDataSource
def get_calendar_for_year(source: FlaskAPIDataSource, year: int, market: str = 'A'):
"""
获取指定年份的交易日历
Args:
source: Flask API 数据源实例
year: 年份(如 2024
market: 市场代码('A', 'US', 'HK'
Returns:
pd.DatetimeIndex: 交易日序列
"""
start_date = f"{year}-01-01"
end_date = f"{year}-12-31"
print(f"\n获取 {year}{market} 市场交易日历...")
trading_dates = source.get_trading_calendar(
market=market,
start_date=start_date,
end_date=end_date
)
if trading_dates is None or len(trading_dates) == 0:
print(f"{year}{market} 市场无交易日数据")
return None
return trading_dates
def analyze_calendar(trading_dates: pd.DatetimeIndex, year: int):
"""
分析交易日历统计信息
Args:
trading_dates: 交易日序列
year: 年份
"""
if trading_dates is None or len(trading_dates) == 0:
return
print(f"\n{'=' * 60}")
print(f"{year} 年 A 股交易日历分析")
print(f"{'=' * 60}")
# 基本统计
total_days = len(trading_dates)
print(f"\n基本统计:")
print(f" 总交易日: {total_days}")
print(f" 起始日期: {trading_dates.min().strftime('%Y-%m-%d')}")
print(f" 结束日期: {trading_dates.max().strftime('%Y-%m-%d')}")
# 按月份统计
print(f"\n按月份统计:")
monthly_counts = {}
for date in trading_dates:
month = date.month
monthly_counts[month] = monthly_counts.get(month, 0) + 1
for month in range(1, 13):
count = monthly_counts.get(month, 0)
month_name = datetime(2024, month, 1).strftime('%B')
print(f" {month:02d}月 ({month_name}): {count}")
# 按季度统计
print(f"\n按季度统计:")
quarterly_counts = {1: 0, 2: 0, 3: 0, 4: 0}
for date in trading_dates:
quarter = (date.month - 1) // 3 + 1
quarterly_counts[quarter] += 1
for quarter, count in quarterly_counts.items():
print(f" Q{quarter}: {count}")
# 特殊日期统计
print(f"\n特殊日期:")
first_date = trading_dates.min()
last_date = trading_dates.max()
print(f" 首个交易日: {first_date.strftime('%Y-%m-%d')} ({first_date.strftime('%A')})")
print(f" 最后交易日: {last_date.strftime('%Y-%m-%d')} ({last_date.strftime('%A')})")
# 查找节假日后的首个交易日(通过间隔判断)
gaps = []
for i in range(1, len(trading_dates)):
prev_date = trading_dates[i-1]
curr_date = trading_dates[i]
gap_days = (curr_date - prev_date).days
if gap_days > 3: # 超过3天视为可能节假日
gaps.append({
'prev': prev_date,
'curr': curr_date,
'gap': gap_days
})
if gaps:
print(f"\n可能的节假日(间隔 > 3天:")
for gap_info in gaps[:5]: # 只显示前5个
print(f" {gap_info['prev'].strftime('%Y-%m-%d')}{gap_info['curr'].strftime('%Y-%m-%d')} "
f"(间隔 {gap_info['gap']} 天)")
print(f"\n{'=' * 60}")
def compare_markets(source: FlaskAPIDataSource, year: int):
"""
比较不同市场的交易日历
Args:
source: Flask API 数据源实例
year: 年份
"""
print(f"\n{'=' * 60}")
print(f"{year} 年不同市场交易日历对比")
print(f"{'=' * 60}")
markets = {
'A': 'A股上交所/深交所)',
'US': '美股NYSE',
'HK': '港股HKEX'
}
results = {}
for market_code, market_name in markets.items():
print(f"\n获取 {market_name} 交易日历...")
trading_dates = get_calendar_for_year(source, year, market_code)
if trading_dates is not None and len(trading_dates) > 0:
results[market_code] = {
'name': market_name,
'dates': trading_dates,
'count': len(trading_dates)
}
# 对比统计
print(f"\n交易日对比:")
print(f"{'市场':<20} {'交易日数':<10} {'起始日期':<12} {'结束日期':<12}")
print("-" * 60)
for market_code, data in results.items():
print(f"{data['name']:<20} {data['count']:<10} "
f"{data['dates'].min().strftime('%Y-%m-%d'):<12} "
f"{data['dates'].max().strftime('%Y-%m-%d'):<12}")
# 计算差异
if len(results) >= 2:
print(f"\n交易日差异:")
market_codes = list(results.keys())
for i in range(len(market_codes)):
for j in range(i+1, len(market_codes)):
m1 = market_codes[i]
m2 = market_codes[j]
diff = results[m1]['count'] - results[m2]['count']
print(f" {results[m1]['name']} vs {results[m2]['name']}: "
f"相差 {abs(diff)} 天 ({'+' if diff > 0 else ''}{diff})")
print(f"\n{'=' * 60}")
def show_recent_dates(trading_dates: pd.DatetimeIndex, n: int = 10):
"""
显示最近的交易日
Args:
trading_dates: 交易日序列
n: 显示数量
"""
if trading_dates is None or len(trading_dates) == 0:
return
print(f"\n最近 {n} 个交易日:")
recent_dates = trading_dates[-n:] if len(trading_dates) >= n else trading_dates
for date in recent_dates:
weekday = date.strftime('%A')
print(f" {date.strftime('%Y-%m-%d')} ({weekday})")
def export_calendar(trading_dates: pd.DatetimeIndex, output_path: str, year: int):
"""
导出交易日历到 CSV
Args:
trading_dates: 交易日序列
output_path: 输出路径
year: 年份
"""
if trading_dates is None or len(trading_dates) == 0:
return
# 创建 DataFrame
df = pd.DataFrame({
'date': trading_dates,
'year': trading_dates.year,
'month': trading_dates.month,
'quarter': (trading_dates.month - 1) // 3 + 1,
'weekday': [d.strftime('%A') for d in trading_dates]
})
# 导出到 CSV
filename = f"{output_path}/trading_calendar_A_{year}.csv"
df.to_csv(filename, index=False)
print(f"\n✓ 交易日历已导出到: {filename}")
print(f" 文件包含 {len(df)} 条记录")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='获取 A 股交易日历')
parser.add_argument(
'--year',
type=int,
default=datetime.now().year,
help='年份(默认当前年份)'
)
parser.add_argument(
'--start',
type=str,
help='起始日期 YYYY-MM-DD'
)
parser.add_argument(
'--end',
type=str,
help='结束日期 YYYY-MM-DD'
)
parser.add_argument(
'--market',
type=str,
default='A',
choices=['A', 'US', 'HK'],
help='市场代码A=A股, US=美股, HK=港股)'
)
parser.add_argument(
'--compare',
action='store_true',
help='对比不同市场交易日历'
)
parser.add_argument(
'--export',
action='store_true',
help='导出交易日历到 CSV'
)
parser.add_argument(
'--output',
type=str,
default='data',
help='导出目录(默认 data'
)
args = parser.parse_args()
# 初始化 Flask API 数据源
print("\n初始化 Flask API 数据源...")
source = FlaskAPIDataSource()
# 检查服务健康状态
health = source.get_health()
if health.get('status') != 'healthy':
print(f"✗ Flask API 服务不可用: {health}")
sys.exit(1)
print(f"✓ Flask API 服务可用 ({source.base_url})")
# 获取交易日历信息
calendar_info = source.get_calendar_info()
if 'error' not in calendar_info:
print(f"\n交易日历服务信息:")
print(f" 支持市场: {', '.join(calendar_info.get('markets', []))}")
print(f" 数据源: {calendar_info.get('source', 'pandas_market_calendars')}")
# 执行不同功能
if args.compare:
# 对比不同市场
compare_markets(source, args.year)
elif args.start and args.end:
# 自定义日期范围
print(f"\n获取 {args.market} 市场交易日历 ({args.start} ~ {args.end})...")
trading_dates = source.get_trading_calendar(
market=args.market,
start_date=args.start,
end_date=args.end
)
if trading_dates is not None:
print(f"✓ 获取到 {len(trading_dates)} 个交易日")
show_recent_dates(trading_dates)
if args.export:
export_calendar(trading_dates, args.output, args.year)
else:
# 获取指定年份交易日历
trading_dates = get_calendar_for_year(source, args.year, args.market)
if trading_dates is not None:
# 分析统计
analyze_calendar(trading_dates, args.year)
# 显示最近交易日
show_recent_dates(trading_dates)
# 导出
if args.export:
export_calendar(trading_dates, args.output, args.year)
print("\n✓ 完成!")
if __name__ == "__main__":
main()