Files
etf/core/datasource/hybrid_source.py
aszerW 6454e6823f fix(datasource): 修正混合数据源导入路径错误
- 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误
- 新增 core.datasource 目录下多个数据源实现模块
- 增加 Akshare 数据源支持 A股指数数据拉取
- 实现数据缓存管理机制,支持本地数据缓存读写
- 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据
- 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问
- 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发
- 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题
- 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问
- 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则
2026-03-25 01:32:33 +08:00

482 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
混合数据源模块
- 中国A股指数: Tushare
- 港股/美股: YFinance (支持 SSH 隧道)
- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理)
"""
import os
import sys
import time
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Dict
from datetime import datetime
import pandas as pd
import yfinance as yf
import urllib3
# 禁用 SSL 警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class SSHTunnelManager:
"""SSH 隧道管理器"""
def __init__(self, config: dict):
self.enabled = config.get("enabled", False)
self.host = config.get("host", "")
self.port = config.get("port", 22)
self.username = config.get("username", "")
self.local_port = config.get("local_port", 1080)
self._process: Optional[subprocess.Popen] = None
# 处理 key_path如果是相对路径转换为绝对路径
key_path = config.get("key_path", "")
if key_path and not os.path.isabs(key_path):
# 相对于项目根目录
project_root = Path(__file__).parent.parent.parent
key_path = str(project_root / key_path)
self.key_path = key_path
print(f"SSH 私钥路径: {self.key_path}")
def start(self) -> bool:
"""启动 SSH 隧道"""
if not self.enabled:
return True
if not all([self.host, self.username, self.key_path]):
print("SSH 配置不完整,跳过隧道建立")
return False
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
cmd = [
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", self.key_path,
"-p", str(self.port),
f"{self.username}@{self.host}"
]
try:
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
time.sleep(2)
if self._process.poll() is not None:
stdout, stderr = self._process.communicate()
print("✗ SSH 隧道建立失败")
if stderr:
print(f"错误: {stderr.decode()}")
return False
# 设置代理环境变量
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
os.environ["HTTP_PROXY"] = proxy_url
os.environ["HTTPS_PROXY"] = proxy_url
os.environ["ALL_PROXY"] = proxy_url
print(f"✓ SSH 隧道已建立: {proxy_url}")
time.sleep(1)
return True
except Exception as e:
print(f"✗ SSH 隧道异常: {e}")
return False
def stop(self):
"""停止 SSH 隧道"""
if self._process:
self._process.terminate()
self._process.wait()
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
os.environ.pop(key, None)
print("SSH 隧道已关闭")
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class HybridDataSource:
"""
混合数据源
- 中国A股指数 (SH/SZ): Tushare
- 港股/美股: YFinance
- 加密货币: CCXT/OKX (通过 SSH->HTTP 代理)
"""
# YFinance 代码映射 (代码 -> YFinance格式)
YF_CODE_MAP = {
# 港股
"HSTECH": "3033.HK", # 恒生科技指数 ETF
"HSI": "^HSI", # 恒生指数
# 美股指数
"NDX": "^NDX", # 纳斯达克100
"SPX": "^GSPC", # 标普500
"DJI": "^DJI", # 道琼斯
# 黄金
"GC=F": "GC=F", # 黄金期货
}
# CCXT 代码映射 (代码 -> CCXT格式)
CCXT_CODE_MAP = {
"BTC": "BTC/USDT", # OKX 比特币现货
"ETH": "ETH/USDT", # OKX 以太坊现货
}
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._tunnel: Optional[SSHTunnelManager] = None
self._tushare_token: Optional[str] = None
def _is_china_index(self, code: str) -> bool:
"""判断是否为中国A股指数"""
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS")
def _is_crypto(self, code: str) -> bool:
"""判断是否为加密货币"""
return code in self.CCXT_CODE_MAP
def _get_tushare_token(self) -> str:
"""获取 Tushare Token"""
if self._tushare_token is None:
import os
self._tushare_token = os.getenv("TUSHARE_TOKEN")
if not self._tushare_token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
return self._tushare_token
def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)"""
import os
# 临时清除代理环境变量Tushare 是国内服务,不需要代理)
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式 (000300.SS -> 000300.SH)
ts_code = code.replace(".SS", ".SH")
# 获取日线数据
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 YFinance 获取数据"""
import time
# 转换代码格式
yf_code = self.YF_CODE_MAP.get(code, code)
# 添加延迟以避免限流
time.sleep(0.5)
try:
ticker = yf.Ticker(yf_code)
data = ticker.history(start=start_date, end=end_date)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 添加代码列
data["code"] = code
return data
except Exception as e:
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
return None
def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)"""
import ccxt
# 转换代码格式
ccxt_code = self.CCXT_CODE_MAP.get(code, code)
# 配置 CCXT
config = {'enableRateLimit': True}
if http_proxy:
config['proxies'] = {'http': http_proxy, 'https': http_proxy}
try:
exchange = ccxt.okx(config)
# 获取日线数据
since = int(pd.Timestamp(start_date).timestamp() * 1000)
all_ohlcv = []
limit = 100
while since < int(pd.Timestamp(end_date).timestamp() * 1000):
ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 86400000
if not all_ohlcv:
return None
# 转换为 DataFrame
df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# 转换时间戳为日期索引
df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize()
df.index.name = 'date'
df = df[['open', 'high', 'low', 'close', 'volume']]
# 过滤日期范围
start_ts = pd.Timestamp(start_date)
end_ts = pd.Timestamp(end_date)
df = df[(df.index >= start_ts) & (df.index <= end_ts)]
df['code'] = code
return df
except Exception as e:
print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}")
return None
def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""获取单个标的的数据"""
if self._is_china_index(code):
return self._fetch_tushare(code, start_date, end_date)
elif self._is_crypto(code):
return self._fetch_ccxt(code, start_date, end_date, http_proxy)
else:
return self._fetch_yfinance(code, start_date, end_date)
def fetch_all(
self,
code_list, # list[代码] 或 dict{代码: 名称}
benchmark_code: str,
start_date: str,
end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
"""
批量获取数据
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
这里返回的是长格式数据,由调用方分别处理各市场的数据
Returns:
(etf_data, benchmark_data, valid_codes)
etf_data: DataFrame with columns [code, close, source], index=date
"""
all_data = []
valid_codes = []
# 兼容列表和字典格式
if isinstance(code_list, dict):
codes = list(code_list.keys())
code_name_map = code_list
else:
codes = code_list
code_name_map = {c: c for c in codes}
print(f"开始下载 {len(codes)} 只标的的数据...")
china_codes = [c for c in codes if self._is_china_index(c)]
global_codes = [c for c in codes if not self._is_china_index(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 港股/美股/加密货币: {len(global_codes)}")
# 检查是否需要启动 socks2http 代理(用于加密货币)
crypto_codes = [c for c in codes if self._is_crypto(c)]
http_proxy = None
socks2http_proc = None
# 只有在 SSH 隧道已建立时才启动 socks2http
if crypto_codes and self._tunnel is not None:
import subprocess
import time
print(f"\n 启动 socks2http 代理服务(用于加密货币)...")
try:
# 启动 socks2http.py 作为子进程
socks2http_path = Path(__file__).parent / "socks2http.py"
socks2http_proc = subprocess.Popen(
[sys.executable, str(socks2http_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(2) # 等待代理启动
http_proxy = "http://127.0.0.1:8080"
print(f" ✓ HTTP 代理已启动: {http_proxy}")
except Exception as e:
print(f" ✗ 启动代理失败: {e}")
# 分别下载数据
for code in codes:
if self._is_china_index(code):
source = "Tushare"
elif self._is_crypto(code):
source = "CCXT/OKX"
else:
source = "YFinance"
name = code_name_map.get(code, code)
print(f" 下载 {code} ({name}) - {source}...", end=" ")
# 加密货币使用 HTTP 代理
proxy = http_proxy if self._is_crypto(code) else None
data = self.fetch_single(code, start_date, end_date, proxy)
if data is not None and len(data) > 0:
# 标准化数据格式
data = data.copy()
data['source'] = source
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
all_data.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 关闭 socks2http 代理
if socks2http_proc:
socks2http_proc.terminate()
socks2http_proc.wait()
print(f"\n socks2http 代理已关闭")
if not all_data:
return None, None, []
# 检查数据源类型
sources = set(d['source'].iloc[0] for d in all_data)
if len(sources) == 1:
# 单一数据源:转换为宽格式(向后兼容)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
etf_data = all_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
print(f" 交易日数: {len(etf_data)}")
print(f" 有效标的: {len(etf_data.columns)}")
else:
# 多数据源以主市场Tushare/A股为基准其他市场数据前向填充
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
# 合并所有数据(索引已经是标准化后的日期)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
# 重命名索引列为 date
if 'index' in all_df.columns:
all_df = all_df.rename(columns={'index': 'date'})
# 确保 date 列是日期格式(不含时间)
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
# 透视为宽格式
etf_data = all_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
# 获取主市场Tushare的交易日历
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
if tushare_codes:
# 使用第一个A股代码的日期作为主市场交易日
primary_dates = etf_data[tushare_codes[0]].dropna().index
print(f" 主市场交易日: {len(primary_dates)}")
# 重新索引到主市场交易日,使用前向填充
etf_data = etf_data.reindex(primary_dates)
# 对每个非主市场代码进行前向填充
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)]
for code in yfinance_codes:
if code in etf_data.columns:
# 前向填充:用最近的有效价格填充休市日的数据
etf_data[code] = etf_data[code].ffill()
# 对于开头的NaN用后向填充
etf_data[code] = etf_data[code].bfill()
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)")
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
print(f" 交易日数: {len(etf_data)}")
print(f" 有效标的: {len(etf_data.columns)}")
# 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None:
# 标准化日期索引(无时区,只保留日期部分)
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return etf_data, benchmark_data, valid_codes
def __enter__(self):
"""上下文管理器入口"""
if self.ssh_config.get("enabled"):
self._tunnel = SSHTunnelManager(self.ssh_config)
self._tunnel.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
if self._tunnel:
self._tunnel.stop()