Files
etf/rotation/simple_rotation.py
aszerW b564a47a1b feat: 新增slope_r2因子并切换为默认因子(年化19.84%, 夏普1.14)
- simple_rotation.py: 新增3种score函数(vol_adjusted_momentum, slope_r2, momentum)
- config_loader.py: FactorType枚举新增VOL_ADJUSTED_MOMENTUM
- config_simple.yaml: factor.type 切换为 slope_r2
- experiments/factor_comparison.py: 4种因子对比实验脚本
- experiments/output: 实验结果(slope_r2全面胜出)
2026-06-06 15:49:22 +08:00

1321 lines
57 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Simple Rotation Strategy (Daily Iteration)
From A-share trading calendar, iterate daily:
1. Get signal source last N days -> compute momentum
2. Get bond momentum (dynamic threshold)
3. Group selection -> generate holdings
4. Compare with yesterday -> compute T+1 return
5. Update NAV
"""
import os
import sys
import math
import json
import time
import requests
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType
# ============================================================
# HTTP client (requests + trust_env=False绕过系统代理避免 SSL EOF)
# ============================================================
# Clash 等代理在处理 TLS 1.3 + 后量子密钥交换时会触发 SSL EOF 错误
# trust_env=False 让 requests 忽略环境变量中的代理配置,直连目标服务器
_session = requests.Session()
_session.trust_env = False
def _http_get(url: str, params: dict = None, timeout: int = 120) -> requests.Response:
"""使用 requests 发起 GET 请求trust_env=False 绕过系统代理)"""
return _session.get(url, params=params, timeout=timeout)
def _sanitize_json(obj):
"""Recursively replace NaN/Inf with None in-place so json.dump produces valid JSON"""
if isinstance(obj, dict):
for k, v in obj.items():
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
obj[k] = None
elif isinstance(v, (dict, list)):
_sanitize_json(v)
elif isinstance(obj, list):
for i, v in enumerate(obj):
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
obj[i] = None
elif isinstance(v, (dict, list)):
_sanitize_json(v)
# ============================================================
# Pure functions: momentum
# ============================================================
def weighted_momentum_score(prices: np.ndarray) -> float:
"""Weighted linear regression momentum score = annualized_return * R^2"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_return * r2
def vol_adjusted_momentum_score(prices: np.ndarray) -> float:
"""Volatility-adjusted momentum score = (annualized_return / annualized_vol) * R^2
Moskowitz, Ooi & Pedersen (2012) TSMOM approach:
divide momentum by realized volatility to make cross-asset comparison fair.
"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_return = math.exp(slope * 250) - 1
# realized volatility (annualized)
daily_returns = np.diff(y)
realized_vol = float(np.std(daily_returns)) * math.sqrt(250)
if realized_vol < 0.01: # guard against near-zero vol
realized_vol = 0.01
# trend quality (R^2)
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return (annualized_return / realized_vol) * r2
def slope_r2_score(prices: np.ndarray) -> float:
"""Slope * R^2 score (unweighted, normalized prices)"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
y = prices / prices[0] # normalize
x = np.arange(len(y))
slope, intercept = np.polyfit(x, y, 1)
y_pred = slope * x + intercept
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return 10000 * slope * r2
def momentum_score(prices: np.ndarray) -> float:
"""Simple price return: (last / first) - 1"""
if len(prices) < 5:
return 0.0
prices = np.clip(prices, 0.01, None)
return prices[-1] / prices[0] - 1
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
# ============================================================
# Data cache
# ============================================================
class DataCache:
"""Data fetching (no local cache, always from API)"""
def __init__(self, base_url: str, timeout: int = 120):
self.base_url = base_url.rstrip('/')
self.api_path = '/api/v1/ohlcv'
self.timeout = timeout
# premium data in memory: {trade_code: {date_str: premium_ratio}}
self.premium_data: Dict[str, Dict[str, float]] = {}
def preload(self, code: str, start_date: str, end_date: str, adj: str = 'raw') -> Optional[pd.DataFrame]:
"""Fetch data from API"""
return self._fetch_api(code, start_date, end_date, adj)
def _fetch_api(self, code: str, start_date: str, end_date: str, adj: str) -> Optional[pd.DataFrame]:
"""Fetch from Flask API, also extracts premium_series for ETFs"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': adj}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1)
continue
print(f" x {code}: HTTP {resp.status_code}")
return None
data = resp.json()
if 'error' in data:
print(f" x {code}: {data['error']}")
return None
records = data.get('data', [])
if not records:
return None
df = pd.DataFrame(records)
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date').sort_index()
keep = [c for c in ['open', 'high', 'low', 'close', 'volume'] if c in df.columns]
df = df[keep]
# Extract premium_series (ETF only) and store in memory
premium_series = data.get('premium_series', [])
if premium_series:
df.attrs['premium_series'] = {item['date']: item['premium'] for item in premium_series}
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(df.attrs['premium_series'])
print(f" + {code}: {len(df)} rows ({adj})")
return df
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
# 网络相关错误超时、SSL、连接断开都进行重试
if attempt < 2:
time.sleep(1 + attempt) # 递增延迟: 1s, 2s
continue
print(f" x {code}: {type(e).__name__} after {attempt+1} retries")
return None
except Exception as e:
print(f" x {code}: {e}")
return None
return None
def preload_premium(self, code: str, end_date: str = None) -> Optional[Dict[str, float]]:
"""Load premium data for an ETF code, fetch from API if not in memory."""
if code in self.premium_data:
if end_date:
dates = sorted(self.premium_data[code].keys())
if dates and dates[-1] >= end_date:
return self.premium_data[code]
else:
return self.premium_data[code]
self._fetch_premium_api(code, '2000-01-01', end_date or datetime.now().strftime('%Y-%m-%d'))
return self.premium_data.get(code)
def _fetch_premium_api(self, code: str, start_date: str, end_date: str):
"""Fetch premium_series from API and store in memory"""
url = f"{self.base_url}{self.api_path}"
params = {'code': code, 'start': start_date, 'end': end_date, 'adj': 'raw'}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1 + attempt)
continue
return
data = resp.json()
if 'error' in data:
return
premium_series = data.get('premium_series', [])
if premium_series:
new_data = {item['date']: item['premium'] for item in premium_series}
if code not in self.premium_data:
self.premium_data[code] = {}
self.premium_data[code].update(new_data)
print(f" + premium {code}: +{len(new_data)} days (total {len(self.premium_data[code])})")
return
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError):
if attempt < 2:
time.sleep(1 + attempt)
continue
return
except Exception:
return
def get_trading_calendar(self, market: str, start_date: str, end_date: str) -> Optional[pd.DatetimeIndex]:
"""Fetch trading calendar from API"""
url = f"{self.base_url}/api/v1/trading-calendar"
params = {'market': market, 'start': start_date, 'end': end_date}
for attempt in range(3):
try:
resp = _http_get(url, params=params, timeout=self.timeout)
if resp.status_code != 200:
if attempt < 2:
time.sleep(1 + attempt)
continue
return None
data = resp.json()
if 'error' in data:
print(f" x calendar: {data['error']}")
return None
dates = data.get('trading_dates', [])
if not dates:
return pd.DatetimeIndex([])
result = pd.DatetimeIndex(dates)
print(f" + {market}: {len(result)} trading days ({start_date} ~ {end_date})")
return result
except (requests.exceptions.Timeout, requests.exceptions.SSLError, requests.exceptions.ConnectionError) as e:
if attempt < 2:
time.sleep(1 + attempt)
continue
print(f" x calendar: {type(e).__name__} after {attempt+1} retries")
return None
except Exception as e:
if attempt < 2:
time.sleep(1 + attempt)
continue
print(f" x calendar: {e}")
return None
return None
# ============================================================
# Core strategy class
# ============================================================
class SimpleRotationStrategy:
"""
Simple rotation strategy (daily iteration)
Flow:
1. Load config
2. Preload all historical data (index raw + ETF hfq)
3. Fetch A-share trading calendar
4. For each trading day:
- Compute momentum (last 25 days)
- Group selection -> generate holdings
- Compare with yesterday -> compute T+1 return
- Update NAV
5. Export results
"""
def __init__(self, config_path: str = None):
if config_path is None:
config_path = Path(__file__).parent / 'config_simple.yaml'
self.config: RotationStrategyConfig = load_rotation_config(str(config_path))
# Strategy params
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
# Dynamic threshold
threshold = self.config.rotation.threshold
self.use_dynamic_threshold = (threshold.mode.value == 'dynamic')
self.bond_code = threshold.dynamic.reference if threshold.dynamic else None
self.bond_ratio = threshold.dynamic.ratio if threshold.dynamic else 1.0
self.fallback_value = threshold.dynamic.fallback_value if threshold.dynamic else 0.0
# Signal codes and mappings
self.signal_codes = []
self.signal_to_trade = {}
self.code_to_group = {}
self.trade_code_to_group = {}
for code, asset in self.config.asset_pools.assets.items():
self.signal_codes.append(asset.signal_source)
self.signal_to_trade[asset.signal_source] = asset.trade_source
self.code_to_group[asset.signal_source] = asset.group
self.trade_code_to_group[asset.trade_source] = asset.group
# Data source
data_source = self.config.data.sources[0]
base_url = data_source.url or 'https://k3s.tokenpluse.xyz'
self.data_cache = DataCache(base_url=base_url, timeout=data_source.timeout)
# Preloaded data
self.index_data: Dict[str, pd.DataFrame] = {}
self.etf_data: Dict[str, pd.DataFrame] = {}
self.benchmark_data: Optional[pd.DataFrame] = None
# Benchmark config
self.benchmark_code = self.config.benchmark.code
self.benchmark_name = self.config.benchmark.name
# Results
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
preload_start = (pd.Timestamp(start_date) - timedelta(days=self.n_days * 2)).strftime('%Y-%m-%d')
print("\n[1/4] Preloading signal sources (index raw)...")
for code in self.signal_codes:
df = self.data_cache.preload(code, preload_start, end_date, adj='raw')
if df is not None:
self.index_data[code] = df
print(f"\n Signal: {len(self.index_data)}/{len(self.signal_codes)} OK")
print("\n[2/4] Preloading trade sources (ETF hfq)...")
trade_codes = set(self.signal_to_trade.values())
for code in trade_codes:
is_bond = any(
a.trade_source == code and a.group == 'BOND'
for a in self.config.asset_pools.assets.values()
)
adj = 'raw' if is_bond else 'hfq'
df = self.data_cache.preload(code, preload_start, end_date, adj=adj)
if df is not None:
self.etf_data[code] = df
# Load premium data cache for all ETF trade codes
for code in trade_codes:
self.data_cache.preload_premium(code, end_date=end_date)
print(f"\n Trade: {len(self.etf_data)}/{len(trade_codes)} OK, premium: {len(self.data_cache.premium_data)} loaded")
# Load benchmark
print(f"\n Loading benchmark ({self.benchmark_code})...")
bm_df = self.data_cache.preload(self.benchmark_code, preload_start, end_date, adj='raw')
if bm_df is not None:
self.benchmark_data = bm_df
print(f" Benchmark: {len(bm_df)} rows")
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum for a single code on a given date (uses configured score function)"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
# Dispatch based on factor type
ft = self.config.factor.type
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
return vol_adjusted_momentum_score(prices)
elif ft == FactorType.SLOPE_R2:
return slope_r2_score(prices)
elif ft == FactorType.MOMENTUM:
return momentum_score(prices)
return weighted_momentum_score(prices)
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Always compute weighted momentum (for threshold comparison)"""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
return weighted_momentum_score(prices)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
Generate rotation signals (group competition + dynamic threshold + BOND fill)
Returns:
(holdings, factors, bond_momentum): tuple of selected holdings,
all computed factor scores, and the bond momentum threshold value.
Logic (identical to V2):
1. Each group: select Top 1 (non-BOND groups must exceed bond_momentum * ratio)
2. From all group winners: sort by momentum, select Top select_num
3. Fill remaining slots with BOND
"""
factors: Dict[str, float] = {}
for code in self.signal_codes:
score = self._compute_momentum(code, date)
if score is not None:
factors[code] = score
if not factors:
return [], {}, None
# Bond threshold always uses raw (weighted) momentum
bond_momentum = None
if self.use_dynamic_threshold and self.bond_code:
bond_momentum = self._compute_raw_momentum(self.bond_code, date)
if bond_momentum is None:
bond_momentum = self.fallback_value
# Raw factors for threshold comparison
raw_factors: Dict[str, float] = {}
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
for code in self.signal_codes:
score = self._compute_raw_momentum(code, date)
if score is not None:
raw_factors[code] = score
else:
raw_factors = factors
groups = self.config.asset_pools.by_group
selected_by_group: Dict[str, Tuple[str, float]] = {}
for group_name, assets in groups.items():
group_codes = [a.signal_source for a in assets.values()]
group_factors = {c: factors[c] for c in group_codes if c in factors}
group_raw = {c: raw_factors[c] for c in group_codes if c in raw_factors}
if not group_factors:
continue
# Threshold comparison uses raw (weighted) momentum
if group_name != 'BOND' and bond_momentum is not None:
thresh = bond_momentum * self.bond_ratio
group_factors = {c: s for c, s in group_factors.items()
if group_raw.get(c, 0) >= thresh}
if not group_factors:
continue
top_code = max(group_factors, key=group_factors.get)
selected_by_group[group_name] = (top_code, group_factors[top_code])
if not selected_by_group:
return [], factors, bond_momentum
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
final_holdings = [code for code, _ in candidates[:self.select_num]]
if len(final_holdings) < self.select_num and self.bond_code:
if self.bond_code not in final_holdings:
n_slots = self.select_num - len(final_holdings)
final_holdings.extend([self.bond_code] * n_slots)
return sorted(final_holdings), factors, bond_momentum
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
Note: Index data may lack 'open' column; use close as fallback.
"""
if trade_code not in self.etf_data:
return None
df = self.etf_data[trade_code]
mask = df.index <= date
recent = df.loc[mask]
if len(recent) < 2:
return None
today_row = recent.iloc[-1]
prev_row = recent.iloc[-2]
today_date = recent.index[-1]
if (date - today_date).days > 5:
return None
close_raw = today_row.get('close')
prev_close_raw = prev_row.get('close')
if close_raw is None or pd.isna(close_raw) or prev_close_raw is None or pd.isna(prev_close_raw):
return None
close = float(close_raw)
prev_close = float(prev_close_raw)
# Handle missing/invalid open (common for index data like 931862.CSI)
raw_open = today_row.get('open')
open_price = float(raw_open) if raw_open is not None and not pd.isna(raw_open) else close
if pd.isna(open_price) or open_price == 0:
open_price = close
return {
'open': open_price,
'close': close,
'prev_close': prev_close,
}
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution):
- Hold: close-to-close
- Sell: close-to-open (sold at open)
- Buy: open-to-close (intraday)
"""
if not old_holdings:
if not new_holdings:
return 0.0
weight = 1.0 / len(new_holdings)
ret = 0.0
for code in new_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0:
ret += weight * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
weight = 1.0 / len(old_holdings)
daily_return = 0.0
for code in old_holdings:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += weight * r
for code in new_holdings:
if code not in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += weight * r
if is_rebalance:
daily_return -= self.trade_cost
return daily_return
def run(self) -> dict:
"""Main backtest loop"""
print("=" * 60)
print(" Simple Rotation Strategy (Daily Iteration)")
print("=" * 60)
self._preload_data()
print("\n[3/4] Fetching A-share trading calendar...")
start_date = self.config.backtest.start_date
end_date = self.config.backtest.end_date or datetime.now().strftime('%Y-%m-%d')
self.trading_calendar = self.data_cache.get_trading_calendar('A', start_date, end_date)
if self.trading_calendar is None or len(self.trading_calendar) == 0:
print(" x Calendar fetch failed")
return {}
print(f"\n[4/4] Backtesting ({len(self.trading_calendar)} trading days)...")
current_holdings: List[str] = []
nav = 1.0
rebalance_count = 0
entry_info: Dict[str, dict] = {} # signal_code -> {entry_date, entry_price_etf, entry_price_idx}
for i, date in enumerate(self.trading_calendar):
# Signal timing: 9:00 AM on day T
# At this moment, T's market has NOT opened yet.
# Only T-1 close data is available for all markets.
# Use calendar day -1 (not A-share prev trading day) so that
# foreign markets (US/HK) that traded during A-share holidays are captured.
# Execution happens at 9:30 AM using T's ETF prices.
signal_date = date - timedelta(days=1)
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
# Data integrity check: if any currently held asset is missing from
# today's factors, abort immediately to prevent false rebalancing.
if current_holdings:
missing = [c for c in current_holdings if c not in factors]
if missing:
raise RuntimeError(
f"Data failure: held assets {missing} missing from factors on "
f"{date.strftime('%Y-%m-%d')}. Aborting to prevent false rebalance."
)
is_rebalance = (sorted(new_holdings) != sorted(current_holdings)) and len(current_holdings) > 0
# Return uses T's ETF prices (open for buy/sell, close for hold)
daily_return = self._calculate_daily_return(
current_holdings, new_holdings, date, is_rebalance
)
nav *= (1 + daily_return)
if is_rebalance:
rebalance_count += 1
# Update entry tracking for held assets
added = set(new_holdings) - set(current_holdings)
removed = set(current_holdings) - set(new_holdings)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price is the actual buy price: today's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (the data used to make the decision)
entry_idx = self._get_index_close(code, signal_date)
entry_info[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': entry_idx,
}
for code in removed:
entry_info.pop(code, None)
# Compute bond threshold value for detail record
threshold_val = 0.0
if self.use_dynamic_threshold and bond_momentum is not None:
threshold_val = round(bond_momentum * self.bond_ratio, 6)
self.daily_records.append({
'date': date.strftime('%Y-%m-%d'),
'nav': round(nav, 6),
'daily_return': round(daily_return, 6),
'is_rebalance': is_rebalance,
'holdings': sorted(new_holdings),
'added': sorted(added),
'removed': sorted(removed),
'factors': {k: round(v, 6) for k, v in factors.items()},
'threshold': threshold_val,
})
current_holdings = new_holdings
if (i + 1) % 100 == 0 or i == len(self.trading_calendar) - 1:
print(f" [{i+1}/{len(self.trading_calendar)}] "
f"NAV: {nav:.4f} | Rebal: {rebalance_count}")
metrics = self._compute_metrics(rebalance_count)
print("\n" + "=" * 60)
print(" Backtest Complete")
print("=" * 60)
if self.daily_records:
print(f" Period: {self.daily_records[0]['date']} ~ {self.daily_records[-1]['date']}")
print(f" Trading days: {len(self.daily_records)}")
print(f" Total return: {metrics['total_return']:.2%}")
print(f" Annual return: {metrics['annual_return']:.2%}")
print(f" Max drawdown: {metrics['max_drawdown']:.2%}")
print(f" Sharpe ratio: {metrics['sharpe_ratio']:.2f}")
print(f" Calmar ratio: {metrics['calmar_ratio']:.2f}")
print(f" Win rate: {metrics['win_rate']:.2%}")
print(f" Rebalances: {rebalance_count}")
print("=" * 60)
return {
'metrics': metrics,
'daily_records': self.daily_records,
}
def _compute_metrics(self, rebalance_count: int) -> dict:
"""Compute performance metrics"""
if not self.daily_records:
return {}
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
nav = df['nav']
returns = df['daily_return']
total_return = nav.iloc[-1] / nav.iloc[0] - 1
n_days = len(df)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
peak = nav.cummax()
drawdown = (nav - peak) / peak
max_drawdown = drawdown.min()
sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
calmar = annual_return / abs(max_drawdown) if max_drawdown != 0 else 0
non_zero = returns[returns != 0]
win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
return {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'calmar_ratio': calmar,
'win_rate': win_rate,
'n_days': n_days,
'rebalance_count': rebalance_count,
}
# ============================================================
# Detail JSON helpers (V2-compatible)
# ============================================================
def _build_meta_codes(self) -> dict:
"""Build meta.codes mapping: signal_code -> {name, etf, market}"""
codes = {}
for code, asset in self.config.asset_pools.assets.items():
codes[asset.signal_source] = {
'name': getattr(asset, 'name', code),
'etf': asset.trade_source,
'market': asset.group,
}
return codes
def _get_index_close(self, code: str, date: pd.Timestamp) -> Optional[float]:
"""Get index close price on or before date."""
df = self.index_data.get(code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_etf_close(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get ETF close price on or before date."""
df = self.etf_data.get(trade_code)
if df is None:
return None
mask = df.index <= date
if mask.sum() == 0:
return None
return float(df.loc[mask].iloc[-1]['close'])
def _get_daily_returns(self, code: str, date: pd.Timestamp) -> Tuple[Optional[float], Optional[float]]:
"""Get (index_return, etf_return_ctc) for a code on a given date."""
trade_code = self.signal_to_trade.get(code, code)
idx_df = self.index_data.get(code)
etf_df = self.etf_data.get(trade_code)
idx_ret = None
if idx_df is not None:
mask = idx_df.index <= date
if mask.sum() >= 2:
rows = idx_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
idx_ret = round((c_today - c_prev) / c_prev, 6)
etf_ret = None
if etf_df is not None:
mask = etf_df.index <= date
if mask.sum() >= 2:
rows = etf_df.loc[mask]
c_today = float(rows.iloc[-1]['close'])
c_prev = float(rows.iloc[-2]['close'])
if c_prev > 0 and not pd.isna(c_today):
etf_ret = round((c_today - c_prev) / c_prev, 6)
return idx_ret, etf_ret
def _compute_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get real premium from API data cache: (ETF_price - NAV) / NAV.
Returns None for BOND or when premium data is unavailable."""
group = self.trade_code_to_group.get(trade_code, '')
if group == 'BOND':
return None
premium_dict = self.data_cache.premium_data.get(trade_code)
if not premium_dict:
return None
date_str = date.strftime('%Y-%m-%d')
val = premium_dict.get(date_str)
if val is None:
return None
return round(float(val), 6)
def _get_latest_premium(self, trade_code: str, date: pd.Timestamp) -> Optional[float]:
"""Get premium for trade_code, looking back up to 5 days if exact date not found."""
group = self.trade_code_to_group.get(trade_code, '')
if group == 'BOND':
return None
premium_dict = self.data_cache.premium_data.get(trade_code)
if not premium_dict:
return None
# Try exact date first, then look back up to 5 calendar days
for offset in range(6):
check_date = date - timedelta(days=offset)
date_str = check_date.strftime('%Y-%m-%d')
val = premium_dict.get(date_str)
if val is not None:
return round(float(val), 6)
return None
def _build_day_assets(self, record: dict, date: pd.Timestamp,
entry_info: Dict[str, dict]) -> dict:
"""Build V2-compatible per-asset detail dict for one day."""
factors = record.get('factors', {})
holdings = set(record['holdings'])
threshold = record.get('threshold', 0.0)
# Rank: sort all codes by momentum descending, rank 1 = highest
sorted_codes = sorted(factors.keys(), key=lambda c: factors[c], reverse=True)
rank_map = {c: i + 1 for i, c in enumerate(sorted_codes)}
assets = {}
for code in self.signal_codes:
momentum = factors.get(code)
rank = rank_map.get(code)
above_thresh = (momentum is not None and momentum >= threshold) if threshold else (momentum is not None)
is_held = code in holdings
trade_code = self.signal_to_trade.get(code, code)
idx_close = self._get_index_close(code, date)
etf_close = self._get_etf_close(trade_code, date)
idx_ret, etf_ret_ctc = self._get_daily_returns(code, date)
premium = self._compute_premium(trade_code, date)
# Entry / holding info
ei = entry_info.get(code) if is_held else None
entry_date = ei['entry_date'] if ei else None
entry_price_etf = ei['entry_price_etf'] if ei else None
entry_price_idx = ei['entry_price_idx'] if ei else None
holding_days = 0
cum_ret_etf = None
cum_ret_idx = None
if ei and is_held:
entry_dt = pd.Timestamp(ei['entry_date'])
holding_days = int((date - entry_dt).days)
if holding_days < 1:
holding_days = 1
ep_etf = ei.get('entry_price_etf')
if ep_etf and ep_etf > 0 and etf_close is not None and not pd.isna(etf_close):
cum_ret_etf = round((etf_close - ep_etf) / ep_etf, 6)
ep_idx = ei.get('entry_price_idx')
if ep_idx and ep_idx > 0 and idx_close is not None and not pd.isna(idx_close):
cum_ret_idx = round((idx_close - ep_idx) / ep_idx, 6)
assets[code] = {
'momentum': float(momentum) if momentum is not None else None,
'rank': rank,
'threshold': float(threshold),
'above_threshold': bool(above_thresh),
'index_close': float(idx_close) if idx_close is not None else None,
'etf_close': float(etf_close) if etf_close is not None else None,
'index_return': float(idx_ret) if idx_ret is not None else None,
'etf_return_ctc': float(etf_ret_ctc) if etf_ret_ctc is not None else None,
'premium': float(premium) if premium is not None else None,
'is_held': bool(is_held),
'entry_date': entry_date,
'entry_price_etf': float(entry_price_etf) if entry_price_etf is not None else None,
'entry_price_idx': float(entry_price_idx) if entry_price_idx is not None else None,
'holding_days': holding_days,
'cum_return_etf': float(cum_ret_etf) if cum_ret_etf is not None else None,
'cum_return_idx': float(cum_ret_idx) if cum_ret_idx is not None else None,
}
return assets
# ============================================================
# Export
# ============================================================
def export_results(self, output_dir: str = None):
"""Export backtest results to CSV and JSON (V2-compatible detail format)"""
if not self.daily_records:
print(" x No results to export")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(self.daily_records)
# NAV curve
nav_path = output_dir / 'simple_rotation_nav.csv'
df[['date', 'nav', 'daily_return']].to_csv(nav_path, index=False)
print(f" + NAV: {nav_path}")
# Signals
sig_path = output_dir / 'simple_rotation_signals.csv'
df[['date', 'holdings', 'is_rebalance', 'added', 'removed']].to_csv(sig_path, index=False)
print(f" + Signals: {sig_path}")
# Detail JSON (V2-compatible format)
detail_path = output_dir / 'simple_rotation_detail.json'
days_out = []
# Track entry_info across days for asset detail reconstruction
tracked_entry: Dict[str, dict] = {}
prev_holdings = []
# Build date index map for signal_date lookup (calendar T-1, not A-share prev trading day)
date_list = [pd.Timestamp(rec['date']) for rec in self.daily_records]
date_to_signal_date = {d: d - timedelta(days=1) for d in date_list}
for rec in self.daily_records:
date = pd.Timestamp(rec['date'])
signal_date = date_to_signal_date[date] # T-1 for signal
holdings = rec['holdings']
added = set(holdings) - set(prev_holdings)
removed = set(prev_holdings) - set(holdings)
# Update entry tracking (consistent with run() logic)
for code in added:
trade_code = self.signal_to_trade.get(code, code)
etf_prices = self._get_etf_prices(trade_code, date)
# Entry price = actual buy price at T's open
entry_etf = etf_prices['open'] if etf_prices else None
# Index close at T-1 (signal data used for decision)
idx_close = self._get_index_close(code, signal_date)
tracked_entry[code] = {
'entry_date': date.strftime('%Y-%m-%d'),
'entry_price_etf': entry_etf,
'entry_price_idx': idx_close,
}
for code in removed:
tracked_entry.pop(code, None)
# Build signals dict: {code: 1} for selected holdings
signals = {c: 1 for c in holdings}
# Build per-asset details
assets = self._build_day_assets(rec, date, tracked_entry)
days_out.append({
'date': rec['date'],
'nav': rec['nav'],
'daily_return': rec['daily_return'],
'is_rebalance': rec['is_rebalance'],
'signals': signals,
'holdings': holdings,
'added': rec['added'],
'removed': rec['removed'],
'assets': assets,
})
prev_holdings = holdings
detail = {
'meta': {
'mode': 'Simple: Daily Iteration',
'start_date': self.config.backtest.start_date,
'end_date': self.daily_records[-1]['date'] if self.daily_records else self.config.backtest.end_date or 'now',
'total_days': len(self.daily_records),
'select_num': self.select_num,
'n_days': self.n_days,
'trade_cost': self.trade_cost,
'bond_threshold': {
'enabled': self.use_dynamic_threshold,
'bond_code': self.bond_code,
'ratio': self.bond_ratio,
},
'codes': self._build_meta_codes(),
},
'days': days_out,
}
_sanitize_json(detail)
with open(detail_path, 'w', encoding='utf-8') as f:
json.dump(detail, f, ensure_ascii=False, indent=2)
print(f" + Detail: {detail_path} ({len(days_out)} days)")
# Metrics JSON
metrics = self._compute_metrics(sum(1 for r in self.daily_records if r['is_rebalance']))
metrics_path = output_dir / 'simple_rotation_metrics.json'
_sanitize_json(metrics)
with open(metrics_path, 'w', encoding='utf-8') as f:
json.dump(metrics, f, ensure_ascii=False, indent=2)
print(f" + Metrics: {metrics_path}")
# ============================================================
# Report Generation (PNG chart with tables)
# ============================================================
def generate_report(self, output_dir: str = None):
"""Generate performance report chart (PNG) with signal table, metrics, NAV, drawdown, holdings"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
if not self.daily_records:
print(" x No results for report")
return
if output_dir is None:
output_dir = Path(__file__).parent / 'results'
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Build DataFrames
df = pd.DataFrame(self.daily_records)
df['date'] = pd.to_datetime(df['date'])
df = df.set_index('date')
strategy_nav = df['nav']
strategy_ret = df['daily_return']
# Compute benchmark NAV
benchmark_nav = None
if self.benchmark_data is not None:
bm_close = self.benchmark_data['close'].reindex(df.index, method='ffill')
if bm_close is not None and not bm_close.isna().all():
benchmark_nav = (1 + bm_close.pct_change(fill_method=None)).cumprod()
first_valid = benchmark_nav.dropna().iloc[0] if len(benchmark_nav.dropna()) > 0 else 1
benchmark_nav = benchmark_nav / first_valid
# Compute individual asset NAVs
asset_navs = {}
for code in self.signal_codes:
if code in self.index_data:
price = self.index_data[code]['close'].reindex(df.index, method='ffill')
if price is not None and not price.isna().all():
nav_s = (1 + price.pct_change(fill_method=None)).cumprod()
fv = nav_s.dropna().iloc[0] if len(nav_s.dropna()) > 0 else 1
asset_navs[code] = nav_s / fv
# Compute metrics
s_total_return = strategy_nav.iloc[-1] / strategy_nav.iloc[0] - 1
n_days = len(df)
s_annual = (1 + s_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
s_sharpe = strategy_ret.mean() / strategy_ret.std() * np.sqrt(252) if strategy_ret.std() > 0 else 0
s_peak = strategy_nav.cummax()
s_dd = ((strategy_nav - s_peak) / s_peak).min()
s_calmar = s_annual / abs(s_dd) if s_dd != 0 else 0
non_zero = strategy_ret[strategy_ret != 0]
s_win_rate = (non_zero > 0).sum() / len(non_zero) if len(non_zero) > 0 else 0
# Benchmark metrics
b_total_return = b_annual = b_sharpe = b_dd = 0
if benchmark_nav is not None:
bm_ret = benchmark_nav.pct_change(fill_method=None)
b_total_return = benchmark_nav.iloc[-1] - 1
b_annual = (1 + b_total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
b_sharpe = bm_ret.mean() / bm_ret.std() * np.sqrt(252) if bm_ret.std() > 0 else 0
b_peak = benchmark_nav.cummax()
b_dd = ((benchmark_nav - b_peak) / b_peak).min()
# Get latest holdings info
last_rec = self.daily_records[-1]
last_date = pd.Timestamp(last_rec['date'])
holdings = last_rec['holdings']
factors = last_rec.get('factors', {})
# Build code config dict for display
code_config = {}
for name, asset in self.config.asset_pools.assets.items():
code_config[asset.signal_source] = {
'name': asset.name if hasattr(asset, 'name') else name,
'etf': asset.trade_source,
'market': asset.group,
}
# Build positions info for table
# Sort holdings by momentum score descending
weight = 1.0 / self.select_num if self.select_num > 0 else 1.0
sorted_holdings = sorted(holdings, key=lambda c: factors.get(c, 0) or 0, reverse=True)
# Determine previous holdings to distinguish "调入" vs "维持"
last_idx = len(self.daily_records) - 1
prev_holdings = set()
if last_idx > 0:
prev_holdings = set(self.daily_records[last_idx - 1]['holdings'])
is_rebalance_day = last_rec.get('is_rebalance', False)
positions_info = []
for code in sorted_holdings:
cfg = code_config.get(code, {})
name = cfg.get('name', code)
etf_code = cfg.get('etf', '')
score = factors.get(code)
idx_close = self._get_index_close(code, last_date)
trade_code = self.signal_to_trade.get(code, code)
etf_close = self._get_etf_close(trade_code, last_date)
premium = self._get_latest_premium(trade_code, last_date)
# Determine action: 调入 (new on rebalance day) vs 维持 (already holding)
if is_rebalance_day and code not in prev_holdings:
action = '调入'
else:
action = '维持'
# Find entry info: scan backwards for continuous holding start
entry_date = None
entry_price = None
holding_days = 0
pnl = None
for rec in reversed(self.daily_records):
if code in rec['holdings']:
entry_date = pd.Timestamp(rec['date'])
p = self._get_etf_prices(trade_code, entry_date)
entry_price = p['open'] if p else None
else:
break
if entry_date is not None:
holding_days = (last_date - entry_date).days
if entry_price and entry_price > 0 and etf_close and etf_close > 0:
pnl = etf_close / entry_price - 1
positions_info.append({
'name': name, 'code': code, 'etf': etf_code,
'weight': weight, 'score': score,
'idx_close': idx_close, 'etf_close': etf_close,
'premium': premium, 'action': action,
'entry_date': entry_date, 'entry_price': entry_price,
'holding_days': holding_days, 'pnl': pnl,
})
# Find exit positions: ONLY show when the last day is a rebalance day
# If last day is NOT rebalance, don't show any "调出" info (already past)
exit_positions = []
if last_rec.get('is_rebalance', False):
# Find the previous day's holdings to compare
last_idx = len(self.daily_records) - 1
if last_idx > 0:
old_holdings = set(self.daily_records[last_idx - 1]['holdings'])
new_holdings = set(holdings)
removed = old_holdings - new_holdings
for code in sorted(removed):
cfg = code_config.get(code, {})
name = cfg.get('name', code)
etf_code = cfg.get('etf', '')
idx_close = self._get_index_close(code, last_date)
trade_code = self.signal_to_trade.get(code, code)
etf_close = self._get_etf_close(trade_code, last_date)
premium = self._get_latest_premium(trade_code, last_date)
exit_positions.append({
'name': name, 'code': code, 'etf': etf_code,
'weight': weight, 'score': None,
'idx_close': idx_close, 'etf_close': etf_close,
'premium': premium, 'action': '调出',
'entry_date': None, 'entry_price': None,
'holding_days': 0, 'pnl': None,
})
# ==================== Plot ====================
plt.rcParams["font.sans-serif"] = ["Arial Unicode MS", "WenQuanYi Zen Hei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
n_rows = len(positions_info) + len(exit_positions)
signal_h = max(1.5, 0.5 + n_rows * 0.35)
fig = plt.figure(figsize=(16, 10 + signal_h + 1.2 + 8))
gs = fig.add_gridspec(5, 1, height_ratios=[signal_h, 1.2, 3, 1, 1.2], hspace=0.35)
# Panel 0: Signal table
ax0 = fig.add_subplot(gs[0])
ax0.axis("off")
ax0.set_title(f"最新调仓信号 (信号日期: {last_date.strftime('%Y-%m-%d')},下一交易日执行)",
fontsize=14, fontweight="bold", loc="left", pad=15)
col_labels = ["标的名称", "指数代码", "ETF代码", "仓位", "得分", "指数最新价",
"ETF收盘价", "溢价率", "操作", "持有天数", "盈亏"]
table_data = []
for p in positions_info:
score_s = f"{p['score']:.2f}" if p['score'] is not None else ""
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else ""
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else ""
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else ""
days_s = str(p['holding_days']) if p['holding_days'] > 0 else ""
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else ""
table_data.append([
p['name'], p['code'], p['etf'], f"{p['weight']:.0%}",
score_s, idx_s, etf_s, prem_s, p['action'], days_s, pnl_s
])
for p in exit_positions:
idx_s = f"{p['idx_close']:.2f}" if p['idx_close'] is not None else ""
etf_s = f"{p['etf_close']:.3f}" if p['etf_close'] is not None else ""
prem_s = f"{p['premium']:+.2%}" if p['premium'] is not None else ""
table_data.append([
p['name'], p['code'], p['etf'], f"{p['weight']:.0%}",
"", idx_s, etf_s, prem_s, "调出", "", ""
])
if table_data:
tbl = ax0.table(cellText=table_data, colLabels=col_labels, loc="center",
cellLoc="center", bbox=[0, 0, 1, 1])
tbl.auto_set_font_size(False)
tbl.set_fontsize(9)
tbl.scale(1, 1.8)
for j in range(len(col_labels)):
tbl[0, j].set_facecolor("#2C3E50")
tbl[0, j].set_text_props(color="white", fontweight="bold")
for i in range(len(table_data)):
action = table_data[i][8]
# Legacy color scheme: 调入=green, 维持=yellow, 调出=red
if action == "调入":
color = "#d4edda" # 浅绿色
elif action == "调出":
color = "#f8d7da" # 浅红色
else:
color = "#fff3cd" # 浅黄色(维持)
for j in range(len(col_labels)):
tbl[i + 1, j].set_facecolor(color)
# Panel 1: Metrics table
ax1 = fig.add_subplot(gs[1])
ax1.axis("off")
ax1.set_title("策略绩效对比", fontsize=14, fontweight="bold", loc="left", pad=10)
start_s = df.index.min().strftime('%Y-%m-%d')
end_s = df.index.max().strftime('%Y-%m-%d')
perf_cols = ["策略", "开始时间", "结束时间", "累计收益", "年化收益", "最大回撤", "夏普比率", "Calmar比率", "日胜率"]
strat_row = ["轮动策略", start_s, end_s,
f"{s_total_return:.2%}", f"{s_annual:.2%}", f"{s_dd:.2%}",
f"{s_sharpe:.2f}", f"{s_calmar:.2f}", f"{s_win_rate:.2%}"]
bench_row = [f"基准({self.benchmark_name})", start_s, end_s,
f"{b_total_return:.2%}", f"{b_annual:.2%}", f"{b_dd:.2%}",
f"{b_sharpe:.2f}", "", ""]
ptbl = ax1.table(cellText=[strat_row, bench_row], colLabels=perf_cols,
loc="center", cellLoc="center", bbox=[0, 0, 1, 1])
ptbl.auto_set_font_size(False)
ptbl.set_fontsize(10)
ptbl.scale(1, 1.8)
for j in range(len(perf_cols)):
ptbl[0, j].set_facecolor("#2C3E50")
ptbl[0, j].set_text_props(color="white", fontweight="bold")
ptbl[1, j].set_facecolor("#d4edda")
ptbl[2, j].set_facecolor("#cce5ff")
# Panel 2: NAV curves
ax2 = fig.add_subplot(gs[2])
ax2.plot(strategy_nav.index, strategy_nav.values,
label="轮动策略", linewidth=2, color="#E74C3C")
if benchmark_nav is not None:
ax2.plot(benchmark_nav.index, benchmark_nav.values,
label=self.benchmark_name, linewidth=1.5, color="#3498DB", alpha=0.8)
colors = plt.cm.tab20.colors
for i, code in enumerate(self.signal_codes):
if code in asset_navs:
cfg = code_config.get(code, {})
lbl = cfg.get('name', code) if i < 10 else None
ax2.plot(asset_navs[code].index, asset_navs[code].values,
label=lbl, linewidth=0.8, alpha=0.4,
color=colors[i % len(colors)])
ax2.set_title("ETF轮动策略 - 净值曲线", fontsize=16, fontweight="bold")
ax2.set_ylabel("净值")
ax2.legend(loc="upper left", fontsize=8, ncol=2)
ax2.grid(True, alpha=0.3)
ax2.set_yscale("log")
# Panel 3: Drawdown
ax3 = fig.add_subplot(gs[3])
drawdown = (strategy_nav - s_peak) / s_peak
ax3.fill_between(drawdown.index, drawdown.values, 0, alpha=0.5, color="#E74C3C")
ax3.set_title("策略回撤", fontsize=12)
ax3.set_ylabel("回撤")
ax3.grid(True, alpha=0.3)
# Panel 4: Holdings distribution
ax4 = fig.add_subplot(gs[4])
holdings_series = df['holdings']
for i, code in enumerate(self.signal_codes):
cfg = code_config.get(code, {})
name = cfg.get('name', code)
mask = holdings_series.apply(lambda h: code in h)
if mask.any():
ax4.fill_between(mask.index, i, i + 0.8,
where=mask, alpha=0.7,
color=colors[i % len(colors)], label=name)
ylabels = [code_config.get(c, {}).get('name', c) for c in self.signal_codes]
ax4.set_title("每日持仓分布", fontsize=12)
ax4.set_yticks(range(len(ylabels)))
ax4.set_yticklabels(ylabels, fontsize=7)
ax4.grid(True, alpha=0.3)
chart_path = output_dir / 'simple_rotation_report.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.close()
print(f" + Report: {chart_path}")
# ============================================================
# Entry point
# ============================================================
if __name__ == "__main__":
if 'FLASK_API_URL' not in os.environ:
os.environ['FLASK_API_URL'] = 'https://k3s.tokenpluse.xyz'
strategy = SimpleRotationStrategy()
result = strategy.run()
if result:
strategy.export_results()
strategy.generate_report()