Files
etf/rotation/simple_rotation.py
aszerW 6d0b928894 fix(rotation): 消除前视偏差 + V2兼容detail导出
时序对齐修复:
- 信号生成改用 T-1 收盘数据(9AM信号时T日未开盘)
- entry_price_etf 改用 T 日 open(实际买入价)
- 年化收益: 52.66% → 25.12%(去除约4倍虚高)

V2兼容detail JSON:
- _generate_signals 返回 (holdings, factors, bond_momentum)
- 6个helper方法: build_meta_codes, get_index/etf_close, daily_returns, premium, day_assets
- 每日11资产×16字段完整记录(momentum/rank/holding_days/cum_return等)
- export_results 同步修复 entry_info 时序逻辑

Backtest (2020-01-10 ~ 2026-06-01, 1545天):
- 总收益 295.14%, 年化 25.12%
- 最大回撤 -14.74%, 夏普 1.33, 卡尔马 1.70
2026-06-01 23:13:43 +08:00

823 lines
33 KiB
Python

"""
Simple Rotation Strategy (Daily Iteration)
From A-share trading calendar, iterate daily:
1. Get signal source last N days -> compute momentum
2. Get bond momentum (dynamic threshold)
3. Group selection -> generate holdings
4. Compare with yesterday -> compute T+1 return
5. Update NAV
"""
import os
import sys
import math
import json
import time
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
# ============================================================
# Pure functions: momentum
# ============================================================
def weighted_momentum_score(prices: np.ndarray) -> float:
"""Weighted linear regression momentum score = annualized_return * R^2"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_return * r2
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
# ============================================================
# Data cache
# ============================================================
class DataCache:
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
if cache_dir is None:
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _cache_path(self, code: str, adj: str) -> Path:
prefix = 'index' if adj == 'raw' else 'etf'
safe_code = code.replace('=', '_').replace('^', '_')
return self.cache_dir / f"{prefix}_{safe_code}.csv"
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Preload full history and cache to CSV"""
cache_path = self._cache_path(code, adj)
if cache_path.exists():
try:
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
if len(df) > 0:
cs = df.index.min().strftime('%Y-%m-%d')
ce = df.index.max().strftime('%Y-%m-%d')
if cs <= start_date and ce >= end_date:
return df
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
if new_start <= end_date:
new_df = self._fetch_api(code, new_start, end_date, adj)
if new_df is not None and len(new_df) > 0:
df = pd.concat([df, new_df])
df = df[~df.index.duplicated(keep='last')]
df.to_csv(cache_path)
return df
except Exception:
pass
df = self._fetch_api(code, start_date, end_date, adj)
if df is not None and len(df) > 0:
df.to_csv(cache_path)
return df
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
print(f" x {code}: HTTP {resp.status_code}")
return None
data = resp.json()
if 'error' in data:
print(f" x {code}: {data['error']}")
return None
records = data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
print(f" + {code}: {len(df)} rows ({adj})")
return df
except requests.exceptions.Timeout:
if attempt < 2:
continue
print(f" x {code}: timeout")
return None
except Exception as e:
print(f" x {code}: {e}")
return None
return None
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
"""Fetch trading calendar from API"""
url = f"{self.base_url}/api/v1/trading-calendar"
params = {'market': market, 'start': start_date, 'end': end_date}
for attempt in range(3):
try:
resp = requests.get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
continue
return None
data = resp.json()
if 'error' in data:
print(f" x calendar: {data['error']}")
return None
dates = data.get('trading_dates', [])
if not dates:
return pd.DatetimeIndex([])
result = pd.DatetimeIndex(dates)
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
return result
except Exception as e:
if attempt < 2:
continue
print(f" x calendar: {e}")
return None
return None
# ============================================================
# Core strategy class
# ============================================================
class SimpleRotationStrategy:
"""
Simple rotation strategy (daily iteration)
Flow:
1. Load config
2. Preload all historical data (index raw + ETF hfq)
3. Fetch A-share trading calendar
4. For each trading day:
- Compute momentum (last 25 days)
- Group selection -> generate holdings
- Compare with yesterday -> compute T+1 return
- Update NAV
5. Export results
"""
def __init__(self, config_path: str = None):
if config_path is None:
config_path = Path(__file__).parent / 'config_simple.yaml'
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
# Strategy params
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
# Dynamic threshold
threshold = self.config.rotation.threshold
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
# Signal codes and mappings
self.signal_codes = []
self.signal_to_trade = {}
self.code_to_group = {}
for code, asset in self.config.asset_pools.assets.items():
self.signal_codes.append(asset.signal_source)
self.signal_to_trade[asset.signal_source] = asset.trade_source
self.code_to_group[asset.signal_source] = asset.group
# Data source
data_source = self.config.data.sources[0]
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
# Preloaded data
self.index_data: Dict[str, pd.DataFrame] = {}
self.etf_data: Dict[str, pd.DataFrame] = {}
# Results
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
print("\n[1/4] Preloading signal sources (index raw)...")
for code in self.signal_codes:
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
if df is not None:
self.index_data[code] = df
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
print("\n[2/4] Preloading trade sources (ETF hfq)...")
trade_codes = set(self.signal_to_trade.values())
for code in trade_codes:
is_bond = any(
a.trade_source == code and a.group == 'BOND'
for a in self.config.asset_pools.assets.values()
)
adj = 'raw' if is_bond else 'hfq'
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
if df is not None:
self.etf_data[code] = df
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK")
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum for a single code on a given date"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
return weighted_momentum_score(prices)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
Generate rotation signals (group competition + dynamic threshold + BOND fill)
Returns:
(holdings, factors, bond_momentum): tuple of selected holdings,
all computed factor scores, and the bond momentum threshold value.
Logic (identical to V2):
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
2. From all group winners: sort by momentum, select Top select_num
3. Fill remaining slots with BOND
"""
factors: Dict[str, float] = {}
for code in self.signal_codes:
score = self._compute_momentum(code, date)
if score is not None:
factors[code] = score
if not factors:
return [], {}, None
bond_momentum = None
if self.use_dynamic_threshold and self.bond_code:
bond_momentum = self._compute_momentum(self.bond_code, date)
if bond_momentum is None:
bond_momentum = self.fallback_value
groups = self.config.asset_pools.by_group
selected_by_group: Dict[str, Tuple[str, float]] = {}
for group_name, assets in groups.items():
group_codes = [a.signal_source for a in assets.values()]
group_factors = {c: factors[c] for c in group_codes if c in factors}
if not group_factors:
continue
if group_name != 'BOND' and bond_momentum is not None:
thresh = bond_momentum * self.bond_ratio
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
if not group_factors:
continue
top_code = max(group_factors, key=group_factors.get)
selected_by_group[group_name] = (top_code, group_factors[top_code])
if not selected_by_group:
return [], factors, bond_momentum
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
final_holdings = [code for code, _ in candidates[:self.select_num]]
if len(final_holdings) < self.select_num and self.bond_code:
if self.bond_code not in final_holdings:
n_slots = self.select_num - len(final_holdings)
final_holdings.extend([self.bond_code] * n_slots)
return sorted(final_holdings), factors, bond_momentum
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
Note: Index data may lack 'open' column; use close as fallback.
"""
if trade_code not in self.etf_data:
return None
df = self.etf_data[trade_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < 2:
return None
today_row = recent.iloc[-1]
prev_row = recent.iloc[-2]
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close = float(today_row['close'])
prev_close = float(prev_row['close'])
# Handle missing open (common for index data like 931862.CSI)
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
if pd.isna(open_price) or open_price == 0:
open_price = close
if pd.isna(close) or pd.isna(prev_close):
return None
return {
'open': open_price,
'close': close,
'prev_close': prev_close,
}
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution):
- Hold: close-to-close
- Sell: close-to-open (sold at open)
- Buy: open-to-close (intraday)
"""
if not old_holdings:
if not new_holdings:
return 0.0
weight = 1.0 / len(new_holdings)
ret = 0.0
for code in new_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0:
ret += weight * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
weight = 1.0 / len(old_holdings)
daily_return = 0.0
for code in old_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += weight * r
for code in new_holdings:
if code not in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += weight * r
if is_rebalance:
daily_return -= self.trade_cost
return daily_return
def run(self) -> dict:
"""Main backtest loop"""
print("=" * 60)
print(" Simple Rotation Strategy (Daily Iteration)")
print("=" * 60)
self._preload_data()
print("\n[3/4] Fetching A-share trading calendar...")
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
if self.trading_calendar is None or len(self.trading_calendar) == 0:
print(" x Calendar fetch failed")
return {}
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
current_holdings: List[str] = []
nav = 1.0
rebalance_count = 0
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
for i, date in enumerate(self.trading_calendar):
# Signal timing: 9:00 AM on day T
# At this moment, T's market has NOT opened yet.
# Only T-1 close data is available for all markets.
# So momentum must be computed from T-1 close (prev_date).
# Execution happens at 9:30 AM using T's ETF prices.
if i > 0:
signal_date = self.trading_calendar[i - 1] # T-1 close
else:
signal_date = date # First day: no prior trading day available
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
# Return uses T's ETF prices (open for buy/sell, close for hold)
daily_return = self._calculate_daily_return(
current_holdings, new_holdings, date, is_rebalance
)
nav *= (1 + daily_return)
if is_rebalance:
rebalance_count += 1
# Update entry tracking for held assets
added = set(new_holdings) - set(current_holdings)
removed = set(current_holdings) - set(new_holdings)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price is the actual buy price: today's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (the data used to make the decision)
entry_idx = self._get_index_close(code, signal_date)
entry_info[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': entry_idx,
}
for code in removed:
entry_info.pop(code, None)
# Compute bond threshold value for detail record
threshold_val = 0.0
if self.use_dynamic_threshold and bond_momentum is not None:
threshold_val = round(bond_momentum * self.bond_ratio, 6)
self.daily_records.append({
'date': date.strftime('%Y-%m-%d'),
'nav': round(nav, 6),
'daily_return': round(daily_return, 6),
'is_rebalance': is_rebalance,
'holdings': sorted(new_holdings),
'added': sorted(added),
'removed': sorted(removed),
'factors': {k: round(v, 6) for k, v in factors.items()},
'threshold': threshold_val,
})
current_holdings = new_holdings
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
print(f" [{i+1}/{len(self.trading_calendar)}] "
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
metrics = self._compute_metrics(rebalance_count)
print("\n" + "=" * 60)
print(" Backtest Complete")
print("=" * 60)
if self.daily_records:
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
print(f" Trading days: {len(self.daily_records)}")
print(f" Total return: {metrics['total_return']:.2%}")
print(f" Annual return: {metrics['annual_return']:.2%}")
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
print(f" Win rate: {metrics['win_rate']:.2%}")
print(f" Rebalances: {rebalance_count}")
print("=" * 60)
return {
'metrics': metrics,
'daily_records': self.daily_records,
}
def _compute_metrics(self, rebalance_count: int) -> dict:
"""Compute performance metrics"""
if not self.daily_records:
return {}
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
nav = df['nav']
returns = df['daily_return']
total_return = nav.iloc[-1] / nav.iloc[0] - 1
n_days = len(df)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
peak = nav.cummax()
drawdown = (nav - peak) / peak
max_drawdown = drawdown.min()
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
non_zero = returns[returns != 0]
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'calmar_ratio': calmar,
'win_rate': win_rate,
'n_days': n_days,
'rebalance_count': rebalance_count,
}
# ============================================================
# Detail JSON helpers (V2-compatible)
# ============================================================
def _build_meta_codes(self) -> dict:
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
codes = {}
for code, asset in self.config.asset_pools.assets.items():
codes[asset.signal_source] = {
'name': getattr(asset, 'name', code),
'etf': asset.trade_source,
'market': asset.group,
}
return codes
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
"""Get index close price on or before date."""
df = self.index_data.get(code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get ETF close price on or before date."""
df = self.etf_data.get(trade_code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
"""Get (index_return, etf_return_ctc) for a code on a given date."""
trade_code = self.signal_to_trade.get(code, code)
idx_df = self.index_data.get(code)
etf_df = self.etf_data.get(trade_code)
idx_ret = None
if idx_df is not None:
mask = idx_df.index <= date
if mask.sum() >= 2:
rows = idx_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
idx_ret = round((c_today - c_prev) / c_prev, 6)
etf_ret = None
if etf_df is not None:
mask = etf_df.index <= date
if mask.sum() >= 2:
rows = etf_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
etf_ret = round((c_today - c_prev) / c_prev, 6)
return idx_ret, etf_ret
def _compute_premium(self, code: str, idx_close: float, etf_close: float) -> Optional[float]:
"""Compute premium = (etf_close - index_close) / index_close.
Only meaningful for ETFs that track an index (not bonds)."""
group = self.code_to_group.get(code, '')
if group == 'BOND':
return None
if idx_close is None or etf_close is None or idx_close == 0:
return None
return round((etf_close - idx_close) / idx_close, 6)
def _build_day_assets(self, record: dict, date: pd.Timestamp,
entry_info: Dict[str, dict]) -> dict:
"""Build V2-compatible per-asset detail dict for one day."""
factors = record.get('factors', {})
holdings = set(record['holdings'])
threshold = record.get('threshold', 0.0)
# Rank: sort all codes by momentum descending, rank 1 = highest
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
assets = {}
for code in self.signal_codes:
momentum = factors.get(code)
rank = rank_map.get(code)
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
is_held = code in holdings
trade_code = self.signal_to_trade.get(code, code)
idx_close = self._get_index_close(code, date)
etf_close = self._get_etf_close(trade_code, date)
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
premium = self._compute_premium(code, idx_close, etf_close)
# Entry / holding info
ei = entry_info.get(code) if is_held else None
entry_date = ei['entry_date'] if ei else None
entry_price_etf = ei['entry_price_etf'] if ei else None
entry_price_idx = ei['entry_price_idx'] if ei else None
holding_days = 0
cum_ret_etf = None
cum_ret_idx = None
if ei and is_held:
entry_dt = pd.Timestamp(ei['entry_date'])
holding_days = int((date - entry_dt).days)
if holding_days < 1:
holding_days = 1
ep_etf = ei.get('entry_price_etf')
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
ep_idx = ei.get('entry_price_idx')
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
assets[code] = {
'momentum': float(momentum) if momentum is not None else None,
'rank': rank,
'threshold': float(threshold),
'above_threshold': bool(above_thresh),
'index_close': float(idx_close) if idx_close is not None else None,
'etf_close': float(etf_close) if etf_close is not None else None,
'index_return': float(idx_ret) if idx_ret is not None else None,
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
'premium': float(premium) if premium is not None else None,
'is_held': bool(is_held),
'entry_date': entry_date,
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
'holding_days': holding_days,
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
}
return assets
# ============================================================
# Export
# ============================================================
def export_results(self, output_dir: str = None):
"""Export backtest results to CSV and JSON (V2-compatible detail format)"""
if not self.daily_records:
print(" x No results to export")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(self.daily_records)
# NAV curve
nav_path = output_dir / 'simple_rotation_nav.csv'
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
print(f" + NAV: {nav_path}")
# Signals
sig_path = output_dir / 'simple_rotation_signals.csv'
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
print(f" + Signals: {sig_path}")
# Detail JSON (V2-compatible format)
detail_path = output_dir / 'simple_rotation_detail.json'
days_out = []
# Track entry_info across days for asset detail reconstruction
tracked_entry: Dict[str, dict] = {}
prev_holdings = []
# Build date index map for signal_date lookup (T-1)
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
date_to_signal_date = {}
for i, d in enumerate(date_list):
date_to_signal_date[d] = date_list[i - 1] if i > 0 else d
for rec in self.daily_records:
date = pd.Timestamp(rec['date'])
signal_date = date_to_signal_date[date] # T-1 for signal
holdings = rec['holdings']
added = set(holdings) - set(prev_holdings)
removed = set(prev_holdings) - set(holdings)
# Update entry tracking (consistent with run() logic)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price = actual buy price at T's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (signal data used for decision)
idx_close = self._get_index_close(code, signal_date)
tracked_entry[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': idx_close,
}
for code in removed:
tracked_entry.pop(code, None)
# Build signals dict: {code: 1} for selected holdings
signals = {c: 1 for c in holdings}
# Build per-asset details
assets = self._build_day_assets(rec, date, tracked_entry)
days_out.append({
'date': rec['date'],
'nav': rec['nav'],
'daily_return': rec['daily_return'],
'is_rebalance': rec['is_rebalance'],
'signals': signals,
'holdings': holdings,
'added': rec['added'],
'removed': rec['removed'],
'assets': assets,
})
prev_holdings = holdings
detail = {
'meta': {
'mode': 'Simple: Daily Iteration',
'start_date': self.config.backtest.start_date,
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
'total_days': len(self.daily_records),
'select_num': self.select_num,
'n_days': self.n_days,
'trade_cost': self.trade_cost,
'bond_threshold': {
'enabled': self.use_dynamic_threshold,
'bond_code': self.bond_code,
'ratio': self.bond_ratio,
},
'codes': self._build_meta_codes(),
},
'days': days_out,
}
with open(detail_path, 'w', encoding='utf-8') as f:
json.dump(detail, f, ensure_ascii=False, indent=2)
print(f" + Detail: {detail_path} ({len(days_out)} days)")
# Metrics JSON
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
metrics_path = output_dir / 'simple_rotation_metrics.json'
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")
# ============================================================
# Entry point
# ============================================================
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
strategy = SimpleRotationStrategy()
result = strategy.run()
if result:
strategy.export_results()