diff --git a/datasource/universal_fetcher.py b/datasource/universal_fetcher.py index 9c8aa4c..0b882b8 100644 --- a/datasource/universal_fetcher.py +++ b/datasource/universal_fetcher.py @@ -120,6 +120,15 @@ class UniversalDataFetcher: AssetType.CRYPTO: ['raw'], # 加密货币无复权 } + # 需要 SSH 隧道的资产类型(港美股/加密货币) + SSH_REQUIRED_TYPES = { + AssetType.US_INDEX, # 美股指数 + AssetType.US_STOCK, # 美股股票 + AssetType.HK_INDEX, # 港股指数 + AssetType.HK_STOCK, # 港股股票 + AssetType.CRYPTO, # 加密货币 + } + def fetch( self, code: str, @@ -169,9 +178,13 @@ class UniversalDataFetcher: f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}" ) + # 统一启动 SSH 隧道(港美股/加密货币需要) + if asset_type in self.SSH_REQUIRED_TYPES: + self._start_tunnel() + for attempt in range(retry): try: - # 路由到具体方法(传递 adj 参数) + # 路由到具体方法(无需重复调用 _start_tunnel) if asset_type == AssetType.CHINA_INDEX: return self._tushare.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.CHINA_ETF: @@ -179,16 +192,12 @@ class UniversalDataFetcher: elif asset_type == AssetType.CHINA_STOCK: return self._tushare.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.US_INDEX: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.US_STOCK: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.HK_INDEX: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.HK_STOCK: - self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date, adj) elif asset_type == AssetType.FUTURES: return self._fetch_futures(code, start_date, end_date, adj) @@ -342,61 +351,9 @@ class UniversalDataFetcher: return premium - def _fetch_us_index( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取美股指数 - - 特点:YFinance,需要SSH隧道,指数代码转换 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_us_stock( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取美股股票 - - 特点:YFinance,需要SSH隧道,返回价格+股票信息 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_hk_index( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取港股指数 - - 特点:YFinance,需要SSH隧道 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) - - def _fetch_hk_stock( - self, - code: str, - start_date: str, - end_date: str - ) -> Optional[pd.DataFrame]: - """ - 获取港股股票 - - 特点:YFinance,需要SSH隧道,返回价格+股票信息 - """ - self._start_tunnel() - return self._yfinance.fetch(code, start_date, end_date) + # ============================================================ + # 内部方法:特殊资产类型(保留) + # ============================================================ def _fetch_futures( self, @@ -486,8 +443,7 @@ class UniversalDataFetcher: for asset_type, code_list in grouped.items(): print(f" {asset_type.value}: {len(code_list)} 只") - # 启动隧道(港美股需要) - self._start_tunnel() + # 无需单独启动隧道,每个 fetch() 会自动处理 for code in codes: results[code] = self.fetch(code, start_date, end_date, retry) @@ -565,8 +521,8 @@ class UniversalDataFetcher: # 苹果复权价格(包含分红和拆分调整) df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq') """ - self._start_tunnel() - return self._yfinance.fetch_adj(code, start_date, end_date, adj) + # 直接调用 fetch(),自动处理 SSH 隧道 + return self.fetch(code, start_date, end_date, adj) def fetch_hk_adj( self, @@ -589,8 +545,8 @@ class UniversalDataFetcher: Returns: DataFrame with columns: date, open, high, low, close, volume (复权后) """ - self._start_tunnel() - return self._yfinance.fetch_adj(code, start_date, end_date, adj) + # 直接调用 fetch(),自动处理 SSH 隧道 + return self.fetch(code, start_date, end_date, adj) def fetch_stock_adj( self,