refactor(datasource): 分层接口设计,移除HybridDataSource

架构改动:
- 移除 HybridDataSource(功能被 UniversalDataFetcher 覆盖)
- 新增分层接口设计:基础层 + 扩展层

基础层(统一接口):
- fetch(): 统一 OHLCV 接口,自动识别资产类型
- fetch_batch(): 批量获取

扩展层(资产类型特有):
- fetch_etf_adj(): A股 ETF 后复权价格
- fetch_us_adj(): 美股复权价格
- fetch_etf_with_nav(): ETF 价格 + 净值 + 溢价率

其他修改:
- YFinanceSource: 新增 fetch_adj() 方法
- strategy.py: 改用 UniversalDataFetcher 替代 HybridDataSource
- __init__.py: 移除 HybridDataSource 导出
This commit is contained in:
2026-05-23 12:46:48 +08:00
parent 209dd7fd83
commit 1148d3166c
5 changed files with 235 additions and 333 deletions

View File

@@ -4,22 +4,27 @@
核心数据获取能力:
- A股数据Tushare指数、ETF、期货
- 境外数据YFinance港股、美股通过SSH隧道
- 加密货币CCXTOKX通过 socks2http
架构设计:
- 分层架构:对外统一接口,对内各资产类型独立实现
- 分层架构:基础层统一接口,扩展层资产类型特有方法
- Flask APILRU + TTL 双缓存机制
用法:
from datasource import UniversalDataFetcher, AssetType
from datasource import UniversalDataFetcher
# 基础层:统一 OHLCV 接口
fetcher = UniversalDataFetcher()
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
# 扩展层:资产类型特有方法
df_adj = fetcher.fetch_etf_adj("513100.SH", ...) # ETF 后复权
df_adj = fetcher.fetch_us_adj("AAPL", ...) # 美股复权
"""
from .ssh_tunnel import SSHTunnelManager
from .tushare_source import TushareSource
from .yfinance_source import YFinanceSource
from .hybrid_source import HybridDataSource
from .asset_type_detector import AssetTypeDetector, AssetType
from .universal_fetcher import UniversalDataFetcher
@@ -27,7 +32,6 @@ __all__ = [
'SSHTunnelManager',
'TushareSource',
'YFinanceSource',
'HybridDataSource',
'AssetTypeDetector',
'AssetType',
'UniversalDataFetcher',

View File

@@ -1,301 +0,0 @@
"""
混合数据源
整合 TushareA股 + YFinance境外数据获取
"""
import os
import time
from typing import Optional, Tuple, Dict, List
from datetime import datetime
from pathlib import Path
import pandas as pd
from .ssh_tunnel import SSHTunnelManager
from .tushare_source import TushareSource
from .yfinance_source import YFinanceSource
class HybridDataSource:
"""
混合数据源
- A股指数/ETF/期货: Tushare
- 港股/美股/商品: YFinance通过SSH隧道
使用方式:
from datasource import HybridDataSource
source = HybridDataSource.from_yaml('strategies/rotation/config.yaml')
result = source.fetch_all()
"""
def __init__(
self,
ssh_config: Optional[dict] = None,
use_cache: bool = True,
cache_dir: str = "data/etf_cache/daily"
):
"""
初始化混合数据源
Args:
ssh_config: SSH隧道配置
use_cache: 是否使用缓存
cache_dir: 缓存目录
"""
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self.cache_dir = cache_dir
# 数据源实例
self._tushare = TushareSource()
self._yfinance = YFinanceSource()
# SSH隧道延迟初始化
self._tunnel: Optional[SSHTunnelManager] = None
@classmethod
def from_yaml(cls, config_path: str) -> 'HybridDataSource':
"""从YAML配置创建实例"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return cls(
ssh_config=config.get('ssh_tunnel', {}),
use_cache=config.get('use_cache', True)
)
def _start_tunnel(self) -> bool:
"""启动SSH隧道"""
if self._tunnel is None and self.ssh_config.get('enabled'):
self._tunnel = SSHTunnelManager(self.ssh_config)
return self._tunnel.start()
return True
def _stop_tunnel(self):
"""停止SSH隧道"""
if self._tunnel:
self._tunnel.stop()
self._tunnel = None
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取单个标的数据
Args:
code: 标的代码
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame with OHLCV data
"""
# 判断数据源
if self._tushare.is_china_index(code) or self._tushare.is_futures(code):
return self._tushare.fetch(code, start_date, end_date)
else:
# YFinance需要SSH隧道
self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date)
def fetch_all(
self,
code_config: dict,
benchmark_code: str = "000300.SH",
start_date: str = "2019-01-01",
end_date: str = None
) -> Tuple[
Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式)
Optional[pd.DataFrame], # etf_data: ETF价格宽格式
Optional[pd.DataFrame], # etf_nav_data: ETF净值
Optional[pd.DataFrame], # benchmark_data: 基准数据
List[str], # valid_codes: 有效代码列表
Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据
Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射
]:
"""
批量获取数据
Args:
code_config: 标的配置 {代码: {name, etf, market}}
benchmark_code: 基准代码
start_date: 开始日期
end_date: 结束日期
Returns:
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map)
"""
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d')
# 启动SSH隧道
self._start_tunnel()
index_codes = list(code_config.keys())
etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')}
print(f"开始下载 {len(index_codes)} 只标的的数据...")
print(f" 指数代码: {len(index_codes)}")
print(f" ETF映射: {len(etf_codes)}")
# 分类统计
china_codes = [c for c in index_codes if self._tushare.is_china_index(c)]
futures_codes = [c for c in index_codes if self._tushare.is_futures(c)]
yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 期货合约: {len(futures_codes)}")
print(f" 港股/美股: {len(yf_codes)}")
# 下载指数数据
print("\n [1/2] 下载指数数据...")
index_data_list = []
index_ohlcv_data = {}
valid_codes = []
for code in index_codes:
name = code_config[code].get('name', code)
source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance"
print(f" 下载 {code} ({name}) - {source}...", end=" ")
data = self.fetch_single(code, start_date, end_date)
if data is not None and len(data) > 0:
# 标准化
data = data.copy()
data['source'] = source
data['code'] = code
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
index_ohlcv_data[code] = data.copy()
index_data_list.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 下载ETF数据
etf_data_list = []
etf_nav_data_list = []
if etf_codes:
print("\n [2/2] 下载ETF数据...")
for idx_code, etf_code in etf_codes.items():
name = code_config[idx_code].get('name', idx_code)
print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ")
# ETF价格
etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date)
# ETF净值
etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date)
if etf_data is not None and len(etf_data) > 0:
etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize()
etf_data_list.append(etf_data[['code', 'close']])
price_count = len(etf_data)
nav_count = len(etf_nav) if etf_nav is not None else 0
print(f"✓ 价格{price_count}条 净值{nav_count}")
else:
print("✗ 无数据")
if etf_nav is not None and len(etf_nav) > 0:
etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize()
etf_nav_data_list.append(etf_nav[['code', 'nav']])
# 整合数据
index_data = None
if index_data_list:
index_data = pd.concat(index_data_list)
if 'code' in index_data.columns and 'close' in index_data.columns:
index_data = index_data.reset_index()
if 'index' in index_data.columns:
index_data = index_data.rename(columns={'index': 'date'})
index_data['date'] = pd.to_datetime(index_data['date']).dt.normalize()
index_data = index_data.pivot_table(index='date', columns='code', values='close')
etf_data = None
if etf_data_list:
etf_data = pd.concat(etf_data_list)
if 'code' in etf_data.columns and 'close' in etf_data.columns:
etf_data = etf_data.reset_index()
if 'index' in etf_data.columns:
etf_data = etf_data.rename(columns={'index': 'date'})
etf_data['date'] = pd.to_datetime(etf_data['date']).dt.normalize()
etf_data = etf_data.pivot_table(index='date', columns='code', values='close')
etf_nav_data = None
if etf_nav_data_list:
etf_nav_data = pd.concat(etf_nav_data_list)
if 'code' in etf_nav_data.columns and 'nav' in etf_nav_data.columns:
etf_nav_data = etf_nav_data.reset_index()
if 'index' in etf_nav_data.columns:
etf_nav_data = etf_nav_data.rename(columns={'index': 'date'})
etf_nav_data['date'] = pd.to_datetime(etf_nav_data['date']).dt.normalize()
etf_nav_data = etf_nav_data.pivot_table(index='date', columns='code', values='nav')
# 基准数据
benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date)
if benchmark_data is not None:
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes
def __enter__(self):
self._start_tunnel()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_tunnel()
# 简化接口
def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict:
"""
获取轮动策略数据(简化接口)
Args:
config_path: 配置文件路径
Returns:
{
'index_data': 指数收盘价DataFrame,
'etf_data': ETF价格DataFrame,
'etf_nav_data': ETF净值DataFrame,
'benchmark_data': 基准DataFrame,
'valid_codes': 有效代码列表,
'index_ohlcv_data': 原始OHLCV数据字典
}
"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
source = HybridDataSource.from_yaml(config_path)
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
source.fetch_all(
code_config=config.get('code_list', {}),
benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'),
start_date=config.get('start_date', '2019-01-01'),
end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
)
return {
'index_data': index_data,
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'index_ohlcv_data': index_ohlcv_data
}

View File

@@ -455,4 +455,64 @@ class UniversalDataFetcher:
def is_supported(self, code: str) -> bool:
"""判断是否支持该代码"""
return AssetTypeDetector.detect(code) != AssetType.UNKNOWN
return AssetTypeDetector.detect(code) != AssetType.UNKNOWN
# ============================================================
# 扩展层:资产类型特有方法(复权/净值/溢价率)
# ============================================================
def fetch_etf_adj(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""
获取 A股 ETF 后复权价格
通过 fund_daily + fund_adj 手动计算后复权价格
- 消除份额折算(拆分)对收益率的影响
- 适用于计算真实收益率
Args:
code: ETF代码'159915.SZ', '513100.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, close, adj_factor, close_hfq
示例:
# 纳指ETF后复权正确计算收益率
df = fetcher.fetch_etf_adj("513100.SH", "2020-01-01", "2024-12-31")
# 使用 close_hfq 计算收益率,而非 close
"""
return self._tushare.fetch_etf_adj(code, start_date, end_date)
def fetch_us_adj(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""
获取美股复权价格
使用 YFinance auto_adjust=True
- 消除拆分(split)和分红(dividend)对价格的影响
- 适用于美股股票/ETF
Args:
code: 美股代码,如 'AAPL', 'TSLA', 'QQQ'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, high, low, close, volume (复权后)
示例:
# 苹果复权价格(包含分红和拆分调整)
df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31")
"""
self._start_tunnel()
return self._yfinance.fetch_adj(code, start_date, end_date)

View File

@@ -114,6 +114,70 @@ class YFinanceSource:
print(f"YFinance下载 {code} ({yf_code}) 失败: {e}")
return None
def fetch_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取复权价格数据
使用 auto_adjust=True 获取复权后的价格
- 消除拆分(split)和分红(dividend)对价格的影响
- 适用于美股股票/ETF
Args:
code: 代码(如 'AAPL', 'TSLA', 'QQQ'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, high, low, close, volume (复权后)
"""
import yfinance as yf
# 添加延迟避免限流
time.sleep(self._delay)
# 转换代码格式
yf_code = self.CODE_MAP.get(code, code)
try:
ticker = yf.Ticker(yf_code)
# end_date 需要加一天yfinance的end是排他的
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
# auto_adjust=True 获取复权价格
df = ticker.history(
start=start_date,
end=end_dt.strftime("%Y-%m-%d"),
auto_adjust=True
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 确保索引是日期格式
df.index = pd.to_datetime(df.index, utc=True).tz_localize(None).normalize()
df.index.name = "date"
# 添加代码列和标记
df["code"] = code
df.attrs['code'] = code
df.attrs['adjusted'] = True
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"YFinance下载复权数据 {code} ({yf_code}) 失败: {e}")
return None
def is_yfinance_code(self, code: str) -> bool:
"""判断是否需要YFinance获取"""
# 非A股代码

View File

@@ -5,6 +5,7 @@
"""
import pandas as pd
import numpy as np
import yaml
from datetime import datetime
from pathlib import Path
@@ -113,7 +114,7 @@ class RotationStrategy(StrategyBase):
Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource
False 则使用本地 UniversalDataFetcher
"""
code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {})
@@ -237,6 +238,12 @@ class RotationStrategy(StrategyBase):
index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
# 获取 A 股 SSE 官方交易日历
from datasource.tushare_source import TushareSource
tushare = TushareSource()
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
print(f"A股交易日历: {len(a_share_dates)}")
return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式)
@@ -245,7 +252,8 @@ class RotationStrategy(StrategyBase):
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
'a_share_dates': a_share_dates # A股SSE交易日历
}
def _get_data_from_local(
@@ -253,33 +261,90 @@ class RotationStrategy(StrategyBase):
code_list_config: dict,
benchmark_code: str
) -> dict:
"""使用本地 HybridDataSource 获取数据"""
from datasource import HybridDataSource
"""使用本地 UniversalDataFetcher 获取数据"""
from datasource import UniversalDataFetcher
from datasource.tushare_source import TushareSource
ssh_config = self.config.get('ssh_tunnel', {})
data_source = HybridDataSource(
fetcher = UniversalDataFetcher(
ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True)
)
# 调用 fetch_all
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
data_source.fetch_all(
code_config=code_list_config,
benchmark_code=benchmark_code,
start_date=self.start_date,
end_date=self.end_date
)
index_codes = list(code_list_config.keys())
etf_code_map = {idx_code: cfg['etf'] for idx_code, cfg in code_list_config.items() if cfg.get('etf')}
# 获取指数数据
index_ohlcv_data = {}
valid_codes = []
with fetcher: # 使用上下文管理器自动管理 SSH 隧道
for code in index_codes:
data = fetcher.fetch(code, self.start_date, self.end_date)
if data is not None and len(data) > 0:
index_ohlcv_data[code] = data
valid_codes.append(code)
print(f"{code}: {len(data)}")
else:
print(f"{code}: 无数据")
# 构建宽格式收盘价
index_close = None
if index_ohlcv_data:
close_list = []
for code, df in index_ohlcv_data.items():
close_df = df[['close']].copy()
close_df.columns = [code]
close_list.append(close_df)
index_close = pd.concat(close_list, axis=1)
# 获取 ETF 数据
etf_data = None
etf_nav_data = None
tushare = TushareSource()
if etf_code_map:
etf_price_list = []
etf_nav_list = []
for idx_code, etf_code in etf_code_map.items():
# ETF 价格
etf_df = tushare.fetch_etf(etf_code, self.start_date, self.end_date)
if etf_df is not None and len(etf_df) > 0:
etf_df = etf_df[['close']].copy()
etf_df.columns = [etf_code]
etf_price_list.append(etf_df)
# ETF 净值
nav_df = tushare.fetch_etf_nav(etf_code, self.start_date, self.end_date)
if nav_df is not None and len(nav_df) > 0:
nav_df = nav_df[['nav']].copy()
nav_df.columns = [etf_code]
etf_nav_list.append(nav_df)
if etf_price_list:
etf_data = pd.concat(etf_price_list, axis=1)
if etf_nav_list:
etf_nav_data = pd.concat(etf_nav_list, axis=1)
# 基准数据
benchmark_data = tushare.fetch_index(benchmark_code, self.start_date, self.end_date)
# A股交易日历
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
print(f"A股交易日历: {len(a_share_dates)}")
return {
'index_data': index_ohlcv_data, # 原始OHLCV数据
'index_close': index_data, # 对齐后的收盘价(宽格式)
'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
'a_share_dates': a_share_dates # A股SSE交易日历
}
def compute_factors(self, data: dict) -> pd.DataFrame:
@@ -290,17 +355,20 @@ class RotationStrategy(StrategyBase):
index_data = data['index_data']
valid_codes = data['valid_codes']
# 获取A股交易日历作为基准使用已有的对齐后数据索引
index_close = data.get('index_close')
if index_close is not None:
a_share_dates = index_close.index
else:
for code in valid_codes:
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
a_share_dates = index_data[code].index
break
# 获取 A 股 SSE 官方交易日历(优先使用已获取的
a_share_dates = data.get('a_share_dates')
if a_share_dates is None or len(a_share_dates) == 0:
# 回退:使用已有的对齐后数据索引
index_close = data.get('index_close')
if index_close is not None:
a_share_dates = index_close.index
else:
a_share_dates = index_data[valid_codes[0]].index
for code in valid_codes:
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
a_share_dates = index_data[code].index
break
else:
a_share_dates = index_data[valid_codes[0]].index
factor_values = {}
final_valid_codes = []
@@ -408,8 +476,15 @@ class RotationStrategy(StrategyBase):
# 4. 执行回测
print("\n执行回测...")
# 获取A股交易日历从因子数据索引
a_share_dates = signals.index
# 获取 A 股 SSE 官方交易日历(优先使用已获取的
a_share_dates = data.get('a_share_dates')
if a_share_dates is None or len(a_share_dates) == 0:
a_share_dates = signals.index
# 将信号对齐到 A 股日历
if a_share_dates is not signals.index:
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
print(f" 信号对齐到A股日历: {len(signals)}")
# 计算日收益率先在原始交易日历计算再对齐到A股日历
# 关键与因子计算逻辑一致避免交易日不对齐导致收益率NaN