Files
etf/rotation/simple_rotation.py
aszerW 5e11b6b690 fix(rotation): 溢价率缓存增加增量更新逻辑
- preload_premium: 检查缓存日期范围,不足时增量拉取
- 新增 _fetch_premium_api: 拉取并合并新溢价率数据
- 调用时传入 end_date 触发增量检查

修复前: premium CSV存在即返回旧数据,明天9点运行时拿不到最新
修复后: 检测 latest_cached < end_date 时自动拉取增量
2026-06-01 23:56:18 +08:00

917 lines
38 KiB
Python

"""
Simple Rotation Strategy (Daily Iteration)
From A-share trading calendar, iterate daily:
1. Get signal source last N days -> compute momentum
2. Get bond momentum (dynamic threshold)
3. Group selection -> generate holdings
4. Compare with yesterday -> compute T+1 return
5. Update NAV
"""
import os
import sys
import math
import json
import time
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
# ============================================================
# Pure functions: momentum
# ============================================================
def weighted_momentum_score(prices: np.ndarray) -> float:
"""Weighted linear regression momentum score = annualized_return * R^2"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_return * r2
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
# ============================================================
# Data cache
# ============================================================
class DataCache:
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
if cache_dir is None:
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
# premium data cache: {trade_code: {date_str: premium_ratio}}
self.premium_data: Dict[str, Dict[str, float]] = {}
def _cache_path(self, code: str, adj: str) -> Path:
prefix = 'index' if adj == 'raw' else 'etf'
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"{prefix}_{safe_code}.csv"
def _premium_cache_path(self, code: str) -> Path:
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"premium_{safe_code}.csv"
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Preload full history and cache to CSV"""
cache_path = self._cache_path(code, adj)
if cache_path.exists():
try:
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
if len(df) > 0:
cs = df.index.min().strftime('%Y-%m-%d')
ce = df.index.max().strftime('%Y-%m-%d')
if cs <= start_date and ce >= end_date:
return df
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
if new_start <= end_date:
new_df = self._fetch_api(code, new_start, end_date, adj)
if new_df is not None and len(new_df) > 0:
df = pd.concat([df, new_df])
df = df[~df.index.duplicated(keep='last')]
df.to_csv(cache_path)
return df
except Exception:
pass
df = self._fetch_api(code, start_date, end_date, adj)
if df is not None and len(df) > 0:
df.to_csv(cache_path)
return df
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API, also extracts premium_series for ETFs"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
print(f" x {code}: HTTP {resp.status_code}")
return None
data = resp.json()
if 'error' in data:
print(f" x {code}: {data['error']}")
return None
records = data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
# Extract and cache premium_series (ETF only)
premium_series = data.get('premium_series', [])
if premium_series:
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
self._save_premium_cache(code, df.attrs['premium_series'])
print(f" + {code}: {len(df)} rows ({adj})")
return df
except requests.exceptions.Timeout:
if attempt < 2:
continue
print(f" x {code}: timeout")
return None
except Exception as e:
print(f" x {code}: {e}")
return None
return None
def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]):
"""Save premium data to CSV cache"""
try:
cache_path = self._premium_cache_path(code)
pd.DataFrame(
[{'date': k, 'premium': v} for k, v in premium_dict.items()]
).to_csv(cache_path, index=False)
except Exception:
pass
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
"""Load premium data for an ETF code from cache, with incremental update.
If cache exists but doesn't cover end_date, fetches the gap."""
if code in self.premium_data:
# Already in memory - check if up-to-date
if end_date:
dates = sorted(self.premium_data[code].keys())
if dates and dates[-1] >= end_date:
return self.premium_data[code]
else:
return self.premium_data[code]
cache_path = self._premium_cache_path(code)
if cache_path.exists():
try:
df = pd.read_csv(cache_path)
if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns:
self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium']))
# Check if cache covers end_date
if end_date:
latest_cached = max(self.premium_data[code].keys())
if latest_cached >= end_date:
return self.premium_data[code]
# Cache is stale - fetch gap from latest_cached+1 to end_date
fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d')
self._fetch_premium_api(code, fetch_start, end_date)
return self.premium_data[code]
except Exception:
pass
# No cache: fetch full history from API
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
return self.premium_data.get(code)
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
"""Fetch premium_series from API and merge into cache"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
return
data = resp.json()
if 'error' in data:
return
premium_series = data.get('premium_series', [])
if premium_series:
new_data = {item['date']: item['premium'] for item in premium_series}
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(new_data)
self._save_premium_cache(code, self.premium_data[code])
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
return
except requests.exceptions.Timeout:
if attempt < 2:
continue
return
except Exception:
return
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
"""Fetch trading calendar from API"""
url = f"{self.base_url}/api/v1/trading-calendar"
params = {'market': market, 'start': start_date, 'end': end_date}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
continue
return None
data = resp.json()
if 'error' in data:
print(f" x calendar: {data['error']}")
return None
dates = data.get('trading_dates', [])
if not dates:
return pd.DatetimeIndex([])
result = pd.DatetimeIndex(dates)
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
return result
except Exception as e:
if attempt < 2:
continue
print(f" x calendar: {e}")
return None
return None
# ============================================================
# Core strategy class
# ============================================================
class SimpleRotationStrategy:
"""
Simple rotation strategy (daily iteration)
Flow:
1. Load config
2. Preload all historical data (index raw + ETF hfq)
3. Fetch A-share trading calendar
4. For each trading day:
- Compute momentum (last 25 days)
- Group selection -> generate holdings
- Compare with yesterday -> compute T+1 return
- Update NAV
5. Export results
"""
def __init__(self, config_path: str = None):
if config_path is None:
config_path = Path(__file__).parent / 'config_simple.yaml'
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
# Strategy params
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
# Dynamic threshold
threshold = self.config.rotation.threshold
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
# Signal codes and mappings
self.signal_codes = []
self.signal_to_trade = {}
self.code_to_group = {}
self.trade_code_to_group = {}
for code, asset in self.config.asset_pools.assets.items():
self.signal_codes.append(asset.signal_source)
self.signal_to_trade[asset.signal_source] = asset.trade_source
self.code_to_group[asset.signal_source] = asset.group
self.trade_code_to_group[asset.trade_source] = asset.group
# Data source
data_source = self.config.data.sources[0]
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
# Preloaded data
self.index_data: Dict[str, pd.DataFrame] = {}
self.etf_data: Dict[str, pd.DataFrame] = {}
# Results
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
print("\n[1/4] Preloading signal sources (index raw)...")
for code in self.signal_codes:
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
if df is not None:
self.index_data[code] = df
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
print("\n[2/4] Preloading trade sources (ETF hfq)...")
trade_codes = set(self.signal_to_trade.values())
for code in trade_codes:
is_bond = any(
a.trade_source == code and a.group == 'BOND'
for a in self.config.asset_pools.assets.values()
)
adj = 'raw' if is_bond else 'hfq'
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
if df is not None:
self.etf_data[code] = df
# Load premium data cache for all ETF trade codes
for code in trade_codes:
self.data_cache.preload_premium(code, end_date=end_date)
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded")
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum for a single code on a given date"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
return weighted_momentum_score(prices)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
Generate rotation signals (group competition + dynamic threshold + BOND fill)
Returns:
(holdings, factors, bond_momentum): tuple of selected holdings,
all computed factor scores, and the bond momentum threshold value.
Logic (identical to V2):
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
2. From all group winners: sort by momentum, select Top select_num
3. Fill remaining slots with BOND
"""
factors: Dict[str, float] = {}
for code in self.signal_codes:
score = self._compute_momentum(code, date)
if score is not None:
factors[code] = score
if not factors:
return [], {}, None
bond_momentum = None
if self.use_dynamic_threshold and self.bond_code:
bond_momentum = self._compute_momentum(self.bond_code, date)
if bond_momentum is None:
bond_momentum = self.fallback_value
groups = self.config.asset_pools.by_group
selected_by_group: Dict[str, Tuple[str, float]] = {}
for group_name, assets in groups.items():
group_codes = [a.signal_source for a in assets.values()]
group_factors = {c: factors[c] for c in group_codes if c in factors}
if not group_factors:
continue
if group_name != 'BOND' and bond_momentum is not None:
thresh = bond_momentum * self.bond_ratio
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
if not group_factors:
continue
top_code = max(group_factors, key=group_factors.get)
selected_by_group[group_name] = (top_code, group_factors[top_code])
if not selected_by_group:
return [], factors, bond_momentum
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
final_holdings = [code for code, _ in candidates[:self.select_num]]
if len(final_holdings) < self.select_num and self.bond_code:
if self.bond_code not in final_holdings:
n_slots = self.select_num - len(final_holdings)
final_holdings.extend([self.bond_code] * n_slots)
return sorted(final_holdings), factors, bond_momentum
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
Note: Index data may lack 'open' column; use close as fallback.
"""
if trade_code not in self.etf_data:
return None
df = self.etf_data[trade_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < 2:
return None
today_row = recent.iloc[-1]
prev_row = recent.iloc[-2]
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close = float(today_row['close'])
prev_close = float(prev_row['close'])
# Handle missing open (common for index data like 931862.CSI)
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
if pd.isna(open_price) or open_price == 0:
open_price = close
if pd.isna(close) or pd.isna(prev_close):
return None
return {
'open': open_price,
'close': close,
'prev_close': prev_close,
}
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution):
- Hold: close-to-close
- Sell: close-to-open (sold at open)
- Buy: open-to-close (intraday)
"""
if not old_holdings:
if not new_holdings:
return 0.0
weight = 1.0 / len(new_holdings)
ret = 0.0
for code in new_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0:
ret += weight * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
weight = 1.0 / len(old_holdings)
daily_return = 0.0
for code in old_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += weight * r
for code in new_holdings:
if code not in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += weight * r
if is_rebalance:
daily_return -= self.trade_cost
return daily_return
def run(self) -> dict:
"""Main backtest loop"""
print("=" * 60)
print(" Simple Rotation Strategy (Daily Iteration)")
print("=" * 60)
self._preload_data()
print("\n[3/4] Fetching A-share trading calendar...")
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
if self.trading_calendar is None or len(self.trading_calendar) == 0:
print(" x Calendar fetch failed")
return {}
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
current_holdings: List[str] = []
nav = 1.0
rebalance_count = 0
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
for i, date in enumerate(self.trading_calendar):
# Signal timing: 9:00 AM on day T
# At this moment, T's market has NOT opened yet.
# Only T-1 close data is available for all markets.
# So momentum must be computed from T-1 close (prev_date).
# Execution happens at 9:30 AM using T's ETF prices.
if i > 0:
signal_date = self.trading_calendar[i - 1] # T-1 close
else:
signal_date = date # First day: no prior trading day available
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
# Return uses T's ETF prices (open for buy/sell, close for hold)
daily_return = self._calculate_daily_return(
current_holdings, new_holdings, date, is_rebalance
)
nav *= (1 + daily_return)
if is_rebalance:
rebalance_count += 1
# Update entry tracking for held assets
added = set(new_holdings) - set(current_holdings)
removed = set(current_holdings) - set(new_holdings)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price is the actual buy price: today's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (the data used to make the decision)
entry_idx = self._get_index_close(code, signal_date)
entry_info[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': entry_idx,
}
for code in removed:
entry_info.pop(code, None)
# Compute bond threshold value for detail record
threshold_val = 0.0
if self.use_dynamic_threshold and bond_momentum is not None:
threshold_val = round(bond_momentum * self.bond_ratio, 6)
self.daily_records.append({
'date': date.strftime('%Y-%m-%d'),
'nav': round(nav, 6),
'daily_return': round(daily_return, 6),
'is_rebalance': is_rebalance,
'holdings': sorted(new_holdings),
'added': sorted(added),
'removed': sorted(removed),
'factors': {k: round(v, 6) for k, v in factors.items()},
'threshold': threshold_val,
})
current_holdings = new_holdings
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
print(f" [{i+1}/{len(self.trading_calendar)}] "
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
metrics = self._compute_metrics(rebalance_count)
print("\n" + "=" * 60)
print(" Backtest Complete")
print("=" * 60)
if self.daily_records:
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
print(f" Trading days: {len(self.daily_records)}")
print(f" Total return: {metrics['total_return']:.2%}")
print(f" Annual return: {metrics['annual_return']:.2%}")
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
print(f" Win rate: {metrics['win_rate']:.2%}")
print(f" Rebalances: {rebalance_count}")
print("=" * 60)
return {
'metrics': metrics,
'daily_records': self.daily_records,
}
def _compute_metrics(self, rebalance_count: int) -> dict:
"""Compute performance metrics"""
if not self.daily_records:
return {}
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
nav = df['nav']
returns = df['daily_return']
total_return = nav.iloc[-1] / nav.iloc[0] - 1
n_days = len(df)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
peak = nav.cummax()
drawdown = (nav - peak) / peak
max_drawdown = drawdown.min()
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
non_zero = returns[returns != 0]
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'calmar_ratio': calmar,
'win_rate': win_rate,
'n_days': n_days,
'rebalance_count': rebalance_count,
}
# ============================================================
# Detail JSON helpers (V2-compatible)
# ============================================================
def _build_meta_codes(self) -> dict:
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
codes = {}
for code, asset in self.config.asset_pools.assets.items():
codes[asset.signal_source] = {
'name': getattr(asset, 'name', code),
'etf': asset.trade_source,
'market': asset.group,
}
return codes
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
"""Get index close price on or before date."""
df = self.index_data.get(code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get ETF close price on or before date."""
df = self.etf_data.get(trade_code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
"""Get (index_return, etf_return_ctc) for a code on a given date."""
trade_code = self.signal_to_trade.get(code, code)
idx_df = self.index_data.get(code)
etf_df = self.etf_data.get(trade_code)
idx_ret = None
if idx_df is not None:
mask = idx_df.index <= date
if mask.sum() >= 2:
rows = idx_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
idx_ret = round((c_today - c_prev) / c_prev, 6)
etf_ret = None
if etf_df is not None:
mask = etf_df.index <= date
if mask.sum() >= 2:
rows = etf_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
etf_ret = round((c_today - c_prev) / c_prev, 6)
return idx_ret, etf_ret
def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get real premium from API data cache: (ETF_price - NAV) / NAV.
Returns None for BOND or when premium data is unavailable."""
group = self.trade_code_to_group.get(trade_code, '')
if group == 'BOND':
return None
premium_dict = self.data_cache.premium_data.get(trade_code)
if not premium_dict:
return None
date_str = date.strftime('%Y-%m-%d')
val = premium_dict.get(date_str)
if val is None:
return None
return round(float(val), 6)
def _build_day_assets(self, record: dict, date: pd.Timestamp,
entry_info: Dict[str, dict]) -> dict:
"""Build V2-compatible per-asset detail dict for one day."""
factors = record.get('factors', {})
holdings = set(record['holdings'])
threshold = record.get('threshold', 0.0)
# Rank: sort all codes by momentum descending, rank 1 = highest
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
assets = {}
for code in self.signal_codes:
momentum = factors.get(code)
rank = rank_map.get(code)
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
is_held = code in holdings
trade_code = self.signal_to_trade.get(code, code)
idx_close = self._get_index_close(code, date)
etf_close = self._get_etf_close(trade_code, date)
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
premium = self._compute_premium(trade_code, date)
# Entry / holding info
ei = entry_info.get(code) if is_held else None
entry_date = ei['entry_date'] if ei else None
entry_price_etf = ei['entry_price_etf'] if ei else None
entry_price_idx = ei['entry_price_idx'] if ei else None
holding_days = 0
cum_ret_etf = None
cum_ret_idx = None
if ei and is_held:
entry_dt = pd.Timestamp(ei['entry_date'])
holding_days = int((date - entry_dt).days)
if holding_days < 1:
holding_days = 1
ep_etf = ei.get('entry_price_etf')
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
ep_idx = ei.get('entry_price_idx')
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
assets[code] = {
'momentum': float(momentum) if momentum is not None else None,
'rank': rank,
'threshold': float(threshold),
'above_threshold': bool(above_thresh),
'index_close': float(idx_close) if idx_close is not None else None,
'etf_close': float(etf_close) if etf_close is not None else None,
'index_return': float(idx_ret) if idx_ret is not None else None,
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
'premium': float(premium) if premium is not None else None,
'is_held': bool(is_held),
'entry_date': entry_date,
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
'holding_days': holding_days,
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
}
return assets
# ============================================================
# Export
# ============================================================
def export_results(self, output_dir: str = None):
"""Export backtest results to CSV and JSON (V2-compatible detail format)"""
if not self.daily_records:
print(" x No results to export")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(self.daily_records)
# NAV curve
nav_path = output_dir / 'simple_rotation_nav.csv'
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
print(f" + NAV: {nav_path}")
# Signals
sig_path = output_dir / 'simple_rotation_signals.csv'
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
print(f" + Signals: {sig_path}")
# Detail JSON (V2-compatible format)
detail_path = output_dir / 'simple_rotation_detail.json'
days_out = []
# Track entry_info across days for asset detail reconstruction
tracked_entry: Dict[str, dict] = {}
prev_holdings = []
# Build date index map for signal_date lookup (T-1)
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
date_to_signal_date = {}
for i, d in enumerate(date_list):
date_to_signal_date[d] = date_list[i - 1] if i > 0 else d
for rec in self.daily_records:
date = pd.Timestamp(rec['date'])
signal_date = date_to_signal_date[date] # T-1 for signal
holdings = rec['holdings']
added = set(holdings) - set(prev_holdings)
removed = set(prev_holdings) - set(holdings)
# Update entry tracking (consistent with run() logic)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price = actual buy price at T's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (signal data used for decision)
idx_close = self._get_index_close(code, signal_date)
tracked_entry[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': idx_close,
}
for code in removed:
tracked_entry.pop(code, None)
# Build signals dict: {code: 1} for selected holdings
signals = {c: 1 for c in holdings}
# Build per-asset details
assets = self._build_day_assets(rec, date, tracked_entry)
days_out.append({
'date': rec['date'],
'nav': rec['nav'],
'daily_return': rec['daily_return'],
'is_rebalance': rec['is_rebalance'],
'signals': signals,
'holdings': holdings,
'added': rec['added'],
'removed': rec['removed'],
'assets': assets,
})
prev_holdings = holdings
detail = {
'meta': {
'mode': 'Simple: Daily Iteration',
'start_date': self.config.backtest.start_date,
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
'total_days': len(self.daily_records),
'select_num': self.select_num,
'n_days': self.n_days,
'trade_cost': self.trade_cost,
'bond_threshold': {
'enabled': self.use_dynamic_threshold,
'bond_code': self.bond_code,
'ratio': self.bond_ratio,
},
'codes': self._build_meta_codes(),
},
'days': days_out,
}
with open(detail_path, 'w', encoding='utf-8') as f:
json.dump(detail, f, ensure_ascii=False, indent=2)
print(f" + Detail: {detail_path} ({len(days_out)} days)")
# Metrics JSON
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
metrics_path = output_dir / 'simple_rotation_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")
# ============================================================
# Entry point
# ============================================================
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
strategy = SimpleRotationStrategy()
result = strategy.run()
if result:
strategy.export_results()