""" 混合数据源 整合 Tushare(A股) + YFinance(境外)数据获取 """ import os import time from typing import Optional, Tuple, Dict, List from datetime import datetime from pathlib import Path import pandas as pd from .ssh_tunnel import SSHTunnelManager from .tushare_source import TushareSource from .yfinance_source import YFinanceSource class HybridDataSource: """ 混合数据源 - A股指数/ETF/期货: Tushare - 港股/美股/商品: YFinance(通过SSH隧道) 使用方式: from datasource import HybridDataSource source = HybridDataSource.from_yaml('strategies/rotation/config.yaml') result = source.fetch_all() """ def __init__( self, ssh_config: Optional[dict] = None, use_cache: bool = True, cache_dir: str = "data/etf_cache/daily" ): """ 初始化混合数据源 Args: ssh_config: SSH隧道配置 use_cache: 是否使用缓存 cache_dir: 缓存目录 """ self.ssh_config = ssh_config or {} self.use_cache = use_cache self.cache_dir = cache_dir # 数据源实例 self._tushare = TushareSource() self._yfinance = YFinanceSource() # SSH隧道(延迟初始化) self._tunnel: Optional[SSHTunnelManager] = None @classmethod def from_yaml(cls, config_path: str) -> 'HybridDataSource': """从YAML配置创建实例""" import yaml with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) return cls( ssh_config=config.get('ssh_tunnel', {}), use_cache=config.get('use_cache', True) ) def _start_tunnel(self) -> bool: """启动SSH隧道""" if self._tunnel is None and self.ssh_config.get('enabled'): self._tunnel = SSHTunnelManager(self.ssh_config) return self._tunnel.start() return True def _stop_tunnel(self): """停止SSH隧道""" if self._tunnel: self._tunnel.stop() self._tunnel = None def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 获取单个标的数据 Args: code: 标的代码 start_date: 开始日期 end_date: 结束日期 Returns: DataFrame with OHLCV data """ # 判断数据源 if self._tushare.is_china_index(code) or self._tushare.is_futures(code): return self._tushare.fetch(code, start_date, end_date) else: # YFinance需要SSH隧道 self._start_tunnel() return self._yfinance.fetch(code, start_date, end_date) def fetch_all( self, code_config: dict, benchmark_code: str = "000300.SH", start_date: str = "2019-01-01", end_date: str = None ) -> Tuple[ Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式) Optional[pd.DataFrame], # etf_data: ETF价格(宽格式) Optional[pd.DataFrame], # etf_nav_data: ETF净值 Optional[pd.DataFrame], # benchmark_data: 基准数据 List[str], # valid_codes: 有效代码列表 Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据 Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射 ]: """ 批量获取数据 Args: code_config: 标的配置 {代码: {name, etf, market}} benchmark_code: 基准代码 start_date: 开始日期 end_date: 结束日期 Returns: (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map) """ if end_date is None: end_date = datetime.now().strftime('%Y-%m-%d') # 启动SSH隧道 self._start_tunnel() index_codes = list(code_config.keys()) etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')} print(f"开始下载 {len(index_codes)} 只标的的数据...") print(f" 指数代码: {len(index_codes)} 只") print(f" ETF映射: {len(etf_codes)} 只") # 分类统计 china_codes = [c for c in index_codes if self._tushare.is_china_index(c)] futures_codes = [c for c in index_codes if self._tushare.is_futures(c)] yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)] print(f" 中国A股指数: {len(china_codes)} 只") print(f" 期货合约: {len(futures_codes)} 只") print(f" 港股/美股: {len(yf_codes)} 只") # 下载指数数据 print("\n [1/2] 下载指数数据...") index_data_list = [] index_ohlcv_data = {} valid_codes = [] for code in index_codes: name = code_config[code].get('name', code) source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance" print(f" 下载 {code} ({name}) - {source}...", end=" ") data = self.fetch_single(code, start_date, end_date) if data is not None and len(data) > 0: # 标准化 data = data.copy() data['source'] = source data['code'] = code data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() index_ohlcv_data[code] = data.copy() index_data_list.append(data[['code', 'close', 'source']]) valid_codes.append(code) print(f"✓ {len(data)} 条") else: print("✗ 无数据") # 下载ETF数据 etf_data_list = [] etf_nav_data_list = [] if etf_codes: print("\n [2/2] 下载ETF数据...") for idx_code, etf_code in etf_codes.items(): name = code_config[idx_code].get('name', idx_code) print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ") # ETF价格 etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date) # ETF净值 etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date) if etf_data is not None and len(etf_data) > 0: etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize() etf_data_list.append(etf_data[['code', 'close']]) price_count = len(etf_data) nav_count = len(etf_nav) if etf_nav is not None else 0 print(f"✓ 价格{price_count}条 净值{nav_count}条") else: print("✗ 无数据") if etf_nav is not None and len(etf_nav) > 0: etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize() etf_nav_data_list.append(etf_nav[['code', 'nav']]) # 整合数据 index_data = None if index_data_list: index_data = pd.concat(index_data_list) if 'code' in index_data.columns and 'close' in index_data.columns: index_data = index_data.reset_index() if 'index' in index_data.columns: index_data = index_data.rename(columns={'index': 'date'}) index_data['date'] = pd.to_datetime(index_data['date']).dt.normalize() index_data = index_data.pivot_table(index='date', columns='code', values='close') etf_data = None if etf_data_list: etf_data = pd.concat(etf_data_list) if 'code' in etf_data.columns and 'close' in etf_data.columns: etf_data = etf_data.reset_index() if 'index' in etf_data.columns: etf_data = etf_data.rename(columns={'index': 'date'}) etf_data['date'] = pd.to_datetime(etf_data['date']).dt.normalize() etf_data = etf_data.pivot_table(index='date', columns='code', values='close') etf_nav_data = None if etf_nav_data_list: etf_nav_data = pd.concat(etf_nav_data_list) if 'code' in etf_nav_data.columns and 'nav' in etf_nav_data.columns: etf_nav_data = etf_nav_data.reset_index() if 'index' in etf_nav_data.columns: etf_nav_data = etf_nav_data.rename(columns={'index': 'date'}) etf_nav_data['date'] = pd.to_datetime(etf_nav_data['date']).dt.normalize() etf_nav_data = etf_nav_data.pivot_table(index='date', columns='code', values='nav') # 基准数据 benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date) if benchmark_data is not None: benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes def __enter__(self): self._start_tunnel() return self def __exit__(self, exc_type, exc_val, exc_tb): self._stop_tunnel() # 简化接口 def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict: """ 获取轮动策略数据(简化接口) Args: config_path: 配置文件路径 Returns: { 'index_data': 指数收盘价DataFrame, 'etf_data': ETF价格DataFrame, 'etf_nav_data': ETF净值DataFrame, 'benchmark_data': 基准DataFrame, 'valid_codes': 有效代码列表, 'index_ohlcv_data': 原始OHLCV数据字典 } """ import yaml with open(config_path, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) source = HybridDataSource.from_yaml(config_path) index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ source.fetch_all( code_config=config.get('code_list', {}), benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'), start_date=config.get('start_date', '2019-01-01'), end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d')) ) return { 'index_data': index_data, 'etf_data': etf_data, 'etf_nav_data': etf_nav_data, 'benchmark_data': benchmark_data, 'valid_codes': valid_codes, 'index_ohlcv_data': index_ohlcv_data }