Files
etf/datasource/tushare_source.py
aszerW feb7c78e68 refactor: 统一ETF获取接口为单个DataFrame返回
重构说明:
- TushareSource.fetch_etf(): 新增 adj 参数,统一接口
  - 返回单个 DataFrame
  - df.attrs['nav']: 净值 DataFrame
  - df.attrs['premium']: 溢价率 Series
- 移除冗余方法:
  - fetch_etf_with_nav() → 合并到 fetch_etf()
  - fetch_etf_adj() → 重命名为 _fetch_etf_hfq()(内部方法)
- UniversalDataFetcher: 适配新接口
  - fetch_etf_with_nav(): 从 df.attrs 提取元数据(兼容旧接口)
  - fetch_etf_adj(): 调用 fetch_etf(adj='hfq')
- Flask: 更新注释说明

架构优势:
- 单一接口:一个方法搞定所有 ETF 数据获取
- 数据一致:所有数据在一个 DataFrame 对象中
- 缓存友好:只需缓存一个 DataFrame
- 扩展性强:新增数据直接添加到 attrs
2026-05-23 22:36:23 +08:00

571 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tushare数据源
获取A股指数、ETF、期货数据
"""
import os
from typing import Optional, Tuple
from datetime import datetime
import pandas as pd
class TushareSource:
"""Tushare数据源"""
def __init__(self, token: Optional[str] = None):
"""
初始化Tushare数据源
Args:
token: Tushare Token可选默认从环境变量读取
"""
self._token = token or os.getenv("TUSHARE_TOKEN")
if not self._token:
raise ValueError("请设置环境变量 TUSHARE_TOKEN")
def _get_pro_api(self):
"""获取Tushare Pro API"""
import tushare as ts
return ts.pro_api(self._token)
def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取A股指数数据
Args:
code: 指数代码,如 '000300.SH', '399006.SZ', 'H30269.CSI'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, high, low, close, volume
"""
try:
pro = self._get_pro_api()
# 转换代码格式 (.SS -> .SH)
ts_code = code.replace(".SS", ".SH")
df = pro.index_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
# 转换日期格式
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载指数 {code} 失败: {e}")
return None
def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取期货数据
Args:
code: 期货代码,如 'AU.SHF', 'CU.SHF'
start_date: 开始日期
end_date: 结束日期
"""
try:
pro = self._get_pro_api()
# 使用 fut_daily 接口
df = pro.fut_daily(
ts_code=code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载期货 {code} 失败: {e}")
return None
def fetch_etf(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""
统一 ETF 获取接口
Args:
code: ETF代码'159915.SZ', '518880.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'raw'(原始) / 'hfq'(后复权),默认 'raw'
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时额外返回 adj_factor, close_hfq
DataFrame.attrs 附加元数据:
- attrs['nav']: 净值 DataFrame
- attrs['premium']: 溢价率 Series始终基于原始价格计算
"""
# 校验 adj 参数
if adj not in ['raw', 'hfq']:
raise ValueError(f"ETF 仅支持 adj='raw''hfq',当前: {adj}")
# 1. 获取价格数据
if adj == 'hfq':
price_df = self._fetch_etf_hfq(code, start_date, end_date)
else:
price_df = self._fetch_etf_raw(code, start_date, end_date)
if price_df is None:
return None
# 2. 获取净值(附加到 attrs
nav_df = self.fetch_etf_nav(code, start_date, end_date)
price_df.attrs['nav'] = nav_df
# 3. 计算溢价率(始终使用原始价格)
if nav_df is not None and len(nav_df) > 0:
# hfq 时需要获取原始价格来计算溢价率
price_for_premium = price_df if adj == 'raw' else self._fetch_etf_raw(code, start_date, end_date)
if price_for_premium is not None:
premium_series = self._calculate_premium_series(price_for_premium, nav_df)
price_df.attrs['premium'] = premium_series
return price_df
def _fetch_etf_raw(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""获取 ETF 原始价格数据(内部方法)"""
try:
pro = self._get_pro_api()
ts_code = code.replace(".SS", ".SH")
df = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
df = df.rename(columns={
"trade_date": "date",
"vol": "volume",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
except Exception as e:
print(f"Tushare下载ETF {code} 失败: {e}")
return None
def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取ETF净值数据
Args:
code: ETF代码
"""
try:
pro = self._get_pro_api()
ts_code = code.replace(".SS", ".SH")
df = pro.fund_nav(
ts_code=ts_code,
start_date=start_date.replace("-", ""),
end_date=end_date.replace("-", "")
)
if df is None or len(df) == 0:
return None
df = df.rename(columns={
"nav_date": "date",
"unit_nav": "nav",
})
df["date"] = pd.to_datetime(df["date"])
df = df.set_index("date")
df = df.sort_index()
df["code"] = code
return df[['code', 'nav']]
except Exception as e:
print(f"Tushare下载ETF净值 {code} 失败: {e}")
return None
def is_china_index(self, code: str) -> bool:
"""判断是否为A股指数排除 ETF"""
# 先排除 ETF
if self._is_etf_code(code):
return False
return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI")
def is_futures(self, code: str) -> bool:
"""判断是否为中国期货(仅支持上期所、大商所、郑商所)"""
# 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所
# NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance
return ".SHF" in code or ".DCE" in code or ".CZC" in code
def is_china_stock(self, code: str) -> bool:
"""判断是否为A股股票6位数字代码 + .SZ/.SH/.SS"""
# 股票代码000001.SZ, 600000.SH 等
# 区分指数:指数代码通常是 000xxx.SH, 399xxx.SZ, H30xxx.CSI
# 股票代码通常是 00xxxx.SZ, 30xxxx.SZ, 60xxxx.SH, 000xxx.SH部分
import re
# 股票代码模式6位数字 + .SZ/.SH/.SS
# 排除指数000xxx.SH (指数), 399xxx.SZ (指数), Hxxxxx.CSI (指数)
if not re.match(r'^\d{6}\.(SZ|SH|SS)$', code):
return False
# 000xxx.SH 可能是指数也可能是股票,需要更细致判断
# 简化处理000/001/002/003 开头 + .SZ 是股票600/601/603 开头 + .SH 是股票
prefix = code[:3]
suffix = code.split('.')[1]
if suffix == 'SZ' and prefix in ['000', '001', '002', '003', '300']:
return True
if suffix == 'SH' and prefix in ['600', '601', '603', '605', '688']:
return True
return False
def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""
通用数据获取(自动判断类型,支持 adj 参数)
Args:
code: 代码
start_date: 开始日期
end_date: 结束日期
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
Returns:
DataFrame with columns: date, open, high, low, close, volume
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
"""
# 校验 adj 参数
if adj not in ['raw', 'qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'raw', 'qfq''hfq',当前: {adj}")
# 原始数据
if adj == 'raw':
# 优先判断 ETF修复ETF 原始数据获取)
if self._is_etf_code(code):
return self.fetch_etf(code, start_date, end_date)
elif self.is_china_index(code):
return self.fetch_index(code, start_date, end_date)
elif self.is_futures(code):
return self.fetch_futures(code, start_date, end_date)
elif self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj='raw')
else:
return None
# 复权数据
if adj in ['qfq', 'hfq']:
# A股股票复权
if self.is_china_stock(code):
return self.fetch_stock_adj(code, start_date, end_date, adj)
# A股 ETF 仅支持 hfq
elif self._is_etf_code(code):
if adj == 'hfq':
return self.fetch_etf_adj(code, start_date, end_date)
else:
raise ValueError(f"ETF 仅支持 adj='hfq'(后复权),当前: {adj}")
else:
# 指数/期货不支持复权
raise ValueError(f"指数/期货不支持复权adj='{adj}' 仅适用于股票/ETF")
def _is_etf_code(self, code: str) -> bool:
"""判断是否为ETF代码"""
# ETF代码51xxxx.SH, 52xxxx.SH, 15xxxx.SZ, 16xxxx.SZ
import re
if not re.match(r'^\d{6}\.(SZ|SH)$', code):
return False
prefix = code[:2]
return prefix in ['51', '52', '15', '16']
def _calculate_premium_series(
self,
price_df: pd.DataFrame,
nav_df: pd.DataFrame
) -> Optional[pd.Series]:
"""
计算历史溢价率序列
溢价率 = (ETF收盘价 - ETF净值) / ETF净值
关键不同QDII基金净值披露规则不同
- 部分基金净值当天披露如日经ETF价格日期=净值日期
- 部分基金净值T+1披露如纳指ETF价格日期配T-1日净值
集思录做法:根据基金特性选择匹配方式
- 如果有当天净值数据,优先使用当天净值
- 如果当天净值不存在使用T-1日净值
Args:
price_df: ETF价格数据索引为日期
nav_df: ETF净值数据索引为日期
Returns:
溢价率Series索引为价格日期值为溢价率
"""
# 去除重复日期
price_index = price_df.index
if price_index.has_duplicates:
price_df = price_df[~price_df.index.duplicated(keep='last')]
nav_index = nav_df.index
if nav_index.has_duplicates:
nav_df = nav_df[~nav_df.index.duplicated(keep='last')]
# 优先尝试使用当天净值如日经ETF
same_day_dates = price_df.index.intersection(nav_df.index)
# 对于没有当天净值的日期使用T-1日净值如纳指ETF
nav_df_shifted = nav_df.copy()
nav_df_shifted.index = nav_df_shifted.index + pd.Timedelta(days=1)
shifted_dates = price_df.index.intersection(nav_df_shifted.index)
# 排除已有当天净值的日期
t1_dates = shifted_dates.difference(same_day_dates)
premium_data = {}
# 使用当天净值计算
if len(same_day_dates) > 0:
close_same = price_df.loc[same_day_dates, 'close']
nav_same = nav_df.loc[same_day_dates, 'nav']
for date in same_day_dates:
if pd.notna(close_same.loc[date]) and pd.notna(nav_same.loc[date]):
premium_data[date] = (close_same.loc[date] - nav_same.loc[date]) / nav_same.loc[date]
# 使用T-1日净值计算仅用于没有当天净值的日期
if len(t1_dates) > 0:
close_t1 = price_df.loc[t1_dates, 'close']
nav_t1 = nav_df_shifted.loc[t1_dates, 'nav']
for date in t1_dates:
if pd.notna(close_t1.loc[date]) and pd.notna(nav_t1.loc[date]):
premium_data[date] = (close_t1.loc[date] - nav_t1.loc[date]) / nav_t1.loc[date]
if len(premium_data) == 0:
return None
# 构建Series并按日期排序
premium = pd.Series(premium_data)
premium = premium.sort_index()
premium = premium.dropna()
return premium
def _fetch_etf_hfq(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
"""
获取 ETF 后复权价格数据(内部方法)
通过 fund_daily + fund_adj 手动计算后复权价格,消除份额折算(拆分)对收益率的影响。
fund_adj 单次限 2000 条,按 5 年分段请求再拼接。
Args:
code: ETF代码'159915.SZ', '518880.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DataFrame with columns: date, open, close, adj_factor, close_hfq
"""
try:
pro = self._get_pro_api()
ts_code = code.replace('.SS', '.SH')
# 获取 fund_daily 数据
df_daily = pro.fund_daily(
ts_code=ts_code,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', '')
)
if df_daily is None or len(df_daily) == 0:
return None
# 获取 fund_adj 数据分段请求单次限2000条
# 按5年分段
start_dt = datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.strptime(end_date, '%Y-%m-%d')
adj_chunks = []
chunk_start = start_dt
while chunk_start < end_dt:
chunk_end = min(chunk_start.replace(year=chunk_start.year + 5), end_dt)
chunk_start_str = chunk_start.strftime('%Y%m%d')
chunk_end_str = chunk_end.strftime('%Y%m%d')
df_adj_chunk = pro.fund_adj(
ts_code=ts_code,
start_date=chunk_start_str,
end_date=chunk_end_str
)
if df_adj_chunk is not None and len(df_adj_chunk) > 0:
adj_chunks.append(df_adj_chunk)
chunk_start = chunk_end
if not adj_chunks:
# 无复权因子,返回原始数据
df = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
df['adj_factor'] = 1.0
df['close_hfq'] = df['close']
df['code'] = code
return df[['code', 'open', 'close', 'adj_factor', 'close_hfq']]
# 合并所有复权因子
df_adj = pd.concat(adj_chunks, ignore_index=True)
df_adj = df_adj.rename(columns={'trade_date': 'date'})
df_adj['date'] = pd.to_datetime(df_adj['date'])
df_adj = df_adj.set_index('date').sort_index()
# 合并 daily 和 adj
df_daily = df_daily.rename(columns={'trade_date': 'date', 'vol': 'volume'})
df_daily['date'] = pd.to_datetime(df_daily['date'])
df_daily = df_daily.set_index('date').sort_index()
# 复权因子对齐(用最新值)
df_adj_aligned = df_adj.reindex(df_daily.index, method='ffill')
df_adj_aligned['adj_factor'] = df_adj_aligned['adj_factor'].fillna(1.0)
# 计算后复权价格
df = df_daily.copy()
df['adj_factor'] = df_adj_aligned['adj_factor']
df['close_hfq'] = df['close'] * df['adj_factor']
df['code'] = code
return df[['code', 'open', 'close', 'adj_factor', 'close_hfq']]
except Exception as e:
print(f"Tushare下载ETF复权数据 {code} 失败: {e}")
return None
def fetch_trade_cal(self, start_date: str, end_date: str) -> pd.DatetimeIndex:
"""
获取 A 股(上交所 SSE官方交易日历
Args:
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
Returns:
DatetimeIndex: A股交易日日期序列
"""
try:
pro = self._get_pro_api()
df = pro.trade_cal(
exchange='SSE',
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''),
is_open='1'
)
if df is None or len(df) == 0:
return pd.DatetimeIndex([])
# 提取交易日并转换为 DatetimeIndex
trade_dates = pd.to_datetime(df['cal_date'])
return pd.DatetimeIndex(trade_dates.sort_values())
except Exception as e:
print(f"Tushare下载交易日历失败: {e}")
return pd.DatetimeIndex([])
def fetch_stock_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]:
"""
获取 A股股票复权价格数据
使用 pro_bar 接口获取前复权(qfq)或后复权(hfq)价格。
Args:
code: 股票代码,如 '000001.SZ', '600000.SH'
start_date: 开始日期 'YYYY-MM-DD'
end_date: 结束日期 'YYYY-MM-DD'
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq'
Returns:
DataFrame with columns: date, code, open, high, low, close, volume, adj_factor
"""
import tushare as ts
if adj not in ['qfq', 'hfq']:
raise ValueError(f"adj 参数必须是 'qfq''hfq',当前: {adj}")
try:
ts_code = code.replace('.SS', '.SH')
# 使用 pro_bar 接口获取复权数据
df = ts.pro_bar(
ts_code=ts_code,
adj=adj,
start_date=start_date.replace('-', ''),
end_date=end_date.replace('-', ''),
adjfactor=True # 返回复权因子
)
if df is None or len(df) == 0:
return None
# 标准化列名
df = df.rename(columns={
'ts_code': 'code',
'trade_date': 'date',
'vol': 'volume',
})
# 转换日期格式
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
df = df.sort_index()
# 恢复原始代码格式(.SS -> .SH 反转)
df['code'] = code
# 标准化返回字段
columns = ['code', 'open', 'high', 'low', 'close', 'volume']
if 'adj_factor' in df.columns:
columns.append('adj_factor')
return df[columns]
except Exception as e:
print(f"Tushare下载股票复权数据 {code} 失败: {e}")
return None