feat(signals): 实现信号生成层抽象

核心组件:
- SignalGenerator: 信号生成器抽象基类
- TopNSelector: Top N选股器(轮动策略)
  - 支持分组选股(先类内竞争,再跨类排序)
  - 支持最小得分阈值过滤
- TrendFollower: 趋势跟随器(趋势策略)
  - 入场阈值/出场阈值控制
- ReversalTrader: 反转交易器(反转策略)
  - 超买超卖信号生成

特点:
- T+1执行机制(信号shift向后移位)
- 向量化计算,避免前视偏差

测试覆盖:10个测试全部通过
This commit is contained in:
2026-05-11 22:18:20 +08:00
parent 796a695eef
commit f5e6202eee
2 changed files with 592 additions and 0 deletions

View File

@@ -0,0 +1,239 @@
"""
信号层测试
测试SignalGenerator、TopNSelector、TrendFollower、ReversalTrader
"""
import pandas as pd
import numpy as np
import pytest
from framework.signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
class TestSignalGenerator:
"""测试信号生成器基类"""
def test_signal_meta(self):
"""测试信号元信息"""
selector = TopNSelector(select_num=3)
assert selector.mode == "top_n"
assert selector.params == {'select_num': 3, 'group_by': None, 'top_per_group': 1, 'min_score': None}
def test_signal_repr(self):
"""测试信号字符串表示"""
selector = TopNSelector(select_num=5)
repr_str = repr(selector)
assert "TopNSelector" in repr_str
assert "top_n" in repr_str
class TestTopNSelector:
"""测试Top N选股器"""
def test_global_top_n(self):
"""测试全局Top N选股"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建因子数据3个标的得分递减
factor_data = pd.DataFrame({
'code1': [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
'code2': [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
'code3': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(factor_data)
# 检查信号列
assert 'signal' in result.columns
# 第一天无信号shift后
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
# 第二天及之后应该选中code1,code2
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal and 'code2' in signal
def test_top_n_with_min_score(self):
"""测试带最小得分阈值的选股"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [5.0] * 10,
'code2': [2.0] * 10, # 低于阈值
'code3': [1.0] * 10, # 低于阈值
}, index=dates)
selector = TopNSelector(select_num=3, min_score=3.0)
result = selector.generate(factor_data)
# 只有code1满足阈值
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal
assert 'code2' not in signal
def test_grouped_selection(self):
"""测试分组选股"""
dates = pd.date_range('2020-01-01', periods=5)
# 创建因子数据和分组信息
factor_data = pd.DataFrame({
'code1': [5.0] * 5, # group A, 最高
'code2': [4.0] * 5, # group A, 次高
'code3': [3.0] * 5, # group B, 最高
'code4': [2.0] * 5, # group B, 次高
'code5': [1.0] * 5, # group C
}, index=dates)
# 分组信息:每行是一个字典 {code: group}
group_info = {
'code1': 'A', 'code2': 'A',
'code3': 'B', 'code4': 'B',
'code5': 'C'
}
factor_data['group_info'] = [group_info] * 5
selector = TopNSelector(select_num=2, group_by='market', top_per_group=1)
result = selector.generate(factor_data)
# 应该选中code1A组冠军、code3B组冠军
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
# code1和code3应该被选中得分最高的两组冠军
selected_codes = signal.split(',')
assert 'code1' in selected_codes or 'code3' in selected_codes
def test_empty_scores(self):
"""测试空得分情况"""
dates = pd.date_range('2020-01-01', periods=5)
# 所有得分为NaN
factor_data = pd.DataFrame({
'code1': [np.nan] * 5,
'code2': [np.nan] * 5,
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(factor_data)
# 应该返回空信号
for i in range(len(result)):
signal = result['signal'].iloc[i]
assert signal == '' or pd.isna(signal)
class TestTrendFollower:
"""测试趋势跟随器"""
def test_trend_entry_signal(self):
"""测试趋势入场信号"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建趋势数据code1强趋势code2弱趋势
factor_data = pd.DataFrame({
'code1': [0.03] * 10, # > 阈值0.02,入场
'code2': [0.01] * 10, # < 阈值0.02,不入场
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
result = follower.generate(factor_data)
# code1应该有入场信号
assert result['code1_entry'].iloc[0] == True
assert result['code2_entry'].iloc[0] == False
def test_trend_exit_signal(self):
"""测试趋势出场信号"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [-0.03] * 10, # < 阈值-0.02,出场
'code2': [0.01] * 10,
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
result = follower.generate(factor_data)
# code1应该有出场信号
assert result['code1_exit'].iloc[0] == True
def test_trend_signal_format(self):
"""测试趋势信号格式"""
dates = pd.date_range('2020-01-01', periods=5)
factor_data = pd.DataFrame({
'code1': [0.05] * 5, # 强趋势,入场
'code2': [0.03] * 5, # 中等趋势,入场
'code3': [0.01] * 5, # 弱趋势,不入场
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, select_num=2)
result = follower.generate(factor_data)
# 信号应该包含code1和code2强度最高的两个
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal or 'code2' in signal
class TestReversalTrader:
"""测试反转交易器"""
def test_reversal_buy_signal(self):
"""测试反转买入信号"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建反转数据code1超卖反转
factor_data = pd.DataFrame({
'code1': [0.2] * 10, # > 阈值0.1,超卖反转(买入)
'code2': [0.05] * 10, # < 阈值0.1,无信号
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# code1应该有买入信号
assert result['code1_buy'].iloc[0] == True
assert result['code2_buy'].iloc[0] == False
def test_reversal_sell_signal(self):
"""测试反转卖出信号"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [-0.2] * 10, # < -阈值0.1,超买反转(卖出)
'code2': [0.05] * 10,
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# code1应该有卖出信号
assert result['code1_sell'].iloc[0] == True
def test_reversal_signal_format(self):
"""测试反转信号格式"""
dates = pd.date_range('2020-01-01', periods=5)
factor_data = pd.DataFrame({
'code1': [0.15] * 5, # 超卖反转
'code2': [-0.15] * 5, # 超买反转
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# 信号格式应该是 'BUY:code' 或 'SELL:code'
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
if 'BUY' in signal:
assert 'code1' in signal
elif 'SELL' in signal:
assert 'code2' in signal
if __name__ == '__main__':
pytest.main([__file__, '-v'])