Files
etf/archive/framework/tests/test_factors.py
aszerW c905230a40 refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
2026-06-03 23:41:46 +08:00

288 lines
9.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
因子层测试
测试FactorBase、FactorRegistry、FactorCombiner抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
class TestFactorBase:
"""测试FactorBase抽象基类"""
def test_factor_meta(self):
"""测试因子元信息"""
factor = MomentumFactor(n_days=25)
assert factor.name == "momentum"
assert factor.category == "momentum"
def test_factor_repr(self):
"""测试因子字符串表示"""
factor = MomentumFactor(n_days=30)
repr_str = repr(factor)
assert "MomentumFactor" in repr_str
def test_validate_data(self):
"""测试数据验证"""
factor = MomentumFactor(n_days=25)
# 数据充足
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
})
assert factor.validate_data(data) == True
# 数据不足
short_data = pd.DataFrame({
'close': np.random.randn(10).cumsum() + 100
})
assert factor.validate_data(short_data) == False
class TestFactorRegistry:
"""测试因子注册器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_factor(self):
"""测试因子注册"""
FactorRegistry.register(MomentumFactor)
assert 'momentum' in FactorRegistry.list_factors()
def test_get_factor(self):
"""测试获取因子实例"""
FactorRegistry.register(MomentumFactor)
factor = FactorRegistry.get('momentum', n_days=30)
assert isinstance(factor, MomentumFactor)
assert factor.n_days == 30
def test_get_unknown_factor(self):
"""测试获取未注册因子"""
with pytest.raises(ValueError):
FactorRegistry.get('unknown_factor')
def test_get_category(self):
"""测试获取因子类别"""
FactorRegistry.register(MomentumFactor)
category = FactorRegistry.get_category('momentum')
assert category == 'momentum'
class TestFactorCombiner:
"""测试因子组合器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_combiner_init(self):
"""测试组合器初始化"""
factors = [
MomentumFactor(n_days=25),
TrendFactor(method='ma_cross')
]
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
assert len(combiner.get_factor_names()) == 2
def test_combiner_equal_weights(self):
"""测试等权组合"""
factors = [
MomentumFactor(n_days=25),
TrendFactor()
]
combiner = FactorCombiner(factors) # 默认等权
# 权重应该归一化
assert sum(combiner._weights) == 1.0
def test_combiner_compute(self):
"""测试因子组合计算"""
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor(fast=5, slow=10)
]
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
result = combiner.compute(data)
# 检查结果列
assert 'momentum' in result.columns
assert 'trend' in result.columns
assert 'combined' in result.columns
def test_combiner_method_rank_average(self):
"""测试rank_average组合方法"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor()
]
combiner = FactorCombiner(factors, method='rank_average')
result = combiner.compute(data)
# combined应该是排名平均值
assert 'combined' in result.columns
class TestMomentumFactor:
"""测试动量因子"""
def test_momentum_compute(self):
"""测试动量因子计算"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势数据
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=True)
values = factor.compute(data)
# 上升趋势应该有正的动量得分
assert values.iloc[-1] > 0
def test_crash_filter(self):
"""测试崩盘过滤"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成正常数据,然后在末尾添加崩盘
prices = 100 + np.random.randn(100).cumsum()
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, crash_filter=True)
values = factor.compute(data)
# 崩盘后动量得分应该被清零
assert values.iloc[-1] == 0.0
def test_simple_momentum(self):
"""测试简单动量(无加权,无崩盘过滤)"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
values = factor.compute(data)
# 简单动量应该是N日涨幅
expected = data['close'].pct_change(25)
# 验证长度一致
assert len(values) == len(expected)
class TestTrendFactor:
"""测试趋势因子"""
def test_ma_cross(self):
"""测试MA交叉趋势"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
values = factor.compute(data)
# 上升趋势应该有正的趋势强度
assert values.iloc[-1] > 0
def test_macd(self):
"""测试MACD趋势"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='macd')
values = factor.compute(data)
# 检查计算结果
assert len(values) == len(data)
class TestReversalFactor:
"""测试反转因子"""
def test_rsi_reversal(self):
"""测试RSI反转信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超买数据(持续上涨)
prices = 100 + np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, overbought=70)
values = factor.compute(data)
# RSI超过70应该产生负值反转向下信号
assert values.iloc[-1] < 0
def test_rsi_oversold(self):
"""测试RSI超卖信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超卖数据(持续下跌)
prices = 100 - np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, oversold=30)
values = factor.compute(data)
# RSI低于30应该产生正值反转向上信号
assert values.iloc[-1] > 0
class TestVolatilityFactor:
"""测试波动率因子"""
def test_std_volatility(self):
"""测试标准差波动率"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = VolatilityFactor(method='std', period=20)
values = factor.compute(data)
assert len(values) == len(data)
def test_atr_volatility(self):
"""测试ATR波动率"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factor = VolatilityFactor(method='atr', period=20)
values = factor.compute(data)
assert len(values) == len(data)
if __name__ == '__main__':
pytest.main([__file__, '-v'])