""" YFinance 数据源模块 支持通过 SSH 隧道访问(用于绕过网络限制) """ import os import time import subprocess from typing import Optional, Tuple from datetime import datetime import pandas as pd import yfinance as yf import urllib3 # 禁用 SSL 警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) class SSHTunnelManager: """SSH 隧道管理器""" def __init__(self, config: dict): from pathlib import Path self.enabled = config.get("enabled", False) self.host = config.get("host", "") self.port = config.get("port", 22) self.username = config.get("username", "") self.local_port = config.get("local_port", 1080) self._process: Optional[subprocess.Popen] = None # 处理 key_path:如果是相对路径,转换为绝对路径 key_path = config.get("key_path", "") if key_path and not os.path.isabs(key_path): # 相对于项目根目录 project_root = Path(__file__).parent.parent.parent key_path = str(project_root / key_path) self.key_path = key_path def start(self) -> bool: """启动 SSH 隧道""" if not self.enabled: return True if not all([self.host, self.username, self.key_path]): print("SSH 配置不完整,跳过隧道建立") return False print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}") cmd = [ "ssh", "-N", "-D", f"127.0.0.1:{self.local_port}", "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-i", self.key_path, "-p", str(self.port), f"{self.username}@{self.host}" ] try: self._process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) time.sleep(2) if self._process.poll() is not None: stdout, stderr = self._process.communicate() print("✗ SSH 隧道建立失败") if stderr: print(f"错误: {stderr.decode()}") return False # 设置代理环境变量 proxy_url = f"socks5://127.0.0.1:{self.local_port}" os.environ["HTTP_PROXY"] = proxy_url os.environ["HTTPS_PROXY"] = proxy_url os.environ["ALL_PROXY"] = proxy_url print(f"✓ SSH 隧道已建立: {proxy_url}") time.sleep(1) return True except Exception as e: print(f"✗ SSH 隧道异常: {e}") return False def stop(self): """停止 SSH 隧道""" if self._process: self._process.terminate() self._process.wait() # 清理环境变量 for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: os.environ.pop(key, None) print("SSH 隧道已关闭") def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() class YFinanceDataSource: """YFinance 数据源""" def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True): self.ssh_config = ssh_config or {} self.use_cache = use_cache self._tunnel: Optional[SSHTunnelManager] = None def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """获取单个 ETF 数据""" try: ticker = yf.Ticker(code) data = ticker.history(start=start_date, end=end_date) if len(data) == 0: return None # 标准化列名 data = data.rename(columns={ "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume", }) # 添加代码列 data["code"] = code return data except Exception as e: print(f"下载 {code} 失败: {e}") return None def fetch_all( self, code_list: list, benchmark_code: str, start_date: str, end_date: str, ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]: """ 批量获取 ETF 数据和基准数据 Returns: (etf_data, benchmark_data, valid_codes) """ all_data = [] valid_codes = [] print(f"开始下载 {len(code_list)} 只 ETF 数据...") for code in code_list: data = self.fetch_single(code, start_date, end_date) if data is not None and len(data) > 0: all_data.append(data) valid_codes.append(code) print(f" ✓ {code}: {len(data)} 条") else: print(f" ✗ {code}: 无数据") if not all_data: return None, None, [] # 合并数据 etf_data = pd.concat(all_data, ignore_index=False) # 获取基准数据 benchmark_data = self.fetch_single(benchmark_code, start_date, end_date) if benchmark_data is not None: print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") return etf_data, benchmark_data, valid_codes def __enter__(self): """上下文管理器入口""" if self.ssh_config.get("enabled"): self._tunnel = SSHTunnelManager(self.ssh_config) self._tunnel.start() return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器出口""" if self._tunnel: self._tunnel.stop()