Files
etf/core/datasource/hybrid_source.py
aszerW b7478bf2ef fix(datasource): 修正非A股指数前向填充逻辑
- 将前向填充范围扩大至所有非A股指数
- 说明所有市场(港股、美股、黄金、加密货币)在T+1日09:00前已收盘
- 保障数据针对多市场的时效性和完整性
2026-03-25 22:25:48 +08:00

673 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
混合数据源模块
- 中国A股指数: Tushare
- 港股/美股: YFinance (支持 SSH 隧道)
- 加密货币: CCXT/OKX (支持 SSH->HTTP 代理)
"""
import os
import sys
import time
import subprocess
from pathlib import Path
from typing import Optional, Tuple, Dict
from datetime import datetime
import pandas as pd
import yfinance as yf
import urllib3
# 禁用 SSL 警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class SSHTunnelManager:
"""SSH 隧道管理器"""
def __init__(self, config: dict):
self.enabled = config.get("enabled", False)
self.host = config.get("host", "")
self.port = config.get("port", 22)
self.username = config.get("username", "")
self.local_port = config.get("local_port", 1080)
self._process: Optional[subprocess.Popen] = None
# 处理 key_path如果是相对路径转换为绝对路径
key_path = config.get("key_path", "")
if key_path and not os.path.isabs(key_path):
# 相对于项目根目录
project_root = Path(__file__).parent.parent.parent
key_path = str(project_root / key_path)
self.key_path = key_path
print(f"SSH 私钥路径: {self.key_path}")
def start(self) -> bool:
"""启动 SSH 隧道"""
if not self.enabled:
return True
if not all([self.host, self.username, self.key_path]):
print("SSH 配置不完整,跳过隧道建立")
return False
print(f"建立 SSH 隧道: {self.host}:{self.port} -> 本地 SOCKS5 端口 {self.local_port}")
cmd = [
"ssh", "-N", "-D", f"127.0.0.1:{self.local_port}",
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-i", self.key_path,
"-p", str(self.port),
f"{self.username}@{self.host}"
]
try:
self._process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
time.sleep(2)
if self._process.poll() is not None:
stdout, stderr = self._process.communicate()
print("✗ SSH 隧道建立失败")
if stderr:
print(f"错误: {stderr.decode()}")
return False
# 设置代理环境变量
proxy_url = f"socks5://127.0.0.1:{self.local_port}"
os.environ["HTTP_PROXY"] = proxy_url
os.environ["HTTPS_PROXY"] = proxy_url
os.environ["ALL_PROXY"] = proxy_url
print(f"✓ SSH 隧道已建立: {proxy_url}")
time.sleep(1)
return True
except Exception as e:
print(f"✗ SSH 隧道异常: {e}")
return False
def stop(self):
"""停止 SSH 隧道"""
if self._process:
self._process.terminate()
self._process.wait()
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
os.environ.pop(key, None)
print("SSH 隧道已关闭")
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
class HybridDataSource:
"""
混合数据源
- 中国A股指数 (SH/SZ): Tushare
- 港股/美股: YFinance
- 加密货币: CCXT/OKX (通过 SSH->HTTP 代理)
"""
# YFinance 代码映射 (代码 -> YFinance格式)
YF_CODE_MAP = {
# 港股
"HSTECH": "3033.HK", # 恒生科技指数 ETF
"HSI": "^HSI", # 恒生指数
# 美股指数
"NDX": "^NDX", # 纳斯达克100
"SPX": "^GSPC", # 标普500
"DJI": "^DJI", # 道琼斯
# 黄金
"GC=F": "GC=F", # 黄金期货
}
# CCXT 代码映射 (代码 -> CCXT格式)
CCXT_CODE_MAP = {
"BTC": "BTC/USDT", # OKX 比特币现货
"ETH": "ETH/USDT", # OKX 以太坊现货
}
def __init__(self, ssh_config: Optional[dict] = None, use_cache: bool = True):
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self._tunnel: Optional[SSHTunnelManager] = None
self._tushare_token: Optional[str] = None
def _is_china_index(self, code: str) -> bool:
"""判断是否为中国A股指数"""
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS")
def _is_crypto(self, code: str) -> bool:
"""判断是否为加密货币"""
return code in self.CCXT_CODE_MAP
def _get_tushare_token(self) -> str:
"""获取 Tushare Token"""
if self._tushare_token is None:
import os
self._tushare_token = os.getenv("TUSHARE_TOKEN")
if not self._tushare_token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
return self._tushare_token
def _fetch_tushare(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取中国指数数据(不使用代理,直连国内服务器)"""
import os
# 临时清除代理环境变量Tushare 是国内服务,不需要代理)
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式 (000300.SS -> 000300.SH)
ts_code = code.replace(".SS", ".SH")
# 获取日线数据
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取A股ETF数据fund_daily接口"""
import os
# 临时清除代理环境变量
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式 (510300.SH -> 510300.SH)
ts_code = code.replace(".SS", ".SH")
# 获取ETF日线数据
df = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"open": "open",
"high": "high",
"low": "low",
"close": "close",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载ETF {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 Tushare 获取ETF净值数据fund_nav接口"""
import os
# 临时清除代理环境变量
original_proxy = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original_proxy[key] = os.environ.pop(key, None)
try:
import tushare as ts
pro = ts.pro_api(self._get_tushare_token())
# 转换代码格式
ts_code = code.replace(".SS", ".SH")
# 获取ETF净值数据
df = pro.fund_nav(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"nav_date": "date",
"unit_nav": "nav",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
# 添加代码列
df["code"] = code
return df
except Exception as e:
print(f"Tushare 下载ETF净值 {code} 失败: {e}")
return None
finally:
# 恢复代理环境变量
for key, value in original_proxy.items():
if value is not None:
os.environ[key] = value
def _fetch_yfinance(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""使用 YFinance 获取数据"""
import time
# 转换代码格式
yf_code = self.YF_CODE_MAP.get(code, code)
# 添加延迟以避免限流
time.sleep(0.5)
try:
ticker = yf.Ticker(yf_code)
data = ticker.history(start=start_date, end=end_date)
if len(data) == 0:
return None
# 标准化列名
data = data.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 添加代码列
data["code"] = code
return data
except Exception as e:
print(f"YFinance 下载 {code} ({yf_code}) 失败: {e}")
return None
def _fetch_ccxt(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""使用 CCXT/OKX 获取加密货币数据(支持 HTTP 代理)"""
import ccxt
# 转换代码格式
ccxt_code = self.CCXT_CODE_MAP.get(code, code)
# 配置 CCXT
config = {'enableRateLimit': True}
if http_proxy:
config['proxies'] = {'http': http_proxy, 'https': http_proxy}
try:
exchange = ccxt.okx(config)
# 获取日线数据
since = int(pd.Timestamp(start_date).timestamp() * 1000)
all_ohlcv = []
limit = 100
while since < int(pd.Timestamp(end_date).timestamp() * 1000):
ohlcv = exchange.fetch_ohlcv(ccxt_code, '1d', since, limit)
if not ohlcv:
break
all_ohlcv.extend(ohlcv)
since = ohlcv[-1][0] + 86400000
if not all_ohlcv:
return None
# 转换为 DataFrame
df = pd.DataFrame(all_ohlcv, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
# 转换时间戳为日期索引
df.index = pd.DatetimeIndex(pd.to_datetime(df['timestamp'], unit='ms', utc=True)).tz_localize(None).normalize()
df.index.name = 'date'
df = df[['open', 'high', 'low', 'close', 'volume']]
# 过滤日期范围
start_ts = pd.Timestamp(start_date)
end_ts = pd.Timestamp(end_date)
df = df[(df.index >= start_ts) & (df.index <= end_ts)]
df['code'] = code
return df
except Exception as e:
print(f"CCXT 下载 {code} ({ccxt_code}) 失败: {e}")
return None
def fetch_single(self, code: str, start_date: str, end_date: str, http_proxy: str = None) -> Optional[pd.DataFrame]:
"""获取单个标的的数据"""
if self._is_china_index(code):
return self._fetch_tushare(code, start_date, end_date)
elif self._is_crypto(code):
return self._fetch_ccxt(code, start_date, end_date, http_proxy)
else:
return self._fetch_yfinance(code, start_date, end_date)
def fetch_all(
self,
code_config: dict, # {代码: {name, etf, market}}
benchmark_code: str,
start_date: str,
end_date: str,
) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], Optional[pd.DataFrame], list]:
"""
批量获取数据(支持指数-ETF映射
Args:
code_config: 配置字典,格式为 {index_code: {name, etf, market}}
benchmark_code: 基准指数代码
start_date: 开始日期
end_date: 结束日期
Returns:
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes)
- index_data: 指数数据(用于因子计算)
- etf_data: ETF价格数据用于收益计算
- etf_nav_data: ETF净值数据用于溢价率计算
- benchmark_data: 基准数据
- valid_codes: 有效代码列表
"""
index_data_list = []
etf_data_list = []
valid_codes = []
# 提取指数代码和ETF代码
index_codes = list(code_config.keys())
etf_codes = {}
for idx_code, cfg in code_config.items():
if cfg.get('etf'):
etf_codes[idx_code] = cfg['etf']
print(f"开始下载 {len(index_codes)} 只标的的数据...")
print(f" 指数代码: {len(index_codes)}")
print(f" ETF映射: {len(etf_codes)}")
china_codes = [c for c in index_codes if self._is_china_index(c)]
global_codes = [c for c in index_codes if not self._is_china_index(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 港股/美股/加密货币: {len(global_codes)}")
# 检查是否需要启动 socks2http 代理(用于加密货币)
crypto_codes = [c for c in index_codes if self._is_crypto(c)]
http_proxy = None
socks2http_proc = None
# 只有在 SSH 隧道已建立时才启动 socks2http
if crypto_codes and self._tunnel is not None:
import subprocess
import time
print(f"\n 启动 socks2http 代理服务(用于加密货币)...")
try:
# 启动 socks2http.py 作为子进程
socks2http_path = Path(__file__).parent / "socks2http.py"
socks2http_proc = subprocess.Popen(
[sys.executable, str(socks2http_path)],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
time.sleep(2) # 等待代理启动
http_proxy = "http://127.0.0.1:8080"
print(f" ✓ HTTP 代理已启动: {http_proxy}")
except Exception as e:
print(f" ✗ 启动代理失败: {e}")
# 下载指数数据
print("\n [1/2] 下载指数数据(用于因子计算)...")
for code in index_codes:
if self._is_china_index(code):
source = "Tushare"
elif self._is_crypto(code):
source = "CCXT/OKX"
else:
source = "YFinance"
name = code_config[code].get('name', code)
print(f" 下载 {code} ({name}) - {source}...", end=" ")
# 加密货币使用 HTTP 代理
proxy = http_proxy if self._is_crypto(code) else None
data = self.fetch_single(code, start_date, end_date, proxy)
if data is not None and len(data) > 0:
# 标准化数据格式
data = data.copy()
data['source'] = source
data['code'] = code # 确保code列正确
# 确保索引是日期格式且无时区,只保留日期部分(去掉时间)
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
index_data_list.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 下载ETF数据价格+净值,用于溢价率计算)
etf_nav_data_list = [] # ETF净值数据
if etf_codes:
print("\n [2/2] 下载ETF数据价格+净值,用于溢价率计算)...")
for idx_code, etf_code in etf_codes.items():
name = code_config[idx_code].get('name', idx_code)
market = code_config[idx_code].get('market', 'A')
# 加密货币跳过ETF下载
if market == 'CRYPTO':
continue
print(f" 下载 ETF {etf_code} (对应指数 {idx_code})...", end=" ")
# 获取ETF价格数据
price_data = self._fetch_etf(etf_code, start_date, end_date)
# 获取ETF净值数据
nav_data = self._fetch_etf_nav(etf_code, start_date, end_date)
if price_data is not None and len(price_data) > 0:
# 使用指数代码作为列名,保持与指数数据一致
price_data = price_data.copy()
price_data['source'] = 'Tushare-ETF'
price_data['code'] = idx_code
price_data.index = pd.to_datetime(price_data.index, utc=True).tz_localize(None).normalize()
etf_data_list.append(price_data[['code', 'close', 'source']])
# 处理净值数据
if nav_data is not None and len(nav_data) > 0:
nav_data = nav_data.copy()
nav_data['code'] = idx_code
nav_data.index = pd.to_datetime(nav_data.index, utc=True).tz_localize(None).normalize()
etf_nav_data_list.append(nav_data[['code', 'nav']])
print(f"✓ 价格{len(price_data)}条 净值{len(nav_data)}")
else:
print(f"✓ 价格{len(price_data)}条 (无净值数据)")
else:
print(f"✗ 无数据")
# 关闭 socks2http 代理
if socks2http_proc:
socks2http_proc.terminate()
socks2http_proc.wait()
print(f"\n socks2http 代理已关闭")
if not index_data_list:
return None, None, None, None, []
# 处理指数数据
print(f"\n整理指数数据(用于因子计算)...")
index_df = pd.concat(index_data_list, ignore_index=False)
index_df = index_df.reset_index()
if 'index' in index_df.columns:
index_df = index_df.rename(columns={'index': 'date'})
index_df['date'] = pd.to_datetime(index_df['date']).dt.normalize()
# 透视为宽格式
index_data = index_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
# 以A股交易日为基准对齐所有数据
tushare_codes = [c for c in valid_codes if self._is_china_index(c)]
if tushare_codes:
primary_dates = index_data[tushare_codes[0]].dropna().index
print(f" 主市场交易日: {len(primary_dates)}")
# 重新索引到主市场交易日
index_data = index_data.reindex(primary_dates)
# 对所有非A股指数进行前向填充
# 所有市场港股、美股、黄金、加密货币在T+1日09:00前都已收盘
non_a_codes = [c for c in valid_codes if not self._is_china_index(c)]
for code in non_a_codes:
if code in index_data.columns:
index_data[code] = index_data[code].ffill().bfill()
print(f" 非A股标的: {len(non_a_codes)} 只 (已前向填充)")
print(f" 时间范围: {index_data.index[0]} ~ {index_data.index[-1]}")
print(f" 交易日数: {len(index_data)}")
# 处理ETF数据
if etf_data_list:
print(f"\n整理ETF数据用于收益计算...")
etf_df = pd.concat(etf_data_list, ignore_index=False)
etf_df = etf_df.reset_index()
if 'index' in etf_df.columns:
etf_df = etf_df.rename(columns={'index': 'date'})
etf_df['date'] = pd.to_datetime(etf_df['date']).dt.normalize()
# 透视为宽格式
etf_data = etf_df.pivot_table(
index='date',
columns='code',
values='close',
aggfunc='first'
)
# 对齐到主市场交易日
if tushare_codes:
etf_data = etf_data.reindex(primary_dates)
print(f" ETF价格数据: {len(etf_data.columns)}")
else:
# 如果没有ETF数据使用指数数据代替
etf_data = index_data.copy()
print(f"\n无ETF映射使用指数数据代替")
# 处理ETF净值数据
etf_nav_data = None
if etf_nav_data_list:
print(f"\n整理ETF净值数据用于溢价率计算...")
nav_df = pd.concat(etf_nav_data_list, ignore_index=False)
nav_df = nav_df.reset_index()
if 'index' in nav_df.columns:
nav_df = nav_df.rename(columns={'index': 'date'})
nav_df['date'] = pd.to_datetime(nav_df['date']).dt.normalize()
# 透视为宽格式
etf_nav_data = nav_df.pivot_table(
index='date',
columns='code',
values='nav',
aggfunc='first'
)
# 对齐到主市场交易日并前向填充缺失值净值数据通常T+1更新
if tushare_codes:
etf_nav_data = etf_nav_data.reindex(primary_dates)
etf_nav_data = etf_nav_data.ffill() # 前向填充缺失的净值数据
print(f" ETF净值数据: {len(etf_nav_data.columns)}")
# 获取基准数据
benchmark_data = self.fetch_single(benchmark_code, start_date, end_date)
if benchmark_data is not None:
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
# 对齐到主市场交易日
if tushare_codes:
benchmark_data = benchmark_data.reindex(primary_dates)
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes
def __enter__(self):
"""上下文管理器入口"""
if self.ssh_config.get("enabled"):
self._tunnel = SSHTunnelManager(self.ssh_config)
self._tunnel.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""上下文管理器出口"""
if self._tunnel:
self._tunnel.stop()