diff --git a/rotation/results/simple_rotation_report.png b/rotation/results/simple_rotation_report.png new file mode 100644 index 0000000..7355307 Binary files /dev/null and b/rotation/results/simple_rotation_report.png differ diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index ad33fb1..b572781 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -229,16 +229,14 @@ def compute_position_weights( def is_crash(prices: np.ndarray) -> bool: - """Crash filter: 3 consecutive days drop > 5%""" + """Crash filter: single day drop > 5% (con1 only)""" if len(prices) < 4: return False p = prices[-4:] r1 = p[3] / p[2] r2 = p[2] / p[1] r3 = p[1] / p[0] - con1 = min(r1, r2, r3) < 0.95 - con2 = (r1 < 1 and r2 < 1 and r3 < 1 and p[3] / p[0] < 0.95) - return con1 or con2 + return min(r1, r2, r3) < 0.95 # ============================================================ @@ -415,6 +413,7 @@ class SimpleRotationStrategy: self.n_days = self.config.factor.n_days self.select_num = self.config.rotation.select_num self.trade_cost = self.config.rebalance.trade_cost + self.min_hold_days = self.config.rebalance.min_hold_days self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank' # Dynamic threshold @@ -502,20 +501,12 @@ class SimpleRotationStrategy: if len(recent) < self.n_days: return None prices = recent['close'].values[-self.n_days:] - if len(prices) >= 4 and is_crash(prices): - return 0.0 - # Dispatch based on factor type - ft = self.config.factor.type - if ft == FactorType.VOL_ADJUSTED_MOMENTUM: - return vol_adjusted_momentum_score(prices) - elif ft == FactorType.SLOPE_R2: - return slope_r2_score(prices) - elif ft == FactorType.STANDARDIZED_SLOPE: - return standardized_slope_score(prices) - elif ft == FactorType.MOMENTUM: - return momentum_score(prices) - return weighted_momentum_score(prices) + base_momentum = self._compute_base_momentum(prices) + + if is_crash(prices): + return 0.0 + return base_momentum def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]: """Compute momentum using the configured factor function. @@ -529,13 +520,29 @@ class SimpleRotationStrategy: if len(recent) < self.n_days: return None prices = recent['close'].values[-self.n_days:] - if len(prices) >= 4 and is_crash(prices): - return 0.0 if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM: - return weighted_momentum_score(prices) - # All other factors: use the same score function as ranking - return self._compute_momentum(signal_code, date) + base_momentum = weighted_momentum_score(prices) + else: + # Use the same score function as ranking + base_momentum = self._compute_base_momentum(prices) + + if is_crash(prices): + return 0.0 + return base_momentum + + def _compute_base_momentum(self, prices: np.ndarray) -> float: + """Compute base momentum score without crash filter.""" + ft = self.config.factor.type + if ft == FactorType.VOL_ADJUSTED_MOMENTUM: + return vol_adjusted_momentum_score(prices) + elif ft == FactorType.SLOPE_R2: + return slope_r2_score(prices) + elif ft == FactorType.STANDARDIZED_SLOPE: + return standardized_slope_score(prices) + elif ft == FactorType.MOMENTUM: + return momentum_score(prices) + return weighted_momentum_score(prices) def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]: """ @@ -746,6 +753,29 @@ class SimpleRotationStrategy: new_holdings, factors, bond_momentum = self._generate_signals(signal_date) + # Apply min_hold_days constraint: prevent removing assets held for less than min_hold_days + if self.min_hold_days > 1 and current_holdings: + forced_hold = [] + for code in current_holdings: + if code not in new_holdings and code in entry_info: + entry_dt = pd.Timestamp(entry_info[code]['entry_date']) + held_days = (date - entry_dt).days + if held_days < self.min_hold_days: + forced_hold.append(code) + if forced_hold: + # Add forced holds back to new_holdings, remove lowest-ranked assets to make room + new_set = set(new_holdings) + for code in forced_hold: + if code not in new_set: + new_holdings.append(code) + new_set.add(code) + # Trim to select_num by removing lowest momentum assets (not forced holds) + if len(new_holdings) > self.select_num: + # Sort by momentum descending, keep forced holds first + optional = [c for c in new_holdings if c not in forced_hold] + optional.sort(key=lambda c: factors.get(c, 0), reverse=True) + new_holdings = forced_hold + optional[:self.select_num - len(forced_hold)] + # Data integrity check: if any currently held asset is missing from # today's factors, abort immediately to prevent false rebalancing. if current_holdings: @@ -1409,7 +1439,7 @@ class SimpleRotationStrategy: col_labels = ["排名", "标的名称", "市场", "指数代码", "ETF代码", "仓位", "得分", "指数最新价", "ETF收盘价", "溢价率", "状态", - "进场日期", "持有天数", "盈亏", "退场日期", "退场价格"] + "进场日期", "持有天数", "盈亏"] table_data = [] row_actions = [] # track action for coloring @@ -1426,12 +1456,10 @@ class SimpleRotationStrategy: pnl_s = f"{p['pnl']:+.2%}" if p['pnl'] is not None else "—" weight_s = f"{p['weight']:.0%}" if p['weight'] > 0 else "—" market = code_config.get(p['code'], {}).get('market', '—') - exit_date_s = p['exit_date'].strftime('%Y-%m-%d') if p.get('exit_date') else "—" - exit_price_s = f"{p['exit_price']:.3f}" if p.get('exit_price') else "—" table_data.append([ rank, p['name'], market, p['code'], p['etf'], weight_s, score_s, idx_s, etf_s, prem_s, p['action'], - entry_date_s, days_s, pnl_s, exit_date_s, exit_price_s + entry_date_s, days_s, pnl_s ]) row_actions.append(p['action']) @@ -1599,7 +1627,7 @@ class SimpleRotationStrategy: ax5.grid(True, alpha=0.3) chart_path = output_dir / 'simple_rotation_report.png' - plt.savefig(str(chart_path), dpi=150, bbox_inches="tight") + plt.savefig(str(chart_path), dpi=300, bbox_inches="tight") plt.close() print(f" + Report: {chart_path}")