Files
etf/core/datasource/universal_fetcher.py
aszerW e319426c10 feat(datasource): 实现统一数据获取接口 UniversalDataFetcher
- 新增 AssetTypeDetector 自动识别8种资产类型
- 支持 A股指数/ETF/股票、港股、美股、期货、加密货币
- 自动路由到 Tushare/YFinance/CCXT 数据源
- 集成 SSH 隧道支持港美股数据获取
- 提供便捷函数 fetch_kline 和 detect_asset_type
- 修复资产类型检测边界情况(.CSI后缀、000001股票等)
2026-05-07 21:19:19 +08:00

454 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一数据获取接口
================
自动识别资产类型并路由到对应的数据源获取K线数据
支持的资产类型:
- A股指数 (代码格式: 000300.SH, 399006.SZ, 931862.CSI)
- A股ETF (代码格式: 510300.SH, 159915.SZ)
- A股股票 (代码格式: 600000.SH, 000001.SZ)
- 港股指数/股票 (代码格式: HSI, HSTECH.HK)
- 美股指数/股票 (代码格式: NDX, SPX, AAPL)
- 期货合约 (代码格式: AU.SHF, CU.SHF)
- 加密货币 (代码格式: BTC, ETH)
用法:
from core.datasource.universal_fetcher import UniversalDataFetcher
fetcher = UniversalDataFetcher()
# 获取单只标的
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
# 获取多只标的
df = fetcher.fetch_multiple(["000300.SH", "NDX", "BTC"], "2024-01-01", "2024-12-31")
"""
import os
import time
from pathlib import Path
from typing import Optional, Union, List, Dict
from datetime import datetime, timedelta
import pandas as pd
import yfinance as yf
from .hybrid_source import HybridDataSource
class AssetTypeDetector:
"""资产类型检测器"""
# A股指数后缀
CHINA_INDEX_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI')
# 期货后缀
FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX')
# 港股后缀
HK_SUFFIXES = ('.HK',)
# 加密货币代码集合
CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'}
# 期货代码映射 (与 HybridDataSource 保持一致)
FUTURES_CODE_MAP = {
"AU.SHF": "AU.SHF",
"CU.SHF": "CU.SHF",
}
# YFinance 代码映射 (与 HybridDataSource 保持一致)
YF_CODE_MAP = {
"HSTECH.HK": "3033.HK",
"HSI": "^HSI",
"NDX": "^NDX",
"SPX": "^GSPC",
"DJI": "^DJI",
"N225": "^N225",
"GDAXI": "^GDAXI",
"CL.NYM": "CL=F",
}
@classmethod
def detect(cls, code: str) -> str:
"""
检测资产类型
Returns:
资产类型字符串: 'china_index', 'china_etf', 'china_stock',
'hk_index', 'us_index', 'us_stock',
'futures', 'crypto'
"""
# 加密货币优先判断
if code.upper() in cls.CRYPTO_CODES:
return 'crypto'
# 期货判断
if any(code.endswith(suffix) for suffix in cls.FUTURES_SUFFIXES):
return 'futures'
# 港股判断在A股之前因为HSI可能被误判
if code.endswith(cls.HK_SUFFIXES):
return 'hk_index' if code in cls.YF_CODE_MAP else 'hk_stock'
# 特殊处理:不在 YF_CODE_MAP 中的港股指数字符串(如 HSI
if code in ('HSI', 'HSCEI', 'HSCCI'):
return 'hk_index'
# A股判断
if code.endswith(cls.CHINA_INDEX_SUFFIXES):
return cls._classify_china_asset(code)
# 美股指数判断(在 YF_CODE_MAP 中)
if code in cls.YF_CODE_MAP and cls.YF_CODE_MAP[code].startswith('^'):
return 'us_index'
# 默认:美股股票
return 'us_stock'
@classmethod
def _classify_china_asset(cls, code: str) -> str:
"""
细分A股资产类型
规则:
- .CSI 后缀:中证指数,直接判定为 china_index
- 指数: 6位数字以0、1、3、9开头 (如 000300, 399006)
- ETF: 6位数字以51、52、56、58、15、16开头
- 股票: 其他
"""
# .CSI 后缀直接判定为指数
if code.endswith('.CSI'):
return 'china_index'
# 提取代码主体
code_body = code.split('.')[0]
# 检查是否为6位数字
if not code_body.isdigit() or len(code_body) != 6:
return 'china_stock'
# 排除特殊情况000001 是平安银行(股票),不是指数
if code_body == '000001':
return 'china_stock'
# ETF代码段判断
etf_prefixes = ['51', '52', '56', '58', '15', '16']
if any(code_body.startswith(prefix) for prefix in etf_prefixes):
return 'china_etf'
# 指数代码段判断
index_prefixes = ['000', '001', '002', '399', '930', '931', '932']
if any(code_body.startswith(prefix) for prefix in index_prefixes):
return 'china_index'
# 默认为股票
return 'china_stock'
class UniversalDataFetcher:
"""
统一数据获取器
封装 Tushare、YFinance、CCXT 等数据源,自动识别资产类型并路由
"""
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
"""
Args:
ssh_config: SSH隧道配置用于访问YFinance等受限数据源
use_cache: 是否使用缓存
"""
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._hybrid_source = HybridDataSource(
ssh_config=ssh_config,
use_cache=use_cache
)
def fetch(
self,
code: str,
start_date: str,
end_date: str,
retry: int = 3
) -> Optional[pd.DataFrame]:
"""
获取单只标的的K线数据
Args:
code: 标的代码(支持所有类型)
start_date: 开始日期,格式 'YYYY-MM-DD'
end_date: 结束日期,格式 'YYYY-MM-DD'
retry: 重试次数
Returns:
DataFrame包含 columns: [open, high, low, close, volume, code]
索引为日期DatetimeIndex
失败时返回 None
"""
for attempt in range(retry):
try:
asset_type = AssetTypeDetector.detect(code)
if asset_type in ('china_index', 'china_etf', 'china_stock'):
return self._fetch_china(code, start_date, end_date, asset_type)
elif asset_type == 'futures':
return self._fetch_futures(code, start_date, end_date)
elif asset_type in ('hk_index', 'hk_stock', 'us_index', 'us_stock'):
return self._fetch_yfinance(code, start_date, end_date, asset_type)
elif asset_type == 'crypto':
return self._fetch_crypto(code, start_date, end_date)
else:
print(f"⚠️ 未知的资产类型: {code}")
return None
except Exception as e:
if attempt < retry - 1:
time.sleep(2)
else:
print(f"✗ 获取 {code} 数据失败 (尝试 {attempt+1}/{retry}): {e}")
return None
return None
def fetch_multiple(
self,
codes: List[str],
start_date: str,
end_date: str,
retry: int = 3
) -> Dict[str, Optional[pd.DataFrame]]:
"""
批量获取多只标的的K线数据
Args:
codes: 标的代码列表
start_date: 开始日期
end_date: 结束日期
retry: 重试次数
Returns:
字典 {code: DataFrame}
"""
results = {}
# 按资产类型分组
grouped = {}
for code in codes:
asset_type = AssetTypeDetector.detect(code)
if asset_type not in grouped:
grouped[asset_type] = []
grouped[asset_type].append(code)
print(f"开始获取 {len(codes)} 只标的的数据...")
print(f" 资产类型分布:")
for asset_type, code_list in grouped.items():
print(f" - {asset_type}: {len(code_list)}")
# 逐组获取
for asset_type, code_list in grouped.items():
for code in code_list:
df = self.fetch(code, start_date, end_date, retry)
if df is not None and len(df) > 0:
results[code] = df
print(f"{code}: {len(df)}")
else:
print(f"{code}: 无数据")
results[code] = None
return results
def _fetch_china(
self,
code: str,
start_date: str,
end_date: str,
asset_type: str
) -> Optional[pd.DataFrame]:
"""获取A股数据指数/ETF/股票)"""
import tushare as ts
# 临时清除代理环境变量
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
token = os.getenv("TUSHARE_TOKEN")
if not token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
pro = ts.pro_api(token)
# 转换代码格式
ts_code = code.replace(".SS", ".SH")
# 根据资产类型选择接口
if asset_type == 'china_index':
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
elif asset_type in ('china_etf', 'china_stock'):
df = pro.fund_daily(ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", ""))
# 如果 fund_daily 无数据,尝试 stock 接口
if df is None or df.empty:
df = pro.daily(ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", ""))
else:
return None
if df is None or df.empty:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 选择需要的列
cols = ['open', 'high', 'low', 'close', 'volume']
available = [c for c in cols if c in df.columns]
df = df[available]
df['code'] = code
return df
except Exception as e:
print(f"Tushare 下载 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_futures(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""获取期货数据"""
return self._hybrid_source._fetch_futures(code, start_date, end_date)
def _fetch_yfinance(
self,
code: str,
start_date: str,
end_date: str,
asset_type: str
) -> Optional[pd.DataFrame]:
"""获取港股/美股数据"""
# 转换代码格式
yf_code = AssetTypeDetector.YF_CODE_MAP.get(code, code)
# 美股指数需要加 ^ 前缀
if asset_type == 'us_index' and not yf_code.startswith('^'):
yf_code = f'^{yf_code}'
# 添加延迟避免限流
time.sleep(0.5)
try:
ticker = yf.Ticker(yf_code)
# end_date 需要加一天yfinance 的 end 是排他的)
end_date_obj = pd.Timestamp(end_date) + timedelta(days=1)
data = ticker.history(
start=start_date,
end=end_date_obj.strftime('%Y-%m-%d'),
auto_adjust=False
)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 选择需要的列
cols = ['open', 'high', 'low', 'close', 'volume']
available = [c for c in cols if c in data.columns]
data = data[available]
data['code'] = code
return data
except Exception as e:
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
return None
def _fetch_crypto(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""获取加密货币数据"""
# 直接使用 HybridDataSource 的 CCXT 接口
return self._hybrid_source._fetch_ccxt(code, start_date, end_date)
def __enter__(self):
"""上下文管理器入口启动SSH隧道"""
if self.ssh_config.get("enabled"):
self._hybrid_source.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口关闭SSH隧道"""
if self.ssh_config.get("enabled"):
self._hybrid_source.__exit__(exc_type, exc_val, exc_tb)
# ============================================================
# 便捷函数
# ============================================================
def fetch_kline(
code: str,
start_date: str,
end_date: str,
ssh_config: Optional[dict] = None
) -> Optional[pd.DataFrame]:
"""
便捷函数获取单只标的的K线数据
Args:
code: 标的代码
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
ssh_config: SSH隧道配置可选
Returns:
DataFrame with OHLCV data
"""
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
with fetcher:
return fetcher.fetch(code, start_date, end_date)
def detect_asset_type(code: str) -> str:
"""
便捷函数:检测资产类型
Returns:
资产类型字符串
"""
return AssetTypeDetector.detect(code)