""" 风控层测试 测试RiskControl、CallbackHook抽象接口 """ import pandas as pd import numpy as np import pytest from datetime import datetime from framework.risk import RiskControl, CallbackHook, Position from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl class TestPosition: """测试Position数据结构""" def test_position_creation(self): """测试持仓创建""" position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now(), quantity=10 ) assert position.code == 'AAPL' assert position.entry_price == 100.0 assert position.current_price == 105.0 def test_profit_ratio(self): """测试盈亏比例计算""" position = Position( code='AAPL', entry_price=100.0, current_price=110.0, entry_time=datetime.now() ) assert position.profit_ratio == 0.10 def test_profit_amount(self): """测试盈亏金额计算""" position = Position( code='AAPL', entry_price=100.0, current_price=110.0, entry_time=datetime.now(), quantity=10 ) assert position.profit_amount == 100.0 def test_position_repr(self): """测试持仓字符串表示""" position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now() ) repr_str = repr(position) assert 'AAPL' in repr_str assert '5.00%' in repr_str class TestStopLossControl: """测试止损控制""" def test_stop_loss_init(self): """测试初始化""" control = StopLossControl(threshold=-0.05) assert control.threshold == -0.05 assert control.name == "stop_loss" def test_stop_loss_check(self): """测试止损检查""" control = StopLossControl(threshold=-0.05) # 盈利持仓应该通过 profit_position = Position( code='AAPL', entry_price=100.0, current_price=110.0, entry_time=datetime.now() ) assert control.check(profit_position) == True # 亏损持仓应该触发止损 loss_position = Position( code='AAPL', entry_price=100.0, current_price=90.0, # 亏损10% entry_time=datetime.now() ) # 亏损超过阈值应该触发 assert control.check(loss_position) == False def test_trailing_stop_loss(self): """测试跟踪止损""" control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03) position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now() ) # 初始最高价=入场价100,当前价105 # 更新后最高价=105 control.check(position) # 从最高价回撤不超过阈值应该通过 position.current_price = 103.0 # 回撤约2% assert control.check(position) == True # 回撤超过阈值应该触发 position.current_price = 100.0 # 回撤约5% assert control.check(position) == False class TestPositionLimitControl: """测试仓位限制控制""" def test_position_limit_init(self): """测试初始化""" control = PositionLimitControl(max_position=0.33) assert control.max_position == 0.33 assert control.name == "position_limit" def test_position_limit_check(self): """测试仓位限制检查""" control = PositionLimitControl(max_position=0.33) # 正常仓位应该通过 normal_position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now(), weight=0.20 ) assert control.check(normal_position) == True # 超限仓位应该触发 over_position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now(), weight=0.50 ) assert control.check(over_position) == False class TestPremiumControl: """测试溢价控制""" def test_premium_control_init(self): """测试初始化""" control = PremiumControl(threshold=0.10) assert control.threshold == 0.10 assert control.name == "premium" def test_premium_filter_mode(self): """测试溢价过滤模式""" control = PremiumControl(threshold=0.10, mode='filter') position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now() ) # 正常溢价应该通过 assert control.check(position, premium=0.05) == True # 高溢价应该被过滤 assert control.check(position, premium=0.15) == False class TestCallbackHook: """测试回调钩子""" def test_callback_hook_init(self): """测试初始化""" hook = CallbackHook() assert len(hook.list_hooks()) == 6 def test_register_callback(self): """测试注册回调""" hook = CallbackHook() def my_callback(code, price, **kwargs): return True hook.register('before_entry', my_callback) callbacks = hook.get_callbacks('before_entry') assert len(callbacks) == 1 def test_trigger_before_entry(self): """测试触发入场前回调""" hook = CallbackHook() def always_pass(code, price, **kwargs): return True def always_block(code, price, **kwargs): return False hook.register('before_entry', always_pass) hook.register('before_entry', always_block) # before_entry要求所有回调返回True才允许 result = hook.trigger('before_entry', 'AAPL', 100.0) assert result == False def test_trigger_dynamic_stoploss(self): """测试触发动态止损回调""" hook = CallbackHook() def stoploss_5p(position): return -0.05 def stoploss_3p(position): return -0.03 hook.register('dynamic_stoploss', stoploss_5p) hook.register('dynamic_stoploss', stoploss_3p) position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now() ) # dynamic_stoploss返回最小止损值(最严格) result = hook.trigger('dynamic_stoploss', position) assert result == -0.05 def test_trigger_custom_exit(self): """测试触发自定义出场回调""" hook = CallbackHook() def exit_on_loss(position): return position.profit_ratio < -0.05 def exit_on_profit(position): return position.profit_ratio > 0.20 hook.register('custom_exit', exit_on_loss) hook.register('custom_exit', exit_on_profit) # custom_exit任一回调触发即可 profit_position = Position( code='AAPL', entry_price=100.0, current_price=125.0, # 盈利25% entry_time=datetime.now() ) result = hook.trigger('custom_exit', profit_position) assert result == True # 未触发 normal_position = Position( code='AAPL', entry_price=100.0, current_price=110.0, entry_time=datetime.now() ) result = hook.trigger('custom_exit', normal_position) assert result == False def test_clear_hooks(self): """测试清空回调""" hook = CallbackHook() def callback(code, price, **kwargs): return True hook.register('before_entry', callback) # 清空特定钩子 hook.clear('before_entry') assert len(hook.get_callbacks('before_entry')) == 0 # 清空所有钩子 hook.register('before_entry', callback) hook.register('after_entry', callback) hook.clear() assert len(hook.get_callbacks('before_entry')) == 0 assert len(hook.get_callbacks('after_entry')) == 0 def test_default_behavior(self): """测试默认行为""" hook = CallbackHook() # 无注册回调 # before_entry默认允许 result = hook.trigger('before_entry', 'AAPL', 100.0) assert result == True # dynamic_stoploss默认值 result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05) assert result == -0.05 # custom_exit默认不出场 position = Position( code='AAPL', entry_price=100.0, current_price=105.0, entry_time=datetime.now() ) result = hook.trigger('custom_exit', position) assert result == False if __name__ == '__main__': pytest.main([__file__, '-v'])