- config_loader.py: WeightType 枚举新增 KELLY - simple_rotation.py: compute_position_weights 新增 kelly 分支 - 公式: w_i = max(score_i, 0) / sum(max(score_j, 0)) - 负分自动排除 (Kelly: 不下注负期望) - 全负分时 fallback 到等权 - _generate_signals 传递 scores 给 kelly 模式 - config_simple.yaml: weight 改为 kelly - 新增策略总结文档: kelly_weight.md 回测对比 (2020-2026): - equal: 年化 19.88%, 夏普 1.13, 回撤 -14.65% - rank: 年化 22.90%, 夏普 1.12, 回撤 -16.27% - kelly: 年化 30.13%, 夏普 1.15, 回撤 -20.44%
1628 lines
71 KiB
Python
1628 lines
71 KiB
Python
"""
|
||
Simple Rotation Strategy (Daily Iteration)
|
||
|
||
From A-share trading calendar, iterate daily:
|
||
1. Get signal source last N days -> compute momentum
|
||
2. Get bond momentum (dynamic threshold)
|
||
3. Group selection -> generate holdings
|
||
4. Compare with yesterday -> compute T+1 return
|
||
5. Update NAV
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import math
|
||
import json
|
||
import time
|
||
import requests
|
||
import numpy as np
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
PROJECT_ROOT = Path(__file__).parent.parent
|
||
sys.path.insert(0, str(PROJECT_ROOT))
|
||
|
||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType
|
||
|
||
# ============================================================
|
||
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||
# ============================================================
|
||
|
||
# Clash 等代理在处理 TLS 1.3 + 后量子密钥交换时会触发 SSL EOF 错误
|
||
# trust_env=False 让 requests 忽略环境变量中的代理配置,直连目标服务器
|
||
_session = requests.Session()
|
||
_session.trust_env = False
|
||
|
||
def _http_get(url: str, params: dict = None, timeout: int = 120) -> requests.Response:
|
||
"""使用 requests 发起 GET 请求(trust_env=False 绕过系统代理)"""
|
||
return _session.get(url, params=params, timeout=timeout)
|
||
|
||
|
||
def _sanitize_json(obj):
|
||
"""Recursively replace NaN/Inf with None in-place so json.dump produces valid JSON"""
|
||
if isinstance(obj, dict):
|
||
for k, v in obj.items():
|
||
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
|
||
obj[k] = None
|
||
elif isinstance(v, (dict, list)):
|
||
_sanitize_json(v)
|
||
elif isinstance(obj, list):
|
||
for i, v in enumerate(obj):
|
||
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
|
||
obj[i] = None
|
||
elif isinstance(v, (dict, list)):
|
||
_sanitize_json(v)
|
||
|
||
|
||
# ============================================================
|
||
# Pure functions: momentum
|
||
# ============================================================
|
||
|
||
def weighted_momentum_score(prices: np.ndarray) -> float:
|
||
"""Weighted linear regression momentum score = annualized_return * R^2"""
|
||
if len(prices) < 5:
|
||
return 0.0
|
||
prices = np.clip(prices, 0.01, None)
|
||
y = np.log(prices)
|
||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||
return 0.0
|
||
x = np.arange(len(y))
|
||
weights = np.linspace(1, 2, len(y))
|
||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||
annualized_return = math.exp(slope * 250) - 1
|
||
y_pred = slope * x + intercept
|
||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||
return annualized_return * r2
|
||
|
||
|
||
def vol_adjusted_momentum_score(prices: np.ndarray) -> float:
|
||
"""Volatility-adjusted momentum score = (annualized_return / annualized_vol) * R^2
|
||
|
||
Moskowitz, Ooi & Pedersen (2012) TSMOM approach:
|
||
divide momentum by realized volatility to make cross-asset comparison fair.
|
||
"""
|
||
if len(prices) < 5:
|
||
return 0.0
|
||
prices = np.clip(prices, 0.01, None)
|
||
y = np.log(prices)
|
||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||
return 0.0
|
||
x = np.arange(len(y))
|
||
weights = np.linspace(1, 2, len(y))
|
||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||
annualized_return = math.exp(slope * 250) - 1
|
||
# realized volatility (annualized)
|
||
daily_returns = np.diff(y)
|
||
realized_vol = float(np.std(daily_returns)) * math.sqrt(250)
|
||
if realized_vol < 0.01: # guard against near-zero vol
|
||
realized_vol = 0.01
|
||
# trend quality (R^2)
|
||
y_pred = slope * x + intercept
|
||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||
return (annualized_return / realized_vol) * r2
|
||
|
||
|
||
def slope_r2_score(prices: np.ndarray) -> float:
|
||
"""Slope * R² score (unweighted, normalized prices)"""
|
||
if len(prices) < 5:
|
||
return 0.0
|
||
prices = np.clip(prices, 0.01, None)
|
||
y = prices / prices[0] # normalize
|
||
x = np.arange(len(y))
|
||
slope, intercept = np.polyfit(x, y, 1)
|
||
y_pred = slope * x + intercept
|
||
ss_res = np.sum((y - y_pred) ** 2)
|
||
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||
return 10000 * slope * r2
|
||
|
||
|
||
def standardized_slope_score(prices: np.ndarray) -> float:
|
||
"""Standardized slope (t-statistic): slope / SE(slope)
|
||
|
||
Academic basis:
|
||
- Uses normalized prices (p/p[0]) for cross-asset comparability,
|
||
consistent with slope_r2_score.
|
||
- Divides slope by its standard error, yielding a statistical significance
|
||
measure rather than raw magnitude.
|
||
- Equivalent to the t-value for H0: slope=0, penalizing noisy trends.
|
||
|
||
Formula:
|
||
SE(slope) = sqrt(MSE / Sxx)
|
||
MSE = SS_res / (n - 2)
|
||
Sxx = sum((xi - x_bar)^2) = n*(n-1)*(n+1)/12 for x = 0..n-1
|
||
"""
|
||
n = len(prices)
|
||
if n < 5:
|
||
return 0.0
|
||
prices = np.clip(prices, 0.01, None)
|
||
y = prices / prices[0] # normalize
|
||
x = np.arange(n)
|
||
slope, intercept = np.polyfit(x, y, 1)
|
||
y_pred = slope * x + intercept
|
||
ss_res = np.sum((y - y_pred) ** 2)
|
||
# Standard error of slope
|
||
mse = ss_res / (n - 2) # unbiased MSE
|
||
sxx = n * (n - 1) * (n + 1) / 12 # sum of squared deviations of x
|
||
se_slope = math.sqrt(mse / sxx) if sxx > 0 else 1e-9
|
||
if se_slope < 1e-12:
|
||
se_slope = 1e-12
|
||
return slope / se_slope
|
||
|
||
|
||
def momentum_score(prices: np.ndarray) -> float:
|
||
"""Simple price return: (last / first) - 1"""
|
||
if len(prices) < 5:
|
||
return 0.0
|
||
prices = np.clip(prices, 0.01, None)
|
||
return prices[-1] / prices[0] - 1
|
||
|
||
|
||
# ============================================================
|
||
# Pure functions: position weight schemes
|
||
# ============================================================
|
||
|
||
def compute_position_weights(
|
||
ranked_holdings: List[str],
|
||
weight_type: str = 'equal',
|
||
scores: Dict[str, float] = None,
|
||
) -> Dict[str, float]:
|
||
"""Compute position weights from ranked slot list.
|
||
|
||
Args:
|
||
ranked_holdings: Ordered list of signal codes, best first.
|
||
May contain duplicates (e.g. bond fills).
|
||
weight_type: 'equal', 'rank', or 'kelly'.
|
||
scores: Required for 'kelly'. Dict mapping code -> momentum score.
|
||
|
||
Returns:
|
||
Dict mapping each unique code to its total weight (sum of slots).
|
||
|
||
Schemes:
|
||
equal: each slot = 1/N, duplicates summed.
|
||
rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed.
|
||
For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%].
|
||
kelly: w_i = max(score_i, 0) / sum(max(score_j, 0)).
|
||
Score-proportional weighting as Kelly criterion proxy.
|
||
Negative scores excluded (Kelly: don't bet on negative edge).
|
||
"""
|
||
N = len(ranked_holdings)
|
||
if N == 0:
|
||
return {}
|
||
|
||
weights: Dict[str, float] = {}
|
||
|
||
if weight_type == 'kelly':
|
||
if not scores:
|
||
raise ValueError("Kelly weighting requires 'scores' parameter")
|
||
# Kelly proxy: weight proportional to positive scores
|
||
positive_scores = {c: max(scores.get(c, 0.0), 0.0) for c in set(ranked_holdings)}
|
||
total = sum(positive_scores.values())
|
||
if total <= 0:
|
||
# Fallback to equal if all scores non-positive
|
||
w = 1.0 / len(positive_scores)
|
||
for code in positive_scores:
|
||
weights[code] = w
|
||
else:
|
||
for code in ranked_holdings:
|
||
w = positive_scores.get(code, 0.0) / total
|
||
weights[code] = weights.get(code, 0.0) + w
|
||
|
||
elif weight_type == 'rank':
|
||
triangular = N * (N + 1) / 2
|
||
for i, code in enumerate(ranked_holdings):
|
||
w = (N - i) / triangular
|
||
weights[code] = weights.get(code, 0.0) + w
|
||
else:
|
||
# equal (default)
|
||
w = 1.0 / N
|
||
for code in ranked_holdings:
|
||
weights[code] = weights.get(code, 0.0) + w
|
||
|
||
return weights
|
||
|
||
|
||
def is_crash(prices: np.ndarray) -> bool:
|
||
"""Crash filter: 3 consecutive days drop > 5%"""
|
||
if len(prices) < 4:
|
||
return False
|
||
p = prices[-4:]
|
||
r1 = p[3] / p[2]
|
||
r2 = p[2] / p[1]
|
||
r3 = p[1] / p[0]
|
||
con1 = min(r1, r2, r3) < 0.95
|
||
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
|
||
return con1 or con2
|
||
|
||
|
||
# ============================================================
|
||
# Data cache
|
||
# ============================================================
|
||
|
||
class DataCache:
|
||
"""Data fetching (no local cache, always from API)"""
|
||
|
||
def __init__(self, base_url: str, timeout: int = 120):
|
||
self.base_url = base_url.rstrip('/')
|
||
self.api_path = '/api/v1/ohlcv'
|
||
self.timeout = timeout
|
||
# premium data in memory: {trade_code: {date_str: premium_ratio}}
|
||
self.premium_data: Dict[str, Dict[str, float]] = {}
|
||
|
||
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
|
||
"""Fetch data from API"""
|
||
return self._fetch_api(code, start_date, end_date, adj)
|
||
|
||
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
|
||
"""Fetch from Flask API, also extracts premium_series for ETFs"""
|
||
url = f"{self.base_url}{self.api_path}"
|
||
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
|
||
for attempt in range(3):
|
||
try:
|
||
resp = _http_get(url, params=params, timeout=self.timeout)
|
||
if resp.status_code != 200:
|
||
if attempt < 2:
|
||
time.sleep(1)
|
||
continue
|
||
print(f" x {code}: HTTP {resp.status_code}")
|
||
return None
|
||
data = resp.json()
|
||
if 'error' in data:
|
||
print(f" x {code}: {data['error']}")
|
||
return None
|
||
records = data.get('data', [])
|
||
if not records:
|
||
return None
|
||
df = pd.DataFrame(records)
|
||
if 'date' in df.columns:
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df = df.set_index('date').sort_index()
|
||
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
|
||
df = df[keep]
|
||
# Extract premium_series (ETF only) and store in memory
|
||
premium_series = data.get('premium_series', [])
|
||
if premium_series:
|
||
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
|
||
if code not in self.premium_data:
|
||
self.premium_data[code] = {}
|
||
self.premium_data[code].update(df.attrs['premium_series'])
|
||
print(f" + {code}: {len(df)} rows ({adj})")
|
||
return df
|
||
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||
# 网络相关错误(超时、SSL、连接断开)都进行重试
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt) # 递增延迟: 1s, 2s
|
||
continue
|
||
print(f" x {code}: {type(e).__name__} after {attempt+1} retries")
|
||
return None
|
||
except Exception as e:
|
||
print(f" x {code}: {e}")
|
||
return None
|
||
return None
|
||
|
||
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
|
||
"""Load premium data for an ETF code, fetch from API if not in memory."""
|
||
if code in self.premium_data:
|
||
if end_date:
|
||
dates = sorted(self.premium_data[code].keys())
|
||
if dates and dates[-1] >= end_date:
|
||
return self.premium_data[code]
|
||
else:
|
||
return self.premium_data[code]
|
||
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
|
||
return self.premium_data.get(code)
|
||
|
||
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
|
||
"""Fetch premium_series from API and store in memory"""
|
||
url = f"{self.base_url}{self.api_path}"
|
||
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
|
||
for attempt in range(3):
|
||
try:
|
||
resp = _http_get(url, params=params, timeout=self.timeout)
|
||
if resp.status_code != 200:
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt)
|
||
continue
|
||
return
|
||
data = resp.json()
|
||
if 'error' in data:
|
||
return
|
||
premium_series = data.get('premium_series', [])
|
||
if premium_series:
|
||
new_data = {item['date']: item['premium'] for item in premium_series}
|
||
if code not in self.premium_data:
|
||
self.premium_data[code] = {}
|
||
self.premium_data[code].update(new_data)
|
||
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
|
||
return
|
||
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt)
|
||
continue
|
||
return
|
||
except Exception:
|
||
return
|
||
|
||
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
|
||
"""Fetch trading calendar from API"""
|
||
url = f"{self.base_url}/api/v1/trading-calendar"
|
||
params = {'market': market, 'start': start_date, 'end': end_date}
|
||
for attempt in range(3):
|
||
try:
|
||
resp = _http_get(url, params=params, timeout=self.timeout)
|
||
if resp.status_code != 200:
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt)
|
||
continue
|
||
return None
|
||
data = resp.json()
|
||
if 'error' in data:
|
||
print(f" x calendar: {data['error']}")
|
||
return None
|
||
dates = data.get('trading_dates', [])
|
||
if not dates:
|
||
return pd.DatetimeIndex([])
|
||
result = pd.DatetimeIndex(dates)
|
||
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
|
||
return result
|
||
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt)
|
||
continue
|
||
print(f" x calendar: {type(e).__name__} after {attempt+1} retries")
|
||
return None
|
||
except Exception as e:
|
||
if attempt < 2:
|
||
time.sleep(1 + attempt)
|
||
continue
|
||
print(f" x calendar: {e}")
|
||
return None
|
||
return None
|
||
|
||
|
||
# ============================================================
|
||
# Core strategy class
|
||
# ============================================================
|
||
|
||
class SimpleRotationStrategy:
|
||
"""
|
||
Simple rotation strategy (daily iteration)
|
||
|
||
Flow:
|
||
1. Load config
|
||
2. Preload all historical data (index raw + ETF hfq)
|
||
3. Fetch A-share trading calendar
|
||
4. For each trading day:
|
||
- Compute momentum (last 25 days)
|
||
- Group selection -> generate holdings
|
||
- Compare with yesterday -> compute T+1 return
|
||
- Update NAV
|
||
5. Export results
|
||
"""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
if config_path is None:
|
||
config_path = Path(__file__).parent / 'config_simple.yaml'
|
||
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
|
||
|
||
# Strategy params
|
||
self.n_days = self.config.factor.n_days
|
||
self.select_num = self.config.rotation.select_num
|
||
self.trade_cost = self.config.rebalance.trade_cost
|
||
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
|
||
|
||
# Dynamic threshold
|
||
threshold = self.config.rotation.threshold
|
||
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
|
||
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
|
||
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
|
||
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
|
||
|
||
# Signal codes and mappings
|
||
self.signal_codes = []
|
||
self.signal_to_trade = {}
|
||
self.code_to_group = {}
|
||
self.trade_code_to_group = {}
|
||
for code, asset in self.config.asset_pools.assets.items():
|
||
self.signal_codes.append(asset.signal_source)
|
||
self.signal_to_trade[asset.signal_source] = asset.trade_source
|
||
self.code_to_group[asset.signal_source] = asset.group
|
||
self.trade_code_to_group[asset.trade_source] = asset.group
|
||
|
||
# Data source
|
||
data_source = self.config.data.sources[0]
|
||
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
|
||
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
|
||
|
||
# Preloaded data
|
||
self.index_data: Dict[str, pd.DataFrame] = {}
|
||
self.etf_data: Dict[str, pd.DataFrame] = {}
|
||
self.benchmark_data: Optional[pd.DataFrame] = None
|
||
|
||
# Benchmark config
|
||
self.benchmark_code = self.config.benchmark.code
|
||
self.benchmark_name = self.config.benchmark.name
|
||
|
||
# Results
|
||
self.daily_records: List[dict] = []
|
||
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
||
|
||
# Position weights: code -> weight (updated each day by _generate_signals)
|
||
self._position_weights: Dict[str, float] = {}
|
||
|
||
def _preload_data(self):
|
||
"""Preload all historical data"""
|
||
start_date = self.config.backtest.start_date
|
||
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
||
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
|
||
|
||
print("\n[1/4] Preloading signal sources (index raw)...")
|
||
for code in self.signal_codes:
|
||
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
|
||
if df is not None:
|
||
self.index_data[code] = df
|
||
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
|
||
|
||
print("\n[2/4] Preloading trade sources (ETF hfq)...")
|
||
trade_codes = set(self.signal_to_trade.values())
|
||
for code in trade_codes:
|
||
is_bond = any(
|
||
a.trade_source == code and a.group == 'BOND'
|
||
for a in self.config.asset_pools.assets.values()
|
||
)
|
||
adj = 'raw' if is_bond else 'hfq'
|
||
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
|
||
if df is not None:
|
||
self.etf_data[code] = df
|
||
# Load premium data cache for all ETF trade codes
|
||
for code in trade_codes:
|
||
self.data_cache.preload_premium(code, end_date=end_date)
|
||
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded")
|
||
|
||
# Load benchmark
|
||
print(f"\n Loading benchmark ({self.benchmark_code})...")
|
||
bm_df = self.data_cache.preload(self.benchmark_code, preload_start, end_date, adj='raw')
|
||
if bm_df is not None:
|
||
self.benchmark_data = bm_df
|
||
print(f" Benchmark: {len(bm_df)} rows")
|
||
|
||
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Compute momentum for a single code on a given date (uses configured score function)"""
|
||
if signal_code not in self.index_data:
|
||
return None
|
||
df = self.index_data[signal_code]
|
||
mask = df.index <= date
|
||
recent = df.loc[mask]
|
||
if len(recent) < self.n_days:
|
||
return None
|
||
prices = recent['close'].values[-self.n_days:]
|
||
if len(prices) >= 4 and is_crash(prices):
|
||
return 0.0
|
||
# Dispatch based on factor type
|
||
ft = self.config.factor.type
|
||
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||
return vol_adjusted_momentum_score(prices)
|
||
elif ft == FactorType.SLOPE_R2:
|
||
return slope_r2_score(prices)
|
||
elif ft == FactorType.STANDARDIZED_SLOPE:
|
||
return standardized_slope_score(prices)
|
||
elif ft == FactorType.MOMENTUM:
|
||
return momentum_score(prices)
|
||
return weighted_momentum_score(prices)
|
||
|
||
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Compute momentum using the configured factor function.
|
||
For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score
|
||
so that threshold comparison operates in raw return space."""
|
||
if signal_code not in self.index_data:
|
||
return None
|
||
df = self.index_data[signal_code]
|
||
mask = df.index <= date
|
||
recent = df.loc[mask]
|
||
if len(recent) < self.n_days:
|
||
return None
|
||
prices = recent['close'].values[-self.n_days:]
|
||
if len(prices) >= 4 and is_crash(prices):
|
||
return 0.0
|
||
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||
return weighted_momentum_score(prices)
|
||
# All other factors: use the same score function as ranking
|
||
return self._compute_momentum(signal_code, date)
|
||
|
||
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
||
"""
|
||
Generate rotation signals (group competition + dynamic threshold + BOND fill)
|
||
|
||
Returns:
|
||
(holdings, factors, bond_momentum): tuple of selected holdings,
|
||
all computed factor scores, and the bond momentum threshold value.
|
||
|
||
Logic (identical to V2):
|
||
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
|
||
2. From all group winners: sort by momentum, select Top select_num
|
||
3. Fill remaining slots with BOND
|
||
"""
|
||
factors: Dict[str, float] = {}
|
||
for code in self.signal_codes:
|
||
score = self._compute_momentum(code, date)
|
||
if score is not None:
|
||
factors[code] = score
|
||
if not factors:
|
||
return [], {}, None
|
||
|
||
# Bond threshold always uses raw (weighted) momentum
|
||
bond_momentum = None
|
||
if self.use_dynamic_threshold and self.bond_code:
|
||
bond_momentum = self._compute_raw_momentum(self.bond_code, date)
|
||
if bond_momentum is None:
|
||
bond_momentum = self.fallback_value
|
||
|
||
# Raw factors for threshold comparison
|
||
raw_factors: Dict[str, float] = {}
|
||
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||
for code in self.signal_codes:
|
||
score = self._compute_raw_momentum(code, date)
|
||
if score is not None:
|
||
raw_factors[code] = score
|
||
else:
|
||
raw_factors = factors
|
||
|
||
groups = self.config.asset_pools.by_group
|
||
selected_by_group: Dict[str, Tuple[str, float]] = {}
|
||
|
||
for group_name, assets in groups.items():
|
||
group_codes = [a.signal_source for a in assets.values()]
|
||
group_factors = {c: factors[c] for c in group_codes if c in factors}
|
||
group_raw = {c: raw_factors[c] for c in group_codes if c in raw_factors}
|
||
if not group_factors:
|
||
continue
|
||
# Threshold comparison uses raw (weighted) momentum
|
||
if group_name != 'BOND' and bond_momentum is not None:
|
||
thresh = bond_momentum * self.bond_ratio
|
||
group_factors = {c: s for c, s in group_factors.items()
|
||
if group_raw.get(c, 0) >= thresh}
|
||
if not group_factors:
|
||
continue
|
||
top_code = max(group_factors, key=group_factors.get)
|
||
selected_by_group[group_name] = (top_code, group_factors[top_code])
|
||
|
||
if not selected_by_group:
|
||
return [], factors, bond_momentum
|
||
|
||
candidates = list(selected_by_group.values())
|
||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||
ranked_holdings = [code for code, _ in candidates[:self.select_num]]
|
||
|
||
if len(ranked_holdings) < self.select_num and self.bond_code:
|
||
if self.bond_code not in ranked_holdings:
|
||
n_slots = self.select_num - len(ranked_holdings)
|
||
ranked_holdings.extend([self.bond_code] * n_slots)
|
||
|
||
# Compute position weights via configured scheme.
|
||
# These are *pending* weights; the caller (run) locks them in
|
||
# only when an actual rebalance occurs.
|
||
self._pending_weights = compute_position_weights(
|
||
ranked_holdings, self.weight_type, scores=factors,
|
||
)
|
||
|
||
return sorted(ranked_holdings), factors, bond_momentum
|
||
|
||
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
||
"""Get ETF prices on a given date: {open, close, prev_close}
|
||
Note: Index data may lack 'open' column; use close as fallback.
|
||
"""
|
||
if trade_code not in self.etf_data:
|
||
return None
|
||
df = self.etf_data[trade_code]
|
||
mask = df.index <= date
|
||
recent = df.loc[mask]
|
||
if len(recent) < 2:
|
||
return None
|
||
today_row = recent.iloc[-1]
|
||
prev_row = recent.iloc[-2]
|
||
today_date = recent.index[-1]
|
||
if (date - today_date).days > 5:
|
||
return None
|
||
close_raw = today_row.get('close')
|
||
prev_close_raw = prev_row.get('close')
|
||
if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
|
||
return None
|
||
close = float(close_raw)
|
||
prev_close = float(prev_close_raw)
|
||
# Handle missing/invalid open (common for index data like 931862.CSI)
|
||
raw_open = today_row.get('open')
|
||
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
|
||
if pd.isna(open_price) or open_price == 0:
|
||
open_price = close
|
||
return {
|
||
'open': open_price,
|
||
'close': close,
|
||
'prev_close': prev_close,
|
||
}
|
||
|
||
def _get_weight(self, code: str, n_unique: int) -> float:
|
||
"""Get position weight for a code. Falls back to equal weight if not available."""
|
||
if code in self._position_weights:
|
||
return self._position_weights[code]
|
||
return 1.0 / n_unique if n_unique > 0 else 0.0
|
||
|
||
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
||
"""
|
||
Compute daily return (T+1 execution) with configurable position weighting:
|
||
- Hold: close-to-close, weighted by today's position weight
|
||
- Sell: close-to-open (sold at open), weighted by today's position weight
|
||
- Buy: open-to-close (intraday), weighted by today's position weight
|
||
"""
|
||
if not old_holdings:
|
||
if not new_holdings:
|
||
return 0.0
|
||
ret = 0.0
|
||
seen = set()
|
||
n_unique = len(set(new_holdings))
|
||
for code in new_holdings:
|
||
if code in seen:
|
||
continue
|
||
seen.add(code)
|
||
tc = self.signal_to_trade.get(code, code)
|
||
p = self._get_etf_prices(tc, date)
|
||
w = self._get_weight(code, n_unique)
|
||
if p and p['open'] > 0:
|
||
ret += w * (p['close'] - p['open']) / p['open']
|
||
if is_rebalance:
|
||
ret -= self.trade_cost
|
||
return ret
|
||
|
||
old_set = set(old_holdings)
|
||
new_set = set(new_holdings)
|
||
n_unique = len(old_set)
|
||
daily_return = 0.0
|
||
|
||
for code in old_set:
|
||
tc = self.signal_to_trade.get(code, code)
|
||
p = self._get_etf_prices(tc, date)
|
||
if p is None or p['prev_close'] == 0:
|
||
continue
|
||
w = self._get_weight(code, n_unique)
|
||
if code in new_set:
|
||
r = (p['close'] - p['prev_close']) / p['prev_close']
|
||
else:
|
||
r = (p['open'] - p['prev_close']) / p['prev_close']
|
||
if not math.isnan(r):
|
||
daily_return += w * r
|
||
|
||
for code in new_set - old_set:
|
||
tc = self.signal_to_trade.get(code, code)
|
||
p = self._get_etf_prices(tc, date)
|
||
w = self._get_weight(code, len(new_set))
|
||
if p and p['open'] > 0 and not math.isnan(p['close']):
|
||
r = (p['close'] - p['open']) / p['open']
|
||
if not math.isnan(r):
|
||
daily_return += w * r
|
||
|
||
if is_rebalance:
|
||
daily_return -= self.trade_cost
|
||
return daily_return
|
||
|
||
def run(self) -> dict:
|
||
"""Main backtest loop"""
|
||
print("=" * 60)
|
||
print(" Simple Rotation Strategy (Daily Iteration)")
|
||
print("=" * 60)
|
||
|
||
self._preload_data()
|
||
|
||
print("\n[3/4] Fetching A-share trading calendar...")
|
||
start_date = self.config.backtest.start_date
|
||
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
|
||
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
|
||
|
||
if self.trading_calendar is None or len(self.trading_calendar) == 0:
|
||
print(" x Calendar fetch failed")
|
||
return {}
|
||
|
||
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
|
||
current_holdings: List[str] = []
|
||
nav = 1.0
|
||
rebalance_count = 0
|
||
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
|
||
active_weights: Dict[str, float] = {} # locked-in weights, updated only on rebalance
|
||
|
||
for i, date in enumerate(self.trading_calendar):
|
||
# Signal timing: 9:00 AM on day T
|
||
# At this moment, T's market has NOT opened yet.
|
||
# Only T-1 close data is available for all markets.
|
||
# Use calendar day -1 (not A-share prev trading day) so that
|
||
# foreign markets (US/HK) that traded during A-share holidays are captured.
|
||
# Execution happens at 9:30 AM using T's ETF prices.
|
||
signal_date = date - timedelta(days=1)
|
||
|
||
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
|
||
|
||
# Data integrity check: if any currently held asset is missing from
|
||
# today's factors, abort immediately to prevent false rebalancing.
|
||
if current_holdings:
|
||
missing = [c for c in current_holdings if c not in factors]
|
||
if missing:
|
||
raise RuntimeError(
|
||
f"Data failure: held assets {missing} missing from factors on "
|
||
f"{date.strftime('%Y-%m-%d')}. Aborting to prevent false rebalance."
|
||
)
|
||
|
||
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
|
||
|
||
# Lock in position weights only on rebalance (or first day)
|
||
if is_rebalance or not current_holdings:
|
||
active_weights = dict(self._pending_weights)
|
||
self._position_weights = active_weights
|
||
|
||
# Return uses T's ETF prices (open for buy/sell, close for hold)
|
||
daily_return = self._calculate_daily_return(
|
||
current_holdings, new_holdings, date, is_rebalance
|
||
)
|
||
nav *= (1 + daily_return)
|
||
|
||
if is_rebalance:
|
||
rebalance_count += 1
|
||
|
||
# Update entry tracking for held assets
|
||
added = set(new_holdings) - set(current_holdings)
|
||
removed = set(current_holdings) - set(new_holdings)
|
||
for code in added:
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
etf_prices = self._get_etf_prices(trade_code, date)
|
||
# Entry price is the actual buy price: today's open
|
||
entry_etf = etf_prices['open'] if etf_prices else None
|
||
# Index close at T-1 (the data used to make the decision)
|
||
entry_idx = self._get_index_close(code, signal_date)
|
||
entry_info[code] = {
|
||
'entry_date': date.strftime('%Y-%m-%d'),
|
||
'entry_price_etf': entry_etf,
|
||
'entry_price_idx': entry_idx,
|
||
}
|
||
for code in removed:
|
||
entry_info.pop(code, None)
|
||
|
||
# Compute bond threshold value for detail record
|
||
threshold_val = 0.0
|
||
if self.use_dynamic_threshold and bond_momentum is not None:
|
||
threshold_val = round(bond_momentum * self.bond_ratio, 6)
|
||
|
||
self.daily_records.append({
|
||
'date': date.strftime('%Y-%m-%d'),
|
||
'nav': round(nav, 6),
|
||
'daily_return': round(daily_return, 6),
|
||
'is_rebalance': is_rebalance,
|
||
'holdings': sorted(new_holdings),
|
||
'added': sorted(added),
|
||
'removed': sorted(removed),
|
||
'factors': {k: round(v, 6) for k, v in factors.items()},
|
||
'threshold': threshold_val,
|
||
'position_weights': {k: round(v, 6) for k, v in active_weights.items()},
|
||
})
|
||
|
||
current_holdings = new_holdings
|
||
|
||
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
|
||
print(f" [{i+1}/{len(self.trading_calendar)}] "
|
||
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
|
||
|
||
metrics = self._compute_metrics(rebalance_count)
|
||
|
||
print("\n" + "=" * 60)
|
||
print(" Backtest Complete")
|
||
print("=" * 60)
|
||
if self.daily_records:
|
||
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
|
||
print(f" Trading days: {len(self.daily_records)}")
|
||
print(f" Total return: {metrics['total_return']:.2%}")
|
||
print(f" Annual return: {metrics['annual_return']:.2%}")
|
||
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
|
||
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
|
||
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
|
||
print(f" Win rate: {metrics['win_rate']:.2%}")
|
||
print(f" Rebalances: {rebalance_count}")
|
||
print("=" * 60)
|
||
|
||
return {
|
||
'metrics': metrics,
|
||
'daily_records': self.daily_records,
|
||
}
|
||
|
||
def _compute_metrics(self, rebalance_count: int) -> dict:
|
||
"""Compute performance metrics"""
|
||
if not self.daily_records:
|
||
return {}
|
||
df = pd.DataFrame(self.daily_records)
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
nav = df['nav']
|
||
returns = df['daily_return']
|
||
|
||
total_return = nav.iloc[-1] / nav.iloc[0] - 1
|
||
n_days = len(df)
|
||
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||
|
||
peak = nav.cummax()
|
||
drawdown = (nav - peak) / peak
|
||
max_drawdown = drawdown.min()
|
||
|
||
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
|
||
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
|
||
|
||
non_zero = returns[returns != 0]
|
||
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
|
||
|
||
return {
|
||
'total_return': total_return,
|
||
'annual_return': annual_return,
|
||
'max_drawdown': max_drawdown,
|
||
'sharpe_ratio': sharpe,
|
||
'calmar_ratio': calmar,
|
||
'win_rate': win_rate,
|
||
'n_days': n_days,
|
||
'rebalance_count': rebalance_count,
|
||
}
|
||
|
||
# ============================================================
|
||
# Detail JSON helpers (V2-compatible)
|
||
# ============================================================
|
||
|
||
def _build_meta_codes(self) -> dict:
|
||
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
|
||
codes = {}
|
||
for code, asset in self.config.asset_pools.assets.items():
|
||
codes[asset.signal_source] = {
|
||
'name': getattr(asset, 'name', code),
|
||
'etf': asset.trade_source,
|
||
'market': asset.group,
|
||
}
|
||
return codes
|
||
|
||
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Get index close price on or before date."""
|
||
df = self.index_data.get(code)
|
||
if df is None:
|
||
return None
|
||
mask = df.index <= date
|
||
if mask.sum() == 0:
|
||
return None
|
||
return float(df.loc[mask].iloc[-1]['close'])
|
||
|
||
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Get ETF close price on or before date."""
|
||
df = self.etf_data.get(trade_code)
|
||
if df is None:
|
||
return None
|
||
mask = df.index <= date
|
||
if mask.sum() == 0:
|
||
return None
|
||
return float(df.loc[mask].iloc[-1]['close'])
|
||
|
||
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
|
||
"""Get (index_return, etf_return_ctc) for a code on a given date."""
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
idx_df = self.index_data.get(code)
|
||
etf_df = self.etf_data.get(trade_code)
|
||
|
||
idx_ret = None
|
||
if idx_df is not None:
|
||
mask = idx_df.index <= date
|
||
if mask.sum() >= 2:
|
||
rows = idx_df.loc[mask]
|
||
c_today = float(rows.iloc[-1]['close'])
|
||
c_prev = float(rows.iloc[-2]['close'])
|
||
if c_prev > 0 and not pd.isna(c_today):
|
||
idx_ret = round((c_today - c_prev) / c_prev, 6)
|
||
|
||
etf_ret = None
|
||
if etf_df is not None:
|
||
mask = etf_df.index <= date
|
||
if mask.sum() >= 2:
|
||
rows = etf_df.loc[mask]
|
||
c_today = float(rows.iloc[-1]['close'])
|
||
c_prev = float(rows.iloc[-2]['close'])
|
||
if c_prev > 0 and not pd.isna(c_today):
|
||
etf_ret = round((c_today - c_prev) / c_prev, 6)
|
||
|
||
return idx_ret, etf_ret
|
||
|
||
def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Get real premium from API data cache: (ETF_price - NAV) / NAV.
|
||
Returns None for BOND or when premium data is unavailable."""
|
||
group = self.trade_code_to_group.get(trade_code, '')
|
||
if group == 'BOND':
|
||
return None
|
||
premium_dict = self.data_cache.premium_data.get(trade_code)
|
||
if not premium_dict:
|
||
return None
|
||
date_str = date.strftime('%Y-%m-%d')
|
||
val = premium_dict.get(date_str)
|
||
if val is None:
|
||
return None
|
||
return round(float(val), 6)
|
||
|
||
def _get_latest_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
|
||
"""Get premium for trade_code, looking back up to 5 days if exact date not found."""
|
||
group = self.trade_code_to_group.get(trade_code, '')
|
||
if group == 'BOND':
|
||
return None
|
||
premium_dict = self.data_cache.premium_data.get(trade_code)
|
||
if not premium_dict:
|
||
return None
|
||
# Try exact date first, then look back up to 5 calendar days
|
||
for offset in range(6):
|
||
check_date = date - timedelta(days=offset)
|
||
date_str = check_date.strftime('%Y-%m-%d')
|
||
val = premium_dict.get(date_str)
|
||
if val is not None:
|
||
return round(float(val), 6)
|
||
return None
|
||
|
||
def _build_day_assets(self, record: dict, date: pd.Timestamp,
|
||
entry_info: Dict[str, dict]) -> dict:
|
||
"""Build V2-compatible per-asset detail dict for one day."""
|
||
factors = record.get('factors', {})
|
||
holdings = set(record['holdings'])
|
||
threshold = record.get('threshold', 0.0)
|
||
position_weights = record.get('position_weights', {})
|
||
|
||
# Rank: sort all codes by momentum descending, rank 1 = highest
|
||
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
|
||
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
|
||
|
||
assets = {}
|
||
for code in self.signal_codes:
|
||
momentum = factors.get(code)
|
||
rank = rank_map.get(code)
|
||
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
|
||
is_held = code in holdings
|
||
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
idx_close = self._get_index_close(code, date)
|
||
etf_close = self._get_etf_close(trade_code, date)
|
||
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
|
||
premium = self._compute_premium(trade_code, date)
|
||
|
||
# Entry / holding info
|
||
ei = entry_info.get(code) if is_held else None
|
||
entry_date = ei['entry_date'] if ei else None
|
||
entry_price_etf = ei['entry_price_etf'] if ei else None
|
||
entry_price_idx = ei['entry_price_idx'] if ei else None
|
||
holding_days = 0
|
||
cum_ret_etf = None
|
||
cum_ret_idx = None
|
||
|
||
if ei and is_held:
|
||
entry_dt = pd.Timestamp(ei['entry_date'])
|
||
holding_days = int((date - entry_dt).days)
|
||
if holding_days < 1:
|
||
holding_days = 1
|
||
ep_etf = ei.get('entry_price_etf')
|
||
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
|
||
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
|
||
ep_idx = ei.get('entry_price_idx')
|
||
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
|
||
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
|
||
|
||
assets[code] = {
|
||
'momentum': float(momentum) if momentum is not None else None,
|
||
'rank': rank,
|
||
'threshold': float(threshold),
|
||
'above_threshold': bool(above_thresh),
|
||
'index_close': float(idx_close) if idx_close is not None else None,
|
||
'etf_close': float(etf_close) if etf_close is not None else None,
|
||
'index_return': float(idx_ret) if idx_ret is not None else None,
|
||
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
|
||
'premium': float(premium) if premium is not None else None,
|
||
'is_held': bool(is_held),
|
||
'entry_date': entry_date,
|
||
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
|
||
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
|
||
'holding_days': holding_days,
|
||
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
|
||
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
|
||
'weight': float(position_weights.get(code, 0.0)) if is_held else None,
|
||
}
|
||
return assets
|
||
|
||
# ============================================================
|
||
# Export
|
||
# ============================================================
|
||
|
||
def export_results(self, output_dir: str = None, detail: bool = True):
|
||
"""Export backtest results to CSV and JSON (V2-compatible detail format)
|
||
|
||
Args:
|
||
output_dir: Output directory path.
|
||
detail: If True, export detail JSON (large file for backtest_viewer).
|
||
Set to False for daily runs to skip expensive detail generation.
|
||
"""
|
||
if not self.daily_records:
|
||
print(" x No results to export")
|
||
return
|
||
|
||
if output_dir is None:
|
||
output_dir = Path(__file__).parent / 'results'
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
df = pd.DataFrame(self.daily_records)
|
||
|
||
# NAV curve
|
||
nav_path = output_dir / 'simple_rotation_nav.csv'
|
||
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
|
||
print(f" + NAV: {nav_path}")
|
||
|
||
# Signals
|
||
sig_path = output_dir / 'simple_rotation_signals.csv'
|
||
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
|
||
print(f" + Signals: {sig_path}")
|
||
|
||
# Detail JSON (V2-compatible format, optional — large file for backtest_viewer)
|
||
if detail:
|
||
detail_path = output_dir / 'simple_rotation_detail.json'
|
||
days_out = []
|
||
# Track entry_info across days for asset detail reconstruction
|
||
tracked_entry: Dict[str, dict] = {}
|
||
prev_holdings = []
|
||
|
||
# Build date index map for signal_date lookup (calendar T-1, not A-share prev trading day)
|
||
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
|
||
date_to_signal_date = {d: d - timedelta(days=1) for d in date_list}
|
||
|
||
for rec in self.daily_records:
|
||
date = pd.Timestamp(rec['date'])
|
||
signal_date = date_to_signal_date[date] # T-1 for signal
|
||
holdings = rec['holdings']
|
||
added = set(holdings) - set(prev_holdings)
|
||
removed = set(prev_holdings) - set(holdings)
|
||
|
||
# Update entry tracking (consistent with run() logic)
|
||
for code in added:
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
etf_prices = self._get_etf_prices(trade_code, date)
|
||
# Entry price = actual buy price at T's open
|
||
entry_etf = etf_prices['open'] if etf_prices else None
|
||
# Index close at T-1 (signal data used for decision)
|
||
idx_close = self._get_index_close(code, signal_date)
|
||
tracked_entry[code] = {
|
||
'entry_date': date.strftime('%Y-%m-%d'),
|
||
'entry_price_etf': entry_etf,
|
||
'entry_price_idx': idx_close,
|
||
}
|
||
for code in removed:
|
||
tracked_entry.pop(code, None)
|
||
|
||
# Build signals dict: {code: 1} for selected holdings
|
||
signals = {c: 1 for c in holdings}
|
||
|
||
# Build per-asset details
|
||
assets = self._build_day_assets(rec, date, tracked_entry)
|
||
|
||
days_out.append({
|
||
'date': rec['date'],
|
||
'nav': rec['nav'],
|
||
'daily_return': rec['daily_return'],
|
||
'is_rebalance': rec['is_rebalance'],
|
||
'signals': signals,
|
||
'holdings': holdings,
|
||
'added': rec['added'],
|
||
'removed': rec['removed'],
|
||
'assets': assets,
|
||
})
|
||
prev_holdings = holdings
|
||
|
||
detail_data = {
|
||
'meta': {
|
||
'mode': 'Simple: Daily Iteration',
|
||
'start_date': self.config.backtest.start_date,
|
||
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
|
||
'total_days': len(self.daily_records),
|
||
'select_num': self.select_num,
|
||
'n_days': self.n_days,
|
||
'trade_cost': self.trade_cost,
|
||
'bond_threshold': {
|
||
'enabled': self.use_dynamic_threshold,
|
||
'bond_code': self.bond_code,
|
||
'ratio': self.bond_ratio,
|
||
},
|
||
'codes': self._build_meta_codes(),
|
||
},
|
||
'days': days_out,
|
||
}
|
||
_sanitize_json(detail_data)
|
||
with open(detail_path, 'w', encoding='utf-8') as f:
|
||
json.dump(detail_data, f, ensure_ascii=False, indent=2)
|
||
print(f" + Detail: {detail_path} ({len(days_out)} days)")
|
||
|
||
# Metrics JSON
|
||
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
|
||
metrics_path = output_dir / 'simple_rotation_metrics.json'
|
||
_sanitize_json(metrics)
|
||
with open(metrics_path, 'w', encoding='utf-8') as f:
|
||
json.dump(metrics, f, ensure_ascii=False, indent=2)
|
||
print(f" + Metrics: {metrics_path}")
|
||
|
||
# ============================================================
|
||
# Report Generation (PNG chart with tables)
|
||
# ============================================================
|
||
|
||
def generate_report(self, output_dir: str = None):
|
||
"""Generate performance report chart (PNG) with signal table, metrics, NAV, drawdown, holdings"""
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
|
||
if not self.daily_records:
|
||
print(" x No results for report")
|
||
return
|
||
|
||
if output_dir is None:
|
||
output_dir = Path(__file__).parent / 'results'
|
||
output_dir = Path(output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# Build DataFrames
|
||
df = pd.DataFrame(self.daily_records)
|
||
df['date'] = pd.to_datetime(df['date'])
|
||
df = df.set_index('date')
|
||
strategy_nav = df['nav']
|
||
strategy_ret = df['daily_return']
|
||
|
||
# Compute benchmark NAV
|
||
benchmark_nav = None
|
||
if self.benchmark_data is not None:
|
||
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
|
||
if bm_close is not None and not bm_close.isna().all():
|
||
benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
|
||
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
|
||
benchmark_nav = benchmark_nav / first_valid
|
||
|
||
# Compute individual asset NAVs
|
||
asset_navs = {}
|
||
for code in self.signal_codes:
|
||
if code in self.index_data:
|
||
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
|
||
if price is not None and not price.isna().all():
|
||
nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
|
||
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
|
||
asset_navs[code] = nav_s / fv
|
||
|
||
# Compute metrics
|
||
s_total_return = strategy_nav.iloc[-1] / strategy_nav.iloc[0] - 1
|
||
n_days = len(df)
|
||
s_annual = (1 + s_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||
s_sharpe = strategy_ret.mean() / strategy_ret.std() * np.sqrt(252) if strategy_ret.std() > 0 else 0
|
||
s_peak = strategy_nav.cummax()
|
||
s_dd = ((strategy_nav - s_peak) / s_peak).min()
|
||
s_calmar = s_annual / abs(s_dd) if s_dd != 0 else 0
|
||
non_zero = strategy_ret[strategy_ret != 0]
|
||
s_win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
|
||
|
||
# Benchmark metrics
|
||
b_total_return = b_annual = b_sharpe = b_dd = 0
|
||
if benchmark_nav is not None:
|
||
bm_ret = benchmark_nav.pct_change(fill_method=None)
|
||
b_total_return = benchmark_nav.iloc[-1] - 1
|
||
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
|
||
b_peak = benchmark_nav.cummax()
|
||
b_dd = ((benchmark_nav - b_peak) / b_peak).min()
|
||
|
||
# Get latest holdings info
|
||
last_rec = self.daily_records[-1]
|
||
last_date = pd.Timestamp(last_rec['date'])
|
||
holdings = last_rec['holdings']
|
||
factors = last_rec.get('factors', {})
|
||
|
||
# Check if market has opened for last_date (ETF data available for exact date)
|
||
market_opened = False
|
||
for code in holdings:
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
if trade_code in self.etf_data:
|
||
etf_df = self.etf_data[trade_code]
|
||
if len(etf_df) > 0 and etf_df.index[-1] >= last_date:
|
||
market_opened = True
|
||
break
|
||
|
||
# Build code config dict for display
|
||
code_config = {}
|
||
for name, asset in self.config.asset_pools.assets.items():
|
||
code_config[asset.signal_source] = {
|
||
'name': asset.name if hasattr(asset, 'name') else name,
|
||
'etf': asset.trade_source,
|
||
'market': asset.group,
|
||
}
|
||
|
||
# Build positions info for table
|
||
# Sort holdings by momentum score descending
|
||
position_weights = last_rec.get('position_weights', {})
|
||
sorted_holdings = sorted(holdings, key=lambda c: factors.get(c, 0) or 0, reverse=True)
|
||
|
||
# Determine previous holdings to distinguish "调入" vs "维持"
|
||
last_idx = len(self.daily_records) - 1
|
||
prev_holdings = set()
|
||
if last_idx > 0:
|
||
prev_holdings = set(self.daily_records[last_idx - 1]['holdings'])
|
||
is_rebalance_day = last_rec.get('is_rebalance', False)
|
||
|
||
positions_info = []
|
||
for code in sorted_holdings:
|
||
cfg = code_config.get(code, {})
|
||
name = cfg.get('name', code)
|
||
etf_code = cfg.get('etf', '—')
|
||
score = factors.get(code)
|
||
idx_close = self._get_index_close(code, last_date)
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
etf_close = self._get_etf_close(trade_code, last_date)
|
||
premium = self._get_latest_premium(trade_code, last_date)
|
||
|
||
# Determine action: 调入 (new on rebalance day) vs 维持 (already holding)
|
||
if is_rebalance_day and code not in prev_holdings:
|
||
action = '调入'
|
||
else:
|
||
action = '维持'
|
||
|
||
# Find entry info: scan backwards for continuous holding start
|
||
entry_date = None
|
||
entry_price = None
|
||
holding_days = 0
|
||
pnl = None
|
||
is_new_entry = (is_rebalance_day and code not in prev_holdings)
|
||
# New entries on a day when market hasn't opened yet: no execution data
|
||
if is_new_entry and not market_opened:
|
||
entry_date = None
|
||
entry_price = None
|
||
holding_days = 0
|
||
pnl = None
|
||
else:
|
||
for rec in reversed(self.daily_records):
|
||
if code in rec['holdings']:
|
||
rec_date = pd.Timestamp(rec['date'])
|
||
p = self._get_etf_prices(trade_code, rec_date)
|
||
if p is not None:
|
||
entry_date = rec_date
|
||
entry_price = p['open']
|
||
else:
|
||
break
|
||
if entry_date is not None:
|
||
holding_days = (last_date - entry_date).days
|
||
# For maintained positions, use latest available close for pnl
|
||
if entry_price and entry_price > 0 and etf_close and etf_close > 0:
|
||
pnl = etf_close / entry_price - 1
|
||
|
||
positions_info.append({
|
||
'name': name, 'code': code, 'etf': etf_code,
|
||
'weight': position_weights.get(code, 1.0 / len(holdings)), 'score': score,
|
||
'idx_close': idx_close, 'etf_close': etf_close,
|
||
'premium': premium, 'action': action,
|
||
'entry_date': entry_date, 'entry_price': entry_price,
|
||
'holding_days': holding_days, 'pnl': pnl,
|
||
'exit_date': None, 'exit_price': None,
|
||
})
|
||
|
||
# Build exit positions: ONLY when the last day is a rebalance day
|
||
exit_positions = []
|
||
if last_rec.get('is_rebalance', False):
|
||
last_idx = len(self.daily_records) - 1
|
||
if last_idx > 0:
|
||
old_holdings = set(self.daily_records[last_idx - 1]['holdings'])
|
||
new_holdings = set(holdings)
|
||
removed = old_holdings - new_holdings
|
||
prev_rec = self.daily_records[last_idx - 1]
|
||
prev_weights = prev_rec.get('position_weights', {})
|
||
for code in sorted(removed):
|
||
cfg = code_config.get(code, {})
|
||
name = cfg.get('name', code)
|
||
etf_code = cfg.get('etf', '—')
|
||
idx_close = self._get_index_close(code, last_date)
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
etf_close = self._get_etf_close(trade_code, last_date)
|
||
premium = self._get_latest_premium(trade_code, last_date)
|
||
score = factors.get(code)
|
||
# Recover holding info from historical records
|
||
exit_entry_date = None
|
||
exit_entry_price = None
|
||
exit_holding_days = 0
|
||
exit_pnl = None
|
||
sell_price = None
|
||
for rec in reversed(self.daily_records[:last_idx]):
|
||
if code in rec['holdings']:
|
||
exit_entry_date = pd.Timestamp(rec['date'])
|
||
p = self._get_etf_prices(trade_code, exit_entry_date)
|
||
exit_entry_price = p['open'] if p else None
|
||
else:
|
||
break
|
||
if exit_entry_date is not None:
|
||
exit_holding_days = (last_date - exit_entry_date).days
|
||
# Exit price: today's open if market opened, else prev_close
|
||
exit_prices = self._get_etf_prices(trade_code, last_date)
|
||
if market_opened and exit_prices:
|
||
sell_price = exit_prices['open']
|
||
elif exit_prices:
|
||
sell_price = exit_prices['prev_close']
|
||
else:
|
||
sell_price = None
|
||
if exit_entry_price and exit_entry_price > 0 and sell_price and sell_price > 0:
|
||
exit_pnl = sell_price / exit_entry_price - 1
|
||
exit_positions.append({
|
||
'name': name, 'code': code, 'etf': etf_code,
|
||
'weight': prev_weights.get(code, 0.0), 'score': score,
|
||
'idx_close': idx_close, 'etf_close': etf_close,
|
||
'premium': premium, 'action': '调出',
|
||
'entry_date': exit_entry_date, 'entry_price': exit_entry_price,
|
||
'holding_days': exit_holding_days, 'pnl': exit_pnl,
|
||
'exit_date': last_date, 'exit_price': sell_price,
|
||
})
|
||
|
||
# Build unselected (未入选) positions: all signal codes not held and not exited
|
||
held_or_exited = set(holdings) | {p['code'] for p in exit_positions}
|
||
unselected_positions = []
|
||
# Sort unselected by momentum descending
|
||
unselected_codes = [c for c in self.signal_codes if c not in held_or_exited]
|
||
unselected_codes.sort(key=lambda c: factors.get(c, -999), reverse=True)
|
||
for code in unselected_codes:
|
||
cfg = code_config.get(code, {})
|
||
name = cfg.get('name', code)
|
||
etf_code = cfg.get('etf', '—')
|
||
score = factors.get(code)
|
||
idx_close = self._get_index_close(code, last_date)
|
||
trade_code = self.signal_to_trade.get(code, code)
|
||
etf_close = self._get_etf_close(trade_code, last_date)
|
||
premium = self._get_latest_premium(trade_code, last_date)
|
||
unselected_positions.append({
|
||
'name': name, 'code': code, 'etf': etf_code,
|
||
'weight': 0, 'score': score,
|
||
'idx_close': idx_close, 'etf_close': etf_close,
|
||
'premium': premium, 'action': '未入选',
|
||
'entry_date': None, 'entry_price': None,
|
||
'holding_days': 0, 'pnl': None,
|
||
'exit_date': None, 'exit_price': None,
|
||
})
|
||
|
||
# ==================== Plot ====================
|
||
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "WenQuanYi Zen Hei", "DejaVu Sans"]
|
||
plt.rcParams["axes.unicode_minus"] = False
|
||
|
||
n_rows = len(positions_info) + len(exit_positions) + len(unselected_positions)
|
||
signal_h = max(1.5, 0.5 + n_rows * 0.35)
|
||
fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 1.5 + 8))
|
||
gs = fig.add_gridspec(6, 1, height_ratios=[signal_h, 1.2, 1.5, 3, 1, 1.2], hspace=0.35)
|
||
|
||
# Panel 0: Full asset ranking table
|
||
ax0 = fig.add_subplot(gs[0])
|
||
ax0.axis("off")
|
||
threshold_val = last_rec.get('threshold', 0.0)
|
||
ax0.set_title(f"全标的动量排名 (信号日期: {last_date.strftime('%Y-%m-%d')},阈值: {threshold_val:.4f})",
|
||
fontsize=14, fontweight="bold", loc="left", pad=15)
|
||
|
||
# Build rank map for all codes by momentum
|
||
ranked_all = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
|
||
rank_map = {c: i + 1 for i, c in enumerate(ranked_all)}
|
||
|
||
col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分",
|
||
"指数最新价", "ETF收盘价", "溢价率", "状态",
|
||
"进场日期", "持有天数", "盈亏", "退场日期", "退场价格"]
|
||
table_data = []
|
||
row_actions = [] # track action for coloring
|
||
|
||
# Ordered: 调入 -> 维持 -> 调出 -> 未入选
|
||
all_rows = positions_info + exit_positions + unselected_positions
|
||
for p in all_rows:
|
||
rank = rank_map.get(p['code'], '—')
|
||
score_s = f"{p['score']:.4f}" if p['score'] is not None else "—"
|
||
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else "—"
|
||
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else "—"
|
||
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else "—"
|
||
entry_date_s = p['entry_date'].strftime('%Y-%m-%d') if p['entry_date'] else "—"
|
||
days_s = str(p['holding_days']) if p['holding_days'] > 0 else "—"
|
||
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else "—"
|
||
weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else "—"
|
||
market = code_config.get(p['code'], {}).get('market', '—')
|
||
exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else "—"
|
||
exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else "—"
|
||
table_data.append([
|
||
rank, p['name'], market, p['code'], p['etf'], weight_s,
|
||
score_s, idx_s, etf_s, prem_s, p['action'],
|
||
entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s
|
||
])
|
||
row_actions.append(p['action'])
|
||
|
||
if table_data:
|
||
tbl = ax0.table(cellText=table_data, colLabels=col_labels, loc="center",
|
||
cellLoc="center", bbox=[0, 0, 1, 1])
|
||
tbl.auto_set_font_size(False)
|
||
tbl.set_fontsize(8)
|
||
tbl.scale(1, 1.6)
|
||
for j in range(len(col_labels)):
|
||
tbl[0, j].set_facecolor("#2C3E50")
|
||
tbl[0, j].set_text_props(color="white", fontweight="bold")
|
||
action_colors = {
|
||
"调入": "#d4edda", # 浅绿色
|
||
"调出": "#f8d7da", # 浅红色
|
||
"维持": "#fff3cd", # 浅黄色
|
||
"未入选": "#f0f0f0", # 浅灰色
|
||
}
|
||
for i in range(len(table_data)):
|
||
color = action_colors.get(row_actions[i], "#ffffff")
|
||
for j in range(len(col_labels)):
|
||
tbl[i + 1, j].set_facecolor(color)
|
||
|
||
# Panel 1: Metrics table
|
||
ax1 = fig.add_subplot(gs[1])
|
||
ax1.axis("off")
|
||
ax1.set_title("策略绩效对比", fontsize=14, fontweight="bold", loc="left", pad=10)
|
||
|
||
start_s = df.index.min().strftime('%Y-%m-%d')
|
||
end_s = df.index.max().strftime('%Y-%m-%d')
|
||
perf_cols = ["策略", "开始时间", "结束时间", "累计收益", "年化收益", "最大回撤", "夏普比率", "Calmar比率", "日胜率"]
|
||
strat_row = ["轮动策略", start_s, end_s,
|
||
f"{s_total_return:.2%}", f"{s_annual:.2%}", f"{s_dd:.2%}",
|
||
f"{s_sharpe:.2f}", f"{s_calmar:.2f}", f"{s_win_rate:.2%}"]
|
||
bench_row = [f"基准({self.benchmark_name})", start_s, end_s,
|
||
f"{b_total_return:.2%}", f"{b_annual:.2%}", f"{b_dd:.2%}",
|
||
f"{b_sharpe:.2f}", "—", "—"]
|
||
ptbl = ax1.table(cellText=[strat_row, bench_row], colLabels=perf_cols,
|
||
loc="center", cellLoc="center", bbox=[0, 0, 1, 1])
|
||
ptbl.auto_set_font_size(False)
|
||
ptbl.set_fontsize(10)
|
||
ptbl.scale(1, 1.8)
|
||
for j in range(len(perf_cols)):
|
||
ptbl[0, j].set_facecolor("#2C3E50")
|
||
ptbl[0, j].set_text_props(color="white", fontweight="bold")
|
||
ptbl[1, j].set_facecolor("#d4edda")
|
||
ptbl[2, j].set_facecolor("#cce5ff")
|
||
|
||
# Panel 2: Monthly returns heatmap (year × month matrix)
|
||
ax2 = fig.add_subplot(gs[2])
|
||
ax2.axis('off')
|
||
ax2.set_title('月度收益矩阵 (%)', fontsize=14, fontweight='bold', loc='left', pad=10)
|
||
# Resample to monthly NAV (last business day)
|
||
monthly_nav = strategy_nav.resample('ME').last().dropna()
|
||
monthly_ret = monthly_nav.pct_change().dropna()
|
||
# Build year-month matrix
|
||
monthly_ret_pct = (monthly_ret * 100).round(2)
|
||
year_month_df = pd.DataFrame({
|
||
'year': monthly_ret_pct.index.year,
|
||
'month': monthly_ret_pct.index.month,
|
||
'ret': monthly_ret_pct.values,
|
||
})
|
||
years = sorted(year_month_df['year'].unique())
|
||
month_labels = [f'{m}月' for m in range(1, 13)]
|
||
col_labels = month_labels + ['年度']
|
||
# Build table data
|
||
table_data = []
|
||
for yr in years:
|
||
row = []
|
||
yr_total = 0.0
|
||
yr_count = 0
|
||
for m in range(1, 13):
|
||
mask = (year_month_df['year'] == yr) & (year_month_df['month'] == m)
|
||
vals = year_month_df.loc[mask, 'ret'].values
|
||
if len(vals) > 0:
|
||
v = float(vals[0])
|
||
row.append(v)
|
||
yr_total = (1 + yr_total / 100) * (1 + v / 100) * 100 - 100
|
||
yr_count += 1
|
||
else:
|
||
row.append(None)
|
||
row.append(round(yr_total, 2) if yr_count > 0 else None)
|
||
table_data.append(row)
|
||
|
||
if table_data:
|
||
cell_text = []
|
||
for row in table_data:
|
||
cell_text.append([f'{v:+.1f}' if v is not None else '—' for v in row])
|
||
tbl = ax2.table(cellText=cell_text, rowLabels=[str(y) for y in years],
|
||
colLabels=col_labels, loc='center', cellLoc='center', bbox=[0, 0, 1, 1])
|
||
tbl.auto_set_font_size(False)
|
||
tbl.set_fontsize(9)
|
||
tbl.scale(1, 1.6)
|
||
# Style header (colLabels row)
|
||
for j in range(len(col_labels)):
|
||
tbl[0, j].set_facecolor('#2C3E50')
|
||
tbl[0, j].set_text_props(color='white', fontweight='bold')
|
||
# Color cells based on value (sqrt mapping for better visibility of small returns)
|
||
max_abs = max(abs(v) for row in table_data for v in row if v is not None) or 1
|
||
for i, row in enumerate(table_data):
|
||
for j, v in enumerate(row):
|
||
if v is not None:
|
||
intensity = min((abs(v) / max_abs) ** 0.5, 1.0)
|
||
if v >= 0:
|
||
r, g, b = 1.0, 1.0 - intensity * 0.55, 1.0 - intensity * 0.55
|
||
else:
|
||
r, g, b = 1.0 - intensity * 0.55, 1.0, 1.0 - intensity * 0.55
|
||
tbl[i + 1, j].set_facecolor((r, g, b))
|
||
else:
|
||
tbl[i + 1, j].set_facecolor('#f5f5f5')
|
||
# Annual column: stronger coloring
|
||
ann_v = row[-1]
|
||
if ann_v is not None:
|
||
intensity = min((abs(ann_v) / max_abs) ** 0.5, 1.0)
|
||
if ann_v >= 0:
|
||
r, g, b = 1.0, 1.0 - intensity * 0.7, 1.0 - intensity * 0.7
|
||
else:
|
||
r, g, b = 1.0 - intensity * 0.7, 1.0, 1.0 - intensity * 0.7
|
||
tbl[i + 1, len(col_labels) - 1].set_facecolor((r, g, b))
|
||
|
||
# Panel 3: NAV curves
|
||
ax3 = fig.add_subplot(gs[3])
|
||
ax3.plot(strategy_nav.index, strategy_nav.values,
|
||
label="轮动策略", linewidth=2, color="#E74C3C")
|
||
if benchmark_nav is not None:
|
||
ax3.plot(benchmark_nav.index, benchmark_nav.values,
|
||
label=self.benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8)
|
||
colors = plt.cm.tab20.colors
|
||
for i, code in enumerate(self.signal_codes):
|
||
if code in asset_navs:
|
||
cfg = code_config.get(code, {})
|
||
lbl = cfg.get('name', code) if i < 10 else None
|
||
ax3.plot(asset_navs[code].index, asset_navs[code].values,
|
||
label=lbl, linewidth=0.8, alpha=0.4,
|
||
color=colors[i % len(colors)])
|
||
ax3.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold")
|
||
ax3.set_ylabel("净值")
|
||
ax3.legend(loc="upper left", fontsize=8, ncol=2)
|
||
ax3.grid(True, alpha=0.3)
|
||
ax3.set_yscale("log")
|
||
|
||
# Panel 4: Drawdown
|
||
ax4 = fig.add_subplot(gs[4])
|
||
drawdown = (strategy_nav - s_peak) / s_peak
|
||
ax4.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C")
|
||
ax4.set_title("策略回撤", fontsize=12)
|
||
ax4.set_ylabel("回撤")
|
||
ax4.grid(True, alpha=0.3)
|
||
|
||
# Panel 5: Holdings distribution
|
||
ax5 = fig.add_subplot(gs[5])
|
||
holdings_series = df['holdings']
|
||
for i, code in enumerate(self.signal_codes):
|
||
cfg = code_config.get(code, {})
|
||
name = cfg.get('name', code)
|
||
mask = holdings_series.apply(lambda h: code in h)
|
||
if mask.any():
|
||
ax5.fill_between(mask.index, i, i + 0.8,
|
||
where=mask, alpha=0.7,
|
||
color=colors[i % len(colors)], label=name)
|
||
ylabels = [code_config.get(c, {}).get('name', c) for c in self.signal_codes]
|
||
ax5.set_title("每日持仓分布", fontsize=12)
|
||
ax5.set_yticks(range(len(ylabels)))
|
||
ax5.set_yticklabels(ylabels, fontsize=7)
|
||
ax5.grid(True, alpha=0.3)
|
||
|
||
chart_path = output_dir / 'simple_rotation_report.png'
|
||
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
|
||
plt.close()
|
||
print(f" + Report: {chart_path}")
|
||
|
||
|
||
# ============================================================
|
||
# Entry point
|
||
# ============================================================
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
if 'FLASK_API_URL' not in os.environ:
|
||
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
|
||
|
||
parser = argparse.ArgumentParser(description='Simple Rotation Strategy Backtest')
|
||
parser.add_argument('--no-detail', action='store_true',
|
||
help='Skip detail JSON export (faster, for daily runs)')
|
||
parser.add_argument('--no-report', action='store_true',
|
||
help='Skip report PNG generation')
|
||
args = parser.parse_args()
|
||
|
||
strategy = SimpleRotationStrategy()
|
||
result = strategy.run()
|
||
|
||
if result:
|
||
strategy.export_results(detail=not args.no_detail)
|
||
if not args.no_report:
|
||
strategy.generate_report()
|