fix(datasource): 修正混合数据源导入路径错误
- 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误 - 新增 core.datasource 目录下多个数据源实现模块 - 增加 Akshare 数据源支持 A股指数数据拉取 - 实现数据缓存管理机制,支持本地数据缓存读写 - 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据 - 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问 - 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发 - 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题 - 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问 - 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -177,7 +177,6 @@ config_local.py
|
||||
*.backup
|
||||
|
||||
|
||||
data/
|
||||
test/
|
||||
|
||||
# Cache and generated files
|
||||
|
||||
@@ -34,8 +34,10 @@ code_list:
|
||||
"HSTECH": "恒生科技" # 港股
|
||||
"NDX": "纳指100" # 美股
|
||||
"GC=F": "黄金" # 黄金期货 (COMEX)
|
||||
"BTC": "比特币" # 加密货币
|
||||
"ETH": "以太坊" # 加密货币
|
||||
|
||||
# 加密货币 (使用 CCXT/OKX 现货) - 通过 SSH->HTTP 代理访问
|
||||
"BTC": "比特币" # OKX 现货
|
||||
"ETH": "以太坊" # OKX 现货
|
||||
|
||||
# 主市场配置(用于确定交易日历)
|
||||
primary_market:
|
||||
|
||||
0
core/datasource/__init__.py
Normal file
0
core/datasource/__init__.py
Normal file
128
core/datasource/akshare_source.py
Normal file
128
core/datasource/akshare_source.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
akshare数据源实现(用于CCI筛选等场景)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
import akshare as ak
|
||||
from typing import Optional
|
||||
|
||||
from .base import DataSource
|
||||
|
||||
|
||||
class AkshareDataSource(DataSource):
|
||||
"""基于akshare的数据源"""
|
||||
|
||||
def __init__(self, delay: float = 3.0):
|
||||
"""
|
||||
初始化akshare数据源
|
||||
|
||||
Args:
|
||||
delay: 每次请求间隔秒数(避免触发限流)
|
||||
"""
|
||||
self.delay = delay
|
||||
|
||||
def fetch_ohlcv(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
fields: Optional[list] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
获取指数历史数据(使用akshare的东方财富接口)
|
||||
|
||||
Args:
|
||||
code: 指数代码,如 '000300'(不带后缀)
|
||||
start_date: 起始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
|
||||
Returns:
|
||||
DataFrame,包含 date, open, high, low, close, volume
|
||||
"""
|
||||
# 转换日期格式
|
||||
sd = start_date.replace("-", "")
|
||||
ed = end_date.replace("-", "")
|
||||
|
||||
# 去除后缀
|
||||
symbol = code.replace(".SH", "").replace(".SZ", "")
|
||||
|
||||
time.sleep(self.delay)
|
||||
|
||||
try:
|
||||
df = ak.index_zh_a_hist(
|
||||
symbol=symbol,
|
||||
period="daily",
|
||||
start_date=sd,
|
||||
end_date=ed,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"akshare查询失败 [{code}]: {e}")
|
||||
|
||||
if df is None or df.empty:
|
||||
raise ValueError(f"akshare返回空数据: {code}")
|
||||
|
||||
# 统一列名
|
||||
df = df.rename(
|
||||
columns={
|
||||
"日期": "date",
|
||||
"开盘": "open",
|
||||
"最高": "high",
|
||||
"最低": "low",
|
||||
"收盘": "close",
|
||||
"成交量": "volume",
|
||||
}
|
||||
)
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df["code"] = code
|
||||
df = df[["date", "code", "open", "high", "low", "close", "volume"]]
|
||||
df = df.sort_values("date").reset_index(drop=True)
|
||||
|
||||
return df
|
||||
|
||||
def fetch_multiple(
|
||||
self,
|
||||
codes: list,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> tuple[pd.DataFrame, list]:
|
||||
"""批量获取数据"""
|
||||
df_list = []
|
||||
failed = []
|
||||
|
||||
for code in codes:
|
||||
try:
|
||||
df = self.fetch_ohlcv(code, start_date, end_date)
|
||||
df_list.append(df[["date", "code", "close"]].copy())
|
||||
except Exception as e:
|
||||
print(f" ⚠ 跳过 {code}: {e}")
|
||||
failed.append(code)
|
||||
|
||||
if not df_list:
|
||||
raise RuntimeError("所有数据获取失败")
|
||||
|
||||
all_df = pd.concat(df_list, ignore_index=True)
|
||||
data = all_df.pivot(index="date", columns="code", values="close")
|
||||
data = data.sort_index()
|
||||
|
||||
return data, failed
|
||||
|
||||
def fetch_index_list(self) -> pd.DataFrame:
|
||||
"""
|
||||
获取所有A股指数列表
|
||||
|
||||
Returns:
|
||||
DataFrame with index info
|
||||
"""
|
||||
sources = ["沪深重要指数", "上证系列指数", "深证系列指数", "中证系列指数"]
|
||||
df_list = []
|
||||
|
||||
for source in sources:
|
||||
try:
|
||||
df = ak.stock_zh_index_spot_em(symbol=source)
|
||||
df["source"] = source
|
||||
df_list.append(df)
|
||||
except Exception as e:
|
||||
print(f" ⚠ 获取 {source} 失败: {e}")
|
||||
|
||||
return pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
|
||||
53
core/datasource/base.py
Normal file
53
core/datasource/base.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
数据源抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DataSource(ABC):
|
||||
"""数据源抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
def fetch_ohlcv(
|
||||
self,
|
||||
code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
fields: Optional[list] = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
获取OHLCV数据
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start_date: 开始日期 (YYYY-MM-DD)
|
||||
end_date: 结束日期 (YYYY-MM-DD)
|
||||
fields: 指定字段列表,None表示获取全部
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_multiple(
|
||||
self,
|
||||
codes: list,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
批量获取多只标的收盘价数据
|
||||
|
||||
Args:
|
||||
codes: 标的代码列表
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
DataFrame, index=日期, columns=代码, values=收盘价
|
||||
"""
|
||||
pass
|
||||
54
core/datasource/cache.py
Normal file
54
core/datasource/cache.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
数据缓存管理模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class DataCache:
|
||||
"""CSV文件缓存管理器"""
|
||||
|
||||
def __init__(self, cache_dir: str = "data_cache"):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(exist_ok=True)
|
||||
|
||||
def _get_cache_path(self, code: str, start_date: str, end_date: str) -> Path:
|
||||
"""生成缓存文件路径"""
|
||||
# 统一日期格式为 YYYYMMDD
|
||||
sd = start_date.replace("-", "")
|
||||
ed = end_date.replace("-", "")
|
||||
safe_code = code.replace(".", "_")
|
||||
return self.cache_dir / f"{safe_code}_{sd}_{ed}.csv"
|
||||
|
||||
def get(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
从缓存读取数据
|
||||
|
||||
Returns:
|
||||
DataFrame or None(缓存不存在)
|
||||
"""
|
||||
cache_path = self._get_cache_path(code, start_date, end_date)
|
||||
if cache_path.exists():
|
||||
df = pd.read_csv(cache_path)
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
return df
|
||||
return None
|
||||
|
||||
def set(self, code: str, start_date: str, end_date: str, df: pd.DataFrame) -> None:
|
||||
"""保存数据到缓存"""
|
||||
cache_path = self._get_cache_path(code, start_date, end_date)
|
||||
df.to_csv(cache_path, index=False)
|
||||
|
||||
def clear(self, code: str = None) -> None:
|
||||
"""清除缓存"""
|
||||
if code:
|
||||
# 清除指定代码的缓存
|
||||
for f in self.cache_dir.glob(f"{code.replace('.', '_')}*.csv"):
|
||||
f.unlink()
|
||||
else:
|
||||
# 清除所有缓存
|
||||
for f in self.cache_dir.glob("*.csv"):
|
||||
f.unlink()
|
||||
481
core/datasource/hybrid_source.py
Normal file
481
core/datasource/hybrid_source.py
Normal file
@@ -0,0 +1,481 @@
|
||||
"""
|
||||
混合数据源模块
|
||||
- 中国A股指数: Tushare
|
||||
- 港股/美股: YFinance (支持 SSH 隧道)
|
||||
- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Dict
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
import urllib3
|
||||
|
||||
# 禁用 SSL 警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
class SSHTunnelManager:
|
||||
"""SSH 隧道管理器"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.enabled = config.get("enabled", False)
|
||||
self.host = config.get("host", "")
|
||||
self.port = config.get("port", 22)
|
||||
self.username = config.get("username", "")
|
||||
self.local_port = config.get("local_port", 1080)
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
|
||||
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||
key_path = config.get("key_path", "")
|
||||
if key_path and not os.path.isabs(key_path):
|
||||
# 相对于项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
print(f"SSH 私钥路径: {self.key_path}")
|
||||
|
||||
def start(self) -> bool:
|
||||
"""启动 SSH 隧道"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
if not all([self.host, self.username, self.key_path]):
|
||||
print("SSH 配置不完整,跳过隧道建立")
|
||||
return False
|
||||
|
||||
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
|
||||
|
||||
cmd = [
|
||||
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-i", self.key_path,
|
||||
"-p", str(self.port),
|
||||
f"{self.username}@{self.host}"
|
||||
]
|
||||
|
||||
try:
|
||||
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
time.sleep(2)
|
||||
|
||||
if self._process.poll() is not None:
|
||||
stdout, stderr = self._process.communicate()
|
||||
print("✗ SSH 隧道建立失败")
|
||||
if stderr:
|
||||
print(f"错误: {stderr.decode()}")
|
||||
return False
|
||||
|
||||
# 设置代理环境变量
|
||||
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
|
||||
os.environ["HTTP_PROXY"] = proxy_url
|
||||
os.environ["HTTPS_PROXY"] = proxy_url
|
||||
os.environ["ALL_PROXY"] = proxy_url
|
||||
|
||||
print(f"✓ SSH 隧道已建立: {proxy_url}")
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ SSH 隧道异常: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""停止 SSH 隧道"""
|
||||
if self._process:
|
||||
self._process.terminate()
|
||||
self._process.wait()
|
||||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
||||
os.environ.pop(key, None)
|
||||
print("SSH 隧道已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
|
||||
class HybridDataSource:
|
||||
"""
|
||||
混合数据源
|
||||
- 中国A股指数 (SH/SZ): Tushare
|
||||
- 港股/美股: YFinance
|
||||
- 加密货币: CCXT/OKX (通过 SSH->HTTP 代理)
|
||||
"""
|
||||
|
||||
# YFinance 代码映射 (代码 -> YFinance格式)
|
||||
YF_CODE_MAP = {
|
||||
# 港股
|
||||
"HSTECH": "3033.HK", # 恒生科技指数 ETF
|
||||
"HSI": "^HSI", # 恒生指数
|
||||
# 美股指数
|
||||
"NDX": "^NDX", # 纳斯达克100
|
||||
"SPX": "^GSPC", # 标普500
|
||||
"DJI": "^DJI", # 道琼斯
|
||||
# 黄金
|
||||
"GC=F": "GC=F", # 黄金期货
|
||||
}
|
||||
|
||||
# CCXT 代码映射 (代码 -> CCXT格式)
|
||||
CCXT_CODE_MAP = {
|
||||
"BTC": "BTC/USDT", # OKX 比特币现货
|
||||
"ETH": "ETH/USDT", # OKX 以太坊现货
|
||||
}
|
||||
|
||||
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
|
||||
self.ssh_config = ssh_config or {}
|
||||
self.use_cache = use_cache
|
||||
self._tunnel: Optional[SSHTunnelManager] = None
|
||||
self._tushare_token: Optional[str] = None
|
||||
|
||||
def _is_china_index(self, code: str) -> bool:
|
||||
"""判断是否为中国A股指数"""
|
||||
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS")
|
||||
|
||||
def _is_crypto(self, code: str) -> bool:
|
||||
"""判断是否为加密货币"""
|
||||
return code in self.CCXT_CODE_MAP
|
||||
|
||||
def _get_tushare_token(self) -> str:
|
||||
"""获取 Tushare Token"""
|
||||
if self._tushare_token is None:
|
||||
import os
|
||||
self._tushare_token = os.getenv("TUSHARE_TOKEN")
|
||||
if not self._tushare_token:
|
||||
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
|
||||
return self._tushare_token
|
||||
|
||||
def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)"""
|
||||
import os
|
||||
|
||||
# 临时清除代理环境变量(Tushare 是国内服务,不需要代理)
|
||||
original_proxy = {}
|
||||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||
original_proxy[key] = os.environ.pop(key, None)
|
||||
|
||||
try:
|
||||
import tushare as ts
|
||||
|
||||
pro = ts.pro_api(self._get_tushare_token())
|
||||
|
||||
# 转换代码格式 (000300.SS -> 000300.SH)
|
||||
ts_code = code.replace(".SS", ".SH")
|
||||
|
||||
# 获取日线数据
|
||||
df = pro.index_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", "")
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
df = df.rename(columns={
|
||||
"trade_date": "date",
|
||||
"open": "open",
|
||||
"high": "high",
|
||||
"low": "low",
|
||||
"close": "close",
|
||||
"vol": "volume",
|
||||
})
|
||||
|
||||
# 转换日期格式
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.set_index("date")
|
||||
df = df.sort_index()
|
||||
|
||||
# 添加代码列
|
||||
df["code"] = code
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare 下载 {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 恢复代理环境变量
|
||||
for key, value in original_proxy.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
|
||||
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""使用 YFinance 获取数据"""
|
||||
import time
|
||||
|
||||
# 转换代码格式
|
||||
yf_code = self.YF_CODE_MAP.get(code, code)
|
||||
|
||||
# 添加延迟以避免限流
|
||||
time.sleep(0.5)
|
||||
|
||||
try:
|
||||
ticker = yf.Ticker(yf_code)
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
if len(data) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
data = data.rename(columns={
|
||||
"Open": "open",
|
||||
"High": "high",
|
||||
"Low": "low",
|
||||
"Close": "close",
|
||||
"Volume": "volume",
|
||||
})
|
||||
|
||||
# 添加代码列
|
||||
data["code"] = code
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
|
||||
return None
|
||||
|
||||
def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
|
||||
"""使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)"""
|
||||
import ccxt
|
||||
|
||||
# 转换代码格式
|
||||
ccxt_code = self.CCXT_CODE_MAP.get(code, code)
|
||||
|
||||
# 配置 CCXT
|
||||
config = {'enableRateLimit': True}
|
||||
if http_proxy:
|
||||
config['proxies'] = {'http': http_proxy, 'https': http_proxy}
|
||||
|
||||
try:
|
||||
exchange = ccxt.okx(config)
|
||||
|
||||
# 获取日线数据
|
||||
since = int(pd.Timestamp(start_date).timestamp() * 1000)
|
||||
all_ohlcv = []
|
||||
limit = 100
|
||||
|
||||
while since < int(pd.Timestamp(end_date).timestamp() * 1000):
|
||||
ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit)
|
||||
if not ohlcv:
|
||||
break
|
||||
all_ohlcv.extend(ohlcv)
|
||||
since = ohlcv[-1][0] + 86400000
|
||||
|
||||
if not all_ohlcv:
|
||||
return None
|
||||
|
||||
# 转换为 DataFrame
|
||||
df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||
# 转换时间戳为日期索引
|
||||
df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize()
|
||||
df.index.name = 'date'
|
||||
df = df[['open', 'high', 'low', 'close', 'volume']]
|
||||
# 过滤日期范围
|
||||
start_ts = pd.Timestamp(start_date)
|
||||
end_ts = pd.Timestamp(end_date)
|
||||
df = df[(df.index >= start_ts) & (df.index <= end_ts)]
|
||||
df['code'] = code
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}")
|
||||
return None
|
||||
|
||||
def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
|
||||
"""获取单个标的的数据"""
|
||||
if self._is_china_index(code):
|
||||
return self._fetch_tushare(code, start_date, end_date)
|
||||
elif self._is_crypto(code):
|
||||
return self._fetch_ccxt(code, start_date, end_date, http_proxy)
|
||||
else:
|
||||
return self._fetch_yfinance(code, start_date, end_date)
|
||||
|
||||
def fetch_all(
|
||||
self,
|
||||
code_list, # list[代码] 或 dict{代码: 名称}
|
||||
benchmark_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
||||
"""
|
||||
批量获取数据
|
||||
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
|
||||
这里返回的是长格式数据,由调用方分别处理各市场的数据
|
||||
|
||||
Returns:
|
||||
(etf_data, benchmark_data, valid_codes)
|
||||
etf_data: DataFrame with columns [code, close, source], index=date
|
||||
"""
|
||||
all_data = []
|
||||
valid_codes = []
|
||||
|
||||
# 兼容列表和字典格式
|
||||
if isinstance(code_list, dict):
|
||||
codes = list(code_list.keys())
|
||||
code_name_map = code_list
|
||||
else:
|
||||
codes = code_list
|
||||
code_name_map = {c: c for c in codes}
|
||||
|
||||
print(f"开始下载 {len(codes)} 只标的的数据...")
|
||||
china_codes = [c for c in codes if self._is_china_index(c)]
|
||||
global_codes = [c for c in codes if not self._is_china_index(c)]
|
||||
print(f" 中国A股指数: {len(china_codes)} 只")
|
||||
print(f" 港股/美股/加密货币: {len(global_codes)} 只")
|
||||
|
||||
# 检查是否需要启动 socks2http 代理(用于加密货币)
|
||||
crypto_codes = [c for c in codes if self._is_crypto(c)]
|
||||
http_proxy = None
|
||||
socks2http_proc = None
|
||||
|
||||
# 只有在 SSH 隧道已建立时才启动 socks2http
|
||||
if crypto_codes and self._tunnel is not None:
|
||||
import subprocess
|
||||
import time
|
||||
print(f"\n 启动 socks2http 代理服务(用于加密货币)...")
|
||||
try:
|
||||
# 启动 socks2http.py 作为子进程
|
||||
socks2http_path = Path(__file__).parent / "socks2http.py"
|
||||
socks2http_proc = subprocess.Popen(
|
||||
[sys.executable, str(socks2http_path)],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
time.sleep(2) # 等待代理启动
|
||||
http_proxy = "http://127.0.0.1:8080"
|
||||
print(f" ✓ HTTP 代理已启动: {http_proxy}")
|
||||
except Exception as e:
|
||||
print(f" ✗ 启动代理失败: {e}")
|
||||
|
||||
# 分别下载数据
|
||||
for code in codes:
|
||||
if self._is_china_index(code):
|
||||
source = "Tushare"
|
||||
elif self._is_crypto(code):
|
||||
source = "CCXT/OKX"
|
||||
else:
|
||||
source = "YFinance"
|
||||
|
||||
name = code_name_map.get(code, code)
|
||||
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
||||
|
||||
# 加密货币使用 HTTP 代理
|
||||
proxy = http_proxy if self._is_crypto(code) else None
|
||||
data = self.fetch_single(code, start_date, end_date, proxy)
|
||||
|
||||
if data is not None and len(data) > 0:
|
||||
# 标准化数据格式
|
||||
data = data.copy()
|
||||
data['source'] = source
|
||||
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
|
||||
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
||||
all_data.append(data[['code', 'close', 'source']])
|
||||
valid_codes.append(code)
|
||||
print(f"✓ {len(data)} 条")
|
||||
else:
|
||||
print("✗ 无数据")
|
||||
|
||||
# 关闭 socks2http 代理
|
||||
if socks2http_proc:
|
||||
socks2http_proc.terminate()
|
||||
socks2http_proc.wait()
|
||||
print(f"\n socks2http 代理已关闭")
|
||||
|
||||
if not all_data:
|
||||
return None, None, []
|
||||
|
||||
# 检查数据源类型
|
||||
sources = set(d['source'].iloc[0] for d in all_data)
|
||||
|
||||
if len(sources) == 1:
|
||||
# 单一数据源:转换为宽格式(向后兼容)
|
||||
all_df = pd.concat(all_data, ignore_index=False)
|
||||
all_df = all_df.reset_index()
|
||||
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
|
||||
etf_data = all_df.pivot_table(
|
||||
index='date',
|
||||
columns='code',
|
||||
values='close',
|
||||
aggfunc='first'
|
||||
)
|
||||
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
|
||||
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
||||
print(f" 交易日数: {len(etf_data)}")
|
||||
print(f" 有效标的: {len(etf_data.columns)} 只")
|
||||
else:
|
||||
# 多数据源:以主市场(Tushare/A股)为基准,其他市场数据前向填充
|
||||
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
|
||||
|
||||
# 合并所有数据(索引已经是标准化后的日期)
|
||||
all_df = pd.concat(all_data, ignore_index=False)
|
||||
all_df = all_df.reset_index()
|
||||
# 重命名索引列为 date
|
||||
if 'index' in all_df.columns:
|
||||
all_df = all_df.rename(columns={'index': 'date'})
|
||||
# 确保 date 列是日期格式(不含时间)
|
||||
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
|
||||
|
||||
# 透视为宽格式
|
||||
etf_data = all_df.pivot_table(
|
||||
index='date',
|
||||
columns='code',
|
||||
values='close',
|
||||
aggfunc='first'
|
||||
)
|
||||
|
||||
# 获取主市场(Tushare)的交易日历
|
||||
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
|
||||
if tushare_codes:
|
||||
# 使用第一个A股代码的日期作为主市场交易日
|
||||
primary_dates = etf_data[tushare_codes[0]].dropna().index
|
||||
print(f" 主市场交易日: {len(primary_dates)} 天")
|
||||
|
||||
# 重新索引到主市场交易日,使用前向填充
|
||||
etf_data = etf_data.reindex(primary_dates)
|
||||
|
||||
# 对每个非主市场代码进行前向填充
|
||||
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)]
|
||||
for code in yfinance_codes:
|
||||
if code in etf_data.columns:
|
||||
# 前向填充:用最近的有效价格填充休市日的数据
|
||||
etf_data[code] = etf_data[code].ffill()
|
||||
# 对于开头的NaN,用后向填充
|
||||
etf_data[code] = etf_data[code].bfill()
|
||||
|
||||
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)")
|
||||
|
||||
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
||||
print(f" 交易日数: {len(etf_data)}")
|
||||
print(f" 有效标的: {len(etf_data.columns)} 只")
|
||||
|
||||
# 获取基准数据
|
||||
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
||||
if benchmark_data is not None:
|
||||
# 标准化日期索引(无时区,只保留日期部分)
|
||||
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||||
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||
|
||||
return etf_data, benchmark_data, valid_codes
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
if self.ssh_config.get("enabled"):
|
||||
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||
self._tunnel.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
if self._tunnel:
|
||||
self._tunnel.stop()
|
||||
168
core/datasource/socks2http.py
Normal file
168
core/datasource/socks2http.py
Normal file
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SOCKS5 转 HTTP 代理转发工具
|
||||
将 SSH 隧道的 SOCKS5 代理 (1080) 转为 HTTP 代理 (8080)
|
||||
供 CCXT 等只支持 HTTP 代理的库使用
|
||||
"""
|
||||
|
||||
import socket
|
||||
import threading
|
||||
import select
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class Socks2Http:
|
||||
def __init__(self, socks_host='127.0.0.1', socks_port=1080, http_port=8080):
|
||||
self.socks_host = socks_host
|
||||
self.socks_port = socks_port
|
||||
self.http_port = http_port
|
||||
self.server = None
|
||||
self.running = False
|
||||
|
||||
def start(self):
|
||||
"""启动 HTTP 代理服务器"""
|
||||
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.server.bind(('127.0.0.1', self.http_port))
|
||||
self.server.listen(5)
|
||||
self.running = True
|
||||
|
||||
print(f"HTTP 代理已启动: http://127.0.0.1:{self.http_port}")
|
||||
print(f"转发到 SOCKS5: {self.socks_host}:{self.socks_port}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
client, addr = self.server.accept()
|
||||
thread = threading.Thread(target=self._handle_client, args=(client,))
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except Exception as e:
|
||||
if self.running:
|
||||
print(f"接受连接错误: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止代理服务器"""
|
||||
self.running = False
|
||||
if self.server:
|
||||
self.server.close()
|
||||
print("HTTP 代理已停止")
|
||||
|
||||
def _handle_client(self, client):
|
||||
"""处理客户端连接"""
|
||||
try:
|
||||
# 读取 HTTP 请求
|
||||
request = client.recv(4096)
|
||||
if not request:
|
||||
client.close()
|
||||
return
|
||||
|
||||
# 解析 CONNECT 请求
|
||||
first_line = request.split(b'\r\n')[0].decode('utf-8', errors='ignore')
|
||||
|
||||
if first_line.startswith('CONNECT'):
|
||||
# HTTPS 代理
|
||||
parts = first_line.split()
|
||||
if len(parts) >= 2:
|
||||
target = parts[1]
|
||||
host, port = target.rsplit(':', 1)
|
||||
port = int(port)
|
||||
|
||||
# 连接到 SOCKS5 代理
|
||||
remote = self._connect_via_socks5(host, port)
|
||||
if remote:
|
||||
client.send(b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
self._relay(client, remote)
|
||||
else:
|
||||
client.send(b'HTTP/1.1 502 Bad Gateway\r\n\r\n')
|
||||
else:
|
||||
# HTTP 代理
|
||||
lines = first_line.split()
|
||||
if len(lines) >= 2:
|
||||
url = lines[1]
|
||||
parsed = urlparse(url)
|
||||
host = parsed.hostname
|
||||
port = parsed.port or 80
|
||||
|
||||
# 连接到 SOCKS5 代理
|
||||
remote = self._connect_via_socks5(host, port)
|
||||
if remote:
|
||||
# 修改请求,去掉完整 URL
|
||||
new_request = request.replace(
|
||||
f'{lines[0]} {url} '.encode(),
|
||||
f'{lines[0]} {parsed.path or "/"}{"?" + parsed.query if parsed.query else ""} '.encode()
|
||||
)
|
||||
remote.send(new_request)
|
||||
self._relay(client, remote)
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理客户端错误: {e}")
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
def _connect_via_socks5(self, host, port):
|
||||
"""通过 SOCKS5 代理连接目标服务器"""
|
||||
try:
|
||||
# 连接到 SOCKS5 代理
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(30)
|
||||
sock.connect((self.socks_host, self.socks_port))
|
||||
|
||||
# SOCKS5 握手
|
||||
# 1. 发送认证方法
|
||||
sock.send(b'\x05\x01\x00') # VER=5, NMETHODS=1, METHOD=0 (无认证)
|
||||
resp = sock.recv(2)
|
||||
if resp[0] != 0x05 or resp[1] != 0x00:
|
||||
sock.close()
|
||||
return None
|
||||
|
||||
# 2. 发送连接请求
|
||||
req = b'\x05\x01\x00\x03' # VER=5, CMD=CONNECT, RSV=0, ATYP=DOMAIN
|
||||
req += bytes([len(host)]) + host.encode()
|
||||
req += bytes([(port >> 8) & 0xFF, port & 0xFF])
|
||||
sock.send(req)
|
||||
|
||||
# 3. 读取响应
|
||||
resp = sock.recv(10)
|
||||
if len(resp) < 4 or resp[1] != 0x00:
|
||||
sock.close()
|
||||
return None
|
||||
|
||||
return sock
|
||||
|
||||
except Exception as e:
|
||||
print(f"SOCKS5 连接错误: {e}")
|
||||
return None
|
||||
|
||||
def _relay(self, client, remote):
|
||||
"""双向转发数据"""
|
||||
try:
|
||||
while True:
|
||||
readable, _, _ = select.select([client, remote], [], [], 60)
|
||||
if not readable:
|
||||
break
|
||||
|
||||
if client in readable:
|
||||
data = client.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
remote.send(data)
|
||||
|
||||
if remote in readable:
|
||||
data = remote.recv(4096)
|
||||
if not data:
|
||||
break
|
||||
client.send(data)
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
client.close()
|
||||
remote.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
proxy = Socks2Http(socks_port=1080, http_port=8080)
|
||||
try:
|
||||
proxy.start()
|
||||
except KeyboardInterrupt:
|
||||
proxy.stop()
|
||||
194
core/datasource/yfinance_source.py
Normal file
194
core/datasource/yfinance_source.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
YFinance 数据源模块
|
||||
支持通过 SSH 隧道访问(用于绕过网络限制)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import subprocess
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
import yfinance as yf
|
||||
import urllib3
|
||||
|
||||
# 禁用 SSL 警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
class SSHTunnelManager:
|
||||
"""SSH 隧道管理器"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
from pathlib import Path
|
||||
|
||||
self.enabled = config.get("enabled", False)
|
||||
self.host = config.get("host", "")
|
||||
self.port = config.get("port", 22)
|
||||
self.username = config.get("username", "")
|
||||
self.local_port = config.get("local_port", 1080)
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
|
||||
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||
key_path = config.get("key_path", "")
|
||||
if key_path and not os.path.isabs(key_path):
|
||||
# 相对于项目根目录
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
|
||||
def start(self) -> bool:
|
||||
"""启动 SSH 隧道"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
if not all([self.host, self.username, self.key_path]):
|
||||
print("SSH 配置不完整,跳过隧道建立")
|
||||
return False
|
||||
|
||||
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
|
||||
|
||||
cmd = [
|
||||
"ssh",
|
||||
"-N",
|
||||
"-D", f"127.0.0.1:{self.local_port}",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-i", self.key_path,
|
||||
"-p", str(self.port),
|
||||
f"{self.username}@{self.host}"
|
||||
]
|
||||
|
||||
try:
|
||||
self._process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
time.sleep(2)
|
||||
|
||||
if self._process.poll() is not None:
|
||||
stdout, stderr = self._process.communicate()
|
||||
print("✗ SSH 隧道建立失败")
|
||||
if stderr:
|
||||
print(f"错误: {stderr.decode()}")
|
||||
return False
|
||||
|
||||
# 设置代理环境变量
|
||||
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
|
||||
os.environ["HTTP_PROXY"] = proxy_url
|
||||
os.environ["HTTPS_PROXY"] = proxy_url
|
||||
os.environ["ALL_PROXY"] = proxy_url
|
||||
|
||||
print(f"✓ SSH 隧道已建立: {proxy_url}")
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ SSH 隧道异常: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""停止 SSH 隧道"""
|
||||
if self._process:
|
||||
self._process.terminate()
|
||||
self._process.wait()
|
||||
# 清理环境变量
|
||||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
||||
os.environ.pop(key, None)
|
||||
print("SSH 隧道已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
|
||||
class YFinanceDataSource:
|
||||
"""YFinance 数据源"""
|
||||
|
||||
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
|
||||
self.ssh_config = ssh_config or {}
|
||||
self.use_cache = use_cache
|
||||
self._tunnel: Optional[SSHTunnelManager] = None
|
||||
|
||||
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""获取单个 ETF 数据"""
|
||||
try:
|
||||
ticker = yf.Ticker(code)
|
||||
data = ticker.history(start=start_date, end=end_date)
|
||||
|
||||
if len(data) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
data = data.rename(columns={
|
||||
"Open": "open",
|
||||
"High": "high",
|
||||
"Low": "low",
|
||||
"Close": "close",
|
||||
"Volume": "volume",
|
||||
})
|
||||
|
||||
# 添加代码列
|
||||
data["code"] = code
|
||||
|
||||
return data
|
||||
|
||||
except Exception as e:
|
||||
print(f"下载 {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
def fetch_all(
|
||||
self,
|
||||
code_list: list,
|
||||
benchmark_code: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
||||
"""
|
||||
批量获取 ETF 数据和基准数据
|
||||
|
||||
Returns:
|
||||
(etf_data, benchmark_data, valid_codes)
|
||||
"""
|
||||
all_data = []
|
||||
valid_codes = []
|
||||
|
||||
print(f"开始下载 {len(code_list)} 只 ETF 数据...")
|
||||
|
||||
for code in code_list:
|
||||
data = self.fetch_single(code, start_date, end_date)
|
||||
if data is not None and len(data) > 0:
|
||||
all_data.append(data)
|
||||
valid_codes.append(code)
|
||||
print(f" ✓ {code}: {len(data)} 条")
|
||||
else:
|
||||
print(f" ✗ {code}: 无数据")
|
||||
|
||||
if not all_data:
|
||||
return None, None, []
|
||||
|
||||
# 合并数据
|
||||
etf_data = pd.concat(all_data, ignore_index=False)
|
||||
|
||||
# 获取基准数据
|
||||
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
||||
if benchmark_data is not None:
|
||||
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||
|
||||
return etf_data, benchmark_data, valid_codes
|
||||
|
||||
def __enter__(self):
|
||||
"""上下文管理器入口"""
|
||||
if self.ssh_config.get("enabled"):
|
||||
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||
self._tunnel.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""上下文管理器出口"""
|
||||
if self._tunnel:
|
||||
self._tunnel.stop()
|
||||
@@ -10,7 +10,7 @@ import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from strategies.base import BacktestStrategy
|
||||
from core.data.hybrid_source import HybridDataSource
|
||||
from core.datasource.hybrid_source import HybridDataSource
|
||||
from core.factors.momentum import compute_factors, calculate_daily_return
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user