feat(strategies): 实现定制组件(因子、信号生成器、风控)
- strategies/shared/factors/momentum.py: MomentumFactor/TrendFactor/ReversalFactor/VolatilityFactor - strategies/shared/signals/selectors.py: TopNSelector/TrendFollower/ReversalTrader - strategies/shared/risk/controls.py: StopLossControl/PositionLimitControl/PremiumControl - strategies/shared/__init__.py: 统一入口导出所有定制组件
This commit is contained in:
54
strategies/shared/__init__.py
Normal file
54
strategies/shared/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
定制组件统一入口
|
||||||
|
|
||||||
|
所有定制因子、信号生成器、风控组件都在这里导出
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 定制因子
|
||||||
|
from strategies.shared.factors.momentum import (
|
||||||
|
MomentumFactor,
|
||||||
|
TrendFactor,
|
||||||
|
ReversalFactor,
|
||||||
|
VolatilityFactor
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定制信号生成器
|
||||||
|
from strategies.shared.signals.selectors import (
|
||||||
|
TopNSelector,
|
||||||
|
TrendFollower,
|
||||||
|
ReversalTrader
|
||||||
|
)
|
||||||
|
|
||||||
|
# 定制风控组件
|
||||||
|
from strategies.shared.risk.controls import (
|
||||||
|
StopLossControl,
|
||||||
|
PositionLimitControl,
|
||||||
|
PremiumControl,
|
||||||
|
premium_filter_callback,
|
||||||
|
crash_filter_callback,
|
||||||
|
holding_time_stoploss_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 因子
|
||||||
|
'MomentumFactor',
|
||||||
|
'TrendFactor',
|
||||||
|
'ReversalFactor',
|
||||||
|
'VolatilityFactor',
|
||||||
|
|
||||||
|
# 信号生成器
|
||||||
|
'TopNSelector',
|
||||||
|
'TrendFollower',
|
||||||
|
'ReversalTrader',
|
||||||
|
|
||||||
|
# 风控组件
|
||||||
|
'StopLossControl',
|
||||||
|
'PositionLimitControl',
|
||||||
|
'PremiumControl',
|
||||||
|
|
||||||
|
# 回调函数
|
||||||
|
'premium_filter_callback',
|
||||||
|
'crash_filter_callback',
|
||||||
|
'holding_time_stoploss_callback',
|
||||||
|
]
|
||||||
243
strategies/shared/factors/momentum.py
Normal file
243
strategies/shared/factors/momentum.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""
|
||||||
|
定制因子实现
|
||||||
|
|
||||||
|
这些因子继承framework.core.factors.FactorBase
|
||||||
|
"""
|
||||||
|
|
||||||
|
from framework.factors import FactorBase, FactorRegistry
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class MomentumFactor(FactorBase):
|
||||||
|
"""
|
||||||
|
动量因子(定制实现)
|
||||||
|
|
||||||
|
计算加权线性回归动量得分:
|
||||||
|
得分 = 年化收益率 × R²
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- n_days: 动量窗口(默认25)
|
||||||
|
- weighted: 是否加权(默认True)
|
||||||
|
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "momentum"
|
||||||
|
category = "momentum"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_days: int = 25,
|
||||||
|
weighted: bool = True,
|
||||||
|
crash_filter: bool = True
|
||||||
|
):
|
||||||
|
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||||
|
self.n_days = n_days
|
||||||
|
self.weighted = weighted
|
||||||
|
self.crash_filter = crash_filter
|
||||||
|
|
||||||
|
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算动量因子值"""
|
||||||
|
if 'close' not in data.columns:
|
||||||
|
raise ValueError("data must contain 'close' column")
|
||||||
|
|
||||||
|
prices = data['close']
|
||||||
|
|
||||||
|
if self.weighted:
|
||||||
|
factor_values = prices.rolling(self.n_days).apply(
|
||||||
|
lambda x: self._weighted_momentum_score(x.values),
|
||||||
|
raw=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
factor_values = prices.pct_change(self.n_days)
|
||||||
|
|
||||||
|
if self.crash_filter:
|
||||||
|
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||||
|
|
||||||
|
return factor_values
|
||||||
|
|
||||||
|
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||||
|
"""计算加权动量得分"""
|
||||||
|
if len(prices) < 5:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
y = np.log(prices)
|
||||||
|
x = np.arange(len(y))
|
||||||
|
weights = np.linspace(1, 2, len(y))
|
||||||
|
|
||||||
|
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||||
|
annualized_returns = math.exp(slope * 250) - 1
|
||||||
|
|
||||||
|
y_pred = slope * x + intercept
|
||||||
|
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||||
|
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||||
|
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||||
|
|
||||||
|
return annualized_returns * r2
|
||||||
|
|
||||||
|
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
|
||||||
|
"""崩盘过滤:连续3天跌>5%清零"""
|
||||||
|
result = factor_values.copy()
|
||||||
|
|
||||||
|
for i in range(3, len(prices)):
|
||||||
|
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||||
|
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||||
|
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||||
|
|
||||||
|
con1 = min(r1, r2, r3) < 0.95
|
||||||
|
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||||
|
|
||||||
|
if con1 or con2:
|
||||||
|
result.iloc[i] = 0.0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TrendFactor(FactorBase):
|
||||||
|
"""趋势因子(定制实现)"""
|
||||||
|
|
||||||
|
name = "trend"
|
||||||
|
category = "trend"
|
||||||
|
|
||||||
|
def __init__(self, method: str = 'ma_cross', fast: int = 5, slow: int = 20):
|
||||||
|
super().__init__(method=method, fast=fast, slow=slow)
|
||||||
|
self.method = method
|
||||||
|
self.fast = fast
|
||||||
|
self.slow = slow
|
||||||
|
|
||||||
|
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算趋势因子值"""
|
||||||
|
if 'close' not in data.columns:
|
||||||
|
raise ValueError("data must contain 'close' column")
|
||||||
|
|
||||||
|
prices = data['close']
|
||||||
|
|
||||||
|
if self.method == 'ma_cross':
|
||||||
|
fast_ma = prices.rolling(self.fast).mean()
|
||||||
|
slow_ma = prices.rolling(self.slow).mean()
|
||||||
|
return (fast_ma - slow_ma) / slow_ma
|
||||||
|
|
||||||
|
elif self.method == 'macd':
|
||||||
|
ema12 = prices.ewm(span=12).mean()
|
||||||
|
ema26 = prices.ewm(span=26).mean()
|
||||||
|
macd = ema12 - ema26
|
||||||
|
signal = macd.ewm(span=9).mean()
|
||||||
|
return macd - signal
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {self.method}")
|
||||||
|
|
||||||
|
|
||||||
|
class ReversalFactor(FactorBase):
|
||||||
|
"""反转因子(定制实现)"""
|
||||||
|
|
||||||
|
name = "reversal"
|
||||||
|
category = "reversal"
|
||||||
|
|
||||||
|
def __init__(self, method: str = 'rsi', period: int = 14, overbought: float = 70, oversold: float = 30):
|
||||||
|
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
|
||||||
|
self.method = method
|
||||||
|
self.period = period
|
||||||
|
self.overbought = overbought
|
||||||
|
self.oversold = oversold
|
||||||
|
|
||||||
|
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算反转因子值"""
|
||||||
|
if 'close' not in data.columns:
|
||||||
|
raise ValueError("data must contain 'close' column")
|
||||||
|
|
||||||
|
prices = data['close']
|
||||||
|
|
||||||
|
if self.method == 'rsi':
|
||||||
|
rsi = self._compute_rsi(prices, self.period)
|
||||||
|
reversal_signal = np.where(
|
||||||
|
rsi > self.overbought,
|
||||||
|
-(rsi - self.overbought) / (100 - self.overbought),
|
||||||
|
np.where(
|
||||||
|
rsi < self.oversold,
|
||||||
|
(self.oversold - rsi) / self.oversold,
|
||||||
|
0
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return pd.Series(reversal_signal, index=prices.index)
|
||||||
|
|
||||||
|
elif self.method == 'kdj':
|
||||||
|
return self._compute_kdj(data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {self.method}")
|
||||||
|
|
||||||
|
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
|
||||||
|
"""计算RSI"""
|
||||||
|
delta = prices.diff()
|
||||||
|
gain = delta.where(delta > 0, 0)
|
||||||
|
loss = (-delta).where(delta < 0, 0)
|
||||||
|
|
||||||
|
avg_gain = gain.rolling(period).mean()
|
||||||
|
avg_loss = loss.rolling(period).mean()
|
||||||
|
|
||||||
|
rs = avg_gain / avg_loss
|
||||||
|
return 100 - (100 / (1 + rs))
|
||||||
|
|
||||||
|
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算KDJ反转信号"""
|
||||||
|
low = data['low']
|
||||||
|
high = data['high']
|
||||||
|
close = data['close']
|
||||||
|
|
||||||
|
low_min = low.rolling(self.period).min()
|
||||||
|
high_max = high.rolling(self.period).max()
|
||||||
|
|
||||||
|
rsv = (close - low_min) / (high_max - low_min) * 100
|
||||||
|
|
||||||
|
k = rsv.ewm(alpha=1/3).mean()
|
||||||
|
d = k.ewm(alpha=1/3).mean()
|
||||||
|
j = 3 * k - 2 * d
|
||||||
|
|
||||||
|
return j
|
||||||
|
|
||||||
|
|
||||||
|
class VolatilityFactor(FactorBase):
|
||||||
|
"""波动率因子(定制实现)"""
|
||||||
|
|
||||||
|
name = "volatility"
|
||||||
|
category = "volatility"
|
||||||
|
|
||||||
|
def __init__(self, method: str = 'std', period: int = 20):
|
||||||
|
super().__init__(method=method, period=period)
|
||||||
|
self.method = method
|
||||||
|
self.period = period
|
||||||
|
|
||||||
|
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算波动率因子值"""
|
||||||
|
if self.method == 'std':
|
||||||
|
return data['close'].rolling(self.period).std()
|
||||||
|
|
||||||
|
elif self.method == 'atr':
|
||||||
|
return self._compute_atr(data)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {self.method}")
|
||||||
|
|
||||||
|
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
|
||||||
|
"""计算ATR"""
|
||||||
|
high = data['high']
|
||||||
|
low = data['low']
|
||||||
|
close = data['close']
|
||||||
|
|
||||||
|
prev_close = close.shift(1)
|
||||||
|
tr = pd.concat([
|
||||||
|
high - low,
|
||||||
|
(high - prev_close).abs(),
|
||||||
|
(low - prev_close).abs()
|
||||||
|
], axis=1).max(axis=1)
|
||||||
|
|
||||||
|
return tr.rolling(self.period).mean()
|
||||||
|
|
||||||
|
|
||||||
|
# 注册因子
|
||||||
|
FactorRegistry.register(MomentumFactor)
|
||||||
|
FactorRegistry.register(TrendFactor)
|
||||||
|
FactorRegistry.register(ReversalFactor)
|
||||||
|
FactorRegistry.register(VolatilityFactor)
|
||||||
143
strategies/shared/risk/controls.py
Normal file
143
strategies/shared/risk/controls.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""
|
||||||
|
定制风控组件实现
|
||||||
|
|
||||||
|
这些风控组件继承framework.core.risk.RiskControl
|
||||||
|
"""
|
||||||
|
|
||||||
|
from framework.risk import RiskControl, Position, CallbackHook
|
||||||
|
|
||||||
|
|
||||||
|
class StopLossControl(RiskControl):
|
||||||
|
"""止损控制(定制实现)"""
|
||||||
|
|
||||||
|
name = "stop_loss"
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = -0.05, trailing: bool = False, trailing_percent: float = 0.03):
|
||||||
|
super().__init__(threshold=threshold, trailing=trailing, trailing_percent=trailing_percent)
|
||||||
|
self.threshold = threshold
|
||||||
|
self.trailing = trailing
|
||||||
|
self.trailing_percent = trailing_percent
|
||||||
|
self._highest_price = {}
|
||||||
|
|
||||||
|
def check(self, position: Position, **kwargs) -> bool:
|
||||||
|
"""检查是否触发止损"""
|
||||||
|
if position is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.trailing:
|
||||||
|
if position.code not in self._highest_price:
|
||||||
|
self._highest_price[position.code] = position.entry_price
|
||||||
|
self._highest_price[position.code] = max(
|
||||||
|
self._highest_price[position.code],
|
||||||
|
position.current_price
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.trailing:
|
||||||
|
highest = self._highest_price[position.code]
|
||||||
|
drawdown = (position.current_price - highest) / highest
|
||||||
|
return drawdown > -self.trailing_percent
|
||||||
|
else:
|
||||||
|
return position.profit_ratio > self.threshold
|
||||||
|
|
||||||
|
def apply(self, position: Position):
|
||||||
|
"""返回止损价格"""
|
||||||
|
if self.trailing:
|
||||||
|
highest = self._highest_price.get(position.code, position.entry_price)
|
||||||
|
return highest * (1 - self.trailing_percent)
|
||||||
|
else:
|
||||||
|
return position.entry_price * (1 + self.threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionLimitControl(RiskControl):
|
||||||
|
"""仓位限制控制(定制实现)"""
|
||||||
|
|
||||||
|
name = "position_limit"
|
||||||
|
|
||||||
|
def __init__(self, max_position: float = 0.33, max_total: float = 1.0):
|
||||||
|
super().__init__(max_position=max_position, max_total=max_total)
|
||||||
|
self.max_position = max_position
|
||||||
|
self.max_total = max_total
|
||||||
|
|
||||||
|
def check(self, position: Position, **kwargs) -> bool:
|
||||||
|
"""检查仓位是否超限"""
|
||||||
|
if position is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if position.weight > self.max_position:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, position: Position):
|
||||||
|
"""返回建议仓位"""
|
||||||
|
return min(position.weight, self.max_position)
|
||||||
|
|
||||||
|
|
||||||
|
class PremiumControl(RiskControl):
|
||||||
|
"""溢价控制(定制实现)"""
|
||||||
|
|
||||||
|
name = "premium"
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = 0.10, mode: str = 'filter'):
|
||||||
|
super().__init__(threshold=threshold, mode=mode)
|
||||||
|
self.threshold = threshold
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def check(self, position: Position, **kwargs) -> bool:
|
||||||
|
"""检查溢价是否超限"""
|
||||||
|
premium = kwargs.get('premium', 0)
|
||||||
|
|
||||||
|
if self.mode == 'filter':
|
||||||
|
return premium <= self.threshold
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, position: Position):
|
||||||
|
"""返回溢价惩罚系数"""
|
||||||
|
if self.mode == 'penalize':
|
||||||
|
return 0.5
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# 定制回调函数
|
||||||
|
def premium_filter_callback(threshold: float = 0.10):
|
||||||
|
"""溢价过滤回调(定制实现)"""
|
||||||
|
def callback(code: str, price: float, **kwargs) -> bool:
|
||||||
|
premium = kwargs.get('premium', 0)
|
||||||
|
if premium > threshold:
|
||||||
|
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||||||
|
"""崩盘过滤回调(定制实现)"""
|
||||||
|
def callback(code: str, price: float, **kwargs) -> bool:
|
||||||
|
history = kwargs.get('history', None)
|
||||||
|
if history is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
recent = history.tail(lookback)
|
||||||
|
if len(recent) < lookback:
|
||||||
|
return True
|
||||||
|
|
||||||
|
returns = recent['close'].pct_change()
|
||||||
|
min_return = returns.min()
|
||||||
|
|
||||||
|
if min_return < -crash_threshold:
|
||||||
|
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def holding_time_stoploss_callback(day_5_stoploss: float = -0.05, day_10_stoploss: float = -0.03):
|
||||||
|
"""持仓时间动态止损回调(定制实现)"""
|
||||||
|
def callback(position: Position) -> float:
|
||||||
|
if position.holding_days >= 10:
|
||||||
|
return day_10_stoploss
|
||||||
|
elif position.holding_days >= 5:
|
||||||
|
return day_5_stoploss
|
||||||
|
return -0.10
|
||||||
|
return callback
|
||||||
215
strategies/shared/signals/selectors.py
Normal file
215
strategies/shared/signals/selectors.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
"""
|
||||||
|
定制信号生成器实现
|
||||||
|
|
||||||
|
这些信号生成器继承framework.core.signals.SignalGenerator
|
||||||
|
"""
|
||||||
|
|
||||||
|
from framework.signals import SignalGenerator
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
|
||||||
|
|
||||||
|
class TopNSelector(SignalGenerator):
|
||||||
|
"""
|
||||||
|
Top N选股器(定制实现)
|
||||||
|
|
||||||
|
用于轮动策略:
|
||||||
|
- 按因子值排序,选出Top N标的
|
||||||
|
- 支持分组选股(先类内竞争,再跨类排序)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- select_num: 选中数量(默认3)
|
||||||
|
- group_by: 分组列名(可选,如'market')
|
||||||
|
- top_per_group: 每组选中数量(默认1)
|
||||||
|
- min_score: 最小得分阈值(可选)
|
||||||
|
"""
|
||||||
|
|
||||||
|
mode = "top_n"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
select_num: int = 3,
|
||||||
|
group_by: Optional[str] = None,
|
||||||
|
top_per_group: int = 1,
|
||||||
|
min_score: Optional[float] = None
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
select_num=select_num,
|
||||||
|
group_by=group_by,
|
||||||
|
top_per_group=top_per_group,
|
||||||
|
min_score=min_score
|
||||||
|
)
|
||||||
|
self.select_num = select_num
|
||||||
|
self.group_by = group_by
|
||||||
|
self.top_per_group = top_per_group
|
||||||
|
self.min_score = min_score
|
||||||
|
|
||||||
|
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""生成Top N选股信号"""
|
||||||
|
result = pd.DataFrame(index=factor_data.index)
|
||||||
|
|
||||||
|
factor_cols = self._get_factor_columns(factor_data)
|
||||||
|
|
||||||
|
if not factor_cols:
|
||||||
|
result['signal'] = ''
|
||||||
|
return result
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
for date in factor_data.index:
|
||||||
|
row = factor_data.loc[date]
|
||||||
|
|
||||||
|
scores = {}
|
||||||
|
for col in factor_cols:
|
||||||
|
score = row[col]
|
||||||
|
if pd.notna(score):
|
||||||
|
scores[col] = score
|
||||||
|
|
||||||
|
if self.min_score:
|
||||||
|
scores = {k: v for k, v in scores.items() if v >= self.min_score}
|
||||||
|
|
||||||
|
if self.group_by and 'group_info' in factor_data.columns:
|
||||||
|
selected = self._grouped_selection(scores, factor_data.loc[date])
|
||||||
|
else:
|
||||||
|
selected = self._global_top_n(scores)
|
||||||
|
|
||||||
|
signals.append(','.join(selected) if selected else '')
|
||||||
|
|
||||||
|
result['signal'] = signals
|
||||||
|
result['signal_raw'] = signals
|
||||||
|
|
||||||
|
result['signal'] = result['signal'].shift(1)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||||
|
"""获取因子列名"""
|
||||||
|
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||||
|
|
||||||
|
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
|
||||||
|
"""全局Top N选股"""
|
||||||
|
if not scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
return [item[0] for item in sorted_items[:self.select_num]]
|
||||||
|
|
||||||
|
def _grouped_selection(self, scores: Dict[str, float], row: pd.Series) -> List[str]:
|
||||||
|
"""分组选股:先类内竞争,再跨类排序"""
|
||||||
|
if 'group_info' not in row.index:
|
||||||
|
return self._global_top_n(scores)
|
||||||
|
|
||||||
|
group_info = row['group_info']
|
||||||
|
if pd.isna(group_info):
|
||||||
|
return self._global_top_n(scores)
|
||||||
|
|
||||||
|
groups = group_info if isinstance(group_info, dict) else {}
|
||||||
|
|
||||||
|
group_champions = {}
|
||||||
|
for code, score in scores.items():
|
||||||
|
group = groups.get(code, 'default')
|
||||||
|
if group not in group_champions or score > group_champions[group][1]:
|
||||||
|
group_champions[group] = (code, score)
|
||||||
|
|
||||||
|
champions_scores = {code: score for code, score in group_champions.values()}
|
||||||
|
return self._global_top_n(champions_scores)
|
||||||
|
|
||||||
|
|
||||||
|
class TrendFollower(SignalGenerator):
|
||||||
|
"""趋势跟随器(定制实现)"""
|
||||||
|
|
||||||
|
mode = "trend"
|
||||||
|
|
||||||
|
def __init__(self, entry_threshold: float = 0.02, exit_threshold: float = -0.02, select_num: int = 1):
|
||||||
|
super().__init__(entry_threshold=entry_threshold, exit_threshold=exit_threshold, select_num=select_num)
|
||||||
|
self.entry_threshold = entry_threshold
|
||||||
|
self.exit_threshold = exit_threshold
|
||||||
|
self.select_num = select_num
|
||||||
|
|
||||||
|
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""生成趋势跟随信号"""
|
||||||
|
result = pd.DataFrame(index=factor_data.index)
|
||||||
|
|
||||||
|
factor_cols = self._get_factor_columns(factor_data)
|
||||||
|
|
||||||
|
for col in factor_cols:
|
||||||
|
trend_strength = factor_data[col]
|
||||||
|
|
||||||
|
result[f'{col}_entry'] = trend_strength > self.entry_threshold
|
||||||
|
result[f'{col}_exit'] = trend_strength < self.exit_threshold
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
for date in result.index:
|
||||||
|
entry_signals = []
|
||||||
|
for col in factor_cols:
|
||||||
|
if result.loc[date, f'{col}_entry']:
|
||||||
|
score = factor_data.loc[date, col]
|
||||||
|
if pd.notna(score):
|
||||||
|
entry_signals.append((col, score))
|
||||||
|
|
||||||
|
entry_signals.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
selected = [item[0] for item in entry_signals[:self.select_num]]
|
||||||
|
signals.append(','.join(selected) if selected else '')
|
||||||
|
|
||||||
|
result['signal'] = signals
|
||||||
|
result['signal'] = result['signal'].shift(1)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||||
|
"""获取因子列名"""
|
||||||
|
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||||
|
|
||||||
|
|
||||||
|
class ReversalTrader(SignalGenerator):
|
||||||
|
"""反转交易器(定制实现)"""
|
||||||
|
|
||||||
|
mode = "reversal"
|
||||||
|
|
||||||
|
def __init__(self, overbought: float = 70, oversold: float = 30, reversal_threshold: float = 0.1):
|
||||||
|
super().__init__(overbought=overbought, oversold=oversold, reversal_threshold=reversal_threshold)
|
||||||
|
self.overbought = overbought
|
||||||
|
self.oversold = oversold
|
||||||
|
self.reversal_threshold = reversal_threshold
|
||||||
|
|
||||||
|
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""生成反转交易信号"""
|
||||||
|
result = pd.DataFrame(index=factor_data.index)
|
||||||
|
|
||||||
|
factor_cols = self._get_factor_columns(factor_data)
|
||||||
|
|
||||||
|
for col in factor_cols:
|
||||||
|
reversal_signal = factor_data[col]
|
||||||
|
|
||||||
|
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
|
||||||
|
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
|
||||||
|
|
||||||
|
signals = []
|
||||||
|
for date in result.index:
|
||||||
|
buy_signals = []
|
||||||
|
sell_signals = []
|
||||||
|
|
||||||
|
for col in factor_cols:
|
||||||
|
if result.loc[date, f'{col}_buy']:
|
||||||
|
buy_signals.append(col)
|
||||||
|
if result.loc[date, f'{col}_sell']:
|
||||||
|
sell_signals.append(col)
|
||||||
|
|
||||||
|
if buy_signals:
|
||||||
|
signals.append(f"BUY:{','.join(buy_signals)}")
|
||||||
|
elif sell_signals:
|
||||||
|
signals.append(f"SELL:{','.join(sell_signals)}")
|
||||||
|
else:
|
||||||
|
signals.append('')
|
||||||
|
|
||||||
|
result['signal'] = signals
|
||||||
|
result['signal'] = result['signal'].shift(1)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||||||
|
"""获取因子列名"""
|
||||||
|
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||||||
|
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||||||
Reference in New Issue
Block a user