- 测试文件改用strategies.shared的具体实现 - 新增framework_comparison_test.py对比新旧实现结果 - 因子计算相关系数达到1.0000,差异为0.000000 - 79个单元测试全部通过
325 lines
9.4 KiB
Python
325 lines
9.4 KiB
Python
"""
|
||
风控层测试
|
||
|
||
测试RiskControl、CallbackHook抽象接口
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import pytest
|
||
from datetime import datetime
|
||
|
||
from framework.risk import RiskControl, CallbackHook, Position
|
||
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
|
||
|
||
|
||
class TestPosition:
|
||
"""测试Position数据结构"""
|
||
|
||
def test_position_creation(self):
|
||
"""测试持仓创建"""
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now(),
|
||
quantity=10
|
||
)
|
||
|
||
assert position.code == 'AAPL'
|
||
assert position.entry_price == 100.0
|
||
assert position.current_price == 105.0
|
||
|
||
def test_profit_ratio(self):
|
||
"""测试盈亏比例计算"""
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=110.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
|
||
assert position.profit_ratio == 0.10
|
||
|
||
def test_profit_amount(self):
|
||
"""测试盈亏金额计算"""
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=110.0,
|
||
entry_time=datetime.now(),
|
||
quantity=10
|
||
)
|
||
|
||
assert position.profit_amount == 100.0
|
||
|
||
def test_position_repr(self):
|
||
"""测试持仓字符串表示"""
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
|
||
repr_str = repr(position)
|
||
assert 'AAPL' in repr_str
|
||
assert '5.00%' in repr_str
|
||
|
||
|
||
class TestStopLossControl:
|
||
"""测试止损控制"""
|
||
|
||
def test_stop_loss_init(self):
|
||
"""测试初始化"""
|
||
control = StopLossControl(threshold=-0.05)
|
||
assert control.threshold == -0.05
|
||
assert control.name == "stop_loss"
|
||
|
||
def test_stop_loss_check(self):
|
||
"""测试止损检查"""
|
||
control = StopLossControl(threshold=-0.05)
|
||
|
||
# 盈利持仓应该通过
|
||
profit_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=110.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
assert control.check(profit_position) == True
|
||
|
||
# 亏损持仓应该触发止损
|
||
loss_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=90.0, # 亏损10%
|
||
entry_time=datetime.now()
|
||
)
|
||
# 亏损超过阈值应该触发
|
||
assert control.check(loss_position) == False
|
||
|
||
def test_trailing_stop_loss(self):
|
||
"""测试跟踪止损"""
|
||
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
|
||
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
|
||
# 初始最高价=入场价100,当前价105
|
||
# 更新后最高价=105
|
||
control.check(position)
|
||
|
||
# 从最高价回撤不超过阈值应该通过
|
||
position.current_price = 103.0 # 回撤约2%
|
||
assert control.check(position) == True
|
||
|
||
# 回撤超过阈值应该触发
|
||
position.current_price = 100.0 # 回撤约5%
|
||
assert control.check(position) == False
|
||
|
||
|
||
class TestPositionLimitControl:
|
||
"""测试仓位限制控制"""
|
||
|
||
def test_position_limit_init(self):
|
||
"""测试初始化"""
|
||
control = PositionLimitControl(max_position=0.33)
|
||
assert control.max_position == 0.33
|
||
assert control.name == "position_limit"
|
||
|
||
def test_position_limit_check(self):
|
||
"""测试仓位限制检查"""
|
||
control = PositionLimitControl(max_position=0.33)
|
||
|
||
# 正常仓位应该通过
|
||
normal_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now(),
|
||
weight=0.20
|
||
)
|
||
assert control.check(normal_position) == True
|
||
|
||
# 超限仓位应该触发
|
||
over_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now(),
|
||
weight=0.50
|
||
)
|
||
assert control.check(over_position) == False
|
||
|
||
|
||
class TestPremiumControl:
|
||
"""测试溢价控制"""
|
||
|
||
def test_premium_control_init(self):
|
||
"""测试初始化"""
|
||
control = PremiumControl(threshold=0.10)
|
||
assert control.threshold == 0.10
|
||
assert control.name == "premium"
|
||
|
||
def test_premium_filter_mode(self):
|
||
"""测试溢价过滤模式"""
|
||
control = PremiumControl(threshold=0.10, mode='filter')
|
||
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
|
||
# 正常溢价应该通过
|
||
assert control.check(position, premium=0.05) == True
|
||
|
||
# 高溢价应该被过滤
|
||
assert control.check(position, premium=0.15) == False
|
||
|
||
|
||
class TestCallbackHook:
|
||
"""测试回调钩子"""
|
||
|
||
def test_callback_hook_init(self):
|
||
"""测试初始化"""
|
||
hook = CallbackHook()
|
||
assert len(hook.list_hooks()) == 6
|
||
|
||
def test_register_callback(self):
|
||
"""测试注册回调"""
|
||
hook = CallbackHook()
|
||
|
||
def my_callback(code, price, **kwargs):
|
||
return True
|
||
|
||
hook.register('before_entry', my_callback)
|
||
|
||
callbacks = hook.get_callbacks('before_entry')
|
||
assert len(callbacks) == 1
|
||
|
||
def test_trigger_before_entry(self):
|
||
"""测试触发入场前回调"""
|
||
hook = CallbackHook()
|
||
|
||
def always_pass(code, price, **kwargs):
|
||
return True
|
||
|
||
def always_block(code, price, **kwargs):
|
||
return False
|
||
|
||
hook.register('before_entry', always_pass)
|
||
hook.register('before_entry', always_block)
|
||
|
||
# before_entry要求所有回调返回True才允许
|
||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||
assert result == False
|
||
|
||
def test_trigger_dynamic_stoploss(self):
|
||
"""测试触发动态止损回调"""
|
||
hook = CallbackHook()
|
||
|
||
def stoploss_5p(position):
|
||
return -0.05
|
||
|
||
def stoploss_3p(position):
|
||
return -0.03
|
||
|
||
hook.register('dynamic_stoploss', stoploss_5p)
|
||
hook.register('dynamic_stoploss', stoploss_3p)
|
||
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
|
||
# dynamic_stoploss返回最小止损值(最严格)
|
||
result = hook.trigger('dynamic_stoploss', position)
|
||
assert result == -0.05
|
||
|
||
def test_trigger_custom_exit(self):
|
||
"""测试触发自定义出场回调"""
|
||
hook = CallbackHook()
|
||
|
||
def exit_on_loss(position):
|
||
return position.profit_ratio < -0.05
|
||
|
||
def exit_on_profit(position):
|
||
return position.profit_ratio > 0.20
|
||
|
||
hook.register('custom_exit', exit_on_loss)
|
||
hook.register('custom_exit', exit_on_profit)
|
||
|
||
# custom_exit任一回调触发即可
|
||
profit_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=125.0, # 盈利25%
|
||
entry_time=datetime.now()
|
||
)
|
||
result = hook.trigger('custom_exit', profit_position)
|
||
assert result == True
|
||
|
||
# 未触发
|
||
normal_position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=110.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
result = hook.trigger('custom_exit', normal_position)
|
||
assert result == False
|
||
|
||
def test_clear_hooks(self):
|
||
"""测试清空回调"""
|
||
hook = CallbackHook()
|
||
|
||
def callback(code, price, **kwargs):
|
||
return True
|
||
|
||
hook.register('before_entry', callback)
|
||
|
||
# 清空特定钩子
|
||
hook.clear('before_entry')
|
||
assert len(hook.get_callbacks('before_entry')) == 0
|
||
|
||
# 清空所有钩子
|
||
hook.register('before_entry', callback)
|
||
hook.register('after_entry', callback)
|
||
hook.clear()
|
||
assert len(hook.get_callbacks('before_entry')) == 0
|
||
assert len(hook.get_callbacks('after_entry')) == 0
|
||
|
||
def test_default_behavior(self):
|
||
"""测试默认行为"""
|
||
hook = CallbackHook() # 无注册回调
|
||
|
||
# before_entry默认允许
|
||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||
assert result == True
|
||
|
||
# dynamic_stoploss默认值
|
||
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
|
||
assert result == -0.05
|
||
|
||
# custom_exit默认不出场
|
||
position = Position(
|
||
code='AAPL',
|
||
entry_price=100.0,
|
||
current_price=105.0,
|
||
entry_time=datetime.now()
|
||
)
|
||
result = hook.trigger('custom_exit', position)
|
||
assert result == False
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pytest.main([__file__, '-v']) |