experiment(rotation): 同大类扩充与纳指vs标普替换对比实验
技术修复: - SOCKS5代理IPv6问题:socks5:// → socks5h:// (hybrid_source.py, yfinance_source.py) 目录整理: - scripts/ → 仅保留策略入口(daily_scheduler, run_rotation, run_cci_screener) - 实验脚本移至 tests/experiments/ - 工具脚本移至 tests/utils/ - 实验记录新增 docs/experiments/ - results/ 添加到 gitignore 实验结果: 实验001 - 同大类扩充(添加标普500): ├─ 累计收益: 1467.35% → 1176.26% (-291%) ├─ CAGR: 48.10% → 43.82% (-4.28%) ├─ 调仓次数: 459 → 501 (+42次) └─ 结论: 添加同大类标的不增加跨类分散,反而侵蚀收益 实验002 - 纳指vs标普替换对比: ├─ 累计收益: 1467.35% → 1118.77% (-348%) ├─ CAGR: 48.10% → 42.87% (-5.22%) ├─ Sharpe: 2.21 → 2.08 (-0.13) ├─ MaxDD: -17.33% → -15.14% (+2.18%) └─ 结论: 纳指100优于标普500,成长风格更适合动量策略 策略建议: - 保持纳指100作为美股大类代表 - 不添加同大类新标的(避免类内切换成本) - 新增标的应优先考虑新大类(增加跨类分散)
This commit is contained in:
280
tests/utils/etf_data_cache.py
Normal file
280
tests/utils/etf_data_cache.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
ETF 全量历史数据本地缓存
|
||||
========================
|
||||
一次性下载全市场 ETF(含已退市)的基础信息和日线数据到本地,
|
||||
供回测中按 ref_date 截取历史数据,消除前视偏差。
|
||||
|
||||
用法:
|
||||
# 首次下载(约 30-60 分钟,取决于 API 限流)
|
||||
python scripts/etf_data_cache.py
|
||||
|
||||
# 增量更新(只下载缺失的新数据)
|
||||
python scripts/etf_data_cache.py --update
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
import tushare as ts
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存目录
|
||||
CACHE_DIR = Path(__file__).parent.parent / 'data' / 'etf_cache'
|
||||
DAILY_DIR = CACHE_DIR / 'daily'
|
||||
BASIC_PATH = CACHE_DIR / 'fund_basic.csv'
|
||||
|
||||
|
||||
class ETFDataCache:
|
||||
"""ETF 全量历史数据缓存管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.pro = ts.pro_api(os.getenv('TUSHARE_TOKEN'))
|
||||
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
DAILY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
self._basic_df = None # 懒加载
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# API 调用(带重试 + 限流)
|
||||
# ----------------------------------------------------------
|
||||
def _api_call(self, func, **kwargs):
|
||||
for attempt in range(3):
|
||||
try:
|
||||
result = func(**kwargs)
|
||||
time.sleep(0.35)
|
||||
return result
|
||||
except Exception as e:
|
||||
if attempt < 2:
|
||||
wait = 2 * (attempt + 1)
|
||||
logger.warning(f" API 重试 ({attempt+1}/3): {e}, 等待 {wait}s")
|
||||
time.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 1. 下载并缓存 fund_basic
|
||||
# ----------------------------------------------------------
|
||||
def download_basic(self, force: bool = False):
|
||||
"""下载全量 ETF 基础信息(含已退市)"""
|
||||
if BASIC_PATH.exists() and not force:
|
||||
logger.info(f"fund_basic 缓存已存在: {BASIC_PATH}")
|
||||
return
|
||||
|
||||
logger.info("下载全量 ETF 基础信息...")
|
||||
fields = 'ts_code,name,management,list_date,delist_date,fund_type,invest_type,benchmark,type,trustee,status'
|
||||
|
||||
dfs = []
|
||||
for status in ['L', 'D']: # L=上市, D=已退市
|
||||
df = self._api_call(self.pro.fund_basic, market='E', status=status, fields=fields)
|
||||
if df is not None and not df.empty:
|
||||
dfs.append(df)
|
||||
logger.info(f" status={status}: {len(df)} 只")
|
||||
|
||||
if not dfs:
|
||||
raise RuntimeError("获取 ETF 列表失败")
|
||||
|
||||
basic = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='ts_code')
|
||||
basic.to_csv(BASIC_PATH, index=False, encoding='utf-8-sig')
|
||||
logger.info(f"fund_basic 已保存: {len(basic)} 只 -> {BASIC_PATH}")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 2. 批量下载日线数据
|
||||
# ----------------------------------------------------------
|
||||
def download_daily(self, force: bool = False):
|
||||
"""批量下载所有 ETF 的全历史日线数据"""
|
||||
basic = self.load_basic()
|
||||
codes = basic['ts_code'].tolist()
|
||||
total = len(codes)
|
||||
logger.info(f"准备下载 {total} 只 ETF 的日线数据...")
|
||||
|
||||
downloaded = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
for i, code in enumerate(codes):
|
||||
csv_path = DAILY_DIR / f"{code}.csv"
|
||||
|
||||
if csv_path.exists() and not force:
|
||||
# 增量更新: 读取已有数据的最后日期
|
||||
try:
|
||||
existing = pd.read_csv(csv_path, nrows=1) # 只读首行检查
|
||||
if not existing.empty:
|
||||
skipped += 1
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (i - skipped) % 20 == 0:
|
||||
logger.info(f" 进度: {i}/{total} (下载={downloaded}, 跳过={skipped}, 失败={failed})")
|
||||
|
||||
try:
|
||||
df = self._api_call(
|
||||
self.pro.fund_daily,
|
||||
ts_code=code,
|
||||
fields='ts_code,trade_date,open,high,low,close,vol,amount'
|
||||
)
|
||||
if df is not None and not df.empty:
|
||||
df = df.sort_values('trade_date')
|
||||
df.to_csv(csv_path, index=False)
|
||||
downloaded += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
logger.warning(f" {code} 下载失败: {e}")
|
||||
failed += 1
|
||||
|
||||
logger.info(f"日线数据下载完成: 下载={downloaded}, 跳过={skipped}, 失败={failed}")
|
||||
|
||||
def update_daily(self):
|
||||
"""增量更新: 只为已有缓存文件追加新数据"""
|
||||
basic = self.load_basic()
|
||||
codes = basic['ts_code'].tolist()
|
||||
today_str = datetime.now().strftime('%Y%m%d')
|
||||
|
||||
updated = 0
|
||||
for code in codes:
|
||||
csv_path = DAILY_DIR / f"{code}.csv"
|
||||
if not csv_path.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
existing = pd.read_csv(csv_path)
|
||||
if existing.empty:
|
||||
continue
|
||||
last_date = str(existing['trade_date'].max())
|
||||
if last_date >= today_str:
|
||||
continue
|
||||
|
||||
# 下载 last_date 之后的数据
|
||||
new_df = self._api_call(
|
||||
self.pro.fund_daily,
|
||||
ts_code=code,
|
||||
start_date=str(int(last_date) + 1),
|
||||
end_date=today_str,
|
||||
fields='ts_code,trade_date,open,high,low,close,vol,amount'
|
||||
)
|
||||
if new_df is not None and not new_df.empty:
|
||||
combined = pd.concat([existing, new_df], ignore_index=True)
|
||||
combined = combined.drop_duplicates(subset='trade_date').sort_values('trade_date')
|
||||
combined.to_csv(csv_path, index=False)
|
||||
updated += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"增量更新完成: {updated} 只有新数据")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# 3. 数据读取接口(回测用)
|
||||
# ----------------------------------------------------------
|
||||
def load_basic(self) -> pd.DataFrame:
|
||||
"""加载 fund_basic 缓存"""
|
||||
if self._basic_df is not None:
|
||||
return self._basic_df
|
||||
|
||||
if not BASIC_PATH.exists():
|
||||
raise FileNotFoundError(f"fund_basic 缓存不存在,请先运行: python scripts/etf_data_cache.py")
|
||||
|
||||
self._basic_df = pd.read_csv(BASIC_PATH)
|
||||
return self._basic_df
|
||||
|
||||
def load_cached_daily(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
加载某只 ETF 的日线数据,截至 end_date(含)。
|
||||
|
||||
Args:
|
||||
ts_code: ETF 代码
|
||||
end_date: 截止日期 YYYYMMDD,None 表示全部
|
||||
|
||||
Returns:
|
||||
DataFrame with columns [trade_date, open, high, low, close, vol, amount]
|
||||
按 trade_date 升序排列
|
||||
"""
|
||||
csv_path = DAILY_DIR / f"{ts_code}.csv"
|
||||
if not csv_path.exists():
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
df['trade_date'] = df['trade_date'].astype(str)
|
||||
df = df.sort_values('trade_date')
|
||||
|
||||
if end_date:
|
||||
end_str = str(end_date).replace('-', '')
|
||||
df = df[df['trade_date'] <= end_str]
|
||||
|
||||
return df
|
||||
|
||||
def load_cached_daily_as_series(self, ts_code: str, end_date: str = None,
|
||||
column: str = 'close') -> pd.Series:
|
||||
"""加载某只 ETF 的单列数据,index 为 datetime"""
|
||||
df = self.load_cached_daily(ts_code, end_date)
|
||||
if df.empty:
|
||||
return pd.Series(dtype=float)
|
||||
df['date'] = pd.to_datetime(df['trade_date'])
|
||||
return df.set_index('date')[column].astype(float)
|
||||
|
||||
def load_cached_ohlcv(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
|
||||
"""加载 OHLCV 数据,index 为 datetime(与 动量.py 的 all_data 格式兼容)"""
|
||||
df = self.load_cached_daily(ts_code, end_date)
|
||||
if df.empty:
|
||||
return pd.DataFrame()
|
||||
df['date'] = pd.to_datetime(df['trade_date'])
|
||||
df = df.set_index('date').sort_index()
|
||||
df = df.rename(columns={'vol': 'volume'})
|
||||
return df[['open', 'high', 'low', 'close', 'volume']].astype(float)
|
||||
|
||||
def ensure_downloaded(self):
|
||||
"""确保基础信息和日线数据都已下载"""
|
||||
self.download_basic()
|
||||
self.download_daily()
|
||||
|
||||
def get_available_codes_at(self, ref_date: str) -> list:
|
||||
"""获取在 ref_date 时已上市且未退市的 ETF 代码列表"""
|
||||
basic = self.load_basic()
|
||||
basic['list_date'] = basic['list_date'].astype(str)
|
||||
mask = basic['list_date'] <= ref_date
|
||||
|
||||
# 排除在 ref_date 之前已退市的
|
||||
if 'delist_date' in basic.columns:
|
||||
delist = basic['delist_date'].astype(str).fillna('99991231')
|
||||
mask = mask & (delist > ref_date)
|
||||
|
||||
return basic[mask]['ts_code'].tolist()
|
||||
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# CLI
|
||||
# ----------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='ETF 全量历史数据缓存下载')
|
||||
parser.add_argument('--update', action='store_true', help='增量更新已有缓存')
|
||||
parser.add_argument('--force', action='store_true', help='强制重新下载全部')
|
||||
args = parser.parse_args()
|
||||
|
||||
cache = ETFDataCache()
|
||||
|
||||
if args.update:
|
||||
cache.download_basic(force=True)
|
||||
cache.update_daily()
|
||||
else:
|
||||
cache.download_basic(force=args.force)
|
||||
cache.download_daily(force=args.force)
|
||||
|
||||
# 统计
|
||||
basic = cache.load_basic()
|
||||
n_daily = len(list(DAILY_DIR.glob('*.csv')))
|
||||
logger.info(f"\n缓存统计: fund_basic={len(basic)} 只, 日线文件={n_daily} 个")
|
||||
Reference in New Issue
Block a user