Files
etf/archive/framework/tests/test_risk.py
aszerW c905230a40 refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
2026-06-03 23:41:46 +08:00

325 lines
9.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
风控层测试
测试RiskControl、CallbackHook抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.risk import RiskControl, CallbackHook, Position
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
class TestPosition:
"""测试Position数据结构"""
def test_position_creation(self):
"""测试持仓创建"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
quantity=10
)
assert position.code == 'AAPL'
assert position.entry_price == 100.0
assert position.current_price == 105.0
def test_profit_ratio(self):
"""测试盈亏比例计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert position.profit_ratio == 0.10
def test_profit_amount(self):
"""测试盈亏金额计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now(),
quantity=10
)
assert position.profit_amount == 100.0
def test_position_repr(self):
"""测试持仓字符串表示"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
repr_str = repr(position)
assert 'AAPL' in repr_str
assert '5.00%' in repr_str
class TestStopLossControl:
"""测试止损控制"""
def test_stop_loss_init(self):
"""测试初始化"""
control = StopLossControl(threshold=-0.05)
assert control.threshold == -0.05
assert control.name == "stop_loss"
def test_stop_loss_check(self):
"""测试止损检查"""
control = StopLossControl(threshold=-0.05)
# 盈利持仓应该通过
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert control.check(profit_position) == True
# 亏损持仓应该触发止损
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=90.0, # 亏损10%
entry_time=datetime.now()
)
# 亏损超过阈值应该触发
assert control.check(loss_position) == False
def test_trailing_stop_loss(self):
"""测试跟踪止损"""
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 初始最高价=入场价100当前价105
# 更新后最高价=105
control.check(position)
# 从最高价回撤不超过阈值应该通过
position.current_price = 103.0 # 回撤约2%
assert control.check(position) == True
# 回撤超过阈值应该触发
position.current_price = 100.0 # 回撤约5%
assert control.check(position) == False
class TestPositionLimitControl:
"""测试仓位限制控制"""
def test_position_limit_init(self):
"""测试初始化"""
control = PositionLimitControl(max_position=0.33)
assert control.max_position == 0.33
assert control.name == "position_limit"
def test_position_limit_check(self):
"""测试仓位限制检查"""
control = PositionLimitControl(max_position=0.33)
# 正常仓位应该通过
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.20
)
assert control.check(normal_position) == True
# 超限仓位应该触发
over_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.50
)
assert control.check(over_position) == False
class TestPremiumControl:
"""测试溢价控制"""
def test_premium_control_init(self):
"""测试初始化"""
control = PremiumControl(threshold=0.10)
assert control.threshold == 0.10
assert control.name == "premium"
def test_premium_filter_mode(self):
"""测试溢价过滤模式"""
control = PremiumControl(threshold=0.10, mode='filter')
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 正常溢价应该通过
assert control.check(position, premium=0.05) == True
# 高溢价应该被过滤
assert control.check(position, premium=0.15) == False
class TestCallbackHook:
"""测试回调钩子"""
def test_callback_hook_init(self):
"""测试初始化"""
hook = CallbackHook()
assert len(hook.list_hooks()) == 6
def test_register_callback(self):
"""测试注册回调"""
hook = CallbackHook()
def my_callback(code, price, **kwargs):
return True
hook.register('before_entry', my_callback)
callbacks = hook.get_callbacks('before_entry')
assert len(callbacks) == 1
def test_trigger_before_entry(self):
"""测试触发入场前回调"""
hook = CallbackHook()
def always_pass(code, price, **kwargs):
return True
def always_block(code, price, **kwargs):
return False
hook.register('before_entry', always_pass)
hook.register('before_entry', always_block)
# before_entry要求所有回调返回True才允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == False
def test_trigger_dynamic_stoploss(self):
"""测试触发动态止损回调"""
hook = CallbackHook()
def stoploss_5p(position):
return -0.05
def stoploss_3p(position):
return -0.03
hook.register('dynamic_stoploss', stoploss_5p)
hook.register('dynamic_stoploss', stoploss_3p)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# dynamic_stoploss返回最小止损值最严格
result = hook.trigger('dynamic_stoploss', position)
assert result == -0.05
def test_trigger_custom_exit(self):
"""测试触发自定义出场回调"""
hook = CallbackHook()
def exit_on_loss(position):
return position.profit_ratio < -0.05
def exit_on_profit(position):
return position.profit_ratio > 0.20
hook.register('custom_exit', exit_on_loss)
hook.register('custom_exit', exit_on_profit)
# custom_exit任一回调触发即可
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=125.0, # 盈利25%
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', profit_position)
assert result == True
# 未触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', normal_position)
assert result == False
def test_clear_hooks(self):
"""测试清空回调"""
hook = CallbackHook()
def callback(code, price, **kwargs):
return True
hook.register('before_entry', callback)
# 清空特定钩子
hook.clear('before_entry')
assert len(hook.get_callbacks('before_entry')) == 0
# 清空所有钩子
hook.register('before_entry', callback)
hook.register('after_entry', callback)
hook.clear()
assert len(hook.get_callbacks('before_entry')) == 0
assert len(hook.get_callbacks('after_entry')) == 0
def test_default_behavior(self):
"""测试默认行为"""
hook = CallbackHook() # 无注册回调
# before_entry默认允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == True
# dynamic_stoploss默认值
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
assert result == -0.05
# custom_exit默认不出场
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', position)
assert result == False
if __name__ == '__main__':
pytest.main([__file__, '-v'])