refactor(datasource): 分层接口设计,移除HybridDataSource
架构改动: - 移除 HybridDataSource(功能被 UniversalDataFetcher 覆盖) - 新增分层接口设计:基础层 + 扩展层 基础层(统一接口): - fetch(): 统一 OHLCV 接口,自动识别资产类型 - fetch_batch(): 批量获取 扩展层(资产类型特有): - fetch_etf_adj(): A股 ETF 后复权价格 - fetch_us_adj(): 美股复权价格 - fetch_etf_with_nav(): ETF 价格 + 净值 + 溢价率 其他修改: - YFinanceSource: 新增 fetch_adj() 方法 - strategy.py: 改用 UniversalDataFetcher 替代 HybridDataSource - __init__.py: 移除 HybridDataSource 导出
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@@ -113,7 +114,7 @@ class RotationStrategy(StrategyBase):
|
||||
|
||||
Args:
|
||||
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True)
|
||||
False 则使用本地 HybridDataSource
|
||||
False 则使用本地 UniversalDataFetcher
|
||||
"""
|
||||
code_list_config = self.config.get('code_list', {})
|
||||
benchmark_config = self.config.get('benchmark', {})
|
||||
@@ -237,6 +238,12 @@ class RotationStrategy(StrategyBase):
|
||||
index_close_dict[code] = df['close']
|
||||
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
|
||||
|
||||
# 获取 A 股 SSE 官方交易日历
|
||||
from datasource.tushare_source import TushareSource
|
||||
tushare = TushareSource()
|
||||
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
|
||||
print(f"A股交易日历: {len(a_share_dates)} 天")
|
||||
|
||||
return {
|
||||
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
|
||||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||
@@ -245,7 +252,8 @@ class RotationStrategy(StrategyBase):
|
||||
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
|
||||
'benchmark_data': benchmark_data, # 基准收盘价 Series
|
||||
'valid_codes': valid_codes, # 有效指数代码列表
|
||||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||||
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
|
||||
'a_share_dates': a_share_dates # A股SSE交易日历
|
||||
}
|
||||
|
||||
def _get_data_from_local(
|
||||
@@ -253,33 +261,90 @@ class RotationStrategy(StrategyBase):
|
||||
code_list_config: dict,
|
||||
benchmark_code: str
|
||||
) -> dict:
|
||||
"""使用本地 HybridDataSource 获取数据"""
|
||||
from datasource import HybridDataSource
|
||||
"""使用本地 UniversalDataFetcher 获取数据"""
|
||||
from datasource import UniversalDataFetcher
|
||||
from datasource.tushare_source import TushareSource
|
||||
|
||||
ssh_config = self.config.get('ssh_tunnel', {})
|
||||
|
||||
data_source = HybridDataSource(
|
||||
fetcher = UniversalDataFetcher(
|
||||
ssh_config=ssh_config,
|
||||
use_cache=self.config.get('use_cache', True)
|
||||
)
|
||||
|
||||
# 调用 fetch_all
|
||||
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \
|
||||
data_source.fetch_all(
|
||||
code_config=code_list_config,
|
||||
benchmark_code=benchmark_code,
|
||||
start_date=self.start_date,
|
||||
end_date=self.end_date
|
||||
)
|
||||
index_codes = list(code_list_config.keys())
|
||||
etf_code_map = {idx_code: cfg['etf'] for idx_code, cfg in code_list_config.items() if cfg.get('etf')}
|
||||
|
||||
# 获取指数数据
|
||||
index_ohlcv_data = {}
|
||||
valid_codes = []
|
||||
|
||||
with fetcher: # 使用上下文管理器自动管理 SSH 隧道
|
||||
for code in index_codes:
|
||||
data = fetcher.fetch(code, self.start_date, self.end_date)
|
||||
if data is not None and len(data) > 0:
|
||||
index_ohlcv_data[code] = data
|
||||
valid_codes.append(code)
|
||||
print(f"✓ {code}: {len(data)} 条")
|
||||
else:
|
||||
print(f"✗ {code}: 无数据")
|
||||
|
||||
# 构建宽格式收盘价
|
||||
index_close = None
|
||||
if index_ohlcv_data:
|
||||
close_list = []
|
||||
for code, df in index_ohlcv_data.items():
|
||||
close_df = df[['close']].copy()
|
||||
close_df.columns = [code]
|
||||
close_list.append(close_df)
|
||||
index_close = pd.concat(close_list, axis=1)
|
||||
|
||||
# 获取 ETF 数据
|
||||
etf_data = None
|
||||
etf_nav_data = None
|
||||
|
||||
tushare = TushareSource()
|
||||
|
||||
if etf_code_map:
|
||||
etf_price_list = []
|
||||
etf_nav_list = []
|
||||
|
||||
for idx_code, etf_code in etf_code_map.items():
|
||||
# ETF 价格
|
||||
etf_df = tushare.fetch_etf(etf_code, self.start_date, self.end_date)
|
||||
if etf_df is not None and len(etf_df) > 0:
|
||||
etf_df = etf_df[['close']].copy()
|
||||
etf_df.columns = [etf_code]
|
||||
etf_price_list.append(etf_df)
|
||||
|
||||
# ETF 净值
|
||||
nav_df = tushare.fetch_etf_nav(etf_code, self.start_date, self.end_date)
|
||||
if nav_df is not None and len(nav_df) > 0:
|
||||
nav_df = nav_df[['nav']].copy()
|
||||
nav_df.columns = [etf_code]
|
||||
etf_nav_list.append(nav_df)
|
||||
|
||||
if etf_price_list:
|
||||
etf_data = pd.concat(etf_price_list, axis=1)
|
||||
if etf_nav_list:
|
||||
etf_nav_data = pd.concat(etf_nav_list, axis=1)
|
||||
|
||||
# 基准数据
|
||||
benchmark_data = tushare.fetch_index(benchmark_code, self.start_date, self.end_date)
|
||||
|
||||
# A股交易日历
|
||||
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
|
||||
print(f"A股交易日历: {len(a_share_dates)} 天")
|
||||
|
||||
return {
|
||||
'index_data': index_ohlcv_data, # 原始OHLCV数据
|
||||
'index_close': index_data, # 对齐后的收盘价(宽格式)
|
||||
'index_close': index_close, # 对齐后的收盘价(宽格式)
|
||||
'etf_data': etf_data,
|
||||
'etf_nav_data': etf_nav_data,
|
||||
'benchmark_data': benchmark_data,
|
||||
'valid_codes': valid_codes,
|
||||
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射
|
||||
'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
|
||||
'a_share_dates': a_share_dates # A股SSE交易日历
|
||||
}
|
||||
|
||||
def compute_factors(self, data: dict) -> pd.DataFrame:
|
||||
@@ -290,17 +355,20 @@ class RotationStrategy(StrategyBase):
|
||||
index_data = data['index_data']
|
||||
valid_codes = data['valid_codes']
|
||||
|
||||
# 获取A股交易日历作为基准(使用已有的对齐后数据索引)
|
||||
index_close = data.get('index_close')
|
||||
if index_close is not None:
|
||||
a_share_dates = index_close.index
|
||||
else:
|
||||
for code in valid_codes:
|
||||
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
|
||||
a_share_dates = index_data[code].index
|
||||
break
|
||||
# 获取 A 股 SSE 官方交易日历(优先使用已获取的)
|
||||
a_share_dates = data.get('a_share_dates')
|
||||
if a_share_dates is None or len(a_share_dates) == 0:
|
||||
# 回退:使用已有的对齐后数据索引
|
||||
index_close = data.get('index_close')
|
||||
if index_close is not None:
|
||||
a_share_dates = index_close.index
|
||||
else:
|
||||
a_share_dates = index_data[valid_codes[0]].index
|
||||
for code in valid_codes:
|
||||
if code.endswith('.SH') or code.endswith('.SZ') or code.endswith('.CSI'):
|
||||
a_share_dates = index_data[code].index
|
||||
break
|
||||
else:
|
||||
a_share_dates = index_data[valid_codes[0]].index
|
||||
|
||||
factor_values = {}
|
||||
final_valid_codes = []
|
||||
@@ -408,8 +476,15 @@ class RotationStrategy(StrategyBase):
|
||||
# 4. 执行回测
|
||||
print("\n执行回测...")
|
||||
|
||||
# 获取A股交易日历(从因子数据索引)
|
||||
a_share_dates = signals.index
|
||||
# 获取 A 股 SSE 官方交易日历(优先使用已获取的)
|
||||
a_share_dates = data.get('a_share_dates')
|
||||
if a_share_dates is None or len(a_share_dates) == 0:
|
||||
a_share_dates = signals.index
|
||||
|
||||
# 将信号对齐到 A 股日历
|
||||
if a_share_dates is not signals.index:
|
||||
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
|
||||
print(f" 信号对齐到A股日历: {len(signals)} 天")
|
||||
|
||||
# 计算日收益率:先在原始交易日历计算,再对齐到A股日历
|
||||
# 关键:与因子计算逻辑一致,避免交易日不对齐导致收益率NaN
|
||||
|
||||
Reference in New Issue
Block a user