Compare commits

...

7 Commits

Author SHA1 Message Date
7f2af6b470 refactor(flask_api): fetch添加adj参数,fetch_with_adj简化
FlaskAPIDataSource.fetch() 新增 adj 参数,fetch_with_adj() 简化

- FlaskAPIDataSource.fetch(adj='raw'): 请求参数包含 adj
- fetch_with_adj(): 简化为 return self.fetch(adj=adj)(减少 ~120行)
- flask_server.py: 缓存逻辑已支持 adj 参数,无需修改
2026-05-23 18:32:20 +08:00
c319fd42be refactor(universal_fetcher): fetch添加adj参数,fetch_with_adj简化
UniversalDataFetcher.fetch() 新增 adj 参数,直接传递给底层

- fetch(adj='raw/qfq/hfq'): 统一入口,参数校验和路由
- fetch_with_adj(): 简化为 return self.fetch(adj=adj)
- 删除重复的 VALID_ADJ_BY_TYPE 定义和路由逻辑(~70行)
- VALID_ADJ_BY_TYPE 移到类级别作为静态配置
2026-05-23 18:32:10 +08:00
02dbc7bd7d refactor(datasource): 底层fetch方法添加adj参数
TushareSource.fetch() 和 YFinanceSource.fetch() 新增 adj 参数支持 raw/qfq/hfq

- TushareSource.fetch(adj='raw'): 内部路由到 fetch_index/fetch_stock_adj/fetch_etf_adj
- YFinanceSource.fetch(adj='raw'): 内部路由到 fetch_adj() 或原始逻辑
- 添加 is_china_stock() 和 _is_etf_code() 方法用于资产类型判断
2026-05-23 18:32:00 +08:00
1148d3166c refactor(datasource): 分层接口设计,移除HybridDataSource
架构改动:
- 移除 HybridDataSource(功能被 UniversalDataFetcher 覆盖)
- 新增分层接口设计:基础层 + 扩展层

基础层(统一接口):
- fetch(): 统一 OHLCV 接口,自动识别资产类型
- fetch_batch(): 批量获取

扩展层(资产类型特有):
- fetch_etf_adj(): A股 ETF 后复权价格
- fetch_us_adj(): 美股复权价格
- fetch_etf_with_nav(): ETF 价格 + 净值 + 溢价率

其他修改:
- YFinanceSource: 新增 fetch_adj() 方法
- strategy.py: 改用 UniversalDataFetcher 替代 HybridDataSource
- __init__.py: 移除 HybridDataSource 导出
2026-05-23 12:46:48 +08:00
209dd7fd83 refactor(tushare): 移除代理清除/恢复逻辑
Tushare是国内服务,不需要代理切换操作

移除内容:
- _clear_proxy 方法
- _restore_proxy 方法
- 所有方法中的代理清除和恢复调用

代码精简37行,保持原有功能不变
2026-05-23 11:55:02 +08:00
b066b23495 feat(tushare): 新增ETF后复权价格和交易日历获取方法
新增方法:
- fetch_etf_adj: 获取ETF后复权价格数据,消除份额拆分对收益率的影响
  通过 fund_daily + fund_adj 手动计算后复权价格
  fund_adj 单次限2000条,按5年分段请求
- fetch_trade_cal: 获取A股SSE官方交易日历

验证结果:
- 纳指ETF后复权正确识别2022-01-14拆分(复权因子5.0)
- 累计收益100.54%与纳指100指数一致
2026-05-23 11:51:32 +08:00
8e8093e0fd chore(config): 调整回测起始日期为2020-01-01
配合Mode B(指数信号+ETF收益)回测需求,缩短回测区间以提高ETF数据可用性
2026-05-23 11:18:00 +08:00
9 changed files with 766 additions and 407 deletions

View File

@@ -4,22 +4,27 @@
核心数据获取能力: 核心数据获取能力:
- A股数据Tushare指数、ETF、期货 - A股数据Tushare指数、ETF、期货
- 境外数据YFinance港股、美股通过SSH隧道 - 境外数据YFinance港股、美股通过SSH隧道
- 加密货币CCXTOKX通过 socks2http
架构设计: 架构设计:
- 分层架构:对外统一接口,对内各资产类型独立实现 - 分层架构:基础层统一接口,扩展层资产类型特有方法
- Flask APILRU + TTL 双缓存机制 - Flask APILRU + TTL 双缓存机制
用法: 用法:
from datasource import UniversalDataFetcher, AssetType from datasource import UniversalDataFetcher
# 基础层:统一 OHLCV 接口
fetcher = UniversalDataFetcher() fetcher = UniversalDataFetcher()
df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31") df = fetcher.fetch("000300.SH", "2024-01-01", "2024-12-31")
# 扩展层:资产类型特有方法
df_adj = fetcher.fetch_etf_adj("513100.SH", ...) # ETF 后复权
df_adj = fetcher.fetch_us_adj("AAPL", ...) # 美股复权
""" """
from .ssh_tunnel import SSHTunnelManager from .ssh_tunnel import SSHTunnelManager
from .tushare_source import TushareSource from .tushare_source import TushareSource
from .yfinance_source import YFinanceSource from .yfinance_source import YFinanceSource
from .hybrid_source import HybridDataSource
from .asset_type_detector import AssetTypeDetector, AssetType from .asset_type_detector import AssetTypeDetector, AssetType
from .universal_fetcher import UniversalDataFetcher from .universal_fetcher import UniversalDataFetcher
@@ -27,7 +32,6 @@ __all__ = [
'SSHTunnelManager', 'SSHTunnelManager',
'TushareSource', 'TushareSource',
'YFinanceSource', 'YFinanceSource',
'HybridDataSource',
'AssetTypeDetector', 'AssetTypeDetector',
'AssetType', 'AssetType',
'UniversalDataFetcher', 'UniversalDataFetcher',

View File

@@ -61,30 +61,41 @@ class FlaskAPIDataSource:
code: str, code: str,
start_date: str, start_date: str,
end_date: str, end_date: str,
adj: str = 'raw',
asset_type: str = None, asset_type: str = None,
timeframe: str = '1d' timeframe: str = '1d'
) -> Optional[pd.DataFrame]: ) -> Optional[pd.DataFrame]:
""" """
获取单只标的 OHLCV 数据 获取单只标的 OHLCV 数据(支持 adj 参数)
Args: Args:
code: 标的代码 code: 标的代码
start_date: 开始日期 YYYY-MM-DD start_date: 开始日期 YYYY-MM-DD
end_date: 结束日期 YYYY-MM-DD end_date: 结束日期 YYYY-MM-DD
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
asset_type: 资产类型(可选,用于覆盖自动检测) asset_type: 资产类型(可选,用于覆盖自动检测)
timeframe: K线周期加密货币需要 timeframe: K线周期加密货币需要
Returns: Returns:
DataFrame with columns: date, open, high, low, close, volume DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
示例:
# 原始价格
df = source.fetch("000300.SH", "2020-01-01", "2024-12-31")
# A股股票后复权
df = source.fetch("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
""" """
# 构建请求 URL # 构建请求 URL
url = f"{self.base_url}{self.api_path}" url = f"{self.base_url}{self.api_path}"
# 构建请求参数 # 构建请求参数(包含 adj
params = { params = {
'code': code, 'code': code,
'start': start_date, 'start': start_date,
'end': end_date, 'end': end_date,
'adj': adj, # 添加 adj 参数
} }
# 加密货币需要 timeframe 参数 # 加密货币需要 timeframe 参数
@@ -296,6 +307,38 @@ class FlaskAPIDataSource:
print(f"{code} 净值获取失败: {e}") print(f"{code} 净值获取失败: {e}")
return None return None
def fetch_with_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'raw',
asset_type: str = None,
timeframe: str = '1d'
) -> Optional[pd.DataFrame]:
"""
获取 OHLCV 数据(支持复权参数)- 简化版
直接调用 fetch(adj=adj),无需重复实现。
Args:
code: 标的代码
start_date: 开始日期 YYYY-MM-DD
end_date: 结束日期 YYYY-MM-DD
adj: 复权参数raw/qfq/hfq默认 'raw'
asset_type: 资产类型(可选)
timeframe: K线周期加密货币需要
Returns:
DataFrame结构因 adj 参数略有不同
示例:
# A股股票后复权
df = source.fetch_with_adj("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
"""
# 直接调用 fetch传递 adj 参数
return self.fetch(code, start_date, end_date, adj, asset_type, timeframe)
def get_health(self) -> Dict: def get_health(self) -> Dict:
"""获取服务健康状态""" """获取服务健康状态"""
# 先尝试 ohlcv 端点检查服务是否可用 # 先尝试 ohlcv 端点检查服务是否可用

View File

@@ -119,16 +119,18 @@ def get_fetcher() -> UniversalDataFetcher:
# ============================================================ # ============================================================
@lru_cache(maxsize=CACHE_MAXSIZE) @lru_cache(maxsize=CACHE_MAXSIZE)
def _fetch_full_data_cached(code: str, today: str) -> Optional[str]: def _fetch_full_data_cached(code: str, today: str, adj: str = 'raw') -> Optional[str]:
""" """
缓存全量数据(仅日级别数据) 缓存全量数据(仅日级别数据)
缓存策略: 缓存策略:
- 日级别数据(股票/指数/ETF/期货): 从 DEFAULT_START_DATE 到 today - 日级别数据(股票/指数/ETF/期货): 从 DEFAULT_START_DATE 到 today
- 加密货币: 不缓存,每次实时下载 - 加密货币: 不缓存,每次实时下载
- 不同 adj 参数raw/qfq/hfq独立缓存
缓存Key: (code, today_date) 缓存Key: (code, today_date, adj)
- today: 实际的今天日期,用于每日更新缓存 - today: 实际的今天日期,用于每日更新缓存
- adj: 复权参数,不同复权类型独立缓存
Returns: Returns:
JSON 序列化的全量数据(仅日级别数据) JSON 序列化的全量数据(仅日级别数据)
@@ -142,19 +144,25 @@ def _fetch_full_data_cached(code: str, today: str) -> Optional[str]:
if asset_type == AssetType.CRYPTO: if asset_type == AssetType.CRYPTO:
return None # 不缓存加密货币 return None # 不缓存加密货币
# 校验 adj 参数是否适用于该资产类型
valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(asset_type, ['raw'])
if adj not in valid_adj:
return json.dumps({"error": f"adj='{adj}' 不适用于 {asset_type.value}"})
try: try:
with f: with f:
# 下载数据:从默认起点到今天 # 使用 fetch_with_adj 获取数据(支持复权)
df = f.fetch(code, DEFAULT_START_DATE, today) df = f.fetch_with_adj(code, DEFAULT_START_DATE, today, adj)
if df is None or len(df) == 0: if df is None or len(df) == 0:
return None return None
# 保存为 DataFrame 格式(方便后续切片) # 保存为 DataFrame 格式(方便后续切片)
result = { result = {
'df_json': dataframe_to_json(df), 'df_json': dataframe_to_json(df, asset_type.value),
'code': code, 'code': code,
'asset_type': asset_type.value, 'asset_type': asset_type.value,
'adj': adj,
'data_start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None, 'data_start': df.index.min().strftime('%Y-%m-%d') if len(df) > 0 else None,
'data_end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None, 'data_end': df.index.max().strftime('%Y-%m-%d') if len(df) > 0 else None,
'cache_strategy': 'full_history', 'cache_strategy': 'full_history',
@@ -190,6 +198,7 @@ def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict:
'count': 0, 'count': 0,
'code': cached_data['code'], 'code': cached_data['code'],
'asset_type': cached_data['asset_type'], 'asset_type': cached_data['asset_type'],
'adj': cached_data.get('adj', 'raw'),
'requested_range': {'start': start, 'end': end}, 'requested_range': {'start': start, 'end': end},
'available_range': {'start': cached_data['data_start'], 'end': cached_data['data_end']}, 'available_range': {'start': cached_data['data_start'], 'end': cached_data['data_end']},
} }
@@ -222,6 +231,7 @@ def _slice_data_from_cache(cached_data: Dict, start: str, end: str) -> Dict:
result = dataframe_to_json(sliced_df) result = dataframe_to_json(sliced_df)
result['code'] = cached_data['code'] result['code'] = cached_data['code']
result['asset_type'] = cached_data['asset_type'] result['asset_type'] = cached_data['asset_type']
result['adj'] = cached_data.get('adj', 'raw')
result['requested_range'] = {'start': start, 'end': end} result['requested_range'] = {'start': start, 'end': end}
result['available_range'] = {'start': cached_data['data_start'], 'end': cached_data['data_end']} result['available_range'] = {'start': cached_data['data_start'], 'end': cached_data['data_end']}
@@ -233,14 +243,16 @@ def fetch_data_with_ttl(
start: str, start: str,
end: str, end: str,
nocache: bool = False, nocache: bool = False,
timeframe: str = '1d' timeframe: str = '1d',
adj: str = 'raw'
) -> Tuple[Optional[Dict], bool]: ) -> Tuple[Optional[Dict], bool]:
""" """
获取数据,支持 TTL 缓存(加密货币不缓存) 获取数据,支持 TTL 缓存(加密货币不缓存)
缓存策略: 缓存策略:
- 日级别数据(股票/指数/ETF/期货): Key=(code, today), 缓存全量数据,切片返回 - 日级别数据(股票/指数/ETF/期货): Key=(code, today, adj), 缓存全量数据,切片返回
- 加密货币: 每次实时下载,不缓存,必须指定 timeframe - 加密货币: 每次实时下载,不缓存,必须指定 timeframe
- 不同 adj 参数独立缓存
Args: Args:
code: 标的代码 code: 标的代码
@@ -248,6 +260,7 @@ def fetch_data_with_ttl(
end: 用户请求的结束日期 end: 用户请求的结束日期
nocache: 是否跳过缓存 nocache: 是否跳过缓存
timeframe: K线周期仅加密货币需要 timeframe: K线周期仅加密货币需要
adj: 复权参数raw/qfq/hfq
Returns: Returns:
(data, is_cached): 数据和是否命中缓存 (data, is_cached): 数据和是否命中缓存
@@ -269,6 +282,7 @@ def fetch_data_with_ttl(
result = dataframe_to_json(df, asset_type.value) result = dataframe_to_json(df, asset_type.value)
result['code'] = code result['code'] = code
result['asset_type'] = asset_type.value result['asset_type'] = asset_type.value
result['adj'] = 'raw' # 加密货币无复权
result['cache_strategy'] = 'no_cache_crypto' result['cache_strategy'] = 'no_cache_crypto'
result['requested_range'] = {'start': start, 'end': end} result['requested_range'] = {'start': start, 'end': end}
result['timeframe'] = timeframe result['timeframe'] = timeframe
@@ -276,15 +290,20 @@ def fetch_data_with_ttl(
except Exception as e: except Exception as e:
return {'error': str(e), 'code': code, 'asset_type': asset_type.value}, False return {'error': str(e), 'code': code, 'asset_type': asset_type.value}, False
# 日级别数据:使用缓存 # 校验 adj 参数
full_cache_key = (code, today) valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(asset_type, ['raw'])
if adj not in valid_adj:
return {'error': f"adj='{adj}' 不适用于 {asset_type.value},支持: {valid_adj}", 'code': code, 'asset_type': asset_type.value}, False
# 日级别数据:使用缓存(缓存 Key 包含 adj
full_cache_key = (code, today, adj)
# 跳过缓存:清理缓存后重新下载 # 跳过缓存:清理缓存后重新下载
if nocache: if nocache:
_fetch_full_data_cached.cache_clear() _fetch_full_data_cached.cache_clear()
global _ttl_cache global _ttl_cache
_ttl_cache.clear() _ttl_cache.clear()
result_json = _fetch_full_data_cached(code, today) result_json = _fetch_full_data_cached(code, today, adj)
if result_json is None: if result_json is None:
return None, False return None, False
full_data = json.loads(result_json) full_data = json.loads(result_json)
@@ -301,7 +320,7 @@ def fetch_data_with_ttl(
del _ttl_cache[full_cache_key] del _ttl_cache[full_cache_key]
# 从 LRU 缓存获取全量数据 # 从 LRU 缓存获取全量数据
result_json = _fetch_full_data_cached(code, today) result_json = _fetch_full_data_cached(code, today, adj)
if result_json is None: if result_json is None:
return None, False return None, False
@@ -552,11 +571,19 @@ def get_ohlcv():
asset_type: 资产类型 (optional, 强制覆盖自动检测结果) asset_type: 资产类型 (optional, 强制覆盖自动检测结果)
- china_index: 中国指数 - china_index: 中国指数
- china_etf: 中国ETF - china_etf: 中国ETF
- china_stock: 中国股票
- us_index: 美股指数 - us_index: 美股指数
- us_stock: 美股股票
- hk_index: 港股指数 - hk_index: 港股指数
- hk_stock: 港股股票
- futures: 期货 - futures: 期货
- crypto: 加密货币 - crypto: 加密货币
注:指定后会覆盖自动检测,用于修复检测逻辑问题 注:指定后会覆盖自动检测,用于修复检测逻辑问题
adj: 复权参数 (optional, 默认raw)
- raw: 原始价格(所有资产类型)
- qfq: 前复权A股股票/美股股票/港股股票)
- hfq: 后复权A股股票/ETF/美股股票/港股股票)
不同资产类型支持的adj值不同非法组合返回400错误
timeframe: K线周期 (optional, 仅加密货币需要) timeframe: K线周期 (optional, 仅加密货币需要)
- 1d: 日线(默认) - 1d: 日线(默认)
- 1h: 小时线 - 1h: 小时线
@@ -569,6 +596,7 @@ def get_ohlcv():
start = request.args.get('start', '').strip() start = request.args.get('start', '').strip()
end = request.args.get('end', '').strip() end = request.args.get('end', '').strip()
asset_type_param = request.args.get('asset_type', '').strip().lower() asset_type_param = request.args.get('asset_type', '').strip().lower()
adj = request.args.get('adj', 'raw').strip().lower()
timeframe = request.args.get('timeframe', '1d').strip().lower() timeframe = request.args.get('timeframe', '1d').strip().lower()
nocache = request.args.get('nocache', 'false').lower() == 'true' nocache = request.args.get('nocache', 'false').lower() == 'true'
@@ -577,7 +605,15 @@ def get_ohlcv():
return jsonify({ return jsonify({
"error": "Missing required parameter: code", "error": "Missing required parameter: code",
"example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31", "example": "/api/v1/ohlcv?code=000300.SH&start=2024-01-01&end=2024-03-31",
"asset_type_hint": "可选 asset_type 参数强制指定类型", "adj_hint": "可选 adj 参数获取复权数据raw/qfq/hfq",
}), 400
# adj 参数验证
if adj not in ['raw', 'qfq', 'hfq']:
return jsonify({
"error": f"Invalid adj parameter: {adj}",
"valid_adj": ['raw', 'qfq', 'hfq'],
"hint": "adj 必须是 raw/qfq/hfq",
}), 400 }), 400
# 设置默认日期 # 设置默认日期
@@ -607,6 +643,15 @@ def get_ohlcv():
"valid_types": [t.value for t in AssetType], "valid_types": [t.value for t in AssetType],
}), 400 }), 400
# 校验 adj 是否适用于该资产类型
valid_adj = UniversalDataFetcher.VALID_ADJ_BY_TYPE.get(final_type, ['raw'])
if adj not in valid_adj:
return jsonify({
"error": f"adj='{adj}' 不适用于 {final_type.value}",
"valid_adj": valid_adj,
"hint": f"{final_type.value} 仅支持复权类型: {valid_adj}",
}), 400
# 加密货币必须指定 timeframe无论自动检测还是手动指定 # 加密货币必须指定 timeframe无论自动检测还是手动指定
if final_type == AssetType.CRYPTO: if final_type == AssetType.CRYPTO:
valid_timeframes = ['1d', '1h', '4h', '15m', '1m', 'daily', 'hourly'] valid_timeframes = ['1d', '1h', '4h', '15m', '1m', 'daily', 'hourly']
@@ -618,12 +663,13 @@ def get_ohlcv():
}), 400 }), 400
# 使用缓存获取数据(加密货币不缓存) # 使用缓存获取数据(加密货币不缓存)
result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe) result, is_cached = fetch_data_with_ttl(code, start, end, nocache, timeframe, adj)
if result is None: if result is None:
return jsonify({ return jsonify({
"code": code, "code": code,
"asset_type": final_type.value, "asset_type": final_type.value,
"adj": adj,
"detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示 "detected_type": detected_type.value if asset_type_param else None, # 仅当用户指定时显示
"error": "No data available", "error": "No data available",
"start": start, "start": start,
@@ -634,15 +680,17 @@ def get_ohlcv():
return jsonify({ return jsonify({
"code": code, "code": code,
"asset_type": final_type.value, "asset_type": final_type.value,
"adj": adj,
"detected_type": detected_type.value if asset_type_param else None, "detected_type": detected_type.value if asset_type_param else None,
"error": result["error"], "error": result["error"],
}), 500 }), 500
result['cached'] = is_cached result['cached'] = is_cached
result['asset_type'] = final_type.value # 使用最终类型 result['asset_type'] = final_type.value # 使用最终类型
result['adj'] = adj # 返回使用的 adj 参数
# 如果是中国 ETF自动附加净值和溢价率数据 # 如果是中国 ETF 且 adj=raw,自动附加净值和溢价率数据
if final_type == AssetType.CHINA_ETF: if final_type == AssetType.CHINA_ETF and adj == 'raw':
try: try:
f = get_fetcher() f = get_fetcher()
with f: with f:

View File

@@ -1,301 +0,0 @@
"""
混合数据源
整合 TushareA股 + YFinance境外数据获取
"""
import os
import time
from typing import Optional, Tuple, Dict, List
from datetime import datetime
from pathlib import Path
import pandas as pd
from .ssh_tunnel import SSHTunnelManager
from .tushare_source import TushareSource
from .yfinance_source import YFinanceSource
class HybridDataSource:
"""
混合数据源
- A股指数/ETF/期货: Tushare
- 港股/美股/商品: YFinance通过SSH隧道
使用方式:
from datasource import HybridDataSource
source = HybridDataSource.from_yaml('strategies/rotation/config.yaml')
result = source.fetch_all()
"""
def __init__(
self,
ssh_config: Optional[dict] = None,
use_cache: bool = True,
cache_dir: str = "data/etf_cache/daily"
):
"""
初始化混合数据源
Args:
ssh_config: SSH隧道配置
use_cache: 是否使用缓存
cache_dir: 缓存目录
"""
self.ssh_config = ssh_config or {}
self.use_cache = use_cache
self.cache_dir = cache_dir
# 数据源实例
self._tushare = TushareSource()
self._yfinance = YFinanceSource()
# SSH隧道延迟初始化
self._tunnel: Optional[SSHTunnelManager] = None
@classmethod
def from_yaml(cls, config_path: str) -> 'HybridDataSource':
"""从YAML配置创建实例"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return cls(
ssh_config=config.get('ssh_tunnel', {}),
use_cache=config.get('use_cache', True)
)
def _start_tunnel(self) -> bool:
"""启动SSH隧道"""
if self._tunnel is None and self.ssh_config.get('enabled'):
self._tunnel = SSHTunnelManager(self.ssh_config)
return self._tunnel.start()
return True
def _stop_tunnel(self):
"""停止SSH隧道"""
if self._tunnel:
self._tunnel.stop()
self._tunnel = None
def fetch_single(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取单个标的数据
Args:
code: 标的代码
start_date: 开始日期
end_date: 结束日期
Returns:
DataFrame with OHLCV data
"""
# 判断数据源
if self._tushare.is_china_index(code) or self._tushare.is_futures(code):
return self._tushare.fetch(code, start_date, end_date)
else:
# YFinance需要SSH隧道
self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date)
def fetch_all(
self,
code_config: dict,
benchmark_code: str = "000300.SH",
start_date: str = "2019-01-01",
end_date: str = None
) -> Tuple[
Optional[pd.DataFrame], # index_data: 指数收盘价(宽格式)
Optional[pd.DataFrame], # etf_data: ETF价格宽格式
Optional[pd.DataFrame], # etf_nav_data: ETF净值
Optional[pd.DataFrame], # benchmark_data: 基准数据
List[str], # valid_codes: 有效代码列表
Dict[str, pd.DataFrame], # index_ohlcv_data: 原始OHLCV数据
Dict[str, str] # etf_code_map: {指数代码: ETF代码} 映射
]:
"""
批量获取数据
Args:
code_config: 标的配置 {代码: {name, etf, market}}
benchmark_code: 基准代码
start_date: 开始日期
end_date: 结束日期
Returns:
(index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map)
"""
if end_date is None:
end_date = datetime.now().strftime('%Y-%m-%d')
# 启动SSH隧道
self._start_tunnel()
index_codes = list(code_config.keys())
etf_codes = {idx_code: cfg['etf'] for idx_code, cfg in code_config.items() if cfg.get('etf')}
print(f"开始下载 {len(index_codes)} 只标的的数据...")
print(f" 指数代码: {len(index_codes)}")
print(f" ETF映射: {len(etf_codes)}")
# 分类统计
china_codes = [c for c in index_codes if self._tushare.is_china_index(c)]
futures_codes = [c for c in index_codes if self._tushare.is_futures(c)]
yf_codes = [c for c in index_codes if not self._tushare.is_china_index(c) and not self._tushare.is_futures(c)]
print(f" 中国A股指数: {len(china_codes)}")
print(f" 期货合约: {len(futures_codes)}")
print(f" 港股/美股: {len(yf_codes)}")
# 下载指数数据
print("\n [1/2] 下载指数数据...")
index_data_list = []
index_ohlcv_data = {}
valid_codes = []
for code in index_codes:
name = code_config[code].get('name', code)
source = "Tushare" if self._tushare.is_china_index(code) or self._tushare.is_futures(code) else "YFinance"
print(f" 下载 {code} ({name}) - {source}...", end=" ")
data = self.fetch_single(code, start_date, end_date)
if data is not None and len(data) > 0:
# 标准化
data = data.copy()
data['source'] = source
data['code'] = code
data.index = pd.to_datetime(data.index, utc=True).tz_localize(None).normalize()
index_ohlcv_data[code] = data.copy()
index_data_list.append(data[['code', 'close', 'source']])
valid_codes.append(code)
print(f"{len(data)}")
else:
print("✗ 无数据")
# 下载ETF数据
etf_data_list = []
etf_nav_data_list = []
if etf_codes:
print("\n [2/2] 下载ETF数据...")
for idx_code, etf_code in etf_codes.items():
name = code_config[idx_code].get('name', idx_code)
print(f" 下载ETF {etf_code} (对应指数 {idx_code})...", end=" ")
# ETF价格
etf_data = self._tushare.fetch_etf(etf_code, start_date, end_date)
# ETF净值
etf_nav = self._tushare.fetch_etf_nav(etf_code, start_date, end_date)
if etf_data is not None and len(etf_data) > 0:
etf_data.index = pd.to_datetime(etf_data.index, utc=True).tz_localize(None).normalize()
etf_data_list.append(etf_data[['code', 'close']])
price_count = len(etf_data)
nav_count = len(etf_nav) if etf_nav is not None else 0
print(f"✓ 价格{price_count}条 净值{nav_count}")
else:
print("✗ 无数据")
if etf_nav is not None and len(etf_nav) > 0:
etf_nav.index = pd.to_datetime(etf_nav.index, utc=True).tz_localize(None).normalize()
etf_nav_data_list.append(etf_nav[['code', 'nav']])
# 整合数据
index_data = None
if index_data_list:
index_data = pd.concat(index_data_list)
if 'code' in index_data.columns and 'close' in index_data.columns:
index_data = index_data.reset_index()
if 'index' in index_data.columns:
index_data = index_data.rename(columns={'index': 'date'})
index_data['date'] = pd.to_datetime(index_data['date']).dt.normalize()
index_data = index_data.pivot_table(index='date', columns='code', values='close')
etf_data = None
if etf_data_list:
etf_data = pd.concat(etf_data_list)
if 'code' in etf_data.columns and 'close' in etf_data.columns:
etf_data = etf_data.reset_index()
if 'index' in etf_data.columns:
etf_data = etf_data.rename(columns={'index': 'date'})
etf_data['date'] = pd.to_datetime(etf_data['date']).dt.normalize()
etf_data = etf_data.pivot_table(index='date', columns='code', values='close')
etf_nav_data = None
if etf_nav_data_list:
etf_nav_data = pd.concat(etf_nav_data_list)
if 'code' in etf_nav_data.columns and 'nav' in etf_nav_data.columns:
etf_nav_data = etf_nav_data.reset_index()
if 'index' in etf_nav_data.columns:
etf_nav_data = etf_nav_data.rename(columns={'index': 'date'})
etf_nav_data['date'] = pd.to_datetime(etf_nav_data['date']).dt.normalize()
etf_nav_data = etf_nav_data.pivot_table(index='date', columns='code', values='nav')
# 基准数据
benchmark_data = self._tushare.fetch_index(benchmark_code, start_date, end_date)
if benchmark_data is not None:
benchmark_data.index = pd.to_datetime(benchmark_data.index, utc=True).tz_localize(None).normalize()
print(f"\n✓ 基准 {benchmark_code}: {len(benchmark_data)}")
return index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_codes
def __enter__(self):
self._start_tunnel()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._stop_tunnel()
# 简化接口
def fetch_rotation_data(config_path: str = "strategies/rotation/config.yaml") -> dict:
"""
获取轮动策略数据(简化接口)
Args:
config_path: 配置文件路径
Returns:
{
'index_data': 指数收盘价DataFrame,
'etf_data': ETF价格DataFrame,
'etf_nav_data': ETF净值DataFrame,
'benchmark_data': 基准DataFrame,
'valid_codes': 有效代码列表,
'index_ohlcv_data': 原始OHLCV数据字典
}
"""
import yaml
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
source = HybridDataSource.from_yaml(config_path)
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data = \
source.fetch_all(
code_config=config.get('code_list', {}),
benchmark_code=config.get('benchmark', {}).get('code', '000300.SH'),
start_date=config.get('start_date', '2019-01-01'),
end_date=config.get('end_date', datetime.now().strftime('%Y-%m-%d'))
)
return {
'index_data': index_data,
'etf_data': etf_data,
'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data,
'valid_codes': valid_codes,
'index_ohlcv_data': index_ohlcv_data
}

View File

@@ -29,19 +29,6 @@ class TushareSource:
import tushare as ts import tushare as ts
return ts.pro_api(self._token) return ts.pro_api(self._token)
def _clear_proxy(self) -> dict:
"""清除代理环境变量Tushare是国内服务不需要代理"""
original = {}
for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
original[key] = os.environ.pop(key, None)
return original
def _restore_proxy(self, original: dict):
"""恢复代理环境变量"""
for key, value in original.items():
if value is not None:
os.environ[key] = value
def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
""" """
获取A股指数数据 获取A股指数数据
@@ -54,8 +41,6 @@ class TushareSource:
Returns: Returns:
DataFrame with columns: date, open, high, low, close, volume DataFrame with columns: date, open, high, low, close, volume
""" """
original_proxy = self._clear_proxy()
try: try:
pro = self._get_pro_api() pro = self._get_pro_api()
@@ -89,9 +74,6 @@ class TushareSource:
print(f"Tushare下载指数 {code} 失败: {e}") print(f"Tushare下载指数 {code} 失败: {e}")
return None return None
finally:
self._restore_proxy(original_proxy)
def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
""" """
获取期货数据 获取期货数据
@@ -101,10 +83,7 @@ class TushareSource:
start_date: 开始日期 start_date: 开始日期
end_date: 结束日期 end_date: 结束日期
""" """
original_proxy = self._clear_proxy()
try: try:
import tushare as ts
pro = self._get_pro_api() pro = self._get_pro_api()
# 使用 fut_daily 接口 # 使用 fut_daily 接口
@@ -134,9 +113,6 @@ class TushareSource:
print(f"Tushare下载期货 {code} 失败: {e}") print(f"Tushare下载期货 {code} 失败: {e}")
return None return None
finally:
self._restore_proxy(original_proxy)
def fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
""" """
获取ETF价格数据 获取ETF价格数据
@@ -144,8 +120,6 @@ class TushareSource:
Args: Args:
code: ETF代码'159915.SZ', '518880.SH' code: ETF代码'159915.SZ', '518880.SH'
""" """
original_proxy = self._clear_proxy()
try: try:
pro = self._get_pro_api() pro = self._get_pro_api()
@@ -176,9 +150,6 @@ class TushareSource:
print(f"Tushare下载ETF {code} 失败: {e}") print(f"Tushare下载ETF {code} 失败: {e}")
return None return None
finally:
self._restore_proxy(original_proxy)
def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
""" """
获取ETF净值数据 获取ETF净值数据
@@ -186,8 +157,6 @@ class TushareSource:
Args: Args:
code: ETF代码 code: ETF代码
""" """
original_proxy = self._clear_proxy()
try: try:
pro = self._get_pro_api() pro = self._get_pro_api()
@@ -218,9 +187,6 @@ class TushareSource:
print(f"Tushare下载ETF净值 {code} 失败: {e}") print(f"Tushare下载ETF净值 {code} 失败: {e}")
return None return None
finally:
self._restore_proxy(original_proxy)
def is_china_index(self, code: str) -> bool: def is_china_index(self, code: str) -> bool:
"""判断是否为A股指数""" """判断是否为A股指数"""
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI") return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI")
@@ -231,18 +197,257 @@ class TushareSource:
# NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance # NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance
return ".SHF" in code or ".DCE" in code or ".CZC" in code return ".SHF" in code or ".DCE" in code or ".CZC" in code
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def is_china_stock(self, code: str) -> bool:
"""判断是否为A股股票6位数字代码 + .SZ/.SH/.SS"""
# 股票代码000001.SZ, 600000.SH 等
# 区分指数:指数代码通常是 000xxx.SH, 399xxx.SZ, H30xxx.CSI
# 股票代码通常是 00xxxx.SZ, 30xxxx.SZ, 60xxxx.SH, 000xxx.SH部分
import re
# 股票代码模式6位数字 + .SZ/.SH/.SS
# 排除指数000xxx.SH (指数), 399xxx.SZ (指数), Hxxxxx.CSI (指数)
if not re.match(r'^\d{6}\.(SZ|SH|SS)$', code):
return False
# 000xxx.SH 可能是指数也可能是股票,需要更细致判断
# 简化处理000/001/002/003 开头 + .SZ 是股票600/601/603 开头 + .SH 是股票
prefix = code[:3]
suffix = code.split('.')[1]
if suffix == 'SZ' and prefix in ['000', '001', '002', '003', '300']:
return True
if suffix == 'SH' and prefix in ['600', '601', '603', '605', '688']:
return True
return False
def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
""" """
通用数据获取(自动判断类型) 通用数据获取(自动判断类型,支持 adj 参数
Args: Args:
code: 代码 code: 代码
start_date: 开始日期 start_date: 开始日期
end_date: 结束日期 end_date: 结束日期
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
""" """
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'raw', 'qfq''hfq',当前: {adj}")
# 原始数据
if adj == 'raw':
if self.is_china_index(code): if self.is_china_index(code):
return self.fetch_index(code, start_date, end_date) return self.fetch_index(code, start_date, end_date)
elif self.is_futures(code): elif self.is_futures(code):
return self.fetch_futures(code, start_date, end_date) return self.fetch_futures(code, start_date, end_date)
elif self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj='raw')
else: else:
return None return None
# 复权数据
if adj in ['qfq', 'hfq']:
# A股股票复权
if self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj)
# A股 ETF 仅支持 hfq
elif self._is_etf_code(code):
if adj == 'hfq':
return self.fetch_etf_adj(code, start_date, end_date)
else:
raise ValueError(f"ETF 仅支持 adj='hfq'(后复权),当前: {adj}")
else:
# 指数/期货不支持复权
raise ValueError(f"指数/期货不支持复权adj='{adj}' 仅适用于股票/ETF")
def _is_etf_code(self, code: str) -> bool:
"""判断是否为ETF代码"""
# ETF代码51xxxx.SH, 52xxxx.SH, 15xxxx.SZ, 16xxxx.SZ
import re
if not re.match(r'^\d{6}\.(SZ|SH)$', code):
return False
prefix = code[:2]
return prefix in ['51', '52', '15', '16']
def fetch_etf_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取 ETF 后复权价格数据
通过 fund_daily + fund_adj 手动计算后复权价格,消除份额折算(拆分)对收益率的影响。
fund_adj 单次限 2000 条,按 5 年分段请求再拼接。
Args:
code: ETF代码'159915.SZ', '518880.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, close, adj_factor, close_hfq
"""
try:
pro = self._get_pro_api()
ts_code = code.replace('.SS', '.SH')
# 获取 fund_daily 数据
df_daily = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
if df_daily is None or len(df_daily) == 0:
return None
# 获取 fund_adj 数据分段请求单次限2000条
# 按5年分段
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
adj_chunks = []
chunk_start = start_dt
while chunk_start < end_dt:
chunk_end = min(chunk_start.replace(year=chunk_start.year + 5), end_dt)
chunk_start_str = chunk_start.strftime('%Y%m%d')
chunk_end_str = chunk_end.strftime('%Y%m%d')
df_adj_chunk = pro.fund_adj(
ts_code=ts_code,
start_date=chunk_start_str,
end_date=chunk_end_str
)
if df_adj_chunk is not None and len(df_adj_chunk) > 0:
adj_chunks.append(df_adj_chunk)
chunk_start = chunk_end
if not adj_chunks:
# 无复权因子,返回原始数据
df = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
df['adj_factor'] = 1.0
df['close_hfq'] = df['close']
df['code'] = code
return df[['code', 'open', 'close', 'adj_factor', 'close_hfq']]
# 合并所有复权因子
df_adj = pd.concat(adj_chunks, ignore_index=True)
df_adj = df_adj.rename(columns={'trade_date': 'date'})
df_adj['date'] = pd.to_datetime(df_adj['date'])
df_adj = df_adj.set_index('date').sort_index()
# 合并 daily 和 adj
df_daily = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df_daily['date'] = pd.to_datetime(df_daily['date'])
df_daily = df_daily.set_index('date').sort_index()
# 复权因子对齐(用最新值)
df_adj_aligned = df_adj.reindex(df_daily.index, method='ffill')
df_adj_aligned['adj_factor'] = df_adj_aligned['adj_factor'].fillna(1.0)
# 计算后复权价格
df = df_daily.copy()
df['adj_factor'] = df_adj_aligned['adj_factor']
df['close_hfq'] = df['close'] * df['adj_factor']
df['code'] = code
return df[['code', 'open', 'close', 'adj_factor', 'close_hfq']]
except Exception as e:
print(f"Tushare下载ETF复权数据 {code} 失败: {e}")
return None
def fetch_trade_cal(self, start_date: str, end_date: str) -> pd.DatetimeIndex:
"""
获取 A 股(上交所 SSE官方交易日历
Args:
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DatetimeIndex: A股交易日日期序列
"""
try:
pro = self._get_pro_api()
df = pro.trade_cal(
exchange='SSE',
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''),
is_open='1'
)
if df is None or len(df) == 0:
return pd.DatetimeIndex([])
# 提取交易日并转换为 DatetimeIndex
trade_dates = pd.to_datetime(df['cal_date'])
return pd.DatetimeIndex(trade_dates.sort_values())
except Exception as e:
print(f"Tushare下载交易日历失败: {e}")
return pd.DatetimeIndex([])
def fetch_stock_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]:
"""
获取 A股股票复权价格数据
使用 pro_bar 接口获取前复权(qfq)或后复权(hfq)价格。
Args:
code: 股票代码,如 '000001.SZ', '600000.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq'
Returns:
DataFrame with columns: date, code, open, high, low, close, volume, adj_factor
"""
import tushare as ts
if adj not in ['qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'qfq''hfq',当前: {adj}")
try:
ts_code = code.replace('.SS', '.SH')
# 使用 pro_bar 接口获取复权数据
df = ts.pro_bar(
ts_code=ts_code,
adj=adj,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''),
adjfactor=True # 返回复权因子
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
'ts_code': 'code',
'trade_date': 'date',
'vol': 'volume',
})
# 转换日期格式
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
df = df.sort_index()
# 恢复原始代码格式(.SS -> .SH 反转)
df['code'] = code
# 标准化返回字段
columns = ['code', 'open', 'high', 'low', 'close', 'volume']
if 'adj_factor' in df.columns:
columns.append('adj_factor')
return df[columns]
except Exception as e:
print(f"Tushare下载股票复权数据 {code} 失败: {e}")
return None

View File

@@ -107,16 +107,30 @@ class UniversalDataFetcher:
# 统一入口(自动路由) # 统一入口(自动路由)
# ============================================================ # ============================================================
# 各资产类型支持的 adj 参数
VALID_ADJ_BY_TYPE = {
AssetType.CHINA_INDEX: ['raw'], # 指数无复权
AssetType.CHINA_ETF: ['raw', 'hfq'], # ETF 仅支持后复权
AssetType.CHINA_STOCK: ['raw', 'qfq', 'hfq'],
AssetType.US_INDEX: ['raw'], # 指数无复权
AssetType.US_STOCK: ['raw', 'qfq', 'hfq'],
AssetType.HK_INDEX: ['raw'], # 指数无复权
AssetType.HK_STOCK: ['raw', 'qfq', 'hfq'],
AssetType.FUTURES: ['raw'], # 期货无复权
AssetType.CRYPTO: ['raw'], # 加密货币无复权
}
def fetch( def fetch(
self, self,
code: str, code: str,
start_date: str, start_date: str,
end_date: str, end_date: str,
adj: str = 'raw',
retry: int = 3, retry: int = 3,
timeframe: str = '1d' timeframe: str = '1d'
) -> Optional[pd.DataFrame]: ) -> Optional[pd.DataFrame]:
""" """
统一数据获取入口 统一数据获取入口(支持 adj 参数)
自动识别资产类型并路由到对应方法 自动识别资产类型并路由到对应方法
@@ -124,31 +138,60 @@ class UniversalDataFetcher:
code: 标的代码 code: 标的代码
start_date: 开始日期 'YYYY-MM-DD' start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
retry: 重试次数 retry: 重试次数
timeframe: K线周期仅加密货币需要默认1d timeframe: K线周期仅加密货币需要默认1d
Returns: Returns:
DataFrame with columns: date, open, high, low, close, volume DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
示例:
# 原始价格
df = fetcher.fetch("000300.SH", "2020-01-01", "2024-12-31")
# A股股票后复权
df = fetcher.fetch("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
# 美股股票前复权
df = fetcher.fetch("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
""" """
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'raw', 'qfq''hfq',当前: {adj}")
asset_type = AssetTypeDetector.detect(code) asset_type = AssetTypeDetector.detect(code)
# 校验 adj 是否适用于该资产类型
valid_adj = self.VALID_ADJ_BY_TYPE.get(asset_type, ['raw'])
if adj not in valid_adj:
raise ValueError(
f"adj='{adj}' 不适用于 {asset_type.value},支持的类型: {valid_adj}"
)
for attempt in range(retry): for attempt in range(retry):
try: try:
# 路由到具体方法 # 路由到具体方法(传递 adj 参数)
if asset_type == AssetType.CHINA_INDEX: if asset_type == AssetType.CHINA_INDEX:
return self._fetch_china_index(code, start_date, end_date) return self._tushare.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.CHINA_ETF: elif asset_type == AssetType.CHINA_ETF:
return self._fetch_china_etf(code, start_date, end_date) return self._tushare.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.CHINA_STOCK:
return self._tushare.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.US_INDEX: elif asset_type == AssetType.US_INDEX:
return self._fetch_us_index(code, start_date, end_date) self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.US_STOCK: elif asset_type == AssetType.US_STOCK:
return self._fetch_us_stock(code, start_date, end_date) self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.HK_INDEX: elif asset_type == AssetType.HK_INDEX:
return self._fetch_hk_index(code, start_date, end_date) self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.HK_STOCK: elif asset_type == AssetType.HK_STOCK:
return self._fetch_hk_stock(code, start_date, end_date) self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date, adj)
elif asset_type == AssetType.FUTURES: elif asset_type == AssetType.FUTURES:
return self._fetch_futures(code, start_date, end_date) return self._fetch_futures(code, start_date, end_date, adj)
elif asset_type == AssetType.CRYPTO: elif asset_type == AssetType.CRYPTO:
return self._fetch_crypto(code, start_date, end_date, timeframe) return self._fetch_crypto(code, start_date, end_date, timeframe)
else: else:
@@ -159,7 +202,7 @@ class UniversalDataFetcher:
if attempt < retry - 1: if attempt < retry - 1:
time.sleep(2) time.sleep(2)
else: else:
print(f"✗ 获取 {code} 失败 (尝试 {attempt+1}/{retry}): {e}") print(f"✗ 获取 {code} adj={adj} 失败 (尝试 {attempt+1}/{retry}): {e}")
return None return None
return None return None
@@ -359,7 +402,8 @@ class UniversalDataFetcher:
self, self,
code: str, code: str,
start_date: str, start_date: str,
end_date: str end_date: str,
adj: str = 'raw'
) -> Optional[pd.DataFrame]: ) -> Optional[pd.DataFrame]:
""" """
获取期货数据 获取期货数据
@@ -367,11 +411,16 @@ class UniversalDataFetcher:
特点: 特点:
- 中国期货(.SHF/.DCE/.CZC): Tushare - 中国期货(.SHF/.DCE/.CZC): Tushare
- NYMEX(.NYM): YFinance - NYMEX(.NYM): YFinance
- 期货不支持复权adj 只能为 'raw'
""" """
# 期货不支持复权
if adj != 'raw':
raise ValueError(f"期货不支持复权adj='{adj}' 仅适用于股票/ETF")
if code.endswith('.NYM'): if code.endswith('.NYM'):
# NYMEX期货走YFinance # NYMEX期货走YFinance
self._start_tunnel() self._start_tunnel()
return self._yfinance.fetch(code, start_date, end_date) return self._yfinance.fetch(code, start_date, end_date, adj='raw')
else: else:
# 中国期货走Tushare # 中国期货走Tushare
return self._tushare.fetch_futures(code, start_date, end_date) return self._tushare.fetch_futures(code, start_date, end_date)
@@ -456,3 +505,149 @@ class UniversalDataFetcher:
def is_supported(self, code: str) -> bool: def is_supported(self, code: str) -> bool:
"""判断是否支持该代码""" """判断是否支持该代码"""
return AssetTypeDetector.detect(code) != AssetType.UNKNOWN return AssetTypeDetector.detect(code) != AssetType.UNKNOWN
# ============================================================
# 扩展层:资产类型特有方法(复权/净值/溢价率)
# ============================================================
def fetch_etf_adj(
self,
code: str,
start_date: str,
end_date: str
) -> Optional[pd.DataFrame]:
"""
获取 A股 ETF 后复权价格
通过 fund_daily + fund_adj 手动计算后复权价格
- 消除份额折算(拆分)对收益率的影响
- 适用于计算真实收益率
Args:
code: ETF代码'159915.SZ', '513100.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, close, adj_factor, close_hfq
示例:
# 纳指ETF后复权正确计算收益率
df = fetcher.fetch_etf_adj("513100.SH", "2020-01-01", "2024-12-31")
# 使用 close_hfq 计算收益率,而非 close
"""
return self._tushare.fetch_etf_adj(code, start_date, end_date)
def fetch_us_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'qfq'
) -> Optional[pd.DataFrame]:
"""
获取美股复权价格
使用 YFinance支持前复权(qfq)和后复权(hfq)
- 消除拆分(split)和分红(dividend)对价格的影响
- 适用于美股股票/ETF
Args:
code: 美股代码,如 'AAPL', 'TSLA', 'QQQ'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'qfq'
Returns:
DataFrame with columns: date, open, high, low, close, volume (复权后)
示例:
# 苹果复权价格(包含分红和拆分调整)
df = fetcher.fetch_us_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
"""
self._start_tunnel()
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
def fetch_hk_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'qfq'
) -> Optional[pd.DataFrame]:
"""
获取港股股票复权价格
使用 YFinance支持前复权(qfq)和后复权(hfq)
Args:
code: 港股代码,如 '00700.HK', '00941.HK'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'qfq'
Returns:
DataFrame with columns: date, open, high, low, close, volume (复权后)
"""
self._start_tunnel()
return self._yfinance.fetch_adj(code, start_date, end_date, adj)
def fetch_stock_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'hfq'
) -> Optional[pd.DataFrame]:
"""
获取 A股股票复权价格
使用 Tushare pro_bar 接口,支持前复权(qfq)和后复权(hfq)
Args:
code: A股股票代码'000001.SZ', '600000.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq'
Returns:
DataFrame with columns: date, open, high, low, close, volume, adj_factor
"""
return self._tushare.fetch_stock_adj(code, start_date, end_date, adj)
# ============================================================
# 统一复权入口(简化版,直接调用 fetch
# ============================================================
def fetch_with_adj(
self,
code: str,
start_date: str,
end_date: str,
adj: str = 'raw',
retry: int = 3
) -> Optional[pd.DataFrame]:
"""
统一复权入口(简化版)
直接调用 fetch(adj=adj),无需重复实现路由逻辑。
Args:
code: 标的代码
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型,默认 'raw'
retry: 重试次数
Returns:
DataFrame结构因资产类型和 adj 参数略有不同
示例:
# A股股票后复权
df = fetcher.fetch_with_adj("000001.SZ", "2020-01-01", "2024-12-31", adj='hfq')
# 美股股票前复权
df = fetcher.fetch_with_adj("AAPL", "2020-01-01", "2024-12-31", adj='qfq')
"""
# 直接调用 fetch传递 adj 参数
return self.fetch(code, start_date, end_date, adj, retry)

View File

@@ -44,19 +44,30 @@ class YFinanceSource:
self.use_ssh_tunnel = use_ssh_tunnel self.use_ssh_tunnel = use_ssh_tunnel
self._delay = 0.5 # 请求延迟(避免限流) self._delay = 0.5 # 请求延迟(避免限流)
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
""" """
获取数据 获取数据(支持 adj 参数)
Args: Args:
code: 代码(如 'NDX', 'N225', 'HSI' code: 代码(如 'NDX', 'N225', 'HSI', 'AAPL'
start_date: 开始日期 'YYYY-MM-DD' start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
Returns: Returns:
DataFrame with columns: date, open, high, low, close, volume DataFrame with columns: date, open, high, low, close, volume
股票元信息存储在 df.attrs['info'] 中 股票元信息存储在 df.attrs['info'] 中
adj='qfq/hfq' 时 df.attrs['adj'] 会标记复权类型
""" """
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'raw', 'qfq''hfq',当前: {adj}")
# 复权数据:调用 fetch_adj
if adj in ['qfq', 'hfq']:
return self.fetch_adj(code, start_date, end_date, adj)
# 原始数据:以下为原有逻辑
import yfinance as yf import yfinance as yf
# 添加延迟避免限流 # 添加延迟避免限流
@@ -107,6 +118,7 @@ class YFinanceSource:
# 将股票信息存储到 DataFrame.attrs 中(最外层结构) # 将股票信息存储到 DataFrame.attrs 中(最外层结构)
df.attrs['info'] = stock_info df.attrs['info'] = stock_info
df.attrs['code'] = code df.attrs['code'] = code
df.attrs['adj'] = 'raw'
return df[['code', 'open', 'high', 'low', 'close', 'volume']] return df[['code', 'open', 'high', 'low', 'close', 'volume']]
@@ -114,6 +126,84 @@ class YFinanceSource:
print(f"YFinance下载 {code} ({yf_code}) 失败: {e}") print(f"YFinance下载 {code} ({yf_code}) 失败: {e}")
return None return None
def fetch_adj(self, code: str, start_date: str, end_date: str, adj: str = 'qfq') -> Optional[pd.DataFrame]:
"""
获取复权价格数据
统一 adj 参数设计:
- 'qfq': 前复权 → yfinance auto_adjust=True (当前价不变)
- 'hfq': 后复权 → yfinance back_adjust=True (历史价不变)
Args:
code: 代码(如 'AAPL', 'TSLA', 'QQQ', '00700.HK'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'qfq'
Returns:
DataFrame with columns: date, code, open, high, low, close, volume (复权后)
"""
import yfinance as yf
if adj not in ['qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'qfq''hfq',当前: {adj}")
# 添加延迟避免限流
time.sleep(self._delay)
# 转换代码格式
yf_code = self.CODE_MAP.get(code, code)
# adj 参数映射到 yfinance 参数
# qfq(前复权) = auto_adjust=True, back_adjust=False (当前价不变)
# hfq(后复权) = auto_adjust=False, back_adjust=True (历史价不变)
adjust_params = {
'qfq': {'auto_adjust': True, 'back_adjust': False},
'hfq': {'auto_adjust': False, 'back_adjust': True},
}
try:
ticker = yf.Ticker(yf_code)
# end_date 需要加一天yfinance的end是排他的
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
# 根据 adj 参数设置复权方式
params = adjust_params[adj]
df = ticker.history(
start=start_date,
end=end_dt.strftime("%Y-%m-%d"),
auto_adjust=params['auto_adjust'],
back_adjust=params['back_adjust']
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"Open": "open",
"High": "high",
"Low": "low",
"Close": "close",
"Volume": "volume",
})
# 确保索引是日期格式
df.index = pd.to_datetime(df.index, utc=True).tz_localize(None).normalize()
df.index.name = "date"
# 添加代码列和标记
df["code"] = code
df.attrs['code'] = code
df.attrs['adj'] = adj
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"YFinance下载复权数据 {code} ({yf_code}) adj={adj} 失败: {e}")
return None
def is_yfinance_code(self, code: str) -> bool: def is_yfinance_code(self, code: str) -> bool:
"""判断是否需要YFinance获取""" """判断是否需要YFinance获取"""
# 非A股代码 # 非A股代码

View File

@@ -104,7 +104,7 @@ benchmark:
name: "沪深300" name: "沪深300"
# ==================== 回测参数 ==================== # ==================== 回测参数 ====================
start_date: "2002-01-01" start_date: "2020-01-01"
# ==================== 因子参数 ==================== # ==================== 因子参数 ====================
# 动量/趋势窗口期(天数) # 动量/趋势窗口期(天数)

View File

@@ -5,6 +5,7 @@
""" """
import pandas as pd import pandas as pd
import numpy as np
import yaml import yaml
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -113,7 +114,7 @@ class RotationStrategy(StrategyBase):
Args: Args:
use_flask_api: 是否使用 Flask API 服务获取数据(默认 True use_flask_api: 是否使用 Flask API 服务获取数据(默认 True
False 则使用本地 HybridDataSource False 则使用本地 UniversalDataFetcher
""" """
code_list_config = self.config.get('code_list', {}) code_list_config = self.config.get('code_list', {})
benchmark_config = self.config.get('benchmark', {}) benchmark_config = self.config.get('benchmark', {})
@@ -237,6 +238,12 @@ class RotationStrategy(StrategyBase):
index_close_dict[code] = df['close'] index_close_dict[code] = df['close']
index_close = pd.DataFrame(index_close_dict) if index_close_dict else None index_close = pd.DataFrame(index_close_dict) if index_close_dict else None
# 获取 A 股 SSE 官方交易日历
from datasource.tushare_source import TushareSource
tushare = TushareSource()
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
print(f"A股交易日历: {len(a_share_dates)}")
return { return {
'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame} 'index_data': index_ohlcv_data, # 原始 OHLCV 数据 {code: DataFrame}
'index_close': index_close, # 对齐后的收盘价(宽格式) 'index_close': index_close, # 对齐后的收盘价(宽格式)
@@ -245,7 +252,8 @@ class RotationStrategy(StrategyBase):
'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict} 'etf_premium_data': etf_premium_data, # ETF 溢价率数据 {code: dict}
'benchmark_data': benchmark_data, # 基准收盘价 Series 'benchmark_data': benchmark_data, # 基准收盘价 Series
'valid_codes': valid_codes, # 有效指数代码列表 'valid_codes': valid_codes, # 有效指数代码列表
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
'a_share_dates': a_share_dates # A股SSE交易日历
} }
def _get_data_from_local( def _get_data_from_local(
@@ -253,33 +261,90 @@ class RotationStrategy(StrategyBase):
code_list_config: dict, code_list_config: dict,
benchmark_code: str benchmark_code: str
) -> dict: ) -> dict:
"""使用本地 HybridDataSource 获取数据""" """使用本地 UniversalDataFetcher 获取数据"""
from datasource import HybridDataSource from datasource import UniversalDataFetcher
from datasource.tushare_source import TushareSource
ssh_config = self.config.get('ssh_tunnel', {}) ssh_config = self.config.get('ssh_tunnel', {})
data_source = HybridDataSource( fetcher = UniversalDataFetcher(
ssh_config=ssh_config, ssh_config=ssh_config,
use_cache=self.config.get('use_cache', True) use_cache=self.config.get('use_cache', True)
) )
# 调用 fetch_all index_codes = list(code_list_config.keys())
index_data, etf_data, etf_nav_data, benchmark_data, valid_codes, index_ohlcv_data, etf_code_map = \ etf_code_map = {idx_code: cfg['etf'] for idx_code, cfg in code_list_config.items() if cfg.get('etf')}
data_source.fetch_all(
code_config=code_list_config, # 获取指数数据
benchmark_code=benchmark_code, index_ohlcv_data = {}
start_date=self.start_date, valid_codes = []
end_date=self.end_date
) with fetcher: # 使用上下文管理器自动管理 SSH 隧道
for code in index_codes:
data = fetcher.fetch(code, self.start_date, self.end_date)
if data is not None and len(data) > 0:
index_ohlcv_data[code] = data
valid_codes.append(code)
print(f"{code}: {len(data)}")
else:
print(f"{code}: 无数据")
# 构建宽格式收盘价
index_close = None
if index_ohlcv_data:
close_list = []
for code, df in index_ohlcv_data.items():
close_df = df[['close']].copy()
close_df.columns = [code]
close_list.append(close_df)
index_close = pd.concat(close_list, axis=1)
# 获取 ETF 数据
etf_data = None
etf_nav_data = None
tushare = TushareSource()
if etf_code_map:
etf_price_list = []
etf_nav_list = []
for idx_code, etf_code in etf_code_map.items():
# ETF 价格
etf_df = tushare.fetch_etf(etf_code, self.start_date, self.end_date)
if etf_df is not None and len(etf_df) > 0:
etf_df = etf_df[['close']].copy()
etf_df.columns = [etf_code]
etf_price_list.append(etf_df)
# ETF 净值
nav_df = tushare.fetch_etf_nav(etf_code, self.start_date, self.end_date)
if nav_df is not None and len(nav_df) > 0:
nav_df = nav_df[['nav']].copy()
nav_df.columns = [etf_code]
etf_nav_list.append(nav_df)
if etf_price_list:
etf_data = pd.concat(etf_price_list, axis=1)
if etf_nav_list:
etf_nav_data = pd.concat(etf_nav_list, axis=1)
# 基准数据
benchmark_data = tushare.fetch_index(benchmark_code, self.start_date, self.end_date)
# A股交易日历
a_share_dates = tushare.fetch_trade_cal(self.start_date, self.end_date)
print(f"A股交易日历: {len(a_share_dates)}")
return { return {
'index_data': index_ohlcv_data, # 原始OHLCV数据 'index_data': index_ohlcv_data, # 原始OHLCV数据
'index_close': index_data, # 对齐后的收盘价(宽格式) 'index_close': index_close, # 对齐后的收盘价(宽格式)
'etf_data': etf_data, 'etf_data': etf_data,
'etf_nav_data': etf_nav_data, 'etf_nav_data': etf_nav_data,
'benchmark_data': benchmark_data, 'benchmark_data': benchmark_data,
'valid_codes': valid_codes, 'valid_codes': valid_codes,
'etf_code_map': etf_code_map # {指数代码: ETF代码} 映射 'etf_code_map': etf_code_map, # {指数代码: ETF代码} 映射
'a_share_dates': a_share_dates # A股SSE交易日历
} }
def compute_factors(self, data: dict) -> pd.DataFrame: def compute_factors(self, data: dict) -> pd.DataFrame:
@@ -290,7 +355,10 @@ class RotationStrategy(StrategyBase):
index_data = data['index_data'] index_data = data['index_data']
valid_codes = data['valid_codes'] valid_codes = data['valid_codes']
# 获取A股交易日历作为基准使用已有的对齐后数据索引 # 获取 A 股 SSE 官方交易日历(优先使用已获取的
a_share_dates = data.get('a_share_dates')
if a_share_dates is None or len(a_share_dates) == 0:
# 回退:使用已有的对齐后数据索引
index_close = data.get('index_close') index_close = data.get('index_close')
if index_close is not None: if index_close is not None:
a_share_dates = index_close.index a_share_dates = index_close.index
@@ -408,9 +476,16 @@ class RotationStrategy(StrategyBase):
# 4. 执行回测 # 4. 执行回测
print("\n执行回测...") print("\n执行回测...")
# 获取A股交易日历从因子数据索引 # 获取 A 股 SSE 官方交易日历(优先使用已获取的
a_share_dates = data.get('a_share_dates')
if a_share_dates is None or len(a_share_dates) == 0:
a_share_dates = signals.index a_share_dates = signals.index
# 将信号对齐到 A 股日历
if a_share_dates is not signals.index:
signals = signals.reindex(a_share_dates, method='ffill').dropna(subset=[signals.columns[0]])
print(f" 信号对齐到A股日历: {len(signals)}")
# 计算日收益率先在原始交易日历计算再对齐到A股日历 # 计算日收益率先在原始交易日历计算再对齐到A股日历
# 关键与因子计算逻辑一致避免交易日不对齐导致收益率NaN # 关键与因子计算逻辑一致避免交易日不对齐导致收益率NaN
returns_data = {} returns_data = {}