- Add data integrity check: if any currently held asset is missing from factors, raise RuntimeError immediately to prevent false rebalance - Previously missing data would silently cause incorrect sell signals - Now fails fast with clear error message identifying the missing assets and the date of failure
1256 lines
54 KiB
Python
1256 lines
54 KiB
Python
"""
|
|
Simple Rotation Strategy (Daily Iteration)
|
|
|
|
From A-share trading calendar, iterate daily:
|
|
1. Get signal source last N days -> compute momentum
|
|
2. Get bond momentum (dynamic threshold)
|
|
3. Group selection -> generate holdings
|
|
4. Compare with yesterday -> compute T+1 return
|
|
5. Update NAV
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import math
|
|
import json
|
|
import time
|
|
import requests
|
|
import numpy as np
|
|
import pandas as pd
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
PROJECT_ROOT = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
|
|
|
|
|
|
# ============================================================
|
|
# Pure functions: momentum
|
|
# ============================================================
|
|
|
|
def weighted_momentum_score(prices: np.ndarray) -> float:
|
|
"""Weighted linear regression momentum score = annualized_return * R^2"""
|
|
if len(prices) < 5:
|
|
return 0.0
|
|
prices = np.clip(prices, 0.01, None)
|
|
y = np.log(prices)
|
|
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
|
return 0.0
|
|
x = np.arange(len(y))
|
|
weights = np.linspace(1, 2, len(y))
|
|
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
|
annualized_return = math.exp(slope * 250) - 1
|
|
y_pred = slope * x + intercept
|
|
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
|
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
|
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
|
return annualized_return * r2
|
|
|
|
|
|
def is_crash(prices: np.ndarray) -> bool:
|
|
"""Crash filter: 3 consecutive days drop > 5%"""
|
|
if len(prices) < 4:
|
|
return False
|
|
p = prices[-4:]
|
|
r1 = p[3] / p[2]
|
|
r2 = p[2] / p[1]
|
|
r3 = p[1] / p[0]
|
|
con1 = min(r1, r2, r3) < 0.95
|
|
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
|
|
return con1 or con2
|
|
|
|
|
|
# ============================================================
|
|
# Data cache
|
|
# ============================================================
|
|
|
|
class DataCache:
|
|
"""CSV-cached data fetching for index (raw) and ETF (hfq)"""
|
|
|
|
def __init__(self, base_url: str, cache_dir: str = None, timeout: int = 120):
|
|
self.base_url = base_url.rstrip('/')
|
|
self.api_path = '/api/v1/ohlcv'
|
|
self.timeout = timeout
|
|
if cache_dir is None:
|
|
cache_dir = PROJECT_ROOT / 'data' / 'simple_rotation_cache'
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
# premium data cache: {trade_code: {date_str: premium_ratio}}
|
|
self.premium_data: Dict[str, Dict[str, float]] = {}
|
|
|
|
def _cache_path(self, code: str, adj: str) -> Path:
|
|
prefix = 'index' if adj == 'raw' else 'etf'
|
|
safe_code = code.replace('=', '_').replace('^', '_')
|
|
return self.cache_dir / f"{prefix}_{safe_code}.csv"
|
|
|
|
def _premium_cache_path(self, code: str) -> Path:
|
|
safe_code = code.replace('=', '_').replace('^', '_')
|
|
return self.cache_dir / f"premium_{safe_code}.csv"
|
|
|
|
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
|
"""Preload full history and cache to CSV"""
|
|
cache_path = self._cache_path(code, adj)
|
|
if cache_path.exists():
|
|
try:
|
|
df = pd.read_csv(cache_path, index_col='date', parse_dates=True)
|
|
if len(df) > 0:
|
|
cs = df.index.min().strftime('%Y-%m-%d')
|
|
ce = df.index.max().strftime('%Y-%m-%d')
|
|
if cs <= start_date and ce >= end_date:
|
|
return df
|
|
new_start = (df.index.max() + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
if new_start <= end_date:
|
|
new_df = self._fetch_api(code, new_start, end_date, adj)
|
|
if new_df is not None and len(new_df) > 0:
|
|
df = pd.concat([df, new_df])
|
|
df = df[~df.index.duplicated(keep='last')]
|
|
df.to_csv(cache_path)
|
|
return df
|
|
except Exception:
|
|
pass
|
|
df = self._fetch_api(code, start_date, end_date, adj)
|
|
if df is not None and len(df) > 0:
|
|
df.to_csv(cache_path)
|
|
return df
|
|
|
|
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
|
|
"""Fetch from Flask API, also extracts premium_series for ETFs"""
|
|
url = f"{self.base_url}{self.api_path}"
|
|
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
|
|
for attempt in range(3):
|
|
try:
|
|
resp = requests.get(url, params=params, timeout=self.timeout)
|
|
if resp.status_code != 200:
|
|
if attempt < 2:
|
|
time.sleep(1)
|
|
continue
|
|
print(f" x {code}: HTTP {resp.status_code}")
|
|
return None
|
|
data = resp.json()
|
|
if 'error' in data:
|
|
print(f" x {code}: {data['error']}")
|
|
return None
|
|
records = data.get('data', [])
|
|
if not records:
|
|
return None
|
|
df = pd.DataFrame(records)
|
|
if 'date' in df.columns:
|
|
df['date'] = pd.to_datetime(df['date'])
|
|
df = df.set_index('date').sort_index()
|
|
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
|
|
df = df[keep]
|
|
# Extract and cache premium_series (ETF only)
|
|
premium_series = data.get('premium_series', [])
|
|
if premium_series:
|
|
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
|
|
self._save_premium_cache(code, df.attrs['premium_series'])
|
|
print(f" + {code}: {len(df)} rows ({adj})")
|
|
return df
|
|
except requests.exceptions.Timeout:
|
|
if attempt < 2:
|
|
continue
|
|
print(f" x {code}: timeout")
|
|
return None
|
|
except Exception as e:
|
|
print(f" x {code}: {e}")
|
|
return None
|
|
return None
|
|
|
|
def _save_premium_cache(self, code: str, premium_dict: Dict[str, float]):
|
|
"""Save premium data to CSV cache"""
|
|
try:
|
|
cache_path = self._premium_cache_path(code)
|
|
pd.DataFrame(
|
|
[{'date': k, 'premium': v} for k, v in premium_dict.items()]
|
|
).to_csv(cache_path, index=False)
|
|
except Exception:
|
|
pass
|
|
|
|
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
|
|
"""Load premium data for an ETF code from cache, with incremental update.
|
|
If cache exists but doesn't cover end_date, fetches the gap."""
|
|
if code in self.premium_data:
|
|
# Already in memory - check if up-to-date
|
|
if end_date:
|
|
dates = sorted(self.premium_data[code].keys())
|
|
if dates and dates[-1] >= end_date:
|
|
return self.premium_data[code]
|
|
else:
|
|
return self.premium_data[code]
|
|
cache_path = self._premium_cache_path(code)
|
|
if cache_path.exists():
|
|
try:
|
|
df = pd.read_csv(cache_path)
|
|
if len(df) > 0 and 'date' in df.columns and 'premium' in df.columns:
|
|
self.premium_data[code] = dict(zip(df['date'].astype(str), df['premium']))
|
|
# Check if cache covers end_date
|
|
if end_date:
|
|
latest_cached = max(self.premium_data[code].keys())
|
|
if latest_cached >= end_date:
|
|
return self.premium_data[code]
|
|
# Cache is stale - fetch gap from latest_cached+1 to end_date
|
|
fetch_start = (pd.Timestamp(latest_cached) + timedelta(days=1)).strftime('%Y-%m-%d')
|
|
self._fetch_premium_api(code, fetch_start, end_date)
|
|
return self.premium_data[code]
|
|
except Exception:
|
|
pass
|
|
# No cache: fetch full history from API
|
|
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
|
|
return self.premium_data.get(code)
|
|
|
|
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
|
|
"""Fetch premium_series from API and merge into cache"""
|
|
url = f"{self.base_url}{self.api_path}"
|
|
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
|
|
for attempt in range(3):
|
|
try:
|
|
resp = requests.get(url, params=params, timeout=self.timeout)
|
|
if resp.status_code != 200:
|
|
if attempt < 2:
|
|
time.sleep(1)
|
|
continue
|
|
return
|
|
data = resp.json()
|
|
if 'error' in data:
|
|
return
|
|
premium_series = data.get('premium_series', [])
|
|
if premium_series:
|
|
new_data = {item['date']: item['premium'] for item in premium_series}
|
|
if code not in self.premium_data:
|
|
self.premium_data[code] = {}
|
|
self.premium_data[code].update(new_data)
|
|
self._save_premium_cache(code, self.premium_data[code])
|
|
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
|
|
return
|
|
except requests.exceptions.Timeout:
|
|
if attempt < 2:
|
|
continue
|
|
return
|
|
except Exception:
|
|
return
|
|
|
|
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
|
|
"""Fetch trading calendar from API"""
|
|
url = f"{self.base_url}/api/v1/trading-calendar"
|
|
params = {'market': market, 'start': start_date, 'end': end_date}
|
|
for attempt in range(3):
|
|
try:
|
|
resp = requests.get(url, params=params, timeout=self.timeout)
|
|
if resp.status_code != 200:
|
|
if attempt < 2:
|
|
continue
|
|
return None
|
|
data = resp.json()
|
|
if 'error' in data:
|
|
print(f" x calendar: {data['error']}")
|
|
return None
|
|
dates = data.get('trading_dates', [])
|
|
if not dates:
|
|
return pd.DatetimeIndex([])
|
|
result = pd.DatetimeIndex(dates)
|
|
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
|
|
return result
|
|
except Exception as e:
|
|
if attempt < 2:
|
|
continue
|
|
print(f" x calendar: {e}")
|
|
return None
|
|
return None
|
|
|
|
|
|
# ============================================================
|
|
# Core strategy class
|
|
# ============================================================
|
|
|
|
class SimpleRotationStrategy:
|
|
"""
|
|
Simple rotation strategy (daily iteration)
|
|
|
|
Flow:
|
|
1. Load config
|
|
2. Preload all historical data (index raw + ETF hfq)
|
|
3. Fetch A-share trading calendar
|
|
4. For each trading day:
|
|
- Compute momentum (last 25 days)
|
|
- Group selection -> generate holdings
|
|
- Compare with yesterday -> compute T+1 return
|
|
- Update NAV
|
|
5. Export results
|
|
"""
|
|
|
|
def __init__(self, config_path: str = None):
|
|
if config_path is None:
|
|
config_path = Path(__file__).parent / 'config_simple.yaml'
|
|
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
|
|
|
|
# Strategy params
|
|
self.n_days = self.config.factor.n_days
|
|
self.select_num = self.config.rotation.select_num
|
|
self.trade_cost = self.config.rebalance.trade_cost
|
|
|
|
# Dynamic threshold
|
|
threshold = self.config.rotation.threshold
|
|
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
|
|
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
|
|
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
|
|
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
|
|
|
|
# Signal codes and mappings
|
|
self.signal_codes = []
|
|
self.signal_to_trade = {}
|
|
self.code_to_group = {}
|
|
self.trade_code_to_group = {}
|
|
for code, asset in self.config.asset_pools.assets.items():
|
|
self.signal_codes.append(asset.signal_source)
|
|
self.signal_to_trade[asset.signal_source] = asset.trade_source
|
|
self.code_to_group[asset.signal_source] = asset.group
|
|
self.trade_code_to_group[asset.trade_source] = asset.group
|
|
|
|
# Data source
|
|
data_source = self.config.data.sources[0]
|
|
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
|
|
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
|
|
|
|
# Preloaded data
|
|
self.index_data: Dict[str, pd.DataFrame] = {}
|
|
self.etf_data: Dict[str, pd.DataFrame] = {}
|
|
self.benchmark_data: Optional[pd.DataFrame] = None
|
|
|
|
# Benchmark config
|
|
self.benchmark_code = self.config.benchmark.code
|
|
self.benchmark_name = self.config.benchmark.name
|
|
|
|
# Results
|
|
self.daily_records: List[dict] = []
|
|
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
|
|
|
def _preload_data(self):
|
|
"""Preload all historical data"""
|
|
start_date = self.config.backtest.start_date
|
|
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
|
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
|
|
|
|
print("\n[1/4] Preloading signal sources (index raw)...")
|
|
for code in self.signal_codes:
|
|
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
|
|
if df is not None:
|
|
self.index_data[code] = df
|
|
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
|
|
|
|
print("\n[2/4] Preloading trade sources (ETF hfq)...")
|
|
trade_codes = set(self.signal_to_trade.values())
|
|
for code in trade_codes:
|
|
is_bond = any(
|
|
a.trade_source == code and a.group == 'BOND'
|
|
for a in self.config.asset_pools.assets.values()
|
|
)
|
|
adj = 'raw' if is_bond else 'hfq'
|
|
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
|
|
if df is not None:
|
|
self.etf_data[code] = df
|
|
# Load premium data cache for all ETF trade codes
|
|
for code in trade_codes:
|
|
self.data_cache.preload_premium(code, end_date=end_date)
|
|
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded")
|
|
|
|
# Load benchmark
|
|
print(f"\n Loading benchmark ({self.benchmark_code})...")
|
|
bm_df = self.data_cache.preload(self.benchmark_code, preload_start, end_date, adj='raw')
|
|
if bm_df is not None:
|
|
self.benchmark_data = bm_df
|
|
print(f" Benchmark: {len(bm_df)} rows")
|
|
|
|
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
|
"""Compute momentum for a single code on a given date"""
|
|
if signal_code not in self.index_data:
|
|
return None
|
|
df = self.index_data[signal_code]
|
|
mask = df.index <= date
|
|
recent = df.loc[mask]
|
|
if len(recent) < self.n_days:
|
|
return None
|
|
prices = recent['close'].values[-self.n_days:]
|
|
if len(prices) >= 4 and is_crash(prices):
|
|
return 0.0
|
|
return weighted_momentum_score(prices)
|
|
|
|
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
|
"""
|
|
Generate rotation signals (group competition + dynamic threshold + BOND fill)
|
|
|
|
Returns:
|
|
(holdings, factors, bond_momentum): tuple of selected holdings,
|
|
all computed factor scores, and the bond momentum threshold value.
|
|
|
|
Logic (identical to V2):
|
|
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
|
|
2. From all group winners: sort by momentum, select Top select_num
|
|
3. Fill remaining slots with BOND
|
|
"""
|
|
factors: Dict[str, float] = {}
|
|
for code in self.signal_codes:
|
|
score = self._compute_momentum(code, date)
|
|
if score is not None:
|
|
factors[code] = score
|
|
if not factors:
|
|
return [], {}, None
|
|
|
|
bond_momentum = None
|
|
if self.use_dynamic_threshold and self.bond_code:
|
|
bond_momentum = self._compute_momentum(self.bond_code, date)
|
|
if bond_momentum is None:
|
|
bond_momentum = self.fallback_value
|
|
|
|
groups = self.config.asset_pools.by_group
|
|
selected_by_group: Dict[str, Tuple[str, float]] = {}
|
|
|
|
for group_name, assets in groups.items():
|
|
group_codes = [a.signal_source for a in assets.values()]
|
|
group_factors = {c: factors[c] for c in group_codes if c in factors}
|
|
if not group_factors:
|
|
continue
|
|
if group_name != 'BOND' and bond_momentum is not None:
|
|
thresh = bond_momentum * self.bond_ratio
|
|
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
|
|
if not group_factors:
|
|
continue
|
|
top_code = max(group_factors, key=group_factors.get)
|
|
selected_by_group[group_name] = (top_code, group_factors[top_code])
|
|
|
|
if not selected_by_group:
|
|
return [], factors, bond_momentum
|
|
|
|
candidates = list(selected_by_group.values())
|
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
|
final_holdings = [code for code, _ in candidates[:self.select_num]]
|
|
|
|
if len(final_holdings) < self.select_num and self.bond_code:
|
|
if self.bond_code not in final_holdings:
|
|
n_slots = self.select_num - len(final_holdings)
|
|
final_holdings.extend([self.bond_code] * n_slots)
|
|
|
|
return sorted(final_holdings), factors, bond_momentum
|
|
|
|
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
|
"""Get ETF prices on a given date: {open, close, prev_close}
|
|
Note: Index data may lack 'open' column; use close as fallback.
|
|
"""
|
|
if trade_code not in self.etf_data:
|
|
return None
|
|
df = self.etf_data[trade_code]
|
|
mask = df.index <= date
|
|
recent = df.loc[mask]
|
|
if len(recent) < 2:
|
|
return None
|
|
today_row = recent.iloc[-1]
|
|
prev_row = recent.iloc[-2]
|
|
today_date = recent.index[-1]
|
|
if (date - today_date).days > 5:
|
|
return None
|
|
close = float(today_row['close'])
|
|
prev_close = float(prev_row['close'])
|
|
# Handle missing open (common for index data like 931862.CSI)
|
|
open_price = float(today_row.get('open', close)) if 'open' in today_row.index else close
|
|
if pd.isna(open_price) or open_price == 0:
|
|
open_price = close
|
|
if pd.isna(close) or pd.isna(prev_close):
|
|
return None
|
|
return {
|
|
'open': open_price,
|
|
'close': close,
|
|
'prev_close': prev_close,
|
|
}
|
|
|
|
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
|
"""
|
|
Compute daily return (T+1 execution):
|
|
- Hold: close-to-close
|
|
- Sell: close-to-open (sold at open)
|
|
- Buy: open-to-close (intraday)
|
|
"""
|
|
if not old_holdings:
|
|
if not new_holdings:
|
|
return 0.0
|
|
weight = 1.0 / len(new_holdings)
|
|
ret = 0.0
|
|
for code in new_holdings:
|
|
tc = self.signal_to_trade.get(code, code)
|
|
p = self._get_etf_prices(tc, date)
|
|
if p and p['open'] > 0:
|
|
ret += weight * (p['close'] - p['open']) / p['open']
|
|
if is_rebalance:
|
|
ret -= self.trade_cost
|
|
return ret
|
|
|
|
old_set = set(old_holdings)
|
|
new_set = set(new_holdings)
|
|
weight = 1.0 / len(old_holdings)
|
|
daily_return = 0.0
|
|
|
|
for code in old_holdings:
|
|
tc = self.signal_to_trade.get(code, code)
|
|
p = self._get_etf_prices(tc, date)
|
|
if p is None or p['prev_close'] == 0:
|
|
continue
|
|
if code in new_set:
|
|
r = (p['close'] - p['prev_close']) / p['prev_close']
|
|
else:
|
|
r = (p['open'] - p['prev_close']) / p['prev_close']
|
|
if not math.isnan(r):
|
|
daily_return += weight * r
|
|
|
|
for code in new_holdings:
|
|
if code not in old_set:
|
|
tc = self.signal_to_trade.get(code, code)
|
|
p = self._get_etf_prices(tc, date)
|
|
if p and p['open'] > 0 and not math.isnan(p['close']):
|
|
r = (p['close'] - p['open']) / p['open']
|
|
if not math.isnan(r):
|
|
daily_return += weight * r
|
|
|
|
if is_rebalance:
|
|
daily_return -= self.trade_cost
|
|
return daily_return
|
|
|
|
def run(self) -> dict:
|
|
"""Main backtest loop"""
|
|
print("=" * 60)
|
|
print(" Simple Rotation Strategy (Daily Iteration)")
|
|
print("=" * 60)
|
|
|
|
self._preload_data()
|
|
|
|
print("\n[3/4] Fetching A-share trading calendar...")
|
|
start_date = self.config.backtest.start_date
|
|
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
|
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
|
|
|
|
if self.trading_calendar is None or len(self.trading_calendar) == 0:
|
|
print(" x Calendar fetch failed")
|
|
return {}
|
|
|
|
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
|
|
current_holdings: List[str] = []
|
|
nav = 1.0
|
|
rebalance_count = 0
|
|
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
|
|
|
|
for i, date in enumerate(self.trading_calendar):
|
|
# Signal timing: 9:00 AM on day T
|
|
# At this moment, T's market has NOT opened yet.
|
|
# Only T-1 close data is available for all markets.
|
|
# So momentum must be computed from T-1 close (prev_date).
|
|
# Execution happens at 9:30 AM using T's ETF prices.
|
|
if i > 0:
|
|
signal_date = self.trading_calendar[i - 1] # T-1 close
|
|
else:
|
|
signal_date = date # First day: no prior trading day available
|
|
|
|
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
|
|
|
|
# Data integrity check: if any currently held asset is missing from
|
|
# today's factors, abort immediately to prevent false rebalancing.
|
|
if current_holdings:
|
|
missing = [c for c in current_holdings if c not in factors]
|
|
if missing:
|
|
raise RuntimeError(
|
|
f"Data failure: held assets {missing} missing from factors on "
|
|
f"{date.strftime('%Y-%m-%d')}. Aborting to prevent false rebalance."
|
|
)
|
|
|
|
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
|
|
|
|
# Return uses T's ETF prices (open for buy/sell, close for hold)
|
|
daily_return = self._calculate_daily_return(
|
|
current_holdings, new_holdings, date, is_rebalance
|
|
)
|
|
nav *= (1 + daily_return)
|
|
|
|
if is_rebalance:
|
|
rebalance_count += 1
|
|
|
|
# Update entry tracking for held assets
|
|
added = set(new_holdings) - set(current_holdings)
|
|
removed = set(current_holdings) - set(new_holdings)
|
|
for code in added:
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
etf_prices = self._get_etf_prices(trade_code, date)
|
|
# Entry price is the actual buy price: today's open
|
|
entry_etf = etf_prices['open'] if etf_prices else None
|
|
# Index close at T-1 (the data used to make the decision)
|
|
entry_idx = self._get_index_close(code, signal_date)
|
|
entry_info[code] = {
|
|
'entry_date': date.strftime('%Y-%m-%d'),
|
|
'entry_price_etf': entry_etf,
|
|
'entry_price_idx': entry_idx,
|
|
}
|
|
for code in removed:
|
|
entry_info.pop(code, None)
|
|
|
|
# Compute bond threshold value for detail record
|
|
threshold_val = 0.0
|
|
if self.use_dynamic_threshold and bond_momentum is not None:
|
|
threshold_val = round(bond_momentum * self.bond_ratio, 6)
|
|
|
|
self.daily_records.append({
|
|
'date': date.strftime('%Y-%m-%d'),
|
|
'nav': round(nav, 6),
|
|
'daily_return': round(daily_return, 6),
|
|
'is_rebalance': is_rebalance,
|
|
'holdings': sorted(new_holdings),
|
|
'added': sorted(added),
|
|
'removed': sorted(removed),
|
|
'factors': {k: round(v, 6) for k, v in factors.items()},
|
|
'threshold': threshold_val,
|
|
})
|
|
|
|
current_holdings = new_holdings
|
|
|
|
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
|
|
print(f" [{i+1}/{len(self.trading_calendar)}] "
|
|
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
|
|
|
|
metrics = self._compute_metrics(rebalance_count)
|
|
|
|
print("\n" + "=" * 60)
|
|
print(" Backtest Complete")
|
|
print("=" * 60)
|
|
if self.daily_records:
|
|
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
|
|
print(f" Trading days: {len(self.daily_records)}")
|
|
print(f" Total return: {metrics['total_return']:.2%}")
|
|
print(f" Annual return: {metrics['annual_return']:.2%}")
|
|
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
|
|
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
|
|
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
|
|
print(f" Win rate: {metrics['win_rate']:.2%}")
|
|
print(f" Rebalances: {rebalance_count}")
|
|
print("=" * 60)
|
|
|
|
return {
|
|
'metrics': metrics,
|
|
'daily_records': self.daily_records,
|
|
}
|
|
|
|
def _compute_metrics(self, rebalance_count: int) -> dict:
|
|
"""Compute performance metrics"""
|
|
if not self.daily_records:
|
|
return {}
|
|
df = pd.DataFrame(self.daily_records)
|
|
df['date'] = pd.to_datetime(df['date'])
|
|
nav = df['nav']
|
|
returns = df['daily_return']
|
|
|
|
total_return = nav.iloc[-1] / nav.iloc[0] - 1
|
|
n_days = len(df)
|
|
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
|
|
|
peak = nav.cummax()
|
|
drawdown = (nav - peak) / peak
|
|
max_drawdown = drawdown.min()
|
|
|
|
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
|
|
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
|
|
|
|
non_zero = returns[returns != 0]
|
|
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
|
|
|
|
return {
|
|
'total_return': total_return,
|
|
'annual_return': annual_return,
|
|
'max_drawdown': max_drawdown,
|
|
'sharpe_ratio': sharpe,
|
|
'calmar_ratio': calmar,
|
|
'win_rate': win_rate,
|
|
'n_days': n_days,
|
|
'rebalance_count': rebalance_count,
|
|
}
|
|
|
|
# ============================================================
|
|
# Detail JSON helpers (V2-compatible)
|
|
# ============================================================
|
|
|
|
def _build_meta_codes(self) -> dict:
|
|
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
|
|
codes = {}
|
|
for code, asset in self.config.asset_pools.assets.items():
|
|
codes[asset.signal_source] = {
|
|
'name': getattr(asset, 'name', code),
|
|
'etf': asset.trade_source,
|
|
'market': asset.group,
|
|
}
|
|
return codes
|
|
|
|
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
|
|
"""Get index close price on or before date."""
|
|
df = self.index_data.get(code)
|
|
if df is None:
|
|
return None
|
|
mask = df.index <= date
|
|
if mask.sum() == 0:
|
|
return None
|
|
return float(df.loc[mask].iloc[-1]['close'])
|
|
|
|
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
|
"""Get ETF close price on or before date."""
|
|
df = self.etf_data.get(trade_code)
|
|
if df is None:
|
|
return None
|
|
mask = df.index <= date
|
|
if mask.sum() == 0:
|
|
return None
|
|
return float(df.loc[mask].iloc[-1]['close'])
|
|
|
|
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
|
|
"""Get (index_return, etf_return_ctc) for a code on a given date."""
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
idx_df = self.index_data.get(code)
|
|
etf_df = self.etf_data.get(trade_code)
|
|
|
|
idx_ret = None
|
|
if idx_df is not None:
|
|
mask = idx_df.index <= date
|
|
if mask.sum() >= 2:
|
|
rows = idx_df.loc[mask]
|
|
c_today = float(rows.iloc[-1]['close'])
|
|
c_prev = float(rows.iloc[-2]['close'])
|
|
if c_prev > 0 and not pd.isna(c_today):
|
|
idx_ret = round((c_today - c_prev) / c_prev, 6)
|
|
|
|
etf_ret = None
|
|
if etf_df is not None:
|
|
mask = etf_df.index <= date
|
|
if mask.sum() >= 2:
|
|
rows = etf_df.loc[mask]
|
|
c_today = float(rows.iloc[-1]['close'])
|
|
c_prev = float(rows.iloc[-2]['close'])
|
|
if c_prev > 0 and not pd.isna(c_today):
|
|
etf_ret = round((c_today - c_prev) / c_prev, 6)
|
|
|
|
return idx_ret, etf_ret
|
|
|
|
def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
|
"""Get real premium from API data cache: (ETF_price - NAV) / NAV.
|
|
Returns None for BOND or when premium data is unavailable."""
|
|
group = self.trade_code_to_group.get(trade_code, '')
|
|
if group == 'BOND':
|
|
return None
|
|
premium_dict = self.data_cache.premium_data.get(trade_code)
|
|
if not premium_dict:
|
|
return None
|
|
date_str = date.strftime('%Y-%m-%d')
|
|
val = premium_dict.get(date_str)
|
|
if val is None:
|
|
return None
|
|
return round(float(val), 6)
|
|
|
|
def _get_latest_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
|
"""Get premium for trade_code, looking back up to 5 days if exact date not found."""
|
|
group = self.trade_code_to_group.get(trade_code, '')
|
|
if group == 'BOND':
|
|
return None
|
|
premium_dict = self.data_cache.premium_data.get(trade_code)
|
|
if not premium_dict:
|
|
return None
|
|
# Try exact date first, then look back up to 5 calendar days
|
|
for offset in range(6):
|
|
check_date = date - timedelta(days=offset)
|
|
date_str = check_date.strftime('%Y-%m-%d')
|
|
val = premium_dict.get(date_str)
|
|
if val is not None:
|
|
return round(float(val), 6)
|
|
return None
|
|
|
|
def _build_day_assets(self, record: dict, date: pd.Timestamp,
|
|
entry_info: Dict[str, dict]) -> dict:
|
|
"""Build V2-compatible per-asset detail dict for one day."""
|
|
factors = record.get('factors', {})
|
|
holdings = set(record['holdings'])
|
|
threshold = record.get('threshold', 0.0)
|
|
|
|
# Rank: sort all codes by momentum descending, rank 1 = highest
|
|
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
|
|
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
|
|
|
|
assets = {}
|
|
for code in self.signal_codes:
|
|
momentum = factors.get(code)
|
|
rank = rank_map.get(code)
|
|
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
|
|
is_held = code in holdings
|
|
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
idx_close = self._get_index_close(code, date)
|
|
etf_close = self._get_etf_close(trade_code, date)
|
|
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
|
|
premium = self._compute_premium(trade_code, date)
|
|
|
|
# Entry / holding info
|
|
ei = entry_info.get(code) if is_held else None
|
|
entry_date = ei['entry_date'] if ei else None
|
|
entry_price_etf = ei['entry_price_etf'] if ei else None
|
|
entry_price_idx = ei['entry_price_idx'] if ei else None
|
|
holding_days = 0
|
|
cum_ret_etf = None
|
|
cum_ret_idx = None
|
|
|
|
if ei and is_held:
|
|
entry_dt = pd.Timestamp(ei['entry_date'])
|
|
holding_days = int((date - entry_dt).days)
|
|
if holding_days < 1:
|
|
holding_days = 1
|
|
ep_etf = ei.get('entry_price_etf')
|
|
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
|
|
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
|
|
ep_idx = ei.get('entry_price_idx')
|
|
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
|
|
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
|
|
|
|
assets[code] = {
|
|
'momentum': float(momentum) if momentum is not None else None,
|
|
'rank': rank,
|
|
'threshold': float(threshold),
|
|
'above_threshold': bool(above_thresh),
|
|
'index_close': float(idx_close) if idx_close is not None else None,
|
|
'etf_close': float(etf_close) if etf_close is not None else None,
|
|
'index_return': float(idx_ret) if idx_ret is not None else None,
|
|
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
|
|
'premium': float(premium) if premium is not None else None,
|
|
'is_held': bool(is_held),
|
|
'entry_date': entry_date,
|
|
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
|
|
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
|
|
'holding_days': holding_days,
|
|
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
|
|
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
|
|
}
|
|
return assets
|
|
|
|
# ============================================================
|
|
# Export
|
|
# ============================================================
|
|
|
|
def export_results(self, output_dir: str = None):
|
|
"""Export backtest results to CSV and JSON (V2-compatible detail format)"""
|
|
if not self.daily_records:
|
|
print(" x No results to export")
|
|
return
|
|
|
|
if output_dir is None:
|
|
output_dir = Path(__file__).parent / 'results'
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
df = pd.DataFrame(self.daily_records)
|
|
|
|
# NAV curve
|
|
nav_path = output_dir / 'simple_rotation_nav.csv'
|
|
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
|
|
print(f" + NAV: {nav_path}")
|
|
|
|
# Signals
|
|
sig_path = output_dir / 'simple_rotation_signals.csv'
|
|
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
|
|
print(f" + Signals: {sig_path}")
|
|
|
|
# Detail JSON (V2-compatible format)
|
|
detail_path = output_dir / 'simple_rotation_detail.json'
|
|
days_out = []
|
|
# Track entry_info across days for asset detail reconstruction
|
|
tracked_entry: Dict[str, dict] = {}
|
|
prev_holdings = []
|
|
|
|
# Build date index map for signal_date lookup (T-1)
|
|
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
|
|
date_to_signal_date = {}
|
|
for i, d in enumerate(date_list):
|
|
date_to_signal_date[d] = date_list[i - 1] if i > 0 else d
|
|
|
|
for rec in self.daily_records:
|
|
date = pd.Timestamp(rec['date'])
|
|
signal_date = date_to_signal_date[date] # T-1 for signal
|
|
holdings = rec['holdings']
|
|
added = set(holdings) - set(prev_holdings)
|
|
removed = set(prev_holdings) - set(holdings)
|
|
|
|
# Update entry tracking (consistent with run() logic)
|
|
for code in added:
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
etf_prices = self._get_etf_prices(trade_code, date)
|
|
# Entry price = actual buy price at T's open
|
|
entry_etf = etf_prices['open'] if etf_prices else None
|
|
# Index close at T-1 (signal data used for decision)
|
|
idx_close = self._get_index_close(code, signal_date)
|
|
tracked_entry[code] = {
|
|
'entry_date': date.strftime('%Y-%m-%d'),
|
|
'entry_price_etf': entry_etf,
|
|
'entry_price_idx': idx_close,
|
|
}
|
|
for code in removed:
|
|
tracked_entry.pop(code, None)
|
|
|
|
# Build signals dict: {code: 1} for selected holdings
|
|
signals = {c: 1 for c in holdings}
|
|
|
|
# Build per-asset details
|
|
assets = self._build_day_assets(rec, date, tracked_entry)
|
|
|
|
days_out.append({
|
|
'date': rec['date'],
|
|
'nav': rec['nav'],
|
|
'daily_return': rec['daily_return'],
|
|
'is_rebalance': rec['is_rebalance'],
|
|
'signals': signals,
|
|
'holdings': holdings,
|
|
'added': rec['added'],
|
|
'removed': rec['removed'],
|
|
'assets': assets,
|
|
})
|
|
prev_holdings = holdings
|
|
|
|
detail = {
|
|
'meta': {
|
|
'mode': 'Simple: Daily Iteration',
|
|
'start_date': self.config.backtest.start_date,
|
|
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
|
|
'total_days': len(self.daily_records),
|
|
'select_num': self.select_num,
|
|
'n_days': self.n_days,
|
|
'trade_cost': self.trade_cost,
|
|
'bond_threshold': {
|
|
'enabled': self.use_dynamic_threshold,
|
|
'bond_code': self.bond_code,
|
|
'ratio': self.bond_ratio,
|
|
},
|
|
'codes': self._build_meta_codes(),
|
|
},
|
|
'days': days_out,
|
|
}
|
|
with open(detail_path, 'w', encoding='utf-8') as f:
|
|
json.dump(detail, f, ensure_ascii=False, indent=2)
|
|
print(f" + Detail: {detail_path} ({len(days_out)} days)")
|
|
|
|
# Metrics JSON
|
|
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
|
|
metrics_path = output_dir / 'simple_rotation_metrics.json'
|
|
with open(metrics_path, 'w', encoding='utf-8') as f:
|
|
json.dump(metrics, f, ensure_ascii=False, indent=2)
|
|
print(f" + Metrics: {metrics_path}")
|
|
|
|
# ============================================================
|
|
# Report Generation (PNG chart with tables)
|
|
# ============================================================
|
|
|
|
def generate_report(self, output_dir: str = None):
|
|
"""Generate performance report chart (PNG) with signal table, metrics, NAV, drawdown, holdings"""
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
if not self.daily_records:
|
|
print(" x No results for report")
|
|
return
|
|
|
|
if output_dir is None:
|
|
output_dir = Path(__file__).parent / 'results'
|
|
output_dir = Path(output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Build DataFrames
|
|
df = pd.DataFrame(self.daily_records)
|
|
df['date'] = pd.to_datetime(df['date'])
|
|
df = df.set_index('date')
|
|
strategy_nav = df['nav']
|
|
strategy_ret = df['daily_return']
|
|
|
|
# Compute benchmark NAV
|
|
benchmark_nav = None
|
|
if self.benchmark_data is not None:
|
|
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
|
|
if bm_close is not None and not bm_close.isna().all():
|
|
benchmark_nav = (1 + bm_close.pct_change()).cumprod()
|
|
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
|
|
benchmark_nav = benchmark_nav / first_valid
|
|
|
|
# Compute individual asset NAVs
|
|
asset_navs = {}
|
|
for code in self.signal_codes:
|
|
if code in self.index_data:
|
|
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
|
|
if price is not None and not price.isna().all():
|
|
nav_s = (1 + price.pct_change()).cumprod()
|
|
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
|
|
asset_navs[code] = nav_s / fv
|
|
|
|
# Compute metrics
|
|
s_total_return = strategy_nav.iloc[-1] / strategy_nav.iloc[0] - 1
|
|
n_days = len(df)
|
|
s_annual = (1 + s_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
|
s_sharpe = strategy_ret.mean() / strategy_ret.std() * np.sqrt(252) if strategy_ret.std() > 0 else 0
|
|
s_peak = strategy_nav.cummax()
|
|
s_dd = ((strategy_nav - s_peak) / s_peak).min()
|
|
s_calmar = s_annual / abs(s_dd) if s_dd != 0 else 0
|
|
non_zero = strategy_ret[strategy_ret != 0]
|
|
s_win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
|
|
|
|
# Benchmark metrics
|
|
b_total_return = b_annual = b_sharpe = b_dd = 0
|
|
if benchmark_nav is not None:
|
|
bm_ret = benchmark_nav.pct_change()
|
|
b_total_return = benchmark_nav.iloc[-1] - 1
|
|
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
|
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
|
|
b_peak = benchmark_nav.cummax()
|
|
b_dd = ((benchmark_nav - b_peak) / b_peak).min()
|
|
|
|
# Get latest holdings info
|
|
last_rec = self.daily_records[-1]
|
|
last_date = pd.Timestamp(last_rec['date'])
|
|
holdings = last_rec['holdings']
|
|
factors = last_rec.get('factors', {})
|
|
|
|
# Build code config dict for display
|
|
code_config = {}
|
|
for name, asset in self.config.asset_pools.assets.items():
|
|
code_config[asset.signal_source] = {
|
|
'name': asset.name if hasattr(asset, 'name') else name,
|
|
'etf': asset.trade_source,
|
|
'market': asset.group,
|
|
}
|
|
|
|
# Build positions info for table
|
|
# Sort holdings by momentum score descending
|
|
weight = 1.0 / self.select_num if self.select_num > 0 else 1.0
|
|
sorted_holdings = sorted(holdings, key=lambda c: factors.get(c, 0) or 0, reverse=True)
|
|
|
|
# Determine previous holdings to distinguish "调入" vs "维持"
|
|
last_idx = len(self.daily_records) - 1
|
|
prev_holdings = set()
|
|
if last_idx > 0:
|
|
prev_holdings = set(self.daily_records[last_idx - 1]['holdings'])
|
|
is_rebalance_day = last_rec.get('is_rebalance', False)
|
|
|
|
positions_info = []
|
|
for code in sorted_holdings:
|
|
cfg = code_config.get(code, {})
|
|
name = cfg.get('name', code)
|
|
etf_code = cfg.get('etf', '—')
|
|
score = factors.get(code)
|
|
idx_close = self._get_index_close(code, last_date)
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
etf_close = self._get_etf_close(trade_code, last_date)
|
|
premium = self._get_latest_premium(trade_code, last_date)
|
|
|
|
# Determine action: 调入 (new on rebalance day) vs 维持 (already holding)
|
|
if is_rebalance_day and code not in prev_holdings:
|
|
action = '调入'
|
|
else:
|
|
action = '维持'
|
|
|
|
# Find entry info: scan backwards for continuous holding start
|
|
entry_date = None
|
|
entry_price = None
|
|
holding_days = 0
|
|
pnl = None
|
|
for rec in reversed(self.daily_records):
|
|
if code in rec['holdings']:
|
|
entry_date = pd.Timestamp(rec['date'])
|
|
p = self._get_etf_prices(trade_code, entry_date)
|
|
entry_price = p['open'] if p else None
|
|
else:
|
|
break
|
|
if entry_date is not None:
|
|
holding_days = (last_date - entry_date).days
|
|
if entry_price and entry_price > 0 and etf_close and etf_close > 0:
|
|
pnl = etf_close / entry_price - 1
|
|
|
|
positions_info.append({
|
|
'name': name, 'code': code, 'etf': etf_code,
|
|
'weight': weight, 'score': score,
|
|
'idx_close': idx_close, 'etf_close': etf_close,
|
|
'premium': premium, 'action': action,
|
|
'entry_date': entry_date, 'entry_price': entry_price,
|
|
'holding_days': holding_days, 'pnl': pnl,
|
|
})
|
|
|
|
# Find exit positions: ONLY show when the last day is a rebalance day
|
|
# If last day is NOT rebalance, don't show any "调出" info (already past)
|
|
exit_positions = []
|
|
if last_rec.get('is_rebalance', False):
|
|
# Find the previous day's holdings to compare
|
|
last_idx = len(self.daily_records) - 1
|
|
if last_idx > 0:
|
|
old_holdings = set(self.daily_records[last_idx - 1]['holdings'])
|
|
new_holdings = set(holdings)
|
|
removed = old_holdings - new_holdings
|
|
for code in sorted(removed):
|
|
cfg = code_config.get(code, {})
|
|
name = cfg.get('name', code)
|
|
etf_code = cfg.get('etf', '—')
|
|
idx_close = self._get_index_close(code, last_date)
|
|
trade_code = self.signal_to_trade.get(code, code)
|
|
etf_close = self._get_etf_close(trade_code, last_date)
|
|
premium = self._get_latest_premium(trade_code, last_date)
|
|
exit_positions.append({
|
|
'name': name, 'code': code, 'etf': etf_code,
|
|
'weight': weight, 'score': None,
|
|
'idx_close': idx_close, 'etf_close': etf_close,
|
|
'premium': premium, 'action': '调出',
|
|
'entry_date': None, 'entry_price': None,
|
|
'holding_days': 0, 'pnl': None,
|
|
})
|
|
|
|
# ==================== Plot ====================
|
|
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "WenQuanYi Zen Hei", "DejaVu Sans"]
|
|
plt.rcParams["axes.unicode_minus"] = False
|
|
|
|
n_rows = len(positions_info) + len(exit_positions)
|
|
signal_h = max(1.5, 0.5 + n_rows * 0.35)
|
|
fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 8))
|
|
gs = fig.add_gridspec(5, 1, height_ratios=[signal_h, 1.2, 3, 1, 1.2], hspace=0.35)
|
|
|
|
# Panel 0: Signal table
|
|
ax0 = fig.add_subplot(gs[0])
|
|
ax0.axis("off")
|
|
ax0.set_title(f"最新调仓信号 (信号日期: {last_date.strftime('%Y-%m-%d')},下一交易日执行)",
|
|
fontsize=14, fontweight="bold", loc="left", pad=15)
|
|
|
|
col_labels = ["标的名称", "指数代码", "ETF代码", "仓位", "得分", "指数最新价",
|
|
"ETF收盘价", "溢价率", "操作", "持有天数", "盈亏"]
|
|
table_data = []
|
|
for p in positions_info:
|
|
score_s = f"{p['score']:.2f}" if p['score'] is not None else "—"
|
|
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else "—"
|
|
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else "—"
|
|
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else "—"
|
|
days_s = str(p['holding_days']) if p['holding_days'] > 0 else "—"
|
|
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else "—"
|
|
table_data.append([
|
|
p['name'], p['code'], p['etf'], f"{p['weight']:.0%}",
|
|
score_s, idx_s, etf_s, prem_s, p['action'], days_s, pnl_s
|
|
])
|
|
for p in exit_positions:
|
|
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else "—"
|
|
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else "—"
|
|
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else "—"
|
|
table_data.append([
|
|
p['name'], p['code'], p['etf'], f"{p['weight']:.0%}",
|
|
"—", idx_s, etf_s, prem_s, "调出", "—", "—"
|
|
])
|
|
|
|
if table_data:
|
|
tbl = ax0.table(cellText=table_data, colLabels=col_labels, loc="center",
|
|
cellLoc="center", bbox=[0, 0, 1, 1])
|
|
tbl.auto_set_font_size(False)
|
|
tbl.set_fontsize(9)
|
|
tbl.scale(1, 1.8)
|
|
for j in range(len(col_labels)):
|
|
tbl[0, j].set_facecolor("#2C3E50")
|
|
tbl[0, j].set_text_props(color="white", fontweight="bold")
|
|
for i in range(len(table_data)):
|
|
action = table_data[i][8]
|
|
# Legacy color scheme: 调入=green, 维持=yellow, 调出=red
|
|
if action == "调入":
|
|
color = "#d4edda" # 浅绿色
|
|
elif action == "调出":
|
|
color = "#f8d7da" # 浅红色
|
|
else:
|
|
color = "#fff3cd" # 浅黄色(维持)
|
|
for j in range(len(col_labels)):
|
|
tbl[i + 1, j].set_facecolor(color)
|
|
|
|
# Panel 1: Metrics table
|
|
ax1 = fig.add_subplot(gs[1])
|
|
ax1.axis("off")
|
|
ax1.set_title("策略绩效对比", fontsize=14, fontweight="bold", loc="left", pad=10)
|
|
|
|
start_s = df.index.min().strftime('%Y-%m-%d')
|
|
end_s = df.index.max().strftime('%Y-%m-%d')
|
|
perf_cols = ["策略", "开始时间", "结束时间", "累计收益", "年化收益", "最大回撤", "夏普比率", "Calmar比率", "日胜率"]
|
|
strat_row = ["轮动策略", start_s, end_s,
|
|
f"{s_total_return:.2%}", f"{s_annual:.2%}", f"{s_dd:.2%}",
|
|
f"{s_sharpe:.2f}", f"{s_calmar:.2f}", f"{s_win_rate:.2%}"]
|
|
bench_row = [f"基准({self.benchmark_name})", start_s, end_s,
|
|
f"{b_total_return:.2%}", f"{b_annual:.2%}", f"{b_dd:.2%}",
|
|
f"{b_sharpe:.2f}", "—", "—"]
|
|
ptbl = ax1.table(cellText=[strat_row, bench_row], colLabels=perf_cols,
|
|
loc="center", cellLoc="center", bbox=[0, 0, 1, 1])
|
|
ptbl.auto_set_font_size(False)
|
|
ptbl.set_fontsize(10)
|
|
ptbl.scale(1, 1.8)
|
|
for j in range(len(perf_cols)):
|
|
ptbl[0, j].set_facecolor("#2C3E50")
|
|
ptbl[0, j].set_text_props(color="white", fontweight="bold")
|
|
ptbl[1, j].set_facecolor("#d4edda")
|
|
ptbl[2, j].set_facecolor("#cce5ff")
|
|
|
|
# Panel 2: NAV curves
|
|
ax2 = fig.add_subplot(gs[2])
|
|
ax2.plot(strategy_nav.index, strategy_nav.values,
|
|
label="轮动策略", linewidth=2, color="#E74C3C")
|
|
if benchmark_nav is not None:
|
|
ax2.plot(benchmark_nav.index, benchmark_nav.values,
|
|
label=self.benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8)
|
|
colors = plt.cm.tab20.colors
|
|
for i, code in enumerate(self.signal_codes):
|
|
if code in asset_navs:
|
|
cfg = code_config.get(code, {})
|
|
lbl = cfg.get('name', code) if i < 10 else None
|
|
ax2.plot(asset_navs[code].index, asset_navs[code].values,
|
|
label=lbl, linewidth=0.8, alpha=0.4,
|
|
color=colors[i % len(colors)])
|
|
ax2.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold")
|
|
ax2.set_ylabel("净值")
|
|
ax2.legend(loc="upper left", fontsize=8, ncol=2)
|
|
ax2.grid(True, alpha=0.3)
|
|
ax2.set_yscale("log")
|
|
|
|
# Panel 3: Drawdown
|
|
ax3 = fig.add_subplot(gs[3])
|
|
drawdown = (strategy_nav - s_peak) / s_peak
|
|
ax3.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C")
|
|
ax3.set_title("策略回撤", fontsize=12)
|
|
ax3.set_ylabel("回撤")
|
|
ax3.grid(True, alpha=0.3)
|
|
|
|
# Panel 4: Holdings distribution
|
|
ax4 = fig.add_subplot(gs[4])
|
|
holdings_series = df['holdings']
|
|
for i, code in enumerate(self.signal_codes):
|
|
cfg = code_config.get(code, {})
|
|
name = cfg.get('name', code)
|
|
mask = holdings_series.apply(lambda h: code in h)
|
|
if mask.any():
|
|
ax4.fill_between(mask.index, i, i + 0.8,
|
|
where=mask, alpha=0.7,
|
|
color=colors[i % len(colors)], label=name)
|
|
ylabels = [code_config.get(c, {}).get('name', c) for c in self.signal_codes]
|
|
ax4.set_title("每日持仓分布", fontsize=12)
|
|
ax4.set_yticks(range(len(ylabels)))
|
|
ax4.set_yticklabels(ylabels, fontsize=7)
|
|
ax4.grid(True, alpha=0.3)
|
|
|
|
chart_path = output_dir / 'simple_rotation_report.png'
|
|
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
print(f" + Report: {chart_path}")
|
|
|
|
|
|
# ============================================================
|
|
# Entry point
|
|
# ============================================================
|
|
|
|
if __name__ == "__main__":
|
|
if 'FLASK_API_URL' not in os.environ:
|
|
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
|
|
|
strategy = SimpleRotationStrategy()
|
|
result = strategy.run()
|
|
|
|
if result:
|
|
strategy.export_results()
|
|
strategy.generate_report()
|