test: 更新测试以验证框架重构正确性
- 测试文件改用strategies.shared的具体实现 - 新增framework_comparison_test.py对比新旧实现结果 - 因子计算相关系数达到1.0000,差异为0.000000 - 79个单元测试全部通过
This commit is contained in:
@@ -1,100 +1,165 @@
|
||||
"""
|
||||
框架核心测试
|
||||
集成测试
|
||||
|
||||
验证框架整体功能
|
||||
测试框架与定制组件的集成
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
from framework.signals import TopNSelector, TrendFollower, ReversalTrader
|
||||
from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback
|
||||
from framework.factors import FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import CallbackHook, Position
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.execution import Portfolio, BacktestExecutor
|
||||
|
||||
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
class TestFrameworkIntegration:
|
||||
"""测试框架集成"""
|
||||
class TestFactorIntegration:
|
||||
"""测试因子集成"""
|
||||
|
||||
def test_rotation_strategy_workflow(self):
|
||||
"""测试轮动策略完整流程"""
|
||||
# 清空注册表
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_and_use_custom_factor(self):
|
||||
"""测试注册并使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
# 1. 注册因子
|
||||
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
def test_combiner_with_custom_factors(self):
|
||||
"""测试组合器使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(VolatilityFactor)
|
||||
|
||||
# 2. 创建因子组合
|
||||
factors = FactorCombiner([
|
||||
FactorRegistry.get('momentum', n_days=25, crash_filter=True),
|
||||
], weights=[1.0])
|
||||
|
||||
# 3. 生成测试数据(需要'close'列)
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'code1': np.random.randn(100).cumsum() + 100,
|
||||
'code2': np.random.randn(100).cumsum() + 100,
|
||||
'code3': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
# 4. 计算因子
|
||||
factor_result = factors.compute(data)
|
||||
|
||||
# 5. 生成信号
|
||||
selector = TopNSelector(select_num=2)
|
||||
signals = selector.generate(factor_result)
|
||||
|
||||
# 6. 验证结果
|
||||
assert 'signal' in signals.columns
|
||||
assert len(signals) == len(data)
|
||||
|
||||
def test_trend_strategy_workflow(self):
|
||||
"""测试趋势策略完整流程"""
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(TrendFactor)
|
||||
|
||||
factors = FactorCombiner([
|
||||
FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20),
|
||||
])
|
||||
momentum = FactorRegistry.get('momentum', n_days=25)
|
||||
volatility = FactorRegistry.get('volatility', method='std', period=20)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factor_result = factors.compute(data)
|
||||
follower = TrendFollower(entry_threshold=0.02)
|
||||
signals = follower.generate(factor_result)
|
||||
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
|
||||
result = combiner.compute(data)
|
||||
|
||||
assert 'signal' in signals.columns
|
||||
assert 'momentum' in result.columns
|
||||
assert 'volatility' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
|
||||
class TestSignalIntegration:
|
||||
"""测试信号生成器集成"""
|
||||
|
||||
def test_callbacks_workflow(self):
|
||||
"""测试回调钩子完整流程"""
|
||||
def test_custom_signal_generator_with_factors(self):
|
||||
"""测试定制信号生成器与因子集成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
factor_data = pd.DataFrame({
|
||||
'momentum_A': np.random.randn(50),
|
||||
'momentum_B': np.random.randn(50),
|
||||
'momentum_C': np.random.randn(50),
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||||
result = selector.generate(factor_data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestRiskIntegration:
|
||||
"""测试风控组件集成"""
|
||||
|
||||
def test_custom_risk_control(self):
|
||||
"""测试定制风控组件"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=pd.Timestamp.now()
|
||||
)
|
||||
|
||||
# 首次检查,设置最高价
|
||||
assert control.check(position) == True
|
||||
|
||||
# 价格下跌,触发跟踪止损
|
||||
position.current_price = 100.0
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_callback_hook_with_custom_callback(self):
|
||||
"""测试回调钩子与定制回调集成"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册回调
|
||||
hook.register('before_entry', premium_filter_callback(0.10))
|
||||
hook.register('dynamic_stoploss', lambda pos: -0.05)
|
||||
# 注册定制回调
|
||||
callback = premium_filter_callback(threshold=0.10)
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 测试入场前回调
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
# 正常溢价通过
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 测试动态止损
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=pd.Timestamp('2020-01-01'),
|
||||
current_price=95.0,
|
||||
current_date=pd.Timestamp('2020-01-10'),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||
assert stoploss == -0.05
|
||||
# 高溢价拒绝
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
|
||||
class TestStrategyIntegration:
|
||||
"""测试策略集成"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_rotation_strategy_full_flow(self):
|
||||
"""测试轮动策略完整流程"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
# 运行策略
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_strategy_with_backtest_executor(self):
|
||||
"""测试策略与回测执行器集成"""
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
signals = strategy.run(data)
|
||||
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user