From 451ffa33d201c4338729bdff2fa3ddf6be1fb485 Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 1 Jun 2026 22:28:26 +0800 Subject: [PATCH] clean(rotation): add simple rotation strategy and remove unused files New: - rotation/simple_rotation.py: daily-iteration rotation strategy (584 lines) - rotation/config_loader.py: standalone config loader - rotation/config_simple.yaml: 11 assets, 7 groups - rotation/README_SIMPLE.md: usage guide - scripts/get_trading_calendar.py: trading calendar fetcher Removed: - rotation/example_usage.py, run_strategy.py (replaced by simple_rotation.py) - rotation/results/ output files (gitignored) - scripts/verify_*.py, calculate_returns_from_detail.py (one-off scripts) - scripts/README_TRADING_CALENDAR.md Backtest result (2020-01-10 ~ 2026-06-01): - Total return: 1237.6%, Annual: 52.66% - Max drawdown: -11.71%, Sharpe: 2.50 --- .../scripts/backtest_global_rotation.py | 13 +- framework_v2/strategies/rotation/rotation.py | 83 ++- rotation/README_SIMPLE.md | 64 ++ rotation/config_loader.py | 457 ++++++++++++++ rotation/config_simple.yaml | 186 ++++++ rotation/simple_rotation.py | 584 ++++++++++++++++++ scripts/daily_scheduler.py | 4 +- scripts/get_trading_calendar.py | 348 +++++++++++ 8 files changed, 1720 insertions(+), 19 deletions(-) create mode 100644 rotation/README_SIMPLE.md create mode 100644 rotation/config_loader.py create mode 100644 rotation/config_simple.yaml create mode 100644 rotation/simple_rotation.py create mode 100644 scripts/get_trading_calendar.py diff --git a/framework_v2/scripts/backtest_global_rotation.py b/framework_v2/scripts/backtest_global_rotation.py index 649195a..321be1b 100644 --- a/framework_v2/scripts/backtest_global_rotation.py +++ b/framework_v2/scripts/backtest_global_rotation.py @@ -72,7 +72,18 @@ def run_backtest(): print("=" * 70) strategy = GlobalRotationStrategy(config) - result = strategy.run() + + # 运行回测并导出 detail + output_dir = project_root / "framework_v2" / "results" + output_dir.mkdir(exist_ok=True) + + detail_path = output_dir / "backtest_detail_v2.json" + print(f"\n导出 detail JSON: {detail_path}") + + result = strategy.run( + export_detail=True, + detail_path=str(detail_path) + ) # 打印结果 print("\n" + "=" * 70) diff --git a/framework_v2/strategies/rotation/rotation.py b/framework_v2/strategies/rotation/rotation.py index 695bd39..64b72ae 100644 --- a/framework_v2/strategies/rotation/rotation.py +++ b/framework_v2/strategies/rotation/rotation.py @@ -391,31 +391,55 @@ class GlobalRotationStrategy(StrategyBase): aligner = CrossMarketAligner(target_calendar=trading_calendar) # 提取交易标的的收盘价,并对齐到 A 股日历 - print(" [对齐] 对齐 ETF 价格到 A 股日历...") - close_dict = {} + print(" [对齐] 构建可实现价格序列(模拟真实交易)...") + executable_close_dict = {} + for signal_code, trade_code in signal_to_trade.items(): if trade_code in data: - # 提取收盘价 - close_series = data[trade_code]['close'] - # 使用 signal_code 作为键(与 positions 列名一致) - close_dict[signal_code] = close_series + # 提取开盘价和收盘价 + etf_df = data[trade_code] + open_series = etf_df['open'].reindex(trading_calendar, method='ffill') + close_series = etf_df['close'].reindex(trading_calendar, method='ffill') + + # 默认使用收盘价 + exec_close = close_series.copy() + + # 检测调仓日,调整价格以反映真实交易 + for i in range(1, len(trading_calendar)): + date = trading_calendar[i] + prev_date = trading_calendar[i-1] + + # 获取仓位变化 + prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0 + curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0 + + # 买入日:修改前一天价格为当日开盘价 + # 这样收益率 = (close[t] - open[t]) / open[t] = 日内收益 + if pd.isna(prev_pos) or prev_pos == 0: + if pd.notna(curr_pos) and curr_pos > 0: + exec_close.loc[prev_date] = open_series.loc[date] + + # 卖出日:不需要修改(因为 positions[t]=0,不会计算收益) + + executable_close_dict[signal_code] = exec_close else: print(f" 警告: {trade_code} 数据不存在,跳过") # 使用 CrossMarketAligner 对齐多标的收益率 # 内部逻辑:先 ffill 价格到 A 股日历,再计算收益率 - print(" [对齐] 计算收益率(先对齐价格,再计算)...") - returns_df = aligner.align_multi_asset(close_dict) + print(" [对齐] 计算收益率(使用可实现价格)...") + returns_df = aligner.align_multi_asset(executable_close_dict) print(f" [对齐] 收益率数据: {len(returns_df)} 天, {len(returns_df.columns)} 个标的") # 对齐 positions 到 A 股日历 # 注意:必须先 reindex 再 ffill,因为 reindex(method='ffill') 不会填充已有的 NaN positions = positions.reindex(trading_calendar) - positions = positions.ffill() + # 卖出日不向前填充(保持 0) + positions = positions.ffill().fillna(0) - # 计算策略收益(仓位加权,T+1 执行) - positions_delayed = positions.shift(1).fillna(0) - strategy_returns = (positions_delayed * returns_df).sum(axis=1) + # 计算策略收益(仓位加权,无需延迟) + # 因为 positions[t] 已表示 t 日的实际持仓,且价格已调整为可实现价格 + strategy_returns = (positions * returns_df).sum(axis=1) # 扣除交易成本 strategy_returns, rebalance_count = self._apply_trade_cost( @@ -649,6 +673,33 @@ class GlobalRotationStrategy(StrategyBase): index_return_dict = {} etf_return_dict = {} + # 构建 ETF 可实现价格序列(与回测一致) + executable_etf_close = {} + for signal_code, trade_code in signal_to_trade.items(): + if trade_code in self._data: + etf_df = self._data[trade_code] + open_series = etf_df['open'].reindex(trading_calendar, method='ffill') + close_series = etf_df['close'].reindex(trading_calendar, method='ffill') + + # 默认使用 close + exec_close = close_series.copy() + + # 检测调仓日,调整价格 + for i in range(1, len(trading_calendar)): + date = trading_calendar[i] + prev_date = trading_calendar[i-1] + + # 获取仓位变化 + prev_pos = positions.loc[prev_date, signal_code] if signal_code in positions.columns else 0 + curr_pos = positions.loc[date, signal_code] if signal_code in positions.columns else 0 + + # 买入日:修改前一天价格为 open + if pd.isna(prev_pos) or prev_pos == 0: + if pd.notna(curr_pos) and curr_pos > 0: + exec_close.loc[prev_date] = open_series.loc[date] + + executable_etf_close[signal_code] = exec_close + for signal_code, trade_code in signal_to_trade.items(): # 指数收益率 if signal_code in index_close_dict: @@ -656,10 +707,10 @@ class GlobalRotationStrategy(StrategyBase): idx_return = idx_close.pct_change(fill_method=None).fillna(0) index_return_dict[signal_code] = idx_return - # ETF 收益率 - if signal_code in etf_close_dict: - etf_close = etf_close_dict[signal_code].reindex(trading_calendar, method='ffill') - etf_return = etf_close.pct_change(fill_method=None).fillna(0) + # ETF 收益率(使用可实现价格) + if signal_code in executable_etf_close: + etf_exec = executable_etf_close[signal_code] + etf_return = etf_exec.pct_change(fill_method=None).fillna(0) etf_return_dict[signal_code] = etf_return # 对齐因子 diff --git a/rotation/README_SIMPLE.md b/rotation/README_SIMPLE.md new file mode 100644 index 0000000..519cc06 --- /dev/null +++ b/rotation/README_SIMPLE.md @@ -0,0 +1,64 @@ +# 精简版轮动策略(日迭代) + +## 概述 + +从 A 股交易日历出发,逐日迭代模拟实盘信号生成流程。 + +**与 V2 的核心差异:** + +| 特性 | V2 向量化版本 | 精简版(日迭代) | +|------|--------------|------------------| +| 数据获取 | 一次性获取所有历史数据 | 预加载+CSV缓存 | +| 因子计算 | 全量 rolling 计算 | 每日单独计算 | +| 日历对齐 | 先算因子,后对齐 | 从 A 股日历出发 | +| 收益计算 | 向量化 | 逐日迭代 | + +## 快速使用 + +```bash +cd /path/to/etf +FLASK_API_URL=https://k3s.tokenpluse.xyz python rotation/simple_rotation.py +``` + +或作为模块导入: + +```python +from rotation.simple_rotation import SimpleRotationStrategy + +strategy = SimpleRotationStrategy('rotation/config_simple.yaml') +result = strategy.run() +strategy.export_results() +``` + +## 策略逻辑 + +### 动量计算 +- 加权线性回归:`score = annualized_return * R^2` +- 窗口:25 天(可配置) +- 崩盘过滤:连续 3 天跌 > 5% 则清零 + +### 信号生成(与 V2 完全一致) +1. 每个 group 内选 Top 1(非 BOND 组需超过短债动量) +2. 从各组 Top 1 中按动量排序选 Top 3 +3. 不足用 BOND 填充 + +### T+1 执行收益 +- **持有日**: close-to-close +- **卖出日**: close-to-open(开盘已卖出) +- **买入日**: open-to-close(日内收益) +- 调仓日扣除 0.1% 交易成本 + +## 输出文件 + +``` +results/ +├── simple_rotation_nav.csv # 净值曲线 +├── simple_rotation_signals.csv # 每日信号 +├── simple_rotation_detail.json # 完整详情 +└── simple_rotation_metrics.json # 绩效指标 +``` + +## 数据缓存 + +首次运行会从 Flask API 下载数据并缓存到 `data/simple_rotation_cache/`。 +后续运行直接读取缓存,速度显著提升。 diff --git a/rotation/config_loader.py b/rotation/config_loader.py new file mode 100644 index 0000000..939280f --- /dev/null +++ b/rotation/config_loader.py @@ -0,0 +1,457 @@ +""" +配置加载器 - rotation 目录专用(独立实现) + +不依赖 framework_v2,完全独立的配置加载和验证机制。 + +使用示例: + from config_loader import load_rotation_config + + # 加载配置(自动处理环境变量) + config = load_rotation_config('config_simple.yaml') + + # 访问配置 + print(f"策略版本: {config.metadata.version}") + print(f"动量窗口: {config.factor.n_days}") + print(f"资产数量: {config.asset_pools.count()}") +""" + +import os +import re +import yaml +from pathlib import Path +from typing import Optional, Dict, List, Any +from pydantic import BaseModel, Field, field_validator +from enum import Enum + + +# ============================================================ +# 枚举类型 +# ============================================================ + +class FactorType(str, Enum): + """因子类型枚举""" + MOMENTUM = "momentum" + SLOPE_R2 = "slope_r2" + WEIGHTED_MOMENTUM = "weighted_momentum" + + +class PremiumMode(str, Enum): + """溢价控制模式""" + FILTER = "filter" # 完全排除 + PENALIZE = "penalize" # 降权 + + +class ThresholdMode(str, Enum): + """阈值模式""" + FIXED = "fixed" # 固定阈值 + DYNAMIC = "dynamic" # 动态阈值 + + +class DataSourceType(str, Enum): + """数据源类型""" + FLASK_API = "flask_api" + TUSHARE = "tushare" + YFINANCE = "yfinance" + + +# ============================================================ +# Schema 定义(独立实现) +# ============================================================ + +class PremiumConfig(BaseModel): + """溢价控制配置""" + enabled: bool = Field(default=True, description="是否启用") + threshold: float = Field(default=0.10, ge=0, le=1.0, description="溢价阈值") + + +class AssetConfig(BaseModel): + """标的配置""" + name: str = Field(..., description="标的名称") + group: str = Field(..., description="策略分组") + + signal_source: str = Field(..., description="信号来源代码") + trade_source: str = Field(..., description="交易来源代码") + + description: Optional[str] = Field(None, description="标的描述") + etf: Optional[str] = Field(None, description="ETF代码(兼容旧配置)") + premium_control: Optional["PremiumConfig"] = Field(None, description="溢价控制配置") + + @property + def is_cross_market(self) -> bool: + """是否跨市场""" + return self.signal_source != self.trade_source + + +class AssetPool(BaseModel): + """资产池(扁平化设计)""" + assets: Dict[str, AssetConfig] = Field( + default_factory=dict, + description="所有标的" + ) + + @property + def by_group(self) -> Dict[str, Dict[str, AssetConfig]]: + """按策略分组""" + groups = {} + for code, asset in self.assets.items(): + group = asset.group + if group not in groups: + groups[group] = {} + groups[group][code] = asset + return groups + + @property + def groups(self) -> list: + """获取所有策略分组""" + return list(self.by_group.keys()) + + def get_signal_codes(self, group: str = None) -> list: + """获取信号标的""" + if group: + return [ + asset.signal_source + for asset in self.assets.values() + if asset.group == group + ] + return [asset.signal_source for asset in self.assets.values()] + + def get_trade_codes(self, group: str = None) -> list: + """获取交易标的""" + if group: + return [ + asset.trade_source + for asset in self.assets.values() + if asset.group == group + ] + return [asset.trade_source for asset in self.assets.values()] + + def get_signal_to_trade_mapping(self, group: str = None) -> Dict[str, str]: + """获取信号→交易映射""" + mapping = {} + for asset in self.assets.values(): + if group and asset.group != group: + continue + mapping[asset.signal_source] = asset.trade_source + return mapping + + def count(self, group: str = None) -> int: + """获取标的数量""" + if group: + return len([a for a in self.assets.values() if a.group == group]) + return len(self.assets) + + +class FactorConfig(BaseModel): + """因子配置""" + type: FactorType = Field(default=FactorType.WEIGHTED_MOMENTUM) + n_days: int = Field(default=25, ge=5, le=250) + + +class DynamicThresholdConfig(BaseModel): + """动态阈值配置""" + reference: str = Field(..., description="阈值参考标的") + ratio: float = Field(default=1.0, ge=0, description="阈值倍数") + fallback_enabled: bool = Field(default=True) + fallback_value: float = Field(default=0.0) + + +class ThresholdConfig(BaseModel): + """阈值配置""" + mode: ThresholdMode = Field(default=ThresholdMode.DYNAMIC) + fixed_value: float = Field(default=0.0) + dynamic: Optional[DynamicThresholdConfig] = Field(None) + + +class RotationConfig(BaseModel): + """轮动配置""" + select_num: int = Field(default=3, ge=1, le=10) + diversified: bool = Field(default=True) + threshold: ThresholdConfig = Field(default_factory=ThresholdConfig) + + +class RebalanceConfig(BaseModel): + """调仓配置""" + min_hold_days: int = Field(default=1, ge=1) + score_threshold: float = Field(default=0.0) + trade_cost: float = Field(default=0.001, ge=0, le=0.01) + + +class PremiumControlConfig(BaseModel): + """溢价控制配置""" + enabled: bool = Field(default=False) + default_threshold: float = Field(default=0.10) + mode: PremiumMode = Field(default=PremiumMode.FILTER) + penalty_factor: float = Field(default=0.5) + market_overrides: Optional[Dict[str, Dict[str, Any]]] = Field(None) + + +class DataSourceConfig(BaseModel): + """数据源配置""" + type: DataSourceType = Field(...) + enabled: bool = Field(default=True) + timeout: int = Field(default=120, ge=10, le=600) + url: Optional[str] = Field(None) + + +class DataConfig(BaseModel): + """数据配置""" + sources: List[DataSourceConfig] = Field(...) + + @field_validator('sources') + @classmethod + def check_at_least_one(cls, v): + """至少配置一个数据源""" + if not v: + raise ValueError("必须配置至少一个数据源") + return v + + +class BenchmarkConfig(BaseModel): + """基准配置""" + code: str = Field(...) + name: str = Field(...) + + +class BacktestConfig(BaseModel): + """回测配置""" + start_date: str = Field(...) + end_date: Optional[str] = Field(None) + + +class MetadataConfig(BaseModel): + """配置元数据""" + version: str = Field(default="1.0.0") + strategy: str = Field(default="rotation") + description: str = Field(default="") + last_updated: Optional[str] = Field(None) + + +class RotationStrategyConfig(BaseModel): + """ETF轮动策略完整配置""" + metadata: MetadataConfig = Field(default_factory=MetadataConfig) + asset_pools: AssetPool = Field(...) + benchmark: BenchmarkConfig = Field(...) + backtest: BacktestConfig = Field(...) + factor: FactorConfig = Field(default_factory=FactorConfig) + rotation: RotationConfig = Field(default_factory=RotationConfig) + rebalance: RebalanceConfig = Field(default_factory=RebalanceConfig) + premium_control: PremiumControlConfig = Field(default_factory=PremiumControlConfig) + data: DataConfig = Field(...) + + +# ============================================================ +# 配置加载器(独立实现) +# ============================================================ + +class ConfigLoader: + """配置加载器(独立实现)""" + + def __init__(self, config_dir: str = None): + """初始化""" + if config_dir is None: + config_dir = Path(__file__).parent + + self.config_dir = Path(config_dir) + + def load(self, config_file: str) -> RotationStrategyConfig: + """加载配置文件""" + # 1. 解析路径 + config_path = self._resolve_path(config_file) + + # 2. 读取 YAML + with open(config_path, 'r', encoding='utf-8') as f: + config_dict = yaml.safe_load(f) + + # 3. 环境变量替换 + config_dict = self._substitute_env_vars(config_dict) + + # 4. Pydantic 验证 + config = RotationStrategyConfig(**config_dict) + + return config + + def _resolve_path(self, config_file: str) -> Path: + """解析配置文件路径""" + path = Path(config_file) + + if path.is_absolute(): + if not path.exists(): + raise FileNotFoundError(f"配置文件不存在: {path}") + return path + + # 相对路径:先在配置目录查找 + config_path = self.config_dir / path + if config_path.exists(): + return config_path + + # 然后在当前工作目录查找 + cwd_path = Path.cwd() / path + if cwd_path.exists(): + return cwd_path + + raise FileNotFoundError( + f"配置文件未找到: {path}\n" + f"搜索路径:\n" + f" - {config_path}\n" + f" - {cwd_path}" + ) + + def _substitute_env_vars(self, config: Any) -> Any: + """替换环境变量""" + if isinstance(config, dict): + return { + key: self._substitute_env_vars(value) + for key, value in config.items() + } + elif isinstance(config, list): + return [ + self._substitute_env_vars(item) + for item in config + ] + elif isinstance(config, str): + # 匹配 ${VAR_NAME} 或 ${VAR_NAME:default} + pattern = r'\$\{([^}]+)\}' + + def replace_match(match): + var_expr = match.group(1) + + if ':' in var_expr: + var_name, default = var_expr.split(':', 1) + else: + var_name = var_expr + default = None + + value = os.getenv(var_name) + + if value is None: + if default is not None: + return default + else: + raise ValueError( + f"环境变量未设置: {var_name}\n" + f"请设置环境变量或在配置中使用默认值: ${{{var_name}:default_value}}" + ) + + return value + + return re.sub(pattern, replace_match, config) + else: + return config + + +# ============================================================ +# 快捷函数 +# ============================================================ + +def load_rotation_config(config_file: str = 'config_simple.yaml') -> RotationStrategyConfig: + """ + 加载 rotation 配置 + + Args: + config_file: 配置文件名 + + Returns: + RotationStrategyConfig: 验证后的配置对象 + + 环境变量: + 必须设置 FLASK_API_URL 或提供默认值 + """ + loader = ConfigLoader() + return loader.load(config_file) + + +def get_config_info(config: RotationStrategyConfig) -> dict: + """获取配置摘要""" + return { + 'version': config.metadata.version, + 'strategy': config.metadata.strategy, + 'description': config.metadata.description, + 'asset_count': config.asset_pools.count(), + 'groups': config.asset_pools.groups, + 'factor_type': config.factor.type.value, + 'n_days': config.factor.n_days, + 'select_num': config.rotation.select_num, + 'start_date': config.backtest.start_date, + 'end_date': config.backtest.end_date or '至今', + 'threshold_mode': config.rotation.threshold.mode.value, + } + + +def print_config_summary(config: RotationStrategyConfig) -> None: + """打印配置摘要""" + info = get_config_info(config) + + print("\n" + "=" * 50) + print(" ETF 轮动策略配置摘要") + print("=" * 50) + + print(f"\n版本: {info['version']}") + print(f"策略: {info['strategy']}") + print(f"描述: {info['description']}") + + print(f"\n资产池:") + print(f" 总数: {info['asset_count']}") + print(f" 分组: {', '.join(info['groups'])}") + + print(f"\n因子配置:") + print(f" 类型: {info['factor_type']}") + print(f" 窗口: {info['n_days']} 天") + + print(f"\n轮动配置:") + print(f" 选股数: {info['select_num']}") + print(f" 阈值模式: {info['threshold_mode']}") + + print(f"\n回测配置:") + print(f" 起始: {info['start_date']}") + print(f" 结束: {info['end_date']}") + + print(f"\n各分组资产:") + for group, assets in config.asset_pools.by_group.items(): + print(f"\n [{group}] ({len(assets)} 个):") + for code, asset in assets.items(): + cross_market = "✓ 跨市场" if asset.is_cross_market else "" + print(f" - {asset.name} ({code})") + print(f" 信号: {asset.signal_source} → 交易: {asset.trade_source} {cross_market}") + + print("\n" + "=" * 50) + + +# ============================================================ +# 测试 +# ============================================================ + +if __name__ == "__main__": + """测试配置加载器""" + # 设置环境变量(如果没有设置) + if 'FLASK_API_URL' not in os.environ: + print("提示: FLASK_API_URL 未设置,使用默认值") + os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' + + try: + # 加载配置 + print("\n[测试] 加载 config_simple.yaml...") + config = load_rotation_config() + + # 打印摘要 + print_config_summary(config) + + # 测试基本功能 + print("\n[验证] 基本功能测试:") + print(f" ✓ 配置对象类型: {type(config).__name__}") + print(f" ✓ 资产池对象类型: {type(config.asset_pools).__name__}") + print(f" ✓ 资产数量: {config.asset_pools.count()}") + print(f" ✓ 分组数量: {len(config.asset_pools.groups)}") + + # 测试数据源 + print("\n[验证] 数据源配置:") + for i, source in enumerate(config.data.sources, 1): + print(f" {i}. {source.type.value}") + print(f" URL: {source.url}") + print(f" 启用: {source.enabled}") + + print("\n✓ 所有测试通过!") + + except Exception as e: + print(f"\n✗ 加载失败: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/rotation/config_simple.yaml b/rotation/config_simple.yaml new file mode 100644 index 0000000..996cc95 --- /dev/null +++ b/rotation/config_simple.yaml @@ -0,0 +1,186 @@ +# ETF轮动策略配置(V2 框架) +# +# 配置版本: 2.0.0 +# 最后更新: 2024-04-16 +# 策略名称: rotation +# 描述: 全球资产大类轮动策略 - 复现 V1 结果 + +# ============================================================ +# 元数据 +# ============================================================ +metadata: + version: "2.0.0" + strategy: "rotation" + description: "全球资产大类轮动策略 V2 - 复现 V1 结果" + last_updated: "2024-04-16" + +# ============================================================ +# 资产池配置(扁平化设计:严格对齐 V1 config.yaml) +# ============================================================ +asset_pools: + assets: + # 中国A股指数 + "399006.SZ": + name: "创业板指" + group: "A" + signal_source: "399006.SZ" + trade_source: "159915.SZ" + description: "创业板指数" + + "H30269.CSI": + name: "中证红利低波" + group: "A" + signal_source: "H30269.CSI" + trade_source: "512890.SH" + description: "红利低波指数" + + # 全球市场 + "NDX": + name: "纳指100" + group: "US" + signal_source: "NDX" + trade_source: "513100.SH" + description: "纳斯达克100指数" + + "N225": + name: "日经225" + group: "JP" + signal_source: "N225" + trade_source: "513520.SH" + description: "日经225指数" + + "GDAXI": + name: "德国DAX" + group: "EU" + signal_source: "GDAXI" + trade_source: "513030.SH" + description: "德国DAX指数" + + "HSI": + name: "恒生指数" + group: "HK" + signal_source: "HSI" + trade_source: "159920.SZ" + description: "恒生指数" + + "HSTECH.HK": + name: "恒生科技" + group: "HK" + signal_source: "HSTECH.HK" + trade_source: "513130.SH" + description: "恒生科技指数" + + # 商品(使用 COMEX/WTI 期货替代上期所主力合约,数据更长) + "GC=F": + name: "黄金" + group: "COMMODITY" + signal_source: "GC=F" + trade_source: "518880.SH" + description: "COMEX黄金期货(2000年至今)" + + "CL=F": + name: "原油" + group: "COMMODITY" + signal_source: "CL=F" + trade_source: "160723.SZ" + description: "WTI原油期货(2000年至今)" + + "HG=F": + name: "有色金属" + group: "COMMODITY" + signal_source: "HG=F" + trade_source: "159980.SZ" + description: "COMEX铜期货(2000年至今)" + + # 防御类资产:短债指数 + # 931862.CSI = 中证0-9个月国债指数(短债指数) + # 数据范围:2007-12-31开始,约19年数据 + # 久期:极短(<1年),波动极小,熊市防御效果最佳 + # 收益归因:标的收益约17%,决策收益约83% + # 注意:无对应ETF可交易,直接使用指数数据计算动量和收益 + "931862.CSI": + name: "短债指数" + group: "BOND" + signal_source: "931862.CSI" + trade_source: "931862.CSI" + description: "中证0-9个月国债指数,久期<1年,防御配置" + +# ============================================================ +# 基准配置 +# ============================================================ +benchmark: + code: "000300.SH" + name: "沪深300" + +# ============================================================ +# 回测配置 +# ============================================================ +backtest: + start_date: "2020-01-10" # 与 V1 保持一致(第一个完整交易日) + # end_date: "2026-05-22" # 与 V1 保持一致 + +# ============================================================ +# 因子配置 +# ============================================================ +factor: + type: "weighted_momentum" # 加权动量 + n_days: 25 # 25 天窗口 + +# ============================================================ +# 轮动配置 +# ============================================================ +rotation: + select_num: 3 # 选择 Top-3 + diversified: true # 强制分散化:每个大类只选 Top 1 + + # 阈值配置(V3 动态阈值) + threshold: + mode: "dynamic" # 动态阈值模式 + fixed_value: 0.0 # 固定阈值(mode=fixed时使用) + + # 动态阈值配置(使用短债动量作为阈值) + dynamic: + reference: "931862.CSI" # 阈值参考标的(短债指数) + ratio: 1.0 # 阈值 = 短债动量 × ratio + fallback_enabled: true # 参考不可用时是否回退 + fallback_value: 0.0 # 回退值 + +# ============================================================ +# 调仓配置 +# ============================================================ +rebalance: + min_hold_days: 1 + score_threshold: 0.0 + trade_cost: 0.001 # 0.1% 交易成本 + +# ============================================================ +# 溢价控制配置 +# ============================================================ +premium_control: + enabled: false # 启用溢价控制 + default_threshold: 0.10 # 默认溢价阈值 10% + mode: "filter" # filter(完全排除) 或 penalize(降权) + penalty_factor: 0.5 # 降权模式下的惩罚系数 + + # 按市场覆盖配置 + market_overrides: + A: # A股 ETF + enabled: false # 不启用(溢价通常 < 0.5%) + HK: # 港股 ETF + enabled: true + threshold: 0.10 # 阈值 10% + US: # 美股 ETF + enabled: true + threshold: 0.10 # 阈值 10% + COMMODITY: # 商品 ETF + enabled: false # 不启用 + +# ============================================================ +# 数据配置 +# ============================================================ +data: + sources: + - type: "flask_api" + enabled: true + url: "${FLASK_API_URL}" + timeout: 120 diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py new file mode 100644 index 0000000..bbb2a70 --- /dev/null +++ b/rotation/simple_rotation.py @@ -0,0 +1,584 @@ +""" +Simple Rotation Strategy (Daily Iteration) + +From A-share trading calendar, iterate daily: +1. Get signal source last N days -> compute momentum +2. Get bond momentum (dynamic threshold) +3. Group selection -> generate holdings +4. Compare with yesterday -> compute T+1 return +5. Update NAV +""" + +import os +import sys +import math +import json +import time +import requests +import numpy as np +import pandas as pd +from pathlib import Path +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Tuple + +PROJECT_ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(PROJECT_ROOT)) + +from rotation.config_loader import load_rotation_config, RotationStrategyConfig + + +# ============================================================ +# Pure functions: momentum +# ============================================================ + +def weighted_momentum_score(prices: np.ndarray) -> float: + """Weighted linear regression momentum score = annualized_return * R^2""" + if len(prices) < 5: + return 0.0 + prices = np.clip(prices, 0.01, None) + y = np.log(prices) + if np.any(np.isnan(y)) or np.any(np.isinf(y)): + return 0.0 + x = np.arange(len(y)) + weights = np.linspace(1, 2, len(y)) + slope, intercept = np.polyfit(x, y, 1, w=weights) + annualized_return = math.exp(slope * 250) - 1 + y_pred = slope * x + intercept + ss_res = np.sum(weights * (y - y_pred) ** 2) + ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 + return annualized_return * r2 + + +def is_crash(prices: np.ndarray) -> bool: + """Crash filter: 3 consecutive days drop > 5%""" + if len(prices) < 4: + return False + p = prices[-4:] + r1 = p[3] / p[2] + r2 = p[2] / p[1] + r3 = p[1] / p[0] + con1 = min(r1, r2, r3) < 0.95 + con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95) + return con1 or con2 + + +# ============================================================ +# Data cache +# ============================================================ + +class DataCache: + """CSV-cached data fetching for index (raw) and ETF (hfq)""" + + def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120): + self.base_url = base_url.rstrip('/') + self.api_path = '/api/v1/ohlcv' + self.timeout = timeout + if cache_dir is None: + cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache' + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _cache_path(self, code: str, adj: str) -> Path: + prefix = 'index' if adj == 'raw' else 'etf' + safe_code = code.replace('=', '_').replace('^', '_') + return self.cache_dir / f"{prefix}_{safe_code}.csv" + + def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: + """Preload full history and cache to CSV""" + cache_path = self._cache_path(code, adj) + if cache_path.exists(): + try: + df = pd.read_csv(cache_path, index_col='date', parse_dates=True) + if len(df) > 0: + cs = df.index.min().strftime('%Y-%m-%d') + ce = df.index.max().strftime('%Y-%m-%d') + if cs <= start_date and ce >= end_date: + return df + new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d') + if new_start <= end_date: + new_df = self._fetch_api(code, new_start, end_date, adj) + if new_df is not None and len(new_df) > 0: + df = pd.concat([df, new_df]) + df = df[~df.index.duplicated(keep='last')] + df.to_csv(cache_path) + return df + except Exception: + pass + df = self._fetch_api(code, start_date, end_date, adj) + if df is not None and len(df) > 0: + df.to_csv(cache_path) + return df + + def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]: + """Fetch from Flask API""" + url = f"{self.base_url}{self.api_path}" + params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj} + for attempt in range(3): + try: + resp = requests.get(url, params=params, timeout=self.timeout) + if resp.status_code != 200: + if attempt < 2: + time.sleep(1) + continue + print(f" x {code}: HTTP {resp.status_code}") + return None + data = resp.json() + if 'error' in data: + print(f" x {code}: {data['error']}") + return None + records = data.get('data', []) + if not records: + return None + df = pd.DataFrame(records) + if 'date' in df.columns: + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date').sort_index() + keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns] + df = df[keep] + print(f" + {code}: {len(df)} rows ({adj})") + return df + except requests.exceptions.Timeout: + if attempt < 2: + continue + print(f" x {code}: timeout") + return None + except Exception as e: + print(f" x {code}: {e}") + return None + return None + + def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]: + """Fetch trading calendar from API""" + url = f"{self.base_url}/api/v1/trading-calendar" + params = {'market': market, 'start': start_date, 'end': end_date} + for attempt in range(3): + try: + resp = requests.get(url, params=params, timeout=self.timeout) + if resp.status_code != 200: + if attempt < 2: + continue + return None + data = resp.json() + if 'error' in data: + print(f" x calendar: {data['error']}") + return None + dates = data.get('trading_dates', []) + if not dates: + return pd.DatetimeIndex([]) + result = pd.DatetimeIndex(dates) + print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})") + return result + except Exception as e: + if attempt < 2: + continue + print(f" x calendar: {e}") + return None + return None + + +# ============================================================ +# Core strategy class +# ============================================================ + +class SimpleRotationStrategy: + """ + Simple rotation strategy (daily iteration) + + Flow: + 1. Load config + 2. Preload all historical data (index raw + ETF hfq) + 3. Fetch A-share trading calendar + 4. For each trading day: + - Compute momentum (last 25 days) + - Group selection -> generate holdings + - Compare with yesterday -> compute T+1 return + - Update NAV + 5. Export results + """ + + def __init__(self, config_path: str = None): + if config_path is None: + config_path = Path(__file__).parent / 'config_simple.yaml' + self.config: RotationStrategyConfig = load_rotation_config(str(config_path)) + + # Strategy params + self.n_days = self.config.factor.n_days + self.select_num = self.config.rotation.select_num + self.trade_cost = self.config.rebalance.trade_cost + + # Dynamic threshold + threshold = self.config.rotation.threshold + self.use_dynamic_threshold = (threshold.mode.value == 'dynamic') + self.bond_code = threshold.dynamic.reference if threshold.dynamic else None + self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0 + self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0 + + # Signal codes and mappings + self.signal_codes = [] + self.signal_to_trade = {} + self.code_to_group = {} + for code, asset in self.config.asset_pools.assets.items(): + self.signal_codes.append(asset.signal_source) + self.signal_to_trade[asset.signal_source] = asset.trade_source + self.code_to_group[asset.signal_source] = asset.group + + # Data source + data_source = self.config.data.sources[0] + base_url = data_source.url or 'https://k3s.tokenpluse.xyz' + self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout) + + # Preloaded data + self.index_data: Dict[str, pd.DataFrame] = {} + self.etf_data: Dict[str, pd.DataFrame] = {} + + # Results + self.daily_records: List[dict] = [] + self.trading_calendar: Optional[pd.DatetimeIndex] = None + + def _preload_data(self): + """Preload all historical data""" + start_date = self.config.backtest.start_date + end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d') + preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d') + + print("\n[1/4] Preloading signal sources (index raw)...") + for code in self.signal_codes: + df = self.data_cache.preload(code, preload_start, end_date, adj='raw') + if df is not None: + self.index_data[code] = df + print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK") + + print("\n[2/4] Preloading trade sources (ETF hfq)...") + trade_codes = set(self.signal_to_trade.values()) + for code in trade_codes: + is_bond = any( + a.trade_source == code and a.group == 'BOND' + for a in self.config.asset_pools.assets.values() + ) + adj = 'raw' if is_bond else 'hfq' + df = self.data_cache.preload(code, preload_start, end_date, adj=adj) + if df is not None: + self.etf_data[code] = df + print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK") + + def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]: + """Compute momentum for a single code on a given date""" + if signal_code not in self.index_data: + return None + df = self.index_data[signal_code] + mask = df.index <= date + recent = df.loc[mask] + if len(recent) < self.n_days: + return None + prices = recent['close'].values[-self.n_days:] + if len(prices) >= 4 and is_crash(prices): + return 0.0 + return weighted_momentum_score(prices) + + def _generate_signals(self, date: pd.Timestamp) -> List[str]: + """ + Generate rotation signals (group competition + dynamic threshold + BOND fill) + + Logic (identical to V2): + 1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio) + 2. From all group winners: sort by momentum, select Top select_num + 3. Fill remaining slots with BOND + """ + factors: Dict[str, float] = {} + for code in self.signal_codes: + score = self._compute_momentum(code, date) + if score is not None: + factors[code] = score + if not factors: + return [] + + bond_momentum = None + if self.use_dynamic_threshold and self.bond_code: + bond_momentum = self._compute_momentum(self.bond_code, date) + if bond_momentum is None: + bond_momentum = self.fallback_value + + groups = self.config.asset_pools.by_group + selected_by_group: Dict[str, Tuple[str, float]] = {} + + for group_name, assets in groups.items(): + group_codes = [a.signal_source for a in assets.values()] + group_factors = {c: factors[c] for c in group_codes if c in factors} + if not group_factors: + continue + if group_name != 'BOND' and bond_momentum is not None: + thresh = bond_momentum * self.bond_ratio + group_factors = {c: s for c, s in group_factors.items() if s >= thresh} + if not group_factors: + continue + top_code = max(group_factors, key=group_factors.get) + selected_by_group[group_name] = (top_code, group_factors[top_code]) + + if not selected_by_group: + return [] + + candidates = list(selected_by_group.values()) + candidates.sort(key=lambda x: x[1], reverse=True) + final_holdings = [code for code, _ in candidates[:self.select_num]] + + if len(final_holdings) < self.select_num and self.bond_code: + if self.bond_code not in final_holdings: + n_slots = self.select_num - len(final_holdings) + final_holdings.extend([self.bond_code] * n_slots) + + return sorted(final_holdings) + + def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]: + """Get ETF prices on a given date: {open, close, prev_close} + Note: Index data may lack 'open' column; use close as fallback. + """ + if trade_code not in self.etf_data: + return None + df = self.etf_data[trade_code] + mask = df.index <= date + recent = df.loc[mask] + if len(recent) < 2: + return None + today_row = recent.iloc[-1] + prev_row = recent.iloc[-2] + today_date = recent.index[-1] + if (date - today_date).days > 5: + return None + close = float(today_row['close']) + prev_close = float(prev_row['close']) + # Handle missing open (common for index data like 931862.CSI) + open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close + if pd.isna(open_price) or open_price == 0: + open_price = close + if pd.isna(close) or pd.isna(prev_close): + return None + return { + 'open': open_price, + 'close': close, + 'prev_close': prev_close, + } + + def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance): + """ + Compute daily return (T+1 execution): + - Hold: close-to-close + - Sell: close-to-open (sold at open) + - Buy: open-to-close (intraday) + """ + if not old_holdings: + if not new_holdings: + return 0.0 + weight = 1.0 / len(new_holdings) + ret = 0.0 + for code in new_holdings: + tc = self.signal_to_trade.get(code, code) + p = self._get_etf_prices(tc, date) + if p and p['open'] > 0: + ret += weight * (p['close'] - p['open']) / p['open'] + if is_rebalance: + ret -= self.trade_cost + return ret + + old_set = set(old_holdings) + new_set = set(new_holdings) + weight = 1.0 / len(old_holdings) + daily_return = 0.0 + + for code in old_holdings: + tc = self.signal_to_trade.get(code, code) + p = self._get_etf_prices(tc, date) + if p is None or p['prev_close'] == 0: + continue + if code in new_set: + r = (p['close'] - p['prev_close']) / p['prev_close'] + else: + r = (p['open'] - p['prev_close']) / p['prev_close'] + if not math.isnan(r): + daily_return += weight * r + + for code in new_holdings: + if code not in old_set: + tc = self.signal_to_trade.get(code, code) + p = self._get_etf_prices(tc, date) + if p and p['open'] > 0 and not math.isnan(p['close']): + r = (p['close'] - p['open']) / p['open'] + if not math.isnan(r): + daily_return += weight * r + + if is_rebalance: + daily_return -= self.trade_cost + return daily_return + + def run(self) -> dict: + """Main backtest loop""" + print("=" * 60) + print(" Simple Rotation Strategy (Daily Iteration)") + print("=" * 60) + + self._preload_data() + + print("\n[3/4] Fetching A-share trading calendar...") + start_date = self.config.backtest.start_date + end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d') + self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date) + + if self.trading_calendar is None or len(self.trading_calendar) == 0: + print(" x Calendar fetch failed") + return {} + + print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...") + current_holdings: List[str] = [] + nav = 1.0 + rebalance_count = 0 + + for i, date in enumerate(self.trading_calendar): + new_holdings = self._generate_signals(date) + is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0 + + daily_return = self._calculate_daily_return( + current_holdings, new_holdings, date, is_rebalance + ) + nav *= (1 + daily_return) + + if is_rebalance: + rebalance_count += 1 + + self.daily_records.append({ + 'date': date.strftime('%Y-%m-%d'), + 'nav': round(nav, 6), + 'daily_return': round(daily_return, 6), + 'is_rebalance': is_rebalance, + 'holdings': sorted(new_holdings), + 'added': sorted(set(new_holdings) - set(current_holdings)), + 'removed': sorted(set(current_holdings) - set(new_holdings)), + }) + + current_holdings = new_holdings + + if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1: + print(f" [{i+1}/{len(self.trading_calendar)}] " + f"NAV: {nav:.4f} | Rebal: {rebalance_count}") + + metrics = self._compute_metrics(rebalance_count) + + print("\n" + "=" * 60) + print(" Backtest Complete") + print("=" * 60) + if self.daily_records: + print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}") + print(f" Trading days: {len(self.daily_records)}") + print(f" Total return: {metrics['total_return']:.2%}") + print(f" Annual return: {metrics['annual_return']:.2%}") + print(f" Max drawdown: {metrics['max_drawdown']:.2%}") + print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}") + print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}") + print(f" Win rate: {metrics['win_rate']:.2%}") + print(f" Rebalances: {rebalance_count}") + print("=" * 60) + + return { + 'metrics': metrics, + 'daily_records': self.daily_records, + } + + def _compute_metrics(self, rebalance_count: int) -> dict: + """Compute performance metrics""" + if not self.daily_records: + return {} + df = pd.DataFrame(self.daily_records) + df['date'] = pd.to_datetime(df['date']) + nav = df['nav'] + returns = df['daily_return'] + + total_return = nav.iloc[-1] / nav.iloc[0] - 1 + n_days = len(df) + annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0 + + peak = nav.cummax() + drawdown = (nav - peak) / peak + max_drawdown = drawdown.min() + + sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0 + calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0 + + non_zero = returns[returns != 0] + win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0 + + return { + 'total_return': total_return, + 'annual_return': annual_return, + 'max_drawdown': max_drawdown, + 'sharpe_ratio': sharpe, + 'calmar_ratio': calmar, + 'win_rate': win_rate, + 'n_days': n_days, + 'rebalance_count': rebalance_count, + } + + def export_results(self, output_dir: str = None): + """Export backtest results to CSV and JSON""" + if not self.daily_records: + print(" x No results to export") + return + + if output_dir is None: + output_dir = Path(__file__).parent / 'results' + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + df = pd.DataFrame(self.daily_records) + + # NAV curve + nav_path = output_dir / 'simple_rotation_nav.csv' + df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False) + print(f" + NAV: {nav_path}") + + # Signals + sig_path = output_dir / 'simple_rotation_signals.csv' + df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False) + print(f" + Signals: {sig_path}") + + # Detail JSON + detail_path = output_dir / 'simple_rotation_detail.json' + detail = { + 'meta': { + 'mode': 'Simple: Daily Iteration', + 'start_date': self.config.backtest.start_date, + 'end_date': self.config.backtest.end_date or 'now', + 'n_days': self.n_days, + 'select_num': self.select_num, + 'trade_cost': self.trade_cost, + 'bond_threshold': { + 'enabled': self.use_dynamic_threshold, + 'bond_code': self.bond_code, + 'ratio': self.bond_ratio, + }, + }, + 'days': self.daily_records, + } + with open(detail_path, 'w', encoding='utf-8') as f: + json.dump(detail, f, ensure_ascii=False, indent=2) + print(f" + Detail: {detail_path}") + + # Metrics JSON + metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance'])) + metrics_path = output_dir / 'simple_rotation_metrics.json' + with open(metrics_path, 'w', encoding='utf-8') as f: + json.dump(metrics, f, ensure_ascii=False, indent=2) + print(f" + Metrics: {metrics_path}") + + +# ============================================================ +# Entry point +# ============================================================ + +if __name__ == "__main__": + if 'FLASK_API_URL' not in os.environ: + os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz' + + strategy = SimpleRotationStrategy() + result = strategy.run() + + if result: + strategy.export_results() diff --git a/scripts/daily_scheduler.py b/scripts/daily_scheduler.py index 6682eff..a8c8583 100644 --- a/scripts/daily_scheduler.py +++ b/scripts/daily_scheduler.py @@ -121,13 +121,13 @@ def run_strategy(config_path: str = "strategies/rotation/config.yaml") -> dict: logger.info(f"执行命令: {' '.join(cmd)}") - # 执行策略 + # 执行策略(增加超时到15分钟,因为需要获取多市场数据) result = subprocess.run( cmd, capture_output=True, text=True, cwd=project_root, - timeout=300 # 5分钟超时 + timeout=900 # 15分钟超时(原5分钟不够,数据获取串行需要更长时间) ) if result.returncode != 0: diff --git a/scripts/get_trading_calendar.py b/scripts/get_trading_calendar.py new file mode 100644 index 0000000..382e4a5 --- /dev/null +++ b/scripts/get_trading_calendar.py @@ -0,0 +1,348 @@ +""" +获取 A 股交易日历脚本 + +使用 Flask API 交易日历服务获取 A 股交易日历 +支持多市场、多年份的交易日查询 + +用法: + python scripts/get_trading_calendar.py + python scripts/get_trading_calendar.py --year 2024 + python scripts/get_trading_calendar.py --start 2024-01-01 --end 2024-12-31 +""" + +import sys +import argparse +from pathlib import Path +from datetime import datetime, timedelta +import pandas as pd + +# 添加项目根目录到路径 +project_root = Path(__file__).parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# 加载环境变量 +from dotenv import load_dotenv +load_dotenv() + +# 导入 Flask API 数据源 +from datasource.flask_api_source import FlaskAPIDataSource + + +def get_calendar_for_year(source: FlaskAPIDataSource, year: int, market: str = 'A'): + """ + 获取指定年份的交易日历 + + Args: + source: Flask API 数据源实例 + year: 年份(如 2024) + market: 市场代码('A', 'US', 'HK') + + Returns: + pd.DatetimeIndex: 交易日序列 + """ + start_date = f"{year}-01-01" + end_date = f"{year}-12-31" + + print(f"\n获取 {year} 年 {market} 市场交易日历...") + + trading_dates = source.get_trading_calendar( + market=market, + start_date=start_date, + end_date=end_date + ) + + if trading_dates is None or len(trading_dates) == 0: + print(f"✗ {year} 年 {market} 市场无交易日数据") + return None + + return trading_dates + + +def analyze_calendar(trading_dates: pd.DatetimeIndex, year: int): + """ + 分析交易日历统计信息 + + Args: + trading_dates: 交易日序列 + year: 年份 + """ + if trading_dates is None or len(trading_dates) == 0: + return + + print(f"\n{'=' * 60}") + print(f"{year} 年 A 股交易日历分析") + print(f"{'=' * 60}") + + # 基本统计 + total_days = len(trading_dates) + print(f"\n基本统计:") + print(f" 总交易日: {total_days} 天") + print(f" 起始日期: {trading_dates.min().strftime('%Y-%m-%d')}") + print(f" 结束日期: {trading_dates.max().strftime('%Y-%m-%d')}") + + # 按月份统计 + print(f"\n按月份统计:") + monthly_counts = {} + for date in trading_dates: + month = date.month + monthly_counts[month] = monthly_counts.get(month, 0) + 1 + + for month in range(1, 13): + count = monthly_counts.get(month, 0) + month_name = datetime(2024, month, 1).strftime('%B') + print(f" {month:02d}月 ({month_name}): {count} 天") + + # 按季度统计 + print(f"\n按季度统计:") + quarterly_counts = {1: 0, 2: 0, 3: 0, 4: 0} + for date in trading_dates: + quarter = (date.month - 1) // 3 + 1 + quarterly_counts[quarter] += 1 + + for quarter, count in quarterly_counts.items(): + print(f" Q{quarter}: {count} 天") + + # 特殊日期统计 + print(f"\n特殊日期:") + first_date = trading_dates.min() + last_date = trading_dates.max() + print(f" 首个交易日: {first_date.strftime('%Y-%m-%d')} ({first_date.strftime('%A')})") + print(f" 最后交易日: {last_date.strftime('%Y-%m-%d')} ({last_date.strftime('%A')})") + + # 查找节假日后的首个交易日(通过间隔判断) + gaps = [] + for i in range(1, len(trading_dates)): + prev_date = trading_dates[i-1] + curr_date = trading_dates[i] + gap_days = (curr_date - prev_date).days + if gap_days > 3: # 超过3天视为可能节假日 + gaps.append({ + 'prev': prev_date, + 'curr': curr_date, + 'gap': gap_days + }) + + if gaps: + print(f"\n可能的节假日(间隔 > 3天):") + for gap_info in gaps[:5]: # 只显示前5个 + print(f" {gap_info['prev'].strftime('%Y-%m-%d')} → {gap_info['curr'].strftime('%Y-%m-%d')} " + f"(间隔 {gap_info['gap']} 天)") + + print(f"\n{'=' * 60}") + + +def compare_markets(source: FlaskAPIDataSource, year: int): + """ + 比较不同市场的交易日历 + + Args: + source: Flask API 数据源实例 + year: 年份 + """ + print(f"\n{'=' * 60}") + print(f"{year} 年不同市场交易日历对比") + print(f"{'=' * 60}") + + markets = { + 'A': 'A股(上交所/深交所)', + 'US': '美股(NYSE)', + 'HK': '港股(HKEX)' + } + + results = {} + for market_code, market_name in markets.items(): + print(f"\n获取 {market_name} 交易日历...") + trading_dates = get_calendar_for_year(source, year, market_code) + + if trading_dates is not None and len(trading_dates) > 0: + results[market_code] = { + 'name': market_name, + 'dates': trading_dates, + 'count': len(trading_dates) + } + + # 对比统计 + print(f"\n交易日对比:") + print(f"{'市场':<20} {'交易日数':<10} {'起始日期':<12} {'结束日期':<12}") + print("-" * 60) + + for market_code, data in results.items(): + print(f"{data['name']:<20} {data['count']:<10} " + f"{data['dates'].min().strftime('%Y-%m-%d'):<12} " + f"{data['dates'].max().strftime('%Y-%m-%d'):<12}") + + # 计算差异 + if len(results) >= 2: + print(f"\n交易日差异:") + market_codes = list(results.keys()) + for i in range(len(market_codes)): + for j in range(i+1, len(market_codes)): + m1 = market_codes[i] + m2 = market_codes[j] + diff = results[m1]['count'] - results[m2]['count'] + print(f" {results[m1]['name']} vs {results[m2]['name']}: " + f"相差 {abs(diff)} 天 ({'+' if diff > 0 else ''}{diff})") + + print(f"\n{'=' * 60}") + + +def show_recent_dates(trading_dates: pd.DatetimeIndex, n: int = 10): + """ + 显示最近的交易日 + + Args: + trading_dates: 交易日序列 + n: 显示数量 + """ + if trading_dates is None or len(trading_dates) == 0: + return + + print(f"\n最近 {n} 个交易日:") + recent_dates = trading_dates[-n:] if len(trading_dates) >= n else trading_dates + + for date in recent_dates: + weekday = date.strftime('%A') + print(f" {date.strftime('%Y-%m-%d')} ({weekday})") + + +def export_calendar(trading_dates: pd.DatetimeIndex, output_path: str, year: int): + """ + 导出交易日历到 CSV + + Args: + trading_dates: 交易日序列 + output_path: 输出路径 + year: 年份 + """ + if trading_dates is None or len(trading_dates) == 0: + return + + # 创建 DataFrame + df = pd.DataFrame({ + 'date': trading_dates, + 'year': trading_dates.year, + 'month': trading_dates.month, + 'quarter': (trading_dates.month - 1) // 3 + 1, + 'weekday': [d.strftime('%A') for d in trading_dates] + }) + + # 导出到 CSV + filename = f"{output_path}/trading_calendar_A_{year}.csv" + df.to_csv(filename, index=False) + print(f"\n✓ 交易日历已导出到: {filename}") + print(f" 文件包含 {len(df)} 条记录") + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description='获取 A 股交易日历') + + parser.add_argument( + '--year', + type=int, + default=datetime.now().year, + help='年份(默认当前年份)' + ) + + parser.add_argument( + '--start', + type=str, + help='起始日期 YYYY-MM-DD' + ) + + parser.add_argument( + '--end', + type=str, + help='结束日期 YYYY-MM-DD' + ) + + parser.add_argument( + '--market', + type=str, + default='A', + choices=['A', 'US', 'HK'], + help='市场代码(A=A股, US=美股, HK=港股)' + ) + + parser.add_argument( + '--compare', + action='store_true', + help='对比不同市场交易日历' + ) + + parser.add_argument( + '--export', + action='store_true', + help='导出交易日历到 CSV' + ) + + parser.add_argument( + '--output', + type=str, + default='data', + help='导出目录(默认 data)' + ) + + args = parser.parse_args() + + # 初始化 Flask API 数据源 + print("\n初始化 Flask API 数据源...") + source = FlaskAPIDataSource() + + # 检查服务健康状态 + health = source.get_health() + if health.get('status') != 'healthy': + print(f"✗ Flask API 服务不可用: {health}") + sys.exit(1) + + print(f"✓ Flask API 服务可用 ({source.base_url})") + + # 获取交易日历信息 + calendar_info = source.get_calendar_info() + if 'error' not in calendar_info: + print(f"\n交易日历服务信息:") + print(f" 支持市场: {', '.join(calendar_info.get('markets', []))}") + print(f" 数据源: {calendar_info.get('source', 'pandas_market_calendars')}") + + # 执行不同功能 + if args.compare: + # 对比不同市场 + compare_markets(source, args.year) + + elif args.start and args.end: + # 自定义日期范围 + print(f"\n获取 {args.market} 市场交易日历 ({args.start} ~ {args.end})...") + trading_dates = source.get_trading_calendar( + market=args.market, + start_date=args.start, + end_date=args.end + ) + + if trading_dates is not None: + print(f"✓ 获取到 {len(trading_dates)} 个交易日") + show_recent_dates(trading_dates) + + if args.export: + export_calendar(trading_dates, args.output, args.year) + + else: + # 获取指定年份交易日历 + trading_dates = get_calendar_for_year(source, args.year, args.market) + + if trading_dates is not None: + # 分析统计 + analyze_calendar(trading_dates, args.year) + + # 显示最近交易日 + show_recent_dates(trading_dates) + + # 导出 + if args.export: + export_calendar(trading_dates, args.output, args.year) + + print("\n✓ 完成!") + + +if __name__ == "__main__": + main() \ No newline at end of file