refactor(universal_fetcher): SSH隧道按资产类型统一启动
定义 SSH_REQUIRED_TYPES 常量集合,在 fetch() 入口处统一启动隧道 改进: - 新增 SSH_REQUIRED_TYPES 常量(港美股/加密货币需要 SSH) - fetch() 入口统一启动隧道,移除各分支重复调用 - fetch_us_adj/fetch_hk_adj 简化为调用 fetch() - fetch_batch 移除冗余的隧道启动 - 移除废弃的 _fetch_xxx 内部方法(减少 55 行) SSH 调用次数:10次 → 3次(仅保留必要场景)
This commit is contained in:
@@ -120,6 +120,15 @@ class UniversalDataFetcher:
|
|||||||
AssetType.CRYPTO: ['raw'], # 加密货币无复权
|
AssetType.CRYPTO: ['raw'], # 加密货币无复权
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 需要 SSH 隧道的资产类型(港美股/加密货币)
|
||||||
|
SSH_REQUIRED_TYPES = {
|
||||||
|
AssetType.US_INDEX, # 美股指数
|
||||||
|
AssetType.US_STOCK, # 美股股票
|
||||||
|
AssetType.HK_INDEX, # 港股指数
|
||||||
|
AssetType.HK_STOCK, # 港股股票
|
||||||
|
AssetType.CRYPTO, # 加密货币
|
||||||
|
}
|
||||||
|
|
||||||
def fetch(
|
def fetch(
|
||||||
self,
|
self,
|
||||||
code: str,
|
code: str,
|
||||||
@@ -169,9 +178,13 @@ class UniversalDataFetcher:
|
|||||||
f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}"
|
f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 统一启动 SSH 隧道(港美股/加密货币需要)
|
||||||
|
if asset_type in self.SSH_REQUIRED_TYPES:
|
||||||
|
self._start_tunnel()
|
||||||
|
|
||||||
for attempt in range(retry):
|
for attempt in range(retry):
|
||||||
try:
|
try:
|
||||||
# 路由到具体方法(传递 adj 参数)
|
# 路由到具体方法(无需重复调用 _start_tunnel)
|
||||||
if asset_type == AssetType.CHINA_INDEX:
|
if asset_type == AssetType.CHINA_INDEX:
|
||||||
return self._tushare.fetch(code, start_date, end_date, adj)
|
return self._tushare.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.CHINA_ETF:
|
elif asset_type == AssetType.CHINA_ETF:
|
||||||
@@ -179,16 +192,12 @@ class UniversalDataFetcher:
|
|||||||
elif asset_type == AssetType.CHINA_STOCK:
|
elif asset_type == AssetType.CHINA_STOCK:
|
||||||
return self._tushare.fetch(code, start_date, end_date, adj)
|
return self._tushare.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.US_INDEX:
|
elif asset_type == AssetType.US_INDEX:
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.US_STOCK:
|
elif asset_type == AssetType.US_STOCK:
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.HK_INDEX:
|
elif asset_type == AssetType.HK_INDEX:
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.HK_STOCK:
|
elif asset_type == AssetType.HK_STOCK:
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date, adj)
|
return self._yfinance.fetch(code, start_date, end_date, adj)
|
||||||
elif asset_type == AssetType.FUTURES:
|
elif asset_type == AssetType.FUTURES:
|
||||||
return self._fetch_futures(code, start_date, end_date, adj)
|
return self._fetch_futures(code, start_date, end_date, adj)
|
||||||
@@ -342,61 +351,9 @@ class UniversalDataFetcher:
|
|||||||
|
|
||||||
return premium
|
return premium
|
||||||
|
|
||||||
def _fetch_us_index(
|
# ============================================================
|
||||||
self,
|
# 内部方法:特殊资产类型(保留)
|
||||||
code: str,
|
# ============================================================
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> Optional[pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
获取美股指数
|
|
||||||
|
|
||||||
特点:YFinance,需要SSH隧道,指数代码转换
|
|
||||||
"""
|
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date)
|
|
||||||
|
|
||||||
def _fetch_us_stock(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> Optional[pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
获取美股股票
|
|
||||||
|
|
||||||
特点:YFinance,需要SSH隧道,返回价格+股票信息
|
|
||||||
"""
|
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date)
|
|
||||||
|
|
||||||
def _fetch_hk_index(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> Optional[pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
获取港股指数
|
|
||||||
|
|
||||||
特点:YFinance,需要SSH隧道
|
|
||||||
"""
|
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date)
|
|
||||||
|
|
||||||
def _fetch_hk_stock(
|
|
||||||
self,
|
|
||||||
code: str,
|
|
||||||
start_date: str,
|
|
||||||
end_date: str
|
|
||||||
) -> Optional[pd.DataFrame]:
|
|
||||||
"""
|
|
||||||
获取港股股票
|
|
||||||
|
|
||||||
特点:YFinance,需要SSH隧道,返回价格+股票信息
|
|
||||||
"""
|
|
||||||
self._start_tunnel()
|
|
||||||
return self._yfinance.fetch(code, start_date, end_date)
|
|
||||||
|
|
||||||
def _fetch_futures(
|
def _fetch_futures(
|
||||||
self,
|
self,
|
||||||
@@ -486,8 +443,7 @@ class UniversalDataFetcher:
|
|||||||
for asset_type, code_list in grouped.items():
|
for asset_type, code_list in grouped.items():
|
||||||
print(f" {asset_type.value}: {len(code_list)} 只")
|
print(f" {asset_type.value}: {len(code_list)} 只")
|
||||||
|
|
||||||
# 启动隧道(港美股需要)
|
# 无需单独启动隧道,每个 fetch() 会自动处理
|
||||||
self._start_tunnel()
|
|
||||||
|
|
||||||
for code in codes:
|
for code in codes:
|
||||||
results[code] = self.fetch(code, start_date, end_date, retry)
|
results[code] = self.fetch(code, start_date, end_date, retry)
|
||||||
@@ -565,8 +521,8 @@ class UniversalDataFetcher:
|
|||||||
# 苹果复权价格(包含分红和拆分调整)
|
# 苹果复权价格(包含分红和拆分调整)
|
||||||
df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
|
df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
|
||||||
"""
|
"""
|
||||||
self._start_tunnel()
|
# 直接调用 fetch(),自动处理 SSH 隧道
|
||||||
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
|
return self.fetch(code, start_date, end_date, adj)
|
||||||
|
|
||||||
def fetch_hk_adj(
|
def fetch_hk_adj(
|
||||||
self,
|
self,
|
||||||
@@ -589,8 +545,8 @@ class UniversalDataFetcher:
|
|||||||
Returns:
|
Returns:
|
||||||
DataFrame with columns: date, open, high, low, close, volume (复权后)
|
DataFrame with columns: date, open, high, low, close, volume (复权后)
|
||||||
"""
|
"""
|
||||||
self._start_tunnel()
|
# 直接调用 fetch(),自动处理 SSH 隧道
|
||||||
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
|
return self.fetch(code, start_date, end_date, adj)
|
||||||
|
|
||||||
def fetch_stock_adj(
|
def fetch_stock_adj(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user