experiment(rotation): 同大类扩充与纳指vs标普替换对比实验

技术修复:
- SOCKS5代理IPv6问题:socks5:// → socks5h:// (hybrid_source.py, yfinance_source.py)

目录整理:
- scripts/ → 仅保留策略入口(daily_scheduler, run_rotation, run_cci_screener)
- 实验脚本移至 tests/experiments/
- 工具脚本移至 tests/utils/
- 实验记录新增 docs/experiments/
- results/ 添加到 gitignore

实验结果:

实验001 - 同大类扩充(添加标普500):
├─ 累计收益: 1467.35% → 1176.26% (-291%)
├─ CAGR: 48.10% → 43.82% (-4.28%)
├─ 调仓次数: 459 → 501 (+42次)
└─ 结论: 添加同大类标的不增加跨类分散,反而侵蚀收益

实验002 - 纳指vs标普替换对比:
├─ 累计收益: 1467.35% → 1118.77% (-348%)
├─ CAGR: 48.10% → 42.87% (-5.22%)
├─ Sharpe: 2.21 → 2.08 (-0.13)
├─ MaxDD: -17.33% → -15.14% (+2.18%)
└─ 结论: 纳指100优于标普500,成长风格更适合动量策略

策略建议:
- 保持纳指100作为美股大类代表
- 不添加同大类新标的(避免类内切换成本)
- 新增标的应优先考虑新大类(增加跨类分散)
This commit is contained in:
2026-05-06 20:43:38 +08:00
parent a4e8a6050e
commit 6b59855c28
20 changed files with 1086 additions and 2 deletions

View File

@@ -0,0 +1,280 @@
"""
ETF 全量历史数据本地缓存
========================
一次性下载全市场 ETF含已退市的基础信息和日线数据到本地
供回测中按 ref_date 截取历史数据,消除前视偏差。
用法:
# 首次下载(约 30-60 分钟,取决于 API 限流)
python scripts/etf_data_cache.py
# 增量更新(只下载缺失的新数据)
python scripts/etf_data_cache.py --update
"""
import os
import sys
import time
import logging
from pathlib import Path
from datetime import datetime
import pandas as pd
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
import tushare as ts
logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')
logger = logging.getLogger(__name__)
# 缓存目录
CACHE_DIR = Path(__file__).parent.parent / 'data' / 'etf_cache'
DAILY_DIR = CACHE_DIR / 'daily'
BASIC_PATH = CACHE_DIR / 'fund_basic.csv'
class ETFDataCache:
"""ETF 全量历史数据缓存管理器"""
def __init__(self):
self.pro = ts.pro_api(os.getenv('TUSHARE_TOKEN'))
CACHE_DIR.mkdir(parents=True, exist_ok=True)
DAILY_DIR.mkdir(parents=True, exist_ok=True)
self._basic_df = None # 懒加载
# ----------------------------------------------------------
# API 调用(带重试 + 限流)
# ----------------------------------------------------------
def _api_call(self, func, **kwargs):
for attempt in range(3):
try:
result = func(**kwargs)
time.sleep(0.35)
return result
except Exception as e:
if attempt < 2:
wait = 2 * (attempt + 1)
logger.warning(f" API 重试 ({attempt+1}/3): {e}, 等待 {wait}s")
time.sleep(wait)
else:
raise
# ----------------------------------------------------------
# 1. 下载并缓存 fund_basic
# ----------------------------------------------------------
def download_basic(self, force: bool = False):
"""下载全量 ETF 基础信息(含已退市)"""
if BASIC_PATH.exists() and not force:
logger.info(f"fund_basic 缓存已存在: {BASIC_PATH}")
return
logger.info("下载全量 ETF 基础信息...")
fields = 'ts_code,name,management,list_date,delist_date,fund_type,invest_type,benchmark,type,trustee,status'
dfs = []
for status in ['L', 'D']: # L=上市, D=已退市
df = self._api_call(self.pro.fund_basic, market='E', status=status, fields=fields)
if df is not None and not df.empty:
dfs.append(df)
logger.info(f" status={status}: {len(df)}")
if not dfs:
raise RuntimeError("获取 ETF 列表失败")
basic = pd.concat(dfs, ignore_index=True).drop_duplicates(subset='ts_code')
basic.to_csv(BASIC_PATH, index=False, encoding='utf-8-sig')
logger.info(f"fund_basic 已保存: {len(basic)} 只 -> {BASIC_PATH}")
# ----------------------------------------------------------
# 2. 批量下载日线数据
# ----------------------------------------------------------
def download_daily(self, force: bool = False):
"""批量下载所有 ETF 的全历史日线数据"""
basic = self.load_basic()
codes = basic['ts_code'].tolist()
total = len(codes)
logger.info(f"准备下载 {total} 只 ETF 的日线数据...")
downloaded = 0
skipped = 0
failed = 0
for i, code in enumerate(codes):
csv_path = DAILY_DIR / f"{code}.csv"
if csv_path.exists() and not force:
# 增量更新: 读取已有数据的最后日期
try:
existing = pd.read_csv(csv_path, nrows=1) # 只读首行检查
if not existing.empty:
skipped += 1
continue
except Exception:
pass
if (i - skipped) % 20 == 0:
logger.info(f" 进度: {i}/{total} (下载={downloaded}, 跳过={skipped}, 失败={failed})")
try:
df = self._api_call(
self.pro.fund_daily,
ts_code=code,
fields='ts_code,trade_date,open,high,low,close,vol,amount'
)
if df is not None and not df.empty:
df = df.sort_values('trade_date')
df.to_csv(csv_path, index=False)
downloaded += 1
else:
failed += 1
except Exception as e:
logger.warning(f" {code} 下载失败: {e}")
failed += 1
logger.info(f"日线数据下载完成: 下载={downloaded}, 跳过={skipped}, 失败={failed}")
def update_daily(self):
"""增量更新: 只为已有缓存文件追加新数据"""
basic = self.load_basic()
codes = basic['ts_code'].tolist()
today_str = datetime.now().strftime('%Y%m%d')
updated = 0
for code in codes:
csv_path = DAILY_DIR / f"{code}.csv"
if not csv_path.exists():
continue
try:
existing = pd.read_csv(csv_path)
if existing.empty:
continue
last_date = str(existing['trade_date'].max())
if last_date >= today_str:
continue
# 下载 last_date 之后的数据
new_df = self._api_call(
self.pro.fund_daily,
ts_code=code,
start_date=str(int(last_date) + 1),
end_date=today_str,
fields='ts_code,trade_date,open,high,low,close,vol,amount'
)
if new_df is not None and not new_df.empty:
combined = pd.concat([existing, new_df], ignore_index=True)
combined = combined.drop_duplicates(subset='trade_date').sort_values('trade_date')
combined.to_csv(csv_path, index=False)
updated += 1
except Exception:
pass
logger.info(f"增量更新完成: {updated} 只有新数据")
# ----------------------------------------------------------
# 3. 数据读取接口(回测用)
# ----------------------------------------------------------
def load_basic(self) -> pd.DataFrame:
"""加载 fund_basic 缓存"""
if self._basic_df is not None:
return self._basic_df
if not BASIC_PATH.exists():
raise FileNotFoundError(f"fund_basic 缓存不存在,请先运行: python scripts/etf_data_cache.py")
self._basic_df = pd.read_csv(BASIC_PATH)
return self._basic_df
def load_cached_daily(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
"""
加载某只 ETF 的日线数据,截至 end_date
Args:
ts_code: ETF 代码
end_date: 截止日期 YYYYMMDDNone 表示全部
Returns:
DataFrame with columns [trade_date, open, high, low, close, vol, amount]
按 trade_date 升序排列
"""
csv_path = DAILY_DIR / f"{ts_code}.csv"
if not csv_path.exists():
return pd.DataFrame()
df = pd.read_csv(csv_path)
if df.empty:
return df
df['trade_date'] = df['trade_date'].astype(str)
df = df.sort_values('trade_date')
if end_date:
end_str = str(end_date).replace('-', '')
df = df[df['trade_date'] <= end_str]
return df
def load_cached_daily_as_series(self, ts_code: str, end_date: str = None,
column: str = 'close') -> pd.Series:
"""加载某只 ETF 的单列数据index 为 datetime"""
df = self.load_cached_daily(ts_code, end_date)
if df.empty:
return pd.Series(dtype=float)
df['date'] = pd.to_datetime(df['trade_date'])
return df.set_index('date')[column].astype(float)
def load_cached_ohlcv(self, ts_code: str, end_date: str = None) -> pd.DataFrame:
"""加载 OHLCV 数据index 为 datetime与 动量.py 的 all_data 格式兼容)"""
df = self.load_cached_daily(ts_code, end_date)
if df.empty:
return pd.DataFrame()
df['date'] = pd.to_datetime(df['trade_date'])
df = df.set_index('date').sort_index()
df = df.rename(columns={'vol': 'volume'})
return df[['open', 'high', 'low', 'close', 'volume']].astype(float)
def ensure_downloaded(self):
"""确保基础信息和日线数据都已下载"""
self.download_basic()
self.download_daily()
def get_available_codes_at(self, ref_date: str) -> list:
"""获取在 ref_date 时已上市且未退市的 ETF 代码列表"""
basic = self.load_basic()
basic['list_date'] = basic['list_date'].astype(str)
mask = basic['list_date'] <= ref_date
# 排除在 ref_date 之前已退市的
if 'delist_date' in basic.columns:
delist = basic['delist_date'].astype(str).fillna('99991231')
mask = mask & (delist > ref_date)
return basic[mask]['ts_code'].tolist()
# ----------------------------------------------------------
# CLI
# ----------------------------------------------------------
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='ETF 全量历史数据缓存下载')
parser.add_argument('--update', action='store_true', help='增量更新已有缓存')
parser.add_argument('--force', action='store_true', help='强制重新下载全部')
args = parser.parse_args()
cache = ETFDataCache()
if args.update:
cache.download_basic(force=True)
cache.update_daily()
else:
cache.download_basic(force=args.force)
cache.download_daily(force=args.force)
# 统计
basic = cache.load_basic()
n_daily = len(list(DAILY_DIR.glob('*.csv')))
logger.info(f"\n缓存统计: fund_basic={len(basic)} 只, 日线文件={n_daily}")