""" 因子层测试 测试FactorBase、FactorRegistry、FactorCombiner抽象接口 """ import pandas as pd import numpy as np import pytest from framework.factors import FactorBase, FactorRegistry, FactorCombiner from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor class TestFactorBase: """测试FactorBase抽象基类""" def test_factor_meta(self): """测试因子元信息""" factor = MomentumFactor(n_days=25) assert factor.name == "momentum" assert factor.category == "momentum" def test_factor_repr(self): """测试因子字符串表示""" factor = MomentumFactor(n_days=30) repr_str = repr(factor) assert "MomentumFactor" in repr_str def test_validate_data(self): """测试数据验证""" factor = MomentumFactor(n_days=25) # 数据充足 data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100 }) assert factor.validate_data(data) == True # 数据不足 short_data = pd.DataFrame({ 'close': np.random.randn(10).cumsum() + 100 }) assert factor.validate_data(short_data) == False class TestFactorRegistry: """测试因子注册器""" def setup_method(self): """每个测试前清空注册表""" FactorRegistry.clear() def test_register_factor(self): """测试因子注册""" FactorRegistry.register(MomentumFactor) assert 'momentum' in FactorRegistry.list_factors() def test_get_factor(self): """测试获取因子实例""" FactorRegistry.register(MomentumFactor) factor = FactorRegistry.get('momentum', n_days=30) assert isinstance(factor, MomentumFactor) assert factor.n_days == 30 def test_get_unknown_factor(self): """测试获取未注册因子""" with pytest.raises(ValueError): FactorRegistry.get('unknown_factor') def test_get_category(self): """测试获取因子类别""" FactorRegistry.register(MomentumFactor) category = FactorRegistry.get_category('momentum') assert category == 'momentum' class TestFactorCombiner: """测试因子组合器""" def setup_method(self): """每个测试前清空注册表""" FactorRegistry.clear() def test_combiner_init(self): """测试组合器初始化""" factors = [ MomentumFactor(n_days=25), TrendFactor(method='ma_cross') ] combiner = FactorCombiner(factors, weights=[0.7, 0.3]) assert len(combiner.get_factor_names()) == 2 def test_combiner_equal_weights(self): """测试等权组合""" factors = [ MomentumFactor(n_days=25), TrendFactor() ] combiner = FactorCombiner(factors) # 默认等权 # 权重应该归一化 assert sum(combiner._weights) == 1.0 def test_combiner_compute(self): """测试因子组合计算""" # 生成测试数据 dates = pd.date_range('2020-01-01', periods=100) data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100, 'high': np.random.randn(100).cumsum() + 105, 'low': np.random.randn(100).cumsum() + 95 }, index=dates) factors = [ MomentumFactor(n_days=20), TrendFactor(fast=5, slow=10) ] combiner = FactorCombiner(factors, weights=[0.6, 0.4]) result = combiner.compute(data) # 检查结果列 assert 'momentum' in result.columns assert 'trend' in result.columns assert 'combined' in result.columns def test_combiner_method_rank_average(self): """测试rank_average组合方法""" dates = pd.date_range('2020-01-01', periods=100) data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100 }, index=dates) factors = [ MomentumFactor(n_days=20), TrendFactor() ] combiner = FactorCombiner(factors, method='rank_average') result = combiner.compute(data) # combined应该是排名平均值 assert 'combined' in result.columns class TestMomentumFactor: """测试动量因子""" def test_momentum_compute(self): """测试动量因子计算""" dates = pd.date_range('2020-01-01', periods=100) # 生成上升趋势数据 prices = 100 + np.arange(100) * 0.5 data = pd.DataFrame({'close': prices}, index=dates) factor = MomentumFactor(n_days=25, weighted=True) values = factor.compute(data) # 上升趋势应该有正的动量得分 assert values.iloc[-1] > 0 def test_crash_filter(self): """测试崩盘过滤""" dates = pd.date_range('2020-01-01', periods=100) # 生成正常数据,然后在末尾添加崩盘 prices = 100 + np.random.randn(100).cumsum() prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌 data = pd.DataFrame({'close': prices}, index=dates) factor = MomentumFactor(n_days=25, crash_filter=True) values = factor.compute(data) # 崩盘后动量得分应该被清零 assert values.iloc[-1] == 0.0 def test_simple_momentum(self): """测试简单动量(无加权,无崩盘过滤)""" dates = pd.date_range('2020-01-01', periods=100) prices = 100 + np.random.randn(100).cumsum() data = pd.DataFrame({'close': prices}, index=dates) factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False) values = factor.compute(data) # 简单动量应该是N日涨幅 expected = data['close'].pct_change(25) # 验证长度一致 assert len(values) == len(expected) class TestTrendFactor: """测试趋势因子""" def test_ma_cross(self): """测试MA交叉趋势""" dates = pd.date_range('2020-01-01', periods=100) # 生成上升趋势 prices = 100 + np.arange(100) * 0.5 data = pd.DataFrame({'close': prices}, index=dates) factor = TrendFactor(method='ma_cross', fast=5, slow=20) values = factor.compute(data) # 上升趋势应该有正的趋势强度 assert values.iloc[-1] > 0 def test_macd(self): """测试MACD趋势""" dates = pd.date_range('2020-01-01', periods=100) prices = 100 + np.random.randn(100).cumsum() data = pd.DataFrame({'close': prices}, index=dates) factor = TrendFactor(method='macd') values = factor.compute(data) # 检查计算结果 assert len(values) == len(data) class TestReversalFactor: """测试反转因子""" def test_rsi_reversal(self): """测试RSI反转信号""" dates = pd.date_range('2020-01-01', periods=100) # 生成超买数据(持续上涨) prices = 100 + np.arange(100) * 1.0 data = pd.DataFrame({'close': prices}, index=dates) factor = ReversalFactor(method='rsi', period=14, overbought=70) values = factor.compute(data) # RSI超过70应该产生负值(反转向下信号) assert values.iloc[-1] < 0 def test_rsi_oversold(self): """测试RSI超卖信号""" dates = pd.date_range('2020-01-01', periods=100) # 生成超卖数据(持续下跌) prices = 100 - np.arange(100) * 1.0 data = pd.DataFrame({'close': prices}, index=dates) factor = ReversalFactor(method='rsi', period=14, oversold=30) values = factor.compute(data) # RSI低于30应该产生正值(反转向上信号) assert values.iloc[-1] > 0 class TestVolatilityFactor: """测试波动率因子""" def test_std_volatility(self): """测试标准差波动率""" dates = pd.date_range('2020-01-01', periods=100) prices = 100 + np.random.randn(100).cumsum() data = pd.DataFrame({'close': prices}, index=dates) factor = VolatilityFactor(method='std', period=20) values = factor.compute(data) assert len(values) == len(data) def test_atr_volatility(self): """测试ATR波动率""" dates = pd.date_range('2020-01-01', periods=100) data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100, 'high': np.random.randn(100).cumsum() + 105, 'low': np.random.randn(100).cumsum() + 95 }, index=dates) factor = VolatilityFactor(method='atr', period=20) values = factor.compute(data) assert len(values) == len(data) if __name__ == '__main__': pytest.main([__file__, '-v'])