refactor(universal_fetcher): SSH隧道按资产类型统一启动
定义 SSH_REQUIRED_TYPES 常量集合,在 fetch() 入口处统一启动隧道 改进: - 新增 SSH_REQUIRED_TYPES 常量(港美股/加密货币需要 SSH) - fetch() 入口统一启动隧道,移除各分支重复调用 - fetch_us_adj/fetch_hk_adj 简化为调用 fetch() - fetch_batch 移除冗余的隧道启动 - 移除废弃的 _fetch_xxx 内部方法(减少 55 行) SSH 调用次数:10次 → 3次(仅保留必要场景)
This commit is contained in:
@@ -120,6 +120,15 @@ class UniversalDataFetcher:
|
||||
AssetType.CRYPTO: ['raw'], # 加密货币无复权
|
||||
}
|
||||
|
||||
# 需要 SSH 隧道的资产类型(港美股/加密货币)
|
||||
SSH_REQUIRED_TYPES = {
|
||||
AssetType.US_INDEX, # 美股指数
|
||||
AssetType.US_STOCK, # 美股股票
|
||||
AssetType.HK_INDEX, # 港股指数
|
||||
AssetType.HK_STOCK, # 港股股票
|
||||
AssetType.CRYPTO, # 加密货币
|
||||
}
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
code: str,
|
||||
@@ -169,9 +178,13 @@ class UniversalDataFetcher:
|
||||
f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}"
|
||||
)
|
||||
|
||||
# 统一启动 SSH 隧道(港美股/加密货币需要)
|
||||
if asset_type in self.SSH_REQUIRED_TYPES:
|
||||
self._start_tunnel()
|
||||
|
||||
for attempt in range(retry):
|
||||
try:
|
||||
# 路由到具体方法(传递 adj 参数)
|
||||
# 路由到具体方法(无需重复调用 _start_tunnel)
|
||||
if asset_type == AssetType.CHINA_INDEX:
|
||||
return self._tushare.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.CHINA_ETF:
|
||||
@@ -179,16 +192,12 @@ class UniversalDataFetcher:
|
||||
elif asset_type == AssetType.CHINA_STOCK:
|
||||
return self._tushare.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.US_INDEX:
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.US_STOCK:
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.HK_INDEX:
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.HK_STOCK:
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||
elif asset_type == AssetType.FUTURES:
|
||||
return self._fetch_futures(code, start_date, end_date, adj)
|
||||
@@ -342,61 +351,9 @@ class UniversalDataFetcher:
|
||||
|
||||
return premium
|
||||
|
||||
def _fetch_us_index(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取美股指数
|
||||
|
||||
特点:YFinance,需要SSH隧道,指数代码转换
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def _fetch_us_stock(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取美股股票
|
||||
|
||||
特点:YFinance,需要SSH隧道,返回价格+股票信息
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def _fetch_hk_index(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取港股指数
|
||||
|
||||
特点:YFinance,需要SSH隧道
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def _fetch_hk_stock(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取港股股票
|
||||
|
||||
特点:YFinance,需要SSH隧道,返回价格+股票信息
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
# ============================================================
|
||||
# 内部方法:特殊资产类型(保留)
|
||||
# ============================================================
|
||||
|
||||
def _fetch_futures(
|
||||
self,
|
||||
@@ -486,8 +443,7 @@ class UniversalDataFetcher:
|
||||
for asset_type, code_list in grouped.items():
|
||||
print(f" {asset_type.value}: {len(code_list)} 只")
|
||||
|
||||
# 启动隧道(港美股需要)
|
||||
self._start_tunnel()
|
||||
# 无需单独启动隧道,每个 fetch() 会自动处理
|
||||
|
||||
for code in codes:
|
||||
results[code] = self.fetch(code, start_date, end_date, retry)
|
||||
@@ -565,8 +521,8 @@ class UniversalDataFetcher:
|
||||
# 苹果复权价格(包含分红和拆分调整)
|
||||
df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
|
||||
# 直接调用 fetch(),自动处理 SSH 隧道
|
||||
return self.fetch(code, start_date, end_date, adj)
|
||||
|
||||
def fetch_hk_adj(
|
||||
self,
|
||||
@@ -589,8 +545,8 @@ class UniversalDataFetcher:
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume (复权后)
|
||||
"""
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
|
||||
# 直接调用 fetch(),自动处理 SSH 隧道
|
||||
return self.fetch(code, start_date, end_date, adj)
|
||||
|
||||
def fetch_stock_adj(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user