""" 统一数据获取接口 ================ 自动识别资产类型并路由到对应的数据源,获取K线数据 支持的资产类型: - A股指数 (代码格式: 000300.SH, 399006.SZ, 931862.CSI) - A股ETF (代码格式: 510300.SH, 159915.SZ) - A股股票 (代码格式: 600000.SH, 000001.SZ) - 港股指数/股票 (代码格式: HSI, HSTECH.HK) - 美股指数/股票 (代码格式: NDX, SPX, AAPL) - 期货合约 (代码格式: AU.SHF, CU.SHF) - 加密货币 (代码格式: BTC, ETH) 用法: from core.datasource.universal_fetcher import UniversalDataFetcher fetcher = UniversalDataFetcher() # 获取单只标的 df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") # 获取多只标的 df = fetcher.fetch_multiple(["000300.SH", "NDX", "BTC"], "2024-01-01", "2024-12-31") """ import os import time from pathlib import Path from typing import Optional, Union, List, Dict from datetime import datetime, timedelta import pandas as pd import yfinance as yf from .hybrid_source import HybridDataSource class AssetTypeDetector: """资产类型检测器""" # A股指数后缀 CHINA_INDEX_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI') # 期货后缀 FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX') # 港股后缀 HK_SUFFIXES = ('.HK',) # 加密货币代码集合 CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'} # 期货代码映射 (与 HybridDataSource 保持一致) FUTURES_CODE_MAP = { "AU.SHF": "AU.SHF", "CU.SHF": "CU.SHF", } # YFinance 代码映射 (与 HybridDataSource 保持一致) YF_CODE_MAP = { "HSTECH.HK": "3033.HK", "HSI": "^HSI", "NDX": "^NDX", "SPX": "^GSPC", "DJI": "^DJI", "N225": "^N225", "GDAXI": "^GDAXI", "CL.NYM": "CL=F", } @classmethod def detect(cls, code: str) -> str: """ 检测资产类型 Returns: 资产类型字符串: 'china_index', 'china_etf', 'china_stock', 'hk_index', 'us_index', 'us_stock', 'futures', 'crypto' """ # 加密货币优先判断 if code.upper() in cls.CRYPTO_CODES: return 'crypto' # 期货判断 if any(code.endswith(suffix) for suffix in cls.FUTURES_SUFFIXES): return 'futures' # 港股判断(在A股之前,因为HSI可能被误判) if code.endswith(cls.HK_SUFFIXES): return 'hk_index' if code in cls.YF_CODE_MAP else 'hk_stock' # 特殊处理:不在 YF_CODE_MAP 中的港股指数字符串(如 HSI) if code in ('HSI', 'HSCEI', 'HSCCI'): return 'hk_index' # A股判断 if code.endswith(cls.CHINA_INDEX_SUFFIXES): return cls._classify_china_asset(code) # 美股指数判断(在 YF_CODE_MAP 中) if code in cls.YF_CODE_MAP and cls.YF_CODE_MAP[code].startswith('^'): return 'us_index' # 默认:美股股票 return 'us_stock' @classmethod def _classify_china_asset(cls, code: str) -> str: """ 细分A股资产类型 规则: - .CSI 后缀:中证指数,直接判定为 china_index - 指数: 6位数字,以0、1、3、9开头 (如 000300, 399006) - ETF: 6位数字,以51、52、56、58、15、16开头 - 股票: 其他 """ # .CSI 后缀直接判定为指数 if code.endswith('.CSI'): return 'china_index' # 提取代码主体 code_body = code.split('.')[0] # 检查是否为6位数字 if not code_body.isdigit() or len(code_body) != 6: return 'china_stock' # 排除特殊情况:000001 是平安银行(股票),不是指数 if code_body == '000001': return 'china_stock' # ETF代码段判断 etf_prefixes = ['51', '52', '56', '58', '15', '16'] if any(code_body.startswith(prefix) for prefix in etf_prefixes): return 'china_etf' # 指数代码段判断 index_prefixes = ['000', '001', '002', '399', '930', '931', '932'] if any(code_body.startswith(prefix) for prefix in index_prefixes): return 'china_index' # 默认为股票 return 'china_stock' class UniversalDataFetcher: """ 统一数据获取器 封装 Tushare、YFinance、CCXT 等数据源,自动识别资产类型并路由 """ def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True): """ Args: ssh_config: SSH隧道配置(用于访问YFinance等受限数据源) use_cache: 是否使用缓存 """ self.ssh_config = ssh_config or {} self.use_cache = use_cache self._hybrid_source = HybridDataSource( ssh_config=ssh_config, use_cache=use_cache ) def fetch( self, code: str, start_date: str, end_date: str, retry: int = 3 ) -> Optional[pd.DataFrame]: """ 获取单只标的的K线数据 Args: code: 标的代码(支持所有类型) start_date: 开始日期,格式 'YYYY-MM-DD' end_date: 结束日期,格式 'YYYY-MM-DD' retry: 重试次数 Returns: DataFrame,包含 columns: [open, high, low, close, volume, code] 索引为日期(DatetimeIndex) 失败时返回 None """ for attempt in range(retry): try: asset_type = AssetTypeDetector.detect(code) if asset_type in ('china_index', 'china_etf', 'china_stock'): return self._fetch_china(code, start_date, end_date, asset_type) elif asset_type == 'futures': return self._fetch_futures(code, start_date, end_date) elif asset_type in ('hk_index', 'hk_stock', 'us_index', 'us_stock'): return self._fetch_yfinance(code, start_date, end_date, asset_type) elif asset_type == 'crypto': return self._fetch_crypto(code, start_date, end_date) else: print(f"⚠️ 未知的资产类型: {code}") return None except Exception as e: if attempt < retry - 1: time.sleep(2) else: print(f"✗ 获取 {code} 数据失败 (尝试 {attempt+1}/{retry}): {e}") return None return None def fetch_multiple( self, codes: List[str], start_date: str, end_date: str, retry: int = 3 ) -> Dict[str, Optional[pd.DataFrame]]: """ 批量获取多只标的的K线数据 Args: codes: 标的代码列表 start_date: 开始日期 end_date: 结束日期 retry: 重试次数 Returns: 字典 {code: DataFrame} """ results = {} # 按资产类型分组 grouped = {} for code in codes: asset_type = AssetTypeDetector.detect(code) if asset_type not in grouped: grouped[asset_type] = [] grouped[asset_type].append(code) print(f"开始获取 {len(codes)} 只标的的数据...") print(f" 资产类型分布:") for asset_type, code_list in grouped.items(): print(f" - {asset_type}: {len(code_list)} 只") # 逐组获取 for asset_type, code_list in grouped.items(): for code in code_list: df = self.fetch(code, start_date, end_date, retry) if df is not None and len(df) > 0: results[code] = df print(f" ✓ {code}: {len(df)} 条") else: print(f" ✗ {code}: 无数据") results[code] = None return results def _fetch_china( self, code: str, start_date: str, end_date: str, asset_type: str ) -> Optional[pd.DataFrame]: """获取A股数据(指数/ETF/股票)""" import tushare as ts # 临时清除代理环境变量 original_proxy = {} for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: original_proxy[key] = os.environ.pop(key, None) try: token = os.getenv("TUSHARE_TOKEN") if not token: raise ValueError("请设置环境变量 TUSHARE_TOKEN") pro = ts.pro_api(token) # 转换代码格式 ts_code = code.replace(".SS", ".SH") # 根据资产类型选择接口 if asset_type == 'china_index': df = pro.index_daily( ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") ) elif asset_type in ('china_etf', 'china_stock'): df = pro.fund_daily(ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "")) # 如果 fund_daily 无数据,尝试 stock 接口 if df is None or df.empty: df = pro.daily(ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "")) else: return None if df is None or df.empty: return None # 标准化列名 df = df.rename(columns={ "trade_date": "date", "vol": "volume", }) # 转换日期格式 df["date"] = pd.to_datetime(df["date"]) df = df.set_index("date") df = df.sort_index() # 选择需要的列 cols = ['open', 'high', 'low', 'close', 'volume'] available = [c for c in cols if c in df.columns] df = df[available] df['code'] = code return df except Exception as e: print(f"Tushare 下载 {code} 失败: {e}") return None finally: # 恢复代理环境变量 for key, value in original_proxy.items(): if value is not None: os.environ[key] = value def _fetch_futures( self, code: str, start_date: str, end_date: str ) -> Optional[pd.DataFrame]: """获取期货数据""" return self._hybrid_source._fetch_futures(code, start_date, end_date) def _fetch_yfinance( self, code: str, start_date: str, end_date: str, asset_type: str ) -> Optional[pd.DataFrame]: """获取港股/美股数据""" # 转换代码格式 yf_code = AssetTypeDetector.YF_CODE_MAP.get(code, code) # 美股指数需要加 ^ 前缀 if asset_type == 'us_index' and not yf_code.startswith('^'): yf_code = f'^{yf_code}' # 添加延迟避免限流 time.sleep(0.5) try: ticker = yf.Ticker(yf_code) # 获取公司信息(仅对股票) info = {} if asset_type in ['us_stock', 'hk_stock']: try: stock_info = ticker.info info = { 'sector': stock_info.get('sector'), 'industry': stock_info.get('industry'), 'market_cap': stock_info.get('marketCap'), } except Exception: pass # end_date 需要加一天(yfinance 的 end 是排他的) end_date_obj = pd.Timestamp(end_date) + timedelta(days=1) data = ticker.history( start=start_date, end=end_date_obj.strftime('%Y-%m-%d'), auto_adjust=False ) if len(data) == 0: return None # 标准化列名 data = data.rename(columns={ "Open": "open", "High": "high", "Low": "low", "Close": "close", "Volume": "volume", }) # 选择需要的列 cols = ['open', 'high', 'low', 'close', 'volume'] available = [c for c in cols if c in data.columns] data = data[available] data['code'] = code # 添加公司信息到 DataFrame 的 attrs(属性) if info: data.attrs['info'] = info return data except Exception as e: print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}") return None def _fetch_crypto( self, code: str, start_date: str, end_date: str ) -> Optional[pd.DataFrame]: """获取加密货币数据""" # 直接使用 HybridDataSource 的 CCXT 接口 return self._hybrid_source._fetch_ccxt(code, start_date, end_date) def __enter__(self): """上下文管理器入口(启动SSH隧道)""" if self.ssh_config.get("enabled"): self._hybrid_source.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器出口(关闭SSH隧道)""" if self.ssh_config.get("enabled"): self._hybrid_source.__exit__(exc_type, exc_val, exc_tb) # ============================================================ # 便捷函数 # ============================================================ def fetch_kline( code: str, start_date: str, end_date: str, ssh_config: Optional[dict] = None ) -> Optional[pd.DataFrame]: """ 便捷函数:获取单只标的的K线数据 Args: code: 标的代码 start_date: 开始日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD' ssh_config: SSH隧道配置(可选) Returns: DataFrame with OHLCV data """ fetcher = UniversalDataFetcher(ssh_config=ssh_config) with fetcher: return fetcher.fetch(code, start_date, end_date) def detect_asset_type(code: str) -> str: """ 便捷函数:检测资产类型 Returns: 资产类型字符串 """ return AssetTypeDetector.detect(code)