Files
etf/strategies/shared/factors/momentum.py
aszerW 5212b004dc fix: 回测细节导出、交易日历测试和动量因子修复
修复项:
- export_backtest_detail.py: 统一回测导出脚本的数据源调用逻辑
- test_trading_calendar.py: 交易日历功能测试
- verify_fix_result.py: 修复结果验证
- verify_mode_b.py: 模式 B 验证

策略修复:
- momentum.py: 动量因子计算优化
- strategy.py: StrategyBase 数据获取修复(fetch_indices 返回 dict)
2026-05-24 14:26:35 +08:00

250 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
定制因子实现
这些因子继承framework.core.factors.FactorBase
"""
from framework.factors import FactorBase, FactorRegistry
import pandas as pd
import numpy as np
import math
class MomentumFactor(FactorBase):
"""
动量因子(定制实现)
计算加权线性回归动量得分:
得分 = 年化收益率 ×
参数:
- n_days: 动量窗口默认25
- weighted: 是否加权默认True
- crash_filter: 是否启用崩盘过滤默认True
"""
name = "momentum"
category = "momentum"
def __init__(
self,
n_days: int = 25,
weighted: bool = True,
crash_filter: bool = True
):
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
self.n_days = n_days
self.weighted = weighted
self.crash_filter = crash_filter
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算动量因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.weighted:
factor_values = prices.rolling(self.n_days).apply(
lambda x: self._weighted_momentum_score(x.values),
raw=False
)
else:
factor_values = prices.pct_change(self.n_days)
if self.crash_filter:
factor_values = self._apply_crash_filter(prices, factor_values)
return factor_values
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
"""计算加权动量得分"""
if len(prices) < 5:
return 0.0
# 价格下界 clip防止 log(0) 或 log(负数)
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
# 异常值检测
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_returns = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_returns * r2
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
"""崩盘过滤连续3天跌>5%清零"""
result = factor_values.copy()
for i in range(3, len(prices)):
r1 = prices.iloc[i] / prices.iloc[i-1]
r2 = prices.iloc[i-1] / prices.iloc[i-2]
r3 = prices.iloc[i-2] / prices.iloc[i-3]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
if con1 or con2:
result.iloc[i] = 0.0
return result
class TrendFactor(FactorBase):
"""趋势因子(定制实现)"""
name = "trend"
category = "trend"
def __init__(self, method: str = 'ma_cross', fast: int = 5, slow: int = 20):
super().__init__(method=method, fast=fast, slow=slow)
self.method = method
self.fast = fast
self.slow = slow
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算趋势因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.method == 'ma_cross':
fast_ma = prices.rolling(self.fast).mean()
slow_ma = prices.rolling(self.slow).mean()
return (fast_ma - slow_ma) / slow_ma
elif self.method == 'macd':
ema12 = prices.ewm(span=12).mean()
ema26 = prices.ewm(span=26).mean()
macd = ema12 - ema26
signal = macd.ewm(span=9).mean()
return macd - signal
else:
raise ValueError(f"Unknown method: {self.method}")
class ReversalFactor(FactorBase):
"""反转因子(定制实现)"""
name = "reversal"
category = "reversal"
def __init__(self, method: str = 'rsi', period: int = 14, overbought: float = 70, oversold: float = 30):
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
self.method = method
self.period = period
self.overbought = overbought
self.oversold = oversold
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算反转因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.method == 'rsi':
rsi = self._compute_rsi(prices, self.period)
reversal_signal = np.where(
rsi > self.overbought,
-(rsi - self.overbought) / (100 - self.overbought),
np.where(
rsi < self.oversold,
(self.oversold - rsi) / self.oversold,
0
)
)
return pd.Series(reversal_signal, index=prices.index)
elif self.method == 'kdj':
return self._compute_kdj(data)
else:
raise ValueError(f"Unknown method: {self.method}")
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
"""计算RSI"""
delta = prices.diff()
gain = delta.where(delta > 0, 0)
loss = (-delta).where(delta < 0, 0)
avg_gain = gain.rolling(period).mean()
avg_loss = loss.rolling(period).mean()
rs = avg_gain / avg_loss
return 100 - (100 / (1 + rs))
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
"""计算KDJ反转信号"""
low = data['low']
high = data['high']
close = data['close']
low_min = low.rolling(self.period).min()
high_max = high.rolling(self.period).max()
rsv = (close - low_min) / (high_max - low_min) * 100
k = rsv.ewm(alpha=1/3).mean()
d = k.ewm(alpha=1/3).mean()
j = 3 * k - 2 * d
return j
class VolatilityFactor(FactorBase):
"""波动率因子(定制实现)"""
name = "volatility"
category = "volatility"
def __init__(self, method: str = 'std', period: int = 20):
super().__init__(method=method, period=period)
self.method = method
self.period = period
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算波动率因子值"""
if self.method == 'std':
return data['close'].rolling(self.period).std()
elif self.method == 'atr':
return self._compute_atr(data)
else:
raise ValueError(f"Unknown method: {self.method}")
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
"""计算ATR"""
high = data['high']
low = data['low']
close = data['close']
prev_close = close.shift(1)
tr = pd.concat([
high - low,
(high - prev_close).abs(),
(low - prev_close).abs()
], axis=1).max(axis=1)
return tr.rolling(self.period).mean()
# 注册因子
FactorRegistry.register(MomentumFactor)
FactorRegistry.register(TrendFactor)
FactorRegistry.register(ReversalFactor)
FactorRegistry.register(VolatilityFactor)