From 67d67b1eea86d207d938e294e1a1f4386b33f54f Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 23 May 2026 18:53:26 +0800 Subject: [PATCH] =?UTF-8?q?refactor(universal=5Ffetcher):=20SSH=E9=9A=A7?= =?UTF-8?q?=E9=81=93=E6=8C=89=E8=B5=84=E4=BA=A7=E7=B1=BB=E5=9E=8B=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E5=90=AF=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 定义 SSH_REQUIRED_TYPES 常量集合,在 fetch() 入口处统一启动隧道 改进: - 新增 SSH_REQUIRED_TYPES 常量(港美股/加密货币需要 SSH) - fetch() 入口统一启动隧道,移除各分支重复调用 - fetch_us_adj/fetch_hk_adj 简化为调用 fetch() - fetch_batch 移除冗余的隧道启动 - 移除废弃的 _fetch_xxx 内部方法(减少 55 行) SSH 调用次数:10次 → 3次(仅保留必要场景) --- datasource/universal_fetcher.py | 88 +++++++++------------------------ 1 file changed, 22 insertions(+), 66 deletions(-) diff --git a/datasource/universal_fetcher.py b/datasource/universal_fetcher.py index 9c8aa4c..0b882b8 100644 --- a/datasource/universal_fetcher.py +++ b/datasource/universal_fetcher.py @@ -120,6 +120,15 @@ class UniversalDataFetcher: AssetType.CRYPTO: ['raw'], # 加密货币无复权 } + # 需要 SSH 隧道的资产类型(港美股/加密货币) + SSH_REQUIRED_TYPES = { + AssetType.US_INDEX, # 美股指数 + AssetType.US_STOCK, # 美股股票 + AssetType.HK_INDEX, # 港股指数 + AssetType.HK_STOCK, # 港股股票 + AssetType.CRYPTO, # 加密货币 + } + def fetch( self, code: str, @@ -169,9 +178,13 @@ class UniversalDataFetcher: f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}" ) + # 统一启动 SSH 隧道(港美股/加密货币需要) + if asset_type in self.SSH_REQUIRED_TYPES: + self._start_tunnel() + for attempt in range(retry): try: - # 路由到具体方法(传递 adj 参数) + # 路由到具体方法(无需重复调用 _start_tunnel) if asset_type == AssetType.CHINA_INDEX: return self._tushare.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.CHINA_ETF: @@ -179,16 +192,12 @@ class UniversalDataFetcher: elif asset_type == AssetType.CHINA_STOCK: return self._tushare.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.US_INDEX: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.US_STOCK: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.HK_INDEX: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.HK_STOCK: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.FUTURES: return self._fetch_futures(code, start_date, end_date, adj) @@ -342,61 +351,9 @@ class UniversalDataFetcher: return premium - def _fetch_us_index( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取美股指数 - - 特点:YFinance,需要SSH隧道,指数代码转换 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_us_stock( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取美股股票 - - 特点:YFinance,需要SSH隧道,返回价格+股票信息 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_hk_index( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取港股指数 - - 特点:YFinance,需要SSH隧道 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_hk_stock( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取港股股票 - - 特点:YFinance,需要SSH隧道,返回价格+股票信息 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) + # ============================================================ + # 内部方法:特殊资产类型(保留) + # ============================================================ def _fetch_futures( self, @@ -486,8 +443,7 @@ class UniversalDataFetcher: for asset_type, code_list in grouped.items(): print(f" {asset_type.value}: {len(code_list)} 只") - # 启动隧道(港美股需要) - self._start_tunnel() + # 无需单独启动隧道,每个 fetch() 会自动处理 for code in codes: results[code] = self.fetch(code, start_date, end_date, retry) @@ -565,8 +521,8 @@ class UniversalDataFetcher: # 苹果复权价格(包含分红和拆分调整) df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq') """ - self._start_tunnel() - return self._yfinance.fetch_adj(code, start_date, end_date, adj) + # 直接调用 fetch(),自动处理 SSH 隧道 + return self.fetch(code, start_date, end_date, adj) def fetch_hk_adj( self, @@ -589,8 +545,8 @@ class UniversalDataFetcher: Returns: DataFrame with columns: date, open, high, low, close, volume (复权后) """ - self._start_tunnel() - return self._yfinance.fetch_adj(code, start_date, end_date, adj) + # 直接调用 fetch(),自动处理 SSH 隧道 + return self.fetch(code, start_date, end_date, adj) def fetch_stock_adj( self,