feat(rotation): componentize position weighting + fix bond threshold consistency

- Extract compute_position_weights() as pluggable pure function
- Add WeightType enum (equal/rank) and RotationConfig.weight field
- Fix bond threshold dimension mismatch: use configured factor function
  for all assets instead of hardcoded weighted_momentum_score
- Default weight: equal in config, active: rank in config_simple.yaml
This commit is contained in:
2026-06-06 22:28:08 +08:00
parent 44588d5026
commit 4973a9a2a5
3 changed files with 103 additions and 26 deletions

View File

@@ -49,6 +49,12 @@ class ThresholdMode(str, Enum):
DYNAMIC = "dynamic" # 动态阈值
class WeightType(str, Enum):
"""仓位加权模式"""
EQUAL = "equal" # 等权
RANK = "rank" # 按排名加权 (slot i gets (N-i)/triangular(N))
class DataSourceType(str, Enum):
"""数据源类型"""
FLASK_API = "flask_api"
@@ -169,6 +175,7 @@ class RotationConfig(BaseModel):
select_num: int = Field(default=3, ge=1, le=10)
diversified: bool = Field(default=True)
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
weight: WeightType = Field(default=WeightType.EQUAL)
class RebalanceConfig(BaseModel):

View File

@@ -108,6 +108,7 @@ rebalance:
rotation:
diversified: true
select_num: 3
weight: rank
threshold:
dynamic:
fallback_enabled: true

View File

@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Tuple
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType
# ============================================================
# HTTP client (requests + trust_env=False绕过系统代理避免 SSL EOF)
@@ -164,6 +164,49 @@ def momentum_score(prices: np.ndarray) -> float:
return prices[-1] / prices[0] - 1
# ============================================================
# Pure functions: position weight schemes
# ============================================================
def compute_position_weights(
ranked_holdings: List[str],
weight_type: str = 'equal',
) -> Dict[str, float]:
"""Compute position weights from ranked slot list.
Args:
ranked_holdings: Ordered list of signal codes, best first.
May contain duplicates (e.g. bond fills).
weight_type: 'equal' or 'rank'.
Returns:
Dict mapping each unique code to its total weight (sum of slots).
Schemes:
equal: each slot = 1/N, duplicates summed.
rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed.
For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%].
"""
N = len(ranked_holdings)
if N == 0:
return {}
weights: Dict[str, float] = {}
if weight_type == 'rank':
triangular = N * (N + 1) / 2
for i, code in enumerate(ranked_holdings):
w = (N - i) / triangular
weights[code] = weights.get(code, 0.0) + w
else:
# equal (default)
w = 1.0 / N
for code in ranked_holdings:
weights[code] = weights.get(code, 0.0) + w
return weights
def is_crash(prices: np.ndarray) -> bool:
"""Crash filter: 3 consecutive days drop > 5%"""
if len(prices) < 4:
@@ -351,6 +394,7 @@ class SimpleRotationStrategy:
self.n_days = self.config.factor.n_days
self.select_num = self.config.rotation.select_num
self.trade_cost = self.config.rebalance.trade_cost
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
# Dynamic threshold
threshold = self.config.rotation.threshold
@@ -388,6 +432,9 @@ class SimpleRotationStrategy:
self.daily_records: List[dict] = []
self.trading_calendar: Optional[pd.DatetimeIndex] = None
# Position weights: code -> weight (updated each day by _generate_signals)
self._position_weights: Dict[str, float] = {}
def _preload_data(self):
"""Preload all historical data"""
start_date = self.config.backtest.start_date
@@ -449,7 +496,9 @@ class SimpleRotationStrategy:
return weighted_momentum_score(prices)
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
"""Always compute weighted momentum (for threshold comparison)"""
"""Compute momentum using the configured factor function.
For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score
so that threshold comparison operates in raw return space."""
if signal_code not in self.index_data:
return None
df = self.index_data[signal_code]
@@ -460,7 +509,10 @@ class SimpleRotationStrategy:
prices = recent['close'].values[-self.n_days:]
if len(prices) >= 4 and is_crash(prices):
return 0.0
return weighted_momentum_score(prices)
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
return weighted_momentum_score(prices)
# All other factors: use the same score function as ranking
return self._compute_momentum(signal_code, date)
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
"""
@@ -524,14 +576,19 @@ class SimpleRotationStrategy:
candidates = list(selected_by_group.values())
candidates.sort(key=lambda x: x[1], reverse=True)
final_holdings = [code for code, _ in candidates[:self.select_num]]
ranked_holdings = [code for code, _ in candidates[:self.select_num]]
if len(final_holdings) < self.select_num and self.bond_code:
if self.bond_code not in final_holdings:
n_slots = self.select_num - len(final_holdings)
final_holdings.extend([self.bond_code] * n_slots)
if len(ranked_holdings) < self.select_num and self.bond_code:
if self.bond_code not in ranked_holdings:
n_slots = self.select_num - len(ranked_holdings)
ranked_holdings.extend([self.bond_code] * n_slots)
return sorted(final_holdings), factors, bond_momentum
# Compute position weights via configured scheme
self._position_weights = compute_position_weights(
ranked_holdings, self.weight_type,
)
return sorted(ranked_holdings), factors, bond_momentum
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
"""Get ETF prices on a given date: {open, close, prev_close}
@@ -566,52 +623,64 @@ class SimpleRotationStrategy:
'prev_close': prev_close,
}
def _get_weight(self, code: str, n_unique: int) -> float:
"""Get position weight for a code. Falls back to equal weight if not available."""
if code in self._position_weights:
return self._position_weights[code]
return 1.0 / n_unique if n_unique > 0 else 0.0
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
"""
Compute daily return (T+1 execution):
- Hold: close-to-close
- Sell: close-to-open (sold at open)
- Buy: open-to-close (intraday)
Compute daily return (T+1 execution) with configurable position weighting:
- Hold: close-to-close, weighted by today's position weight
- Sell: close-to-open (sold at open), weighted by today's position weight
- Buy: open-to-close (intraday), weighted by today's position weight
"""
if not old_holdings:
if not new_holdings:
return 0.0
weight = 1.0 / len(new_holdings)
ret = 0.0
seen = set()
n_unique = len(set(new_holdings))
for code in new_holdings:
if code in seen:
continue
seen.add(code)
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
w = self._get_weight(code, n_unique)
if p and p['open'] > 0:
ret += weight * (p['close'] - p['open']) / p['open']
ret += w * (p['close'] - p['open']) / p['open']
if is_rebalance:
ret -= self.trade_cost
return ret
old_set = set(old_holdings)
new_set = set(new_holdings)
weight = 1.0 / len(old_holdings)
n_unique = len(old_set)
daily_return = 0.0
for code in old_holdings:
for code in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p is None or p['prev_close'] == 0:
continue
w = self._get_weight(code, n_unique)
if code in new_set:
r = (p['close'] - p['prev_close']) / p['prev_close']
else:
r = (p['open'] - p['prev_close']) / p['prev_close']
if not math.isnan(r):
daily_return += weight * r
daily_return += w * r
for code in new_holdings:
if code not in old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += weight * r
for code in new_set - old_set:
tc = self.signal_to_trade.get(code, code)
p = self._get_etf_prices(tc, date)
w = self._get_weight(code, len(new_set))
if p and p['open'] > 0 and not math.isnan(p['close']):
r = (p['close'] - p['open']) / p['open']
if not math.isnan(r):
daily_return += w * r
if is_rebalance:
daily_return -= self.trade_cost