- 测试文件改用strategies.shared的具体实现 - 新增framework_comparison_test.py对比新旧实现结果 - 因子计算相关系数达到1.0000,差异为0.000000 - 79个单元测试全部通过
163 lines
5.0 KiB
Python
163 lines
5.0 KiB
Python
"""
|
||
信号层测试
|
||
|
||
测试SignalGenerator抽象接口
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import pytest
|
||
|
||
from framework.signals import SignalGenerator
|
||
from strategies.shared.signals.selectors import TopNSelector, TrendFollower, ReversalTrader
|
||
|
||
|
||
class TestTopNSelector:
|
||
"""测试TopNSelector"""
|
||
|
||
def test_top_n_selector_init(self):
|
||
"""测试初始化"""
|
||
selector = TopNSelector(select_num=3)
|
||
assert selector.select_num == 3
|
||
assert selector.mode == "top_n"
|
||
|
||
def test_top_n_selection(self):
|
||
"""测试Top N选股"""
|
||
dates = pd.date_range('2020-01-01', periods=50)
|
||
|
||
# 生成因子数据
|
||
data = pd.DataFrame({
|
||
'factor_A': np.random.randn(50),
|
||
'factor_B': np.random.randn(50),
|
||
'factor_C': np.random.randn(50),
|
||
}, index=dates)
|
||
|
||
selector = TopNSelector(select_num=2)
|
||
result = selector.generate(data)
|
||
|
||
# 检查结果列
|
||
assert 'signal' in result.columns
|
||
assert 'signal_raw' in result.columns
|
||
|
||
# 检查T+1移位(signal比signal_raw滞后1天)
|
||
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
|
||
|
||
def test_min_score_filter(self):
|
||
"""测试最小得分过滤"""
|
||
dates = pd.date_range('2020-01-01', periods=50)
|
||
|
||
# 生成因子数据,部分为负值
|
||
data = pd.DataFrame({
|
||
'factor_A': [0.1] * 50,
|
||
'factor_B': [-0.1] * 50, # 负分
|
||
'factor_C': [0.2] * 50,
|
||
}, index=dates)
|
||
|
||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||
result = selector.generate(data)
|
||
|
||
# 负分因子应该被过滤
|
||
signals = result['signal_raw'].dropna().unique()
|
||
for sig in signals:
|
||
if sig:
|
||
codes = sig.split(',')
|
||
assert 'factor_B' not in codes
|
||
|
||
def test_grouped_selection(self):
|
||
"""测试分组选股"""
|
||
dates = pd.date_range('2020-01-01', periods=50)
|
||
|
||
# 生成因子数据和分组信息
|
||
data = pd.DataFrame({
|
||
'factor_A': [0.1] * 50,
|
||
'factor_B': [0.2] * 50,
|
||
'factor_C': [0.15] * 50,
|
||
}, index=dates)
|
||
|
||
# 分组信息(模拟)
|
||
# 注:实际使用需要在数据中包含group_info列
|
||
|
||
selector = TopNSelector(select_num=2, group_by='market')
|
||
result = selector.generate(data)
|
||
|
||
assert 'signal' in result.columns
|
||
|
||
|
||
class TestTrendFollower:
|
||
"""测试趋势跟随器"""
|
||
|
||
def test_trend_follower_init(self):
|
||
"""测试初始化"""
|
||
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
|
||
assert follower.entry_threshold == 0.02
|
||
assert follower.exit_threshold == -0.02
|
||
assert follower.mode == "trend"
|
||
|
||
def test_trend_signal_generation(self):
|
||
"""测试趋势信号生成"""
|
||
dates = pd.date_range('2020-01-01', periods=50)
|
||
|
||
# 生成趋势因子数据
|
||
data = pd.DataFrame({
|
||
'trend_A': [0.05] * 50, # 强趋势
|
||
'trend_B': [-0.05] * 50, # 弱趋势
|
||
}, index=dates)
|
||
|
||
follower = TrendFollower(entry_threshold=0.02)
|
||
result = follower.generate(data)
|
||
|
||
# 检查信号列
|
||
assert 'signal' in result.columns
|
||
|
||
|
||
class TestReversalTrader:
|
||
"""测试反转交易器"""
|
||
|
||
def test_reversal_trader_init(self):
|
||
"""测试初始化"""
|
||
trader = ReversalTrader(overbought=70, oversold=30)
|
||
assert trader.overbought == 70
|
||
assert trader.oversold == 30
|
||
assert trader.mode == "reversal"
|
||
|
||
def test_reversal_signal_generation(self):
|
||
"""测试反转信号生成"""
|
||
dates = pd.date_range('2020-01-01', periods=50)
|
||
|
||
# 生成反转因子数据
|
||
data = pd.DataFrame({
|
||
'reversal_A': [0.15] * 50, # 超卖反转
|
||
'reversal_B': [-0.15] * 50, # 超买反转
|
||
}, index=dates)
|
||
|
||
trader = ReversalTrader(reversal_threshold=0.1)
|
||
result = trader.generate(data)
|
||
|
||
# 检查信号列
|
||
assert 'signal' in result.columns
|
||
|
||
|
||
class TestSignalGeneratorBase:
|
||
"""测试SignalGenerator抽象基类"""
|
||
|
||
def test_validate_factor_data(self):
|
||
"""测试数据验证"""
|
||
selector = TopNSelector(select_num=3)
|
||
|
||
# 空数据应该返回False
|
||
empty_data = pd.DataFrame()
|
||
assert selector.validate_factor_data(empty_data) == False
|
||
|
||
# 有效数据应该返回True
|
||
valid_data = pd.DataFrame({'factor_A': [1, 2, 3]})
|
||
assert selector.validate_factor_data(valid_data) == True
|
||
|
||
def test_repr(self):
|
||
"""测试字符串表示"""
|
||
selector = TopNSelector(select_num=3, min_score=0.5)
|
||
repr_str = repr(selector)
|
||
assert "TopNSelector" in repr_str
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pytest.main([__file__, '-v']) |