""" Tushare数据源 获取A股指数、ETF、期货数据 """ import os from typing import Optional from datetime import datetime import pandas as pd class TushareSource: """Tushare数据源""" def __init__(self, token: Optional[str] = None): """ 初始化Tushare数据源 Args: token: Tushare Token(可选,默认从环境变量读取) """ self._token = token or os.getenv("TUSHARE_TOKEN") if not self._token: raise ValueError("请设置环境变量 TUSHARE_TOKEN") def _get_pro_api(self): """获取Tushare Pro API""" import tushare as ts return ts.pro_api(self._token) def _clear_proxy(self) -> dict: """清除代理环境变量(Tushare是国内服务,不需要代理)""" original = {} for key in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: original[key] = os.environ.pop(key, None) return original def _restore_proxy(self, original: dict): """恢复代理环境变量""" for key, value in original.items(): if value is not None: os.environ[key] = value def fetch_index(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 获取A股指数数据 Args: code: 指数代码,如 '000300.SH', '399006.SZ', 'H30269.CSI' start_date: 开始日期 'YYYY-MM-DD' end_date: 结束日期 'YYYY-MM-DD' Returns: DataFrame with columns: date, open, high, low, close, volume """ original_proxy = self._clear_proxy() try: pro = self._get_pro_api() # 转换代码格式 (.SS -> .SH) ts_code = code.replace(".SS", ".SH") df = pro.index_daily( ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") ) if df is None or len(df) == 0: return None # 标准化列名 df = df.rename(columns={ "trade_date": "date", "vol": "volume", }) # 转换日期格式 df["date"] = pd.to_datetime(df["date"]) df = df.set_index("date") df = df.sort_index() df["code"] = code return df[['code', 'open', 'high', 'low', 'close', 'volume']] except Exception as e: print(f"Tushare下载指数 {code} 失败: {e}") return None finally: self._restore_proxy(original_proxy) def fetch_futures(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 获取期货数据 Args: code: 期货代码,如 'AU.SHF', 'CU.SHF' start_date: 开始日期 end_date: 结束日期 """ original_proxy = self._clear_proxy() try: import tushare as ts pro = self._get_pro_api() # 使用 fut_daily 接口 df = pro.fut_daily( ts_code=code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") ) if df is None or len(df) == 0: return None # 标准化列名 df = df.rename(columns={ "trade_date": "date", "vol": "volume", }) df["date"] = pd.to_datetime(df["date"]) df = df.set_index("date") df = df.sort_index() df["code"] = code return df[['code', 'open', 'high', 'low', 'close', 'volume']] except Exception as e: print(f"Tushare下载期货 {code} 失败: {e}") return None finally: self._restore_proxy(original_proxy) def fetch_etf(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 获取ETF价格数据 Args: code: ETF代码,如 '159915.SZ', '518880.SH' """ original_proxy = self._clear_proxy() try: pro = self._get_pro_api() ts_code = code.replace(".SS", ".SH") df = pro.fund_daily( ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") ) if df is None or len(df) == 0: return None df = df.rename(columns={ "trade_date": "date", "vol": "volume", }) df["date"] = pd.to_datetime(df["date"]) df = df.set_index("date") df = df.sort_index() df["code"] = code return df[['code', 'open', 'high', 'low', 'close', 'volume']] except Exception as e: print(f"Tushare下载ETF {code} 失败: {e}") return None finally: self._restore_proxy(original_proxy) def fetch_etf_nav(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 获取ETF净值数据 Args: code: ETF代码 """ original_proxy = self._clear_proxy() try: pro = self._get_pro_api() ts_code = code.replace(".SS", ".SH") df = pro.fund_nav( ts_code=ts_code, start_date=start_date.replace("-", ""), end_date=end_date.replace("-", "") ) if df is None or len(df) == 0: return None df = df.rename(columns={ "nav_date": "date", "unit_nav": "nav", }) df["date"] = pd.to_datetime(df["date"]) df = df.set_index("date") df = df.sort_index() df["code"] = code return df[['code', 'nav']] except Exception as e: print(f"Tushare下载ETF净值 {code} 失败: {e}") return None finally: self._restore_proxy(original_proxy) def is_china_index(self, code: str) -> bool: """判断是否为A股指数""" return code.endswith(".SH") or code.endswith(".SZ") or code.endswith(".SS") or code.endswith(".CSI") def is_futures(self, code: str) -> bool: """判断是否为期货""" return ".SHF" in code or ".NYM" in code or ".DCE" in code or ".CZC" in code def fetch(self, code: str, start_date: str, end_date: str) -> Optional[pd.DataFrame]: """ 通用数据获取(自动判断类型) Args: code: 代码 start_date: 开始日期 end_date: 结束日期 """ if self.is_china_index(code): return self.fetch_index(code, start_date, end_date) elif self.is_futures(code): return self.fetch_futures(code, start_date, end_date) else: return None