feat: 新增slope_r2因子并切换为默认因子(年化19.84%, 夏普1.14)
- simple_rotation.py: 新增3种score函数(vol_adjusted_momentum, slope_r2, momentum) - config_loader.py: FactorType枚举新增VOL_ADJUSTED_MOMENTUM - config_simple.yaml: factor.type 切换为 slope_r2 - experiments/factor_comparison.py: 4种因子对比实验脚本 - experiments/output: 实验结果(slope_r2全面胜出)
This commit is contained in:
@@ -24,7 +24,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig
|
||||
from rotation.config_loader import load_rotation_config, RotationStrategyConfig, FactorType
|
||||
|
||||
# ============================================================
|
||||
# HTTP client (requests + trust_env=False,绕过系统代理避免 SSL EOF)
|
||||
@@ -79,6 +79,58 @@ def weighted_momentum_score(prices: np.ndarray) -> float:
|
||||
return annualized_return * r2
|
||||
|
||||
|
||||
def vol_adjusted_momentum_score(prices: np.ndarray) -> float:
|
||||
"""Volatility-adjusted momentum score = (annualized_return / annualized_vol) * R^2
|
||||
|
||||
Moskowitz, Ooi & Pedersen (2012) TSMOM approach:
|
||||
divide momentum by realized volatility to make cross-asset comparison fair.
|
||||
"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = np.log(prices)
|
||||
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
|
||||
return 0.0
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_return = math.exp(slope * 250) - 1
|
||||
# realized volatility (annualized)
|
||||
daily_returns = np.diff(y)
|
||||
realized_vol = float(np.std(daily_returns)) * math.sqrt(250)
|
||||
if realized_vol < 0.01: # guard against near-zero vol
|
||||
realized_vol = 0.01
|
||||
# trend quality (R^2)
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
return (annualized_return / realized_vol) * r2
|
||||
|
||||
|
||||
def slope_r2_score(prices: np.ndarray) -> float:
|
||||
"""Slope * R^2 score (unweighted, normalized prices)"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
y = prices / prices[0] # normalize
|
||||
x = np.arange(len(y))
|
||||
slope, intercept = np.polyfit(x, y, 1)
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum((y - y_pred) ** 2)
|
||||
ss_tot = np.sum((y - np.mean(y)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
return 10000 * slope * r2
|
||||
|
||||
|
||||
def momentum_score(prices: np.ndarray) -> float:
|
||||
"""Simple price return: (last / first) - 1"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
prices = np.clip(prices, 0.01, None)
|
||||
return prices[-1] / prices[0] - 1
|
||||
|
||||
|
||||
def is_crash(prices: np.ndarray) -> bool:
|
||||
"""Crash filter: 3 consecutive days drop > 5%"""
|
||||
if len(prices) < 4:
|
||||
@@ -340,7 +392,29 @@ class SimpleRotationStrategy:
|
||||
print(f" Benchmark: {len(bm_df)} rows")
|
||||
|
||||
def _compute_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||
"""Compute momentum for a single code on a given date"""
|
||||
"""Compute momentum for a single code on a given date (uses configured score function)"""
|
||||
if signal_code not in self.index_data:
|
||||
return None
|
||||
df = self.index_data[signal_code]
|
||||
mask = df.index <= date
|
||||
recent = df.loc[mask]
|
||||
if len(recent) < self.n_days:
|
||||
return None
|
||||
prices = recent['close'].values[-self.n_days:]
|
||||
if len(prices) >= 4 and is_crash(prices):
|
||||
return 0.0
|
||||
# Dispatch based on factor type
|
||||
ft = self.config.factor.type
|
||||
if ft == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
return vol_adjusted_momentum_score(prices)
|
||||
elif ft == FactorType.SLOPE_R2:
|
||||
return slope_r2_score(prices)
|
||||
elif ft == FactorType.MOMENTUM:
|
||||
return momentum_score(prices)
|
||||
return weighted_momentum_score(prices)
|
||||
|
||||
def _compute_raw_momentum(self, signal_code: str, date: pd.Timestamp) -> Optional[float]:
|
||||
"""Always compute weighted momentum (for threshold comparison)"""
|
||||
if signal_code not in self.index_data:
|
||||
return None
|
||||
df = self.index_data[signal_code]
|
||||
@@ -374,23 +448,37 @@ class SimpleRotationStrategy:
|
||||
if not factors:
|
||||
return [], {}, None
|
||||
|
||||
# Bond threshold always uses raw (weighted) momentum
|
||||
bond_momentum = None
|
||||
if self.use_dynamic_threshold and self.bond_code:
|
||||
bond_momentum = self._compute_momentum(self.bond_code, date)
|
||||
bond_momentum = self._compute_raw_momentum(self.bond_code, date)
|
||||
if bond_momentum is None:
|
||||
bond_momentum = self.fallback_value
|
||||
|
||||
# Raw factors for threshold comparison
|
||||
raw_factors: Dict[str, float] = {}
|
||||
if self.config.factor.type == FactorType.VOL_ADJUSTED_MOMENTUM:
|
||||
for code in self.signal_codes:
|
||||
score = self._compute_raw_momentum(code, date)
|
||||
if score is not None:
|
||||
raw_factors[code] = score
|
||||
else:
|
||||
raw_factors = factors
|
||||
|
||||
groups = self.config.asset_pools.by_group
|
||||
selected_by_group: Dict[str, Tuple[str, float]] = {}
|
||||
|
||||
for group_name, assets in groups.items():
|
||||
group_codes = [a.signal_source for a in assets.values()]
|
||||
group_factors = {c: factors[c] for c in group_codes if c in factors}
|
||||
group_raw = {c: raw_factors[c] for c in group_codes if c in raw_factors}
|
||||
if not group_factors:
|
||||
continue
|
||||
# Threshold comparison uses raw (weighted) momentum
|
||||
if group_name != 'BOND' and bond_momentum is not None:
|
||||
thresh = bond_momentum * self.bond_ratio
|
||||
group_factors = {c: s for c, s in group_factors.items() if s >= thresh}
|
||||
group_factors = {c: s for c, s in group_factors.items()
|
||||
if group_raw.get(c, 0) >= thresh}
|
||||
if not group_factors:
|
||||
continue
|
||||
top_code = max(group_factors, key=group_factors.get)
|
||||
|
||||
Reference in New Issue
Block a user