diff --git a/datasource/tushare_source.py b/datasource/tushare_source.py index 679057c..c7c8131 100644 --- a/datasource/tushare_source.py +++ b/datasource/tushare_source.py @@ -196,22 +196,79 @@ class TushareSource: # 只支持中国交易所期货(.SHF上期所、.DCE大商所、.CZC郑商所) # NYMEX (.NYM) 和 ICE (.ICE) 走 YFinance return ".SHF" in code or ".DCE" in code or ".CZC" in code + + def is_china_stock(self, code: str) -> bool: + """判断是否为A股股票(6位数字代码 + .SZ/.SH/.SS)""" + # 股票代码:000001.SZ, 600000.SH 等 + # 区分指数:指数代码通常是 000xxx.SH, 399xxx.SZ, H30xxx.CSI + # 股票代码通常是 00xxxx.SZ, 30xxxx.SZ, 60xxxx.SH, 000xxx.SH(部分) + import re + # 股票代码模式:6位数字 + .SZ/.SH/.SS + # 排除指数:000xxx.SH (指数), 399xxx.SZ (指数), Hxxxxx.CSI (指数) + if not re.match(r'^\d{6}\.(SZ|SH|SS)$', code): + return False + # 000xxx.SH 可能是指数也可能是股票,需要更细致判断 + # 简化处理:000/001/002/003 开头 + .SZ 是股票,600/601/603 开头 + .SH 是股票 + prefix = code[:3] + suffix = code.split('.')[1] + if suffix == 'SZ' and prefix in ['000', '001', '002', '003', '300']: + return True + if suffix == 'SH' and prefix in ['600', '601', '603', '605', '688']: + return True + return False - def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: """ - 通用数据获取(自动判断类型) + 通用数据获取(自动判断类型,支持 adj 参数) Args: code: 代码 start_date: 开始日期 end_date: 结束日期 + adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw' + + Returns: + DataFrame with columns: date, open, high, low, close, volume + adj='hfq' 时 A股 ETF 会额外返回 adj_factor, close_hfq """ - if self.is_china_index(code): - return self.fetch_index(code, start_date, end_date) - elif self.is_futures(code): - return self.fetch_futures(code, start_date, end_date) - else: - return None + # 校验 adj 参数 + if adj not in ['raw', 'qfq', 'hfq']: + raise ValueError(f"adj 参数必须是 'raw', 'qfq' 或 'hfq',当前: {adj}") + + # 原始数据 + if adj == 'raw': + if self.is_china_index(code): + return self.fetch_index(code, start_date, end_date) + elif self.is_futures(code): + return self.fetch_futures(code, start_date, end_date) + elif self.is_china_stock(code): + return self.fetch_stock_adj(code, start_date, end_date, adj='raw') + else: + return None + + # 复权数据 + if adj in ['qfq', 'hfq']: + # A股股票复权 + if self.is_china_stock(code): + return self.fetch_stock_adj(code, start_date, end_date, adj) + # A股 ETF 仅支持 hfq + elif self._is_etf_code(code): + if adj == 'hfq': + return self.fetch_etf_adj(code, start_date, end_date) + else: + raise ValueError(f"ETF 仅支持 adj='hfq'(后复权),当前: {adj}") + else: + # 指数/期货不支持复权 + raise ValueError(f"指数/期货不支持复权,adj='{adj}' 仅适用于股票/ETF") + + def _is_etf_code(self, code: str) -> bool: + """判断是否为ETF代码""" + # ETF代码:51xxxx.SH, 52xxxx.SH, 15xxxx.SZ, 16xxxx.SZ + import re + if not re.match(r'^\d{6}\.(SZ|SH)$', code): + return False + prefix = code[:2] + return prefix in ['51', '52', '15', '16'] def fetch_etf_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ @@ -332,4 +389,65 @@ class TushareSource: except Exception as e: print(f"Tushare下载交易日历失败: {e}") - return pd.DatetimeIndex([]) \ No newline at end of file + return pd.DatetimeIndex([]) + + def fetch_stock_adj(self, code: str, start_date: str, end_date: str, adj: str = 'hfq') -> Optional[pd.DataFrame]: + """ + 获取 A股股票复权价格数据 + + 使用 pro_bar 接口获取前复权(qfq)或后复权(hfq)价格。 + + Args: + code: 股票代码,如 '000001.SZ', '600000.SH' + start_date: 开始日期 'YYYY-MM-DD' + end_date: 结束日期 'YYYY-MM-DD' + adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'hfq' + + Returns: + DataFrame with columns: date, code, open, high, low, close, volume, adj_factor + """ + import tushare as ts + + if adj not in ['qfq', 'hfq']: + raise ValueError(f"adj 参数必须是 'qfq' 或 'hfq',当前: {adj}") + + try: + ts_code = code.replace('.SS', '.SH') + + # 使用 pro_bar 接口获取复权数据 + df = ts.pro_bar( + ts_code=ts_code, + adj=adj, + start_date=start_date.replace('-', ''), + end_date=end_date.replace('-', ''), + adjfactor=True # 返回复权因子 + ) + + if df is None or len(df) == 0: + return None + + # 标准化列名 + df = df.rename(columns={ + 'ts_code': 'code', + 'trade_date': 'date', + 'vol': 'volume', + }) + + # 转换日期格式 + df['date'] = pd.to_datetime(df['date']) + df = df.set_index('date') + df = df.sort_index() + + # 恢复原始代码格式(.SS -> .SH 反转) + df['code'] = code + + # 标准化返回字段 + columns = ['code', 'open', 'high', 'low', 'close', 'volume'] + if 'adj_factor' in df.columns: + columns.append('adj_factor') + + return df[columns] + + except Exception as e: + print(f"Tushare下载股票复权数据 {code} 失败: {e}") + return None \ No newline at end of file diff --git a/datasource/yfinance_source.py b/datasource/yfinance_source.py index 959369e..40054b1 100644 --- a/datasource/yfinance_source.py +++ b/datasource/yfinance_source.py @@ -44,19 +44,30 @@ class YFinanceSource: self.use_ssh_tunnel = use_ssh_tunnel self._delay = 0.5 # 请求延迟(避免限流) - def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + def fetch(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]: """ - 获取数据 + 获取数据(支持 adj 参数) Args: - code: 代码(如 'NDX', 'N225', 'HSI') + code: 代码(如 'NDX', 'N225', 'HSI', 'AAPL') start_date: 开始日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD' + adj: 复权类型 'raw'(原始) / 'qfq'(前复权) / 'hfq'(后复权),默认 'raw' Returns: DataFrame with columns: date, open, high, low, close, volume 股票元信息存储在 df.attrs['info'] 中 + adj='qfq/hfq' 时 df.attrs['adj'] 会标记复权类型 """ + # 校验 adj 参数 + if adj not in ['raw', 'qfq', 'hfq']: + raise ValueError(f"adj 参数必须是 'raw', 'qfq' 或 'hfq',当前: {adj}") + + # 复权数据:调用 fetch_adj + if adj in ['qfq', 'hfq']: + return self.fetch_adj(code, start_date, end_date, adj) + + # 原始数据:以下为原有逻辑 import yfinance as yf # 添加延迟避免限流 @@ -107,6 +118,7 @@ class YFinanceSource: # 将股票信息存储到 DataFrame.attrs 中(最外层结构) df.attrs['info'] = stock_info df.attrs['code'] = code + df.attrs['adj'] = 'raw' return df[['code', 'open', 'high', 'low', 'close', 'volume']] @@ -114,41 +126,55 @@ class YFinanceSource: print(f"YFinance下载 {code} ({yf_code}) 失败: {e}") return None - def fetch_adj(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: + def fetch_adj(self, code: str, start_date: str, end_date: str, adj: str = 'qfq') -> Optional[pd.DataFrame]: """ 获取复权价格数据 - 使用 auto_adjust=True 获取复权后的价格 - - 消除拆分(split)和分红(dividend)对价格的影响 - - 适用于美股股票/ETF + 统一 adj 参数设计: + - 'qfq': 前复权 → yfinance auto_adjust=True (当前价不变) + - 'hfq': 后复权 → yfinance back_adjust=True (历史价不变) Args: - code: 代码(如 'AAPL', 'TSLA', 'QQQ') + code: 代码(如 'AAPL', 'TSLA', 'QQQ', '00700.HK') start_date: 开始日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD' + adj: 复权类型 'qfq'(前复权) 或 'hfq'(后复权),默认 'qfq' Returns: - DataFrame with columns: date, open, high, low, close, volume (复权后) + DataFrame with columns: date, code, open, high, low, close, volume (复权后) """ import yfinance as yf + if adj not in ['qfq', 'hfq']: + raise ValueError(f"adj 参数必须是 'qfq' 或 'hfq',当前: {adj}") + # 添加延迟避免限流 time.sleep(self._delay) # 转换代码格式 yf_code = self.CODE_MAP.get(code, code) + # adj 参数映射到 yfinance 参数 + # qfq(前复权) = auto_adjust=True, back_adjust=False (当前价不变) + # hfq(后复权) = auto_adjust=False, back_adjust=True (历史价不变) + adjust_params = { + 'qfq': {'auto_adjust': True, 'back_adjust': False}, + 'hfq': {'auto_adjust': False, 'back_adjust': True}, + } + try: ticker = yf.Ticker(yf_code) # end_date 需要加一天(yfinance的end是排他的) end_dt = datetime.strptime(end_date, "%Y-%m-%d") + timedelta(days=1) - # auto_adjust=True 获取复权价格 + # 根据 adj 参数设置复权方式 + params = adjust_params[adj] df = ticker.history( start=start_date, end=end_dt.strftime("%Y-%m-%d"), - auto_adjust=True + auto_adjust=params['auto_adjust'], + back_adjust=params['back_adjust'] ) if df is None or len(df) == 0: @@ -170,12 +196,12 @@ class YFinanceSource: # 添加代码列和标记 df["code"] = code df.attrs['code'] = code - df.attrs['adjusted'] = True + df.attrs['adj'] = adj return df[['code', 'open', 'high', 'low', 'close', 'volume']] except Exception as e: - print(f"YFinance下载复权数据 {code} ({yf_code}) 失败: {e}") + print(f"YFinance下载复权数据 {code} ({yf_code}) adj={adj} 失败: {e}") return None def is_yfinance_code(self, code: str) -> bool: