归档内容: - core/ (数据源、因子计算、通用工具) → archive/legacy_core/ - strategies/rotation/engine.py, portfolio.py, report.py → archive/legacy_core/ - scripts/ (run_rotation, daily_scheduler) → archive/legacy_scripts/ - examples/ → archive/legacy_examples/ - tests/ (实验、对比测试) → archive/legacy_tests/ - 单独文件 (fetch_*.py, 动量.py, 全球市场.py等) → archive/single_files/ 保留新结构: - framework/ (抽象接口) - strategies/shared/ (定制组件) - strategies/rotation/strategy.py (新策略) - 外层配置: .env, .dockerignore, build-and-push.sh, hk_ecs.pem, README.md, requirements.txt - Docker相关: Dockerfile, Dockerfile_base, docker-compose.yml 更新README反映新框架架构
281 lines
10 KiB
Python
281 lines
10 KiB
Python
"""
|
||
ETF 全量历史数据本地缓存
|
||
========================
|
||
一次性下载全市场 ETF(含已退市)的基础信息和日线数据到本地,
|
||
供回测中按 ref_date 截取历史数据,消除前视偏差。
|
||
|
||
用法:
|
||
# 首次下载(约 30-60 分钟,取决于 API 限流)
|
||
python scripts/etf_data_cache.py
|
||
|
||
# 增量更新(只下载缺失的新数据)
|
||
python scripts/etf_data_cache.py --update
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import logging
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
import pandas as pd
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
from dotenv import load_dotenv
|
||
load_dotenv()
|
||
|
||
import tushare as ts
|
||
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 缓存目录
|
||
CACHE_DIR = Path(__file__).parent.parent / 'data' / 'etf_cache'
|
||
DAILY_DIR = CACHE_DIR / 'daily'
|
||
BASIC_PATH = CACHE_DIR / 'fund_basic.csv'
|
||
|
||
|
||
class ETFDataCache:
|
||
"""ETF 全量历史数据缓存管理器"""
|
||
|
||
def __init__(self):
|
||
self.pro = ts.pro_api(os.getenv('TUSHARE_TOKEN'))
|
||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
DAILY_DIR.mkdir(parents=True, exist_ok=True)
|
||
self._basic_df = None # 懒加载
|
||
|
||
# ----------------------------------------------------------
|
||
# API 调用(带重试 + 限流)
|
||
# ----------------------------------------------------------
|
||
def _api_call(self, func, **kwargs):
|
||
for attempt in range(3):
|
||
try:
|
||
result = func(**kwargs)
|
||
time.sleep(0.35)
|
||
return result
|
||
except Exception as e:
|
||
if attempt < 2:
|
||
wait = 2 * (attempt + 1)
|
||
logger.warning(f" API 重试 ({attempt+1}/3): {e}, 等待 {wait}s")
|
||
time.sleep(wait)
|
||
else:
|
||
raise
|
||
|
||
# ----------------------------------------------------------
|
||
# 1. 下载并缓存 fund_basic
|
||
# ----------------------------------------------------------
|
||
def download_basic(self, force: bool = False):
|
||
"""下载全量 ETF 基础信息(含已退市)"""
|
||
if BASIC_PATH.exists() and not force:
|
||
logger.info(f"fund_basic 缓存已存在: {BASIC_PATH}")
|
||
return
|
||
|
||
logger.info("下载全量 ETF 基础信息...")
|
||
fields = 'ts_code,name,management,list_date,delist_date,fund_type,invest_type,benchmark,type,trustee,status'
|
||
|
||
dfs = []
|
||
for status in ['L', 'D']: # L=上市, D=已退市
|
||
df = self._api_call(self.pro.fund_basic, market='E', status=status, fields=fields)
|
||
if df is not None and not df.empty:
|
||
dfs.append(df)
|
||
logger.info(f" status={status}: {len(df)} 只")
|
||
|
||
if not dfs:
|
||
raise RuntimeError("获取 ETF 列表失败")
|
||
|
||
basic = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='ts_code')
|
||
basic.to_csv(BASIC_PATH, index=False, encoding='utf-8-sig')
|
||
logger.info(f"fund_basic 已保存: {len(basic)} 只 -> {BASIC_PATH}")
|
||
|
||
# ----------------------------------------------------------
|
||
# 2. 批量下载日线数据
|
||
# ----------------------------------------------------------
|
||
def download_daily(self, force: bool = False):
|
||
"""批量下载所有 ETF 的全历史日线数据"""
|
||
basic = self.load_basic()
|
||
codes = basic['ts_code'].tolist()
|
||
total = len(codes)
|
||
logger.info(f"准备下载 {total} 只 ETF 的日线数据...")
|
||
|
||
downloaded = 0
|
||
skipped = 0
|
||
failed = 0
|
||
|
||
for i, code in enumerate(codes):
|
||
csv_path = DAILY_DIR / f"{code}.csv"
|
||
|
||
if csv_path.exists() and not force:
|
||
# 增量更新: 读取已有数据的最后日期
|
||
try:
|
||
existing = pd.read_csv(csv_path, nrows=1) # 只读首行检查
|
||
if not existing.empty:
|
||
skipped += 1
|
||
continue
|
||
except Exception:
|
||
pass
|
||
|
||
if (i - skipped) % 20 == 0:
|
||
logger.info(f" 进度: {i}/{total} (下载={downloaded}, 跳过={skipped}, 失败={failed})")
|
||
|
||
try:
|
||
df = self._api_call(
|
||
self.pro.fund_daily,
|
||
ts_code=code,
|
||
fields='ts_code,trade_date,open,high,low,close,vol,amount'
|
||
)
|
||
if df is not None and not df.empty:
|
||
df = df.sort_values('trade_date')
|
||
df.to_csv(csv_path, index=False)
|
||
downloaded += 1
|
||
else:
|
||
failed += 1
|
||
except Exception as e:
|
||
logger.warning(f" {code} 下载失败: {e}")
|
||
failed += 1
|
||
|
||
logger.info(f"日线数据下载完成: 下载={downloaded}, 跳过={skipped}, 失败={failed}")
|
||
|
||
def update_daily(self):
|
||
"""增量更新: 只为已有缓存文件追加新数据"""
|
||
basic = self.load_basic()
|
||
codes = basic['ts_code'].tolist()
|
||
today_str = datetime.now().strftime('%Y%m%d')
|
||
|
||
updated = 0
|
||
for code in codes:
|
||
csv_path = DAILY_DIR / f"{code}.csv"
|
||
if not csv_path.exists():
|
||
continue
|
||
|
||
try:
|
||
existing = pd.read_csv(csv_path)
|
||
if existing.empty:
|
||
continue
|
||
last_date = str(existing['trade_date'].max())
|
||
if last_date >= today_str:
|
||
continue
|
||
|
||
# 下载 last_date 之后的数据
|
||
new_df = self._api_call(
|
||
self.pro.fund_daily,
|
||
ts_code=code,
|
||
start_date=str(int(last_date) + 1),
|
||
end_date=today_str,
|
||
fields='ts_code,trade_date,open,high,low,close,vol,amount'
|
||
)
|
||
if new_df is not None and not new_df.empty:
|
||
combined = pd.concat([existing, new_df], ignore_index=True)
|
||
combined = combined.drop_duplicates(subset='trade_date').sort_values('trade_date')
|
||
combined.to_csv(csv_path, index=False)
|
||
updated += 1
|
||
except Exception:
|
||
pass
|
||
|
||
logger.info(f"增量更新完成: {updated} 只有新数据")
|
||
|
||
# ----------------------------------------------------------
|
||
# 3. 数据读取接口(回测用)
|
||
# ----------------------------------------------------------
|
||
def load_basic(self) -> pd.DataFrame:
|
||
"""加载 fund_basic 缓存"""
|
||
if self._basic_df is not None:
|
||
return self._basic_df
|
||
|
||
if not BASIC_PATH.exists():
|
||
raise FileNotFoundError(f"fund_basic 缓存不存在,请先运行: python scripts/etf_data_cache.py")
|
||
|
||
self._basic_df = pd.read_csv(BASIC_PATH)
|
||
return self._basic_df
|
||
|
||
def load_cached_daily(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
加载某只 ETF 的日线数据,截至 end_date(含)。
|
||
|
||
Args:
|
||
ts_code: ETF 代码
|
||
end_date: 截止日期 YYYYMMDD,None 表示全部
|
||
|
||
Returns:
|
||
DataFrame with columns [trade_date, open, high, low, close, vol, amount]
|
||
按 trade_date 升序排列
|
||
"""
|
||
csv_path = DAILY_DIR / f"{ts_code}.csv"
|
||
if not csv_path.exists():
|
||
return pd.DataFrame()
|
||
|
||
df = pd.read_csv(csv_path)
|
||
if df.empty:
|
||
return df
|
||
|
||
df['trade_date'] = df['trade_date'].astype(str)
|
||
df = df.sort_values('trade_date')
|
||
|
||
if end_date:
|
||
end_str = str(end_date).replace('-', '')
|
||
df = df[df['trade_date'] <= end_str]
|
||
|
||
return df
|
||
|
||
def load_cached_daily_as_series(self, ts_code: str, end_date: str = None,
|
||
column: str = 'close') -> pd.Series:
|
||
"""加载某只 ETF 的单列数据,index 为 datetime"""
|
||
df = self.load_cached_daily(ts_code, end_date)
|
||
if df.empty:
|
||
return pd.Series(dtype=float)
|
||
df['date'] = pd.to_datetime(df['trade_date'])
|
||
return df.set_index('date')[column].astype(float)
|
||
|
||
def load_cached_ohlcv(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
|
||
"""加载 OHLCV 数据,index 为 datetime(与 动量.py 的 all_data 格式兼容)"""
|
||
df = self.load_cached_daily(ts_code, end_date)
|
||
if df.empty:
|
||
return pd.DataFrame()
|
||
df['date'] = pd.to_datetime(df['trade_date'])
|
||
df = df.set_index('date').sort_index()
|
||
df = df.rename(columns={'vol': 'volume'})
|
||
return df[['open', 'high', 'low', 'close', 'volume']].astype(float)
|
||
|
||
def ensure_downloaded(self):
|
||
"""确保基础信息和日线数据都已下载"""
|
||
self.download_basic()
|
||
self.download_daily()
|
||
|
||
def get_available_codes_at(self, ref_date: str) -> list:
|
||
"""获取在 ref_date 时已上市且未退市的 ETF 代码列表"""
|
||
basic = self.load_basic()
|
||
basic['list_date'] = basic['list_date'].astype(str)
|
||
mask = basic['list_date'] <= ref_date
|
||
|
||
# 排除在 ref_date 之前已退市的
|
||
if 'delist_date' in basic.columns:
|
||
delist = basic['delist_date'].astype(str).fillna('99991231')
|
||
mask = mask & (delist > ref_date)
|
||
|
||
return basic[mask]['ts_code'].tolist()
|
||
|
||
|
||
# ----------------------------------------------------------
|
||
# CLI
|
||
# ----------------------------------------------------------
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description='ETF 全量历史数据缓存下载')
|
||
parser.add_argument('--update', action='store_true', help='增量更新已有缓存')
|
||
parser.add_argument('--force', action='store_true', help='强制重新下载全部')
|
||
args = parser.parse_args()
|
||
|
||
cache = ETFDataCache()
|
||
|
||
if args.update:
|
||
cache.download_basic(force=True)
|
||
cache.update_daily()
|
||
else:
|
||
cache.download_basic(force=args.force)
|
||
cache.download_daily(force=args.force)
|
||
|
||
# 统计
|
||
basic = cache.load_basic()
|
||
n_daily = len(list(DAILY_DIR.glob('*.csv')))
|
||
logger.info(f"\n缓存统计: fund_basic={len(basic)} 只, 日线文件={n_daily} 个")
|