Files
etf/archive/legacy_tests/tests/utils/etf_data_cache.py
aszerW 1fca536c95 refactor: 归档旧代码,保留新框架结构
归档内容:
- core/ (数据源、因子计算、通用工具) → archive/legacy_core/
- strategies/rotation/engine.py, portfolio.py, report.py → archive/legacy_core/
- scripts/ (run_rotation, daily_scheduler) → archive/legacy_scripts/
- examples/ → archive/legacy_examples/
- tests/ (实验、对比测试) → archive/legacy_tests/
- 单独文件 (fetch_*.py, 动量.py, 全球市场.py等) → archive/single_files/

保留新结构:
- framework/ (抽象接口)
- strategies/shared/ (定制组件)
- strategies/rotation/strategy.py (新策略)
- 外层配置: .env, .dockerignore, build-and-push.sh, hk_ecs.pem, README.md, requirements.txt
- Docker相关: Dockerfile, Dockerfile_base, docker-compose.yml

更新README反映新框架架构
2026-05-11 23:34:23 +08:00

281 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ETF 全量历史数据本地缓存
========================
一次性下载全市场 ETF含已退市的基础信息和日线数据到本地
供回测中按 ref_date 截取历史数据,消除前视偏差。
用法:
# 首次下载(约 30-60 分钟,取决于 API 限流)
python scripts/etf_data_cache.py
# 增量更新(只下载缺失的新数据)
python scripts/etf_data_cache.py --update
"""
import os
import sys
import time
import logging
from pathlib import Path
from datetime import datetime
import pandas as pd
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
import tushare as ts
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
# 缓存目录
CACHE_DIR = Path(__file__).parent.parent / 'data' / 'etf_cache'
DAILY_DIR = CACHE_DIR / 'daily'
BASIC_PATH = CACHE_DIR / 'fund_basic.csv'
class ETFDataCache:
"""ETF 全量历史数据缓存管理器"""
def __init__(self):
self.pro = ts.pro_api(os.getenv('TUSHARE_TOKEN'))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
DAILY_DIR.mkdir(parents=True, exist_ok=True)
self._basic_df = None # 懒加载
# ----------------------------------------------------------
# API 调用(带重试 + 限流)
# ----------------------------------------------------------
def _api_call(self, func, **kwargs):
for attempt in range(3):
try:
result = func(**kwargs)
time.sleep(0.35)
return result
except Exception as e:
if attempt < 2:
wait = 2 * (attempt + 1)
logger.warning(f" API 重试 ({attempt+1}/3): {e}, 等待 {wait}s")
time.sleep(wait)
else:
raise
# ----------------------------------------------------------
# 1. 下载并缓存 fund_basic
# ----------------------------------------------------------
def download_basic(self, force: bool = False):
"""下载全量 ETF 基础信息(含已退市)"""
if BASIC_PATH.exists() and not force:
logger.info(f"fund_basic 缓存已存在: {BASIC_PATH}")
return
logger.info("下载全量 ETF 基础信息...")
fields = 'ts_code,name,management,list_date,delist_date,fund_type,invest_type,benchmark,type,trustee,status'
dfs = []
for status in ['L', 'D']: # L=上市, D=已退市
df = self._api_call(self.pro.fund_basic, market='E', status=status, fields=fields)
if df is not None and not df.empty:
dfs.append(df)
logger.info(f" status={status}: {len(df)}")
if not dfs:
raise RuntimeError("获取 ETF 列表失败")
basic = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='ts_code')
basic.to_csv(BASIC_PATH, index=False, encoding='utf-8-sig')
logger.info(f"fund_basic 已保存: {len(basic)} 只 -> {BASIC_PATH}")
# ----------------------------------------------------------
# 2. 批量下载日线数据
# ----------------------------------------------------------
def download_daily(self, force: bool = False):
"""批量下载所有 ETF 的全历史日线数据"""
basic = self.load_basic()
codes = basic['ts_code'].tolist()
total = len(codes)
logger.info(f"准备下载 {total} 只 ETF 的日线数据...")
downloaded = 0
skipped = 0
failed = 0
for i, code in enumerate(codes):
csv_path = DAILY_DIR / f"{code}.csv"
if csv_path.exists() and not force:
# 增量更新: 读取已有数据的最后日期
try:
existing = pd.read_csv(csv_path, nrows=1) # 只读首行检查
if not existing.empty:
skipped += 1
continue
except Exception:
pass
if (i - skipped) % 20 == 0:
logger.info(f" 进度: {i}/{total} (下载={downloaded}, 跳过={skipped}, 失败={failed})")
try:
df = self._api_call(
self.pro.fund_daily,
ts_code=code,
fields='ts_code,trade_date,open,high,low,close,vol,amount'
)
if df is not None and not df.empty:
df = df.sort_values('trade_date')
df.to_csv(csv_path, index=False)
downloaded += 1
else:
failed += 1
except Exception as e:
logger.warning(f" {code} 下载失败: {e}")
failed += 1
logger.info(f"日线数据下载完成: 下载={downloaded}, 跳过={skipped}, 失败={failed}")
def update_daily(self):
"""增量更新: 只为已有缓存文件追加新数据"""
basic = self.load_basic()
codes = basic['ts_code'].tolist()
today_str = datetime.now().strftime('%Y%m%d')
updated = 0
for code in codes:
csv_path = DAILY_DIR / f"{code}.csv"
if not csv_path.exists():
continue
try:
existing = pd.read_csv(csv_path)
if existing.empty:
continue
last_date = str(existing['trade_date'].max())
if last_date >= today_str:
continue
# 下载 last_date 之后的数据
new_df = self._api_call(
self.pro.fund_daily,
ts_code=code,
start_date=str(int(last_date) + 1),
end_date=today_str,
fields='ts_code,trade_date,open,high,low,close,vol,amount'
)
if new_df is not None and not new_df.empty:
combined = pd.concat([existing, new_df], ignore_index=True)
combined = combined.drop_duplicates(subset='trade_date').sort_values('trade_date')
combined.to_csv(csv_path, index=False)
updated += 1
except Exception:
pass
logger.info(f"增量更新完成: {updated} 只有新数据")
# ----------------------------------------------------------
# 3. 数据读取接口(回测用)
# ----------------------------------------------------------
def load_basic(self) -> pd.DataFrame:
"""加载 fund_basic 缓存"""
if self._basic_df is not None:
return self._basic_df
if not BASIC_PATH.exists():
raise FileNotFoundError(f"fund_basic 缓存不存在,请先运行: python scripts/etf_data_cache.py")
self._basic_df = pd.read_csv(BASIC_PATH)
return self._basic_df
def load_cached_daily(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
"""
加载某只 ETF 的日线数据,截至 end_date
Args:
ts_code: ETF 代码
end_date: 截止日期 YYYYMMDDNone 表示全部
Returns:
DataFrame with columns [trade_date, open, high, low, close, vol, amount]
按 trade_date 升序排列
"""
csv_path = DAILY_DIR / f"{ts_code}.csv"
if not csv_path.exists():
return pd.DataFrame()
df = pd.read_csv(csv_path)
if df.empty:
return df
df['trade_date'] = df['trade_date'].astype(str)
df = df.sort_values('trade_date')
if end_date:
end_str = str(end_date).replace('-', '')
df = df[df['trade_date'] <= end_str]
return df
def load_cached_daily_as_series(self, ts_code: str, end_date: str = None,
column: str = 'close') -> pd.Series:
"""加载某只 ETF 的单列数据index 为 datetime"""
df = self.load_cached_daily(ts_code, end_date)
if df.empty:
return pd.Series(dtype=float)
df['date'] = pd.to_datetime(df['trade_date'])
return df.set_index('date')[column].astype(float)
def load_cached_ohlcv(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
"""加载 OHLCV 数据index 为 datetime与 动量.py 的 all_data 格式兼容)"""
df = self.load_cached_daily(ts_code, end_date)
if df.empty:
return pd.DataFrame()
df['date'] = pd.to_datetime(df['trade_date'])
df = df.set_index('date').sort_index()
df = df.rename(columns={'vol': 'volume'})
return df[['open', 'high', 'low', 'close', 'volume']].astype(float)
def ensure_downloaded(self):
"""确保基础信息和日线数据都已下载"""
self.download_basic()
self.download_daily()
def get_available_codes_at(self, ref_date: str) -> list:
"""获取在 ref_date 时已上市且未退市的 ETF 代码列表"""
basic = self.load_basic()
basic['list_date'] = basic['list_date'].astype(str)
mask = basic['list_date'] <= ref_date
# 排除在 ref_date 之前已退市的
if 'delist_date' in basic.columns:
delist = basic['delist_date'].astype(str).fillna('99991231')
mask = mask & (delist > ref_date)
return basic[mask]['ts_code'].tolist()
# ----------------------------------------------------------
# CLI
# ----------------------------------------------------------
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='ETF 全量历史数据缓存下载')
parser.add_argument('--update', action='store_true', help='增量更新已有缓存')
parser.add_argument('--force', action='store_true', help='强制重新下载全部')
args = parser.parse_args()
cache = ETFDataCache()
if args.update:
cache.download_basic(force=True)
cache.update_daily()
else:
cache.download_basic(force=args.force)
cache.download_daily(force=args.force)
# 统计
basic = cache.load_basic()
n_daily = len(list(DAILY_DIR.glob('*.csv')))
logger.info(f"\n缓存统计: fund_basic={len(basic)} 只, 日线文件={n_daily}")