Files
etf/core/datasource/yfinance_source.py
aszerW 6454e6823f fix(datasource): 修正混合数据源导入路径错误
- 修正 strategies.rotation.engine 中 hybrid_source 模块导入路径错误
- 新增 core.datasource 目录下多个数据源实现模块
- 增加 Akshare 数据源支持 A股指数数据拉取
- 实现数据缓存管理机制,支持本地数据缓存读写
- 新增 YFinance 数据源,支持通过 SSH 隧道访问美股和港股数据
- 实现混合数据源支持 A股/Tushare、港美股/YFinance、加密货币/CCXT 的统一访问
- 集成 SSH 隧道管理,支持 SOCKS5 转 HTTP 代理转发
- 新增 socks2http.py 代理转发工具,解决 CCXT 仅支持 HTTP 代理问题
- 修改 rotation.yaml 加密货币注释,明确使用 OKX 现货和 SSH->HTTP 代理访问
- 删除.gitignore中无用的 data/ 忽略规则,保留 test/ 文件夹忽略规则
2026-03-25 01:32:33 +08:00

195 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YFinance 数据源模块
支持通过 SSH 隧道访问(用于绕过网络限制)
"""
import os
import time
import subprocess
from typing import Optional, Tuple
from datetime import datetime
import pandas as pd
import yfinance as yf
import urllib3
# 禁用 SSL 警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class SSHTunnelManager:
"""SSH 隧道管理器"""
def __init__(self, config: dict):
from pathlib import Path
self.enabled = config.get("enabled", False)
self.host = config.get("host", "")
self.port = config.get("port", 22)
self.username = config.get("username", "")
self.local_port = config.get("local_port", 1080)
self._process: Optional[subprocess.Popen] = None
# 处理 key_path如果是相对路径转换为绝对路径
key_path = config.get("key_path", "")
if key_path and not os.path.isabs(key_path):
# 相对于项目根目录
project_root = Path(__file__).parent.parent.parent
key_path = str(project_root / key_path)
self.key_path = key_path
def start(self) -> bool:
"""启动 SSH 隧道"""
if not self.enabled:
return True
if not all([self.host, self.username, self.key_path]):
print("SSH 配置不完整,跳过隧道建立")
return False
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
cmd = [
"ssh",
"-N",
"-D", f"127.0.0.1:{self.local_port}",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", self.key_path,
"-p", str(self.port),
f"{self.username}@{self.host}"
]
try:
self._process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(2)
if self._process.poll() is not None:
stdout, stderr = self._process.communicate()
print("✗ SSH 隧道建立失败")
if stderr:
print(f"错误: {stderr.decode()}")
return False
# 设置代理环境变量
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
os.environ["HTTP_PROXY"] = proxy_url
os.environ["HTTPS_PROXY"] = proxy_url
os.environ["ALL_PROXY"] = proxy_url
print(f"✓ SSH 隧道已建立: {proxy_url}")
time.sleep(1)
return True
except Exception as e:
print(f"✗ SSH 隧道异常: {e}")
return False
def stop(self):
"""停止 SSH 隧道"""
if self._process:
self._process.terminate()
self._process.wait()
# 清理环境变量
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
os.environ.pop(key, None)
print("SSH 隧道已关闭")
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class YFinanceDataSource:
"""YFinance 数据源"""
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._tunnel: Optional[SSHTunnelManager] = None
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""获取单个 ETF 数据"""
try:
ticker = yf.Ticker(code)
data = ticker.history(start=start_date, end=end_date)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 添加代码列
data["code"] = code
return data
except Exception as e:
print(f"下载 {code} 失败: {e}")
return None
def fetch_all(
self,
code_list: list,
benchmark_code: str,
start_date: str,
end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
"""
批量获取 ETF 数据和基准数据
Returns:
(etf_data, benchmark_data, valid_codes)
"""
all_data = []
valid_codes = []
print(f"开始下载 {len(code_list)} 只 ETF 数据...")
for code in code_list:
data = self.fetch_single(code, start_date, end_date)
if data is not None and len(data) > 0:
all_data.append(data)
valid_codes.append(code)
print(f"{code}: {len(data)}")
else:
print(f"{code}: 无数据")
if not all_data:
return None, None, []
# 合并数据
etf_data = pd.concat(all_data, ignore_index=False)
# 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None:
print(f" ✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return etf_data, benchmark_data, valid_codes
def __enter__(self):
"""上下文管理器入口"""
if self.ssh_config.get("enabled"):
self._tunnel = SSHTunnelManager(self.ssh_config)
self._tunnel.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
if self._tunnel:
self._tunnel.stop()