From 69081297c58049954c700cbdafa0a36b9099ccf2 Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 11 May 2026 23:09:35 +0800 Subject: [PATCH] =?UTF-8?q?feat(strategies):=20=E5=AE=9E=E7=8E=B0=E5=AE=9A?= =?UTF-8?q?=E5=88=B6=E7=BB=84=E4=BB=B6=EF=BC=88=E5=9B=A0=E5=AD=90=E3=80=81?= =?UTF-8?q?=E4=BF=A1=E5=8F=B7=E7=94=9F=E6=88=90=E5=99=A8=E3=80=81=E9=A3=8E?= =?UTF-8?q?=E6=8E=A7=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - strategies/shared/factors/momentum.py: MomentumFactor/TrendFactor/ReversalFactor/VolatilityFactor - strategies/shared/signals/selectors.py: TopNSelector/TrendFollower/ReversalTrader - strategies/shared/risk/controls.py: StopLossControl/PositionLimitControl/PremiumControl - strategies/shared/__init__.py: 统一入口导出所有定制组件 --- strategies/shared/__init__.py | 54 ++++++ strategies/shared/factors/momentum.py | 243 +++++++++++++++++++++++++ strategies/shared/risk/controls.py | 143 +++++++++++++++ strategies/shared/signals/selectors.py | 215 ++++++++++++++++++++++ 4 files changed, 655 insertions(+) create mode 100644 strategies/shared/__init__.py create mode 100644 strategies/shared/factors/momentum.py create mode 100644 strategies/shared/risk/controls.py create mode 100644 strategies/shared/signals/selectors.py diff --git a/strategies/shared/__init__.py b/strategies/shared/__init__.py new file mode 100644 index 0000000..fcc6ef3 --- /dev/null +++ b/strategies/shared/__init__.py @@ -0,0 +1,54 @@ +""" +定制组件统一入口 + +所有定制因子、信号生成器、风控组件都在这里导出 +""" + +# 定制因子 +from strategies.shared.factors.momentum import ( + MomentumFactor, + TrendFactor, + ReversalFactor, + VolatilityFactor +) + +# 定制信号生成器 +from strategies.shared.signals.selectors import ( + TopNSelector, + TrendFollower, + ReversalTrader +) + +# 定制风控组件 +from strategies.shared.risk.controls import ( + StopLossControl, + PositionLimitControl, + PremiumControl, + premium_filter_callback, + crash_filter_callback, + holding_time_stoploss_callback +) + + +__all__ = [ + # 因子 + 'MomentumFactor', + 'TrendFactor', + 'ReversalFactor', + 'VolatilityFactor', + + # 信号生成器 + 'TopNSelector', + 'TrendFollower', + 'ReversalTrader', + + # 风控组件 + 'StopLossControl', + 'PositionLimitControl', + 'PremiumControl', + + # 回调函数 + 'premium_filter_callback', + 'crash_filter_callback', + 'holding_time_stoploss_callback', +] \ No newline at end of file diff --git a/strategies/shared/factors/momentum.py b/strategies/shared/factors/momentum.py new file mode 100644 index 0000000..02fae44 --- /dev/null +++ b/strategies/shared/factors/momentum.py @@ -0,0 +1,243 @@ +""" +定制因子实现 + +这些因子继承framework.core.factors.FactorBase +""" + +from framework.factors import FactorBase, FactorRegistry +import pandas as pd +import numpy as np +import math + + +class MomentumFactor(FactorBase): + """ + 动量因子(定制实现) + + 计算加权线性回归动量得分: + 得分 = 年化收益率 × R² + + 参数: + - n_days: 动量窗口(默认25) + - weighted: 是否加权(默认True) + - crash_filter: 是否启用崩盘过滤(默认True) + """ + + name = "momentum" + category = "momentum" + + def __init__( + self, + n_days: int = 25, + weighted: bool = True, + crash_filter: bool = True + ): + super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter) + self.n_days = n_days + self.weighted = weighted + self.crash_filter = crash_filter + + def compute(self, data: pd.DataFrame) -> pd.Series: + """计算动量因子值""" + if 'close' not in data.columns: + raise ValueError("data must contain 'close' column") + + prices = data['close'] + + if self.weighted: + factor_values = prices.rolling(self.n_days).apply( + lambda x: self._weighted_momentum_score(x.values), + raw=False + ) + else: + factor_values = prices.pct_change(self.n_days) + + if self.crash_filter: + factor_values = self._apply_crash_filter(prices, factor_values) + + return factor_values + + def _weighted_momentum_score(self, prices: np.ndarray) -> float: + """计算加权动量得分""" + if len(prices) < 5: + return 0.0 + + y = np.log(prices) + x = np.arange(len(y)) + weights = np.linspace(1, 2, len(y)) + + slope, intercept = np.polyfit(x, y, 1, w=weights) + annualized_returns = math.exp(slope * 250) - 1 + + y_pred = slope * x + intercept + ss_res = np.sum(weights * (y - y_pred) ** 2) + ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2) + r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0 + + return annualized_returns * r2 + + def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series: + """崩盘过滤:连续3天跌>5%清零""" + result = factor_values.copy() + + for i in range(3, len(prices)): + r1 = prices.iloc[i] / prices.iloc[i-1] + r2 = prices.iloc[i-1] / prices.iloc[i-2] + r3 = prices.iloc[i-2] / prices.iloc[i-3] + + con1 = min(r1, r2, r3) < 0.95 + con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95) + + if con1 or con2: + result.iloc[i] = 0.0 + + return result + + +class TrendFactor(FactorBase): + """趋势因子(定制实现)""" + + name = "trend" + category = "trend" + + def __init__(self, method: str = 'ma_cross', fast: int = 5, slow: int = 20): + super().__init__(method=method, fast=fast, slow=slow) + self.method = method + self.fast = fast + self.slow = slow + + def compute(self, data: pd.DataFrame) -> pd.Series: + """计算趋势因子值""" + if 'close' not in data.columns: + raise ValueError("data must contain 'close' column") + + prices = data['close'] + + if self.method == 'ma_cross': + fast_ma = prices.rolling(self.fast).mean() + slow_ma = prices.rolling(self.slow).mean() + return (fast_ma - slow_ma) / slow_ma + + elif self.method == 'macd': + ema12 = prices.ewm(span=12).mean() + ema26 = prices.ewm(span=26).mean() + macd = ema12 - ema26 + signal = macd.ewm(span=9).mean() + return macd - signal + + else: + raise ValueError(f"Unknown method: {self.method}") + + +class ReversalFactor(FactorBase): + """反转因子(定制实现)""" + + name = "reversal" + category = "reversal" + + def __init__(self, method: str = 'rsi', period: int = 14, overbought: float = 70, oversold: float = 30): + super().__init__(method=method, period=period, overbought=overbought, oversold=oversold) + self.method = method + self.period = period + self.overbought = overbought + self.oversold = oversold + + def compute(self, data: pd.DataFrame) -> pd.Series: + """计算反转因子值""" + if 'close' not in data.columns: + raise ValueError("data must contain 'close' column") + + prices = data['close'] + + if self.method == 'rsi': + rsi = self._compute_rsi(prices, self.period) + reversal_signal = np.where( + rsi > self.overbought, + -(rsi - self.overbought) / (100 - self.overbought), + np.where( + rsi < self.oversold, + (self.oversold - rsi) / self.oversold, + 0 + ) + ) + return pd.Series(reversal_signal, index=prices.index) + + elif self.method == 'kdj': + return self._compute_kdj(data) + + else: + raise ValueError(f"Unknown method: {self.method}") + + def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series: + """计算RSI""" + delta = prices.diff() + gain = delta.where(delta > 0, 0) + loss = (-delta).where(delta < 0, 0) + + avg_gain = gain.rolling(period).mean() + avg_loss = loss.rolling(period).mean() + + rs = avg_gain / avg_loss + return 100 - (100 / (1 + rs)) + + def _compute_kdj(self, data: pd.DataFrame) -> pd.Series: + """计算KDJ反转信号""" + low = data['low'] + high = data['high'] + close = data['close'] + + low_min = low.rolling(self.period).min() + high_max = high.rolling(self.period).max() + + rsv = (close - low_min) / (high_max - low_min) * 100 + + k = rsv.ewm(alpha=1/3).mean() + d = k.ewm(alpha=1/3).mean() + j = 3 * k - 2 * d + + return j + + +class VolatilityFactor(FactorBase): + """波动率因子(定制实现)""" + + name = "volatility" + category = "volatility" + + def __init__(self, method: str = 'std', period: int = 20): + super().__init__(method=method, period=period) + self.method = method + self.period = period + + def compute(self, data: pd.DataFrame) -> pd.Series: + """计算波动率因子值""" + if self.method == 'std': + return data['close'].rolling(self.period).std() + + elif self.method == 'atr': + return self._compute_atr(data) + + else: + raise ValueError(f"Unknown method: {self.method}") + + def _compute_atr(self, data: pd.DataFrame) -> pd.Series: + """计算ATR""" + high = data['high'] + low = data['low'] + close = data['close'] + + prev_close = close.shift(1) + tr = pd.concat([ + high - low, + (high - prev_close).abs(), + (low - prev_close).abs() + ], axis=1).max(axis=1) + + return tr.rolling(self.period).mean() + + +# 注册因子 +FactorRegistry.register(MomentumFactor) +FactorRegistry.register(TrendFactor) +FactorRegistry.register(ReversalFactor) +FactorRegistry.register(VolatilityFactor) \ No newline at end of file diff --git a/strategies/shared/risk/controls.py b/strategies/shared/risk/controls.py new file mode 100644 index 0000000..79d94a1 --- /dev/null +++ b/strategies/shared/risk/controls.py @@ -0,0 +1,143 @@ +""" +定制风控组件实现 + +这些风控组件继承framework.core.risk.RiskControl +""" + +from framework.risk import RiskControl, Position, CallbackHook + + +class StopLossControl(RiskControl): + """止损控制(定制实现)""" + + name = "stop_loss" + + def __init__(self, threshold: float = -0.05, trailing: bool = False, trailing_percent: float = 0.03): + super().__init__(threshold=threshold, trailing=trailing, trailing_percent=trailing_percent) + self.threshold = threshold + self.trailing = trailing + self.trailing_percent = trailing_percent + self._highest_price = {} + + def check(self, position: Position, **kwargs) -> bool: + """检查是否触发止损""" + if position is None: + return True + + if self.trailing: + if position.code not in self._highest_price: + self._highest_price[position.code] = position.entry_price + self._highest_price[position.code] = max( + self._highest_price[position.code], + position.current_price + ) + + if self.trailing: + highest = self._highest_price[position.code] + drawdown = (position.current_price - highest) / highest + return drawdown > -self.trailing_percent + else: + return position.profit_ratio > self.threshold + + def apply(self, position: Position): + """返回止损价格""" + if self.trailing: + highest = self._highest_price.get(position.code, position.entry_price) + return highest * (1 - self.trailing_percent) + else: + return position.entry_price * (1 + self.threshold) + + +class PositionLimitControl(RiskControl): + """仓位限制控制(定制实现)""" + + name = "position_limit" + + def __init__(self, max_position: float = 0.33, max_total: float = 1.0): + super().__init__(max_position=max_position, max_total=max_total) + self.max_position = max_position + self.max_total = max_total + + def check(self, position: Position, **kwargs) -> bool: + """检查仓位是否超限""" + if position is None: + return True + + if position.weight > self.max_position: + return False + + return True + + def apply(self, position: Position): + """返回建议仓位""" + return min(position.weight, self.max_position) + + +class PremiumControl(RiskControl): + """溢价控制(定制实现)""" + + name = "premium" + + def __init__(self, threshold: float = 0.10, mode: str = 'filter'): + super().__init__(threshold=threshold, mode=mode) + self.threshold = threshold + self.mode = mode + + def check(self, position: Position, **kwargs) -> bool: + """检查溢价是否超限""" + premium = kwargs.get('premium', 0) + + if self.mode == 'filter': + return premium <= self.threshold + else: + return True + + def apply(self, position: Position): + """返回溢价惩罚系数""" + if self.mode == 'penalize': + return 0.5 + return None + + +# 定制回调函数 +def premium_filter_callback(threshold: float = 0.10): + """溢价过滤回调(定制实现)""" + def callback(code: str, price: float, **kwargs) -> bool: + premium = kwargs.get('premium', 0) + if premium > threshold: + print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") + return False + return True + return callback + + +def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05): + """崩盘过滤回调(定制实现)""" + def callback(code: str, price: float, **kwargs) -> bool: + history = kwargs.get('history', None) + if history is None: + return True + + recent = history.tail(lookback) + if len(recent) < lookback: + return True + + returns = recent['close'].pct_change() + min_return = returns.min() + + if min_return < -crash_threshold: + print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})") + return False + return True + return callback + + +def holding_time_stoploss_callback(day_5_stoploss: float = -0.05, day_10_stoploss: float = -0.03): + """持仓时间动态止损回调(定制实现)""" + def callback(position: Position) -> float: + if position.holding_days >= 10: + return day_10_stoploss + elif position.holding_days >= 5: + return day_5_stoploss + return -0.10 + return callback \ No newline at end of file diff --git a/strategies/shared/signals/selectors.py b/strategies/shared/signals/selectors.py new file mode 100644 index 0000000..97c8bc7 --- /dev/null +++ b/strategies/shared/signals/selectors.py @@ -0,0 +1,215 @@ +""" +定制信号生成器实现 + +这些信号生成器继承framework.core.signals.SignalGenerator +""" + +from framework.signals import SignalGenerator +import pandas as pd +import numpy as np +from typing import Dict, List, Optional, Any + + +class TopNSelector(SignalGenerator): + """ + Top N选股器(定制实现) + + 用于轮动策略: + - 按因子值排序,选出Top N标的 + - 支持分组选股(先类内竞争,再跨类排序) + + 参数: + - select_num: 选中数量(默认3) + - group_by: 分组列名(可选,如'market') + - top_per_group: 每组选中数量(默认1) + - min_score: 最小得分阈值(可选) + """ + + mode = "top_n" + + def __init__( + self, + select_num: int = 3, + group_by: Optional[str] = None, + top_per_group: int = 1, + min_score: Optional[float] = None + ): + super().__init__( + select_num=select_num, + group_by=group_by, + top_per_group=top_per_group, + min_score=min_score + ) + self.select_num = select_num + self.group_by = group_by + self.top_per_group = top_per_group + self.min_score = min_score + + def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: + """生成Top N选股信号""" + result = pd.DataFrame(index=factor_data.index) + + factor_cols = self._get_factor_columns(factor_data) + + if not factor_cols: + result['signal'] = '' + return result + + signals = [] + for date in factor_data.index: + row = factor_data.loc[date] + + scores = {} + for col in factor_cols: + score = row[col] + if pd.notna(score): + scores[col] = score + + if self.min_score: + scores = {k: v for k, v in scores.items() if v >= self.min_score} + + if self.group_by and 'group_info' in factor_data.columns: + selected = self._grouped_selection(scores, factor_data.loc[date]) + else: + selected = self._global_top_n(scores) + + signals.append(','.join(selected) if selected else '') + + result['signal'] = signals + result['signal_raw'] = signals + + result['signal'] = result['signal'].shift(1) + + return result + + def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: + """获取因子列名""" + exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume'] + return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] + + def _global_top_n(self, scores: Dict[str, float]) -> List[str]: + """全局Top N选股""" + if not scores: + return [] + + sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True) + return [item[0] for item in sorted_items[:self.select_num]] + + def _grouped_selection(self, scores: Dict[str, float], row: pd.Series) -> List[str]: + """分组选股:先类内竞争,再跨类排序""" + if 'group_info' not in row.index: + return self._global_top_n(scores) + + group_info = row['group_info'] + if pd.isna(group_info): + return self._global_top_n(scores) + + groups = group_info if isinstance(group_info, dict) else {} + + group_champions = {} + for code, score in scores.items(): + group = groups.get(code, 'default') + if group not in group_champions or score > group_champions[group][1]: + group_champions[group] = (code, score) + + champions_scores = {code: score for code, score in group_champions.values()} + return self._global_top_n(champions_scores) + + +class TrendFollower(SignalGenerator): + """趋势跟随器(定制实现)""" + + mode = "trend" + + def __init__(self, entry_threshold: float = 0.02, exit_threshold: float = -0.02, select_num: int = 1): + super().__init__(entry_threshold=entry_threshold, exit_threshold=exit_threshold, select_num=select_num) + self.entry_threshold = entry_threshold + self.exit_threshold = exit_threshold + self.select_num = select_num + + def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: + """生成趋势跟随信号""" + result = pd.DataFrame(index=factor_data.index) + + factor_cols = self._get_factor_columns(factor_data) + + for col in factor_cols: + trend_strength = factor_data[col] + + result[f'{col}_entry'] = trend_strength > self.entry_threshold + result[f'{col}_exit'] = trend_strength < self.exit_threshold + + signals = [] + for date in result.index: + entry_signals = [] + for col in factor_cols: + if result.loc[date, f'{col}_entry']: + score = factor_data.loc[date, col] + if pd.notna(score): + entry_signals.append((col, score)) + + entry_signals.sort(key=lambda x: x[1], reverse=True) + selected = [item[0] for item in entry_signals[:self.select_num]] + signals.append(','.join(selected) if selected else '') + + result['signal'] = signals + result['signal'] = result['signal'].shift(1) + + return result + + def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: + """获取因子列名""" + exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume'] + return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] + + +class ReversalTrader(SignalGenerator): + """反转交易器(定制实现)""" + + mode = "reversal" + + def __init__(self, overbought: float = 70, oversold: float = 30, reversal_threshold: float = 0.1): + super().__init__(overbought=overbought, oversold=oversold, reversal_threshold=reversal_threshold) + self.overbought = overbought + self.oversold = oversold + self.reversal_threshold = reversal_threshold + + def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame: + """生成反转交易信号""" + result = pd.DataFrame(index=factor_data.index) + + factor_cols = self._get_factor_columns(factor_data) + + for col in factor_cols: + reversal_signal = factor_data[col] + + result[f'{col}_buy'] = reversal_signal > self.reversal_threshold + result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold + + signals = [] + for date in result.index: + buy_signals = [] + sell_signals = [] + + for col in factor_cols: + if result.loc[date, f'{col}_buy']: + buy_signals.append(col) + if result.loc[date, f'{col}_sell']: + sell_signals.append(col) + + if buy_signals: + signals.append(f"BUY:{','.join(buy_signals)}") + elif sell_signals: + signals.append(f"SELL:{','.join(sell_signals)}") + else: + signals.append('') + + result['signal'] = signals + result['signal'] = result['signal'].shift(1) + + return result + + def _get_factor_columns(self, data: pd.DataFrame) -> List[str]: + """获取因子列名""" + exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume'] + return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] \ No newline at end of file