diff --git a/core/datasource/__init__.py b/core/datasource/__init__.py index e69de29..4e3686b 100644 --- a/core/datasource/__init__.py +++ b/core/datasource/__init__.py @@ -0,0 +1,30 @@ +""" +数据源模块 +========== +提供统一的数据获取接口,支持多种资产类型和数据源 + +主要组件: +- UniversalDataFetcher: 统一数据获取器(推荐) +- HybridDataSource: 混合数据源(轮动策略使用) +- YFinanceDataSource: YFinance数据源 +- AssetTypeDetector: 资产类型检测器 +""" + +from .universal_fetcher import ( + UniversalDataFetcher, + AssetTypeDetector, + fetch_kline, + detect_asset_type, +) + +from .hybrid_source import HybridDataSource +from .yfinance_source import YFinanceDataSource + +__all__ = [ + 'UniversalDataFetcher', + 'AssetTypeDetector', + 'fetch_kline', + 'detect_asset_type', + 'HybridDataSource', + 'YFinanceDataSource', +] \ No newline at end of file diff --git a/core/datasource/universal_fetcher.py b/core/datasource/universal_fetcher.py new file mode 100644 index 0000000..0da1fe7 --- /dev/null +++ b/core/datasource/universal_fetcher.py @@ -0,0 +1,453 @@ +""" +统一数据获取接口 +================ +自动识别资产类型并路由到对应的数据源,获取K线数据 + +支持的资产类型: +- A股指数 (代码格式: 000300.SH, 399006.SZ, 931862.CSI) +- A股ETF (代码格式: 510300.SH, 159915.SZ) +- A股股票 (代码格式: 600000.SH, 000001.SZ) +- 港股指数/股票 (代码格式: HSI, HSTECH.HK) +- 美股指数/股票 (代码格式: NDX, SPX, AAPL) +- 期货合约 (代码格式: AU.SHF, CU.SHF) +- 加密货币 (代码格式: BTC, ETH) + +用法: + from core.datasource.universal_fetcher import UniversalDataFetcher + + fetcher = UniversalDataFetcher() + + # 获取单只标的 + df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") + + # 获取多只标的 + df = fetcher.fetch_multiple(["000300.SH", "NDX", "BTC"], "2024-01-01", "2024-12-31") +""" + +import os +import time +from pathlib import Path +from typing import Optional, Union, List, Dict +from datetime import datetime, timedelta +import pandas as pd +import yfinance as yf + +from .hybrid_source import HybridDataSource + + +class AssetTypeDetector: + """资产类型检测器""" + + # A股指数后缀 + CHINA_INDEX_SUFFIXES = ('.SH', '.SZ', '.SS', '.CSI') + + # 期货后缀 + FUTURES_SUFFIXES = ('.SHF', '.DCE', '.CZC', '.INE', '.GFEX') + + # 港股后缀 + HK_SUFFIXES = ('.HK',) + + # 加密货币代码集合 + CRYPTO_CODES = {'BTC', 'ETH', 'SOL', 'BNB', 'XRP', 'ADA', 'DOGE'} + + # 期货代码映射 (与 HybridDataSource 保持一致) + FUTURES_CODE_MAP = { + "AU.SHF": "AU.SHF", + "CU.SHF": "CU.SHF", + } + + # YFinance 代码映射 (与 HybridDataSource 保持一致) + YF_CODE_MAP = { + "HSTECH.HK": "3033.HK", + "HSI": "^HSI", + "NDX": "^NDX", + "SPX": "^GSPC", + "DJI": "^DJI", + "N225": "^N225", + "GDAXI": "^GDAXI", + "CL.NYM": "CL=F", + } + + @classmethod + def detect(cls, code: str) -> str: + """ + 检测资产类型 + + Returns: + 资产类型字符串: 'china_index', 'china_etf', 'china_stock', + 'hk_index', 'us_index', 'us_stock', + 'futures', 'crypto' + """ + # 加密货币优先判断 + if code.upper() in cls.CRYPTO_CODES: + return 'crypto' + + # 期货判断 + if any(code.endswith(suffix) for suffix in cls.FUTURES_SUFFIXES): + return 'futures' + + # 港股判断(在A股之前,因为HSI可能被误判) + if code.endswith(cls.HK_SUFFIXES): + return 'hk_index' if code in cls.YF_CODE_MAP else 'hk_stock' + + # 特殊处理:不在 YF_CODE_MAP 中的港股指数字符串(如 HSI) + if code in ('HSI', 'HSCEI', 'HSCCI'): + return 'hk_index' + + # A股判断 + if code.endswith(cls.CHINA_INDEX_SUFFIXES): + return cls._classify_china_asset(code) + + # 美股指数判断(在 YF_CODE_MAP 中) + if code in cls.YF_CODE_MAP and cls.YF_CODE_MAP[code].startswith('^'): + return 'us_index' + + # 默认:美股股票 + return 'us_stock' + + @classmethod + def _classify_china_asset(cls, code: str) -> str: + """ + 细分A股资产类型 + + 规则: + - .CSI 后缀:中证指数,直接判定为 china_index + - 指数: 6位数字,以0、1、3、9开头 (如 000300, 399006) + - ETF: 6位数字,以51、52、56、58、15、16开头 + - 股票: 其他 + """ + # .CSI 后缀直接判定为指数 + if code.endswith('.CSI'): + return 'china_index' + + # 提取代码主体 + code_body = code.split('.')[0] + + # 检查是否为6位数字 + if not code_body.isdigit() or len(code_body) != 6: + return 'china_stock' + + # 排除特殊情况:000001 是平安银行(股票),不是指数 + if code_body == '000001': + return 'china_stock' + + # ETF代码段判断 + etf_prefixes = ['51', '52', '56', '58', '15', '16'] + if any(code_body.startswith(prefix) for prefix in etf_prefixes): + return 'china_etf' + + # 指数代码段判断 + index_prefixes = ['000', '001', '002', '399', '930', '931', '932'] + if any(code_body.startswith(prefix) for prefix in index_prefixes): + return 'china_index' + + # 默认为股票 + return 'china_stock' + + +class UniversalDataFetcher: + """ + 统一数据获取器 + + 封装 Tushare、YFinance、CCXT 等数据源,自动识别资产类型并路由 + """ + + def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True): + """ + Args: + ssh_config: SSH隧道配置(用于访问YFinance等受限数据源) + use_cache: 是否使用缓存 + """ + self.ssh_config = ssh_config or {} + self.use_cache = use_cache + self._hybrid_source = HybridDataSource( + ssh_config=ssh_config, + use_cache=use_cache + ) + + def fetch( + self, + code: str, + start_date: str, + end_date: str, + retry: int = 3 + ) -> Optional[pd.DataFrame]: + """ + 获取单只标的的K线数据 + + Args: + code: 标的代码(支持所有类型) + start_date: 开始日期,格式 'YYYY-MM-DD' + end_date: 结束日期,格式 'YYYY-MM-DD' + retry: 重试次数 + + Returns: + DataFrame,包含 columns: [open, high, low, close, volume, code] + 索引为日期(DatetimeIndex) + 失败时返回 None + """ + for attempt in range(retry): + try: + asset_type = AssetTypeDetector.detect(code) + + if asset_type in ('china_index', 'china_etf', 'china_stock'): + return self._fetch_china(code, start_date, end_date, asset_type) + elif asset_type == 'futures': + return self._fetch_futures(code, start_date, end_date) + elif asset_type in ('hk_index', 'hk_stock', 'us_index', 'us_stock'): + return self._fetch_yfinance(code, start_date, end_date, asset_type) + elif asset_type == 'crypto': + return self._fetch_crypto(code, start_date, end_date) + else: + print(f"⚠️ 未知的资产类型: {code}") + return None + + except Exception as e: + if attempt < retry - 1: + time.sleep(2) + else: + print(f"✗ 获取 {code} 数据失败 (尝试 {attempt+1}/{retry}): {e}") + return None + + return None + + def fetch_multiple( + self, + codes: List[str], + start_date: str, + end_date: str, + retry: int = 3 + ) -> Dict[str, Optional[pd.DataFrame]]: + """ + 批量获取多只标的的K线数据 + + Args: + codes: 标的代码列表 + start_date: 开始日期 + end_date: 结束日期 + retry: 重试次数 + + Returns: + 字典 {code: DataFrame} + """ + results = {} + + # 按资产类型分组 + grouped = {} + for code in codes: + asset_type = AssetTypeDetector.detect(code) + if asset_type not in grouped: + grouped[asset_type] = [] + grouped[asset_type].append(code) + + print(f"开始获取 {len(codes)} 只标的的数据...") + print(f" 资产类型分布:") + for asset_type, code_list in grouped.items(): + print(f" - {asset_type}: {len(code_list)} 只") + + # 逐组获取 + for asset_type, code_list in grouped.items(): + for code in code_list: + df = self.fetch(code, start_date, end_date, retry) + if df is not None and len(df) > 0: + results[code] = df + print(f" ✓ {code}: {len(df)} 条") + else: + print(f" ✗ {code}: 无数据") + results[code] = None + + return results + + def _fetch_china( + self, + code: str, + start_date: str, + end_date: str, + asset_type: str + ) -> Optional[pd.DataFrame]: + """获取A股数据(指数/ETF/股票)""" + import tushare as ts + + # 临时清除代理环境变量 + original_proxy = {} + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + original_proxy[key] = os.environ.pop(key, None) + + try: + token = os.getenv("TUSHARE_TOKEN") + if not token: + raise ValueError("请设置环境变量 TUSHARE_TOKEN") + + pro = ts.pro_api(token) + + # 转换代码格式 + ts_code = code.replace(".SS", ".SH") + + # 根据资产类型选择接口 + if asset_type == 'china_index': + df = pro.index_daily( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + elif asset_type in ('china_etf', 'china_stock'): + df = pro.fund_daily(ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "")) + # 如果 fund_daily 无数据,尝试 stock 接口 + if df is None or df.empty: + df = pro.daily(ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "")) + else: + return None + + if df is None or df.empty: + return None + + # 标准化列名 + df = df.rename(columns={ + "trade_date": "date", + "vol": "volume", + }) + + # 转换日期格式 + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + + # 选择需要的列 + cols = ['open', 'high', 'low', 'close', 'volume'] + available = [c for c in cols if c in df.columns] + df = df[available] + df['code'] = code + + return df + + except Exception as e: + print(f"Tushare 下载 {code} 失败: {e}") + return None + finally: + # 恢复代理环境变量 + for key, value in original_proxy.items(): + if value is not None: + os.environ[key] = value + + def _fetch_futures( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """获取期货数据""" + return self._hybrid_source._fetch_futures(code, start_date, end_date) + + def _fetch_yfinance( + self, + code: str, + start_date: str, + end_date: str, + asset_type: str + ) -> Optional[pd.DataFrame]: + """获取港股/美股数据""" + # 转换代码格式 + yf_code = AssetTypeDetector.YF_CODE_MAP.get(code, code) + + # 美股指数需要加 ^ 前缀 + if asset_type == 'us_index' and not yf_code.startswith('^'): + yf_code = f'^{yf_code}' + + # 添加延迟避免限流 + time.sleep(0.5) + + try: + ticker = yf.Ticker(yf_code) + # end_date 需要加一天(yfinance 的 end 是排他的) + end_date_obj = pd.Timestamp(end_date) + timedelta(days=1) + data = ticker.history( + start=start_date, + end=end_date_obj.strftime('%Y-%m-%d'), + auto_adjust=False + ) + + if len(data) == 0: + return None + + # 标准化列名 + data = data.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + # 选择需要的列 + cols = ['open', 'high', 'low', 'close', 'volume'] + available = [c for c in cols if c in data.columns] + data = data[available] + data['code'] = code + + return data + + except Exception as e: + print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}") + return None + + def _fetch_crypto( + self, + code: str, + start_date: str, + end_date: str + ) -> Optional[pd.DataFrame]: + """获取加密货币数据""" + # 直接使用 HybridDataSource 的 CCXT 接口 + return self._hybrid_source._fetch_ccxt(code, start_date, end_date) + + def __enter__(self): + """上下文管理器入口(启动SSH隧道)""" + if self.ssh_config.get("enabled"): + self._hybrid_source.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口(关闭SSH隧道)""" + if self.ssh_config.get("enabled"): + self._hybrid_source.__exit__(exc_type, exc_val, exc_tb) + + +# ============================================================ +# 便捷函数 +# ============================================================ + +def fetch_kline( + code: str, + start_date: str, + end_date: str, + ssh_config: Optional[dict] = None +) -> Optional[pd.DataFrame]: + """ + 便捷函数:获取单只标的的K线数据 + + Args: + code: 标的代码 + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + ssh_config: SSH隧道配置(可选) + + Returns: + DataFrame with OHLCV data + """ + fetcher = UniversalDataFetcher(ssh_config=ssh_config) + with fetcher: + return fetcher.fetch(code, start_date, end_date) + + +def detect_asset_type(code: str) -> str: + """ + 便捷函数:检测资产类型 + + Returns: + 资产类型字符串 + """ + return AssetTypeDetector.detect(code)