Files
etf/framework/tests/test_signals.py
aszerW f5e6202eee feat(signals): 实现信号生成层抽象
核心组件:
- SignalGenerator: 信号生成器抽象基类
- TopNSelector: Top N选股器(轮动策略)
  - 支持分组选股(先类内竞争,再跨类排序)
  - 支持最小得分阈值过滤
- TrendFollower: 趋势跟随器(趋势策略)
  - 入场阈值/出场阈值控制
- ReversalTrader: 反转交易器(反转策略)
  - 超买超卖信号生成

特点:
- T+1执行机制(信号shift向后移位)
- 向量化计算,避免前视偏差

测试覆盖:10个测试全部通过
2026-05-11 22:18:20 +08:00

239 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
信号层测试
测试SignalGenerator、TopNSelector、TrendFollower、ReversalTrader
"""
import pandas as pd
import numpy as np
import pytest
from framework.signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
class TestSignalGenerator:
"""测试信号生成器基类"""
def test_signal_meta(self):
"""测试信号元信息"""
selector = TopNSelector(select_num=3)
assert selector.mode == "top_n"
assert selector.params == {'select_num': 3, 'group_by': None, 'top_per_group': 1, 'min_score': None}
def test_signal_repr(self):
"""测试信号字符串表示"""
selector = TopNSelector(select_num=5)
repr_str = repr(selector)
assert "TopNSelector" in repr_str
assert "top_n" in repr_str
class TestTopNSelector:
"""测试Top N选股器"""
def test_global_top_n(self):
"""测试全局Top N选股"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建因子数据3个标的得分递减
factor_data = pd.DataFrame({
'code1': [5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
'code2': [3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0],
'code3': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(factor_data)
# 检查信号列
assert 'signal' in result.columns
# 第一天无信号shift后
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
# 第二天及之后应该选中code1,code2
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal and 'code2' in signal
def test_top_n_with_min_score(self):
"""测试带最小得分阈值的选股"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [5.0] * 10,
'code2': [2.0] * 10, # 低于阈值
'code3': [1.0] * 10, # 低于阈值
}, index=dates)
selector = TopNSelector(select_num=3, min_score=3.0)
result = selector.generate(factor_data)
# 只有code1满足阈值
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal
assert 'code2' not in signal
def test_grouped_selection(self):
"""测试分组选股"""
dates = pd.date_range('2020-01-01', periods=5)
# 创建因子数据和分组信息
factor_data = pd.DataFrame({
'code1': [5.0] * 5, # group A, 最高
'code2': [4.0] * 5, # group A, 次高
'code3': [3.0] * 5, # group B, 最高
'code4': [2.0] * 5, # group B, 次高
'code5': [1.0] * 5, # group C
}, index=dates)
# 分组信息:每行是一个字典 {code: group}
group_info = {
'code1': 'A', 'code2': 'A',
'code3': 'B', 'code4': 'B',
'code5': 'C'
}
factor_data['group_info'] = [group_info] * 5
selector = TopNSelector(select_num=2, group_by='market', top_per_group=1)
result = selector.generate(factor_data)
# 应该选中code1A组冠军、code3B组冠军
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
# code1和code3应该被选中得分最高的两组冠军
selected_codes = signal.split(',')
assert 'code1' in selected_codes or 'code3' in selected_codes
def test_empty_scores(self):
"""测试空得分情况"""
dates = pd.date_range('2020-01-01', periods=5)
# 所有得分为NaN
factor_data = pd.DataFrame({
'code1': [np.nan] * 5,
'code2': [np.nan] * 5,
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(factor_data)
# 应该返回空信号
for i in range(len(result)):
signal = result['signal'].iloc[i]
assert signal == '' or pd.isna(signal)
class TestTrendFollower:
"""测试趋势跟随器"""
def test_trend_entry_signal(self):
"""测试趋势入场信号"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建趋势数据code1强趋势code2弱趋势
factor_data = pd.DataFrame({
'code1': [0.03] * 10, # > 阈值0.02,入场
'code2': [0.01] * 10, # < 阈值0.02,不入场
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
result = follower.generate(factor_data)
# code1应该有入场信号
assert result['code1_entry'].iloc[0] == True
assert result['code2_entry'].iloc[0] == False
def test_trend_exit_signal(self):
"""测试趋势出场信号"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [-0.03] * 10, # < 阈值-0.02,出场
'code2': [0.01] * 10,
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
result = follower.generate(factor_data)
# code1应该有出场信号
assert result['code1_exit'].iloc[0] == True
def test_trend_signal_format(self):
"""测试趋势信号格式"""
dates = pd.date_range('2020-01-01', periods=5)
factor_data = pd.DataFrame({
'code1': [0.05] * 5, # 强趋势,入场
'code2': [0.03] * 5, # 中等趋势,入场
'code3': [0.01] * 5, # 弱趋势,不入场
}, index=dates)
follower = TrendFollower(entry_threshold=0.02, select_num=2)
result = follower.generate(factor_data)
# 信号应该包含code1和code2强度最高的两个
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
assert 'code1' in signal or 'code2' in signal
class TestReversalTrader:
"""测试反转交易器"""
def test_reversal_buy_signal(self):
"""测试反转买入信号"""
dates = pd.date_range('2020-01-01', periods=10)
# 创建反转数据code1超卖反转
factor_data = pd.DataFrame({
'code1': [0.2] * 10, # > 阈值0.1,超卖反转(买入)
'code2': [0.05] * 10, # < 阈值0.1,无信号
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# code1应该有买入信号
assert result['code1_buy'].iloc[0] == True
assert result['code2_buy'].iloc[0] == False
def test_reversal_sell_signal(self):
"""测试反转卖出信号"""
dates = pd.date_range('2020-01-01', periods=10)
factor_data = pd.DataFrame({
'code1': [-0.2] * 10, # < -阈值0.1,超买反转(卖出)
'code2': [0.05] * 10,
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# code1应该有卖出信号
assert result['code1_sell'].iloc[0] == True
def test_reversal_signal_format(self):
"""测试反转信号格式"""
dates = pd.date_range('2020-01-01', periods=5)
factor_data = pd.DataFrame({
'code1': [0.15] * 5, # 超卖反转
'code2': [-0.15] * 5, # 超买反转
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(factor_data)
# 信号格式应该是 'BUY:code' 或 'SELL:code'
for i in range(1, len(result)):
signal = result['signal'].iloc[i]
if 'BUY' in signal:
assert 'code1' in signal
elif 'SELL' in signal:
assert 'code2' in signal
if __name__ == '__main__':
pytest.main([__file__, '-v'])