From 4973a9a2a56e03be5d926bc339dc6e9d650c996d Mon Sep 17 00:00:00 2001 From: aszerW Date: Sat, 6 Jun 2026 22:28:08 +0800 Subject: [PATCH] feat(rotation): componentize position weighting + fix bond threshold consistency - Extract compute_position_weights() as pluggable pure function - Add WeightType enum (equal/rank) and RotationConfig.weight field - Fix bond threshold dimension mismatch: use configured factor function for all assets instead of hardcoded weighted_momentum_score - Default weight: equal in config, active: rank in config_simple.yaml --- rotation/config_loader.py | 7 +++ rotation/config_simple.yaml | 1 + rotation/simple_rotation.py | 121 ++++++++++++++++++++++++++++-------- 3 files changed, 103 insertions(+), 26 deletions(-) diff --git a/rotation/config_loader.py b/rotation/config_loader.py index b34aad5..81ac0a4 100644 --- a/rotation/config_loader.py +++ b/rotation/config_loader.py @@ -49,6 +49,12 @@ class ThresholdMode(str, Enum): DYNAMIC = "dynamic" # 动态阈值 +class WeightType(str, Enum): + """仓位加权模式""" + EQUAL = "equal" # 等权 + RANK = "rank" # 按排名加权 (slot i gets (N-i)/triangular(N)) + + class DataSourceType(str, Enum): """数据源类型""" FLASK_API = "flask_api" @@ -169,6 +175,7 @@ class RotationConfig(BaseModel): select_num: int = Field(default=3, ge=1, le=10) diversified: bool = Field(default=True) threshold: ThresholdConfig = Field(default_factory=ThresholdConfig) + weight: WeightType = Field(default=WeightType.EQUAL) class RebalanceConfig(BaseModel): diff --git a/rotation/config_simple.yaml b/rotation/config_simple.yaml index 01be8fe..3111bb1 100644 --- a/rotation/config_simple.yaml +++ b/rotation/config_simple.yaml @@ -108,6 +108,7 @@ rebalance: rotation: diversified: true select_num: 3 + weight: rank threshold: dynamic: fallback_enabled: true diff --git a/rotation/simple_rotation.py b/rotation/simple_rotation.py index 0caf119..b3115c7 100644 --- a/rotation/simple_rotation.py +++ b/rotation/simple_rotation.py @@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Tuple PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) -from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType +from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType # ============================================================ # HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF) @@ -164,6 +164,49 @@ def momentum_score(prices: np.ndarray) -> float: return prices[-1] / prices[0] - 1 +# ============================================================ +# Pure functions: position weight schemes +# ============================================================ + +def compute_position_weights( + ranked_holdings: List[str], + weight_type: str = 'equal', +) -> Dict[str, float]: + """Compute position weights from ranked slot list. + + Args: + ranked_holdings: Ordered list of signal codes, best first. + May contain duplicates (e.g. bond fills). + weight_type: 'equal' or 'rank'. + + Returns: + Dict mapping each unique code to its total weight (sum of slots). + + Schemes: + equal: each slot = 1/N, duplicates summed. + rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed. + For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%]. + """ + N = len(ranked_holdings) + if N == 0: + return {} + + weights: Dict[str, float] = {} + + if weight_type == 'rank': + triangular = N * (N + 1) / 2 + for i, code in enumerate(ranked_holdings): + w = (N - i) / triangular + weights[code] = weights.get(code, 0.0) + w + else: + # equal (default) + w = 1.0 / N + for code in ranked_holdings: + weights[code] = weights.get(code, 0.0) + w + + return weights + + def is_crash(prices: np.ndarray) -> bool: """Crash filter: 3 consecutive days drop > 5%""" if len(prices) < 4: @@ -351,6 +394,7 @@ class SimpleRotationStrategy: self.n_days = self.config.factor.n_days self.select_num = self.config.rotation.select_num self.trade_cost = self.config.rebalance.trade_cost + self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank' # Dynamic threshold threshold = self.config.rotation.threshold @@ -388,6 +432,9 @@ class SimpleRotationStrategy: self.daily_records: List[dict] = [] self.trading_calendar: Optional[pd.DatetimeIndex] = None + # Position weights: code -> weight (updated each day by _generate_signals) + self._position_weights: Dict[str, float] = {} + def _preload_data(self): """Preload all historical data""" start_date = self.config.backtest.start_date @@ -449,7 +496,9 @@ class SimpleRotationStrategy: return weighted_momentum_score(prices) def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]: - """Always compute weighted momentum (for threshold comparison)""" + """Compute momentum using the configured factor function. + For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score + so that threshold comparison operates in raw return space.""" if signal_code not in self.index_data: return None df = self.index_data[signal_code] @@ -460,7 +509,10 @@ class SimpleRotationStrategy: prices = recent['close'].values[-self.n_days:] if len(prices) >= 4 and is_crash(prices): return 0.0 - return weighted_momentum_score(prices) + if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM: + return weighted_momentum_score(prices) + # All other factors: use the same score function as ranking + return self._compute_momentum(signal_code, date) def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]: """ @@ -524,14 +576,19 @@ class SimpleRotationStrategy: candidates = list(selected_by_group.values()) candidates.sort(key=lambda x: x[1], reverse=True) - final_holdings = [code for code, _ in candidates[:self.select_num]] + ranked_holdings = [code for code, _ in candidates[:self.select_num]] - if len(final_holdings) < self.select_num and self.bond_code: - if self.bond_code not in final_holdings: - n_slots = self.select_num - len(final_holdings) - final_holdings.extend([self.bond_code] * n_slots) + if len(ranked_holdings) < self.select_num and self.bond_code: + if self.bond_code not in ranked_holdings: + n_slots = self.select_num - len(ranked_holdings) + ranked_holdings.extend([self.bond_code] * n_slots) - return sorted(final_holdings), factors, bond_momentum + # Compute position weights via configured scheme + self._position_weights = compute_position_weights( + ranked_holdings, self.weight_type, + ) + + return sorted(ranked_holdings), factors, bond_momentum def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]: """Get ETF prices on a given date: {open, close, prev_close} @@ -566,52 +623,64 @@ class SimpleRotationStrategy: 'prev_close': prev_close, } + def _get_weight(self, code: str, n_unique: int) -> float: + """Get position weight for a code. Falls back to equal weight if not available.""" + if code in self._position_weights: + return self._position_weights[code] + return 1.0 / n_unique if n_unique > 0 else 0.0 + def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance): """ - Compute daily return (T+1 execution): - - Hold: close-to-close - - Sell: close-to-open (sold at open) - - Buy: open-to-close (intraday) + Compute daily return (T+1 execution) with configurable position weighting: + - Hold: close-to-close, weighted by today's position weight + - Sell: close-to-open (sold at open), weighted by today's position weight + - Buy: open-to-close (intraday), weighted by today's position weight """ if not old_holdings: if not new_holdings: return 0.0 - weight = 1.0 / len(new_holdings) ret = 0.0 + seen = set() + n_unique = len(set(new_holdings)) for code in new_holdings: + if code in seen: + continue + seen.add(code) tc = self.signal_to_trade.get(code, code) p = self._get_etf_prices(tc, date) + w = self._get_weight(code, n_unique) if p and p['open'] > 0: - ret += weight * (p['close'] - p['open']) / p['open'] + ret += w * (p['close'] - p['open']) / p['open'] if is_rebalance: ret -= self.trade_cost return ret old_set = set(old_holdings) new_set = set(new_holdings) - weight = 1.0 / len(old_holdings) + n_unique = len(old_set) daily_return = 0.0 - for code in old_holdings: + for code in old_set: tc = self.signal_to_trade.get(code, code) p = self._get_etf_prices(tc, date) if p is None or p['prev_close'] == 0: continue + w = self._get_weight(code, n_unique) if code in new_set: r = (p['close'] - p['prev_close']) / p['prev_close'] else: r = (p['open'] - p['prev_close']) / p['prev_close'] if not math.isnan(r): - daily_return += weight * r + daily_return += w * r - for code in new_holdings: - if code not in old_set: - tc = self.signal_to_trade.get(code, code) - p = self._get_etf_prices(tc, date) - if p and p['open'] > 0 and not math.isnan(p['close']): - r = (p['close'] - p['open']) / p['open'] - if not math.isnan(r): - daily_return += weight * r + for code in new_set - old_set: + tc = self.signal_to_trade.get(code, code) + p = self._get_etf_prices(tc, date) + w = self._get_weight(code, len(new_set)) + if p and p['open'] > 0 and not math.isnan(p['close']): + r = (p['close'] - p['open']) / p['open'] + if not math.isnan(r): + daily_return += w * r if is_rebalance: daily_return -= self.trade_cost