feat: 创建数据源模块 datasource/
核心功能:
- ssh_tunnel.py: SSH隧道管理器(连接香港ECS)
- tushare_source.py: A股数据获取(指数、ETF、期货)
- yfinance_source.py: 境外数据获取(港股、美股)
- hybrid_source.py: 混合数据源(整合所有)
使用方式:
from datasource import HybridDataSource
source = HybridDataSource.from_yaml('config/strategies/rotation.yaml')
result = source.fetch_all()
更新 RotationStrategy 使用新数据源模块
This commit is contained in:
19
datasource/__init__.py
Normal file
19
datasource/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
数据源模块
|
||||
|
||||
核心数据获取能力:
|
||||
- A股数据:Tushare(指数、ETF、期货)
|
||||
- 境外数据:YFinance(港股、美股)通过SSH隧道
|
||||
"""
|
||||
|
||||
from .ssh_tunnel import SSHTunnelManager
|
||||
from .tushare_source import TushareSource
|
||||
from .yfinance_source import YFinanceSource
|
||||
from .hybrid_source import HybridDataSource
|
||||
|
||||
__all__ = [
|
||||
'SSHTunnelManager',
|
||||
'TushareSource',
|
||||
'YFinanceSource',
|
||||
'HybridDataSource',
|
||||
]
|
||||
285
datasource/hybrid_source.py
Normal file
285
datasource/hybrid_source.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
混合数据源
|
||||
|
||||
整合 Tushare(A股) + YFinance(境外)数据获取
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
from .ssh_tunnel import SSHTunnelManager
|
||||
from .tushare_source import TushareSource
|
||||
from .yfinance_source import YFinanceSource
|
||||
|
||||
|
||||
class HybridDataSource:
|
||||
"""
|
||||
混合数据源
|
||||
|
||||
- A股指数/ETF/期货: Tushare
|
||||
- 港股/美股/商品: YFinance(通过SSH隧道)
|
||||
|
||||
使用方式:
|
||||
from datasource import HybridDataSource
|
||||
|
||||
source = HybridDataSource.from_yaml('config/strategies/rotation.yaml')
|
||||
result = source.fetch_all()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ssh_config: Optional[dict] = None,
|
||||
use_cache: bool = True,
|
||||
cache_dir: str = "data/etf_cache/daily"
|
||||
):
|
||||
"""
|
||||
初始化混合数据源
|
||||
|
||||
Args:
|
||||
ssh_config: SSH隧道配置
|
||||
use_cache: 是否使用缓存
|
||||
cache_dir: 缓存目录
|
||||
"""
|
||||
self.ssh_config = ssh_config or {}
|
||||
self.use_cache = use_cache
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
# 数据源实例
|
||||
self._tushare = TushareSource()
|
||||
self._yfinance = YFinanceSource()
|
||||
|
||||
# SSH隧道(延迟初始化)
|
||||
self._tunnel: Optional[SSHTunnelManager] = None
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, config_path: str) -> 'HybridDataSource':
|
||||
"""从YAML配置创建实例"""
|
||||
import yaml
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
return cls(
|
||||
ssh_config=config.get('ssh_tunnel', {}),
|
||||
use_cache=config.get('use_cache', True)
|
||||
)
|
||||
|
||||
def _start_tunnel(self) -> bool:
|
||||
"""启动SSH隧道"""
|
||||
if self._tunnel is None and self.ssh_config.get('enabled'):
|
||||
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||||
return self._tunnel.start()
|
||||
return True
|
||||
|
||||
def _stop_tunnel(self):
|
||||
"""停止SSH隧道"""
|
||||
if self._tunnel:
|
||||
self._tunnel.stop()
|
||||
self._tunnel = None
|
||||
|
||||
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取单个标的数据
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data
|
||||
"""
|
||||
# 判断数据源
|
||||
if self._tushare.is_china_index(code) or self._tushare.is_futures(code):
|
||||
return self._tushare.fetch(code, start_date, end_date)
|
||||
else:
|
||||
# YFinance需要SSH隧道
|
||||
self._start_tunnel()
|
||||
return self._yfinance.fetch(code, start_date, end_date)
|
||||
|
||||
def fetch_all(
|
||||
self,
|
||||
code_config: dict,
|
||||
benchmark_code: str = "000300.SH",
|
||||
start_date: str = "2019-01-01",
|
||||
end_date: str = None
|
||||
) -> Tuple[
|
||||
Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式)
|
||||
Optional[pd.DataFrame], # etf_data: ETF价格(宽格式)
|
||||
Optional[pd.DataFrame], # etf_nav_data: ETF净值
|
||||
Optional[pd.DataFrame], # benchmark_data: 基准数据
|
||||
List[str], # valid_codes: 有效代码列表
|
||||
Dict[str, pd.DataFrame] # index_ohlcv_data: 原始OHLCV数据
|
||||
]:
|
||||
"""
|
||||
批量获取数据
|
||||
|
||||
Args:
|
||||
code_config: 标的配置 {代码: {name, etf, market}}
|
||||
benchmark_code: 基准代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
|
||||
Returns:
|
||||
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data)
|
||||
"""
|
||||
if end_date is None:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
# 启动SSH隧道
|
||||
self._start_tunnel()
|
||||
|
||||
index_codes = list(code_config.keys())
|
||||
etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')}
|
||||
|
||||
print(f"开始下载 {len(index_codes)} 只标的的数据...")
|
||||
print(f" 指数代码: {len(index_codes)} 只")
|
||||
print(f" ETF映射: {len(etf_codes)} 只")
|
||||
|
||||
# 分类统计
|
||||
china_codes = [c for c in index_codes if self._tushare.is_china_index(c)]
|
||||
futures_codes = [c for c in index_codes if self._tushare.is_futures(c)]
|
||||
yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)]
|
||||
|
||||
print(f" 中国A股指数: {len(china_codes)} 只")
|
||||
print(f" 期货合约: {len(futures_codes)} 只")
|
||||
print(f" 港股/美股: {len(yf_codes)} 只")
|
||||
|
||||
# 下载指数数据
|
||||
print("\n [1/2] 下载指数数据...")
|
||||
index_data_list = []
|
||||
index_ohlcv_data = {}
|
||||
valid_codes = []
|
||||
|
||||
for code in index_codes:
|
||||
name = code_config[code].get('name', code)
|
||||
source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance"
|
||||
|
||||
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
||||
|
||||
data = self.fetch_single(code, start_date, end_date)
|
||||
|
||||
if data is not None and len(data) > 0:
|
||||
# 标准化
|
||||
data = data.copy()
|
||||
data['source'] = source
|
||||
data['code'] = code
|
||||
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
||||
|
||||
index_ohlcv_data[code] = data.copy()
|
||||
index_data_list.append(data[['code', 'close', 'source']])
|
||||
valid_codes.append(code)
|
||||
print(f"✓ {len(data)} 条")
|
||||
else:
|
||||
print("✗ 无数据")
|
||||
|
||||
# 下载ETF数据
|
||||
etf_data_list = []
|
||||
etf_nav_data_list = []
|
||||
|
||||
if etf_codes:
|
||||
print("\n [2/2] 下载ETF数据...")
|
||||
|
||||
for idx_code, etf_code in etf_codes.items():
|
||||
name = code_config[idx_code].get('name', idx_code)
|
||||
|
||||
print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ")
|
||||
|
||||
# ETF价格
|
||||
etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date)
|
||||
|
||||
# ETF净值
|
||||
etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date)
|
||||
|
||||
if etf_data is not None and len(etf_data) > 0:
|
||||
etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize()
|
||||
etf_data_list.append(etf_data[['code', 'close']])
|
||||
|
||||
price_count = len(etf_data)
|
||||
nav_count = len(etf_nav) if etf_nav else 0
|
||||
|
||||
print(f"✓ 价格{price_count}条 净值{nav_count}条")
|
||||
else:
|
||||
print("✗ 无数据")
|
||||
|
||||
if etf_nav is not None and len(etf_nav) > 0:
|
||||
etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize()
|
||||
etf_nav_data_list.append(etf_nav[['code', 'nav']])
|
||||
|
||||
# 整合数据
|
||||
index_data = None
|
||||
if index_data_list:
|
||||
index_data = pd.concat(index_data_list)
|
||||
index_data = index_data.pivot(columns='code', values='close')
|
||||
|
||||
etf_data = None
|
||||
if etf_data_list:
|
||||
etf_data = pd.concat(etf_data_list)
|
||||
etf_data = etf_data.pivot(columns='code', values='close')
|
||||
|
||||
etf_nav_data = None
|
||||
if etf_nav_data_list:
|
||||
etf_nav_data = pd.concat(etf_nav_data_list)
|
||||
etf_nav_data = etf_nav_data.pivot(columns='code', values='nav')
|
||||
|
||||
# 基准数据
|
||||
benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date)
|
||||
if benchmark_data is not None:
|
||||
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||||
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||||
|
||||
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data
|
||||
|
||||
def __enter__(self):
|
||||
self._start_tunnel()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._stop_tunnel()
|
||||
|
||||
|
||||
# 简化接口
|
||||
def fetch_rotation_data(config_path: str = "config/strategies/rotation.yaml") -> dict:
|
||||
"""
|
||||
获取轮动策略数据(简化接口)
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
|
||||
Returns:
|
||||
{
|
||||
'index_data': 指数收盘价DataFrame,
|
||||
'etf_data': ETF价格DataFrame,
|
||||
'etf_nav_data': ETF净值DataFrame,
|
||||
'benchmark_data': 基准DataFrame,
|
||||
'valid_codes': 有效代码列表,
|
||||
'index_ohlcv_data': 原始OHLCV数据字典
|
||||
}
|
||||
"""
|
||||
import yaml
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
source = HybridDataSource.from_yaml(config_path)
|
||||
|
||||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
|
||||
source.fetch_all(
|
||||
code_config=config.get('code_list', {}),
|
||||
benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'),
|
||||
start_date=config.get('start_date', '2019-01-01'),
|
||||
end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||||
)
|
||||
|
||||
return {
|
||||
'index_data': index_data,
|
||||
'etf_data': etf_data,
|
||||
'etf_nav_data': etf_nav_data,
|
||||
'benchmark_data': benchmark_data,
|
||||
'valid_codes': valid_codes,
|
||||
'index_ohlcv_data': index_ohlcv_data
|
||||
}
|
||||
116
datasource/ssh_tunnel.py
Normal file
116
datasource/ssh_tunnel.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
SSH隧道管理器
|
||||
|
||||
通过SSH隧道建立本地SOCKS5代理,用于访问境外数据源
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class SSHTunnelManager:
|
||||
"""SSH隧道管理器"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
"""
|
||||
初始化SSH隧道
|
||||
|
||||
Args:
|
||||
config: SSH配置字典
|
||||
- host: SSH服务器地址
|
||||
- port: SSH端口(默认22)
|
||||
- username: SSH用户名
|
||||
- key_path: SSH私钥路径(相对或绝对)
|
||||
- local_port: 本地SOCKS5端口(默认1080)
|
||||
"""
|
||||
self.enabled = config.get("enabled", False)
|
||||
self.host = config.get("host", "")
|
||||
self.port = config.get("port", 22)
|
||||
self.username = config.get("username", "root")
|
||||
self.local_port = config.get("local_port", 1080)
|
||||
self._process: Optional[subprocess.Popen] = None
|
||||
|
||||
# 处理 key_path:相对路径转换为绝对路径
|
||||
key_path = config.get("key_path", "")
|
||||
if key_path and not os.path.isabs(key_path):
|
||||
# 相对于项目根目录
|
||||
project_root = Path(__file__).parent.parent
|
||||
key_path = str(project_root / key_path)
|
||||
self.key_path = key_path
|
||||
|
||||
def start(self) -> bool:
|
||||
"""启动SSH隧道"""
|
||||
if not self.enabled:
|
||||
return True
|
||||
|
||||
if not all([self.host, self.username, self.key_path]):
|
||||
print("SSH配置不完整,跳过隧道建立")
|
||||
return False
|
||||
|
||||
print(f"建立SSH隧道: {self.host}:{self.port} -> 本地SOCKS5端口 {self.local_port}")
|
||||
|
||||
cmd = [
|
||||
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-i", self.key_path,
|
||||
"-p", str(self.port),
|
||||
f"{self.username}@{self.host}"
|
||||
]
|
||||
|
||||
try:
|
||||
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
time.sleep(2)
|
||||
|
||||
if self._process.poll() is not None:
|
||||
stdout, stderr = self._process.communicate()
|
||||
print("✗ SSH隧道建立失败")
|
||||
if stderr:
|
||||
print(f"错误: {stderr.decode()}")
|
||||
return False
|
||||
|
||||
# 设置代理环境变量
|
||||
# 使用 socks5h:// 让代理服务器远程解析DNS,避免IPv6问题
|
||||
proxy_url = f"socks5h://127.0.0.1:{self.local_port}"
|
||||
os.environ["HTTP_PROXY"] = proxy_url
|
||||
os.environ["HTTPS_PROXY"] = proxy_url
|
||||
os.environ["ALL_PROXY"] = proxy_url
|
||||
|
||||
print(f"✓ SSH隧道已建立: {proxy_url}")
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ SSH隧道异常: {e}")
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
"""停止SSH隧道"""
|
||||
if self._process:
|
||||
self._process.terminate()
|
||||
self._process.wait()
|
||||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
||||
os.environ.pop(key, None)
|
||||
print("SSH隧道已关闭")
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop()
|
||||
|
||||
|
||||
def create_ssh_tunnel_from_yaml(config_path: str) -> SSHTunnelManager:
|
||||
"""从YAML配置创建SSH隧道"""
|
||||
import yaml
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
ssh_config = config.get('ssh_tunnel', {})
|
||||
return SSHTunnelManager(ssh_config)
|
||||
245
datasource/tushare_source.py
Normal file
245
datasource/tushare_source.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""
|
||||
Tushare数据源
|
||||
|
||||
获取A股指数、ETF、期货数据
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class TushareSource:
|
||||
"""Tushare数据源"""
|
||||
|
||||
def __init__(self, token: Optional[str] = None):
|
||||
"""
|
||||
初始化Tushare数据源
|
||||
|
||||
Args:
|
||||
token: Tushare Token(可选,默认从环境变量读取)
|
||||
"""
|
||||
self._token = token or os.getenv("TUSHARE_TOKEN")
|
||||
if not self._token:
|
||||
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
|
||||
|
||||
def _get_pro_api(self):
|
||||
"""获取Tushare Pro API"""
|
||||
import tushare as ts
|
||||
return ts.pro_api(self._token)
|
||||
|
||||
def _clear_proxy(self) -> dict:
|
||||
"""清除代理环境变量(Tushare是国内服务,不需要代理)"""
|
||||
original = {}
|
||||
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||
original[key] = os.environ.pop(key, None)
|
||||
return original
|
||||
|
||||
def _restore_proxy(self, original: dict):
|
||||
"""恢复代理环境变量"""
|
||||
for key, value in original.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
|
||||
def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取A股指数数据
|
||||
|
||||
Args:
|
||||
code: 指数代码,如 '000300.SH', '399006.SZ', 'H30269.CSI'
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
"""
|
||||
original_proxy = self._clear_proxy()
|
||||
|
||||
try:
|
||||
pro = self._get_pro_api()
|
||||
|
||||
# 转换代码格式 (.SS -> .SH)
|
||||
ts_code = code.replace(".SS", ".SH")
|
||||
|
||||
df = pro.index_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", "")
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
df = df.rename(columns={
|
||||
"trade_date": "date",
|
||||
"vol": "volume",
|
||||
})
|
||||
|
||||
# 转换日期格式
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.set_index("date")
|
||||
df = df.sort_index()
|
||||
df["code"] = code
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载指数 {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
self._restore_proxy(original_proxy)
|
||||
|
||||
def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取期货数据
|
||||
|
||||
Args:
|
||||
code: 期货代码,如 'AU.SHF', 'CU.SHF'
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
"""
|
||||
original_proxy = self._clear_proxy()
|
||||
|
||||
try:
|
||||
pro = self._get_pro_api()
|
||||
|
||||
df = pro.futures_daily(
|
||||
ts_code=code,
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", ""),
|
||||
exchange=''
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
df = df.rename(columns={
|
||||
"trade_date": "date",
|
||||
"vol": "volume",
|
||||
})
|
||||
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.set_index("date")
|
||||
df = df.sort_index()
|
||||
df["code"] = code
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载期货 {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
self._restore_proxy(original_proxy)
|
||||
|
||||
def fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取ETF价格数据
|
||||
|
||||
Args:
|
||||
code: ETF代码,如 '159915.SZ', '518880.SH'
|
||||
"""
|
||||
original_proxy = self._clear_proxy()
|
||||
|
||||
try:
|
||||
pro = self._get_pro_api()
|
||||
|
||||
ts_code = code.replace(".SS", ".SH")
|
||||
|
||||
df = pro.fund_daily(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", "")
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
df = df.rename(columns={
|
||||
"trade_date": "date",
|
||||
"vol": "volume",
|
||||
})
|
||||
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.set_index("date")
|
||||
df = df.sort_index()
|
||||
df["code"] = code
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载ETF {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
self._restore_proxy(original_proxy)
|
||||
|
||||
def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取ETF净值数据
|
||||
|
||||
Args:
|
||||
code: ETF代码
|
||||
"""
|
||||
original_proxy = self._clear_proxy()
|
||||
|
||||
try:
|
||||
pro = self._get_pro_api()
|
||||
|
||||
ts_code = code.replace(".SS", ".SH")
|
||||
|
||||
df = pro.fund_nav(
|
||||
ts_code=ts_code,
|
||||
start_date=start_date.replace("-", ""),
|
||||
end_date=end_date.replace("-", "")
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
df = df.rename(columns={
|
||||
"nav_date": "date",
|
||||
"unit_nav": "nav",
|
||||
})
|
||||
|
||||
df["date"] = pd.to_datetime(df["date"])
|
||||
df = df.set_index("date")
|
||||
df = df.sort_index()
|
||||
df["code"] = code
|
||||
|
||||
return df[['code', 'nav']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载ETF净值 {code} 失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
self._restore_proxy(original_proxy)
|
||||
|
||||
def is_china_index(self, code: str) -> bool:
|
||||
"""判断是否为A股指数"""
|
||||
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI")
|
||||
|
||||
def is_futures(self, code: str) -> bool:
|
||||
"""判断是否为期货"""
|
||||
return ".SHF" in code or ".NYM" in code or ".DCE" in code or ".CZC" in code
|
||||
|
||||
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
通用数据获取(自动判断类型)
|
||||
|
||||
Args:
|
||||
code: 代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
"""
|
||||
if self.is_china_index(code):
|
||||
return self.fetch_index(code, start_date, end_date)
|
||||
elif self.is_futures(code):
|
||||
return self.fetch_futures(code, start_date, end_date)
|
||||
else:
|
||||
return None
|
||||
112
datasource/yfinance_source.py
Normal file
112
datasource/yfinance_source.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
YFinance数据源
|
||||
|
||||
获取港股、美股数据(通过SSH隧道)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
import pandas as pd
|
||||
import urllib3
|
||||
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
|
||||
class YFinanceSource:
|
||||
"""YFinance数据源"""
|
||||
|
||||
# 代码映射(项目代码 -> YFinance格式)
|
||||
CODE_MAP = {
|
||||
# 港股
|
||||
"HSTECH.HK": "3033.HK", # 恒生科技指数
|
||||
"HSI": "^HSI", # 恒生指数
|
||||
# 美股指数
|
||||
"NDX": "^NDX", # 纳斯达克100
|
||||
"SPX": "^GSPC", # 标普500
|
||||
"DJI": "^DJI", # 道琼斯
|
||||
# 日本/欧洲
|
||||
"N225": "^N225", # 日经225
|
||||
"GDAXI": "^GDAXI", # 德国DAX
|
||||
# 商品
|
||||
"CL.NYM": "CL=F", # WTI原油期货
|
||||
}
|
||||
|
||||
def __init__(self, use_ssh_tunnel: bool = False):
|
||||
"""
|
||||
初始化YFinance数据源
|
||||
|
||||
Args:
|
||||
use_ssh_tunnel: 是否使用SSH隧道(需先启动SSHTunnelManager)
|
||||
"""
|
||||
self.use_ssh_tunnel = use_ssh_tunnel
|
||||
self._delay = 0.5 # 请求延迟(避免限流)
|
||||
|
||||
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取数据
|
||||
|
||||
Args:
|
||||
code: 代码(如 'NDX', 'N225', 'HSI')
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
"""
|
||||
import yfinance as yf
|
||||
|
||||
# 添加延迟避免限流
|
||||
time.sleep(self._delay)
|
||||
|
||||
# 转换代码格式
|
||||
yf_code = self.CODE_MAP.get(code, code)
|
||||
|
||||
try:
|
||||
ticker = yf.Ticker(yf_code)
|
||||
|
||||
# end_date 需要加一天(yfinance的end是排他的)
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
||||
|
||||
# auto_adjust=False 获取不复权价格
|
||||
df = ticker.history(
|
||||
start=start_date,
|
||||
end=end_dt.strftime("%Y-%m-%d"),
|
||||
auto_adjust=False
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
df = df.rename(columns={
|
||||
"Open": "open",
|
||||
"High": "high",
|
||||
"Low": "low",
|
||||
"Close": "close",
|
||||
"Volume": "volume",
|
||||
})
|
||||
|
||||
# 确保索引是日期格式
|
||||
df.index = pd.to_datetime(df.index, utc=True).tz_localize(None).normalize()
|
||||
df.index.name = "date"
|
||||
|
||||
# 添加代码列
|
||||
df["code"] = code
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"YFinance下载 {code} ({yf_code}) 失败: {e}")
|
||||
return None
|
||||
|
||||
def is_yfinance_code(self, code: str) -> bool:
|
||||
"""判断是否需要YFinance获取"""
|
||||
# 非A股代码
|
||||
china_suffixes = ['.SH', '.SZ', '.SS', '.CSI']
|
||||
futures_suffixes = ['.SHF', '.NYM', '.DCE', '.CZC']
|
||||
|
||||
# A股或期货用Tushare,其他用YFinance
|
||||
return not any(code.endswith(s) for s in china_suffixes + futures_suffixes)
|
||||
Reference in New Issue
Block a user