From 6454e6823ff12a4024754ffed6adc410043fc50b Mon Sep 17 00:00:00 2001 From: aszerW Date: Wed, 25 Mar 2026 01:32:33 +0800 Subject: [PATCH] =?UTF-8?q?fix(datasource):=20=E4=BF=AE=E6=AD=A3=E6=B7=B7?= =?UTF-8?q?=E5=90=88=E6=95=B0=E6=8D=AE=E6=BA=90=E5=AF=BC=E5=85=A5=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误 - 新增 core.datasource 目录下多个数据源实现模块 - 增加 Akshare 数据源支持 A股指数数据拉取 - 实现数据缓存管理机制,支持本地数据缓存读写 - 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据 - 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问 - 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发 - 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题 - 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问 - 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则 --- .gitignore | 1 - config/strategies/rotation.yaml | 6 +- core/datasource/__init__.py | 0 core/datasource/akshare_source.py | 128 ++++++++ core/datasource/base.py | 53 ++++ core/datasource/cache.py | 54 ++++ core/datasource/hybrid_source.py | 481 +++++++++++++++++++++++++++++ core/datasource/socks2http.py | 168 ++++++++++ core/datasource/yfinance_source.py | 194 ++++++++++++ strategies/rotation/engine.py | 2 +- 10 files changed, 1083 insertions(+), 4 deletions(-) create mode 100644 core/datasource/__init__.py create mode 100644 core/datasource/akshare_source.py create mode 100644 core/datasource/base.py create mode 100644 core/datasource/cache.py create mode 100644 core/datasource/hybrid_source.py create mode 100644 core/datasource/socks2http.py create mode 100644 core/datasource/yfinance_source.py diff --git a/.gitignore b/.gitignore index 4e80ccc..81ff4dd 100644 --- a/.gitignore +++ b/.gitignore @@ -177,7 +177,6 @@ config_local.py *.backup -data/ test/ # Cache and generated files diff --git a/config/strategies/rotation.yaml b/config/strategies/rotation.yaml index 61497dd..6d05204 100644 --- a/config/strategies/rotation.yaml +++ b/config/strategies/rotation.yaml @@ -34,8 +34,10 @@ code_list: "HSTECH": "恒生科技" # 港股 "NDX": "纳指100" # 美股 "GC=F": "黄金" # 黄金期货 (COMEX) - "BTC": "比特币" # 加密货币 - "ETH": "以太坊" # 加密货币 + + # 加密货币 (使用 CCXT/OKX 现货) - 通过 SSH->HTTP 代理访问 + "BTC": "比特币" # OKX 现货 + "ETH": "以太坊" # OKX 现货 # 主市场配置(用于确定交易日历) primary_market: diff --git a/core/datasource/__init__.py b/core/datasource/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/core/datasource/akshare_source.py b/core/datasource/akshare_source.py new file mode 100644 index 0000000..e5b8931 --- /dev/null +++ b/core/datasource/akshare_source.py @@ -0,0 +1,128 @@ +""" +akshare数据源实现(用于CCI筛选等场景) +""" + +import time +import pandas as pd +import akshare as ak +from typing import Optional + +from .base import DataSource + + +class AkshareDataSource(DataSource): + """基于akshare的数据源""" + + def __init__(self, delay: float = 3.0): + """ + 初始化akshare数据源 + + Args: + delay: 每次请求间隔秒数(避免触发限流) + """ + self.delay = delay + + def fetch_ohlcv( + self, + code: str, + start_date: str, + end_date: str, + fields: Optional[list] = None, + ) -> pd.DataFrame: + """ + 获取指数历史数据(使用akshare的东方财富接口) + + Args: + code: 指数代码,如 '000300'(不带后缀) + start_date: 起始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + + Returns: + DataFrame,包含 date, open, high, low, close, volume + """ + # 转换日期格式 + sd = start_date.replace("-", "") + ed = end_date.replace("-", "") + + # 去除后缀 + symbol = code.replace(".SH", "").replace(".SZ", "") + + time.sleep(self.delay) + + try: + df = ak.index_zh_a_hist( + symbol=symbol, + period="daily", + start_date=sd, + end_date=ed, + ) + except Exception as e: + raise RuntimeError(f"akshare查询失败 [{code}]: {e}") + + if df is None or df.empty: + raise ValueError(f"akshare返回空数据: {code}") + + # 统一列名 + df = df.rename( + columns={ + "日期": "date", + "开盘": "open", + "最高": "high", + "最低": "low", + "收盘": "close", + "成交量": "volume", + } + ) + df["date"] = pd.to_datetime(df["date"]) + df["code"] = code + df = df[["date", "code", "open", "high", "low", "close", "volume"]] + df = df.sort_values("date").reset_index(drop=True) + + return df + + def fetch_multiple( + self, + codes: list, + start_date: str, + end_date: str, + ) -> tuple[pd.DataFrame, list]: + """批量获取数据""" + df_list = [] + failed = [] + + for code in codes: + try: + df = self.fetch_ohlcv(code, start_date, end_date) + df_list.append(df[["date", "code", "close"]].copy()) + except Exception as e: + print(f" ⚠ 跳过 {code}: {e}") + failed.append(code) + + if not df_list: + raise RuntimeError("所有数据获取失败") + + all_df = pd.concat(df_list, ignore_index=True) + data = all_df.pivot(index="date", columns="code", values="close") + data = data.sort_index() + + return data, failed + + def fetch_index_list(self) -> pd.DataFrame: + """ + 获取所有A股指数列表 + + Returns: + DataFrame with index info + """ + sources = ["沪深重要指数", "上证系列指数", "深证系列指数", "中证系列指数"] + df_list = [] + + for source in sources: + try: + df = ak.stock_zh_index_spot_em(symbol=source) + df["source"] = source + df_list.append(df) + except Exception as e: + print(f" ⚠ 获取 {source} 失败: {e}") + + return pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame() diff --git a/core/datasource/base.py b/core/datasource/base.py new file mode 100644 index 0000000..764937f --- /dev/null +++ b/core/datasource/base.py @@ -0,0 +1,53 @@ +""" +数据源抽象基类 +""" + +from abc import ABC, abstractmethod +from typing import Optional +import pandas as pd + + +class DataSource(ABC): + """数据源抽象基类""" + + @abstractmethod + def fetch_ohlcv( + self, + code: str, + start_date: str, + end_date: str, + fields: Optional[list] = None, + ) -> pd.DataFrame: + """ + 获取OHLCV数据 + + Args: + code: 标的代码 + start_date: 开始日期 (YYYY-MM-DD) + end_date: 结束日期 (YYYY-MM-DD) + fields: 指定字段列表,None表示获取全部 + + Returns: + DataFrame with columns: date, open, high, low, close, volume + """ + pass + + @abstractmethod + def fetch_multiple( + self, + codes: list, + start_date: str, + end_date: str, + ) -> pd.DataFrame: + """ + 批量获取多只标的收盘价数据 + + Args: + codes: 标的代码列表 + start_date: 开始日期 + end_date: 结束日期 + + Returns: + DataFrame, index=日期, columns=代码, values=收盘价 + """ + pass diff --git a/core/datasource/cache.py b/core/datasource/cache.py new file mode 100644 index 0000000..6343c34 --- /dev/null +++ b/core/datasource/cache.py @@ -0,0 +1,54 @@ +""" +数据缓存管理模块 +""" + +import os +import pandas as pd +from pathlib import Path +from typing import Optional + + +class DataCache: + """CSV文件缓存管理器""" + + def __init__(self, cache_dir: str = "data_cache"): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + + def _get_cache_path(self, code: str, start_date: str, end_date: str) -> Path: + """生成缓存文件路径""" + # 统一日期格式为 YYYYMMDD + sd = start_date.replace("-", "") + ed = end_date.replace("-", "") + safe_code = code.replace(".", "_") + return self.cache_dir / f"{safe_code}_{sd}_{ed}.csv" + + def get(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """ + 从缓存读取数据 + + Returns: + DataFrame or None(缓存不存在) + """ + cache_path = self._get_cache_path(code, start_date, end_date) + if cache_path.exists(): + df = pd.read_csv(cache_path) + df["date"] = pd.to_datetime(df["date"]) + return df + return None + + def set(self, code: str, start_date: str, end_date: str, df: pd.DataFrame) -> None: + """保存数据到缓存""" + cache_path = self._get_cache_path(code, start_date, end_date) + df.to_csv(cache_path, index=False) + + def clear(self, code: str = None) -> None: + """清除缓存""" + if code: + # 清除指定代码的缓存 + for f in self.cache_dir.glob(f"{code.replace('.', '_')}*.csv"): + f.unlink() + else: + # 清除所有缓存 + for f in self.cache_dir.glob("*.csv"): + f.unlink() diff --git a/core/datasource/hybrid_source.py b/core/datasource/hybrid_source.py new file mode 100644 index 0000000..db24c37 --- /dev/null +++ b/core/datasource/hybrid_source.py @@ -0,0 +1,481 @@ +""" +混合数据源模块 +- 中国A股指数: Tushare +- 港股/美股: YFinance (支持 SSH 隧道) +- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理) +""" + +import os +import sys +import time +import subprocess +from pathlib import Path +from typing import Optional, Tuple, Dict +from datetime import datetime +import pandas as pd +import yfinance as yf +import urllib3 + +# 禁用 SSL 警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class SSHTunnelManager: + """SSH 隧道管理器""" + + def __init__(self, config: dict): + self.enabled = config.get("enabled", False) + self.host = config.get("host", "") + self.port = config.get("port", 22) + self.username = config.get("username", "") + self.local_port = config.get("local_port", 1080) + self._process: Optional[subprocess.Popen] = None + + # 处理 key_path:如果是相对路径,转换为绝对路径 + key_path = config.get("key_path", "") + if key_path and not os.path.isabs(key_path): + # 相对于项目根目录 + project_root = Path(__file__).parent.parent.parent + key_path = str(project_root / key_path) + self.key_path = key_path + print(f"SSH 私钥路径: {self.key_path}") + + def start(self) -> bool: + """启动 SSH 隧道""" + if not self.enabled: + return True + + if not all([self.host, self.username, self.key_path]): + print("SSH 配置不完整,跳过隧道建立") + return False + + print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}") + + cmd = [ + "ssh", "-N", "-D", f"127.0.0.1:{self.local_port}", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-i", self.key_path, + "-p", str(self.port), + f"{self.username}@{self.host}" + ] + + try: + self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + time.sleep(2) + + if self._process.poll() is not None: + stdout, stderr = self._process.communicate() + print("✗ SSH 隧道建立失败") + if stderr: + print(f"错误: {stderr.decode()}") + return False + + # 设置代理环境变量 + proxy_url = f"socks5://127.0.0.1:{self.local_port}" + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + os.environ["ALL_PROXY"] = proxy_url + + print(f"✓ SSH 隧道已建立: {proxy_url}") + time.sleep(1) + return True + + except Exception as e: + print(f"✗ SSH 隧道异常: {e}") + return False + + def stop(self): + """停止 SSH 隧道""" + if self._process: + self._process.terminate() + self._process.wait() + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: + os.environ.pop(key, None) + print("SSH 隧道已关闭") + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + +class HybridDataSource: + """ + 混合数据源 + - 中国A股指数 (SH/SZ): Tushare + - 港股/美股: YFinance + - 加密货币: CCXT/OKX (通过 SSH->HTTP 代理) + """ + + # YFinance 代码映射 (代码 -> YFinance格式) + YF_CODE_MAP = { + # 港股 + "HSTECH": "3033.HK", # 恒生科技指数 ETF + "HSI": "^HSI", # 恒生指数 + # 美股指数 + "NDX": "^NDX", # 纳斯达克100 + "SPX": "^GSPC", # 标普500 + "DJI": "^DJI", # 道琼斯 + # 黄金 + "GC=F": "GC=F", # 黄金期货 + } + + # CCXT 代码映射 (代码 -> CCXT格式) + CCXT_CODE_MAP = { + "BTC": "BTC/USDT", # OKX 比特币现货 + "ETH": "ETH/USDT", # OKX 以太坊现货 + } + + def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True): + self.ssh_config = ssh_config or {} + self.use_cache = use_cache + self._tunnel: Optional[SSHTunnelManager] = None + self._tushare_token: Optional[str] = None + + def _is_china_index(self, code: str) -> bool: + """判断是否为中国A股指数""" + return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") + + def _is_crypto(self, code: str) -> bool: + """判断是否为加密货币""" + return code in self.CCXT_CODE_MAP + + def _get_tushare_token(self) -> str: + """获取 Tushare Token""" + if self._tushare_token is None: + import os + self._tushare_token = os.getenv("TUSHARE_TOKEN") + if not self._tushare_token: + raise ValueError("请设置环境变量 TUSHARE_TOKEN") + return self._tushare_token + + def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)""" + import os + + # 临时清除代理环境变量(Tushare 是国内服务,不需要代理) + original_proxy = {} + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + original_proxy[key] = os.environ.pop(key, None) + + try: + import tushare as ts + + pro = ts.pro_api(self._get_tushare_token()) + + # 转换代码格式 (000300.SS -> 000300.SH) + ts_code = code.replace(".SS", ".SH") + + # 获取日线数据 + df = pro.index_daily( + ts_code=ts_code, + start_date=start_date.replace("-", ""), + end_date=end_date.replace("-", "") + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + "trade_date": "date", + "open": "open", + "high": "high", + "low": "low", + "close": "close", + "vol": "volume", + }) + + # 转换日期格式 + df["date"] = pd.to_datetime(df["date"]) + df = df.set_index("date") + df = df.sort_index() + + # 添加代码列 + df["code"] = code + + return df + + except Exception as e: + print(f"Tushare 下载 {code} 失败: {e}") + return None + + finally: + # 恢复代理环境变量 + for key, value in original_proxy.items(): + if value is not None: + os.environ[key] = value + + def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """使用 YFinance 获取数据""" + import time + + # 转换代码格式 + yf_code = self.YF_CODE_MAP.get(code, code) + + # 添加延迟以避免限流 + time.sleep(0.5) + + try: + ticker = yf.Ticker(yf_code) + data = ticker.history(start=start_date, end=end_date) + + if len(data) == 0: + return None + + # 标准化列名 + data = data.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + # 添加代码列 + data["code"] = code + + return data + + except Exception as e: + print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}") + return None + + def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]: + """使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)""" + import ccxt + + # 转换代码格式 + ccxt_code = self.CCXT_CODE_MAP.get(code, code) + + # 配置 CCXT + config = {'enableRateLimit': True} + if http_proxy: + config['proxies'] = {'http': http_proxy, 'https': http_proxy} + + try: + exchange = ccxt.okx(config) + + # 获取日线数据 + since = int(pd.Timestamp(start_date).timestamp() * 1000) + all_ohlcv = [] + limit = 100 + + while since < int(pd.Timestamp(end_date).timestamp() * 1000): + ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit) + if not ohlcv: + break + all_ohlcv.extend(ohlcv) + since = ohlcv[-1][0] + 86400000 + + if not all_ohlcv: + return None + + # 转换为 DataFrame + df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + # 转换时间戳为日期索引 + df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize() + df.index.name = 'date' + df = df[['open', 'high', 'low', 'close', 'volume']] + # 过滤日期范围 + start_ts = pd.Timestamp(start_date) + end_ts = pd.Timestamp(end_date) + df = df[(df.index >= start_ts) & (df.index <= end_ts)] + df['code'] = code + + return df + + except Exception as e: + print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}") + return None + + def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]: + """获取单个标的的数据""" + if self._is_china_index(code): + return self._fetch_tushare(code, start_date, end_date) + elif self._is_crypto(code): + return self._fetch_ccxt(code, start_date, end_date, http_proxy) + else: + return self._fetch_yfinance(code, start_date, end_date) + + def fetch_all( + self, + code_list, # list[代码] 或 dict{代码: 名称} + benchmark_code: str, + start_date: str, + end_date: str, + ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]: + """ + 批量获取数据 + 注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同, + 这里返回的是长格式数据,由调用方分别处理各市场的数据 + + Returns: + (etf_data, benchmark_data, valid_codes) + etf_data: DataFrame with columns [code, close, source], index=date + """ + all_data = [] + valid_codes = [] + + # 兼容列表和字典格式 + if isinstance(code_list, dict): + codes = list(code_list.keys()) + code_name_map = code_list + else: + codes = code_list + code_name_map = {c: c for c in codes} + + print(f"开始下载 {len(codes)} 只标的的数据...") + china_codes = [c for c in codes if self._is_china_index(c)] + global_codes = [c for c in codes if not self._is_china_index(c)] + print(f" 中国A股指数: {len(china_codes)} 只") + print(f" 港股/美股/加密货币: {len(global_codes)} 只") + + # 检查是否需要启动 socks2http 代理(用于加密货币) + crypto_codes = [c for c in codes if self._is_crypto(c)] + http_proxy = None + socks2http_proc = None + + # 只有在 SSH 隧道已建立时才启动 socks2http + if crypto_codes and self._tunnel is not None: + import subprocess + import time + print(f"\n 启动 socks2http 代理服务(用于加密货币)...") + try: + # 启动 socks2http.py 作为子进程 + socks2http_path = Path(__file__).parent / "socks2http.py" + socks2http_proc = subprocess.Popen( + [sys.executable, str(socks2http_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + time.sleep(2) # 等待代理启动 + http_proxy = "http://127.0.0.1:8080" + print(f" ✓ HTTP 代理已启动: {http_proxy}") + except Exception as e: + print(f" ✗ 启动代理失败: {e}") + + # 分别下载数据 + for code in codes: + if self._is_china_index(code): + source = "Tushare" + elif self._is_crypto(code): + source = "CCXT/OKX" + else: + source = "YFinance" + + name = code_name_map.get(code, code) + print(f" 下载 {code} ({name}) - {source}...", end=" ") + + # 加密货币使用 HTTP 代理 + proxy = http_proxy if self._is_crypto(code) else None + data = self.fetch_single(code, start_date, end_date, proxy) + + if data is not None and len(data) > 0: + # 标准化数据格式 + data = data.copy() + data['source'] = source + # 确保索引是日期格式且无时区,只保留日期部分(去掉时间) + data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize() + all_data.append(data[['code', 'close', 'source']]) + valid_codes.append(code) + print(f"✓ {len(data)} 条") + else: + print("✗ 无数据") + + # 关闭 socks2http 代理 + if socks2http_proc: + socks2http_proc.terminate() + socks2http_proc.wait() + print(f"\n socks2http 代理已关闭") + + if not all_data: + return None, None, [] + + # 检查数据源类型 + sources = set(d['source'].iloc[0] for d in all_data) + + if len(sources) == 1: + # 单一数据源:转换为宽格式(向后兼容) + all_df = pd.concat(all_data, ignore_index=False) + all_df = all_df.reset_index() + all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None) + etf_data = all_df.pivot_table( + index='date', + columns='code', + values='close', + aggfunc='first' + ) + print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):") + print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}") + print(f" 交易日数: {len(etf_data)}") + print(f" 有效标的: {len(etf_data.columns)} 只") + else: + # 多数据源:以主市场(Tushare/A股)为基准,其他市场数据前向填充 + print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):") + + # 合并所有数据(索引已经是标准化后的日期) + all_df = pd.concat(all_data, ignore_index=False) + all_df = all_df.reset_index() + # 重命名索引列为 date + if 'index' in all_df.columns: + all_df = all_df.rename(columns={'index': 'date'}) + # 确保 date 列是日期格式(不含时间) + all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize() + + # 透视为宽格式 + etf_data = all_df.pivot_table( + index='date', + columns='code', + values='close', + aggfunc='first' + ) + + # 获取主市场(Tushare)的交易日历 + tushare_codes = [c for c in valid_codes if self._is_china_index(c)] + if tushare_codes: + # 使用第一个A股代码的日期作为主市场交易日 + primary_dates = etf_data[tushare_codes[0]].dropna().index + print(f" 主市场交易日: {len(primary_dates)} 天") + + # 重新索引到主市场交易日,使用前向填充 + etf_data = etf_data.reindex(primary_dates) + + # 对每个非主市场代码进行前向填充 + yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)] + for code in yfinance_codes: + if code in etf_data.columns: + # 前向填充:用最近的有效价格填充休市日的数据 + etf_data[code] = etf_data[code].ffill() + # 对于开头的NaN,用后向填充 + etf_data[code] = etf_data[code].bfill() + + print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)") + + print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}") + print(f" 交易日数: {len(etf_data)}") + print(f" 有效标的: {len(etf_data.columns)} 只") + + # 获取基准数据 + benchmark_data = self.fetch_single(benchmark_code, start_date, end_date) + if benchmark_data is not None: + # 标准化日期索引(无时区,只保留日期部分) + benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize() + print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") + + return etf_data, benchmark_data, valid_codes + + def __enter__(self): + """上下文管理器入口""" + if self.ssh_config.get("enabled"): + self._tunnel = SSHTunnelManager(self.ssh_config) + self._tunnel.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + if self._tunnel: + self._tunnel.stop() diff --git a/core/datasource/socks2http.py b/core/datasource/socks2http.py new file mode 100644 index 0000000..37ba687 --- /dev/null +++ b/core/datasource/socks2http.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 +""" +SOCKS5 转 HTTP 代理转发工具 +将 SSH 隧道的 SOCKS5 代理 (1080) 转为 HTTP 代理 (8080) +供 CCXT 等只支持 HTTP 代理的库使用 +""" + +import socket +import threading +import select +import sys +from urllib.parse import urlparse + + +class Socks2Http: + def __init__(self, socks_host='127.0.0.1', socks_port=1080, http_port=8080): + self.socks_host = socks_host + self.socks_port = socks_port + self.http_port = http_port + self.server = None + self.running = False + + def start(self): + """启动 HTTP 代理服务器""" + self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.server.bind(('127.0.0.1', self.http_port)) + self.server.listen(5) + self.running = True + + print(f"HTTP 代理已启动: http://127.0.0.1:{self.http_port}") + print(f"转发到 SOCKS5: {self.socks_host}:{self.socks_port}") + + while self.running: + try: + client, addr = self.server.accept() + thread = threading.Thread(target=self._handle_client, args=(client,)) + thread.daemon = True + thread.start() + except Exception as e: + if self.running: + print(f"接受连接错误: {e}") + + def stop(self): + """停止代理服务器""" + self.running = False + if self.server: + self.server.close() + print("HTTP 代理已停止") + + def _handle_client(self, client): + """处理客户端连接""" + try: + # 读取 HTTP 请求 + request = client.recv(4096) + if not request: + client.close() + return + + # 解析 CONNECT 请求 + first_line = request.split(b'\r\n')[0].decode('utf-8', errors='ignore') + + if first_line.startswith('CONNECT'): + # HTTPS 代理 + parts = first_line.split() + if len(parts) >= 2: + target = parts[1] + host, port = target.rsplit(':', 1) + port = int(port) + + # 连接到 SOCKS5 代理 + remote = self._connect_via_socks5(host, port) + if remote: + client.send(b'HTTP/1.1 200 Connection established\r\n\r\n') + self._relay(client, remote) + else: + client.send(b'HTTP/1.1 502 Bad Gateway\r\n\r\n') + else: + # HTTP 代理 + lines = first_line.split() + if len(lines) >= 2: + url = lines[1] + parsed = urlparse(url) + host = parsed.hostname + port = parsed.port or 80 + + # 连接到 SOCKS5 代理 + remote = self._connect_via_socks5(host, port) + if remote: + # 修改请求,去掉完整 URL + new_request = request.replace( + f'{lines[0]} {url} '.encode(), + f'{lines[0]} {parsed.path or "/"}{"?" + parsed.query if parsed.query else ""} '.encode() + ) + remote.send(new_request) + self._relay(client, remote) + + except Exception as e: + print(f"处理客户端错误: {e}") + finally: + client.close() + + def _connect_via_socks5(self, host, port): + """通过 SOCKS5 代理连接目标服务器""" + try: + # 连接到 SOCKS5 代理 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(30) + sock.connect((self.socks_host, self.socks_port)) + + # SOCKS5 握手 + # 1. 发送认证方法 + sock.send(b'\x05\x01\x00') # VER=5, NMETHODS=1, METHOD=0 (无认证) + resp = sock.recv(2) + if resp[0] != 0x05 or resp[1] != 0x00: + sock.close() + return None + + # 2. 发送连接请求 + req = b'\x05\x01\x00\x03' # VER=5, CMD=CONNECT, RSV=0, ATYP=DOMAIN + req += bytes([len(host)]) + host.encode() + req += bytes([(port >> 8) & 0xFF, port & 0xFF]) + sock.send(req) + + # 3. 读取响应 + resp = sock.recv(10) + if len(resp) < 4 or resp[1] != 0x00: + sock.close() + return None + + return sock + + except Exception as e: + print(f"SOCKS5 连接错误: {e}") + return None + + def _relay(self, client, remote): + """双向转发数据""" + try: + while True: + readable, _, _ = select.select([client, remote], [], [], 60) + if not readable: + break + + if client in readable: + data = client.recv(4096) + if not data: + break + remote.send(data) + + if remote in readable: + data = remote.recv(4096) + if not data: + break + client.send(data) + except: + pass + finally: + client.close() + remote.close() + + +if __name__ == '__main__': + proxy = Socks2Http(socks_port=1080, http_port=8080) + try: + proxy.start() + except KeyboardInterrupt: + proxy.stop() diff --git a/core/datasource/yfinance_source.py b/core/datasource/yfinance_source.py new file mode 100644 index 0000000..a1e1818 --- /dev/null +++ b/core/datasource/yfinance_source.py @@ -0,0 +1,194 @@ +""" +YFinance 数据源模块 +支持通过 SSH 隧道访问(用于绕过网络限制) +""" + +import os +import time +import subprocess +from typing import Optional, Tuple +from datetime import datetime +import pandas as pd +import yfinance as yf +import urllib3 + +# 禁用 SSL 警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + + +class SSHTunnelManager: + """SSH 隧道管理器""" + + def __init__(self, config: dict): + from pathlib import Path + + self.enabled = config.get("enabled", False) + self.host = config.get("host", "") + self.port = config.get("port", 22) + self.username = config.get("username", "") + self.local_port = config.get("local_port", 1080) + self._process: Optional[subprocess.Popen] = None + + # 处理 key_path:如果是相对路径,转换为绝对路径 + key_path = config.get("key_path", "") + if key_path and not os.path.isabs(key_path): + # 相对于项目根目录 + project_root = Path(__file__).parent.parent.parent + key_path = str(project_root / key_path) + self.key_path = key_path + + def start(self) -> bool: + """启动 SSH 隧道""" + if not self.enabled: + return True + + if not all([self.host, self.username, self.key_path]): + print("SSH 配置不完整,跳过隧道建立") + return False + + print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}") + + cmd = [ + "ssh", + "-N", + "-D", f"127.0.0.1:{self.local_port}", + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-i", self.key_path, + "-p", str(self.port), + f"{self.username}@{self.host}" + ] + + try: + self._process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + time.sleep(2) + + if self._process.poll() is not None: + stdout, stderr = self._process.communicate() + print("✗ SSH 隧道建立失败") + if stderr: + print(f"错误: {stderr.decode()}") + return False + + # 设置代理环境变量 + proxy_url = f"socks5://127.0.0.1:{self.local_port}" + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + os.environ["ALL_PROXY"] = proxy_url + + print(f"✓ SSH 隧道已建立: {proxy_url}") + time.sleep(1) + return True + + except Exception as e: + print(f"✗ SSH 隧道异常: {e}") + return False + + def stop(self): + """停止 SSH 隧道""" + if self._process: + self._process.terminate() + self._process.wait() + # 清理环境变量 + for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]: + os.environ.pop(key, None) + print("SSH 隧道已关闭") + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + +class YFinanceDataSource: + """YFinance 数据源""" + + def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True): + self.ssh_config = ssh_config or {} + self.use_cache = use_cache + self._tunnel: Optional[SSHTunnelManager] = None + + def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + """获取单个 ETF 数据""" + try: + ticker = yf.Ticker(code) + data = ticker.history(start=start_date, end=end_date) + + if len(data) == 0: + return None + + # 标准化列名 + data = data.rename(columns={ + "Open": "open", + "High": "high", + "Low": "low", + "Close": "close", + "Volume": "volume", + }) + + # 添加代码列 + data["code"] = code + + return data + + except Exception as e: + print(f"下载 {code} 失败: {e}") + return None + + def fetch_all( + self, + code_list: list, + benchmark_code: str, + start_date: str, + end_date: str, + ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]: + """ + 批量获取 ETF 数据和基准数据 + + Returns: + (etf_data, benchmark_data, valid_codes) + """ + all_data = [] + valid_codes = [] + + print(f"开始下载 {len(code_list)} 只 ETF 数据...") + + for code in code_list: + data = self.fetch_single(code, start_date, end_date) + if data is not None and len(data) > 0: + all_data.append(data) + valid_codes.append(code) + print(f" ✓ {code}: {len(data)} 条") + else: + print(f" ✗ {code}: 无数据") + + if not all_data: + return None, None, [] + + # 合并数据 + etf_data = pd.concat(all_data, ignore_index=False) + + # 获取基准数据 + benchmark_data = self.fetch_single(benchmark_code, start_date, end_date) + if benchmark_data is not None: + print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条") + + return etf_data, benchmark_data, valid_codes + + def __enter__(self): + """上下文管理器入口""" + if self.ssh_config.get("enabled"): + self._tunnel = SSHTunnelManager(self.ssh_config) + self._tunnel.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """上下文管理器出口""" + if self._tunnel: + self._tunnel.stop() diff --git a/strategies/rotation/engine.py b/strategies/rotation/engine.py index 6b7a8f9..87820a0 100644 --- a/strategies/rotation/engine.py +++ b/strategies/rotation/engine.py @@ -10,7 +10,7 @@ import numpy as np from typing import Optional from strategies.base import BacktestStrategy -from core.data.hybrid_source import HybridDataSource +from core.datasource.hybrid_source import HybridDataSource from core.factors.momentum import compute_factors, calculate_daily_return