核心组件: - SignalGenerator: 信号生成器抽象基类 - TopNSelector: Top N选股器(轮动策略) - 支持分组选股(先类内竞争,再跨类排序) - 支持最小得分阈值过滤 - TrendFollower: 趋势跟随器(趋势策略) - 入场阈值/出场阈值控制 - ReversalTrader: 反转交易器(反转策略) - 超买超卖信号生成 特点: - T+1执行机制(信号shift向后移位) - 向量化计算,避免前视偏差 测试覆盖:10个测试全部通过
353 lines
11 KiB
Python
353 lines
11 KiB
Python
"""
|
||
信号层抽象设计
|
||
|
||
核心组件:
|
||
- SignalGenerator: 信号生成器抽象基类
|
||
- TopNSelector: Top N选股器(轮动策略)
|
||
- TrendFollower: 趋势跟随器(趋势策略)
|
||
- ReversalTrader: 反转交易器(反转策略)
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List, Optional, Any
|
||
from dataclasses import dataclass
|
||
|
||
|
||
@dataclass
|
||
class SignalMeta:
|
||
"""信号元信息"""
|
||
mode: str # 'top_n', 'trend', 'reversal'
|
||
select_num: int
|
||
description: str = ""
|
||
|
||
|
||
class SignalGenerator(ABC):
|
||
"""
|
||
信号生成器抽象基类
|
||
|
||
所有信号生成器必须继承此基类,实现generate方法。
|
||
支持不同策略类型的信号生成逻辑。
|
||
"""
|
||
|
||
# 类属性(可被配置覆盖)
|
||
mode: str = "base"
|
||
|
||
def __init__(self, **params):
|
||
"""
|
||
初始化信号生成器
|
||
|
||
Args:
|
||
**params: 信号参数
|
||
"""
|
||
self._params = params
|
||
self._meta = SignalMeta(
|
||
mode=self.mode,
|
||
select_num=params.get('select_num', 1),
|
||
description=self.__doc__ or ""
|
||
)
|
||
|
||
@abstractmethod
|
||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
生成交易信号
|
||
|
||
Args:
|
||
factor_data: 包含因子值的DataFrame
|
||
|
||
Returns:
|
||
包含信号列的DataFrame
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
def params(self) -> Dict[str, Any]:
|
||
"""获取信号参数"""
|
||
return self._params
|
||
|
||
@property
|
||
def meta(self) -> SignalMeta:
|
||
"""获取信号元信息"""
|
||
return self._meta
|
||
|
||
def __repr__(self) -> str:
|
||
return f"{self.__class__.__name__}(mode={self.mode}, params={self._params})"
|
||
|
||
|
||
class TopNSelector(SignalGenerator):
|
||
"""
|
||
Top N选股器
|
||
|
||
用于轮动策略:
|
||
- 按因子值排序,选出Top N标的
|
||
- 支持分组选股(先类内竞争,再跨类排序)
|
||
|
||
参数:
|
||
- select_num: 选中数量(默认3)
|
||
- group_by: 分组列名(可选,如'market')
|
||
- top_per_group: 每组选中数量(默认1)
|
||
- min_score: 最小得分阈值(可选)
|
||
"""
|
||
|
||
mode = "top_n"
|
||
|
||
def __init__(
|
||
self,
|
||
select_num: int = 3,
|
||
group_by: Optional[str] = None,
|
||
top_per_group: int = 1,
|
||
min_score: Optional[float] = None
|
||
):
|
||
super().__init__(
|
||
select_num=select_num,
|
||
group_by=group_by,
|
||
top_per_group=top_per_group,
|
||
min_score=min_score
|
||
)
|
||
self.select_num = select_num
|
||
self.group_by = group_by
|
||
self.top_per_group = top_per_group
|
||
self.min_score = min_score
|
||
|
||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||
"""生成Top N选股信号"""
|
||
result = pd.DataFrame(index=factor_data.index)
|
||
|
||
# 获取因子列(排除非因子列)
|
||
factor_cols = self._get_factor_columns(factor_data)
|
||
|
||
if not factor_cols:
|
||
print("⚠ 未找到因子列")
|
||
result['signal'] = ''
|
||
return result
|
||
|
||
# 每日选股
|
||
signals = []
|
||
for date in factor_data.index:
|
||
row = factor_data.loc[date]
|
||
|
||
# 提取当日因子值
|
||
scores = {}
|
||
for col in factor_cols:
|
||
score = row[col]
|
||
if pd.notna(score):
|
||
scores[col] = score
|
||
|
||
# 应用最小得分过滤
|
||
if self.min_score:
|
||
scores = {k: v for k, v in scores.items() if v >= self.min_score}
|
||
|
||
# 选股逻辑
|
||
if self.group_by and 'group_info' in factor_data.columns:
|
||
# 分组选股:先类内竞争,再跨类排序
|
||
selected = self._grouped_selection(scores, factor_data.loc[date])
|
||
else:
|
||
# 全局Top N
|
||
selected = self._global_top_n(scores)
|
||
|
||
# 信号格式:逗号分隔的代码列表
|
||
signals.append(','.join(selected) if selected else '')
|
||
|
||
result['signal'] = signals
|
||
result['signal_raw'] = signals # 原始信号(未shift)
|
||
|
||
# T+1执行:信号向后移位1天
|
||
result['signal'] = result['signal'].shift(1)
|
||
|
||
return result
|
||
|
||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||
"""获取因子列名"""
|
||
# 排除已知非因子列
|
||
exclude_cols = ['signal', 'signal_raw', 'group_info', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||
factor_cols = [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||
return factor_cols
|
||
|
||
def _global_top_n(self, scores: Dict[str, float]) -> List[str]:
|
||
"""全局Top N选股"""
|
||
if not scores:
|
||
return []
|
||
|
||
# 按得分排序
|
||
sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||
|
||
# 选Top N
|
||
selected = [item[0] for item in sorted_items[:self.select_num]]
|
||
return selected
|
||
|
||
def _grouped_selection(
|
||
self,
|
||
scores: Dict[str, float],
|
||
row: pd.Series
|
||
) -> List[str]:
|
||
"""分组选股:先类内竞争,再跨类排序"""
|
||
if 'group_info' not in row.index:
|
||
return self._global_top_n(scores)
|
||
|
||
group_info = row['group_info']
|
||
if pd.isna(group_info):
|
||
return self._global_top_n(scores)
|
||
|
||
# 解析分组信息:{code: group}
|
||
groups = group_info if isinstance(group_info, dict) else {}
|
||
|
||
# 类内竞争:每组选Top1
|
||
group_champions = {}
|
||
for code, score in scores.items():
|
||
group = groups.get(code, 'default')
|
||
if group not in group_champions or score > group_champions[group][1]:
|
||
group_champions[group] = (code, score)
|
||
|
||
# 跨类排序:从冠军中选Top N
|
||
champions_scores = {code: score for code, score in group_champions.values()}
|
||
return self._global_top_n(champions_scores)
|
||
|
||
|
||
class TrendFollower(SignalGenerator):
|
||
"""
|
||
趋势跟随器
|
||
|
||
用于趋势跟踪策略:
|
||
- 趋势强度 > 入场阈值 → 入场信号
|
||
- 趋势强度 < 出场阈值 → 出场信号
|
||
|
||
参数:
|
||
- entry_threshold: 入场阈值(默认0.02)
|
||
- exit_threshold: 出场阈值(默认-0.02)
|
||
- select_num: 最大持仓数量(默认1)
|
||
"""
|
||
|
||
mode = "trend"
|
||
|
||
def __init__(
|
||
self,
|
||
entry_threshold: float = 0.02,
|
||
exit_threshold: float = -0.02,
|
||
select_num: int = 1
|
||
):
|
||
super().__init__(
|
||
entry_threshold=entry_threshold,
|
||
exit_threshold=exit_threshold,
|
||
select_num=select_num
|
||
)
|
||
self.entry_threshold = entry_threshold
|
||
self.exit_threshold = exit_threshold
|
||
self.select_num = select_num
|
||
|
||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||
"""生成趋势跟随信号"""
|
||
result = pd.DataFrame(index=factor_data.index)
|
||
|
||
factor_cols = self._get_factor_columns(factor_data)
|
||
|
||
for col in factor_cols:
|
||
trend_strength = factor_data[col]
|
||
|
||
# 入场信号:趋势强度 > 阈值
|
||
result[f'{col}_entry'] = trend_strength > self.entry_threshold
|
||
|
||
# 出场信号:趋势强度 < 阈值
|
||
result[f'{col}_exit'] = trend_strength < self.exit_threshold
|
||
|
||
# 综合信号:入场强度最高的Top N
|
||
signals = []
|
||
for date in result.index:
|
||
entry_signals = []
|
||
for col in factor_cols:
|
||
if result.loc[date, f'{col}_entry']:
|
||
score = factor_data.loc[date, col]
|
||
if pd.notna(score):
|
||
entry_signals.append((col, score))
|
||
|
||
# 按强度排序,选Top N
|
||
entry_signals.sort(key=lambda x: x[1], reverse=True)
|
||
selected = [item[0] for item in entry_signals[:self.select_num]]
|
||
signals.append(','.join(selected) if selected else '')
|
||
|
||
result['signal'] = signals
|
||
result['signal'] = result['signal'].shift(1) # T+1执行
|
||
|
||
return result
|
||
|
||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||
"""获取因子列名"""
|
||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')]
|
||
|
||
|
||
class ReversalTrader(SignalGenerator):
|
||
"""
|
||
反转交易器
|
||
|
||
用于反转策略:
|
||
- 超买区域(RSI>70) → 反转向下信号(卖出)
|
||
- 超卖区域(RSI<30) → 反转向上信号(买入)
|
||
|
||
参数:
|
||
- overbought: 超买阈值(默认70)
|
||
- oversold: 超卖阈值(默认30)
|
||
- reversal_threshold: 反转信号强度阈值(默认0.1)
|
||
"""
|
||
|
||
mode = "reversal"
|
||
|
||
def __init__(
|
||
self,
|
||
overbought: float = 70,
|
||
oversold: float = 30,
|
||
reversal_threshold: float = 0.1
|
||
):
|
||
super().__init__(
|
||
overbought=overbought,
|
||
oversold=oversold,
|
||
reversal_threshold=reversal_threshold
|
||
)
|
||
self.overbought = overbought
|
||
self.oversold = oversold
|
||
self.reversal_threshold = reversal_threshold
|
||
|
||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||
"""生成反转交易信号"""
|
||
result = pd.DataFrame(index=factor_data.index)
|
||
|
||
factor_cols = self._get_factor_columns(factor_data)
|
||
|
||
for col in factor_cols:
|
||
reversal_signal = factor_data[col]
|
||
|
||
# 买入信号:反转信号 > 阈值(正值,超卖反转)
|
||
result[f'{col}_buy'] = reversal_signal > self.reversal_threshold
|
||
|
||
# 卖出信号:反转信号 < -阈值(负值,超买反转)
|
||
result[f'{col}_sell'] = reversal_signal < -self.reversal_threshold
|
||
|
||
# 综合信号
|
||
signals = []
|
||
for date in result.index:
|
||
buy_signals = []
|
||
sell_signals = []
|
||
|
||
for col in factor_cols:
|
||
if result.loc[date, f'{col}_buy']:
|
||
buy_signals.append(col)
|
||
if result.loc[date, f'{col}_sell']:
|
||
sell_signals.append(col)
|
||
|
||
# 信号格式:'BUY:code1,code2' 或 'SELL:code1' 或 ''
|
||
if buy_signals:
|
||
signals.append(f"BUY:{','.join(buy_signals)}")
|
||
elif sell_signals:
|
||
signals.append(f"SELL:{','.join(sell_signals)}")
|
||
else:
|
||
signals.append('')
|
||
|
||
result['signal'] = signals
|
||
result['signal'] = result['signal'].shift(1) # T+1执行
|
||
|
||
return result
|
||
|
||
def _get_factor_columns(self, data: pd.DataFrame) -> List[str]:
|
||
"""获取因子列名"""
|
||
exclude_cols = ['signal', 'signal_raw', 'combined', 'open', 'high', 'low', 'close', 'volume']
|
||
return [col for col in data.columns if col not in exclude_cols and not col.endswith('_weighted')] |