refactor(datasource): 底层fetch方法添加adj参数
TushareSource.fetch() 和 YFinanceSource.fetch() 新增 adj 参数支持 raw/qfq/hfq - TushareSource.fetch(adj='raw'): 内部路由到 fetch_index/fetch_stock_adj/fetch_etf_adj - YFinanceSource.fetch(adj='raw'): 内部路由到 fetch_adj() 或原始逻辑 - 添加 is_china_stock() 和 _is_etf_code() 方法用于资产类型判断
This commit is contained in:
@@ -196,22 +196,79 @@ class TushareSource:
|
||||
# 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所)
|
||||
# NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance
|
||||
return ".SHF" in code or ".DCE" in code or ".CZC" in code
|
||||
|
||||
def is_china_stock(self, code: str) -> bool:
|
||||
"""判断是否为A股股票(6位数字代码 + .SZ/.SH/.SS)"""
|
||||
# 股票代码:000001.SZ, 600000.SH 等
|
||||
# 区分指数:指数代码通常是 000xxx.SH, 399xxx.SZ, H30xxx.CSI
|
||||
# 股票代码通常是 00xxxx.SZ, 30xxxx.SZ, 60xxxx.SH, 000xxx.SH(部分)
|
||||
import re
|
||||
# 股票代码模式:6位数字 + .SZ/.SH/.SS
|
||||
# 排除指数:000xxx.SH (指数), 399xxx.SZ (指数), Hxxxxx.CSI (指数)
|
||||
if not re.match(r'^\d{6}\.(SZ|SH|SS)$', code):
|
||||
return False
|
||||
# 000xxx.SH 可能是指数也可能是股票,需要更细致判断
|
||||
# 简化处理:000/001/002/003 开头 + .SZ 是股票,600/601/603 开头 + .SH 是股票
|
||||
prefix = code[:3]
|
||||
suffix = code.split('.')[1]
|
||||
if suffix == 'SZ' and prefix in ['000', '001', '002', '003', '300']:
|
||||
return True
|
||||
if suffix == 'SH' and prefix in ['600', '601', '603', '605', '688']:
|
||||
return True
|
||||
return False
|
||||
|
||||
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
通用数据获取(自动判断类型)
|
||||
通用数据获取(自动判断类型,支持 adj 参数)
|
||||
|
||||
Args:
|
||||
code: 代码
|
||||
start_date: 开始日期
|
||||
end_date: 结束日期
|
||||
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq
|
||||
"""
|
||||
if self.is_china_index(code):
|
||||
return self.fetch_index(code, start_date, end_date)
|
||||
elif self.is_futures(code):
|
||||
return self.fetch_futures(code, start_date, end_date)
|
||||
else:
|
||||
return None
|
||||
# 校验 adj 参数
|
||||
if adj not in ['raw', 'qfq', 'hfq']:
|
||||
raise ValueError(f"adj 参数必须是 'raw', 'qfq' 或 'hfq',当前: {adj}")
|
||||
|
||||
# 原始数据
|
||||
if adj == 'raw':
|
||||
if self.is_china_index(code):
|
||||
return self.fetch_index(code, start_date, end_date)
|
||||
elif self.is_futures(code):
|
||||
return self.fetch_futures(code, start_date, end_date)
|
||||
elif self.is_china_stock(code):
|
||||
return self.fetch_stock_adj(code, start_date, end_date, adj='raw')
|
||||
else:
|
||||
return None
|
||||
|
||||
# 复权数据
|
||||
if adj in ['qfq', 'hfq']:
|
||||
# A股股票复权
|
||||
if self.is_china_stock(code):
|
||||
return self.fetch_stock_adj(code, start_date, end_date, adj)
|
||||
# A股 ETF 仅支持 hfq
|
||||
elif self._is_etf_code(code):
|
||||
if adj == 'hfq':
|
||||
return self.fetch_etf_adj(code, start_date, end_date)
|
||||
else:
|
||||
raise ValueError(f"ETF 仅支持 adj='hfq'(后复权),当前: {adj}")
|
||||
else:
|
||||
# 指数/期货不支持复权
|
||||
raise ValueError(f"指数/期货不支持复权,adj='{adj}' 仅适用于股票/ETF")
|
||||
|
||||
def _is_etf_code(self, code: str) -> bool:
|
||||
"""判断是否为ETF代码"""
|
||||
# ETF代码:51xxxx.SH, 52xxxx.SH, 15xxxx.SZ, 16xxxx.SZ
|
||||
import re
|
||||
if not re.match(r'^\d{6}\.(SZ|SH)$', code):
|
||||
return False
|
||||
prefix = code[:2]
|
||||
return prefix in ['51', '52', '15', '16']
|
||||
|
||||
def fetch_etf_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
@@ -332,4 +389,65 @@ class TushareSource:
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载交易日历失败: {e}")
|
||||
return pd.DatetimeIndex([])
|
||||
return pd.DatetimeIndex([])
|
||||
|
||||
def fetch_stock_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取 A股股票复权价格数据
|
||||
|
||||
使用 pro_bar 接口获取前复权(qfq)或后复权(hfq)价格。
|
||||
|
||||
Args:
|
||||
code: 股票代码,如 '000001.SZ', '600000.SH'
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, code, open, high, low, close, volume, adj_factor
|
||||
"""
|
||||
import tushare as ts
|
||||
|
||||
if adj not in ['qfq', 'hfq']:
|
||||
raise ValueError(f"adj 参数必须是 'qfq' 或 'hfq',当前: {adj}")
|
||||
|
||||
try:
|
||||
ts_code = code.replace('.SS', '.SH')
|
||||
|
||||
# 使用 pro_bar 接口获取复权数据
|
||||
df = ts.pro_bar(
|
||||
ts_code=ts_code,
|
||||
adj=adj,
|
||||
start_date=start_date.replace('-', ''),
|
||||
end_date=end_date.replace('-', ''),
|
||||
adjfactor=True # 返回复权因子
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
return None
|
||||
|
||||
# 标准化列名
|
||||
df = df.rename(columns={
|
||||
'ts_code': 'code',
|
||||
'trade_date': 'date',
|
||||
'vol': 'volume',
|
||||
})
|
||||
|
||||
# 转换日期格式
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
df = df.set_index('date')
|
||||
df = df.sort_index()
|
||||
|
||||
# 恢复原始代码格式(.SS -> .SH 反转)
|
||||
df['code'] = code
|
||||
|
||||
# 标准化返回字段
|
||||
columns = ['code', 'open', 'high', 'low', 'close', 'volume']
|
||||
if 'adj_factor' in df.columns:
|
||||
columns.append('adj_factor')
|
||||
|
||||
return df[columns]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Tushare下载股票复权数据 {code} 失败: {e}")
|
||||
return None
|
||||
@@ -44,19 +44,30 @@ class YFinanceSource:
|
||||
self.use_ssh_tunnel = use_ssh_tunnel
|
||||
self._delay = 0.5 # 请求延迟(避免限流)
|
||||
|
||||
def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取数据
|
||||
获取数据(支持 adj 参数)
|
||||
|
||||
Args:
|
||||
code: 代码(如 'NDX', 'N225', 'HSI')
|
||||
code: 代码(如 'NDX', 'N225', 'HSI', 'AAPL')
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume
|
||||
股票元信息存储在 df.attrs['info'] 中
|
||||
adj='qfq/hfq' 时 df.attrs['adj'] 会标记复权类型
|
||||
"""
|
||||
# 校验 adj 参数
|
||||
if adj not in ['raw', 'qfq', 'hfq']:
|
||||
raise ValueError(f"adj 参数必须是 'raw', 'qfq' 或 'hfq',当前: {adj}")
|
||||
|
||||
# 复权数据:调用 fetch_adj
|
||||
if adj in ['qfq', 'hfq']:
|
||||
return self.fetch_adj(code, start_date, end_date, adj)
|
||||
|
||||
# 原始数据:以下为原有逻辑
|
||||
import yfinance as yf
|
||||
|
||||
# 添加延迟避免限流
|
||||
@@ -107,6 +118,7 @@ class YFinanceSource:
|
||||
# 将股票信息存储到 DataFrame.attrs 中(最外层结构)
|
||||
df.attrs['info'] = stock_info
|
||||
df.attrs['code'] = code
|
||||
df.attrs['adj'] = 'raw'
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
@@ -114,41 +126,55 @@ class YFinanceSource:
|
||||
print(f"YFinance下载 {code} ({yf_code}) 失败: {e}")
|
||||
return None
|
||||
|
||||
def fetch_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]:
|
||||
def fetch_adj(self, code: str, start_date: str, end_date: str, adj: str = 'qfq') -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
获取复权价格数据
|
||||
|
||||
使用 auto_adjust=True 获取复权后的价格
|
||||
- 消除拆分(split)和分红(dividend)对价格的影响
|
||||
- 适用于美股股票/ETF
|
||||
统一 adj 参数设计:
|
||||
- 'qfq': 前复权 → yfinance auto_adjust=True (当前价不变)
|
||||
- 'hfq': 后复权 → yfinance back_adjust=True (历史价不变)
|
||||
|
||||
Args:
|
||||
code: 代码(如 'AAPL', 'TSLA', 'QQQ')
|
||||
code: 代码(如 'AAPL', 'TSLA', 'QQQ', '00700.HK')
|
||||
start_date: 开始日期 'YYYY-MM-DD'
|
||||
end_date: 结束日期 'YYYY-MM-DD'
|
||||
adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'qfq'
|
||||
|
||||
Returns:
|
||||
DataFrame with columns: date, open, high, low, close, volume (复权后)
|
||||
DataFrame with columns: date, code, open, high, low, close, volume (复权后)
|
||||
"""
|
||||
import yfinance as yf
|
||||
|
||||
if adj not in ['qfq', 'hfq']:
|
||||
raise ValueError(f"adj 参数必须是 'qfq' 或 'hfq',当前: {adj}")
|
||||
|
||||
# 添加延迟避免限流
|
||||
time.sleep(self._delay)
|
||||
|
||||
# 转换代码格式
|
||||
yf_code = self.CODE_MAP.get(code, code)
|
||||
|
||||
# adj 参数映射到 yfinance 参数
|
||||
# qfq(前复权) = auto_adjust=True, back_adjust=False (当前价不变)
|
||||
# hfq(后复权) = auto_adjust=False, back_adjust=True (历史价不变)
|
||||
adjust_params = {
|
||||
'qfq': {'auto_adjust': True, 'back_adjust': False},
|
||||
'hfq': {'auto_adjust': False, 'back_adjust': True},
|
||||
}
|
||||
|
||||
try:
|
||||
ticker = yf.Ticker(yf_code)
|
||||
|
||||
# end_date 需要加一天(yfinance的end是排他的)
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1)
|
||||
|
||||
# auto_adjust=True 获取复权价格
|
||||
# 根据 adj 参数设置复权方式
|
||||
params = adjust_params[adj]
|
||||
df = ticker.history(
|
||||
start=start_date,
|
||||
end=end_dt.strftime("%Y-%m-%d"),
|
||||
auto_adjust=True
|
||||
auto_adjust=params['auto_adjust'],
|
||||
back_adjust=params['back_adjust']
|
||||
)
|
||||
|
||||
if df is None or len(df) == 0:
|
||||
@@ -170,12 +196,12 @@ class YFinanceSource:
|
||||
# 添加代码列和标记
|
||||
df["code"] = code
|
||||
df.attrs['code'] = code
|
||||
df.attrs['adjusted'] = True
|
||||
df.attrs['adj'] = adj
|
||||
|
||||
return df[['code', 'open', 'high', 'low', 'close', 'volume']]
|
||||
|
||||
except Exception as e:
|
||||
print(f"YFinance下载复权数据 {code} ({yf_code}) 失败: {e}")
|
||||
print(f"YFinance下载复权数据 {code} ({yf_code}) adj={adj} 失败: {e}")
|
||||
return None
|
||||
|
||||
def is_yfinance_code(self, code: str) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user