feat(rotation): componentize position weighting + fix bond threshold consistency
- Extract compute_position_weights() as pluggable pure function - Add WeightType enum (equal/rank) and RotationConfig.weight field - Fix bond threshold dimension mismatch: use configured factor function for all assets instead of hardcoded weighted_momentum_score - Default weight: equal in config, active: rank in config_simple.yaml
This commit is contained in:
@@ -49,6 +49,12 @@ class ThresholdMode(str, Enum):
|
|||||||
DYNAMIC = "dynamic" # 动态阈值
|
DYNAMIC = "dynamic" # 动态阈值
|
||||||
|
|
||||||
|
|
||||||
|
class WeightType(str, Enum):
|
||||||
|
"""仓位加权模式"""
|
||||||
|
EQUAL = "equal" # 等权
|
||||||
|
RANK = "rank" # 按排名加权 (slot i gets (N-i)/triangular(N))
|
||||||
|
|
||||||
|
|
||||||
class DataSourceType(str, Enum):
|
class DataSourceType(str, Enum):
|
||||||
"""数据源类型"""
|
"""数据源类型"""
|
||||||
FLASK_API = "flask_api"
|
FLASK_API = "flask_api"
|
||||||
@@ -169,6 +175,7 @@ class RotationConfig(BaseModel):
|
|||||||
select_num: int = Field(default=3, ge=1, le=10)
|
select_num: int = Field(default=3, ge=1, le=10)
|
||||||
diversified: bool = Field(default=True)
|
diversified: bool = Field(default=True)
|
||||||
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
threshold: ThresholdConfig = Field(default_factory=ThresholdConfig)
|
||||||
|
weight: WeightType = Field(default=WeightType.EQUAL)
|
||||||
|
|
||||||
|
|
||||||
class RebalanceConfig(BaseModel):
|
class RebalanceConfig(BaseModel):
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ rebalance:
|
|||||||
rotation:
|
rotation:
|
||||||
diversified: true
|
diversified: true
|
||||||
select_num: 3
|
select_num: 3
|
||||||
|
weight: rank
|
||||||
threshold:
|
threshold:
|
||||||
dynamic:
|
dynamic:
|
||||||
fallback_enabled: true
|
fallback_enabled: true
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType
|
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType, WeightType
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||||||
@@ -164,6 +164,49 @@ def momentum_score(prices: np.ndarray) -> float:
|
|||||||
return prices[-1] / prices[0] - 1
|
return prices[-1] / prices[0] - 1
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pure functions: position weight schemes
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compute_position_weights(
|
||||||
|
ranked_holdings: List[str],
|
||||||
|
weight_type: str = 'equal',
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Compute position weights from ranked slot list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ranked_holdings: Ordered list of signal codes, best first.
|
||||||
|
May contain duplicates (e.g. bond fills).
|
||||||
|
weight_type: 'equal' or 'rank'.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping each unique code to its total weight (sum of slots).
|
||||||
|
|
||||||
|
Schemes:
|
||||||
|
equal: each slot = 1/N, duplicates summed.
|
||||||
|
rank: slot i (0-indexed) = (N-i) / triangular(N), duplicates summed.
|
||||||
|
For N=3: [3/6, 2/6, 1/6] = [50%, 33%, 17%].
|
||||||
|
"""
|
||||||
|
N = len(ranked_holdings)
|
||||||
|
if N == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
weights: Dict[str, float] = {}
|
||||||
|
|
||||||
|
if weight_type == 'rank':
|
||||||
|
triangular = N * (N + 1) / 2
|
||||||
|
for i, code in enumerate(ranked_holdings):
|
||||||
|
w = (N - i) / triangular
|
||||||
|
weights[code] = weights.get(code, 0.0) + w
|
||||||
|
else:
|
||||||
|
# equal (default)
|
||||||
|
w = 1.0 / N
|
||||||
|
for code in ranked_holdings:
|
||||||
|
weights[code] = weights.get(code, 0.0) + w
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
def is_crash(prices: np.ndarray) -> bool:
|
def is_crash(prices: np.ndarray) -> bool:
|
||||||
"""Crash filter: 3 consecutive days drop > 5%"""
|
"""Crash filter: 3 consecutive days drop > 5%"""
|
||||||
if len(prices) < 4:
|
if len(prices) < 4:
|
||||||
@@ -351,6 +394,7 @@ class SimpleRotationStrategy:
|
|||||||
self.n_days = self.config.factor.n_days
|
self.n_days = self.config.factor.n_days
|
||||||
self.select_num = self.config.rotation.select_num
|
self.select_num = self.config.rotation.select_num
|
||||||
self.trade_cost = self.config.rebalance.trade_cost
|
self.trade_cost = self.config.rebalance.trade_cost
|
||||||
|
self.weight_type = self.config.rotation.weight.value # 'equal' or 'rank'
|
||||||
|
|
||||||
# Dynamic threshold
|
# Dynamic threshold
|
||||||
threshold = self.config.rotation.threshold
|
threshold = self.config.rotation.threshold
|
||||||
@@ -388,6 +432,9 @@ class SimpleRotationStrategy:
|
|||||||
self.daily_records: List[dict] = []
|
self.daily_records: List[dict] = []
|
||||||
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
self.trading_calendar: Optional[pd.DatetimeIndex] = None
|
||||||
|
|
||||||
|
# Position weights: code -> weight (updated each day by _generate_signals)
|
||||||
|
self._position_weights: Dict[str, float] = {}
|
||||||
|
|
||||||
def _preload_data(self):
|
def _preload_data(self):
|
||||||
"""Preload all historical data"""
|
"""Preload all historical data"""
|
||||||
start_date = self.config.backtest.start_date
|
start_date = self.config.backtest.start_date
|
||||||
@@ -449,7 +496,9 @@ class SimpleRotationStrategy:
|
|||||||
return weighted_momentum_score(prices)
|
return weighted_momentum_score(prices)
|
||||||
|
|
||||||
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||||
"""Always compute weighted momentum (for threshold comparison)"""
|
"""Compute momentum using the configured factor function.
|
||||||
|
For VOL_ADJUSTED_MOMENTUM, falls back to weighted_momentum_score
|
||||||
|
so that threshold comparison operates in raw return space."""
|
||||||
if signal_code not in self.index_data:
|
if signal_code not in self.index_data:
|
||||||
return None
|
return None
|
||||||
df = self.index_data[signal_code]
|
df = self.index_data[signal_code]
|
||||||
@@ -460,7 +509,10 @@ class SimpleRotationStrategy:
|
|||||||
prices = recent['close'].values[-self.n_days:]
|
prices = recent['close'].values[-self.n_days:]
|
||||||
if len(prices) >= 4 and is_crash(prices):
|
if len(prices) >= 4 and is_crash(prices):
|
||||||
return 0.0
|
return 0.0
|
||||||
|
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||||
return weighted_momentum_score(prices)
|
return weighted_momentum_score(prices)
|
||||||
|
# All other factors: use the same score function as ranking
|
||||||
|
return self._compute_momentum(signal_code, date)
|
||||||
|
|
||||||
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
def _generate_signals(self, date: pd.Timestamp) -> Tuple[List[str], Dict[str, float], Optional[float]]:
|
||||||
"""
|
"""
|
||||||
@@ -524,14 +576,19 @@ class SimpleRotationStrategy:
|
|||||||
|
|
||||||
candidates = list(selected_by_group.values())
|
candidates = list(selected_by_group.values())
|
||||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||||
final_holdings = [code for code, _ in candidates[:self.select_num]]
|
ranked_holdings = [code for code, _ in candidates[:self.select_num]]
|
||||||
|
|
||||||
if len(final_holdings) < self.select_num and self.bond_code:
|
if len(ranked_holdings) < self.select_num and self.bond_code:
|
||||||
if self.bond_code not in final_holdings:
|
if self.bond_code not in ranked_holdings:
|
||||||
n_slots = self.select_num - len(final_holdings)
|
n_slots = self.select_num - len(ranked_holdings)
|
||||||
final_holdings.extend([self.bond_code] * n_slots)
|
ranked_holdings.extend([self.bond_code] * n_slots)
|
||||||
|
|
||||||
return sorted(final_holdings), factors, bond_momentum
|
# Compute position weights via configured scheme
|
||||||
|
self._position_weights = compute_position_weights(
|
||||||
|
ranked_holdings, self.weight_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted(ranked_holdings), factors, bond_momentum
|
||||||
|
|
||||||
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
def _get_etf_prices(self, trade_code: str, date: pd.Timestamp) -> Optional[dict]:
|
||||||
"""Get ETF prices on a given date: {open, close, prev_close}
|
"""Get ETF prices on a given date: {open, close, prev_close}
|
||||||
@@ -566,52 +623,64 @@ class SimpleRotationStrategy:
|
|||||||
'prev_close': prev_close,
|
'prev_close': prev_close,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_weight(self, code: str, n_unique: int) -> float:
|
||||||
|
"""Get position weight for a code. Falls back to equal weight if not available."""
|
||||||
|
if code in self._position_weights:
|
||||||
|
return self._position_weights[code]
|
||||||
|
return 1.0 / n_unique if n_unique > 0 else 0.0
|
||||||
|
|
||||||
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
def _calculate_daily_return(self, old_holdings, new_holdings, date, is_rebalance):
|
||||||
"""
|
"""
|
||||||
Compute daily return (T+1 execution):
|
Compute daily return (T+1 execution) with configurable position weighting:
|
||||||
- Hold: close-to-close
|
- Hold: close-to-close, weighted by today's position weight
|
||||||
- Sell: close-to-open (sold at open)
|
- Sell: close-to-open (sold at open), weighted by today's position weight
|
||||||
- Buy: open-to-close (intraday)
|
- Buy: open-to-close (intraday), weighted by today's position weight
|
||||||
"""
|
"""
|
||||||
if not old_holdings:
|
if not old_holdings:
|
||||||
if not new_holdings:
|
if not new_holdings:
|
||||||
return 0.0
|
return 0.0
|
||||||
weight = 1.0 / len(new_holdings)
|
|
||||||
ret = 0.0
|
ret = 0.0
|
||||||
|
seen = set()
|
||||||
|
n_unique = len(set(new_holdings))
|
||||||
for code in new_holdings:
|
for code in new_holdings:
|
||||||
|
if code in seen:
|
||||||
|
continue
|
||||||
|
seen.add(code)
|
||||||
tc = self.signal_to_trade.get(code, code)
|
tc = self.signal_to_trade.get(code, code)
|
||||||
p = self._get_etf_prices(tc, date)
|
p = self._get_etf_prices(tc, date)
|
||||||
|
w = self._get_weight(code, n_unique)
|
||||||
if p and p['open'] > 0:
|
if p and p['open'] > 0:
|
||||||
ret += weight * (p['close'] - p['open']) / p['open']
|
ret += w * (p['close'] - p['open']) / p['open']
|
||||||
if is_rebalance:
|
if is_rebalance:
|
||||||
ret -= self.trade_cost
|
ret -= self.trade_cost
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
old_set = set(old_holdings)
|
old_set = set(old_holdings)
|
||||||
new_set = set(new_holdings)
|
new_set = set(new_holdings)
|
||||||
weight = 1.0 / len(old_holdings)
|
n_unique = len(old_set)
|
||||||
daily_return = 0.0
|
daily_return = 0.0
|
||||||
|
|
||||||
for code in old_holdings:
|
for code in old_set:
|
||||||
tc = self.signal_to_trade.get(code, code)
|
tc = self.signal_to_trade.get(code, code)
|
||||||
p = self._get_etf_prices(tc, date)
|
p = self._get_etf_prices(tc, date)
|
||||||
if p is None or p['prev_close'] == 0:
|
if p is None or p['prev_close'] == 0:
|
||||||
continue
|
continue
|
||||||
|
w = self._get_weight(code, n_unique)
|
||||||
if code in new_set:
|
if code in new_set:
|
||||||
r = (p['close'] - p['prev_close']) / p['prev_close']
|
r = (p['close'] - p['prev_close']) / p['prev_close']
|
||||||
else:
|
else:
|
||||||
r = (p['open'] - p['prev_close']) / p['prev_close']
|
r = (p['open'] - p['prev_close']) / p['prev_close']
|
||||||
if not math.isnan(r):
|
if not math.isnan(r):
|
||||||
daily_return += weight * r
|
daily_return += w * r
|
||||||
|
|
||||||
for code in new_holdings:
|
for code in new_set - old_set:
|
||||||
if code not in old_set:
|
|
||||||
tc = self.signal_to_trade.get(code, code)
|
tc = self.signal_to_trade.get(code, code)
|
||||||
p = self._get_etf_prices(tc, date)
|
p = self._get_etf_prices(tc, date)
|
||||||
|
w = self._get_weight(code, len(new_set))
|
||||||
if p and p['open'] > 0 and not math.isnan(p['close']):
|
if p and p['open'] > 0 and not math.isnan(p['close']):
|
||||||
r = (p['close'] - p['open']) / p['open']
|
r = (p['close'] - p['open']) / p['open']
|
||||||
if not math.isnan(r):
|
if not math.isnan(r):
|
||||||
daily_return += weight * r
|
daily_return += w * r
|
||||||
|
|
||||||
if is_rebalance:
|
if is_rebalance:
|
||||||
daily_return -= self.trade_cost
|
daily_return -= self.trade_cost
|
||||||
|
|||||||
Reference in New Issue
Block a user