feat(factors): 实现因子层抽象
核心组件: - FactorBase: 因子抽象基类(compute方法 + 数据验证) - FactorRegistry: 因子注册器(注册/获取/按类别筛选) - FactorCombiner: 因子组合器(加权组合4种方法) 已实现因子: - MomentumFactor: 加权动量因子(含崩盘过滤) - TrendFactor: 趋势因子(MA交叉/MACD) - ReversalFactor: 反转因子(RSI/KDJ) - VolatilityFactor: 波动率因子(ATR/标准差) 测试覆盖:18个测试全部通过
This commit is contained in:
282
framework/tests/test_factors.py
Normal file
282
framework/tests/test_factors.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
因子层测试
|
||||
|
||||
测试FactorBase、FactorRegistry、FactorCombiner
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
|
||||
|
||||
class TestFactorBase:
|
||||
"""测试FactorBase抽象基类"""
|
||||
|
||||
def test_factor_meta(self):
|
||||
"""测试因子元信息"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
assert factor.name == "momentum"
|
||||
assert factor.category == "momentum"
|
||||
assert factor.params == {'n_days': 25, 'weighted': True, 'crash_filter': True}
|
||||
|
||||
def test_factor_repr(self):
|
||||
"""测试因子字符串表示"""
|
||||
factor = MomentumFactor(n_days=30)
|
||||
repr_str = repr(factor)
|
||||
assert "MomentumFactor" in repr_str
|
||||
assert "momentum" in repr_str
|
||||
|
||||
def test_validate_data(self):
|
||||
"""测试数据验证"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
|
||||
# 数据充足
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(data) == True
|
||||
|
||||
# 数据不足
|
||||
short_data = pd.DataFrame({
|
||||
'close': np.random.randn(10).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(short_data) == False
|
||||
|
||||
|
||||
class TestFactorRegistry:
|
||||
"""测试因子注册器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_factor(self):
|
||||
"""测试因子注册"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
assert 'momentum' in FactorRegistry.list()
|
||||
|
||||
def test_get_factor(self):
|
||||
"""测试获取因子实例"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
factor = FactorRegistry.get('momentum', n_days=30)
|
||||
assert isinstance(factor, MomentumFactor)
|
||||
assert factor.n_days == 30
|
||||
|
||||
def test_get_unknown_factor(self):
|
||||
"""测试获取未注册因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
with pytest.raises(KeyError):
|
||||
FactorRegistry.get('unknown_factor')
|
||||
|
||||
def test_list_by_category(self):
|
||||
"""测试按类别列出因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(TrendFactor)
|
||||
FactorRegistry.register(ReversalFactor)
|
||||
|
||||
categories = FactorRegistry.list_by_category()
|
||||
assert 'momentum' in categories
|
||||
assert 'trend' in categories
|
||||
assert 'reversal' in categories
|
||||
|
||||
def test_register_invalid_factor(self):
|
||||
"""测试注册无效因子"""
|
||||
with pytest.raises(TypeError):
|
||||
FactorRegistry.register(str) # 不是FactorBase子类
|
||||
|
||||
|
||||
class TestFactorCombiner:
|
||||
"""测试因子组合器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_combiner_init(self):
|
||||
"""测试组合器初始化"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor(method='ma_cross')
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
|
||||
|
||||
assert len(combiner.factors) == 2
|
||||
assert combiner.weights == [0.7, 0.3] # 未归一化时
|
||||
|
||||
def test_combiner_equal_weights(self):
|
||||
"""测试等权组合"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors) # 默认等权
|
||||
|
||||
# 权重应该归一化
|
||||
assert sum(combiner.weights) == 1.0
|
||||
|
||||
def test_combiner_compute(self):
|
||||
"""测试因子组合计算"""
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor(fast=5, slow=10)
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'momentum' in result.columns
|
||||
assert 'trend' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
# 检查加权列
|
||||
assert 'momentum_weighted' in result.columns
|
||||
assert 'trend_weighted' in result.columns
|
||||
|
||||
def test_combiner_method_max(self):
|
||||
"""测试max组合方法"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors, method='max')
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# combined应该是momentum和trend的最大值
|
||||
factor_cols = ['momentum', 'trend']
|
||||
expected_max = result[factor_cols].max(axis=1)
|
||||
pd.testing.assert_series_equal(result['combined'], expected_max, check_names=False)
|
||||
|
||||
|
||||
class TestMomentumFactor:
|
||||
"""测试动量因子"""
|
||||
|
||||
def test_momentum_compute(self):
|
||||
"""测试动量因子计算"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势数据
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的动量得分
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_crash_filter(self):
|
||||
"""测试崩盘过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成正常数据,然后在末尾添加崩盘
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
|
||||
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, crash_filter=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 崩盘后动量得分应该被清零
|
||||
assert values.iloc[-1] == 0.0
|
||||
|
||||
def test_simple_momentum(self):
|
||||
"""测试简单动量(无加权,无崩盘过滤)"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 简单动量应该是N日涨幅(无崩盘过滤时)
|
||||
expected = data['close'].pct_change(25)
|
||||
# 验证前25个值都是NaN
|
||||
assert values.iloc[:25].isna().all()
|
||||
# 验证后续值大致正确
|
||||
assert len(values) == len(expected)
|
||||
|
||||
|
||||
class TestTrendFactor:
|
||||
"""测试趋势因子"""
|
||||
|
||||
def test_ma_cross(self):
|
||||
"""测试MA交叉趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的趋势强度
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_macd(self):
|
||||
"""测试MACD趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='macd')
|
||||
values = factor.compute(data)
|
||||
|
||||
# 检查计算结果
|
||||
assert len(values) == len(data)
|
||||
assert not values.iloc[:26].isna().all() # MACD应该有值
|
||||
|
||||
|
||||
class TestReversalFactor:
|
||||
"""测试反转因子"""
|
||||
|
||||
def test_rsi_reversal(self):
|
||||
"""测试RSI反转信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超买数据(持续上涨)
|
||||
prices = 100 + np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, overbought=70)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI超过70应该产生负值(反转向下信号)
|
||||
assert values.iloc[-1] < 0
|
||||
|
||||
def test_rsi_oversold(self):
|
||||
"""测试RSI超卖信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超卖数据(持续下跌)
|
||||
prices = 100 - np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, oversold=30)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI低于30应该产生正值(反转向上信号)
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user