Files
etf/rotation/simple_rotation.py
aszerW 8b7bcf206a feat(weight): 实现 Kelly 仓位权重模式
- config_loader.py: WeightType 枚举新增 KELLY
- simple_rotation.py: compute_position_weights 新增 kelly 分支
  - 公式: w_i = max(score_i, 0) / sum(max(score_j, 0))
  - 负分自动排除 (Kelly: 不下注负期望)
  - 全负分时 fallback 到等权
- _generate_signals 传递 scores 给 kelly 模式
- config_simple.yaml: weight 改为 kelly
- 新增策略总结文档: kelly_weight.md

回测对比 (2020-2026):
- equal: 年化 19.88%, 夏普 1.13, 回撤 -14.65%
- rank:  年化 22.90%, 夏普 1.12, 回撤 -16.27%
- kelly: 年化 30.13%, 夏普 1.15, 回撤 -20.44%
2026-06-08 23:05:26 +08:00

1628 lines
71 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Simple Rotation Strategy (Daily Iteration)
From A-share trading calendar, iterate daily:
1. Get signal source last N days -> compute momentum
2. Get bond momentum (dynamic threshold)
3. Group selection -> generate holdings
4. Compare with yesterday -> compute T+1 return
5. Update NAV
"""
import os
import sys
import math
import json
import time
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType
# ============================================================
# HTTP client (requests + trust_env=False绕过系统代理避免 SSL EOF)
# ============================================================
# Clash 等代理在处理 TLS 1.3 + 后量子密钥交换时会触发 SSL EOF 错误
# trust_env=False 让 requests 忽略环境变量中的代理配置,直连目标服务器
_session = requests.Session()
_session.trust_env = False
def _http_get(url: str, params: dict = None, timeout: int = 120) -> requests.Response:
"""使用 requests 发起 GET 请求trust_env=False 绕过系统代理)"""
return _session.get(url, params=params, timeout=timeout)
def _sanitize_json(obj):
"""Recursively replace NaN/Inf with None in-place so json.dump produces valid JSON"""
if isinstance(obj, dict):
for k, v in obj.items():
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
obj[k] = None
elif isinstance(v, (dict, list)):
_sanitize_json(v)
elif isinstance(obj, list):
for i, v in enumerate(obj):
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
obj[i] = None
elif isinstance(v, (dict, list)):
_sanitize_json(v)
# ============================================================
# Pure functions: momentum
# ============================================================
def weighted_momentum_score(prices: np.ndarray) -> float:
"""Weighted linear regression momentum score = annualized_return * R^2"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_return * r2
def vol_adjusted_momentum_score(prices: np.ndarray) -> float:
"""Volatility-adjusted momentum score = (annualized_return / annualized_vol) * R^2
Moskowitz, Ooi & Pedersen (2012) TSMOM approach:
divide momentum by realized volatility to make cross-asset comparison fair.
"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
# realized volatility (annualized)
daily_returns = np.diff(y)
realized_vol = float(np.std(daily_returns)) * math.sqrt(250)
if realized_vol < 0.01: # guard against near-zero vol
realized_vol = 0.01
# trend quality (R^2)
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return (annualized_return / realized_vol) * r2
def slope_r2_score(prices: np.ndarray) -> float:
"""Slope * R² score (unweighted, normalized prices)"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = prices / prices[0] # normalize
x = np.arange(len(y))
slope, intercept = np.polyfit(x, y, 1)
y_pred = slope * x + intercept
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return 10000 * slope * r2
def standardized_slope_score(prices: np.ndarray) -> float:
"""Standardized slope (t-statistic): slope / SE(slope)
Academic basis:
- Uses normalized prices (p/p[0]) for cross-asset comparability,
consistent with slope_r2_score.
- Divides slope by its standard error, yielding a statistical significance
measure rather than raw magnitude.
- Equivalent to the t-value for H0: slope=0, penalizing noisy trends.
Formula:
SE(slope) = sqrt(MSE / Sxx)
MSE = SS_res / (n - 2)
Sxx = sum((xi - x_bar)^2) = n*(n-1)*(n+1)/12 for x = 0..n-1
"""
n = len(prices)
if n < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = prices / prices[0] # normalize
x = np.arange(n)
slope, intercept = np.polyfit(x, y, 1)
y_pred = slope * x + intercept
ss_res = np.sum((y - y_pred) ** 2)
# Standard error of slope
mse = ss_res / (n - 2) # unbiased MSE
sxx = n * (n - 1) * (n + 1) / 12 # sum of squared deviations of x
se_slope = math.sqrt(mse / sxx) if sxx > 0 else 1e-9
if se_slope < 1e-12:
se_slope = 1e-12
return slope / se_slope
def momentum_score(prices: np.ndarray) -> float:
"""Simple price return: (last / first) - 1"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
return prices[-1] / prices[0] - 1
# ============================================================
# Pure functions: position weight schemes
# ============================================================
def compute_position_weights(
ranked_holdings: List[str],
weight_type: str = 'equal',
scores: Dict[str, float] = None,
) -> Dict[str, float]:
"""Compute position weights from ranked slot list.
Args:
ranked_holdings: Ordered list of signal codes, best first.
May contain duplicates (e.g. bond fills).
weight_type: 'equal', 'rank', or 'kelly'.
scores: Required for 'kelly'. Dict mapping code -> momentum score.
Returns:
Dict mapping each unique code to its total weight (sum of slots).
Schemes:
equal: each slot = 1/N, duplicates summed.
rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed.
For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%].
kelly: w_i = max(score_i, 0) / sum(max(score_j, 0)).
Score-proportional weighting as Kelly criterion proxy.
Negative scores excluded (Kelly: don't bet on negative edge).
"""
N = len(ranked_holdings)
if N == 0:
return {}
weights: Dict[str, float] = {}
if weight_type == 'kelly':
if not scores:
raise ValueError("Kelly weighting requires 'scores' parameter")
# Kelly proxy: weight proportional to positive scores
positive_scores = {c: max(scores.get(c, 0.0), 0.0) for c in set(ranked_holdings)}
total = sum(positive_scores.values())
if total <= 0:
# Fallback to equal if all scores non-positive
w = 1.0 / len(positive_scores)
for code in positive_scores:
weights[code] = w
else:
for code in ranked_holdings:
w = positive_scores.get(code, 0.0) / total
weights[code] = weights.get(code, 0.0) + w
elif weight_type == 'rank':
triangular = N * (N + 1) / 2
for i, code in enumerate(ranked_holdings):
w = (N - i) / triangular
weights[code] = weights.get(code, 0.0) + w
else:
# equal (default)
w = 1.0 / N
for code in ranked_holdings:
weights[code] = weights.get(code, 0.0) + w
return weights
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
# ============================================================
# Data cache
# ============================================================
class DataCache:
"""Data fetching (no local cache, always from API)"""
def __init__(self, base_url: str, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
# premium data in memory: {trade_code: {date_str: premium_ratio}}
self.premium_data: Dict[str, Dict[str, float]] = {}
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Fetch data from API"""
return self._fetch_api(code, start_date, end_date, adj)
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API, also extracts premium_series for ETFs"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
print(f" x {code}: HTTP {resp.status_code}")
return None
data = resp.json()
if 'error' in data:
print(f" x {code}: {data['error']}")
return None
records = data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
# Extract premium_series (ETF only) and store in memory
premium_series = data.get('premium_series', [])
if premium_series:
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(df.attrs['premium_series'])
print(f" + {code}: {len(df)} rows ({adj})")
return df
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
# 网络相关错误超时、SSL、连接断开都进行重试
if attempt < 2:
time.sleep(1 + attempt) # 递增延迟: 1s, 2s
continue
print(f" x {code}: {type(e).__name__} after {attempt+1} retries")
return None
except Exception as e:
print(f" x {code}: {e}")
return None
return None
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
"""Load premium data for an ETF code, fetch from API if not in memory."""
if code in self.premium_data:
if end_date:
dates = sorted(self.premium_data[code].keys())
if dates and dates[-1] >= end_date:
return self.premium_data[code]
else:
return self.premium_data[code]
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
return self.premium_data.get(code)
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
"""Fetch premium_series from API and store in memory"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1 + attempt)
continue
return
data = resp.json()
if 'error' in data:
return
premium_series = data.get('premium_series', [])
if premium_series:
new_data = {item['date']: item['premium'] for item in premium_series}
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(new_data)
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
return
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
if attempt < 2:
time.sleep(1 + attempt)
continue
return
except Exception:
return
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
"""Fetch trading calendar from API"""
url = f"{self.base_url}/api/v1/trading-calendar"
params = {'market': market, 'start': start_date, 'end': end_date}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1 + attempt)
continue
return None
data = resp.json()
if 'error' in data:
print(f" x calendar: {data['error']}")
return None
dates = data.get('trading_dates', [])
if not dates:
return pd.DatetimeIndex([])
result = pd.DatetimeIndex(dates)
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
return result
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
if attempt < 2:
time.sleep(1 + attempt)
continue
print(f" x calendar: {type(e).__name__} after {attempt+1} retries")
return None
except Exception as e:
if attempt < 2:
time.sleep(1 + attempt)
continue
print(f" x calendar: {e}")
return None
return None
# ============================================================
# Core strategy class
# ============================================================
class SimpleRotationStrategy:
"""
Simple rotation strategy (daily iteration)
Flow:
1. Load config
2. Preload all historical data (index raw + ETF hfq)
3. Fetch A-share trading calendar
4. For each trading day:
- Compute momentum (last 25 days)
- Group selection -> generate holdings
- Compare with yesterday -> compute T+1 return
- Update NAV
5. Export results
"""
def __init__(self, config_path: str = None):
if config_path is None:
config_path = Path(__file__).parent / 'config_simple.yaml'
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
# Strategy params
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
# Dynamic threshold
threshold = self.config.rotation.threshold
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
# Signal codes and mappings
self.signal_codes = []
self.signal_to_trade = {}
self.code_to_group = {}
self.trade_code_to_group = {}
for code, asset in self.config.asset_pools.assets.items():
self.signal_codes.append(asset.signal_source)
self.signal_to_trade[asset.signal_source] = asset.trade_source
self.code_to_group[asset.signal_source] = asset.group
self.trade_code_to_group[asset.trade_source] = asset.group
# Data source
data_source = self.config.data.sources[0]
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
# Preloaded data
self.index_data: Dict[str, pd.DataFrame] = {}
self.etf_data: Dict[str, pd.DataFrame] = {}
self.benchmark_data: Optional[pd.DataFrame] = None
# Benchmark config
self.benchmark_code = self.config.benchmark.code
self.benchmark_name = self.config.benchmark.name
# Results
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
# Position weights: code -> weight (updated each day by _generate_signals)
self._position_weights: Dict[str, float] = {}
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
print("\n[1/4] Preloading signal sources (index raw)...")
for code in self.signal_codes:
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
if df is not None:
self.index_data[code] = df
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
print("\n[2/4] Preloading trade sources (ETF hfq)...")
trade_codes = set(self.signal_to_trade.values())
for code in trade_codes:
is_bond = any(
a.trade_source == code and a.group == 'BOND'
for a in self.config.asset_pools.assets.values()
)
adj = 'raw' if is_bond else 'hfq'
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
if df is not None:
self.etf_data[code] = df
# Load premium data cache for all ETF trade codes
for code in trade_codes:
self.data_cache.preload_premium(code, end_date=end_date)
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded")
# Load benchmark
print(f"\n Loading benchmark ({self.benchmark_code})...")
bm_df = self.data_cache.preload(self.benchmark_code, preload_start, end_date, adj='raw')
if bm_df is not None:
self.benchmark_data = bm_df
print(f" Benchmark: {len(bm_df)} rows")
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum for a single code on a given date (uses configured score function)"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
# Dispatch based on factor type
ft = self.config.factor.type
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
return vol_adjusted_momentum_score(prices)
elif ft == FactorType.SLOPE_R2:
return slope_r2_score(prices)
elif ft == FactorType.STANDARDIZED_SLOPE:
return standardized_slope_score(prices)
elif ft == FactorType.MOMENTUM:
return momentum_score(prices)
return weighted_momentum_score(prices)
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum using the configured factor function.
For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score
so that threshold comparison operates in raw return space."""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
return weighted_momentum_score(prices)
# All other factors: use the same score function as ranking
return self._compute_momentum(signal_code, date)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
Generate rotation signals (group competition + dynamic threshold + BOND fill)
Returns:
(holdings, factors, bond_momentum): tuple of selected holdings,
all computed factor scores, and the bond momentum threshold value.
Logic (identical to V2):
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
2. From all group winners: sort by momentum, select Top select_num
3. Fill remaining slots with BOND
"""
factors: Dict[str, float] = {}
for code in self.signal_codes:
score = self._compute_momentum(code, date)
if score is not None:
factors[code] = score
if not factors:
return [], {}, None
# Bond threshold always uses raw (weighted) momentum
bond_momentum = None
if self.use_dynamic_threshold and self.bond_code:
bond_momentum = self._compute_raw_momentum(self.bond_code, date)
if bond_momentum is None:
bond_momentum = self.fallback_value
# Raw factors for threshold comparison
raw_factors: Dict[str, float] = {}
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
for code in self.signal_codes:
score = self._compute_raw_momentum(code, date)
if score is not None:
raw_factors[code] = score
else:
raw_factors = factors
groups = self.config.asset_pools.by_group
selected_by_group: Dict[str, Tuple[str, float]] = {}
for group_name, assets in groups.items():
group_codes = [a.signal_source for a in assets.values()]
group_factors = {c: factors[c] for c in group_codes if c in factors}
group_raw = {c: raw_factors[c] for c in group_codes if c in raw_factors}
if not group_factors:
continue
# Threshold comparison uses raw (weighted) momentum
if group_name != 'BOND' and bond_momentum is not None:
thresh = bond_momentum * self.bond_ratio
group_factors = {c: s for c, s in group_factors.items()
if group_raw.get(c, 0) >= thresh}
if not group_factors:
continue
top_code = max(group_factors, key=group_factors.get)
selected_by_group[group_name] = (top_code, group_factors[top_code])
if not selected_by_group:
return [], factors, bond_momentum
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
ranked_holdings = [code for code, _ in candidates[:self.select_num]]
if len(ranked_holdings) < self.select_num and self.bond_code:
if self.bond_code not in ranked_holdings:
n_slots = self.select_num - len(ranked_holdings)
ranked_holdings.extend([self.bond_code] * n_slots)
# Compute position weights via configured scheme.
# These are *pending* weights; the caller (run) locks them in
# only when an actual rebalance occurs.
self._pending_weights = compute_position_weights(
ranked_holdings, self.weight_type, scores=factors,
)
return sorted(ranked_holdings), factors, bond_momentum
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
Note: Index data may lack 'open' column; use close as fallback.
"""
if trade_code not in self.etf_data:
return None
df = self.etf_data[trade_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < 2:
return None
today_row = recent.iloc[-1]
prev_row = recent.iloc[-2]
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close_raw = today_row.get('close')
prev_close_raw = prev_row.get('close')
if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
return None
close = float(close_raw)
prev_close = float(prev_close_raw)
# Handle missing/invalid open (common for index data like 931862.CSI)
raw_open = today_row.get('open')
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
if pd.isna(open_price) or open_price == 0:
open_price = close
return {
'open': open_price,
'close': close,
'prev_close': prev_close,
}
def _get_weight(self, code: str, n_unique: int) -> float:
"""Get position weight for a code. Falls back to equal weight if not available."""
if code in self._position_weights:
return self._position_weights[code]
return 1.0 / n_unique if n_unique > 0 else 0.0
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution) with configurable position weighting:
- Hold: close-to-close, weighted by today's position weight
- Sell: close-to-open (sold at open), weighted by today's position weight
- Buy: open-to-close (intraday), weighted by today's position weight
"""
if not old_holdings:
if not new_holdings:
return 0.0
ret = 0.0
seen = set()
n_unique = len(set(new_holdings))
for code in new_holdings:
if code in seen:
continue
seen.add(code)
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
w = self._get_weight(code, n_unique)
if p and p['open'] > 0:
ret += w * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
n_unique = len(old_set)
daily_return = 0.0
for code in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
w = self._get_weight(code, n_unique)
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += w * r
for code in new_set - old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
w = self._get_weight(code, len(new_set))
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += w * r
if is_rebalance:
daily_return -= self.trade_cost
return daily_return
def run(self) -> dict:
"""Main backtest loop"""
print("=" * 60)
print(" Simple Rotation Strategy (Daily Iteration)")
print("=" * 60)
self._preload_data()
print("\n[3/4] Fetching A-share trading calendar...")
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
if self.trading_calendar is None or len(self.trading_calendar) == 0:
print(" x Calendar fetch failed")
return {}
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
current_holdings: List[str] = []
nav = 1.0
rebalance_count = 0
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
active_weights: Dict[str, float] = {} # locked-in weights, updated only on rebalance
for i, date in enumerate(self.trading_calendar):
# Signal timing: 9:00 AM on day T
# At this moment, T's market has NOT opened yet.
# Only T-1 close data is available for all markets.
# Use calendar day -1 (not A-share prev trading day) so that
# foreign markets (US/HK) that traded during A-share holidays are captured.
# Execution happens at 9:30 AM using T's ETF prices.
signal_date = date - timedelta(days=1)
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
# Data integrity check: if any currently held asset is missing from
# today's factors, abort immediately to prevent false rebalancing.
if current_holdings:
missing = [c for c in current_holdings if c not in factors]
if missing:
raise RuntimeError(
f"Data failure: held assets {missing} missing from factors on "
f"{date.strftime('%Y-%m-%d')}. Aborting to prevent false rebalance."
)
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
# Lock in position weights only on rebalance (or first day)
if is_rebalance or not current_holdings:
active_weights = dict(self._pending_weights)
self._position_weights = active_weights
# Return uses T's ETF prices (open for buy/sell, close for hold)
daily_return = self._calculate_daily_return(
current_holdings, new_holdings, date, is_rebalance
)
nav *= (1 + daily_return)
if is_rebalance:
rebalance_count += 1
# Update entry tracking for held assets
added = set(new_holdings) - set(current_holdings)
removed = set(current_holdings) - set(new_holdings)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price is the actual buy price: today's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (the data used to make the decision)
entry_idx = self._get_index_close(code, signal_date)
entry_info[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': entry_idx,
}
for code in removed:
entry_info.pop(code, None)
# Compute bond threshold value for detail record
threshold_val = 0.0
if self.use_dynamic_threshold and bond_momentum is not None:
threshold_val = round(bond_momentum * self.bond_ratio, 6)
self.daily_records.append({
'date': date.strftime('%Y-%m-%d'),
'nav': round(nav, 6),
'daily_return': round(daily_return, 6),
'is_rebalance': is_rebalance,
'holdings': sorted(new_holdings),
'added': sorted(added),
'removed': sorted(removed),
'factors': {k: round(v, 6) for k, v in factors.items()},
'threshold': threshold_val,
'position_weights': {k: round(v, 6) for k, v in active_weights.items()},
})
current_holdings = new_holdings
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
print(f" [{i+1}/{len(self.trading_calendar)}] "
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
metrics = self._compute_metrics(rebalance_count)
print("\n" + "=" * 60)
print(" Backtest Complete")
print("=" * 60)
if self.daily_records:
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
print(f" Trading days: {len(self.daily_records)}")
print(f" Total return: {metrics['total_return']:.2%}")
print(f" Annual return: {metrics['annual_return']:.2%}")
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
print(f" Win rate: {metrics['win_rate']:.2%}")
print(f" Rebalances: {rebalance_count}")
print("=" * 60)
return {
'metrics': metrics,
'daily_records': self.daily_records,
}
def _compute_metrics(self, rebalance_count: int) -> dict:
"""Compute performance metrics"""
if not self.daily_records:
return {}
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
nav = df['nav']
returns = df['daily_return']
total_return = nav.iloc[-1] / nav.iloc[0] - 1
n_days = len(df)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
peak = nav.cummax()
drawdown = (nav - peak) / peak
max_drawdown = drawdown.min()
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
non_zero = returns[returns != 0]
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'calmar_ratio': calmar,
'win_rate': win_rate,
'n_days': n_days,
'rebalance_count': rebalance_count,
}
# ============================================================
# Detail JSON helpers (V2-compatible)
# ============================================================
def _build_meta_codes(self) -> dict:
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
codes = {}
for code, asset in self.config.asset_pools.assets.items():
codes[asset.signal_source] = {
'name': getattr(asset, 'name', code),
'etf': asset.trade_source,
'market': asset.group,
}
return codes
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
"""Get index close price on or before date."""
df = self.index_data.get(code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get ETF close price on or before date."""
df = self.etf_data.get(trade_code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
"""Get (index_return, etf_return_ctc) for a code on a given date."""
trade_code = self.signal_to_trade.get(code, code)
idx_df = self.index_data.get(code)
etf_df = self.etf_data.get(trade_code)
idx_ret = None
if idx_df is not None:
mask = idx_df.index <= date
if mask.sum() >= 2:
rows = idx_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
idx_ret = round((c_today - c_prev) / c_prev, 6)
etf_ret = None
if etf_df is not None:
mask = etf_df.index <= date
if mask.sum() >= 2:
rows = etf_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
etf_ret = round((c_today - c_prev) / c_prev, 6)
return idx_ret, etf_ret
def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get real premium from API data cache: (ETF_price - NAV) / NAV.
Returns None for BOND or when premium data is unavailable."""
group = self.trade_code_to_group.get(trade_code, '')
if group == 'BOND':
return None
premium_dict = self.data_cache.premium_data.get(trade_code)
if not premium_dict:
return None
date_str = date.strftime('%Y-%m-%d')
val = premium_dict.get(date_str)
if val is None:
return None
return round(float(val), 6)
def _get_latest_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get premium for trade_code, looking back up to 5 days if exact date not found."""
group = self.trade_code_to_group.get(trade_code, '')
if group == 'BOND':
return None
premium_dict = self.data_cache.premium_data.get(trade_code)
if not premium_dict:
return None
# Try exact date first, then look back up to 5 calendar days
for offset in range(6):
check_date = date - timedelta(days=offset)
date_str = check_date.strftime('%Y-%m-%d')
val = premium_dict.get(date_str)
if val is not None:
return round(float(val), 6)
return None
def _build_day_assets(self, record: dict, date: pd.Timestamp,
entry_info: Dict[str, dict]) -> dict:
"""Build V2-compatible per-asset detail dict for one day."""
factors = record.get('factors', {})
holdings = set(record['holdings'])
threshold = record.get('threshold', 0.0)
position_weights = record.get('position_weights', {})
# Rank: sort all codes by momentum descending, rank 1 = highest
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
assets = {}
for code in self.signal_codes:
momentum = factors.get(code)
rank = rank_map.get(code)
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
is_held = code in holdings
trade_code = self.signal_to_trade.get(code, code)
idx_close = self._get_index_close(code, date)
etf_close = self._get_etf_close(trade_code, date)
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
premium = self._compute_premium(trade_code, date)
# Entry / holding info
ei = entry_info.get(code) if is_held else None
entry_date = ei['entry_date'] if ei else None
entry_price_etf = ei['entry_price_etf'] if ei else None
entry_price_idx = ei['entry_price_idx'] if ei else None
holding_days = 0
cum_ret_etf = None
cum_ret_idx = None
if ei and is_held:
entry_dt = pd.Timestamp(ei['entry_date'])
holding_days = int((date - entry_dt).days)
if holding_days < 1:
holding_days = 1
ep_etf = ei.get('entry_price_etf')
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
ep_idx = ei.get('entry_price_idx')
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
assets[code] = {
'momentum': float(momentum) if momentum is not None else None,
'rank': rank,
'threshold': float(threshold),
'above_threshold': bool(above_thresh),
'index_close': float(idx_close) if idx_close is not None else None,
'etf_close': float(etf_close) if etf_close is not None else None,
'index_return': float(idx_ret) if idx_ret is not None else None,
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
'premium': float(premium) if premium is not None else None,
'is_held': bool(is_held),
'entry_date': entry_date,
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
'holding_days': holding_days,
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
'weight': float(position_weights.get(code, 0.0)) if is_held else None,
}
return assets
# ============================================================
# Export
# ============================================================
def export_results(self, output_dir: str = None, detail: bool = True):
"""Export backtest results to CSV and JSON (V2-compatible detail format)
Args:
output_dir: Output directory path.
detail: If True, export detail JSON (large file for backtest_viewer).
Set to False for daily runs to skip expensive detail generation.
"""
if not self.daily_records:
print(" x No results to export")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(self.daily_records)
# NAV curve
nav_path = output_dir / 'simple_rotation_nav.csv'
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
print(f" + NAV: {nav_path}")
# Signals
sig_path = output_dir / 'simple_rotation_signals.csv'
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
print(f" + Signals: {sig_path}")
# Detail JSON (V2-compatible format, optional — large file for backtest_viewer)
if detail:
detail_path = output_dir / 'simple_rotation_detail.json'
days_out = []
# Track entry_info across days for asset detail reconstruction
tracked_entry: Dict[str, dict] = {}
prev_holdings = []
# Build date index map for signal_date lookup (calendar T-1, not A-share prev trading day)
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
date_to_signal_date = {d: d - timedelta(days=1) for d in date_list}
for rec in self.daily_records:
date = pd.Timestamp(rec['date'])
signal_date = date_to_signal_date[date] # T-1 for signal
holdings = rec['holdings']
added = set(holdings) - set(prev_holdings)
removed = set(prev_holdings) - set(holdings)
# Update entry tracking (consistent with run() logic)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price = actual buy price at T's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (signal data used for decision)
idx_close = self._get_index_close(code, signal_date)
tracked_entry[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': idx_close,
}
for code in removed:
tracked_entry.pop(code, None)
# Build signals dict: {code: 1} for selected holdings
signals = {c: 1 for c in holdings}
# Build per-asset details
assets = self._build_day_assets(rec, date, tracked_entry)
days_out.append({
'date': rec['date'],
'nav': rec['nav'],
'daily_return': rec['daily_return'],
'is_rebalance': rec['is_rebalance'],
'signals': signals,
'holdings': holdings,
'added': rec['added'],
'removed': rec['removed'],
'assets': assets,
})
prev_holdings = holdings
detail_data = {
'meta': {
'mode': 'Simple: Daily Iteration',
'start_date': self.config.backtest.start_date,
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
'total_days': len(self.daily_records),
'select_num': self.select_num,
'n_days': self.n_days,
'trade_cost': self.trade_cost,
'bond_threshold': {
'enabled': self.use_dynamic_threshold,
'bond_code': self.bond_code,
'ratio': self.bond_ratio,
},
'codes': self._build_meta_codes(),
},
'days': days_out,
}
_sanitize_json(detail_data)
with open(detail_path, 'w', encoding='utf-8') as f:
json.dump(detail_data, f, ensure_ascii=False, indent=2)
print(f" + Detail: {detail_path} ({len(days_out)} days)")
# Metrics JSON
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
metrics_path = output_dir / 'simple_rotation_metrics.json'
_sanitize_json(metrics)
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")
# ============================================================
# Report Generation (PNG chart with tables)
# ============================================================
def generate_report(self, output_dir: str = None):
"""Generate performance report chart (PNG) with signal table, metrics, NAV, drawdown, holdings"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
if not self.daily_records:
print(" x No results for report")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Build DataFrames
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
strategy_nav = df['nav']
strategy_ret = df['daily_return']
# Compute benchmark NAV
benchmark_nav = None
if self.benchmark_data is not None:
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
if bm_close is not None and not bm_close.isna().all():
benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
benchmark_nav = benchmark_nav / first_valid
# Compute individual asset NAVs
asset_navs = {}
for code in self.signal_codes:
if code in self.index_data:
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
if price is not None and not price.isna().all():
nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
asset_navs[code] = nav_s / fv
# Compute metrics
s_total_return = strategy_nav.iloc[-1] / strategy_nav.iloc[0] - 1
n_days = len(df)
s_annual = (1 + s_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
s_sharpe = strategy_ret.mean() / strategy_ret.std() * np.sqrt(252) if strategy_ret.std() > 0 else 0
s_peak = strategy_nav.cummax()
s_dd = ((strategy_nav - s_peak) / s_peak).min()
s_calmar = s_annual / abs(s_dd) if s_dd != 0 else 0
non_zero = strategy_ret[strategy_ret != 0]
s_win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
# Benchmark metrics
b_total_return = b_annual = b_sharpe = b_dd = 0
if benchmark_nav is not None:
bm_ret = benchmark_nav.pct_change(fill_method=None)
b_total_return = benchmark_nav.iloc[-1] - 1
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
b_peak = benchmark_nav.cummax()
b_dd = ((benchmark_nav - b_peak) / b_peak).min()
# Get latest holdings info
last_rec = self.daily_records[-1]
last_date = pd.Timestamp(last_rec['date'])
holdings = last_rec['holdings']
factors = last_rec.get('factors', {})
# Check if market has opened for last_date (ETF data available for exact date)
market_opened = False
for code in holdings:
trade_code = self.signal_to_trade.get(code, code)
if trade_code in self.etf_data:
etf_df = self.etf_data[trade_code]
if len(etf_df) > 0 and etf_df.index[-1] >= last_date:
market_opened = True
break
# Build code config dict for display
code_config = {}
for name, asset in self.config.asset_pools.assets.items():
code_config[asset.signal_source] = {
'name': asset.name if hasattr(asset, 'name') else name,
'etf': asset.trade_source,
'market': asset.group,
}
# Build positions info for table
# Sort holdings by momentum score descending
position_weights = last_rec.get('position_weights', {})
sorted_holdings = sorted(holdings, key=lambda c: factors.get(c, 0) or 0, reverse=True)
# Determine previous holdings to distinguish "调入" vs "维持"
last_idx = len(self.daily_records) - 1
prev_holdings = set()
if last_idx > 0:
prev_holdings = set(self.daily_records[last_idx - 1]['holdings'])
is_rebalance_day = last_rec.get('is_rebalance', False)
positions_info = []
for code in sorted_holdings:
cfg = code_config.get(code, {})
name = cfg.get('name', code)
etf_code = cfg.get('etf', '')
score = factors.get(code)
idx_close = self._get_index_close(code, last_date)
trade_code = self.signal_to_trade.get(code, code)
etf_close = self._get_etf_close(trade_code, last_date)
premium = self._get_latest_premium(trade_code, last_date)
# Determine action: 调入 (new on rebalance day) vs 维持 (already holding)
if is_rebalance_day and code not in prev_holdings:
action = '调入'
else:
action = '维持'
# Find entry info: scan backwards for continuous holding start
entry_date = None
entry_price = None
holding_days = 0
pnl = None
is_new_entry = (is_rebalance_day and code not in prev_holdings)
# New entries on a day when market hasn't opened yet: no execution data
if is_new_entry and not market_opened:
entry_date = None
entry_price = None
holding_days = 0
pnl = None
else:
for rec in reversed(self.daily_records):
if code in rec['holdings']:
rec_date = pd.Timestamp(rec['date'])
p = self._get_etf_prices(trade_code, rec_date)
if p is not None:
entry_date = rec_date
entry_price = p['open']
else:
break
if entry_date is not None:
holding_days = (last_date - entry_date).days
# For maintained positions, use latest available close for pnl
if entry_price and entry_price > 0 and etf_close and etf_close > 0:
pnl = etf_close / entry_price - 1
positions_info.append({
'name': name, 'code': code, 'etf': etf_code,
'weight': position_weights.get(code, 1.0 / len(holdings)), 'score': score,
'idx_close': idx_close, 'etf_close': etf_close,
'premium': premium, 'action': action,
'entry_date': entry_date, 'entry_price': entry_price,
'holding_days': holding_days, 'pnl': pnl,
'exit_date': None, 'exit_price': None,
})
# Build exit positions: ONLY when the last day is a rebalance day
exit_positions = []
if last_rec.get('is_rebalance', False):
last_idx = len(self.daily_records) - 1
if last_idx > 0:
old_holdings = set(self.daily_records[last_idx - 1]['holdings'])
new_holdings = set(holdings)
removed = old_holdings - new_holdings
prev_rec = self.daily_records[last_idx - 1]
prev_weights = prev_rec.get('position_weights', {})
for code in sorted(removed):
cfg = code_config.get(code, {})
name = cfg.get('name', code)
etf_code = cfg.get('etf', '')
idx_close = self._get_index_close(code, last_date)
trade_code = self.signal_to_trade.get(code, code)
etf_close = self._get_etf_close(trade_code, last_date)
premium = self._get_latest_premium(trade_code, last_date)
score = factors.get(code)
# Recover holding info from historical records
exit_entry_date = None
exit_entry_price = None
exit_holding_days = 0
exit_pnl = None
sell_price = None
for rec in reversed(self.daily_records[:last_idx]):
if code in rec['holdings']:
exit_entry_date = pd.Timestamp(rec['date'])
p = self._get_etf_prices(trade_code, exit_entry_date)
exit_entry_price = p['open'] if p else None
else:
break
if exit_entry_date is not None:
exit_holding_days = (last_date - exit_entry_date).days
# Exit price: today's open if market opened, else prev_close
exit_prices = self._get_etf_prices(trade_code, last_date)
if market_opened and exit_prices:
sell_price = exit_prices['open']
elif exit_prices:
sell_price = exit_prices['prev_close']
else:
sell_price = None
if exit_entry_price and exit_entry_price > 0 and sell_price and sell_price > 0:
exit_pnl = sell_price / exit_entry_price - 1
exit_positions.append({
'name': name, 'code': code, 'etf': etf_code,
'weight': prev_weights.get(code, 0.0), 'score': score,
'idx_close': idx_close, 'etf_close': etf_close,
'premium': premium, 'action': '调出',
'entry_date': exit_entry_date, 'entry_price': exit_entry_price,
'holding_days': exit_holding_days, 'pnl': exit_pnl,
'exit_date': last_date, 'exit_price': sell_price,
})
# Build unselected (未入选) positions: all signal codes not held and not exited
held_or_exited = set(holdings) | {p['code'] for p in exit_positions}
unselected_positions = []
# Sort unselected by momentum descending
unselected_codes = [c for c in self.signal_codes if c not in held_or_exited]
unselected_codes.sort(key=lambda c: factors.get(c, -999), reverse=True)
for code in unselected_codes:
cfg = code_config.get(code, {})
name = cfg.get('name', code)
etf_code = cfg.get('etf', '')
score = factors.get(code)
idx_close = self._get_index_close(code, last_date)
trade_code = self.signal_to_trade.get(code, code)
etf_close = self._get_etf_close(trade_code, last_date)
premium = self._get_latest_premium(trade_code, last_date)
unselected_positions.append({
'name': name, 'code': code, 'etf': etf_code,
'weight': 0, 'score': score,
'idx_close': idx_close, 'etf_close': etf_close,
'premium': premium, 'action': '未入选',
'entry_date': None, 'entry_price': None,
'holding_days': 0, 'pnl': None,
'exit_date': None, 'exit_price': None,
})
# ==================== Plot ====================
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "WenQuanYi Zen Hei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
n_rows = len(positions_info) + len(exit_positions) + len(unselected_positions)
signal_h = max(1.5, 0.5 + n_rows * 0.35)
fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 1.5 + 8))
gs = fig.add_gridspec(6, 1, height_ratios=[signal_h, 1.2, 1.5, 3, 1, 1.2], hspace=0.35)
# Panel 0: Full asset ranking table
ax0 = fig.add_subplot(gs[0])
ax0.axis("off")
threshold_val = last_rec.get('threshold', 0.0)
ax0.set_title(f"全标的动量排名 (信号日期: {last_date.strftime('%Y-%m-%d')},阈值: {threshold_val:.4f})",
fontsize=14, fontweight="bold", loc="left", pad=15)
# Build rank map for all codes by momentum
ranked_all = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
rank_map = {c: i + 1 for i, c in enumerate(ranked_all)}
col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分",
"指数最新价", "ETF收盘价", "溢价率", "状态",
"进场日期", "持有天数", "盈亏", "退场日期", "退场价格"]
table_data = []
row_actions = [] # track action for coloring
# Ordered: 调入 -> 维持 -> 调出 -> 未入选
all_rows = positions_info + exit_positions + unselected_positions
for p in all_rows:
rank = rank_map.get(p['code'], '')
score_s = f"{p['score']:.4f}" if p['score'] is not None else ""
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else ""
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else ""
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else ""
entry_date_s = p['entry_date'].strftime('%Y-%m-%d') if p['entry_date'] else ""
days_s = str(p['holding_days']) if p['holding_days'] > 0 else ""
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else ""
weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else ""
market = code_config.get(p['code'], {}).get('market', '')
exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else ""
exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else ""
table_data.append([
rank, p['name'], market, p['code'], p['etf'], weight_s,
score_s, idx_s, etf_s, prem_s, p['action'],
entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s
])
row_actions.append(p['action'])
if table_data:
tbl = ax0.table(cellText=table_data, colLabels=col_labels, loc="center",
cellLoc="center", bbox=[0, 0, 1, 1])
tbl.auto_set_font_size(False)
tbl.set_fontsize(8)
tbl.scale(1, 1.6)
for j in range(len(col_labels)):
tbl[0, j].set_facecolor("#2C3E50")
tbl[0, j].set_text_props(color="white", fontweight="bold")
action_colors = {
"调入": "#d4edda", # 浅绿色
"调出": "#f8d7da", # 浅红色
"维持": "#fff3cd", # 浅黄色
"未入选": "#f0f0f0", # 浅灰色
}
for i in range(len(table_data)):
color = action_colors.get(row_actions[i], "#ffffff")
for j in range(len(col_labels)):
tbl[i + 1, j].set_facecolor(color)
# Panel 1: Metrics table
ax1 = fig.add_subplot(gs[1])
ax1.axis("off")
ax1.set_title("策略绩效对比", fontsize=14, fontweight="bold", loc="left", pad=10)
start_s = df.index.min().strftime('%Y-%m-%d')
end_s = df.index.max().strftime('%Y-%m-%d')
perf_cols = ["策略", "开始时间", "结束时间", "累计收益", "年化收益", "最大回撤", "夏普比率", "Calmar比率", "日胜率"]
strat_row = ["轮动策略", start_s, end_s,
f"{s_total_return:.2%}", f"{s_annual:.2%}", f"{s_dd:.2%}",
f"{s_sharpe:.2f}", f"{s_calmar:.2f}", f"{s_win_rate:.2%}"]
bench_row = [f"基准({self.benchmark_name})", start_s, end_s,
f"{b_total_return:.2%}", f"{b_annual:.2%}", f"{b_dd:.2%}",
f"{b_sharpe:.2f}", "", ""]
ptbl = ax1.table(cellText=[strat_row, bench_row], colLabels=perf_cols,
loc="center", cellLoc="center", bbox=[0, 0, 1, 1])
ptbl.auto_set_font_size(False)
ptbl.set_fontsize(10)
ptbl.scale(1, 1.8)
for j in range(len(perf_cols)):
ptbl[0, j].set_facecolor("#2C3E50")
ptbl[0, j].set_text_props(color="white", fontweight="bold")
ptbl[1, j].set_facecolor("#d4edda")
ptbl[2, j].set_facecolor("#cce5ff")
# Panel 2: Monthly returns heatmap (year × month matrix)
ax2 = fig.add_subplot(gs[2])
ax2.axis('off')
ax2.set_title('月度收益矩阵 (%)', fontsize=14, fontweight='bold', loc='left', pad=10)
# Resample to monthly NAV (last business day)
monthly_nav = strategy_nav.resample('ME').last().dropna()
monthly_ret = monthly_nav.pct_change().dropna()
# Build year-month matrix
monthly_ret_pct = (monthly_ret * 100).round(2)
year_month_df = pd.DataFrame({
'year': monthly_ret_pct.index.year,
'month': monthly_ret_pct.index.month,
'ret': monthly_ret_pct.values,
})
years = sorted(year_month_df['year'].unique())
month_labels = [f'{m}' for m in range(1, 13)]
col_labels = month_labels + ['年度']
# Build table data
table_data = []
for yr in years:
row = []
yr_total = 0.0
yr_count = 0
for m in range(1, 13):
mask = (year_month_df['year'] == yr) & (year_month_df['month'] == m)
vals = year_month_df.loc[mask, 'ret'].values
if len(vals) > 0:
v = float(vals[0])
row.append(v)
yr_total = (1 + yr_total / 100) * (1 + v / 100) * 100 - 100
yr_count += 1
else:
row.append(None)
row.append(round(yr_total, 2) if yr_count > 0 else None)
table_data.append(row)
if table_data:
cell_text = []
for row in table_data:
cell_text.append([f'{v:+.1f}' if v is not None else '' for v in row])
tbl = ax2.table(cellText=cell_text, rowLabels=[str(y) for y in years],
colLabels=col_labels, loc='center', cellLoc='center', bbox=[0, 0, 1, 1])
tbl.auto_set_font_size(False)
tbl.set_fontsize(9)
tbl.scale(1, 1.6)
# Style header (colLabels row)
for j in range(len(col_labels)):
tbl[0, j].set_facecolor('#2C3E50')
tbl[0, j].set_text_props(color='white', fontweight='bold')
# Color cells based on value (sqrt mapping for better visibility of small returns)
max_abs = max(abs(v) for row in table_data for v in row if v is not None) or 1
for i, row in enumerate(table_data):
for j, v in enumerate(row):
if v is not None:
intensity = min((abs(v) / max_abs) ** 0.5, 1.0)
if v >= 0:
r, g, b = 1.0, 1.0 - intensity * 0.55, 1.0 - intensity * 0.55
else:
r, g, b = 1.0 - intensity * 0.55, 1.0, 1.0 - intensity * 0.55
tbl[i + 1, j].set_facecolor((r, g, b))
else:
tbl[i + 1, j].set_facecolor('#f5f5f5')
# Annual column: stronger coloring
ann_v = row[-1]
if ann_v is not None:
intensity = min((abs(ann_v) / max_abs) ** 0.5, 1.0)
if ann_v >= 0:
r, g, b = 1.0, 1.0 - intensity * 0.7, 1.0 - intensity * 0.7
else:
r, g, b = 1.0 - intensity * 0.7, 1.0, 1.0 - intensity * 0.7
tbl[i + 1, len(col_labels) - 1].set_facecolor((r, g, b))
# Panel 3: NAV curves
ax3 = fig.add_subplot(gs[3])
ax3.plot(strategy_nav.index, strategy_nav.values,
label="轮动策略", linewidth=2, color="#E74C3C")
if benchmark_nav is not None:
ax3.plot(benchmark_nav.index, benchmark_nav.values,
label=self.benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8)
colors = plt.cm.tab20.colors
for i, code in enumerate(self.signal_codes):
if code in asset_navs:
cfg = code_config.get(code, {})
lbl = cfg.get('name', code) if i < 10 else None
ax3.plot(asset_navs[code].index, asset_navs[code].values,
label=lbl, linewidth=0.8, alpha=0.4,
color=colors[i % len(colors)])
ax3.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold")
ax3.set_ylabel("净值")
ax3.legend(loc="upper left", fontsize=8, ncol=2)
ax3.grid(True, alpha=0.3)
ax3.set_yscale("log")
# Panel 4: Drawdown
ax4 = fig.add_subplot(gs[4])
drawdown = (strategy_nav - s_peak) / s_peak
ax4.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C")
ax4.set_title("策略回撤", fontsize=12)
ax4.set_ylabel("回撤")
ax4.grid(True, alpha=0.3)
# Panel 5: Holdings distribution
ax5 = fig.add_subplot(gs[5])
holdings_series = df['holdings']
for i, code in enumerate(self.signal_codes):
cfg = code_config.get(code, {})
name = cfg.get('name', code)
mask = holdings_series.apply(lambda h: code in h)
if mask.any():
ax5.fill_between(mask.index, i, i + 0.8,
where=mask, alpha=0.7,
color=colors[i % len(colors)], label=name)
ylabels = [code_config.get(c, {}).get('name', c) for c in self.signal_codes]
ax5.set_title("每日持仓分布", fontsize=12)
ax5.set_yticks(range(len(ylabels)))
ax5.set_yticklabels(ylabels, fontsize=7)
ax5.grid(True, alpha=0.3)
chart_path = output_dir / 'simple_rotation_report.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.close()
print(f" + Report: {chart_path}")
# ============================================================
# Entry point
# ============================================================
if __name__ == "__main__":
import argparse
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
parser = argparse.ArgumentParser(description='Simple Rotation Strategy Backtest')
parser.add_argument('--no-detail', action='store_true',
help='Skip detail JSON export (faster, for daily runs)')
parser.add_argument('--no-report', action='store_true',
help='Skip report PNG generation')
args = parser.parse_args()
strategy = SimpleRotationStrategy()
result = strategy.run()
if result:
strategy.export_results(detail=not args.no_detail)
if not args.no_report:
strategy.generate_report()