Files
etf/strategies/shared/signals/selectors.py
aszerW 69081297c5 feat(strategies): 实现定制组件(因子、信号生成器、风控)
- strategies/shared/factors/momentum.py: MomentumFactor/TrendFactor/ReversalFactor/VolatilityFactor
- strategies/shared/signals/selectors.py: TopNSelector/TrendFollower/ReversalTrader
- strategies/shared/risk/controls.py: StopLossControl/PositionLimitControl/PremiumControl
- strategies/shared/__init__.py: 统一入口导出所有定制组件
2026-05-11 23:09:35 +08:00

215 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
定制信号生成器实现
这些信号生成器继承framework.core.signals.SignalGenerator
"""
from framework.signals import SignalGenerator
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Any
class TopNSelector(SignalGenerator):
"""
Top N选股器定制实现
用于轮动策略:
- 按因子值排序选出Top N标的
- 支持分组选股(先类内竞争,再跨类排序)
参数:
- select_num: 选中数量默认3
- group_by: 分组列名(可选,如'market'
- top_per_group: 每组选中数量默认1
- min_score: 最小得分阈值(可选)
"""
mode = "top_n"
def __init__(
self,
select_num: int = 3,
group_by: Optional[str] = None,
top_per_group: int = 1,
min_score: Optional[float] = None
):
super().__init__(
select_num=select_num,
group_by=group_by,
top_per_group=top_per_group,
min_score=min_score
)
self.select_num = select_num
self.group_by = group_by
self.top_per_group = top_per_group
self.min_score = min_score
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成Top N选股信号"""
result = pd.DataFrame(index=factor_data.index)
factor_cols = self._get_factor_columns(factor_data)
if not factor_cols:
result['signal'] = ''
return result
signals = []
for date in factor_data.index:
row = factor_data.loc[date]
scores = {}
for col in factor_cols:
score = row[col]
if pd.notna(score):
scores[col] = score
if self.min_score:
scores = {k: v for k, v in scores.items() if v >= self.min_score}
if self.group_by and 'group_info' in factor_data.columns:
selected = self._grouped_selection(scores, factor_data.loc[date])
else:
selected = self._global_top_n(scores)
signals.append(','.join(selected) if selected else '')
result['signal'] = signals
result['signal_raw'] = signals
result['signal'] = result['signal'].shift(1)
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
"""全局Top N选股"""
if not scores:
return []
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
return [item[0] for item in sorted_items[:self.select_num]]
def _grouped_selection(self, scores: Dict[str, float], row: pd.Series) -> List[str]:
"""分组选股:先类内竞争,再跨类排序"""
if 'group_info' not in row.index:
return self._global_top_n(scores)
group_info = row['group_info']
if pd.isna(group_info):
return self._global_top_n(scores)
groups = group_info if isinstance(group_info, dict) else {}
group_champions = {}
for code, score in scores.items():
group = groups.get(code, 'default')
if group not in group_champions or score > group_champions[group][1]:
group_champions[group] = (code, score)
champions_scores = {code: score for code, score in group_champions.values()}
return self._global_top_n(champions_scores)
class TrendFollower(SignalGenerator):
"""趋势跟随器(定制实现)"""
mode = "trend"
def __init__(self, entry_threshold: float = 0.02, exit_threshold: float = -0.02, select_num: int = 1):
super().__init__(entry_threshold=entry_threshold, exit_threshold=exit_threshold, select_num=select_num)
self.entry_threshold = entry_threshold
self.exit_threshold = exit_threshold
self.select_num = select_num
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成趋势跟随信号"""
result = pd.DataFrame(index=factor_data.index)
factor_cols = self._get_factor_columns(factor_data)
for col in factor_cols:
trend_strength = factor_data[col]
result[f'{col}_entry'] = trend_strength > self.entry_threshold
result[f'{col}_exit'] = trend_strength < self.exit_threshold
signals = []
for date in result.index:
entry_signals = []
for col in factor_cols:
if result.loc[date, f'{col}_entry']:
score = factor_data.loc[date, col]
if pd.notna(score):
entry_signals.append((col, score))
entry_signals.sort(key=lambda x: x[1], reverse=True)
selected = [item[0] for item in entry_signals[:self.select_num]]
signals.append(','.join(selected) if selected else '')
result['signal'] = signals
result['signal'] = result['signal'].shift(1)
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
class ReversalTrader(SignalGenerator):
"""反转交易器(定制实现)"""
mode = "reversal"
def __init__(self, overbought: float = 70, oversold: float = 30, reversal_threshold: float = 0.1):
super().__init__(overbought=overbought, oversold=oversold, reversal_threshold=reversal_threshold)
self.overbought = overbought
self.oversold = oversold
self.reversal_threshold = reversal_threshold
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""生成反转交易信号"""
result = pd.DataFrame(index=factor_data.index)
factor_cols = self._get_factor_columns(factor_data)
for col in factor_cols:
reversal_signal = factor_data[col]
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
signals = []
for date in result.index:
buy_signals = []
sell_signals = []
for col in factor_cols:
if result.loc[date, f'{col}_buy']:
buy_signals.append(col)
if result.loc[date, f'{col}_sell']:
sell_signals.append(col)
if buy_signals:
signals.append(f"BUY:{','.join(buy_signals)}")
elif sell_signals:
signals.append(f"SELL:{','.join(sell_signals)}")
else:
signals.append('')
result['signal'] = signals
result['signal'] = result['signal'].shift(1)
return result
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
"""获取因子列名"""
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]