""" ETF 全量历史数据本地缓存 ======================== 一次性下载全市场 ETF(含已退市)的基础信息和日线数据到本地, 供回测中按 ref_date 截取历史数据,消除前视偏差。 用法: # 首次下载(约 30-60 分钟,取决于 API 限流) python scripts/etf_data_cache.py # 增量更新(只下载缺失的新数据) python scripts/etf_data_cache.py --update """ import os import sys import time import logging from pathlib import Path from datetime import datetime import pandas as pd sys.path.insert(0, str(Path(__file__).parent.parent)) from dotenv import load_dotenv load_dotenv() import tushare as ts logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s') logger = logging.getLogger(__name__) # 缓存目录 CACHE_DIR = Path(__file__).parent.parent / 'data' / 'etf_cache' DAILY_DIR = CACHE_DIR / 'daily' BASIC_PATH = CACHE_DIR / 'fund_basic.csv' class ETFDataCache: """ETF 全量历史数据缓存管理器""" def __init__(self): self.pro = ts.pro_api(os.getenv('TUSHARE_TOKEN')) CACHE_DIR.mkdir(parents=True, exist_ok=True) DAILY_DIR.mkdir(parents=True, exist_ok=True) self._basic_df = None # 懒加载 # ---------------------------------------------------------- # API 调用(带重试 + 限流) # ---------------------------------------------------------- def _api_call(self, func, **kwargs): for attempt in range(3): try: result = func(**kwargs) time.sleep(0.35) return result except Exception as e: if attempt < 2: wait = 2 * (attempt + 1) logger.warning(f" API 重试 ({attempt+1}/3): {e}, 等待 {wait}s") time.sleep(wait) else: raise # ---------------------------------------------------------- # 1. 下载并缓存 fund_basic # ---------------------------------------------------------- def download_basic(self, force: bool = False): """下载全量 ETF 基础信息(含已退市)""" if BASIC_PATH.exists() and not force: logger.info(f"fund_basic 缓存已存在: {BASIC_PATH}") return logger.info("下载全量 ETF 基础信息...") fields = 'ts_code,name,management,list_date,delist_date,fund_type,invest_type,benchmark,type,trustee,status' dfs = [] for status in ['L', 'D']: # L=上市, D=已退市 df = self._api_call(self.pro.fund_basic, market='E', status=status, fields=fields) if df is not None and not df.empty: dfs.append(df) logger.info(f" status={status}: {len(df)} 只") if not dfs: raise RuntimeError("获取 ETF 列表失败") basic = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='ts_code') basic.to_csv(BASIC_PATH, index=False, encoding='utf-8-sig') logger.info(f"fund_basic 已保存: {len(basic)} 只 -> {BASIC_PATH}") # ---------------------------------------------------------- # 2. 批量下载日线数据 # ---------------------------------------------------------- def download_daily(self, force: bool = False): """批量下载所有 ETF 的全历史日线数据""" basic = self.load_basic() codes = basic['ts_code'].tolist() total = len(codes) logger.info(f"准备下载 {total} 只 ETF 的日线数据...") downloaded = 0 skipped = 0 failed = 0 for i, code in enumerate(codes): csv_path = DAILY_DIR / f"{code}.csv" if csv_path.exists() and not force: # 增量更新: 读取已有数据的最后日期 try: existing = pd.read_csv(csv_path, nrows=1) # 只读首行检查 if not existing.empty: skipped += 1 continue except Exception: pass if (i - skipped) % 20 == 0: logger.info(f" 进度: {i}/{total} (下载={downloaded}, 跳过={skipped}, 失败={failed})") try: df = self._api_call( self.pro.fund_daily, ts_code=code, fields='ts_code,trade_date,open,high,low,close,vol,amount' ) if df is not None and not df.empty: df = df.sort_values('trade_date') df.to_csv(csv_path, index=False) downloaded += 1 else: failed += 1 except Exception as e: logger.warning(f" {code} 下载失败: {e}") failed += 1 logger.info(f"日线数据下载完成: 下载={downloaded}, 跳过={skipped}, 失败={failed}") def update_daily(self): """增量更新: 只为已有缓存文件追加新数据""" basic = self.load_basic() codes = basic['ts_code'].tolist() today_str = datetime.now().strftime('%Y%m%d') updated = 0 for code in codes: csv_path = DAILY_DIR / f"{code}.csv" if not csv_path.exists(): continue try: existing = pd.read_csv(csv_path) if existing.empty: continue last_date = str(existing['trade_date'].max()) if last_date >= today_str: continue # 下载 last_date 之后的数据 new_df = self._api_call( self.pro.fund_daily, ts_code=code, start_date=str(int(last_date) + 1), end_date=today_str, fields='ts_code,trade_date,open,high,low,close,vol,amount' ) if new_df is not None and not new_df.empty: combined = pd.concat([existing, new_df], ignore_index=True) combined = combined.drop_duplicates(subset='trade_date').sort_values('trade_date') combined.to_csv(csv_path, index=False) updated += 1 except Exception: pass logger.info(f"增量更新完成: {updated} 只有新数据") # ---------------------------------------------------------- # 3. 数据读取接口(回测用) # ---------------------------------------------------------- def load_basic(self) -> pd.DataFrame: """加载 fund_basic 缓存""" if self._basic_df is not None: return self._basic_df if not BASIC_PATH.exists(): raise FileNotFoundError(f"fund_basic 缓存不存在,请先运行: python scripts/etf_data_cache.py") self._basic_df = pd.read_csv(BASIC_PATH) return self._basic_df def load_cached_daily(self, ts_code: str, end_date: str = None) -> pd.DataFrame: """ 加载某只 ETF 的日线数据,截至 end_date(含)。 Args: ts_code: ETF 代码 end_date: 截止日期 YYYYMMDD,None 表示全部 Returns: DataFrame with columns [trade_date, open, high, low, close, vol, amount] 按 trade_date 升序排列 """ csv_path = DAILY_DIR / f"{ts_code}.csv" if not csv_path.exists(): return pd.DataFrame() df = pd.read_csv(csv_path) if df.empty: return df df['trade_date'] = df['trade_date'].astype(str) df = df.sort_values('trade_date') if end_date: end_str = str(end_date).replace('-', '') df = df[df['trade_date'] <= end_str] return df def load_cached_daily_as_series(self, ts_code: str, end_date: str = None, column: str = 'close') -> pd.Series: """加载某只 ETF 的单列数据,index 为 datetime""" df = self.load_cached_daily(ts_code, end_date) if df.empty: return pd.Series(dtype=float) df['date'] = pd.to_datetime(df['trade_date']) return df.set_index('date')[column].astype(float) def load_cached_ohlcv(self, ts_code: str, end_date: str = None) -> pd.DataFrame: """加载 OHLCV 数据,index 为 datetime(与 动量.py 的 all_data 格式兼容)""" df = self.load_cached_daily(ts_code, end_date) if df.empty: return pd.DataFrame() df['date'] = pd.to_datetime(df['trade_date']) df = df.set_index('date').sort_index() df = df.rename(columns={'vol': 'volume'}) return df[['open', 'high', 'low', 'close', 'volume']].astype(float) def ensure_downloaded(self): """确保基础信息和日线数据都已下载""" self.download_basic() self.download_daily() def get_available_codes_at(self, ref_date: str) -> list: """获取在 ref_date 时已上市且未退市的 ETF 代码列表""" basic = self.load_basic() basic['list_date'] = basic['list_date'].astype(str) mask = basic['list_date'] <= ref_date # 排除在 ref_date 之前已退市的 if 'delist_date' in basic.columns: delist = basic['delist_date'].astype(str).fillna('99991231') mask = mask & (delist > ref_date) return basic[mask]['ts_code'].tolist() # ---------------------------------------------------------- # CLI # ---------------------------------------------------------- if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='ETF 全量历史数据缓存下载') parser.add_argument('--update', action='store_true', help='增量更新已有缓存') parser.add_argument('--force', action='store_true', help='强制重新下载全部') args = parser.parse_args() cache = ETFDataCache() if args.update: cache.download_basic(force=True) cache.update_daily() else: cache.download_basic(force=args.force) cache.download_daily(force=args.force) # 统计 basic = cache.load_basic() n_daily = len(list(DAILY_DIR.glob('*.csv'))) logger.info(f"\n缓存统计: fund_basic={len(basic)} 只, 日线文件={n_daily} 个")