迁移内容: - config/strategies/rotation.yaml → strategies/rotation/config.yaml 路径更新(核心文件): - strategies/rotation/strategy.py(注释示例) - scripts/generate_legacy_report.py(config_path) - run_rotation.py(注释和默认参数) - datasource/hybrid_source.py(from_yaml示例和fetch_rotation_data) 保留: - config/strategies/cci.yaml(无对应策略目录,暂保留) 设计原则:策略模块自包含,配置与实现同目录,方便移植和复制 验证:策略加载成功(候选池11只,回测区间2019-01-01 ~ 2026-05-12)
301 lines
11 KiB
Python
301 lines
11 KiB
Python
"""
|
||
混合数据源
|
||
|
||
整合 Tushare(A股) + YFinance(境外)数据获取
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
from typing import Optional, Tuple, Dict, List
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import pandas as pd
|
||
|
||
from .ssh_tunnel import SSHTunnelManager
|
||
from .tushare_source import TushareSource
|
||
from .yfinance_source import YFinanceSource
|
||
|
||
|
||
class HybridDataSource:
|
||
"""
|
||
混合数据源
|
||
|
||
- A股指数/ETF/期货: Tushare
|
||
- 港股/美股/商品: YFinance(通过SSH隧道)
|
||
|
||
使用方式:
|
||
from datasource import HybridDataSource
|
||
|
||
source = HybridDataSource.from_yaml('strategies/rotation/config.yaml')
|
||
result = source.fetch_all()
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
ssh_config: Optional[dict] = None,
|
||
use_cache: bool = True,
|
||
cache_dir: str = "data/etf_cache/daily"
|
||
):
|
||
"""
|
||
初始化混合数据源
|
||
|
||
Args:
|
||
ssh_config: SSH隧道配置
|
||
use_cache: 是否使用缓存
|
||
cache_dir: 缓存目录
|
||
"""
|
||
self.ssh_config = ssh_config or {}
|
||
self.use_cache = use_cache
|
||
self.cache_dir = cache_dir
|
||
|
||
# 数据源实例
|
||
self._tushare = TushareSource()
|
||
self._yfinance = YFinanceSource()
|
||
|
||
# SSH隧道(延迟初始化)
|
||
self._tunnel: Optional[SSHTunnelManager] = None
|
||
|
||
@classmethod
|
||
def from_yaml(cls, config_path: str) -> 'HybridDataSource':
|
||
"""从YAML配置创建实例"""
|
||
import yaml
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
return cls(
|
||
ssh_config=config.get('ssh_tunnel', {}),
|
||
use_cache=config.get('use_cache', True)
|
||
)
|
||
|
||
def _start_tunnel(self) -> bool:
|
||
"""启动SSH隧道"""
|
||
if self._tunnel is None and self.ssh_config.get('enabled'):
|
||
self._tunnel = SSHTunnelManager(self.ssh_config)
|
||
return self._tunnel.start()
|
||
return True
|
||
|
||
def _stop_tunnel(self):
|
||
"""停止SSH隧道"""
|
||
if self._tunnel:
|
||
self._tunnel.stop()
|
||
self._tunnel = None
|
||
|
||
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||
"""
|
||
获取单个标的数据
|
||
|
||
Args:
|
||
code: 标的代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
DataFrame with OHLCV data
|
||
"""
|
||
# 判断数据源
|
||
if self._tushare.is_china_index(code) or self._tushare.is_futures(code):
|
||
return self._tushare.fetch(code, start_date, end_date)
|
||
else:
|
||
# YFinance需要SSH隧道
|
||
self._start_tunnel()
|
||
return self._yfinance.fetch(code, start_date, end_date)
|
||
|
||
def fetch_all(
|
||
self,
|
||
code_config: dict,
|
||
benchmark_code: str = "000300.SH",
|
||
start_date: str = "2019-01-01",
|
||
end_date: str = None
|
||
) -> Tuple[
|
||
Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式)
|
||
Optional[pd.DataFrame], # etf_data: ETF价格(宽格式)
|
||
Optional[pd.DataFrame], # etf_nav_data: ETF净值
|
||
Optional[pd.DataFrame], # benchmark_data: 基准数据
|
||
List[str], # valid_codes: 有效代码列表
|
||
Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据
|
||
Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射
|
||
]:
|
||
"""
|
||
批量获取数据
|
||
|
||
Args:
|
||
code_config: 标的配置 {代码: {name, etf, market}}
|
||
benchmark_code: 基准代码
|
||
start_date: 开始日期
|
||
end_date: 结束日期
|
||
|
||
Returns:
|
||
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map)
|
||
"""
|
||
if end_date is None:
|
||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||
|
||
# 启动SSH隧道
|
||
self._start_tunnel()
|
||
|
||
index_codes = list(code_config.keys())
|
||
etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')}
|
||
|
||
print(f"开始下载 {len(index_codes)} 只标的的数据...")
|
||
print(f" 指数代码: {len(index_codes)} 只")
|
||
print(f" ETF映射: {len(etf_codes)} 只")
|
||
|
||
# 分类统计
|
||
china_codes = [c for c in index_codes if self._tushare.is_china_index(c)]
|
||
futures_codes = [c for c in index_codes if self._tushare.is_futures(c)]
|
||
yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)]
|
||
|
||
print(f" 中国A股指数: {len(china_codes)} 只")
|
||
print(f" 期货合约: {len(futures_codes)} 只")
|
||
print(f" 港股/美股: {len(yf_codes)} 只")
|
||
|
||
# 下载指数数据
|
||
print("\n [1/2] 下载指数数据...")
|
||
index_data_list = []
|
||
index_ohlcv_data = {}
|
||
valid_codes = []
|
||
|
||
for code in index_codes:
|
||
name = code_config[code].get('name', code)
|
||
source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance"
|
||
|
||
print(f" 下载 {code} ({name}) - {source}...", end=" ")
|
||
|
||
data = self.fetch_single(code, start_date, end_date)
|
||
|
||
if data is not None and len(data) > 0:
|
||
# 标准化
|
||
data = data.copy()
|
||
data['source'] = source
|
||
data['code'] = code
|
||
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
|
||
|
||
index_ohlcv_data[code] = data.copy()
|
||
index_data_list.append(data[['code', 'close', 'source']])
|
||
valid_codes.append(code)
|
||
print(f"✓ {len(data)} 条")
|
||
else:
|
||
print("✗ 无数据")
|
||
|
||
# 下载ETF数据
|
||
etf_data_list = []
|
||
etf_nav_data_list = []
|
||
|
||
if etf_codes:
|
||
print("\n [2/2] 下载ETF数据...")
|
||
|
||
for idx_code, etf_code in etf_codes.items():
|
||
name = code_config[idx_code].get('name', idx_code)
|
||
|
||
print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ")
|
||
|
||
# ETF价格
|
||
etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date)
|
||
|
||
# ETF净值
|
||
etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date)
|
||
|
||
if etf_data is not None and len(etf_data) > 0:
|
||
etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize()
|
||
etf_data_list.append(etf_data[['code', 'close']])
|
||
|
||
price_count = len(etf_data)
|
||
nav_count = len(etf_nav) if etf_nav is not None else 0
|
||
|
||
print(f"✓ 价格{price_count}条 净值{nav_count}条")
|
||
else:
|
||
print("✗ 无数据")
|
||
|
||
if etf_nav is not None and len(etf_nav) > 0:
|
||
etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize()
|
||
etf_nav_data_list.append(etf_nav[['code', 'nav']])
|
||
|
||
# 整合数据
|
||
index_data = None
|
||
if index_data_list:
|
||
index_data = pd.concat(index_data_list)
|
||
if 'code' in index_data.columns and 'close' in index_data.columns:
|
||
index_data = index_data.reset_index()
|
||
if 'index' in index_data.columns:
|
||
index_data = index_data.rename(columns={'index': 'date'})
|
||
index_data['date'] = pd.to_datetime(index_data['date']).dt.normalize()
|
||
index_data = index_data.pivot_table(index='date', columns='code', values='close')
|
||
|
||
etf_data = None
|
||
if etf_data_list:
|
||
etf_data = pd.concat(etf_data_list)
|
||
if 'code' in etf_data.columns and 'close' in etf_data.columns:
|
||
etf_data = etf_data.reset_index()
|
||
if 'index' in etf_data.columns:
|
||
etf_data = etf_data.rename(columns={'index': 'date'})
|
||
etf_data['date'] = pd.to_datetime(etf_data['date']).dt.normalize()
|
||
etf_data = etf_data.pivot_table(index='date', columns='code', values='close')
|
||
|
||
etf_nav_data = None
|
||
if etf_nav_data_list:
|
||
etf_nav_data = pd.concat(etf_nav_data_list)
|
||
if 'code' in etf_nav_data.columns and 'nav' in etf_nav_data.columns:
|
||
etf_nav_data = etf_nav_data.reset_index()
|
||
if 'index' in etf_nav_data.columns:
|
||
etf_nav_data = etf_nav_data.rename(columns={'index': 'date'})
|
||
etf_nav_data['date'] = pd.to_datetime(etf_nav_data['date']).dt.normalize()
|
||
etf_nav_data = etf_nav_data.pivot_table(index='date', columns='code', values='nav')
|
||
|
||
# 基准数据
|
||
benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date)
|
||
if benchmark_data is not None:
|
||
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
|
||
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)} 条")
|
||
|
||
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes
|
||
|
||
def __enter__(self):
|
||
self._start_tunnel()
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
self._stop_tunnel()
|
||
|
||
|
||
# 简化接口
|
||
def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict:
|
||
"""
|
||
获取轮动策略数据(简化接口)
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
|
||
Returns:
|
||
{
|
||
'index_data': 指数收盘价DataFrame,
|
||
'etf_data': ETF价格DataFrame,
|
||
'etf_nav_data': ETF净值DataFrame,
|
||
'benchmark_data': 基准DataFrame,
|
||
'valid_codes': 有效代码列表,
|
||
'index_ohlcv_data': 原始OHLCV数据字典
|
||
}
|
||
"""
|
||
import yaml
|
||
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
|
||
source = HybridDataSource.from_yaml(config_path)
|
||
|
||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
|
||
source.fetch_all(
|
||
code_config=config.get('code_list', {}),
|
||
benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'),
|
||
start_date=config.get('start_date', '2019-01-01'),
|
||
end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
|
||
)
|
||
|
||
return {
|
||
'index_data': index_data,
|
||
'etf_data': etf_data,
|
||
'etf_nav_data': etf_nav_data,
|
||
'benchmark_data': benchmark_data,
|
||
'valid_codes': valid_codes,
|
||
'index_ohlcv_data': index_ohlcv_data
|
||
} |