""" 美股轮动策略 纯美股轮动策略,使用动量因子选股 特点: - 全部使用 yfinance 数据源 - 美股交易日历 - 不分组,直接选 Top 5 - 基准为纳指 NDX """ import sys import yaml import time import pandas as pd from pathlib import Path from datetime import datetime from typing import Dict, List, Optional, Tuple # 添加项目根目录 project_root = Path(__file__).parent.parent.parent sys.path.insert(0, str(project_root)) from datasource.yfinance_source import YFinanceSource from datasource.ssh_tunnel import SSHTunnelManager from strategies.shared.factors.momentum import MomentumFactor from strategies.shared.signals.selectors import TopNSelector from framework.execution import BacktestExecutor class USRotationStrategy: """美股轮动策略""" def __init__(self, config_path: str = None, config: dict = None): """ 初始化策略 Args: config_path: 配置文件路径 config: 配置字典(可选) """ if config_path: with open(config_path, 'r', encoding='utf-8') as f: self.config = yaml.safe_load(f) elif config: self.config = config else: raise ValueError("需要提供 config_path 或 config") # 结束日期默认今天 if not self.config.get('end_date'): self.config['end_date'] = datetime.now().strftime('%Y-%m-%d') # 应用配置 self._apply_config() # 初始化因子 self._factor = MomentumFactor( n_days=self.n_days, weighted=True, crash_filter=True ) # 初始化选择器(不分组,直接选 Top N) self._selector = TopNSelector( select_num=self.select_num, rebalance_days=self.rebalance_days, rebalance_threshold=self.rebalance_threshold ) # 数据源(延迟初始化) self._yfinance: Optional[YFinanceSource] = None self._tunnel: Optional[SSHTunnelManager] = None @classmethod def from_yaml(cls, config_path: str) -> 'USRotationStrategy': """从 YAML 文件创建策略实例""" return cls(config_path=config_path) def _apply_config(self): """应用配置参数""" self.select_num = self.config.get('select_num', 5) self.n_days = self.config.get('n_days', 250) self.rebalance_days = self.config.get('rebalance_days', 1) self.rebalance_threshold = self.config.get('rebalance_threshold', 0.0) self.trade_cost = self.config.get('trade_cost', 0.001) self.start_date = self.config.get('start_date', '2016-01-01') self.end_date = self.config['end_date'] self.use_cache = self.config.get('use_cache', True) def _start_tunnel(self) -> bool: """启动 SSH 隧道""" ssh_config = self.config.get('ssh_tunnel', {}) if not ssh_config.get('enabled', False): return True self._tunnel = SSHTunnelManager(ssh_config) return self._tunnel.start() def _stop_tunnel(self): """停止 SSH 隧道""" if self._tunnel: self._tunnel.stop() self._tunnel = None def _init_yfinance(self): """初始化 YFinance 数据源""" if self._yfinance is None: self._yfinance = YFinanceSource(use_ssh_tunnel=True) def fetch_data(self) -> Dict: """获取数据(全部使用 yfinance)""" print("\n" + "=" * 60) print("获取美股数据") print("=" * 60) code_list_config = self.config.get('code_list', {}) benchmark_config = self.config.get('benchmark', {}) benchmark_code = benchmark_config.get('code', 'NDX') if not code_list_config: raise ValueError("配置中未找到 code_list") codes = list(code_list_config.keys()) print(f"标的池: {len(codes)} 只股票") print(f"基准: {benchmark_code}") print(f"时间范围: {self.start_date} ~ {self.end_date}") # 启动 SSH 隓道 print("\n启动 SSH 隧道...") if not self._start_tunnel(): raise RuntimeError("SSH 隧道启动失败") self._init_yfinance() # 获取数据 all_data: Dict[str, pd.DataFrame] = {} valid_codes: List[str] = [] print("\n获取股票数据...") for i, code in enumerate(codes): print(f" [{i+1}/{len(codes)}] {code}...", end=" ") try: df = self._yfinance.fetch(code, self.start_date, self.end_date) time.sleep(0.5) # 避免限流 if df is not None and len(df) >= self.n_days: all_data[code] = df valid_codes.append(code) print(f"✓ {len(df)} 条") else: print(f"✗ 数据不足") except Exception as e: print(f"✗ 失败: {e}") # 获取基准数据 print(f"\n获取基准 {benchmark_code}...", end=" ") try: benchmark_df = self._yfinance.fetch(benchmark_code, self.start_date, self.end_date) if benchmark_df is not None and len(benchmark_df) > 0: print(f"✓ {len(benchmark_df)} 条") else: print(f"✗ 基准数据获取失败") benchmark_df = None except Exception as e: print(f"✗ 失败: {e}") benchmark_df = None # 停止隧道 self._stop_tunnel() print(f"\n数据获取完成: {len(valid_codes)}/{len(codes)} 只有效") return { 'stock_data': all_data, 'valid_codes': valid_codes, 'benchmark': benchmark_df, 'benchmark_code': benchmark_code } def compute_factors(self, data: Dict) -> pd.DataFrame: """计算动量因子""" print("\n" + "=" * 60) print("计算动量因子") print("=" * 60) stock_data = data['stock_data'] valid_codes = data['valid_codes'] factor_values: Dict[str, pd.Series] = {} for code in valid_codes: df = stock_data[code] if 'close' not in df.columns: continue # 数据长度检查 if len(df) < self.n_days: print(f" ⚠ {code}: 数据不足 {self.n_days} 天,跳过") continue # MomentumFactor.compute 需要DataFrame factor_series = self._factor.compute(df) if factor_series is not None and len(factor_series) > 0: factor_values[code] = factor_series # 合成 DataFrame factor_df = pd.DataFrame(factor_values) print(f"\n因子计算完成: {len(factor_df.columns)} 只标的") print(f" 窗口: {self.n_days} 天") if len(factor_df) > 0: print(f" 日期范围: {factor_df.index.min()} ~ {factor_df.index.max()}") return factor_df def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame: """生成轮动信号(不分组,直接选 Top 5)""" print("\n" + "=" * 60) print("生成轮动信号") print("=" * 60) # 不分组,直接对因子排序选 Top 5 # TopNSelector.generate 会自动处理调仓周期和T+1 signals_df = self._selector.generate(factor_df) print(f"\n信号生成完成:") print(f" 选股数量: {self.select_num}") if 'signal' in signals_df.columns: valid_signals = signals_df[signals_df['signal'] != ''] print(f" 有效信号天数: {len(valid_signals)}") return signals_df def run_backtest(self, data: Dict = None, save_path: str = None) -> Dict: """运行回测""" print("\n" + "=" * 60) print("美股动量轮动策略 回测") print("=" * 60) # 1. 获取数据 if data is None: data = self.fetch_data() valid_codes = data['valid_codes'] # 2. 计算因子 factor_df = self.compute_factors(data) # 3. 生成信号 signals = self.generate_signals(factor_df) # 4. 构建收益数据(BacktestExecutor期望列名格式:日收益率_{code}) print("\n构建收益数据...") stock_data = data['stock_data'] # 计算日收益率,列名格式为 '日收益率_{code}' returns_data: Dict[str, pd.Series] = {} for code in valid_codes: if code in stock_data: df = stock_data[code] if 'close' in df.columns: returns_data[f'日收益率_{code}'] = df['close'].pct_change() returns_df = pd.DataFrame(returns_data) # 对齐日期 common_dates = signals.index.intersection(returns_df.index) signals = signals.loc[common_dates] returns_df = returns_df.loc[common_dates] # 5. 执行回测 print("\n执行回测...") executor = BacktestExecutor( initial_capital=100, trade_cost=self.trade_cost, select_num=self.select_num ) portfolio = executor.execute(signals, returns_df) # 6. 计算基准收益 benchmark_df = data['benchmark'] benchmark_code = data['benchmark_code'] if benchmark_df is not None and 'close' in benchmark_df.columns: benchmark_returns = benchmark_df['close'].pct_change() # 对齐日期 benchmark_returns = benchmark_returns.loc[common_dates] # 7. 输出结果 if hasattr(portfolio, 'backtest_result') and portfolio.backtest_result is not None: result = portfolio.backtest_result # 策略净值(DataFrame列) if '策略净值' in result.columns: strategy_nav = result['策略净值'].values final_nav = strategy_nav[-1] if len(strategy_nav) > 0 else 100 total_return = (final_nav - 1) * 100 # 净值归一化起点为1 else: final_nav = 100 total_return = 0 # 基准收益 if benchmark_df is not None and 'close' in benchmark_df.columns: benchmark_start = benchmark_df['close'].iloc[0] benchmark_end = benchmark_df['close'].iloc[-1] benchmark_return = (benchmark_end / benchmark_start - 1) * 100 else: benchmark_return = 0 print("\n" + "=" * 60) print("回测结果") print("=" * 60) print(f"策略最终净值: {final_nav:.2f}") print(f"策略总收益: {total_return:.2f}%") print(f"基准 ({benchmark_code}) 收益: {benchmark_return:.2f}%") print(f"超额收益: {total_return - benchmark_return:.2f}%") print(f"交易成本: {self.trade_cost * 100:.1f}%") # 保存结果 if save_path: Path(save_path).parent.mkdir(parents=True, exist_ok=True) # 保存净值曲线 if '策略净值' in result.columns: nav_df = pd.DataFrame({ 'date': result.index, 'strategy_nav': result['策略净值'].values }) if benchmark_df is not None and 'close' in benchmark_df.columns: # 重建基准净值 benchmark_nav = (benchmark_df['close'].pct_change() + 1).cumprod() nav_df['benchmark_nav'] = benchmark_nav.reindex(result.index, method='ffill').values nav_df.to_csv(f"{save_path}_nav.csv", index=False) # 保存信号 signals.to_csv(f"{save_path}_signals.csv") print(f"\n报告保存: {save_path}_*.csv") return { 'final_nav': final_nav, 'total_return': total_return, 'benchmark_return': benchmark_return, 'excess_return': total_return - benchmark_return, 'signals': signals, 'result': result } return {'signals': signals}