feat(rotation): componentize position weighting + fix bond threshold consistency
- Extract compute_position_weights() as pluggable pure function - Add WeightType enum (equal/rank) and RotationConfig.weight field - Fix bond threshold dimension mismatch: use configured factor function for all assets instead of hardcoded weighted_momentum_score - Default weight: equal in config, active: rank in config_simple.yaml
This commit is contained in:
@@ -49,6 +49,12 @@ class ThresholdMode(str, Enum):
|
||||
DYNAMIC = "dynamic" # 动态阈值
|
||||
|
||||
|
||||
class WeightType(str, Enum):
|
||||
"""仓位加权模式"""
|
||||
EQUAL = "equal" # 等权
|
||||
RANK = "rank" # 按排名加权 (slot i gets (N-i)/triangular(N))
|
||||
|
||||
|
||||
class DataSourceType(str, Enum):
|
||||
"""数据源类型"""
|
||||
FLASK_API = "flask_api"
|
||||
@@ -169,6 +175,7 @@ class RotationConfig(BaseModel):
|
||||
select_num: int = Field(default=3, ge=1, le=10)
|
||||
diversified: bool = Field(default=True)
|
||||
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
||||
weight: WeightType = Field(default=WeightType.EQUAL)
|
||||
|
||||
|
||||
class RebalanceConfig(BaseModel):
|
||||
|
||||
@@ -108,6 +108,7 @@ rebalance:
|
||||
rotation:
|
||||
diversified: true
|
||||
select_num: 3
|
||||
weight: rank
|
||||
threshold:
|
||||
dynamic:
|
||||
fallback_enabled: true
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType
|
||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType
|
||||
|
||||
# ============================================================
|
||||
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||||
@@ -164,6 +164,49 @@ def momentum_score(prices: np.ndarray) -> float:
|
||||
return prices[-1] / prices[0] - 1
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Pure functions: position weight schemes
|
||||
# ============================================================
|
||||
|
||||
def compute_position_weights(
|
||||
ranked_holdings: List[str],
|
||||
weight_type: str = 'equal',
|
||||
) -> Dict[str, float]:
|
||||
"""Compute position weights from ranked slot list.
|
||||
|
||||
Args:
|
||||
ranked_holdings: Ordered list of signal codes, best first.
|
||||
May contain duplicates (e.g. bond fills).
|
||||
weight_type: 'equal' or 'rank'.
|
||||
|
||||
Returns:
|
||||
Dict mapping each unique code to its total weight (sum of slots).
|
||||
|
||||
Schemes:
|
||||
equal: each slot = 1/N, duplicates summed.
|
||||
rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed.
|
||||
For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%].
|
||||
"""
|
||||
N = len(ranked_holdings)
|
||||
if N == 0:
|
||||
return {}
|
||||
|
||||
weights: Dict[str, float] = {}
|
||||
|
||||
if weight_type == 'rank':
|
||||
triangular = N * (N + 1) / 2
|
||||
for i, code in enumerate(ranked_holdings):
|
||||
w = (N - i) / triangular
|
||||
weights[code] = weights.get(code, 0.0) + w
|
||||
else:
|
||||
# equal (default)
|
||||
w = 1.0 / N
|
||||
for code in ranked_holdings:
|
||||
weights[code] = weights.get(code, 0.0) + w
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def is_crash(prices: np.ndarray) -> bool:
|
||||
"""Crash filter: 3 consecutive days drop > 5%"""
|
||||
if len(prices) < 4:
|
||||
@@ -351,6 +394,7 @@ class SimpleRotationStrategy:
|
||||
self.n_days = self.config.factor.n_days
|
||||
self.select_num = self.config.rotation.select_num
|
||||
self.trade_cost = self.config.rebalance.trade_cost
|
||||
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
|
||||
|
||||
# Dynamic threshold
|
||||
threshold = self.config.rotation.threshold
|
||||
@@ -388,6 +432,9 @@ class SimpleRotationStrategy:
|
||||
self.daily_records: List[dict] = []
|
||||
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
||||
|
||||
# Position weights: code -> weight (updated each day by _generate_signals)
|
||||
self._position_weights: Dict[str, float] = {}
|
||||
|
||||
def _preload_data(self):
|
||||
"""Preload all historical data"""
|
||||
start_date = self.config.backtest.start_date
|
||||
@@ -449,7 +496,9 @@ class SimpleRotationStrategy:
|
||||
return weighted_momentum_score(prices)
|
||||
|
||||
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||
"""Always compute weighted momentum (for threshold comparison)"""
|
||||
"""Compute momentum using the configured factor function.
|
||||
For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score
|
||||
so that threshold comparison operates in raw return space."""
|
||||
if signal_code not in self.index_data:
|
||||
return None
|
||||
df = self.index_data[signal_code]
|
||||
@@ -460,7 +509,10 @@ class SimpleRotationStrategy:
|
||||
prices = recent['close'].values[-self.n_days:]
|
||||
if len(prices) >= 4 and is_crash(prices):
|
||||
return 0.0
|
||||
return weighted_momentum_score(prices)
|
||||
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
return weighted_momentum_score(prices)
|
||||
# All other factors: use the same score function as ranking
|
||||
return self._compute_momentum(signal_code, date)
|
||||
|
||||
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
||||
"""
|
||||
@@ -524,14 +576,19 @@ class SimpleRotationStrategy:
|
||||
|
||||
candidates = list(selected_by_group.values())
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
final_holdings = [code for code, _ in candidates[:self.select_num]]
|
||||
ranked_holdings = [code for code, _ in candidates[:self.select_num]]
|
||||
|
||||
if len(final_holdings) < self.select_num and self.bond_code:
|
||||
if self.bond_code not in final_holdings:
|
||||
n_slots = self.select_num - len(final_holdings)
|
||||
final_holdings.extend([self.bond_code] * n_slots)
|
||||
if len(ranked_holdings) < self.select_num and self.bond_code:
|
||||
if self.bond_code not in ranked_holdings:
|
||||
n_slots = self.select_num - len(ranked_holdings)
|
||||
ranked_holdings.extend([self.bond_code] * n_slots)
|
||||
|
||||
return sorted(final_holdings), factors, bond_momentum
|
||||
# Compute position weights via configured scheme
|
||||
self._position_weights = compute_position_weights(
|
||||
ranked_holdings, self.weight_type,
|
||||
)
|
||||
|
||||
return sorted(ranked_holdings), factors, bond_momentum
|
||||
|
||||
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
||||
"""Get ETF prices on a given date: {open, close, prev_close}
|
||||
@@ -566,52 +623,64 @@ class SimpleRotationStrategy:
|
||||
'prev_close': prev_close,
|
||||
}
|
||||
|
||||
def _get_weight(self, code: str, n_unique: int) -> float:
|
||||
"""Get position weight for a code. Falls back to equal weight if not available."""
|
||||
if code in self._position_weights:
|
||||
return self._position_weights[code]
|
||||
return 1.0 / n_unique if n_unique > 0 else 0.0
|
||||
|
||||
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
||||
"""
|
||||
Compute daily return (T+1 execution):
|
||||
- Hold: close-to-close
|
||||
- Sell: close-to-open (sold at open)
|
||||
- Buy: open-to-close (intraday)
|
||||
Compute daily return (T+1 execution) with configurable position weighting:
|
||||
- Hold: close-to-close, weighted by today's position weight
|
||||
- Sell: close-to-open (sold at open), weighted by today's position weight
|
||||
- Buy: open-to-close (intraday), weighted by today's position weight
|
||||
"""
|
||||
if not old_holdings:
|
||||
if not new_holdings:
|
||||
return 0.0
|
||||
weight = 1.0 / len(new_holdings)
|
||||
ret = 0.0
|
||||
seen = set()
|
||||
n_unique = len(set(new_holdings))
|
||||
for code in new_holdings:
|
||||
if code in seen:
|
||||
continue
|
||||
seen.add(code)
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
w = self._get_weight(code, n_unique)
|
||||
if p and p['open'] > 0:
|
||||
ret += weight * (p['close'] - p['open']) / p['open']
|
||||
ret += w * (p['close'] - p['open']) / p['open']
|
||||
if is_rebalance:
|
||||
ret -= self.trade_cost
|
||||
return ret
|
||||
|
||||
old_set = set(old_holdings)
|
||||
new_set = set(new_holdings)
|
||||
weight = 1.0 / len(old_holdings)
|
||||
n_unique = len(old_set)
|
||||
daily_return = 0.0
|
||||
|
||||
for code in old_holdings:
|
||||
for code in old_set:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
if p is None or p['prev_close'] == 0:
|
||||
continue
|
||||
w = self._get_weight(code, n_unique)
|
||||
if code in new_set:
|
||||
r = (p['close'] - p['prev_close']) / p['prev_close']
|
||||
else:
|
||||
r = (p['open'] - p['prev_close']) / p['prev_close']
|
||||
if not math.isnan(r):
|
||||
daily_return += weight * r
|
||||
daily_return += w * r
|
||||
|
||||
for code in new_holdings:
|
||||
if code not in old_set:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
if p and p['open'] > 0 and not math.isnan(p['close']):
|
||||
r = (p['close'] - p['open']) / p['open']
|
||||
if not math.isnan(r):
|
||||
daily_return += weight * r
|
||||
for code in new_set - old_set:
|
||||
tc = self.signal_to_trade.get(code, code)
|
||||
p = self._get_etf_prices(tc, date)
|
||||
w = self._get_weight(code, len(new_set))
|
||||
if p and p['open'] > 0 and not math.isnan(p['close']):
|
||||
r = (p['close'] - p['open']) / p['open']
|
||||
if not math.isnan(r):
|
||||
daily_return += w * r
|
||||
|
||||
if is_rebalance:
|
||||
daily_return -= self.trade_cost
|
||||
|
||||
Reference in New Issue
Block a user