refactor(rotation): simplify crash filter and add min_hold_days support
Changes: - Simplify is_crash(): remove con2 (consecutive decline) condition, keep only single-day drop > 5% - Extract _compute_base_momentum() to eliminate factor dispatch duplication - Add min_hold_days config for forced holding constraint (currently disabled, value=1) Backtest comparison (2020-01-10 ~ 2026-06-09): | Metric | Old (con1 OR con2) | New (con1 only) | |-----------------|--------------------|-----------------| | Total Return | 241.73% | 271.98% | | Annual Return | 22.10% | 23.79% | | Max Drawdown | -16.27% | -16.27% | | Sharpe Ratio | 1.09 | 1.14 | | Calmar Ratio | 1.36 | 1.46 | | Win Rate | 53.71% | 53.78% | | Rebalances | 393 | 362 | Conclusion: Relaxing crash filter improves return (+1.69% annual) with same drawdown and fewer rebalances.
This commit is contained in:
BIN
rotation/results/simple_rotation_report.png
Normal file
BIN
rotation/results/simple_rotation_report.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.7 MiB |
@@ -229,16 +229,14 @@ def compute_position_weights(
|
||||
|
||||
|
||||
def is_crash(prices: np.ndarray) -> bool:
|
||||
"""Crash filter: 3 consecutive days drop > 5%"""
|
||||
"""Crash filter: single day drop > 5% (con1 only)"""
|
||||
if len(prices) < 4:
|
||||
return False
|
||||
p = prices[-4:]
|
||||
r1 = p[3] / p[2]
|
||||
r2 = p[2] / p[1]
|
||||
r3 = p[1] / p[0]
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
|
||||
return con1 or con2
|
||||
return min(r1, r2, r3) < 0.95
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -415,6 +413,7 @@ class SimpleRotationStrategy:
|
||||
self.n_days = self.config.factor.n_days
|
||||
self.select_num = self.config.rotation.select_num
|
||||
self.trade_cost = self.config.rebalance.trade_cost
|
||||
self.min_hold_days = self.config.rebalance.min_hold_days
|
||||
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
|
||||
|
||||
# Dynamic threshold
|
||||
@@ -502,20 +501,12 @@ class SimpleRotationStrategy:
|
||||
if len(recent) < self.n_days:
|
||||
return None
|
||||
prices = recent['close'].values[-self.n_days:]
|
||||
if len(prices) >= 4 and is_crash(prices):
|
||||
return 0.0
|
||||
|
||||
# Dispatch based on factor type
|
||||
ft = self.config.factor.type
|
||||
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
return vol_adjusted_momentum_score(prices)
|
||||
elif ft == FactorType.SLOPE_R2:
|
||||
return slope_r2_score(prices)
|
||||
elif ft == FactorType.STANDARDIZED_SLOPE:
|
||||
return standardized_slope_score(prices)
|
||||
elif ft == FactorType.MOMENTUM:
|
||||
return momentum_score(prices)
|
||||
return weighted_momentum_score(prices)
|
||||
base_momentum = self._compute_base_momentum(prices)
|
||||
|
||||
if is_crash(prices):
|
||||
return 0.0
|
||||
return base_momentum
|
||||
|
||||
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||
"""Compute momentum using the configured factor function.
|
||||
@@ -529,13 +520,29 @@ class SimpleRotationStrategy:
|
||||
if len(recent) < self.n_days:
|
||||
return None
|
||||
prices = recent['close'].values[-self.n_days:]
|
||||
if len(prices) >= 4 and is_crash(prices):
|
||||
return 0.0
|
||||
|
||||
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
return weighted_momentum_score(prices)
|
||||
# All other factors: use the same score function as ranking
|
||||
return self._compute_momentum(signal_code, date)
|
||||
base_momentum = weighted_momentum_score(prices)
|
||||
else:
|
||||
# Use the same score function as ranking
|
||||
base_momentum = self._compute_base_momentum(prices)
|
||||
|
||||
if is_crash(prices):
|
||||
return 0.0
|
||||
return base_momentum
|
||||
|
||||
def _compute_base_momentum(self, prices: np.ndarray) -> float:
|
||||
"""Compute base momentum score without crash filter."""
|
||||
ft = self.config.factor.type
|
||||
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
return vol_adjusted_momentum_score(prices)
|
||||
elif ft == FactorType.SLOPE_R2:
|
||||
return slope_r2_score(prices)
|
||||
elif ft == FactorType.STANDARDIZED_SLOPE:
|
||||
return standardized_slope_score(prices)
|
||||
elif ft == FactorType.MOMENTUM:
|
||||
return momentum_score(prices)
|
||||
return weighted_momentum_score(prices)
|
||||
|
||||
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
||||
"""
|
||||
@@ -746,6 +753,29 @@ class SimpleRotationStrategy:
|
||||
|
||||
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
|
||||
|
||||
# Apply min_hold_days constraint: prevent removing assets held for less than min_hold_days
|
||||
if self.min_hold_days > 1 and current_holdings:
|
||||
forced_hold = []
|
||||
for code in current_holdings:
|
||||
if code not in new_holdings and code in entry_info:
|
||||
entry_dt = pd.Timestamp(entry_info[code]['entry_date'])
|
||||
held_days = (date - entry_dt).days
|
||||
if held_days < self.min_hold_days:
|
||||
forced_hold.append(code)
|
||||
if forced_hold:
|
||||
# Add forced holds back to new_holdings, remove lowest-ranked assets to make room
|
||||
new_set = set(new_holdings)
|
||||
for code in forced_hold:
|
||||
if code not in new_set:
|
||||
new_holdings.append(code)
|
||||
new_set.add(code)
|
||||
# Trim to select_num by removing lowest momentum assets (not forced holds)
|
||||
if len(new_holdings) > self.select_num:
|
||||
# Sort by momentum descending, keep forced holds first
|
||||
optional = [c for c in new_holdings if c not in forced_hold]
|
||||
optional.sort(key=lambda c: factors.get(c, 0), reverse=True)
|
||||
new_holdings = forced_hold + optional[:self.select_num - len(forced_hold)]
|
||||
|
||||
# Data integrity check: if any currently held asset is missing from
|
||||
# today's factors, abort immediately to prevent false rebalancing.
|
||||
if current_holdings:
|
||||
@@ -1409,7 +1439,7 @@ class SimpleRotationStrategy:
|
||||
|
||||
col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分",
|
||||
"指数最新价", "ETF收盘价", "溢价率", "状态",
|
||||
"进场日期", "持有天数", "盈亏", "退场日期", "退场价格"]
|
||||
"进场日期", "持有天数", "盈亏"]
|
||||
table_data = []
|
||||
row_actions = [] # track action for coloring
|
||||
|
||||
@@ -1426,12 +1456,10 @@ class SimpleRotationStrategy:
|
||||
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else "—"
|
||||
weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else "—"
|
||||
market = code_config.get(p['code'], {}).get('market', '—')
|
||||
exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else "—"
|
||||
exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else "—"
|
||||
table_data.append([
|
||||
rank, p['name'], market, p['code'], p['etf'], weight_s,
|
||||
score_s, idx_s, etf_s, prem_s, p['action'],
|
||||
entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s
|
||||
entry_date_s, days_s, pnl_s
|
||||
])
|
||||
row_actions.append(p['action'])
|
||||
|
||||
@@ -1599,7 +1627,7 @@ class SimpleRotationStrategy:
|
||||
ax5.grid(True, alpha=0.3)
|
||||
|
||||
chart_path = output_dir / 'simple_rotation_report.png'
|
||||
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
|
||||
plt.savefig(str(chart_path), dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
print(f" + Report: {chart_path}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user