diff --git a/datasource/__init__.py b/datasource/__init__.py new file mode 100644 index 0000000..46cc232 --- /dev/null +++ b/datasource/__init__.py @@ -0,0 +1,19 @@ +""" +数据源模块 + +核心数据获取能力: +- A股数据:Tushare(指数、ETF、期货) +- 境外数据:YFinance(港股、美股)通过SSH隧道 +""" + +from .ssh_tunnel import SSHTunnelManager +from .tushare_source import TushareSource +from .yfinance_source import YFinanceSource +from .hybrid_source import HybridDataSource + +__all__ = [ + 'SSHTunnelManager', + 'TushareSource', + 'YFinanceSource', + 'HybridDataSource', +] \ No newline at end of file diff --git a/datasource/hybrid_source.py b/datasource/hybrid_source.py new file mode 100644 index 0000000..0be74d0 --- /dev/null +++ b/datasource/hybrid_source.py @@ -0,0 +1,285 @@ +""" +混合数据源 + +整合 Tushare(A股) + YFinance(境外)数据获取 +""" + +import os +import time +from typing import Optional, Tuple, Dict, List +from datetime import datetime +from pathlib import Path +import pandas as pd + +from .ssh_tunnel import SSHTunnelManager +from .tushare_source import TushareSource +from .yfinance_source import YFinanceSource + + +class HybridDataSource: + """ + 混合数据源 + + - A股指数/ETF/期货: Tushare + - 港股/美股/商品: YFinance(通过SSH隧道) + + 使用方式: + from datasource import HybridDataSource + + source = HybridDataSource.from_yaml('config/strategies/rotation.yaml') + result = source.fetch_all() + """ + + def __init__( + self, + ssh_config: Optional[dict] = None, + use_cache: bool = True, + cache_dir: str = "data/etf_cache/daily" + ): + """ + 初始化混合数据源 + + Args: + ssh_config: SSH隧道配置 + use_cache: 是否使用缓存 + cache_dir: 缓存目录 + """ + self.ssh_config = ssh_config or {} + self.use_cache = use_cache + self.cache_dir = cache_dir + + # 数据源实例 + self._tushare = TushareSource() + self._yfinance = YFinanceSource() + + # SSH隧道(延迟初始化) + self._tunnel: Optional[SSHTunnelManager] = None + + @classmethod + def from_yaml(cls, config_path: str) -> 'HybridDataSource': + """从YAML配置创建实例""" + import yaml + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + return cls( + ssh_config=config.get('ssh_tunnel', {}), + use_cache=config.get('use_cache', True) + ) + + def _start_tunnel(self) -> bool: + """启动SSH隧道""" + if self._tunnel is None and self.ssh_config.get('enabled'): + self._tunnel = SSHTunnelManager(self.ssh_config) + return self._tunnel.start() + return True + + def _stop_tunnel(self): + """停止SSH隧道""" + if self._tunnel: + self._tunnel.stop() + self._tunnel = None + + def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取单个标的数据 + + Args: + code: 标的代码 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + DataFrame with OHLCV data + """ + # 判断数据源 + if self._tushare.is_china_index(code) or self._tushare.is_futures(code): + return self._tushare.fetch(code, start_date, end_date) + else: + # YFinance需要SSH隧道 + self._start_tunnel() + return self._yfinance.fetch(code, start_date, end_date) + + def fetch_all( + self, + code_config: dict, + benchmark_code: str = "000300.SH", + start_date: str = "2019-01-01", + end_date: str = None + ) -> Tuple[ + Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式) + Optional[pd.DataFrame], # etf_data: ETF价格(宽格式) + Optional[pd.DataFrame], # etf_nav_data: ETF净值 + Optional[pd.DataFrame], # benchmark_data: 基准数据 + List[str], # valid_codes: 有效代码列表 + Dict[str, pd.DataFrame] # index_ohlcv_data: 原始OHLCV数据 + ]: + """ + 批量获取数据 + + Args: + code_config: 标的配置 {代码: {name, etf, market}} + benchmark_code: 基准代码 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + (index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data) + """ + if end_date is None: + end_date = datetime.now().strftime('%Y-%m-%d') + + # 启动SSH隧道 + self._start_tunnel() + + index_codes = list(code_config.keys()) + etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')} + + print(f"开始下载 {len(index_codes)} 只标的的数据...") + print(f" 指数代码: {len(index_codes)} 只") + print(f" ETF映射: {len(etf_codes)} 只") + + # 分类统计 + china_codes = [c for c in index_codes if self._tushare.is_china_index(c)] + futures_codes = [c for c in index_codes if self._tushare.is_futures(c)] + yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)] + + print(f" 中国A股指数: {len(china_codes)} 只") + print(f" 期货合约: {len(futures_codes)} 只") + print(f" 港股/美股: {len(yf_codes)} 只") + + # 下载指数数据 + print("\n [1/2] 下载指数数据...") + index_data_list = [] + index_ohlcv_data = {} + valid_codes = [] + + for code in index_codes: + name = code_config[code].get('name', code) + source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance" + + print(f" 下载 {code} ({name}) - {source}...", end=" ") + + data = self.fetch_single(code, start_date, end_date) + + if data is not None and len(data) > 0: + # 标准化 + data = data.copy() + data['source'] = source + data['code'] = code + data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() + + index_ohlcv_data[code] = data.copy() + index_data_list.append(data[['code', 'close', 'source']]) + valid_codes.append(code) + print(f"✓ {len(data)} 条") + else: + print("✗ 无数据") + + # 下载ETF数据 + etf_data_list = [] + etf_nav_data_list = [] + + if etf_codes: + print("\n [2/2] 下载ETF数据...") + + for idx_code, etf_code in etf_codes.items(): + name = code_config[idx_code].get('name', idx_code) + + print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ") + + # ETF价格 + etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date) + + # ETF净值 + etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date) + + if etf_data is not None and len(etf_data) > 0: + etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize() + etf_data_list.append(etf_data[['code', 'close']]) + + price_count = len(etf_data) + nav_count = len(etf_nav) if etf_nav else 0 + + print(f"✓ 价格{price_count}条 净值{nav_count}条") + else: + print("✗ 无数据") + + if etf_nav is not None and len(etf_nav) > 0: + etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize() + etf_nav_data_list.append(etf_nav[['code', 'nav']]) + + # 整合数据 + index_data = None + if index_data_list: + index_data = pd.concat(index_data_list) + index_data = index_data.pivot(columns='code', values='close') + + etf_data = None + if etf_data_list: + etf_data = pd.concat(etf_data_list) + etf_data = etf_data.pivot(columns='code', values='close') + + etf_nav_data = None + if etf_nav_data_list: + etf_nav_data = pd.concat(etf_nav_data_list) + etf_nav_data = etf_nav_data.pivot(columns='code', values='nav') + + # 基准数据 + benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date) + if benchmark_data is not None: + benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() + print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") + + return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data + + def __enter__(self): + self._start_tunnel() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._stop_tunnel() + + +# 简化接口 +def fetch_rotation_data(config_path: str = "config/strategies/rotation.yaml") -> dict: + """ + 获取轮动策略数据(简化接口) + + Args: + config_path: 配置文件路径 + + Returns: + { + 'index_data': 指数收盘价DataFrame, + 'etf_data': ETF价格DataFrame, + 'etf_nav_data': ETF净值DataFrame, + 'benchmark_data': 基准DataFrame, + 'valid_codes': 有效代码列表, + 'index_ohlcv_data': 原始OHLCV数据字典 + } + """ + import yaml + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + source = HybridDataSource.from_yaml(config_path) + + index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ + source.fetch_all( + code_config=config.get('code_list', {}), + benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'), + start_date=config.get('start_date', '2019-01-01'), + end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d')) + ) + + return { + 'index_data': index_data, + 'etf_data': etf_data, + 'etf_nav_data': etf_nav_data, + 'benchmark_data': benchmark_data, + 'valid_codes': valid_codes, + 'index_ohlcv_data': index_ohlcv_data + } \ No newline at end of file diff --git a/datasource/ssh_tunnel.py b/datasource/ssh_tunnel.py new file mode 100644 index 0000000..fb54a61 --- /dev/null +++ b/datasource/ssh_tunnel.py @@ -0,0 +1,116 @@ +""" +SSH隧道管理器 + +通过SSH隧道建立本地SOCKS5代理,用于访问境外数据源 +""" + +import os +import sys +import time +import subprocess +from pathlib import Path +from typing import Optional + + +class SSHTunnelManager: + """SSH隧道管理器""" + + def __init__(self, config: dict): + """ + 初始化SSH隧道 + + Args: + config: SSH配置字典 + - host: SSH服务器地址 + - port: SSH端口(默认22) + - username: SSH用户名 + - key_path: SSH私钥路径(相对或绝对) + - local_port: 本地SOCKS5端口(默认1080) + """ + self.enabled = config.get("enabled", False) + self.host = config.get("host", "") + self.port = config.get("port", 22) + self.username = config.get("username", "root") + self.local_port = config.get("local_port", 1080) + self._process: Optional[subprocess.Popen] = None + + # 处理 key_path:相对路径转换为绝对路径 + key_path = config.get("key_path", "") + if key_path and not os.path.isabs(key_path): + # 相对于项目根目录 + project_root = Path(__file__).parent.parent + key_path = str(project_root / key_path) + self.key_path = key_path + + def start(self) -> bool: + """启动SSH隧道""" + if not self.enabled: + return True + + if not all([self.host, self.username, self.key_path]): + print("SSH配置不完整,跳过隧道建立") + return False + + print(f"建立SSH隧道: {self.host}:{self.port} -> 本地SOCKS5端口 {self.local_port}") + + cmd = [ + "ssh", "-N", "-D", f"127.0.0.1:{self.local_port}", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-i", self.key_path, + "-p", str(self.port), + f"{self.username}@{self.host}" + ] + + try: + self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + time.sleep(2) + + if self._process.poll() is not None: + stdout, stderr = self._process.communicate() + print("✗ SSH隧道建立失败") + if stderr: + print(f"错误: {stderr.decode()}") + return False + + # 设置代理环境变量 + # 使用 socks5h:// 让代理服务器远程解析DNS,避免IPv6问题 + proxy_url = f"socks5h://127.0.0.1:{self.local_port}" + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + os.environ["ALL_PROXY"] = proxy_url + + print(f"✓ SSH隧道已建立: {proxy_url}") + time.sleep(1) + return True + + except Exception as e: + print(f"✗ SSH隧道异常: {e}") + return False + + def stop(self): + """停止SSH隧道""" + if self._process: + self._process.terminate() + self._process.wait() + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: + os.environ.pop(key, None) + print("SSH隧道已关闭") + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + +def create_ssh_tunnel_from_yaml(config_path: str) -> SSHTunnelManager: + """从YAML配置创建SSH隧道""" + import yaml + + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + ssh_config = config.get('ssh_tunnel', {}) + return SSHTunnelManager(ssh_config) \ No newline at end of file diff --git a/datasource/tushare_source.py b/datasource/tushare_source.py new file mode 100644 index 0000000..71551c4 --- /dev/null +++ b/datasource/tushare_source.py @@ -0,0 +1,245 @@ +""" +Tushare数据源 + +获取A股指数、ETF、期货数据 +""" + +import os +from typing import Optional +from datetime import datetime +import pandas as pd + + +class TushareSource: + """Tushare数据源""" + + def __init__(self, token: Optional[str] = None): + """ + 初始化Tushare数据源 + + Args: + token: Tushare Token(可选,默认从环境变量读取) + """ + self._token = token or os.getenv("TUSHARE_TOKEN") + if not self._token: + raise ValueError("请设置环境变量 TUSHARE_TOKEN") + + def _get_pro_api(self): + """获取Tushare Pro API""" + import tushare as ts + return ts.pro_api(self._token) + + def _clear_proxy(self) -> dict: + """清除代理环境变量(Tushare是国内服务,不需要代理)""" + original = {} + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + original[key] = os.environ.pop(key, None) + return original + + def _restore_proxy(self, original: dict): + """恢复代理环境变量""" + for key, value in original.items(): + if value is not None: + os.environ[key] = value + + def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取A股指数数据 + + Args: + code: 指数代码,如 '000300.SH', '399006.SZ', 'H30269.CSI' + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame with columns: date, open, high, low, close, volume + """ + original_proxy = self._clear_proxy() + + try: + pro = self._get_pro_api() + + # 转换代码格式 (.SS -> .SH) + ts_code = code.replace(".SS", ".SH") + + df = pro.index_daily( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "trade_date": "date", + "vol": "volume", + }) + + # 转换日期格式 + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + df["code"] = code + + return df[['code', 'open', 'high', 'low', 'close', 'volume']] + + except Exception as e: + print(f"Tushare下载指数 {code} 失败: {e}") + return None + + finally: + self._restore_proxy(original_proxy) + + def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取期货数据 + + Args: + code: 期货代码,如 'AU.SHF', 'CU.SHF' + start_date: 开始日期 + end_date: 结束日期 + """ + original_proxy = self._clear_proxy() + + try: + pro = self._get_pro_api() + + df = pro.futures_daily( + ts_code=code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", ""), + exchange='' + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "trade_date": "date", + "vol": "volume", + }) + + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + df["code"] = code + + return df[['code', 'open', 'high', 'low', 'close', 'volume']] + + except Exception as e: + print(f"Tushare下载期货 {code} 失败: {e}") + return None + + finally: + self._restore_proxy(original_proxy) + + def fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取ETF价格数据 + + Args: + code: ETF代码,如 '159915.SZ', '518880.SH' + """ + original_proxy = self._clear_proxy() + + try: + pro = self._get_pro_api() + + ts_code = code.replace(".SS", ".SH") + + df = pro.fund_daily( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + df = df.rename(columns={ + "trade_date": "date", + "vol": "volume", + }) + + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + df["code"] = code + + return df[['code', 'open', 'high', 'low', 'close', 'volume']] + + except Exception as e: + print(f"Tushare下载ETF {code} 失败: {e}") + return None + + finally: + self._restore_proxy(original_proxy) + + def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取ETF净值数据 + + Args: + code: ETF代码 + """ + original_proxy = self._clear_proxy() + + try: + pro = self._get_pro_api() + + ts_code = code.replace(".SS", ".SH") + + df = pro.fund_nav( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + df = df.rename(columns={ + "nav_date": "date", + "unit_nav": "nav", + }) + + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + df["code"] = code + + return df[['code', 'nav']] + + except Exception as e: + print(f"Tushare下载ETF净值 {code} 失败: {e}") + return None + + finally: + self._restore_proxy(original_proxy) + + def is_china_index(self, code: str) -> bool: + """判断是否为A股指数""" + return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI") + + def is_futures(self, code: str) -> bool: + """判断是否为期货""" + return ".SHF" in code or ".NYM" in code or ".DCE" in code or ".CZC" in code + + def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 通用数据获取(自动判断类型) + + Args: + code: 代码 + start_date: 开始日期 + end_date: 结束日期 + """ + if self.is_china_index(code): + return self.fetch_index(code, start_date, end_date) + elif self.is_futures(code): + return self.fetch_futures(code, start_date, end_date) + else: + return None \ No newline at end of file diff --git a/datasource/yfinance_source.py b/datasource/yfinance_source.py new file mode 100644 index 0000000..f0fa0f3 --- /dev/null +++ b/datasource/yfinance_source.py @@ -0,0 +1,112 @@ +""" +YFinance数据源 + +获取港股、美股数据(通过SSH隧道) +""" + +import os +import time +from typing import Optional +from datetime import datetime, timedelta +import pandas as pd +import urllib3 + +# 禁用SSL警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class YFinanceSource: + """YFinance数据源""" + + # 代码映射(项目代码 -> YFinance格式) + CODE_MAP = { + # 港股 + "HSTECH.HK": "3033.HK", # 恒生科技指数 + "HSI": "^HSI", # 恒生指数 + # 美股指数 + "NDX": "^NDX", # 纳斯达克100 + "SPX": "^GSPC", # 标普500 + "DJI": "^DJI", # 道琼斯 + # 日本/欧洲 + "N225": "^N225", # 日经225 + "GDAXI": "^GDAXI", # 德国DAX + # 商品 + "CL.NYM": "CL=F", # WTI原油期货 + } + + def __init__(self, use_ssh_tunnel: bool = False): + """ + 初始化YFinance数据源 + + Args: + use_ssh_tunnel: 是否使用SSH隧道(需先启动SSHTunnelManager) + """ + self.use_ssh_tunnel = use_ssh_tunnel + self._delay = 0.5 # 请求延迟(避免限流) + + def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 获取数据 + + Args: + code: 代码(如 'NDX', 'N225', 'HSI') + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame with columns: date, open, high, low, close, volume + """ + import yfinance as yf + + # 添加延迟避免限流 + time.sleep(self._delay) + + # 转换代码格式 + yf_code = self.CODE_MAP.get(code, code) + + try: + ticker = yf.Ticker(yf_code) + + # end_date 需要加一天(yfinance的end是排他的) + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) + + # auto_adjust=False 获取不复权价格 + df = ticker.history( + start=start_date, + end=end_dt.strftime("%Y-%m-%d"), + auto_adjust=False + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + # 确保索引是日期格式 + df.index = pd.to_datetime(df.index, utc=True).tz_localize(None).normalize() + df.index.name = "date" + + # 添加代码列 + df["code"] = code + + return df[['code', 'open', 'high', 'low', 'close', 'volume']] + + except Exception as e: + print(f"YFinance下载 {code} ({yf_code}) 失败: {e}") + return None + + def is_yfinance_code(self, code: str) -> bool: + """判断是否需要YFinance获取""" + # 非A股代码 + china_suffixes = ['.SH', '.SZ', '.SS', '.CSI'] + futures_suffixes = ['.SHF', '.NYM', '.DCE', '.CZC'] + + # A股或期货用Tushare,其他用YFinance + return not any(code.endswith(s) for s in china_suffixes + futures_suffixes) \ No newline at end of file diff --git a/strategies/rotation/strategy.py b/strategies/rotation/strategy.py index 930c8e2..ad86c44 100644 --- a/strategies/rotation/strategy.py +++ b/strategies/rotation/strategy.py @@ -106,7 +106,7 @@ class RotationStrategy(StrategyBase): return group_mapping def get_data(self) -> dict: - """获取数据(复用归档的数据源)""" + """获取数据(使用新数据源模块)""" code_list_config = self.config.get('code_list', {}) benchmark_config = self.config.get('benchmark', {}) benchmark_code = benchmark_config.get('code', '000300.SH') @@ -114,27 +114,17 @@ class RotationStrategy(StrategyBase): if not code_list_config: raise ValueError("配置中未找到 code_list") - # 使用归档的HybridDataSource - from archive.legacy_core.core.datasource.hybrid_source import HybridDataSource + # 使用新数据源模块 + from datasource import HybridDataSource ssh_config = self.config.get('ssh_tunnel', {}) - if ssh_config.get('enabled'): - ssh_config = { - 'host': ssh_config.get('host'), - 'port': ssh_config.get('port', 22), - 'username': ssh_config.get('username', 'root'), - 'key_path': ssh_config.get('key_path', 'hk_ecs.pem'), - 'local_port': ssh_config.get('local_port', 1080) - } - else: - ssh_config = None data_source = HybridDataSource( ssh_config=ssh_config, use_cache=self.config.get('use_cache', True) ) - # 调用 fetch_all(返回元组) + # 调用 fetch_all index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \ data_source.fetch_all( code_config=code_list_config,