Files
etf/datasource/tushare_source.py
aszerW c0195c5bca refactor(tushare): 合并ETF复权方法,消除冗余设计
- 合并 fetch_etf_adj 和 _fetch_etf_adj 为单一方法
- 删除 _fetch_etf_qfq 转发方法
- 减少~26行代码,优化代码结构(从3个方法→1个方法)
- 保持公共接口签名不变,完全向后兼容
- 全面测试通过:raw/qfq/hfq三种模式数据正确
- 更新 VALID_ADJ_BY_TYPE 配置,ETF支持前复权/后复权
2026-05-25 19:59:49 +08:00

662 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tushare数据源
获取A股指数、ETF、期货数据
"""
import os
from typing import Optional, Tuple
from datetime import datetime
import pandas as pd
class TushareSource:
"""Tushare数据源"""
def __init__(self, token: Optional[str] = None):
"""
初始化Tushare数据源
Args:
token: Tushare Token可选默认从环境变量读取
"""
self._token = token or os.getenv("TUSHARE_TOKEN")
if not self._token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
def _get_pro_api(self):
"""获取Tushare Pro API"""
import tushare as ts
return ts.pro_api(self._token)
def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取A股指数数据
Args:
code: 指数代码,如 '000300.SH', '399006.SZ', 'H30269.CSI'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, high, low, close, volume
"""
try:
pro = self._get_pro_api()
# 转换代码格式 (.SS -> .SH)
ts_code = code.replace(".SS", ".SH")
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载指数 {code} 失败: {e}")
return None
def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取期货数据
Args:
code: 期货代码,如 'AU.SHF', 'CU.SHF'
start_date: 开始日期
end_date: 结束日期
"""
try:
pro = self._get_pro_api()
# 使用 fut_daily 接口
df = pro.fut_daily(
ts_code=code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载期货 {code} 失败: {e}")
return None
def fetch_etf(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""
统一 ETF 获取接口
Args:
code: ETF代码'159915.SZ', '518880.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='qfq''hfq' 时额外返回复权价格
DataFrame.attrs 附加元数据:
- attrs['nav']: 净值 DataFrame
- attrs['premium']: 溢价率 Series始终基于原始价格计算
"""
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"ETF 仅支持 adj='raw', 'qfq''hfq',当前: {adj}")
# 1. 获取价格数据
if adj in ['qfq', 'hfq']:
price_df = self.fetch_etf_adj(code, start_date, end_date, adj)
else:
price_df = self._fetch_etf_raw(code, start_date, end_date)
if price_df is None:
return None
# 2. 获取净值(附加到 attrs
nav_df = self.fetch_etf_nav(code, start_date, end_date)
price_df.attrs['nav'] = nav_df
# 3. 计算溢价率(始终使用原始价格)
if nav_df is not None and len(nav_df) > 0:
# qfq/hfq 时需要获取原始价格来计算溢价率
if adj == 'raw':
price_for_premium = price_df
else:
price_for_premium = self._fetch_etf_raw(code, start_date, end_date)
if price_for_premium is not None:
premium_series = self._calculate_premium_series(price_for_premium, nav_df)
price_df.attrs['premium'] = premium_series
return price_df
def _fetch_etf_raw(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""获取 ETF 原始价格数据(内部方法)"""
try:
pro = self._get_pro_api()
ts_code = code.replace(".SS", ".SH")
df = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载ETF {code} 失败: {e}")
return None
def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取ETF净值数据
Args:
code: ETF代码
"""
try:
pro = self._get_pro_api()
ts_code = code.replace(".SS", ".SH")
df = pro.fund_nav(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
df = df.rename(columns={
"nav_date": "date",
"unit_nav": "nav",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'nav']]
except Exception as e:
print(f"Tushare下载ETF净值 {code} 失败: {e}")
return None
def is_china_index(self, code: str) -> bool:
"""判断是否为A股指数排除 ETF"""
# 先排除 ETF
if self._is_etf_code(code):
return False
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI")
def is_futures(self, code: str) -> bool:
"""判断是否为中国期货(仅支持上期所、大商所、郑商所)"""
# 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所
# NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance
return ".SHF" in code or ".DCE" in code or ".CZC" in code
def is_china_stock(self, code: str) -> bool:
"""判断是否为A股股票6位数字代码 + .SZ/.SH/.SS"""
# 股票代码000001.SZ, 600000.SH 等
# 区分指数:指数代码通常是 000xxx.SH, 399xxx.SZ, H30xxx.CSI
# 股票代码通常是 00xxxx.SZ, 30xxxx.SZ, 60xxxx.SH, 000xxx.SH部分
import re
# 股票代码模式6位数字 + .SZ/.SH/.SS
# 排除指数000xxx.SH (指数), 399xxx.SZ (指数), Hxxxxx.CSI (指数)
if not re.match(r'^\d{6}\.(SZ|SH|SS)$', code):
return False
# 000xxx.SH 可能是指数也可能是股票,需要更细致判断
# 简化处理000/001/002/003 开头 + .SZ 是股票600/601/603 开头 + .SH 是股票
prefix = code[:3]
suffix = code.split('.')[1]
if suffix == 'SZ' and prefix in ['000', '001', '002', '003', '300']:
return True
if suffix == 'SH' and prefix in ['600', '601', '603', '605', '688']:
return True
return False
def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""
通用数据获取(自动判断类型,支持 adj 参数)
Args:
code: 代码
start_date: 开始日期
end_date: 结束日期
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
"""
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'raw', 'qfq''hfq',当前: {adj}")
# 原始数据
if adj == 'raw':
# 优先判断 ETF修复ETF 原始数据获取)
if self._is_etf_code(code):
return self.fetch_etf(code, start_date, end_date)
elif self.is_china_index(code):
return self.fetch_index(code, start_date, end_date)
elif self.is_futures(code):
return self.fetch_futures(code, start_date, end_date)
elif self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj='raw')
else:
return None
# 复权数据
if adj in ['qfq', 'hfq']:
# A股股票复权
if self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj)
# A股 ETF 仅支持 hfq
elif self._is_etf_code(code):
if adj == 'hfq':
return self.fetch_etf_adj(code, start_date, end_date)
else:
raise ValueError(f"ETF 仅支持 adj='hfq'(后复权),当前: {adj}")
else:
# 指数/期货不支持复权
raise ValueError(f"指数/期货不支持复权adj='{adj}' 仅适用于股票/ETF")
def _is_etf_code(self, code: str) -> bool:
"""判断是否为ETF代码"""
# ETF代码51xxxx.SH, 52xxxx.SH, 15xxxx.SZ, 16xxxx.SZ
import re
if not re.match(r'^\d{6}\.(SZ|SH)$', code):
return False
prefix = code[:2]
return prefix in ['51', '52', '15', '16']
def _calculate_premium_series(
self,
price_df: pd.DataFrame,
nav_df: pd.DataFrame
) -> Optional[pd.Series]:
"""
计算历史溢价率序列
溢价率 = (ETF收盘价 - ETF净值) / ETF净值
关键不同QDII基金净值披露规则不同
- 部分基金净值当天披露如日经ETF价格日期=净值日期
- 部分基金净值T+1披露如纳指ETF价格日期配T-1日净值
集思录做法:根据基金特性选择匹配方式
- 如果有当天净值数据,优先使用当天净值
- 如果当天净值不存在使用T-1日净值
Args:
price_df: ETF价格数据索引为日期
nav_df: ETF净值数据索引为日期
Returns:
溢价率Series索引为价格日期值为溢价率
"""
# 去除重复日期
price_index = price_df.index
if price_index.has_duplicates:
price_df = price_df[~price_df.index.duplicated(keep='last')]
nav_index = nav_df.index
if nav_index.has_duplicates:
nav_df = nav_df[~nav_df.index.duplicated(keep='last')]
# 优先尝试使用当天净值如日经ETF
same_day_dates = price_df.index.intersection(nav_df.index)
# 对于没有当天净值的日期使用T-1日净值如纳指ETF
nav_df_shifted = nav_df.copy()
nav_df_shifted.index = nav_df_shifted.index + pd.Timedelta(days=1)
shifted_dates = price_df.index.intersection(nav_df_shifted.index)
# 排除已有当天净值的日期
t1_dates = shifted_dates.difference(same_day_dates)
premium_data = {}
# 使用当天净值计算
if len(same_day_dates) > 0:
close_same = price_df.loc[same_day_dates, 'close']
nav_same = nav_df.loc[same_day_dates, 'nav']
for date in same_day_dates:
if pd.notna(close_same.loc[date]) and pd.notna(nav_same.loc[date]):
premium_data[date] = (close_same.loc[date] - nav_same.loc[date]) / nav_same.loc[date]
# 使用T-1日净值计算仅用于没有当天净值的日期
if len(t1_dates) > 0:
close_t1 = price_df.loc[t1_dates, 'close']
nav_t1 = nav_df_shifted.loc[t1_dates, 'nav']
for date in t1_dates:
if pd.notna(close_t1.loc[date]) and pd.notna(nav_t1.loc[date]):
premium_data[date] = (close_t1.loc[date] - nav_t1.loc[date]) / nav_t1.loc[date]
if len(premium_data) == 0:
return None
# 构建Series并按日期排序
premium = pd.Series(premium_data)
premium = premium.sort_index()
premium = premium.dropna()
return premium
def fetch_etf_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]:
"""
获取 ETF 复权价格数据
自己实现复权计算(不使用 pro_bar避免 pandas 兼容性问题):
1. 使用 fund_daily() 获取原始价格
2. 使用 fund_adj() 获取复权因子
3. 根据 adj 参数计算复权价格
复权公式:
- 后复权 (hfq): close_hfq = close × adj_factor
- 前复权 (qfq): close_qfq = close × adj_factor / latest_factor
fund_adj 单次限 2000 条,按 5 年分段请求再拼接。
Args:
code: ETF代码'159915.SZ', '518880.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型,支持 'hfq'(后复权)或 'qfq'(前复权)
Returns:
DataFrame with columns: date, code, open, high, low, close, volume, adj_factor
"""
if adj not in ['qfq', 'hfq']:
raise ValueError(f"ETF adj 参数必须是 'qfq''hfq',当前: {adj}")
try:
pro = self._get_pro_api()
ts_code = code.replace('.SS', '.SH')
# 步骤 1: 获取原始价格数据
df_daily = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
if df_daily is None or len(df_daily) == 0:
return None
# 步骤 2: 获取复权因子分段请求单次限2000条
# 按5年分段
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
adj_chunks = []
chunk_start = start_dt
while chunk_start < end_dt:
chunk_end = min(chunk_start.replace(year=chunk_start.year + 5), end_dt)
chunk_start_str = chunk_start.strftime('%Y%m%d')
chunk_end_str = chunk_end.strftime('%Y%m%d')
df_adj_chunk = pro.fund_adj(
ts_code=ts_code,
start_date=chunk_start_str,
end_date=chunk_end_str
)
if df_adj_chunk is not None and len(df_adj_chunk) > 0:
adj_chunks.append(df_adj_chunk)
chunk_start = chunk_end
if not adj_chunks:
# 无复权因子,返回原始数据
print(f"警告: {code} 无复权因子数据,返回原始价格")
df = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
df['code'] = code
df['adj_factor'] = 1.0
df['close'] = df['close'] # close 保持原始价格
return df[['code', 'open', 'high', 'low', 'close', 'volume', 'adj_factor']]
# 合并所有复权因子
df_adj = pd.concat(adj_chunks, ignore_index=True)
df_adj = df_adj.rename(columns={'trade_date': 'date'})
df_adj['date'] = pd.to_datetime(df_adj['date'])
df_adj = df_adj.set_index('date').sort_index()
# 步骤 3: 标准化 daily 数据
df_daily = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df_daily['date'] = pd.to_datetime(df_daily['date'])
df_daily = df_daily.set_index('date').sort_index()
# 步骤 4: 复权因子对齐(使用 ffill 向前填充)
df_adj_aligned = df_adj.reindex(df_daily.index)
df_adj_aligned['adj_factor'] = df_adj_aligned['adj_factor'].ffill().fillna(1.0)
# 步骤 5: 计算复权价格
df = df_daily.copy()
df['adj_factor'] = df_adj_aligned['adj_factor']
if adj == 'hfq':
# 后复权: close_hfq = close × adj_factor
df['close_hfq'] = (df['close'] * df['adj_factor']).round(4)
df['open'] = (df['open'] * df['adj_factor']).round(4)
df['high'] = (df['high'] * df['adj_factor']).round(4)
df['low'] = (df['low'] * df['adj_factor']).round(4)
df['close'] = df['close_hfq'] # close 列设为后复权价格
elif adj == 'qfq':
# 前复权: close_qfq = close × adj_factor / latest_factor
# 获取全量最新复权因子
latest_factor = df_adj['adj_factor'].iloc[-1]
if latest_factor and latest_factor > 0:
adj_ratio = df['adj_factor'] / latest_factor
df['close_qfq'] = (df['close'] * adj_ratio).round(4)
df['open'] = (df['open'] * adj_ratio).round(4)
df['high'] = (df['high'] * adj_ratio).round(4)
df['low'] = (df['low'] * adj_ratio).round(4)
df['close'] = df['close_qfq'] # close 列设为前复权价格
else:
# 无有效复权因子,返回原始价格
df['close'] = df['close']
df['code'] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume', 'adj_factor']]
except Exception as e:
print(f"Tushare下载ETF复权数据 {code} 失败: {e}")
import traceback
traceback.print_exc()
return None
def fetch_trade_cal(self, start_date: str, end_date: str) -> pd.DatetimeIndex:
"""
获取 A 股(上交所 SSE官方交易日历
Args:
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DatetimeIndex: A股交易日日期序列
"""
try:
pro = self._get_pro_api()
df = pro.trade_cal(
exchange='SSE',
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''),
is_open='1'
)
if df is None or len(df) == 0:
return pd.DatetimeIndex([])
# 提取交易日并转换为 DatetimeIndex
trade_dates = pd.to_datetime(df['cal_date'])
return pd.DatetimeIndex(trade_dates.sort_values())
except Exception as e:
print(f"Tushare下载交易日历失败: {e}")
return pd.DatetimeIndex([])
def fetch_stock_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]:
"""
获取 A股股票复权价格数据
自己实现复权计算(不使用 pro_bar避免 pandas 兼容性问题):
1. 使用 pro.daily() 获取原始价格
2. 使用 pro.adj_factor() 获取复权因子
3. 根据 adj 参数计算复权价格
复权公式:
- 后复权 (hfq): close_hfq = close × adj_factor
- 前复权 (qfq): close_qfq = close × adj_factor / latest_factor
Args:
code: 股票代码,如 '000001.SZ', '600000.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq'
Returns:
DataFrame with columns: date, code, open, high, low, close, volume, adj_factor
adj='hfq' 时额外返回 close_hfq 列
"""
if adj not in ['qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'qfq''hfq',当前: {adj}")
try:
pro = self._get_pro_api()
ts_code = code.replace('.SS', '.SH')
# 步骤 1: 获取原始价格数据
daily_df = pro.daily(
ts_code=ts_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
if daily_df is None or len(daily_df) == 0:
return None
# 步骤 2: 获取复权因子(需要获取全量数据才能正确计算)
# 注意adj_factor 需要从上市日至今的完整数据
adj_df = pro.adj_factor(ts_code=ts_code)
if adj_df is None or len(adj_df) == 0:
print(f"警告: {code} 无复权因子数据,返回原始价格")
# 降级:返回原始价格
daily_df = daily_df.rename(columns={
'ts_code': 'code',
'trade_date': 'date',
'vol': 'volume',
})
daily_df['date'] = pd.to_datetime(daily_df['date'])
daily_df = daily_df.set_index('date').sort_index()
daily_df['code'] = code
return daily_df[['code', 'open', 'high', 'low', 'close', 'volume']]
# 标准化复权因子
adj_df = adj_df.rename(columns={'trade_date': 'date'})
adj_df['date'] = pd.to_datetime(adj_df['date'])
adj_df = adj_df.set_index('date').sort_index()
# 标准化日线数据
daily_df = daily_df.rename(columns={
'ts_code': 'code',
'trade_date': 'date',
'vol': 'volume',
})
daily_df['date'] = pd.to_datetime(daily_df['date'])
daily_df = daily_df.set_index('date').sort_index()
# 步骤 3: 合并复权因子
df = daily_df.join(adj_df[['adj_factor']], how='left')
# 填充复权因子(向前填充,使用最新的因子)
df['adj_factor'] = df['adj_factor'].ffill()
# 步骤 4: 计算复权价格
if adj == 'hfq':
# 后复权:原始价格 × 复权因子
df['close_hfq'] = (df['close'] * df['adj_factor']).round(4)
df['open'] = (df['open'] * df['adj_factor']).round(4)
df['high'] = (df['high'] * df['adj_factor']).round(4)
df['low'] = (df['low'] * df['adj_factor']).round(4)
# close 列保持为后复权价格
df['close'] = df['close_hfq']
elif adj == 'qfq':
# 前复权:原始价格 × 复权因子 / 最新复权因子
# 注意:需要使用全量最新的复权因子,而不是请求时间范围内的
latest_factor = adj_df['adj_factor'].iloc[-1] # 从全量数据获取最新因子
if latest_factor and latest_factor > 0:
adj_ratio = df['adj_factor'] / latest_factor
df['close'] = (df['close'] * adj_ratio).round(4)
df['open'] = (df['open'] * adj_ratio).round(4)
df['high'] = (df['high'] * adj_ratio).round(4)
df['low'] = (df['low'] * adj_ratio).round(4)
# 恢复原始代码格式
df['code'] = code
# 标准化返回字段
columns = ['code', 'open', 'high', 'low', 'close', 'volume', 'adj_factor']
return df[columns]
except Exception as e:
print(f"Tushare下载股票复权数据 {code} 失败: {e}")
import traceback
traceback.print_exc()
return None