- 在 universal_fetcher._fetch_yfinance 中获取公司信息 - 包含 sector、industry、market_cap 字段 - 将信息存储在 DataFrame.attrs 中 - Flask API 自动提取并返回 info 字段
472 lines
15 KiB
Python
472 lines
15 KiB
Python
"""
|
||
统一数据获取接口
|
||
================
|
||
自动识别资产类型并路由到对应的数据源,获取K线数据
|
||
|
||
支持的资产类型:
|
||
- A股指数 (代码格式: 000300.SH, 399006.SZ, 931862.CSI)
|
||
- A股ETF (代码格式: 510300.SH, 159915.SZ)
|
||
- A股股票 (代码格式: 600000.SH, 000001.SZ)
|
||
- 港股指数/股票 (代码格式: HSI, HSTECH.HK)
|
||
- 美股指数/股票 (代码格式: NDX, SPX, AAPL)
|
||
- 期货合约 (代码格式: AU.SHF, CU.SHF)
|
||
- 加密货币 (代码格式: BTC, ETH)
|
||
|
||
用法:
|
||
from core.datasource.universal_fetcher import UniversalDataFetcher
|
||
|
||
fetcher = UniversalDataFetcher()
|
||
|
||
# 获取单只标的
|
||
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
|
||
|
||
# 获取多只标的
|
||
df = fetcher.fetch_multiple(["000300.SH", "NDX", "BTC"], "2024-01-01", "2024-12-31")
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Optional, Union, List, Dict
|
||
from datetime import datetime, timedelta
|
||
import pandas as pd
|
||
import yfinance as yf
|
||
|
||
from .hybrid_source import HybridDataSource
|
||
|
||
|
||
class AssetTypeDetector:
|
||
"""资产类型检测器"""
|
||
|
||
# A股指数后缀
|
||
CHINA_INDEX_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI')
|
||
|
||
# 期货后缀
|
||
FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX')
|
||
|
||
# 港股后缀
|
||
HK_SUFFIXES = ('.HK',)
|
||
|
||
# 加密货币代码集合
|
||
CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'}
|
||
|
||
# 期货代码映射 (与 HybridDataSource 保持一致)
|
||
FUTURES_CODE_MAP = {
|
||
"AU.SHF": "AU.SHF",
|
||
"CU.SHF": "CU.SHF",
|
||
}
|
||
|
||
# YFinance 代码映射 (与 HybridDataSource 保持一致)
|
||
YF_CODE_MAP = {
|
||
"HSTECH.HK": "3033.HK",
|
||
"HSI": "^HSI",
|
||
"NDX": "^NDX",
|
||
"SPX": "^GSPC",
|
||
"DJI": "^DJI",
|
||
"N225": "^N225",
|
||
"GDAXI": "^GDAXI",
|
||
"CL.NYM": "CL=F",
|
||
}
|
||
|
||
@classmethod
|
||
def detect(cls, code: str) -> str:
|
||
"""
|
||
检测资产类型
|
||
|
||
Returns:
|
||
资产类型字符串: 'china_index', 'china_etf', 'china_stock',
|
||
'hk_index', 'us_index', 'us_stock',
|
||
'futures', 'crypto'
|
||
"""
|
||
# 加密货币优先判断
|
||
if code.upper() in cls.CRYPTO_CODES:
|
||
return 'crypto'
|
||
|
||
# 期货判断
|
||
if any(code.endswith(suffix) for suffix in cls.FUTURES_SUFFIXES):
|
||
return 'futures'
|
||
|
||
# 港股判断(在A股之前,因为HSI可能被误判)
|
||
if code.endswith(cls.HK_SUFFIXES):
|
||
return 'hk_index' if code in cls.YF_CODE_MAP else 'hk_stock'
|
||
|
||
# 特殊处理:不在 YF_CODE_MAP 中的港股指数字符串(如 HSI)
|
||
if code in ('HSI', 'HSCEI', 'HSCCI'):
|
||
return 'hk_index'
|
||
|
||
# A股判断
|
||
if code.endswith(cls.CHINA_INDEX_SUFFIXES):
|
||
return cls._classify_china_asset(code)
|
||
|
||
# 美股指数判断(在 YF_CODE_MAP 中)
|
||
if code in cls.YF_CODE_MAP and cls.YF_CODE_MAP[code].startswith('^'):
|
||
return 'us_index'
|
||
|
||
# 默认:美股股票
|
||
return 'us_stock'
|
||
|
||
@classmethod
|
||
def _classify_china_asset(cls, code: str) -> str:
|
||
"""
|
||
细分A股资产类型
|
||
|
||
规则:
|
||
- .CSI 后缀:中证指数,直接判定为 china_index
|
||
- 指数: 6位数字,以0、1、3、9开头 (如 000300, 399006)
|
||
- ETF: 6位数字,以51、52、56、58、15、16开头
|
||
- 股票: 其他
|
||
"""
|
||
# .CSI 后缀直接判定为指数
|
||
if code.endswith('.CSI'):
|
||
return 'china_index'
|
||
|
||
# 提取代码主体
|
||
code_body = code.split('.')[0]
|
||
|
||
# 检查是否为6位数字
|
||
if not code_body.isdigit() or len(code_body) != 6:
|
||
return 'china_stock'
|
||
|
||
# 排除特殊情况:000001 是平安银行(股票),不是指数
|
||
if code_body == '000001':
|
||
return 'china_stock'
|
||
|
||
# ETF代码段判断
|
||
etf_prefixes = ['51', '52', '56', '58', '15', '16']
|
||
if any(code_body.startswith(prefix) for prefix in etf_prefixes):
|
||
return 'china_etf'
|
||
|
||
# 指数代码段判断
|
||
index_prefixes = ['000', '001', '002', '399', '930', '931', '932']
|
||
if any(code_body.startswith(prefix) for prefix in index_prefixes):
|
||
return 'china_index'
|
||
|
||
# 默认为股票
|
||
return 'china_stock'
|
||
|
||
|
||
class UniversalDataFetcher:
|
||
"""
|
||
统一数据获取器
|
||
|
||
封装 Tushare、YFinance、CCXT 等数据源,自动识别资产类型并路由
|
||
"""
|
||
|
||
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
|
||
"""
|
||
Args:
|
||
ssh_config: SSH隧道配置(用于访问YFinance等受限数据源)
|
||
use_cache: 是否使用缓存
|
||
"""
|
||
self.ssh_config = ssh_config or {}
|
||
self.use_cache = use_cache
|
||
self._hybrid_source = HybridDataSource(
|
||
ssh_config=ssh_config,
|
||
use_cache=use_cache
|
||
)
|
||
|
||
def fetch(
|
||
self,
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
retry: int = 3
|
||
) -> Optional[pd.DataFrame]:
|
||
"""
|
||
获取单只标的的K线数据
|
||
|
||
Args:
|
||
code: 标的代码(支持所有类型)
|
||
start_date: 开始日期,格式 'YYYY-MM-DD'
|
||
end_date: 结束日期,格式 'YYYY-MM-DD'
|
||
retry: 重试次数
|
||
|
||
Returns:
|
||
DataFrame,包含 columns: [open, high, low, close, volume, code]
|
||
索引为日期(DatetimeIndex)
|
||
失败时返回 None
|
||
"""
|
||
for attempt in range(retry):
|
||
try:
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
|
||
if asset_type in ('china_index', 'china_etf', 'china_stock'):
|
||
return self._fetch_china(code, start_date, end_date, asset_type)
|
||
elif asset_type == 'futures':
|
||
return self._fetch_futures(code, start_date, end_date)
|
||
elif asset_type in ('hk_index', 'hk_stock', 'us_index', 'us_stock'):
|
||
return self._fetch_yfinance(code, start_date, end_date, asset_type)
|
||
elif asset_type == 'crypto':
|
||
return self._fetch_crypto(code, start_date, end_date)
|
||
else:
|
||
print(f"⚠️ 未知的资产类型: {code}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
if attempt < retry - 1:
|
||
time.sleep(2)
|
||
else:
|
||
print(f"✗ 获取 {code} 数据失败 (尝试 {attempt+1}/{retry}): {e}")
|
||
return None
|
||
|
||
return None
|
||
|
||
def fetch_multiple(
|
||
self,
|
||
codes: List[str],
|
||
start_date: str,
|
||
end_date: str,
|
||
retry: int = 3
|
||
) -> Dict[str, Optional[pd.DataFrame]]:
|
||
"""
|
||
批量获取多只标的的K线数据
|
||
|
||
Args:
|
||
codes: 标的代码列表
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
retry: 重试次数
|
||
|
||
Returns:
|
||
字典 {code: DataFrame}
|
||
"""
|
||
results = {}
|
||
|
||
# 按资产类型分组
|
||
grouped = {}
|
||
for code in codes:
|
||
asset_type = AssetTypeDetector.detect(code)
|
||
if asset_type not in grouped:
|
||
grouped[asset_type] = []
|
||
grouped[asset_type].append(code)
|
||
|
||
print(f"开始获取 {len(codes)} 只标的的数据...")
|
||
print(f" 资产类型分布:")
|
||
for asset_type, code_list in grouped.items():
|
||
print(f" - {asset_type}: {len(code_list)} 只")
|
||
|
||
# 逐组获取
|
||
for asset_type, code_list in grouped.items():
|
||
for code in code_list:
|
||
df = self.fetch(code, start_date, end_date, retry)
|
||
if df is not None and len(df) > 0:
|
||
results[code] = df
|
||
print(f" ✓ {code}: {len(df)} 条")
|
||
else:
|
||
print(f" ✗ {code}: 无数据")
|
||
results[code] = None
|
||
|
||
return results
|
||
|
||
def _fetch_china(
|
||
self,
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
asset_type: str
|
||
) -> Optional[pd.DataFrame]:
|
||
"""获取A股数据(指数/ETF/股票)"""
|
||
import tushare as ts
|
||
|
||
# 临时清除代理环境变量
|
||
original_proxy = {}
|
||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||
original_proxy[key] = os.environ.pop(key, None)
|
||
|
||
try:
|
||
token = os.getenv("TUSHARE_TOKEN")
|
||
if not token:
|
||
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
|
||
|
||
pro = ts.pro_api(token)
|
||
|
||
# 转换代码格式
|
||
ts_code = code.replace(".SS", ".SH")
|
||
|
||
# 根据资产类型选择接口
|
||
if asset_type == 'china_index':
|
||
df = pro.index_daily(
|
||
ts_code=ts_code,
|
||
start_date=start_date.replace("-", ""),
|
||
end_date=end_date.replace("-", "")
|
||
)
|
||
elif asset_type in ('china_etf', 'china_stock'):
|
||
df = pro.fund_daily(ts_code=ts_code,
|
||
start_date=start_date.replace("-", ""),
|
||
end_date=end_date.replace("-", ""))
|
||
# 如果 fund_daily 无数据,尝试 stock 接口
|
||
if df is None or df.empty:
|
||
df = pro.daily(ts_code=ts_code,
|
||
start_date=start_date.replace("-", ""),
|
||
end_date=end_date.replace("-", ""))
|
||
else:
|
||
return None
|
||
|
||
if df is None or df.empty:
|
||
return None
|
||
|
||
# 标准化列名
|
||
df = df.rename(columns={
|
||
"trade_date": "date",
|
||
"vol": "volume",
|
||
})
|
||
|
||
# 转换日期格式
|
||
df["date"] = pd.to_datetime(df["date"])
|
||
df = df.set_index("date")
|
||
df = df.sort_index()
|
||
|
||
# 选择需要的列
|
||
cols = ['open', 'high', 'low', 'close', 'volume']
|
||
available = [c for c in cols if c in df.columns]
|
||
df = df[available]
|
||
df['code'] = code
|
||
|
||
return df
|
||
|
||
except Exception as e:
|
||
print(f"Tushare 下载 {code} 失败: {e}")
|
||
return None
|
||
finally:
|
||
# 恢复代理环境变量
|
||
for key, value in original_proxy.items():
|
||
if value is not None:
|
||
os.environ[key] = value
|
||
|
||
def _fetch_futures(
|
||
self,
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str
|
||
) -> Optional[pd.DataFrame]:
|
||
"""获取期货数据"""
|
||
return self._hybrid_source._fetch_futures(code, start_date, end_date)
|
||
|
||
def _fetch_yfinance(
|
||
self,
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
asset_type: str
|
||
) -> Optional[pd.DataFrame]:
|
||
"""获取港股/美股数据"""
|
||
# 转换代码格式
|
||
yf_code = AssetTypeDetector.YF_CODE_MAP.get(code, code)
|
||
|
||
# 美股指数需要加 ^ 前缀
|
||
if asset_type == 'us_index' and not yf_code.startswith('^'):
|
||
yf_code = f'^{yf_code}'
|
||
|
||
# 添加延迟避免限流
|
||
time.sleep(0.5)
|
||
|
||
try:
|
||
ticker = yf.Ticker(yf_code)
|
||
|
||
# 获取公司信息(仅对股票)
|
||
info = {}
|
||
if asset_type in ['us_stock', 'hk_stock']:
|
||
try:
|
||
stock_info = ticker.info
|
||
info = {
|
||
'sector': stock_info.get('sector'),
|
||
'industry': stock_info.get('industry'),
|
||
'market_cap': stock_info.get('marketCap'),
|
||
}
|
||
except Exception:
|
||
pass
|
||
|
||
# end_date 需要加一天(yfinance 的 end 是排他的)
|
||
end_date_obj = pd.Timestamp(end_date) + timedelta(days=1)
|
||
data = ticker.history(
|
||
start=start_date,
|
||
end=end_date_obj.strftime('%Y-%m-%d'),
|
||
auto_adjust=False
|
||
)
|
||
|
||
if len(data) == 0:
|
||
return None
|
||
|
||
# 标准化列名
|
||
data = data.rename(columns={
|
||
"Open": "open",
|
||
"High": "high",
|
||
"Low": "low",
|
||
"Close": "close",
|
||
"Volume": "volume",
|
||
})
|
||
|
||
# 选择需要的列
|
||
cols = ['open', 'high', 'low', 'close', 'volume']
|
||
available = [c for c in cols if c in data.columns]
|
||
data = data[available]
|
||
data['code'] = code
|
||
|
||
# 添加公司信息到 DataFrame 的 attrs(属性)
|
||
if info:
|
||
data.attrs['info'] = info
|
||
|
||
return data
|
||
|
||
except Exception as e:
|
||
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
|
||
return None
|
||
|
||
def _fetch_crypto(
|
||
self,
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str
|
||
) -> Optional[pd.DataFrame]:
|
||
"""获取加密货币数据"""
|
||
# 直接使用 HybridDataSource 的 CCXT 接口
|
||
return self._hybrid_source._fetch_ccxt(code, start_date, end_date)
|
||
|
||
def __enter__(self):
|
||
"""上下文管理器入口(启动SSH隧道)"""
|
||
if self.ssh_config.get("enabled"):
|
||
self._hybrid_source.__enter__()
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""上下文管理器出口(关闭SSH隧道)"""
|
||
if self.ssh_config.get("enabled"):
|
||
self._hybrid_source.__exit__(exc_type, exc_val, exc_tb)
|
||
|
||
|
||
# ============================================================
|
||
# 便捷函数
|
||
# ============================================================
|
||
|
||
def fetch_kline(
|
||
code: str,
|
||
start_date: str,
|
||
end_date: str,
|
||
ssh_config: Optional[dict] = None
|
||
) -> Optional[pd.DataFrame]:
|
||
"""
|
||
便捷函数:获取单只标的的K线数据
|
||
|
||
Args:
|
||
code: 标的代码
|
||
start_date: 开始日期 'YYYY-MM-DD'
|
||
end_date: 结束日期 'YYYY-MM-DD'
|
||
ssh_config: SSH隧道配置(可选)
|
||
|
||
Returns:
|
||
DataFrame with OHLCV data
|
||
"""
|
||
fetcher = UniversalDataFetcher(ssh_config=ssh_config)
|
||
with fetcher:
|
||
return fetcher.fetch(code, start_date, end_date)
|
||
|
||
|
||
def detect_asset_type(code: str) -> str:
|
||
"""
|
||
便捷函数:检测资产类型
|
||
|
||
Returns:
|
||
资产类型字符串
|
||
"""
|
||
return AssetTypeDetector.detect(code)
|