fix(datasource): 修正混合数据源导入路径错误

- 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误
- 新增 core.datasource 目录下多个数据源实现模块
- 增加 Akshare 数据源支持 A股指数数据拉取
- 实现数据缓存管理机制,支持本地数据缓存读写
- 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据
- 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问
- 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发
- 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题
- 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问
- 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则
This commit is contained in:
2026-03-25 01:32:33 +08:00
parent c104fca693
commit 6454e6823f
10 changed files with 1083 additions and 4 deletions

1
.gitignore vendored
View File

@@ -177,7 +177,6 @@ config_local.py
*.backup
data/
test/
# Cache and generated files

View File

@@ -34,8 +34,10 @@ code_list:
"HSTECH": "恒生科技" # 港股
"NDX": "纳指100" # 美股
"GC=F": "黄金" # 黄金期货 (COMEX)
"BTC": "比特币" # 加密货币
"ETH": "以太坊" # 加密货币
# 加密货币 (使用 CCXT/OKX 现货) - 通过 SSH->HTTP 代理访问
"BTC": "比特币" # OKX 现货
"ETH": "以太坊" # OKX 现货
# 主市场配置(用于确定交易日历)
primary_market:

View File

View File

@@ -0,0 +1,128 @@
"""
akshare数据源实现用于CCI筛选等场景
"""
import time
import pandas as pd
import akshare as ak
from typing import Optional
from .base import DataSource
class AkshareDataSource(DataSource):
"""基于akshare的数据源"""
def __init__(self, delay: float = 3.0):
"""
初始化akshare数据源
Args:
delay: 每次请求间隔秒数(避免触发限流)
"""
self.delay = delay
def fetch_ohlcv(
self,
code: str,
start_date: str,
end_date: str,
fields: Optional[list] = None,
) -> pd.DataFrame:
"""
获取指数历史数据使用akshare的东方财富接口
Args:
code: 指数代码,如 '000300'(不带后缀)
start_date: 起始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame包含 date, open, high, low, close, volume
"""
# 转换日期格式
sd = start_date.replace("-", "")
ed = end_date.replace("-", "")
# 去除后缀
symbol = code.replace(".SH", "").replace(".SZ", "")
time.sleep(self.delay)
try:
df = ak.index_zh_a_hist(
symbol=symbol,
period="daily",
start_date=sd,
end_date=ed,
)
except Exception as e:
raise RuntimeError(f"akshare查询失败 [{code}]: {e}")
if df is None or df.empty:
raise ValueError(f"akshare返回空数据: {code}")
# 统一列名
df = df.rename(
columns={
"日期": "date",
"开盘": "open",
"最高": "high",
"最低": "low",
"收盘": "close",
"成交量": "volume",
}
)
df["date"] = pd.to_datetime(df["date"])
df["code"] = code
df = df[["date", "code", "open", "high", "low", "close", "volume"]]
df = df.sort_values("date").reset_index(drop=True)
return df
def fetch_multiple(
self,
codes: list,
start_date: str,
end_date: str,
) -> tuple[pd.DataFrame, list]:
"""批量获取数据"""
df_list = []
failed = []
for code in codes:
try:
df = self.fetch_ohlcv(code, start_date, end_date)
df_list.append(df[["date", "code", "close"]].copy())
except Exception as e:
print(f" ⚠ 跳过 {code}: {e}")
failed.append(code)
if not df_list:
raise RuntimeError("所有数据获取失败")
all_df = pd.concat(df_list, ignore_index=True)
data = all_df.pivot(index="date", columns="code", values="close")
data = data.sort_index()
return data, failed
def fetch_index_list(self) -> pd.DataFrame:
"""
获取所有A股指数列表
Returns:
DataFrame with index info
"""
sources = ["沪深重要指数", "上证系列指数", "深证系列指数", "中证系列指数"]
df_list = []
for source in sources:
try:
df = ak.stock_zh_index_spot_em(symbol=source)
df["source"] = source
df_list.append(df)
except Exception as e:
print(f" ⚠ 获取 {source} 失败: {e}")
return pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()

53
core/datasource/base.py Normal file
View File

@@ -0,0 +1,53 @@
"""
数据源抽象基类
"""
from abc import ABC, abstractmethod
from typing import Optional
import pandas as pd
class DataSource(ABC):
"""数据源抽象基类"""
@abstractmethod
def fetch_ohlcv(
self,
code: str,
start_date: str,
end_date: str,
fields: Optional[list] = None,
) -> pd.DataFrame:
"""
获取OHLCV数据
Args:
code: 标的代码
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
fields: 指定字段列表None表示获取全部
Returns:
DataFrame with columns: date, open, high, low, close, volume
"""
pass
@abstractmethod
def fetch_multiple(
self,
codes: list,
start_date: str,
end_date: str,
) -> pd.DataFrame:
"""
批量获取多只标的收盘价数据
Args:
codes: 标的代码列表
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame, index=日期, columns=代码, values=收盘价
"""
pass

54
core/datasource/cache.py Normal file
View File

@@ -0,0 +1,54 @@
"""
数据缓存管理模块
"""
import os
import pandas as pd
from pathlib import Path
from typing import Optional
class DataCache:
"""CSV文件缓存管理器"""
def __init__(self, cache_dir: str = "data_cache"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(exist_ok=True)
def _get_cache_path(self, code: str, start_date: str, end_date: str) -> Path:
"""生成缓存文件路径"""
# 统一日期格式为 YYYYMMDD
sd = start_date.replace("-", "")
ed = end_date.replace("-", "")
safe_code = code.replace(".", "_")
return self.cache_dir / f"{safe_code}_{sd}_{ed}.csv"
def get(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
从缓存读取数据
Returns:
DataFrame or None缓存不存在
"""
cache_path = self._get_cache_path(code, start_date, end_date)
if cache_path.exists():
df = pd.read_csv(cache_path)
df["date"] = pd.to_datetime(df["date"])
return df
return None
def set(self, code: str, start_date: str, end_date: str, df: pd.DataFrame) -> None:
"""保存数据到缓存"""
cache_path = self._get_cache_path(code, start_date, end_date)
df.to_csv(cache_path, index=False)
def clear(self, code: str = None) -> None:
"""清除缓存"""
if code:
# 清除指定代码的缓存
for f in self.cache_dir.glob(f"{code.replace('.', '_')}*.csv"):
f.unlink()
else:
# 清除所有缓存
for f in self.cache_dir.glob("*.csv"):
f.unlink()

View File

@@ -0,0 +1,481 @@
"""
混合数据源模块
- 中国A股指数: Tushare
- 港股/美股: YFinance (支持 SSH 隧道)
- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理)
"""
import os
import sys
import time
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Dict
from datetime import datetime
import pandas as pd
import yfinance as yf
import urllib3
# 禁用 SSL 警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class SSHTunnelManager:
"""SSH 隧道管理器"""
def __init__(self, config: dict):
self.enabled = config.get("enabled", False)
self.host = config.get("host", "")
self.port = config.get("port", 22)
self.username = config.get("username", "")
self.local_port = config.get("local_port", 1080)
self._process: Optional[subprocess.Popen] = None
# 处理 key_path如果是相对路径转换为绝对路径
key_path = config.get("key_path", "")
if key_path and not os.path.isabs(key_path):
# 相对于项目根目录
project_root = Path(__file__).parent.parent.parent
key_path = str(project_root / key_path)
self.key_path = key_path
print(f"SSH 私钥路径: {self.key_path}")
def start(self) -> bool:
"""启动 SSH 隧道"""
if not self.enabled:
return True
if not all([self.host, self.username, self.key_path]):
print("SSH 配置不完整,跳过隧道建立")
return False
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
cmd = [
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", self.key_path,
"-p", str(self.port),
f"{self.username}@{self.host}"
]
try:
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
time.sleep(2)
if self._process.poll() is not None:
stdout, stderr = self._process.communicate()
print("✗ SSH 隧道建立失败")
if stderr:
print(f"错误: {stderr.decode()}")
return False
# 设置代理环境变量
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
os.environ["HTTP_PROXY"] = proxy_url
os.environ["HTTPS_PROXY"] = proxy_url
os.environ["ALL_PROXY"] = proxy_url
print(f"✓ SSH 隧道已建立: {proxy_url}")
time.sleep(1)
return True
except Exception as e:
print(f"✗ SSH 隧道异常: {e}")
return False
def stop(self):
"""停止 SSH 隧道"""
if self._process:
self._process.terminate()
self._process.wait()
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
os.environ.pop(key, None)
print("SSH 隧道已关闭")
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class HybridDataSource:
"""
混合数据源
- 中国A股指数 (SH/SZ): Tushare
- 港股/美股: YFinance
- 加密货币: CCXT/OKX (通过 SSH->HTTP 代理)
"""
# YFinance 代码映射 (代码 -> YFinance格式)
YF_CODE_MAP = {
# 港股
"HSTECH": "3033.HK", # 恒生科技指数 ETF
"HSI": "^HSI", # 恒生指数
# 美股指数
"NDX": "^NDX", # 纳斯达克100
"SPX": "^GSPC", # 标普500
"DJI": "^DJI", # 道琼斯
# 黄金
"GC=F": "GC=F", # 黄金期货
}
# CCXT 代码映射 (代码 -> CCXT格式)
CCXT_CODE_MAP = {
"BTC": "BTC/USDT", # OKX 比特币现货
"ETH": "ETH/USDT", # OKX 以太坊现货
}
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._tunnel: Optional[SSHTunnelManager] = None
self._tushare_token: Optional[str] = None
def _is_china_index(self, code: str) -> bool:
"""判断是否为中国A股指数"""
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS")
def _is_crypto(self, code: str) -> bool:
"""判断是否为加密货币"""
return code in self.CCXT_CODE_MAP
def _get_tushare_token(self) -> str:
"""获取 Tushare Token"""
if self._tushare_token is None:
import os
self._tushare_token = os.getenv("TUSHARE_TOKEN")
if not self._tushare_token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
return self._tushare_token
def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)"""
import os
# 临时清除代理环境变量Tushare 是国内服务,不需要代理)
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式 (000300.SS -> 000300.SH)
ts_code = code.replace(".SS", ".SH")
# 获取日线数据
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 YFinance 获取数据"""
import time
# 转换代码格式
yf_code = self.YF_CODE_MAP.get(code, code)
# 添加延迟以避免限流
time.sleep(0.5)
try:
ticker = yf.Ticker(yf_code)
data = ticker.history(start=start_date, end=end_date)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 添加代码列
data["code"] = code
return data
except Exception as e:
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
return None
def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)"""
import ccxt
# 转换代码格式
ccxt_code = self.CCXT_CODE_MAP.get(code, code)
# 配置 CCXT
config = {'enableRateLimit': True}
if http_proxy:
config['proxies'] = {'http': http_proxy, 'https': http_proxy}
try:
exchange = ccxt.okx(config)
# 获取日线数据
since = int(pd.Timestamp(start_date).timestamp() * 1000)
all_ohlcv = []
limit = 100
while since < int(pd.Timestamp(end_date).timestamp() * 1000):
ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 86400000
if not all_ohlcv:
return None
# 转换为 DataFrame
df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# 转换时间戳为日期索引
df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize()
df.index.name = 'date'
df = df[['open', 'high', 'low', 'close', 'volume']]
# 过滤日期范围
start_ts = pd.Timestamp(start_date)
end_ts = pd.Timestamp(end_date)
df = df[(df.index >= start_ts) & (df.index <= end_ts)]
df['code'] = code
return df
except Exception as e:
print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}")
return None
def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""获取单个标的的数据"""
if self._is_china_index(code):
return self._fetch_tushare(code, start_date, end_date)
elif self._is_crypto(code):
return self._fetch_ccxt(code, start_date, end_date, http_proxy)
else:
return self._fetch_yfinance(code, start_date, end_date)
def fetch_all(
self,
code_list, # list[代码] 或 dict{代码: 名称}
benchmark_code: str,
start_date: str,
end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
"""
批量获取数据
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
这里返回的是长格式数据,由调用方分别处理各市场的数据
Returns:
(etf_data, benchmark_data, valid_codes)
etf_data: DataFrame with columns [code, close, source], index=date
"""
all_data = []
valid_codes = []
# 兼容列表和字典格式
if isinstance(code_list, dict):
codes = list(code_list.keys())
code_name_map = code_list
else:
codes = code_list
code_name_map = {c: c for c in codes}
print(f"开始下载 {len(codes)} 只标的的数据...")
china_codes = [c for c in codes if self._is_china_index(c)]
global_codes = [c for c in codes if not self._is_china_index(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 港股/美股/加密货币: {len(global_codes)}")
# 检查是否需要启动 socks2http 代理(用于加密货币)
crypto_codes = [c for c in codes if self._is_crypto(c)]
http_proxy = None
socks2http_proc = None
# 只有在 SSH 隧道已建立时才启动 socks2http
if crypto_codes and self._tunnel is not None:
import subprocess
import time
print(f"\n 启动 socks2http 代理服务(用于加密货币)...")
try:
# 启动 socks2http.py 作为子进程
socks2http_path = Path(__file__).parent / "socks2http.py"
socks2http_proc = subprocess.Popen(
[sys.executable, str(socks2http_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(2) # 等待代理启动
http_proxy = "http://127.0.0.1:8080"
print(f" ✓ HTTP 代理已启动: {http_proxy}")
except Exception as e:
print(f" ✗ 启动代理失败: {e}")
# 分别下载数据
for code in codes:
if self._is_china_index(code):
source = "Tushare"
elif self._is_crypto(code):
source = "CCXT/OKX"
else:
source = "YFinance"
name = code_name_map.get(code, code)
print(f" 下载 {code} ({name}) - {source}...", end=" ")
# 加密货币使用 HTTP 代理
proxy = http_proxy if self._is_crypto(code) else None
data = self.fetch_single(code, start_date, end_date, proxy)
if data is not None and len(data) > 0:
# 标准化数据格式
data = data.copy()
data['source'] = source
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
all_data.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 关闭 socks2http 代理
if socks2http_proc:
socks2http_proc.terminate()
socks2http_proc.wait()
print(f"\n socks2http 代理已关闭")
if not all_data:
return None, None, []
# 检查数据源类型
sources = set(d['source'].iloc[0] for d in all_data)
if len(sources) == 1:
# 单一数据源:转换为宽格式(向后兼容)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
etf_data = all_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
print(f" 交易日数: {len(etf_data)}")
print(f" 有效标的: {len(etf_data.columns)}")
else:
# 多数据源以主市场Tushare/A股为基准其他市场数据前向填充
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
# 合并所有数据(索引已经是标准化后的日期)
all_df = pd.concat(all_data, ignore_index=False)
all_df = all_df.reset_index()
# 重命名索引列为 date
if 'index' in all_df.columns:
all_df = all_df.rename(columns={'index': 'date'})
# 确保 date 列是日期格式(不含时间)
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
# 透视为宽格式
etf_data = all_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
# 获取主市场Tushare的交易日历
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
if tushare_codes:
# 使用第一个A股代码的日期作为主市场交易日
primary_dates = etf_data[tushare_codes[0]].dropna().index
print(f" 主市场交易日: {len(primary_dates)}")
# 重新索引到主市场交易日,使用前向填充
etf_data = etf_data.reindex(primary_dates)
# 对每个非主市场代码进行前向填充
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)]
for code in yfinance_codes:
if code in etf_data.columns:
# 前向填充:用最近的有效价格填充休市日的数据
etf_data[code] = etf_data[code].ffill()
# 对于开头的NaN用后向填充
etf_data[code] = etf_data[code].bfill()
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)")
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
print(f" 交易日数: {len(etf_data)}")
print(f" 有效标的: {len(etf_data.columns)}")
# 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None:
# 标准化日期索引(无时区,只保留日期部分)
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return etf_data, benchmark_data, valid_codes
def __enter__(self):
"""上下文管理器入口"""
if self.ssh_config.get("enabled"):
self._tunnel = SSHTunnelManager(self.ssh_config)
self._tunnel.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
if self._tunnel:
self._tunnel.stop()

View File

@@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
SOCKS5 转 HTTP 代理转发工具
将 SSH 隧道的 SOCKS5 代理 (1080) 转为 HTTP 代理 (8080)
供 CCXT 等只支持 HTTP 代理的库使用
"""
import socket
import threading
import select
import sys
from urllib.parse import urlparse
class Socks2Http:
def __init__(self, socks_host='127.0.0.1', socks_port=1080, http_port=8080):
self.socks_host = socks_host
self.socks_port = socks_port
self.http_port = http_port
self.server = None
self.running = False
def start(self):
"""启动 HTTP 代理服务器"""
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind(('127.0.0.1', self.http_port))
self.server.listen(5)
self.running = True
print(f"HTTP 代理已启动: http://127.0.0.1:{self.http_port}")
print(f"转发到 SOCKS5: {self.socks_host}:{self.socks_port}")
while self.running:
try:
client, addr = self.server.accept()
thread = threading.Thread(target=self._handle_client, args=(client,))
thread.daemon = True
thread.start()
except Exception as e:
if self.running:
print(f"接受连接错误: {e}")
def stop(self):
"""停止代理服务器"""
self.running = False
if self.server:
self.server.close()
print("HTTP 代理已停止")
def _handle_client(self, client):
"""处理客户端连接"""
try:
# 读取 HTTP 请求
request = client.recv(4096)
if not request:
client.close()
return
# 解析 CONNECT 请求
first_line = request.split(b'\r\n')[0].decode('utf-8', errors='ignore')
if first_line.startswith('CONNECT'):
# HTTPS 代理
parts = first_line.split()
if len(parts) >= 2:
target = parts[1]
host, port = target.rsplit(':', 1)
port = int(port)
# 连接到 SOCKS5 代理
remote = self._connect_via_socks5(host, port)
if remote:
client.send(b'HTTP/1.1 200 Connection established\r\n\r\n')
self._relay(client, remote)
else:
client.send(b'HTTP/1.1 502 Bad Gateway\r\n\r\n')
else:
# HTTP 代理
lines = first_line.split()
if len(lines) >= 2:
url = lines[1]
parsed = urlparse(url)
host = parsed.hostname
port = parsed.port or 80
# 连接到 SOCKS5 代理
remote = self._connect_via_socks5(host, port)
if remote:
# 修改请求,去掉完整 URL
new_request = request.replace(
f'{lines[0]} {url} '.encode(),
f'{lines[0]} {parsed.path or "/"}{"?" + parsed.query if parsed.query else ""} '.encode()
)
remote.send(new_request)
self._relay(client, remote)
except Exception as e:
print(f"处理客户端错误: {e}")
finally:
client.close()
def _connect_via_socks5(self, host, port):
"""通过 SOCKS5 代理连接目标服务器"""
try:
# 连接到 SOCKS5 代理
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(30)
sock.connect((self.socks_host, self.socks_port))
# SOCKS5 握手
# 1. 发送认证方法
sock.send(b'\x05\x01\x00') # VER=5, NMETHODS=1, METHOD=0 (无认证)
resp = sock.recv(2)
if resp[0] != 0x05 or resp[1] != 0x00:
sock.close()
return None
# 2. 发送连接请求
req = b'\x05\x01\x00\x03' # VER=5, CMD=CONNECT, RSV=0, ATYP=DOMAIN
req += bytes([len(host)]) + host.encode()
req += bytes([(port >> 8) & 0xFF, port & 0xFF])
sock.send(req)
# 3. 读取响应
resp = sock.recv(10)
if len(resp) < 4 or resp[1] != 0x00:
sock.close()
return None
return sock
except Exception as e:
print(f"SOCKS5 连接错误: {e}")
return None
def _relay(self, client, remote):
"""双向转发数据"""
try:
while True:
readable, _, _ = select.select([client, remote], [], [], 60)
if not readable:
break
if client in readable:
data = client.recv(4096)
if not data:
break
remote.send(data)
if remote in readable:
data = remote.recv(4096)
if not data:
break
client.send(data)
except:
pass
finally:
client.close()
remote.close()
if __name__ == '__main__':
proxy = Socks2Http(socks_port=1080, http_port=8080)
try:
proxy.start()
except KeyboardInterrupt:
proxy.stop()

View File

@@ -0,0 +1,194 @@
"""
YFinance 数据源模块
支持通过 SSH 隧道访问(用于绕过网络限制)
"""
import os
import time
import subprocess
from typing import Optional, Tuple
from datetime import datetime
import pandas as pd
import yfinance as yf
import urllib3
# 禁用 SSL 警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class SSHTunnelManager:
"""SSH 隧道管理器"""
def __init__(self, config: dict):
from pathlib import Path
self.enabled = config.get("enabled", False)
self.host = config.get("host", "")
self.port = config.get("port", 22)
self.username = config.get("username", "")
self.local_port = config.get("local_port", 1080)
self._process: Optional[subprocess.Popen] = None
# 处理 key_path如果是相对路径转换为绝对路径
key_path = config.get("key_path", "")
if key_path and not os.path.isabs(key_path):
# 相对于项目根目录
project_root = Path(__file__).parent.parent.parent
key_path = str(project_root / key_path)
self.key_path = key_path
def start(self) -> bool:
"""启动 SSH 隧道"""
if not self.enabled:
return True
if not all([self.host, self.username, self.key_path]):
print("SSH 配置不完整,跳过隧道建立")
return False
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
cmd = [
"ssh",
"-N",
"-D", f"127.0.0.1:{self.local_port}",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", self.key_path,
"-p", str(self.port),
f"{self.username}@{self.host}"
]
try:
self._process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(2)
if self._process.poll() is not None:
stdout, stderr = self._process.communicate()
print("✗ SSH 隧道建立失败")
if stderr:
print(f"错误: {stderr.decode()}")
return False
# 设置代理环境变量
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
os.environ["HTTP_PROXY"] = proxy_url
os.environ["HTTPS_PROXY"] = proxy_url
os.environ["ALL_PROXY"] = proxy_url
print(f"✓ SSH 隧道已建立: {proxy_url}")
time.sleep(1)
return True
except Exception as e:
print(f"✗ SSH 隧道异常: {e}")
return False
def stop(self):
"""停止 SSH 隧道"""
if self._process:
self._process.terminate()
self._process.wait()
# 清理环境变量
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
os.environ.pop(key, None)
print("SSH 隧道已关闭")
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class YFinanceDataSource:
"""YFinance 数据源"""
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._tunnel: Optional[SSHTunnelManager] = None
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""获取单个 ETF 数据"""
try:
ticker = yf.Ticker(code)
data = ticker.history(start=start_date, end=end_date)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 添加代码列
data["code"] = code
return data
except Exception as e:
print(f"下载 {code} 失败: {e}")
return None
def fetch_all(
self,
code_list: list,
benchmark_code: str,
start_date: str,
end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
"""
批量获取 ETF 数据和基准数据
Returns:
(etf_data, benchmark_data, valid_codes)
"""
all_data = []
valid_codes = []
print(f"开始下载 {len(code_list)} 只 ETF 数据...")
for code in code_list:
data = self.fetch_single(code, start_date, end_date)
if data is not None and len(data) > 0:
all_data.append(data)
valid_codes.append(code)
print(f"{code}: {len(data)}")
else:
print(f"{code}: 无数据")
if not all_data:
return None, None, []
# 合并数据
etf_data = pd.concat(all_data, ignore_index=False)
# 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None:
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return etf_data, benchmark_data, valid_codes
def __enter__(self):
"""上下文管理器入口"""
if self.ssh_config.get("enabled"):
self._tunnel = SSHTunnelManager(self.ssh_config)
self._tunnel.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
if self._tunnel:
self._tunnel.stop()

View File

@@ -10,7 +10,7 @@ import numpy as np
from typing import Optional
from strategies.base import BacktestStrategy
from core.data.hybrid_source import HybridDataSource
from core.datasource.hybrid_source import HybridDataSource
from core.factors.momentum import compute_factors, calculate_daily_return