refactor(rotation): simplify crash filter and add min_hold_days support

Changes:
- Simplify is_crash(): remove con2 (consecutive decline) condition, keep only single-day drop > 5%
- Extract _compute_base_momentum() to eliminate factor dispatch duplication
- Add min_hold_days config for forced holding constraint (currently disabled, value=1)

Backtest comparison (2020-01-10 ~ 2026-06-09):
| Metric          | Old (con1 OR con2) | New (con1 only) |
|-----------------|--------------------|-----------------|
| Total Return    | 241.73%            | 271.98%         |
| Annual Return   | 22.10%             | 23.79%          |
| Max Drawdown    | -16.27%            | -16.27%         |
| Sharpe Ratio    | 1.09               | 1.14            |
| Calmar Ratio    | 1.36               | 1.46            |
| Win Rate        | 53.71%             | 53.78%          |
| Rebalances      | 393                | 362             |

Conclusion: Relaxing crash filter improves return (+1.69% annual) with
same drawdown and fewer rebalances.
This commit is contained in:
2026-06-09 22:31:26 +08:00
parent e2038ae722
commit fe73c0f199
2 changed files with 55 additions and 27 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

View File

@@ -229,16 +229,14 @@ def compute_position_weights(
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
"""Crash filter: single day drop > 5% (con1 only)"""
if len(prices) < 4:
return False
p = prices[-4:]
r1 = p[3] / p[2]
r2 = p[2] / p[1]
r3 = p[1] / p[0]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95)
return con1 or con2
return min(r1, r2, r3) < 0.95
# ============================================================
@@ -415,6 +413,7 @@ class SimpleRotationStrategy:
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
self.min_hold_days = self.config.rebalance.min_hold_days
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
# Dynamic threshold
@@ -502,20 +501,12 @@ class SimpleRotationStrategy:
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
# Dispatch based on factor type
ft = self.config.factor.type
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
return vol_adjusted_momentum_score(prices)
elif ft == FactorType.SLOPE_R2:
return slope_r2_score(prices)
elif ft == FactorType.STANDARDIZED_SLOPE:
return standardized_slope_score(prices)
elif ft == FactorType.MOMENTUM:
return momentum_score(prices)
return weighted_momentum_score(prices)
base_momentum = self._compute_base_momentum(prices)
if is_crash(prices):
return 0.0
return base_momentum
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Compute momentum using the configured factor function.
@@ -529,13 +520,29 @@ class SimpleRotationStrategy:
if len(recent) < self.n_days:
return None
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
return weighted_momentum_score(prices)
# All other factors: use the same score function as ranking
return self._compute_momentum(signal_code, date)
base_momentum = weighted_momentum_score(prices)
else:
# Use the same score function as ranking
base_momentum = self._compute_base_momentum(prices)
if is_crash(prices):
return 0.0
return base_momentum
def _compute_base_momentum(self, prices: np.ndarray) -> float:
"""Compute base momentum score without crash filter."""
ft = self.config.factor.type
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
return vol_adjusted_momentum_score(prices)
elif ft == FactorType.SLOPE_R2:
return slope_r2_score(prices)
elif ft == FactorType.STANDARDIZED_SLOPE:
return standardized_slope_score(prices)
elif ft == FactorType.MOMENTUM:
return momentum_score(prices)
return weighted_momentum_score(prices)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
@@ -746,6 +753,29 @@ class SimpleRotationStrategy:
new_holdings, factors, bond_momentum = self._generate_signals(signal_date)
# Apply min_hold_days constraint: prevent removing assets held for less than min_hold_days
if self.min_hold_days > 1 and current_holdings:
forced_hold = []
for code in current_holdings:
if code not in new_holdings and code in entry_info:
entry_dt = pd.Timestamp(entry_info[code]['entry_date'])
held_days = (date - entry_dt).days
if held_days < self.min_hold_days:
forced_hold.append(code)
if forced_hold:
# Add forced holds back to new_holdings, remove lowest-ranked assets to make room
new_set = set(new_holdings)
for code in forced_hold:
if code not in new_set:
new_holdings.append(code)
new_set.add(code)
# Trim to select_num by removing lowest momentum assets (not forced holds)
if len(new_holdings) > self.select_num:
# Sort by momentum descending, keep forced holds first
optional = [c for c in new_holdings if c not in forced_hold]
optional.sort(key=lambda c: factors.get(c, 0), reverse=True)
new_holdings = forced_hold + optional[:self.select_num - len(forced_hold)]
# Data integrity check: if any currently held asset is missing from
# today's factors, abort immediately to prevent false rebalancing.
if current_holdings:
@@ -1409,7 +1439,7 @@ class SimpleRotationStrategy:
col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分",
"指数最新价", "ETF收盘价", "溢价率", "状态",
"进场日期", "持有天数", "盈亏", "退场日期", "退场价格"]
"进场日期", "持有天数", "盈亏"]
table_data = []
row_actions = [] # track action for coloring
@@ -1426,12 +1456,10 @@ class SimpleRotationStrategy:
pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else ""
weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else ""
market = code_config.get(p['code'], {}).get('market', '')
exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else ""
exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else ""
table_data.append([
rank, p['name'], market, p['code'], p['etf'], weight_s,
score_s, idx_s, etf_s, prem_s, p['action'],
entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s
entry_date_s, days_s, pnl_s
])
row_actions.append(p['action'])
@@ -1599,7 +1627,7 @@ class SimpleRotationStrategy:
ax5.grid(True, alpha=0.3)
chart_path = output_dir / 'simple_rotation_report.png'
plt.savefig(str(chart_path), dpi=150, bbox_inches="tight")
plt.savefig(str(chart_path), dpi=300, bbox_inches="tight")
plt.close()
print(f" + Report: {chart_path}")