fix(datasource): 修正混合数据源导入路径错误
- 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误 - 新增 core.datasource 目录下多个数据源实现模块 - 增加 Akshare 数据源支持 A股指数数据拉取 - 实现数据缓存管理机制,支持本地数据缓存读写 - 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据 - 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问 - 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发 - 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题 - 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问 - 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -177,7 +177,6 @@ config_local.py
|
|||||||
*.backup
|
*.backup
|
||||||
|
|
||||||
|
|
||||||
data/
|
|
||||||
test/
|
test/
|
||||||
|
|
||||||
# Cache and generated files
|
# Cache and generated files
|
||||||
|
|||||||
@@ -34,8 +34,10 @@ code_list:
|
|||||||
"HSTECH": "恒生科技" # 港股
|
"HSTECH": "恒生科技" # 港股
|
||||||
"NDX": "纳指100" # 美股
|
"NDX": "纳指100" # 美股
|
||||||
"GC=F": "黄金" # 黄金期货 (COMEX)
|
"GC=F": "黄金" # 黄金期货 (COMEX)
|
||||||
"BTC": "比特币" # 加密货币
|
|
||||||
"ETH": "以太坊" # 加密货币
|
# 加密货币 (使用 CCXT/OKX 现货) - 通过 SSH->HTTP 代理访问
|
||||||
|
"BTC": "比特币" # OKX 现货
|
||||||
|
"ETH": "以太坊" # OKX 现货
|
||||||
|
|
||||||
# 主市场配置(用于确定交易日历)
|
# 主市场配置(用于确定交易日历)
|
||||||
primary_market:
|
primary_market:
|
||||||
|
|||||||
0
core/datasource/__init__.py
Normal file
0
core/datasource/__init__.py
Normal file
128
core/datasource/akshare_source.py
Normal file
128
core/datasource/akshare_source.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""
|
||||||
|
akshare数据源实现(用于CCI筛选等场景)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pandas as pd
|
||||||
|
import akshare as ak
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from .base import DataSource
|
||||||
|
|
||||||
|
|
||||||
|
class AkshareDataSource(DataSource):
|
||||||
|
"""基于akshare的数据源"""
|
||||||
|
|
||||||
|
def __init__(self, delay: float = 3.0):
|
||||||
|
"""
|
||||||
|
初始化akshare数据源
|
||||||
|
|
||||||
|
Args:
|
||||||
|
delay: 每次请求间隔秒数(避免触发限流)
|
||||||
|
"""
|
||||||
|
self.delay = delay
|
||||||
|
|
||||||
|
def fetch_ohlcv(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
fields: Optional[list] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
获取指数历史数据(使用akshare的东方财富接口)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 指数代码,如 '000300'(不带后缀)
|
||||||
|
start_date: 起始日期 'YYYY-MM-DD'
|
||||||
|
end_date: 结束日期 'YYYY-MM-DD'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame,包含 date, open, high, low, close, volume
|
||||||
|
"""
|
||||||
|
# 转换日期格式
|
||||||
|
sd = start_date.replace("-", "")
|
||||||
|
ed = end_date.replace("-", "")
|
||||||
|
|
||||||
|
# 去除后缀
|
||||||
|
symbol = code.replace(".SH", "").replace(".SZ", "")
|
||||||
|
|
||||||
|
time.sleep(self.delay)
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = ak.index_zh_a_hist(
|
||||||
|
symbol=symbol,
|
||||||
|
period="daily",
|
||||||
|
start_date=sd,
|
||||||
|
end_date=ed,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"akshare查询失败 [{code}]: {e}")
|
||||||
|
|
||||||
|
if df is None or df.empty:
|
||||||
|
raise ValueError(f"akshare返回空数据: {code}")
|
||||||
|
|
||||||
|
# 统一列名
|
||||||
|
df = df.rename(
|
||||||
|
columns={
|
||||||
|
"日期": "date",
|
||||||
|
"开盘": "open",
|
||||||
|
"最高": "high",
|
||||||
|
"最低": "low",
|
||||||
|
"收盘": "close",
|
||||||
|
"成交量": "volume",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
|
df["code"] = code
|
||||||
|
df = df[["date", "code", "open", "high", "low", "close", "volume"]]
|
||||||
|
df = df.sort_values("date").reset_index(drop=True)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
def fetch_multiple(
|
||||||
|
self,
|
||||||
|
codes: list,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> tuple[pd.DataFrame, list]:
|
||||||
|
"""批量获取数据"""
|
||||||
|
df_list = []
|
||||||
|
failed = []
|
||||||
|
|
||||||
|
for code in codes:
|
||||||
|
try:
|
||||||
|
df = self.fetch_ohlcv(code, start_date, end_date)
|
||||||
|
df_list.append(df[["date", "code", "close"]].copy())
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠ 跳过 {code}: {e}")
|
||||||
|
failed.append(code)
|
||||||
|
|
||||||
|
if not df_list:
|
||||||
|
raise RuntimeError("所有数据获取失败")
|
||||||
|
|
||||||
|
all_df = pd.concat(df_list, ignore_index=True)
|
||||||
|
data = all_df.pivot(index="date", columns="code", values="close")
|
||||||
|
data = data.sort_index()
|
||||||
|
|
||||||
|
return data, failed
|
||||||
|
|
||||||
|
def fetch_index_list(self) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
获取所有A股指数列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with index info
|
||||||
|
"""
|
||||||
|
sources = ["沪深重要指数", "上证系列指数", "深证系列指数", "中证系列指数"]
|
||||||
|
df_list = []
|
||||||
|
|
||||||
|
for source in sources:
|
||||||
|
try:
|
||||||
|
df = ak.stock_zh_index_spot_em(symbol=source)
|
||||||
|
df["source"] = source
|
||||||
|
df_list.append(df)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ⚠ 获取 {source} 失败: {e}")
|
||||||
|
|
||||||
|
return pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()
|
||||||
53
core/datasource/base.py
Normal file
53
core/datasource/base.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""
|
||||||
|
数据源抽象基类
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource(ABC):
|
||||||
|
"""数据源抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch_ohlcv(
|
||||||
|
self,
|
||||||
|
code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
fields: Optional[list] = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
获取OHLCV数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 标的代码
|
||||||
|
start_date: 开始日期 (YYYY-MM-DD)
|
||||||
|
end_date: 结束日期 (YYYY-MM-DD)
|
||||||
|
fields: 指定字段列表,None表示获取全部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with columns: date, open, high, low, close, volume
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fetch_multiple(
|
||||||
|
self,
|
||||||
|
codes: list,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
批量获取多只标的收盘价数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
codes: 标的代码列表
|
||||||
|
start_date: 开始日期
|
||||||
|
end_date: 结束日期
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame, index=日期, columns=代码, values=收盘价
|
||||||
|
"""
|
||||||
|
pass
|
||||||
54
core/datasource/cache.py
Normal file
54
core/datasource/cache.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
数据缓存管理模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class DataCache:
|
||||||
|
"""CSV文件缓存管理器"""
|
||||||
|
|
||||||
|
def __init__(self, cache_dir: str = "data_cache"):
|
||||||
|
self.cache_dir = Path(cache_dir)
|
||||||
|
self.cache_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def _get_cache_path(self, code: str, start_date: str, end_date: str) -> Path:
|
||||||
|
"""生成缓存文件路径"""
|
||||||
|
# 统一日期格式为 YYYYMMDD
|
||||||
|
sd = start_date.replace("-", "")
|
||||||
|
ed = end_date.replace("-", "")
|
||||||
|
safe_code = code.replace(".", "_")
|
||||||
|
return self.cache_dir / f"{safe_code}_{sd}_{ed}.csv"
|
||||||
|
|
||||||
|
def get(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
从缓存读取数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame or None(缓存不存在)
|
||||||
|
"""
|
||||||
|
cache_path = self._get_cache_path(code, start_date, end_date)
|
||||||
|
if cache_path.exists():
|
||||||
|
df = pd.read_csv(cache_path)
|
||||||
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
|
return df
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, code: str, start_date: str, end_date: str, df: pd.DataFrame) -> None:
|
||||||
|
"""保存数据到缓存"""
|
||||||
|
cache_path = self._get_cache_path(code, start_date, end_date)
|
||||||
|
df.to_csv(cache_path, index=False)
|
||||||
|
|
||||||
|
def clear(self, code: str = None) -> None:
|
||||||
|
"""清除缓存"""
|
||||||
|
if code:
|
||||||
|
# 清除指定代码的缓存
|
||||||
|
for f in self.cache_dir.glob(f"{code.replace('.', '_')}*.csv"):
|
||||||
|
f.unlink()
|
||||||
|
else:
|
||||||
|
# 清除所有缓存
|
||||||
|
for f in self.cache_dir.glob("*.csv"):
|
||||||
|
f.unlink()
|
||||||
481
core/datasource/hybrid_source.py
Normal file
481
core/datasource/hybrid_source.py
Normal file
@@ -0,0 +1,481 @@
|
|||||||
|
"""
|
||||||
|
混合数据源模块
|
||||||
|
- 中国A股指数: Tushare
|
||||||
|
- 港股/美股: YFinance (支持 SSH 隧道)
|
||||||
|
- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
import pandas as pd
|
||||||
|
import yfinance as yf
|
||||||
|
import urllib3
|
||||||
|
|
||||||
|
# 禁用 SSL 警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelManager:
|
||||||
|
"""SSH 隧道管理器"""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
self.enabled = config.get("enabled", False)
|
||||||
|
self.host = config.get("host", "")
|
||||||
|
self.port = config.get("port", 22)
|
||||||
|
self.username = config.get("username", "")
|
||||||
|
self.local_port = config.get("local_port", 1080)
|
||||||
|
self._process: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
|
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||||
|
key_path = config.get("key_path", "")
|
||||||
|
if key_path and not os.path.isabs(key_path):
|
||||||
|
# 相对于项目根目录
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
key_path = str(project_root / key_path)
|
||||||
|
self.key_path = key_path
|
||||||
|
print(f"SSH 私钥路径: {self.key_path}")
|
||||||
|
|
||||||
|
def start(self) -> bool:
|
||||||
|
"""启动 SSH 隧道"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not all([self.host, self.username, self.key_path]):
|
||||||
|
print("SSH 配置不完整,跳过隧道建立")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
|
"-i", self.key_path,
|
||||||
|
"-p", str(self.port),
|
||||||
|
f"{self.username}@{self.host}"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
if self._process.poll() is not None:
|
||||||
|
stdout, stderr = self._process.communicate()
|
||||||
|
print("✗ SSH 隧道建立失败")
|
||||||
|
if stderr:
|
||||||
|
print(f"错误: {stderr.decode()}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 设置代理环境变量
|
||||||
|
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
|
||||||
|
os.environ["HTTP_PROXY"] = proxy_url
|
||||||
|
os.environ["HTTPS_PROXY"] = proxy_url
|
||||||
|
os.environ["ALL_PROXY"] = proxy_url
|
||||||
|
|
||||||
|
print(f"✓ SSH 隧道已建立: {proxy_url}")
|
||||||
|
time.sleep(1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ SSH 隧道异常: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止 SSH 隧道"""
|
||||||
|
if self._process:
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.wait()
|
||||||
|
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
print("SSH 隧道已关闭")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class HybridDataSource:
|
||||||
|
"""
|
||||||
|
混合数据源
|
||||||
|
- 中国A股指数 (SH/SZ): Tushare
|
||||||
|
- 港股/美股: YFinance
|
||||||
|
- 加密货币: CCXT/OKX (通过 SSH->HTTP 代理)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# YFinance 代码映射 (代码 -> YFinance格式)
|
||||||
|
YF_CODE_MAP = {
|
||||||
|
# 港股
|
||||||
|
"HSTECH": "3033.HK", # 恒生科技指数 ETF
|
||||||
|
"HSI": "^HSI", # 恒生指数
|
||||||
|
# 美股指数
|
||||||
|
"NDX": "^NDX", # 纳斯达克100
|
||||||
|
"SPX": "^GSPC", # 标普500
|
||||||
|
"DJI": "^DJI", # 道琼斯
|
||||||
|
# 黄金
|
||||||
|
"GC=F": "GC=F", # 黄金期货
|
||||||
|
}
|
||||||
|
|
||||||
|
# CCXT 代码映射 (代码 -> CCXT格式)
|
||||||
|
CCXT_CODE_MAP = {
|
||||||
|
"BTC": "BTC/USDT", # OKX 比特币现货
|
||||||
|
"ETH": "ETH/USDT", # OKX 以太坊现货
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
|
||||||
|
self.ssh_config = ssh_config or {}
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self._tunnel: Optional[SSHTunnelManager] = None
|
||||||
|
self._tushare_token: Optional[str] = None
|
||||||
|
|
||||||
|
def _is_china_index(self, code: str) -> bool:
|
||||||
|
"""判断是否为中国A股指数"""
|
||||||
|
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS")
|
||||||
|
|
||||||
|
def _is_crypto(self, code: str) -> bool:
|
||||||
|
"""判断是否为加密货币"""
|
||||||
|
return code in self.CCXT_CODE_MAP
|
||||||
|
|
||||||
|
def _get_tushare_token(self) -> str:
|
||||||
|
"""获取 Tushare Token"""
|
||||||
|
if self._tushare_token is None:
|
||||||
|
import os
|
||||||
|
self._tushare_token = os.getenv("TUSHARE_TOKEN")
|
||||||
|
if not self._tushare_token:
|
||||||
|
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
|
||||||
|
return self._tushare_token
|
||||||
|
|
||||||
|
def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 临时清除代理环境变量(Tushare 是国内服务,不需要代理)
|
||||||
|
original_proxy = {}
|
||||||
|
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||||
|
original_proxy[key] = os.environ.pop(key, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tushare as ts
|
||||||
|
|
||||||
|
pro = ts.pro_api(self._get_tushare_token())
|
||||||
|
|
||||||
|
# 转换代码格式 (000300.SS -> 000300.SH)
|
||||||
|
ts_code = code.replace(".SS", ".SH")
|
||||||
|
|
||||||
|
# 获取日线数据
|
||||||
|
df = pro.index_daily(
|
||||||
|
ts_code=ts_code,
|
||||||
|
start_date=start_date.replace("-", ""),
|
||||||
|
end_date=end_date.replace("-", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
if df is None or len(df) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化列名
|
||||||
|
df = df.rename(columns={
|
||||||
|
"trade_date": "date",
|
||||||
|
"open": "open",
|
||||||
|
"high": "high",
|
||||||
|
"low": "low",
|
||||||
|
"close": "close",
|
||||||
|
"vol": "volume",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 转换日期格式
|
||||||
|
df["date"] = pd.to_datetime(df["date"])
|
||||||
|
df = df.set_index("date")
|
||||||
|
df = df.sort_index()
|
||||||
|
|
||||||
|
# 添加代码列
|
||||||
|
df["code"] = code
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Tushare 下载 {code} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 恢复代理环境变量
|
||||||
|
for key, value in original_proxy.items():
|
||||||
|
if value is not None:
|
||||||
|
os.environ[key] = value
|
||||||
|
|
||||||
|
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""使用 YFinance 获取数据"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 转换代码格式
|
||||||
|
yf_code = self.YF_CODE_MAP.get(code, code)
|
||||||
|
|
||||||
|
# 添加延迟以避免限流
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ticker = yf.Ticker(yf_code)
|
||||||
|
data = ticker.history(start=start_date, end=end_date)
|
||||||
|
|
||||||
|
if len(data) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化列名
|
||||||
|
data = data.rename(columns={
|
||||||
|
"Open": "open",
|
||||||
|
"High": "high",
|
||||||
|
"Low": "low",
|
||||||
|
"Close": "close",
|
||||||
|
"Volume": "volume",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加代码列
|
||||||
|
data["code"] = code
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
|
||||||
|
"""使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)"""
|
||||||
|
import ccxt
|
||||||
|
|
||||||
|
# 转换代码格式
|
||||||
|
ccxt_code = self.CCXT_CODE_MAP.get(code, code)
|
||||||
|
|
||||||
|
# 配置 CCXT
|
||||||
|
config = {'enableRateLimit': True}
|
||||||
|
if http_proxy:
|
||||||
|
config['proxies'] = {'http': http_proxy, 'https': http_proxy}
|
||||||
|
|
||||||
|
try:
|
||||||
|
exchange = ccxt.okx(config)
|
||||||
|
|
||||||
|
# 获取日线数据
|
||||||
|
since = int(pd.Timestamp(start_date).timestamp() * 1000)
|
||||||
|
all_ohlcv = []
|
||||||
|
limit = 100
|
||||||
|
|
||||||
|
while since < int(pd.Timestamp(end_date).timestamp() * 1000):
|
||||||
|
ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit)
|
||||||
|
if not ohlcv:
|
||||||
|
break
|
||||||
|
all_ohlcv.extend(ohlcv)
|
||||||
|
since = ohlcv[-1][0] + 86400000
|
||||||
|
|
||||||
|
if not all_ohlcv:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 转换为 DataFrame
|
||||||
|
df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||||
|
# 转换时间戳为日期索引
|
||||||
|
df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize()
|
||||||
|
df.index.name = 'date'
|
||||||
|
df = df[['open', 'high', 'low', 'close', 'volume']]
|
||||||
|
# 过滤日期范围
|
||||||
|
start_ts = pd.Timestamp(start_date)
|
||||||
|
end_ts = pd.Timestamp(end_date)
|
||||||
|
df = df[(df.index >= start_ts) & (df.index <= end_ts)]
|
||||||
|
df['code'] = code
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
|
||||||
|
"""获取单个标的的数据"""
|
||||||
|
if self._is_china_index(code):
|
||||||
|
return self._fetch_tushare(code, start_date, end_date)
|
||||||
|
elif self._is_crypto(code):
|
||||||
|
return self._fetch_ccxt(code, start_date, end_date, http_proxy)
|
||||||
|
else:
|
||||||
|
return self._fetch_yfinance(code, start_date, end_date)
|
||||||
|
|
||||||
|
def fetch_all(
|
||||||
|
self,
|
||||||
|
code_list, # list[代码] 或 dict{代码: 名称}
|
||||||
|
benchmark_code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
||||||
|
"""
|
||||||
|
批量获取数据
|
||||||
|
注意:由于 Tushare(中国A股) 和 YFinance(美股/加密货币) 的交易日历不同,
|
||||||
|
这里返回的是长格式数据,由调用方分别处理各市场的数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(etf_data, benchmark_data, valid_codes)
|
||||||
|
etf_data: DataFrame with columns [code, close, source], index=date
|
||||||
|
"""
|
||||||
|
all_data = []
|
||||||
|
valid_codes = []
|
||||||
|
|
||||||
|
# 兼容列表和字典格式
|
||||||
|
if isinstance(code_list, dict):
|
||||||
|
codes = list(code_list.keys())
|
||||||
|
code_name_map = code_list
|
||||||
|
else:
|
||||||
|
codes = code_list
|
||||||
|
code_name_map = {c: c for c in codes}
|
||||||
|
|
||||||
|
print(f"开始下载 {len(codes)} 只标的的数据...")
|
||||||
|
china_codes = [c for c in codes if self._is_china_index(c)]
|
||||||
|
global_codes = [c for c in codes if not self._is_china_index(c)]
|
||||||
|
print(f" 中国A股指数: {len(china_codes)} 只")
|
||||||
|
print(f" 港股/美股/加密货币: {len(global_codes)} 只")
|
||||||
|
|
||||||
|
# 检查是否需要启动 socks2http 代理(用于加密货币)
|
||||||
|
crypto_codes = [c for c in codes if self._is_crypto(c)]
|
||||||
|
http_proxy = None
|
||||||
|
socks2http_proc = None
|
||||||
|
|
||||||
|
# 只有在 SSH 隧道已建立时才启动 socks2http
|
||||||
|
if crypto_codes and self._tunnel is not None:
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
print(f"\n 启动 socks2http 代理服务(用于加密货币)...")
|
||||||
|
try:
|
||||||
|
# 启动 socks2http.py 作为子进程
|
||||||
|
socks2http_path = Path(__file__).parent / "socks2http.py"
|
||||||
|
socks2http_proc = subprocess.Popen(
|
||||||
|
[sys.executable, str(socks2http_path)],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
time.sleep(2) # 等待代理启动
|
||||||
|
http_proxy = "http://127.0.0.1:8080"
|
||||||
|
print(f" ✓ HTTP 代理已启动: {http_proxy}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ✗ 启动代理失败: {e}")
|
||||||
|
|
||||||
|
# 分别下载数据
|
||||||
|
for code in codes:
|
||||||
|
if self._is_china_index(code):
|
||||||
|
source = "Tushare"
|
||||||
|
elif self._is_crypto(code):
|
||||||
|
source = "CCXT/OKX"
|
||||||
|
else:
|
||||||
|
source = "YFinance"
|
||||||
|
|
||||||
|
name = code_name_map.get(code, code)
|
||||||
|
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
||||||
|
|
||||||
|
# 加密货币使用 HTTP 代理
|
||||||
|
proxy = http_proxy if self._is_crypto(code) else None
|
||||||
|
data = self.fetch_single(code, start_date, end_date, proxy)
|
||||||
|
|
||||||
|
if data is not None and len(data) > 0:
|
||||||
|
# 标准化数据格式
|
||||||
|
data = data.copy()
|
||||||
|
data['source'] = source
|
||||||
|
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
|
||||||
|
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
||||||
|
all_data.append(data[['code', 'close', 'source']])
|
||||||
|
valid_codes.append(code)
|
||||||
|
print(f"✓ {len(data)} 条")
|
||||||
|
else:
|
||||||
|
print("✗ 无数据")
|
||||||
|
|
||||||
|
# 关闭 socks2http 代理
|
||||||
|
if socks2http_proc:
|
||||||
|
socks2http_proc.terminate()
|
||||||
|
socks2http_proc.wait()
|
||||||
|
print(f"\n socks2http 代理已关闭")
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
return None, None, []
|
||||||
|
|
||||||
|
# 检查数据源类型
|
||||||
|
sources = set(d['source'].iloc[0] for d in all_data)
|
||||||
|
|
||||||
|
if len(sources) == 1:
|
||||||
|
# 单一数据源:转换为宽格式(向后兼容)
|
||||||
|
all_df = pd.concat(all_data, ignore_index=False)
|
||||||
|
all_df = all_df.reset_index()
|
||||||
|
all_df['date'] = pd.to_datetime(all_df['date'], utc=True).dt.tz_localize(None)
|
||||||
|
etf_data = all_df.pivot_table(
|
||||||
|
index='date',
|
||||||
|
columns='code',
|
||||||
|
values='close',
|
||||||
|
aggfunc='first'
|
||||||
|
)
|
||||||
|
print(f"\n数据整理完成 (单一数据源 {list(sources)[0]}):")
|
||||||
|
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
||||||
|
print(f" 交易日数: {len(etf_data)}")
|
||||||
|
print(f" 有效标的: {len(etf_data.columns)} 只")
|
||||||
|
else:
|
||||||
|
# 多数据源:以主市场(Tushare/A股)为基准,其他市场数据前向填充
|
||||||
|
print(f"\n数据整理完成 (多数据源 - 以A股交易日为基准):")
|
||||||
|
|
||||||
|
# 合并所有数据(索引已经是标准化后的日期)
|
||||||
|
all_df = pd.concat(all_data, ignore_index=False)
|
||||||
|
all_df = all_df.reset_index()
|
||||||
|
# 重命名索引列为 date
|
||||||
|
if 'index' in all_df.columns:
|
||||||
|
all_df = all_df.rename(columns={'index': 'date'})
|
||||||
|
# 确保 date 列是日期格式(不含时间)
|
||||||
|
all_df['date'] = pd.to_datetime(all_df['date']).dt.normalize()
|
||||||
|
|
||||||
|
# 透视为宽格式
|
||||||
|
etf_data = all_df.pivot_table(
|
||||||
|
index='date',
|
||||||
|
columns='code',
|
||||||
|
values='close',
|
||||||
|
aggfunc='first'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取主市场(Tushare)的交易日历
|
||||||
|
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
|
||||||
|
if tushare_codes:
|
||||||
|
# 使用第一个A股代码的日期作为主市场交易日
|
||||||
|
primary_dates = etf_data[tushare_codes[0]].dropna().index
|
||||||
|
print(f" 主市场交易日: {len(primary_dates)} 天")
|
||||||
|
|
||||||
|
# 重新索引到主市场交易日,使用前向填充
|
||||||
|
etf_data = etf_data.reindex(primary_dates)
|
||||||
|
|
||||||
|
# 对每个非主市场代码进行前向填充
|
||||||
|
yfinance_codes = [c for c in valid_codes if not self._is_china_index(c)]
|
||||||
|
for code in yfinance_codes:
|
||||||
|
if code in etf_data.columns:
|
||||||
|
# 前向填充:用最近的有效价格填充休市日的数据
|
||||||
|
etf_data[code] = etf_data[code].ffill()
|
||||||
|
# 对于开头的NaN,用后向填充
|
||||||
|
etf_data[code] = etf_data[code].bfill()
|
||||||
|
|
||||||
|
print(f" 非主市场标的: {len(yfinance_codes)} 只 (已前向填充)")
|
||||||
|
|
||||||
|
print(f" 时间范围: {etf_data.index[0]} ~ {etf_data.index[-1]}")
|
||||||
|
print(f" 交易日数: {len(etf_data)}")
|
||||||
|
print(f" 有效标的: {len(etf_data.columns)} 只")
|
||||||
|
|
||||||
|
# 获取基准数据
|
||||||
|
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
||||||
|
if benchmark_data is not None:
|
||||||
|
# 标准化日期索引(无时区,只保留日期部分)
|
||||||
|
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||||||
|
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||||
|
|
||||||
|
return etf_data, benchmark_data, valid_codes
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""上下文管理器入口"""
|
||||||
|
if self.ssh_config.get("enabled"):
|
||||||
|
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||||
|
self._tunnel.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""上下文管理器出口"""
|
||||||
|
if self._tunnel:
|
||||||
|
self._tunnel.stop()
|
||||||
168
core/datasource/socks2http.py
Normal file
168
core/datasource/socks2http.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
SOCKS5 转 HTTP 代理转发工具
|
||||||
|
将 SSH 隧道的 SOCKS5 代理 (1080) 转为 HTTP 代理 (8080)
|
||||||
|
供 CCXT 等只支持 HTTP 代理的库使用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
import threading
|
||||||
|
import select
|
||||||
|
import sys
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
class Socks2Http:
|
||||||
|
def __init__(self, socks_host='127.0.0.1', socks_port=1080, http_port=8080):
|
||||||
|
self.socks_host = socks_host
|
||||||
|
self.socks_port = socks_port
|
||||||
|
self.http_port = http_port
|
||||||
|
self.server = None
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
"""启动 HTTP 代理服务器"""
|
||||||
|
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
self.server.bind(('127.0.0.1', self.http_port))
|
||||||
|
self.server.listen(5)
|
||||||
|
self.running = True
|
||||||
|
|
||||||
|
print(f"HTTP 代理已启动: http://127.0.0.1:{self.http_port}")
|
||||||
|
print(f"转发到 SOCKS5: {self.socks_host}:{self.socks_port}")
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
client, addr = self.server.accept()
|
||||||
|
thread = threading.Thread(target=self._handle_client, args=(client,))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
except Exception as e:
|
||||||
|
if self.running:
|
||||||
|
print(f"接受连接错误: {e}")
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止代理服务器"""
|
||||||
|
self.running = False
|
||||||
|
if self.server:
|
||||||
|
self.server.close()
|
||||||
|
print("HTTP 代理已停止")
|
||||||
|
|
||||||
|
def _handle_client(self, client):
|
||||||
|
"""处理客户端连接"""
|
||||||
|
try:
|
||||||
|
# 读取 HTTP 请求
|
||||||
|
request = client.recv(4096)
|
||||||
|
if not request:
|
||||||
|
client.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析 CONNECT 请求
|
||||||
|
first_line = request.split(b'\r\n')[0].decode('utf-8', errors='ignore')
|
||||||
|
|
||||||
|
if first_line.startswith('CONNECT'):
|
||||||
|
# HTTPS 代理
|
||||||
|
parts = first_line.split()
|
||||||
|
if len(parts) >= 2:
|
||||||
|
target = parts[1]
|
||||||
|
host, port = target.rsplit(':', 1)
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
# 连接到 SOCKS5 代理
|
||||||
|
remote = self._connect_via_socks5(host, port)
|
||||||
|
if remote:
|
||||||
|
client.send(b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||||
|
self._relay(client, remote)
|
||||||
|
else:
|
||||||
|
client.send(b'HTTP/1.1 502 Bad Gateway\r\n\r\n')
|
||||||
|
else:
|
||||||
|
# HTTP 代理
|
||||||
|
lines = first_line.split()
|
||||||
|
if len(lines) >= 2:
|
||||||
|
url = lines[1]
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname
|
||||||
|
port = parsed.port or 80
|
||||||
|
|
||||||
|
# 连接到 SOCKS5 代理
|
||||||
|
remote = self._connect_via_socks5(host, port)
|
||||||
|
if remote:
|
||||||
|
# 修改请求,去掉完整 URL
|
||||||
|
new_request = request.replace(
|
||||||
|
f'{lines[0]} {url} '.encode(),
|
||||||
|
f'{lines[0]} {parsed.path or "/"}{"?" + parsed.query if parsed.query else ""} '.encode()
|
||||||
|
)
|
||||||
|
remote.send(new_request)
|
||||||
|
self._relay(client, remote)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理客户端错误: {e}")
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
def _connect_via_socks5(self, host, port):
|
||||||
|
"""通过 SOCKS5 代理连接目标服务器"""
|
||||||
|
try:
|
||||||
|
# 连接到 SOCKS5 代理
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(30)
|
||||||
|
sock.connect((self.socks_host, self.socks_port))
|
||||||
|
|
||||||
|
# SOCKS5 握手
|
||||||
|
# 1. 发送认证方法
|
||||||
|
sock.send(b'\x05\x01\x00') # VER=5, NMETHODS=1, METHOD=0 (无认证)
|
||||||
|
resp = sock.recv(2)
|
||||||
|
if resp[0] != 0x05 or resp[1] != 0x00:
|
||||||
|
sock.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 发送连接请求
|
||||||
|
req = b'\x05\x01\x00\x03' # VER=5, CMD=CONNECT, RSV=0, ATYP=DOMAIN
|
||||||
|
req += bytes([len(host)]) + host.encode()
|
||||||
|
req += bytes([(port >> 8) & 0xFF, port & 0xFF])
|
||||||
|
sock.send(req)
|
||||||
|
|
||||||
|
# 3. 读取响应
|
||||||
|
resp = sock.recv(10)
|
||||||
|
if len(resp) < 4 or resp[1] != 0x00:
|
||||||
|
sock.close()
|
||||||
|
return None
|
||||||
|
|
||||||
|
return sock
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"SOCKS5 连接错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _relay(self, client, remote):
|
||||||
|
"""双向转发数据"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
readable, _, _ = select.select([client, remote], [], [], 60)
|
||||||
|
if not readable:
|
||||||
|
break
|
||||||
|
|
||||||
|
if client in readable:
|
||||||
|
data = client.recv(4096)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
remote.send(data)
|
||||||
|
|
||||||
|
if remote in readable:
|
||||||
|
data = remote.recv(4096)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
|
client.send(data)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
client.close()
|
||||||
|
remote.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
proxy = Socks2Http(socks_port=1080, http_port=8080)
|
||||||
|
try:
|
||||||
|
proxy.start()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
proxy.stop()
|
||||||
194
core/datasource/yfinance_source.py
Normal file
194
core/datasource/yfinance_source.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
YFinance 数据源模块
|
||||||
|
支持通过 SSH 隧道访问(用于绕过网络限制)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
import pandas as pd
|
||||||
|
import yfinance as yf
|
||||||
|
import urllib3
|
||||||
|
|
||||||
|
# 禁用 SSL 警告
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
|
class SSHTunnelManager:
|
||||||
|
"""SSH 隧道管理器"""
|
||||||
|
|
||||||
|
def __init__(self, config: dict):
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
self.enabled = config.get("enabled", False)
|
||||||
|
self.host = config.get("host", "")
|
||||||
|
self.port = config.get("port", 22)
|
||||||
|
self.username = config.get("username", "")
|
||||||
|
self.local_port = config.get("local_port", 1080)
|
||||||
|
self._process: Optional[subprocess.Popen] = None
|
||||||
|
|
||||||
|
# 处理 key_path:如果是相对路径,转换为绝对路径
|
||||||
|
key_path = config.get("key_path", "")
|
||||||
|
if key_path and not os.path.isabs(key_path):
|
||||||
|
# 相对于项目根目录
|
||||||
|
project_root = Path(__file__).parent.parent.parent
|
||||||
|
key_path = str(project_root / key_path)
|
||||||
|
self.key_path = key_path
|
||||||
|
|
||||||
|
def start(self) -> bool:
|
||||||
|
"""启动 SSH 隧道"""
|
||||||
|
if not self.enabled:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not all([self.host, self.username, self.key_path]):
|
||||||
|
print("SSH 配置不完整,跳过隧道建立")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"ssh",
|
||||||
|
"-N",
|
||||||
|
"-D", f"127.0.0.1:{self.local_port}",
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-o", "UserKnownHostsFile=/dev/null",
|
||||||
|
"-i", self.key_path,
|
||||||
|
"-p", str(self.port),
|
||||||
|
f"{self.username}@{self.host}"
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
if self._process.poll() is not None:
|
||||||
|
stdout, stderr = self._process.communicate()
|
||||||
|
print("✗ SSH 隧道建立失败")
|
||||||
|
if stderr:
|
||||||
|
print(f"错误: {stderr.decode()}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 设置代理环境变量
|
||||||
|
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
|
||||||
|
os.environ["HTTP_PROXY"] = proxy_url
|
||||||
|
os.environ["HTTPS_PROXY"] = proxy_url
|
||||||
|
os.environ["ALL_PROXY"] = proxy_url
|
||||||
|
|
||||||
|
print(f"✓ SSH 隧道已建立: {proxy_url}")
|
||||||
|
time.sleep(1)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ SSH 隧道异常: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止 SSH 隧道"""
|
||||||
|
if self._process:
|
||||||
|
self._process.terminate()
|
||||||
|
self._process.wait()
|
||||||
|
# 清理环境变量
|
||||||
|
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
print("SSH 隧道已关闭")
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class YFinanceDataSource:
|
||||||
|
"""YFinance 数据源"""
|
||||||
|
|
||||||
|
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
|
||||||
|
self.ssh_config = ssh_config or {}
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self._tunnel: Optional[SSHTunnelManager] = None
|
||||||
|
|
||||||
|
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||||
|
"""获取单个 ETF 数据"""
|
||||||
|
try:
|
||||||
|
ticker = yf.Ticker(code)
|
||||||
|
data = ticker.history(start=start_date, end=end_date)
|
||||||
|
|
||||||
|
if len(data) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 标准化列名
|
||||||
|
data = data.rename(columns={
|
||||||
|
"Open": "open",
|
||||||
|
"High": "high",
|
||||||
|
"Low": "low",
|
||||||
|
"Close": "close",
|
||||||
|
"Volume": "volume",
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加代码列
|
||||||
|
data["code"] = code
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载 {code} 失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def fetch_all(
|
||||||
|
self,
|
||||||
|
code_list: list,
|
||||||
|
benchmark_code: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
|
||||||
|
"""
|
||||||
|
批量获取 ETF 数据和基准数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(etf_data, benchmark_data, valid_codes)
|
||||||
|
"""
|
||||||
|
all_data = []
|
||||||
|
valid_codes = []
|
||||||
|
|
||||||
|
print(f"开始下载 {len(code_list)} 只 ETF 数据...")
|
||||||
|
|
||||||
|
for code in code_list:
|
||||||
|
data = self.fetch_single(code, start_date, end_date)
|
||||||
|
if data is not None and len(data) > 0:
|
||||||
|
all_data.append(data)
|
||||||
|
valid_codes.append(code)
|
||||||
|
print(f" ✓ {code}: {len(data)} 条")
|
||||||
|
else:
|
||||||
|
print(f" ✗ {code}: 无数据")
|
||||||
|
|
||||||
|
if not all_data:
|
||||||
|
return None, None, []
|
||||||
|
|
||||||
|
# 合并数据
|
||||||
|
etf_data = pd.concat(all_data, ignore_index=False)
|
||||||
|
|
||||||
|
# 获取基准数据
|
||||||
|
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
|
||||||
|
if benchmark_data is not None:
|
||||||
|
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||||
|
|
||||||
|
return etf_data, benchmark_data, valid_codes
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
"""上下文管理器入口"""
|
||||||
|
if self.ssh_config.get("enabled"):
|
||||||
|
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||||
|
self._tunnel.start()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
"""上下文管理器出口"""
|
||||||
|
if self._tunnel:
|
||||||
|
self._tunnel.stop()
|
||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from strategies.base import BacktestStrategy
|
from strategies.base import BacktestStrategy
|
||||||
from core.data.hybrid_source import HybridDataSource
|
from core.datasource.hybrid_source import HybridDataSource
|
||||||
from core.factors.momentum import compute_factors, calculate_daily_return
|
from core.factors.momentum import compute_factors, calculate_daily_return
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user