""" 策略层测试 测试StrategyBase抽象接口 """ import pandas as pd import numpy as np import pytest from datetime import datetime from framework.strategy import StrategyBase from framework.factors import FactorBase, FactorRegistry, FactorCombiner from framework.signals import SignalGenerator from framework.risk import Position class TestStrategyBase: """测试StrategyBase抽象基类""" def test_strategy_config_override(self): """测试配置覆盖类属性""" from strategies.rotation.strategy import RotationStrategy strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03}) assert strategy.select_num == 5 assert strategy.stoploss == -0.03 def test_strategy_default_values(self): """测试默认值""" from strategies.rotation.strategy import RotationStrategy strategy = RotationStrategy() assert strategy.select_num == 3 assert strategy.stoploss == -0.05 def test_strategy_repr(self): """测试字符串表示""" from strategies.rotation.strategy import RotationStrategy strategy = RotationStrategy() repr_str = repr(strategy) assert 'RotationStrategy' in repr_str assert 'rotation' in repr_str def test_strategy_interface_version(self): """测试接口版本""" from strategies.rotation.strategy import RotationStrategy strategy = RotationStrategy() assert strategy.INTERFACE_VERSION == 1 class TestRotationStrategy: """测试轮动策略""" def test_rotation_strategy_init(self): """测试初始化""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 检查因子初始化 assert strategy._factors is not None assert strategy._signal_gen is not None def test_rotation_strategy_run(self): """测试策略运行""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 生成测试数据 dates = pd.date_range('2020-01-01', periods=100) data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100 }, index=dates) result = strategy.run(data) # 检查结果 assert 'signal' in result.columns def test_dynamic_stoploss(self): """测试动态止损""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 测试不同持仓时间 position_5days = Position( code='AAPL', entry_price=100.0, current_price=95.0, entry_time=datetime.now() - pd.Timedelta(days=5) ) # 5天持仓止损阈值应该为-0.05 stoploss = strategy.dynamic_stoploss(position_5days) assert stoploss == -0.05 def test_before_entry_premium_filter(self): """测试入场前溢价过滤""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 正常溢价应该通过 result = strategy.before_entry('AAPL', 100.0, premium=0.05) assert result == True # 高溢价应该被拒绝 result = strategy.before_entry('AAPL', 100.0, premium=0.15) assert result == False def test_custom_exit(self): """测试自定义出场""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 正常盈亏不触发 normal_position = Position( code='AAPL', entry_price=100.0, current_price=95.0, entry_time=datetime.now() ) result = strategy.custom_exit(normal_position) assert result == False # 大亏损触发出场 loss_position = Position( code='AAPL', entry_price=100.0, current_price=85.0, entry_time=datetime.now() ) result = strategy.custom_exit(loss_position) assert result == True class TestStrategyCallbacks: """测试策略回调机制""" def test_callback_registration(self): """测试回调自动注册""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 检查回调是否注册 callbacks = strategy._callbacks.get_callbacks('before_entry') assert len(callbacks) > 0 callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss') assert len(callbacks) > 0 def test_callback_trigger_in_run(self): """测试回调在策略运行中触发""" from strategies.rotation.strategy import RotationStrategy FactorRegistry.clear() strategy = RotationStrategy() # 添加自定义回调 call_count = {'count': 0} def counting_callback(code, price, **kwargs): call_count['count'] += 1 return True strategy._callbacks.register('before_entry', counting_callback) # 运行策略 dates = pd.date_range('2020-01-01', periods=100) data = pd.DataFrame({ 'close': np.random.randn(100).cumsum() + 100 }, index=dates) strategy.run(data) if __name__ == '__main__': pytest.main([__file__, '-v'])