Files
etf/datasource/hybrid_source.py
aszerW e56bd39400 feat: 创建数据源模块 datasource/
核心功能:
- ssh_tunnel.py: SSH隧道管理器(连接香港ECS)
- tushare_source.py: A股数据获取(指数、ETF、期货)
- yfinance_source.py: 境外数据获取(港股、美股)
- hybrid_source.py: 混合数据源(整合所有)

使用方式:
  from datasource import HybridDataSource

  source = HybridDataSource.from_yaml('config/strategies/rotation.yaml')
  result = source.fetch_all()

更新 RotationStrategy 使用新数据源模块
2026-05-12 00:03:25 +08:00

285 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
混合数据源
整合 TushareA股 + YFinance境外数据获取
"""
import os
import time
from typing import Optional, Tuple, Dict, List
from datetime import datetime
from pathlib import Path
import pandas as pd
from .ssh_tunnel import SSHTunnelManager
from .tushare_source import TushareSource
from .yfinance_source import YFinanceSource
class HybridDataSource:
"""
混合数据源
- A股指数/ETF/期货: Tushare
- 港股/美股/商品: YFinance通过SSH隧道
使用方式:
from datasource import HybridDataSource
source = HybridDataSource.from_yaml('config/strategies/rotation.yaml')
result = source.fetch_all()
"""
def __init__(
self,
ssh_config: Optional[dict] = None,
use_cache: bool = True,
cache_dir: str = "data/etf_cache/daily"
):
"""
初始化混合数据源
Args:
ssh_config: SSH隧道配置
use_cache: 是否使用缓存
cache_dir: 缓存目录
"""
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self.cache_dir = cache_dir
# 数据源实例
self._tushare = TushareSource()
self._yfinance = YFinanceSource()
# SSH隧道延迟初始化
self._tunnel: Optional[SSHTunnelManager] = None
@classmethod
def from_yaml(cls, config_path: str) -> 'HybridDataSource':
"""从YAML配置创建实例"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return cls(
ssh_config=config.get('ssh_tunnel', {}),
use_cache=config.get('use_cache', True)
)
def _start_tunnel(self) -> bool:
"""启动SSH隧道"""
if self._tunnel is None and self.ssh_config.get('enabled'):
self._tunnel = SSHTunnelManager(self.ssh_config)
return self._tunnel.start()
return True
def _stop_tunnel(self):
"""停止SSH隧道"""
if self._tunnel:
self._tunnel.stop()
self._tunnel = None
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取单个标的数据
Args:
code: 标的代码
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame with OHLCV data
"""
# 判断数据源
if self._tushare.is_china_index(code) or self._tushare.is_futures(code):
return self._tushare.fetch(code, start_date, end_date)
else:
# YFinance需要SSH隧道
self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date)
def fetch_all(
self,
code_config: dict,
benchmark_code: str = "000300.SH",
start_date: str = "2019-01-01",
end_date: str = None
) -> Tuple[
Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式)
Optional[pd.DataFrame], # etf_data: ETF价格宽格式
Optional[pd.DataFrame], # etf_nav_data: ETF净值
Optional[pd.DataFrame], # benchmark_data: 基准数据
List[str], # valid_codes: 有效代码列表
Dict[str, pd.DataFrame] # index_ohlcv_data: 原始OHLCV数据
]:
"""
批量获取数据
Args:
code_config: 标的配置 {代码: {name, etf, market}}
benchmark_code: 基准代码
start_date: 开始日期
end_date: 结束日期
Returns:
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data)
"""
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d')
# 启动SSH隧道
self._start_tunnel()
index_codes = list(code_config.keys())
etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')}
print(f"开始下载 {len(index_codes)} 只标的的数据...")
print(f" 指数代码: {len(index_codes)}")
print(f" ETF映射: {len(etf_codes)}")
# 分类统计
china_codes = [c for c in index_codes if self._tushare.is_china_index(c)]
futures_codes = [c for c in index_codes if self._tushare.is_futures(c)]
yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 期货合约: {len(futures_codes)}")
print(f" 港股/美股: {len(yf_codes)}")
# 下载指数数据
print("\n [1/2] 下载指数数据...")
index_data_list = []
index_ohlcv_data = {}
valid_codes = []
for code in index_codes:
name = code_config[code].get('name', code)
source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance"
print(f" 下载 {code} ({name}) - {source}...", end=" ")
data = self.fetch_single(code, start_date, end_date)
if data is not None and len(data) > 0:
# 标准化
data = data.copy()
data['source'] = source
data['code'] = code
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
index_ohlcv_data[code] = data.copy()
index_data_list.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 下载ETF数据
etf_data_list = []
etf_nav_data_list = []
if etf_codes:
print("\n [2/2] 下载ETF数据...")
for idx_code, etf_code in etf_codes.items():
name = code_config[idx_code].get('name', idx_code)
print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ")
# ETF价格
etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date)
# ETF净值
etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date)
if etf_data is not None and len(etf_data) > 0:
etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize()
etf_data_list.append(etf_data[['code', 'close']])
price_count = len(etf_data)
nav_count = len(etf_nav) if etf_nav else 0
print(f"✓ 价格{price_count}条 净值{nav_count}")
else:
print("✗ 无数据")
if etf_nav is not None and len(etf_nav) > 0:
etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize()
etf_nav_data_list.append(etf_nav[['code', 'nav']])
# 整合数据
index_data = None
if index_data_list:
index_data = pd.concat(index_data_list)
index_data = index_data.pivot(columns='code', values='close')
etf_data = None
if etf_data_list:
etf_data = pd.concat(etf_data_list)
etf_data = etf_data.pivot(columns='code', values='close')
etf_nav_data = None
if etf_nav_data_list:
etf_nav_data = pd.concat(etf_nav_data_list)
etf_nav_data = etf_nav_data.pivot(columns='code', values='nav')
# 基准数据
benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date)
if benchmark_data is not None:
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data
def __enter__(self):
self._start_tunnel()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_tunnel()
# 简化接口
def fetch_rotation_data(config_path: str = "config/strategies/rotation.yaml") -> dict:
"""
获取轮动策略数据(简化接口)
Args:
config_path: 配置文件路径
Returns:
{
'index_data': 指数收盘价DataFrame,
'etf_data': ETF价格DataFrame,
'etf_nav_data': ETF净值DataFrame,
'benchmark_data': 基准DataFrame,
'valid_codes': 有效代码列表,
'index_ohlcv_data': 原始OHLCV数据字典
}
"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
source = HybridDataSource.from_yaml(config_path)
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
source.fetch_all(
code_config=config.get('code_list', {}),
benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'),
start_date=config.get('start_date', '2019-01-01'),
end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
)
return {
'index_data': index_data,
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'index_ohlcv_data': index_ohlcv_data
}