Files
etf/rotation/enrich_etf_data.py
aszerW adb83d8cd7 feat: 实现贪心分配模式(greedy)
- config_loader.py: 添加 etf_pool 字段和 GREEDY 枚举
- config_simple.yaml: 每个资产添加 etf_pool 列表
- simple_rotation.py:
  - 添加 _compute_greedy_weights 方法
  - _calculate_daily_return 支持 greedy 模式
  - 向后兼容原有 rank/equal 模式

贪心算法:按 ETF 池容量分配仓位,装不下的顺延给下一名
- 有色金属(1 ETF): 吸收25%,顺延75%
- 原油(3 ETF): 吸收75%
- 黄金(4 ETF): 吸收100%

回测对比 (select_num=3):
- rank: 326.60% 累计收益, 1.24 夏普
- greedy: 421.35% 累计收益, 1.03 夏普
2026-06-21 12:40:40 +08:00

174 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
补充 ETF 丰富信息到 etf_basic_full.csv批量版
- fund_share: 最新基金份额
- fund_daily: 近20日日均成交额
- fund_nav: 最新累计净值
"""
import os
import sys
import time
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime, timedelta
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / '.env')
import tushare as ts
pro = ts.pro_api(os.getenv('TUSHARE_TOKEN'))
csv_path = PROJECT_ROOT / 'rotation' / 'results' / 'etf_basic_full.csv'
df = pd.read_csv(csv_path)
print(f"加载 {len(df)} 只ETF")
# 初始化新列
df['fund_scale_yi'] = np.nan
df['avg_amount_wan'] = np.nan
df['avg_vol_wan'] = np.nan
df['latest_nav'] = np.nan
df['nav_date'] = ''
end_date = datetime.now().strftime('%Y%m%d')
start_date = (datetime.now() - timedelta(days=60)).strftime('%Y%m%d')
daily_start = (datetime.now() - timedelta(days=40)).strftime('%Y%m%d')
all_codes = df['ts_code'].tolist()
# ============================================================
# 1. fund_share - 批量获取每次最多100只
# ============================================================
print(f"\n[1/3] 获取基金份额...")
share_records = []
batch_size = 100
for i in range(0, len(all_codes), batch_size):
batch = all_codes[i:i+batch_size]
ts_codes_str = ','.join(batch)
try:
share_df = pro.fund_share(ts_code=ts_codes_str, start_date=start_date, end_date=end_date)
if share_df is not None and len(share_df) > 0:
share_records.append(share_df)
except Exception as e:
pass
if (i + batch_size) % 500 == 0:
print(f" 进度: {min(i+batch_size, len(all_codes))}/{len(all_codes)}")
time.sleep(0.3)
if share_records:
all_shares = pd.concat(share_records, ignore_index=True)
# 每只取最新一条
latest_shares = all_shares.sort_values('trade_date', ascending=False).drop_duplicates(subset=['ts_code'], keep='first')
share_map = dict(zip(latest_shares['ts_code'], latest_shares['fd_share']))
df['fund_scale_yi'] = df['ts_code'].map(share_map) / 10000 # 万份→亿份
print(f" ✓ 份额数据: {len(share_map)}")
else:
print(f" ✗ 份额数据获取失败")
# ============================================================
# 2. fund_daily - 批量获取
# ============================================================
print(f"\n[2/3] 获取日线行情...")
daily_records = []
for i in range(0, len(all_codes), batch_size):
batch = all_codes[i:i+batch_size]
ts_codes_str = ','.join(batch)
try:
daily_df = pro.fund_daily(ts_code=ts_codes_str, start_date=daily_start, end_date=end_date)
if daily_df is not None and len(daily_df) > 0:
daily_records.append(daily_df)
except Exception as e:
pass
if (i + batch_size) % 500 == 0:
print(f" 进度: {min(i+batch_size, len(all_codes))}/{len(all_codes)}")
time.sleep(0.3)
if daily_records:
all_daily = pd.concat(daily_records, ignore_index=True)
# 每只取最近20个交易日算均值
daily_stats = all_daily.groupby('ts_code').agg(
avg_amount=('amount', 'mean'),
avg_vol=('vol', 'mean')
).reset_index()
daily_map_amt = dict(zip(daily_stats['ts_code'], daily_stats['avg_amount']))
daily_map_vol = dict(zip(daily_stats['ts_code'], daily_stats['avg_vol']))
df['avg_amount_wan'] = df['ts_code'].map(daily_map_amt) / 10 # 千元→万元
df['avg_vol_wan'] = df['ts_code'].map(daily_map_vol) / 10000 # 手→万手
print(f" ✓ 行情数据: {len(daily_stats)}")
else:
print(f" ✗ 行情数据获取失败")
# ============================================================
# 3. fund_nav - 批量获取
# ============================================================
print(f"\n[3/3] 获取基金净值...")
nav_records = []
for i in range(0, len(all_codes), batch_size):
batch = all_codes[i:i+batch_size]
ts_codes_str = ','.join(batch)
try:
nav_df = pro.fund_nav(ts_code=ts_codes_str, start_date=start_date, end_date=end_date)
if nav_df is not None and len(nav_df) > 0:
nav_records.append(nav_df)
except Exception as e:
pass
if (i + batch_size) % 500 == 0:
print(f" 进度: {min(i+batch_size, len(all_codes))}/{len(all_codes)}")
time.sleep(0.3)
if nav_records:
all_nav = pd.concat(nav_records, ignore_index=True)
latest_nav = all_nav.sort_values('nav_date', ascending=False).drop_duplicates(subset=['ts_code'], keep='first')
nav_map = dict(zip(latest_nav['ts_code'], latest_nav['accum_nav']))
nav_date_map = dict(zip(latest_nav['ts_code'], latest_nav['nav_date'].astype(str)))
df['latest_nav'] = df['ts_code'].map(nav_map)
df['nav_date'] = df['ts_code'].map(nav_date_map)
print(f" ✓ 净值数据: {len(nav_map)}")
else:
print(f" ✗ 净值数据获取失败")
# 用 NAV 精确计算规模
mask = df['fund_scale_yi'].notna() & df['latest_nav'].notna()
df.loc[mask, 'fund_scale_yi'] = (df.loc[mask, 'fund_scale_yi'] * df.loc[mask, 'latest_nav']).round(2)
# 保存
df.to_csv(csv_path, index=False, encoding='utf-8-sig')
print(f"\n已保存: {csv_path}")
print(f"新增字段: fund_scale_yi, avg_amount_wan, avg_vol_wan, latest_nav, nav_date")
# 统计
print(f"\n{'='*80}")
print("数据质量")
print(f"{'='*80}")
print(f"基金规模: {df['fund_scale_yi'].notna().sum()}/{len(df)} ({df['fund_scale_yi'].notna().mean()*100:.1f}%)")
print(f"日均成交额: {df['avg_amount_wan'].notna().sum()}/{len(df)} ({df['avg_amount_wan'].notna().mean()*100:.1f}%)")
print(f"最新净值: {df['latest_nav'].notna().sum()}/{len(df)} ({df['latest_nav'].notna().mean()*100:.1f}%)")
# 轮动策略标的池
print(f"\n{'='*80}")
print("轮动策略标的池")
print(f"{'='*80}")
pool_etfs = {
'159915.SZ': '创业板指', '512890.SH': '红利低波',
'159920.SZ': '恒生指数', '513130.SH': '恒生科技',
'513100.SH': '纳指100', '513520.SH': '日经225',
'513030.SH': '德国DAX', '518880.SH': '黄金',
'159980.SZ': '有色金属', '160723.SZ': '原油',
}
print(f"{'代码':<12} {'名称':<8} {'规模(亿)':>10} {'日成交额(万)':>14} {'净值':>8} {'费率':>6} {'类型':<6}")
print('-' * 75)
for code, name in pool_etfs.items():
row = df[df['ts_code'] == code]
if len(row) > 0:
r = row.iloc[0]
scale = f"{r['fund_scale_yi']:.1f}" if pd.notna(r['fund_scale_yi']) else 'N/A'
amt = f"{r['avg_amount_wan']:,.0f}" if pd.notna(r['avg_amount_wan']) else 'N/A'
nav = f"{r['latest_nav']:.4f}" if pd.notna(r['latest_nav']) else 'N/A'
fee = f"{r['mgt_fee']}%" if pd.notna(r['mgt_fee']) else 'N/A'
etype = str(r['etf_type'])
print(f"{code:<12} {name:<8} {scale:>10} {amt:>14} {nav:>8} {fee:>6} {etype:<6}")