test: 更新测试以验证框架重构正确性

- 测试文件改用strategies.shared的具体实现
- 新增framework_comparison_test.py对比新旧实现结果
- 因子计算相关系数达到1.0000,差异为0.000000
- 79个单元测试全部通过
This commit is contained in:
2026-05-11 23:10:02 +08:00
parent de31271ab3
commit fc59836ec3
7 changed files with 1066 additions and 618 deletions

View File

@@ -1,100 +1,165 @@
"""
框架核心测试
集成测试
验证框架整体功能
测试框架与定制组件的集成
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
from framework.signals import TopNSelector, TrendFollower, ReversalTrader
from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
from framework.execution import Portfolio, BacktestExecutor
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
from strategies.shared.signals.selectors import TopNSelector
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
from strategies.rotation.strategy import RotationStrategy
class TestFrameworkIntegration:
"""测试框架集成"""
class TestFactorIntegration:
"""测试因子集成"""
def test_rotation_strategy_workflow(self):
"""测试轮动策略完整流程"""
# 清空注册表
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_and_use_custom_factor(self):
"""测试注册并使用定制因子"""
FactorRegistry.register(MomentumFactor)
# 1. 注册因子
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
values = factor.compute(data)
assert len(values) == len(data)
def test_combiner_with_custom_factors(self):
"""测试组合器使用定制因子"""
FactorRegistry.register(MomentumFactor)
FactorRegistry.register(VolatilityFactor)
# 2. 创建因子组合
factors = FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True),
], weights=[1.0])
# 3. 生成测试数据(需要'close'列)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'code1': np.random.randn(100).cumsum() + 100,
'code2': np.random.randn(100).cumsum() + 100,
'code3': np.random.randn(100).cumsum() + 100,
}, index=dates)
# 4. 计算因子
factor_result = factors.compute(data)
# 5. 生成信号
selector = TopNSelector(select_num=2)
signals = selector.generate(factor_result)
# 6. 验证结果
assert 'signal' in signals.columns
assert len(signals) == len(data)
def test_trend_strategy_workflow(self):
"""测试趋势策略完整流程"""
FactorRegistry.clear()
FactorRegistry.register(TrendFactor)
factors = FactorCombiner([
FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20),
])
momentum = FactorRegistry.get('momentum', n_days=25)
volatility = FactorRegistry.get('volatility', method='std', period=20)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factor_result = factors.compute(data)
follower = TrendFollower(entry_threshold=0.02)
signals = follower.generate(factor_result)
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
result = combiner.compute(data)
assert 'signal' in signals.columns
assert 'momentum' in result.columns
assert 'volatility' in result.columns
assert 'combined' in result.columns
class TestSignalIntegration:
"""测试信号生成器集成"""
def test_callbacks_workflow(self):
"""测试回调钩子完整流程"""
def test_custom_signal_generator_with_factors(self):
"""测试定制信号生成器与因子集成"""
dates = pd.date_range('2020-01-01', periods=50)
factor_data = pd.DataFrame({
'momentum_A': np.random.randn(50),
'momentum_B': np.random.randn(50),
'momentum_C': np.random.randn(50),
}, index=dates)
selector = TopNSelector(select_num=2, min_score=0.0)
result = selector.generate(factor_data)
assert 'signal' in result.columns
class TestRiskIntegration:
"""测试风控组件集成"""
def test_custom_risk_control(self):
"""测试定制风控组件"""
control = StopLossControl(threshold=-0.05, trailing=True)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=pd.Timestamp.now()
)
# 首次检查,设置最高价
assert control.check(position) == True
# 价格下跌,触发跟踪止损
position.current_price = 100.0
assert control.check(position) == False
def test_callback_hook_with_custom_callback(self):
"""测试回调钩子与定制回调集成"""
hook = CallbackHook()
# 注册回调
hook.register('before_entry', premium_filter_callback(0.10))
hook.register('dynamic_stoploss', lambda pos: -0.05)
# 注册定制回调
callback = premium_filter_callback(threshold=0.10)
hook.register('before_entry', callback)
# 测试入场前回调
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
# 正常溢价通过
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
assert result == True
# 测试动态止损
position = Position(
code='code1',
entry_price=100.0,
entry_date=pd.Timestamp('2020-01-01'),
current_price=95.0,
current_date=pd.Timestamp('2020-01-10'),
quantity=100,
weight=0.33
)
stoploss = hook.trigger('dynamic_stoploss', position)
assert stoploss == -0.05
# 高溢价拒绝
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
assert result == False
class TestStrategyIntegration:
"""测试策略集成"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_rotation_strategy_full_flow(self):
"""测试轮动策略完整流程"""
strategy = RotationStrategy()
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
# 运行策略
result = strategy.run(data)
assert 'signal' in result.columns
def test_strategy_with_backtest_executor(self):
"""测试策略与回测执行器集成"""
FactorRegistry.clear()
strategy = RotationStrategy()
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
signals = strategy.run(data)
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':