test: 更新测试以验证框架重构正确性
- 测试文件改用strategies.shared的具体实现 - 新增framework_comparison_test.py对比新旧实现结果 - 因子计算相关系数达到1.0000,差异为0.000000 - 79个单元测试全部通过
This commit is contained in:
@@ -1,291 +1,323 @@
|
||||
"""
|
||||
风控层测试
|
||||
|
||||
测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook
|
||||
测试RiskControl、CallbackHook抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
from framework.risk import (
|
||||
RiskControl, StopLossControl, PositionLimitControl, PremiumControl,
|
||||
CallbackHook, Position, Trade,
|
||||
premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback
|
||||
)
|
||||
from framework.risk import RiskControl, CallbackHook, Position
|
||||
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
|
||||
|
||||
|
||||
class TestPosition:
|
||||
"""测试持仓信息"""
|
||||
"""测试Position数据结构"""
|
||||
|
||||
def test_position_profit(self):
|
||||
"""测试盈亏计算"""
|
||||
def test_position_creation(self):
|
||||
"""测试持仓创建"""
|
||||
position = Position(
|
||||
code='code1',
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.code == 'AAPL'
|
||||
assert position.entry_price == 100.0
|
||||
assert position.current_price == 105.0
|
||||
|
||||
def test_profit_ratio(self):
|
||||
"""测试盈亏比例计算"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=110.0,
|
||||
current_date=datetime(2020, 1, 10),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
assert position.profit_ratio == 0.10
|
||||
assert position.is_profit == True
|
||||
assert position.holding_days == 9
|
||||
|
||||
def test_position_loss(self):
|
||||
"""测试亏损计算"""
|
||||
def test_profit_amount(self):
|
||||
"""测试盈亏金额计算"""
|
||||
position = Position(
|
||||
code='code1',
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.profit_ratio == -0.05
|
||||
assert position.is_profit == False
|
||||
assert position.holding_days == 4
|
||||
assert position.profit_amount == 100.0
|
||||
|
||||
def test_position_repr(self):
|
||||
"""测试持仓字符串表示"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
repr_str = repr(position)
|
||||
assert 'AAPL' in repr_str
|
||||
assert '5.00%' in repr_str
|
||||
|
||||
|
||||
class TestStopLossControl:
|
||||
"""测试止损控制"""
|
||||
|
||||
def test_fixed_stoploss_check(self):
|
||||
"""测试固定止损检查"""
|
||||
def test_stop_loss_init(self):
|
||||
"""测试初始化"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
assert control.threshold == -0.05
|
||||
assert control.name == "stop_loss"
|
||||
|
||||
def test_stop_loss_check(self):
|
||||
"""测试止损检查"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
|
||||
# 未触发止损
|
||||
position = Position(
|
||||
code='code1',
|
||||
# 盈利持仓应该通过
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=96.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
assert control.check(profit_position) == True
|
||||
|
||||
# 亏损持仓应该触发止损
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=90.0, # 亏损10%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
# 亏损超过阈值应该触发
|
||||
assert control.check(loss_position) == False
|
||||
|
||||
def test_trailing_stop_loss(self):
|
||||
"""测试跟踪止损"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# 初始最高价=入场价100,当前价105
|
||||
# 更新后最高价=105
|
||||
control.check(position)
|
||||
|
||||
# 从最高价回撤不超过阈值应该通过
|
||||
position.current_price = 103.0 # 回撤约2%
|
||||
assert control.check(position) == True
|
||||
|
||||
# 触发止损
|
||||
position.current_price = 94.0
|
||||
# 回撤超过阈值应该触发
|
||||
position.current_price = 100.0 # 回撤约5%
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_fixed_stoploss_apply(self):
|
||||
"""测试固定止损应用"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
stop_price = control.apply(position)
|
||||
assert stop_price == 95.0 # 100 * (1 - 0.05)
|
||||
|
||||
def test_trailing_stoploss(self):
|
||||
"""测试跟踪止损"""
|
||||
control = StopLossControl(trailing=True, trailing_percent=0.03)
|
||||
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
# 最高价更新为105
|
||||
control.check(position)
|
||||
assert control._highest_price['code1'] == 105.0
|
||||
|
||||
# 当前价回撤到101(从105回撤4%),超过3%阈值
|
||||
position.current_price = 101.0
|
||||
assert control.check(position) == False
|
||||
|
||||
# 止损价格应为 105 * (1 - 0.03) = 101.85
|
||||
stop_price = control.apply(position)
|
||||
assert abs(stop_price - 101.85) < 0.01
|
||||
|
||||
|
||||
class TestPositionLimitControl:
|
||||
"""测试仓位限制控制"""
|
||||
|
||||
def test_position_limit_init(self):
|
||||
"""测试初始化"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
assert control.max_position == 0.33
|
||||
assert control.name == "position_limit"
|
||||
|
||||
def test_position_limit_check(self):
|
||||
"""测试仓位限制检查"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
|
||||
# 仓位未超限
|
||||
position = Position(
|
||||
code='code1',
|
||||
# 正常仓位应该通过
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.30
|
||||
entry_time=datetime.now(),
|
||||
weight=0.20
|
||||
)
|
||||
assert control.check(position) == True
|
||||
assert control.check(normal_position) == True
|
||||
|
||||
# 仓位超限
|
||||
position.weight = 0.40
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_position_limit_apply(self):
|
||||
"""测试仓位限制应用"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
position = Position(
|
||||
code='code1',
|
||||
# 超限仓位应该触发
|
||||
over_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
entry_time=datetime.now(),
|
||||
weight=0.50
|
||||
)
|
||||
|
||||
suggested_weight = control.apply(position)
|
||||
assert suggested_weight == 0.33
|
||||
assert control.check(over_position) == False
|
||||
|
||||
|
||||
class TestPremiumControl:
|
||||
"""测试溢价控制"""
|
||||
|
||||
def test_premium_filter(self):
|
||||
"""测试溢价过滤"""
|
||||
def test_premium_control_init(self):
|
||||
"""测试初始化"""
|
||||
control = PremiumControl(threshold=0.10)
|
||||
assert control.threshold == 0.10
|
||||
assert control.name == "premium"
|
||||
|
||||
def test_premium_filter_mode(self):
|
||||
"""测试溢价过滤模式"""
|
||||
control = PremiumControl(threshold=0.10, mode='filter')
|
||||
|
||||
# 溢价未超限
|
||||
assert control.check(None, premium=0.05) == True
|
||||
|
||||
# 溢价超限
|
||||
assert control.check(None, premium=0.15) == False
|
||||
|
||||
def test_premium_penalize(self):
|
||||
"""测试溢价降权"""
|
||||
control = PremiumControl(threshold=0.10, mode='penalize')
|
||||
|
||||
# 降权模式下允许通过
|
||||
assert control.check(None, premium=0.15) == True
|
||||
|
||||
# 返回降权系数
|
||||
position = Position(
|
||||
code='code1',
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
penalty = control.apply(position)
|
||||
assert penalty == 0.5
|
||||
|
||||
# 正常溢价应该通过
|
||||
assert control.check(position, premium=0.05) == True
|
||||
|
||||
# 高溢价应该被过滤
|
||||
assert control.check(position, premium=0.15) == False
|
||||
|
||||
|
||||
class TestCallbackHook:
|
||||
"""测试回调钩子"""
|
||||
|
||||
def test_register_hook(self):
|
||||
def test_callback_hook_init(self):
|
||||
"""测试初始化"""
|
||||
hook = CallbackHook()
|
||||
assert len(hook.list_hooks()) == 6
|
||||
|
||||
def test_register_callback(self):
|
||||
"""测试注册回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def dummy_callback(code, price):
|
||||
def my_callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', dummy_callback)
|
||||
assert len(hook._hooks['before_entry']) == 1
|
||||
hook.register('before_entry', my_callback)
|
||||
|
||||
callbacks = hook.get_callbacks('before_entry')
|
||||
assert len(callbacks) == 1
|
||||
|
||||
def test_trigger_before_entry(self):
|
||||
"""测试触发入场前回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册溢价过滤回调
|
||||
hook.register('before_entry', premium_filter_callback(threshold=0.10))
|
||||
def always_pass(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
# 溢价正常,允许入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
def always_block(code, price, **kwargs):
|
||||
return False
|
||||
|
||||
# 溢价过高,拒绝入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||
hook.register('before_entry', always_pass)
|
||||
hook.register('before_entry', always_block)
|
||||
|
||||
# before_entry要求所有回调返回True才允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == False
|
||||
|
||||
def test_trigger_dynamic_stoploss(self):
|
||||
"""测试触发动态止损回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册持仓时间止损回调
|
||||
hook.register('dynamic_stoploss', holding_time_stoploss_callback())
|
||||
def stoploss_5p(position):
|
||||
return -0.05
|
||||
|
||||
def stoploss_3p(position):
|
||||
return -0.03
|
||||
|
||||
hook.register('dynamic_stoploss', stoploss_5p)
|
||||
hook.register('dynamic_stoploss', stoploss_3p)
|
||||
|
||||
# 持仓5天,止损-5%
|
||||
position = Position(
|
||||
code='code1',
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 6),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||
# holding_days=5,返回-0.05
|
||||
assert stoploss == -0.05
|
||||
|
||||
# dynamic_stoploss返回最小止损值(最严格)
|
||||
result = hook.trigger('dynamic_stoploss', position)
|
||||
assert result == -0.05
|
||||
|
||||
def test_trigger_custom_exit(self):
|
||||
"""测试触发自定义出场回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def reversal_exit_callback(position):
|
||||
# 反转信号触发出场
|
||||
return position.profit_ratio < -0.02
|
||||
def exit_on_loss(position):
|
||||
return position.profit_ratio < -0.05
|
||||
|
||||
hook.register('custom_exit', reversal_exit_callback)
|
||||
def exit_on_profit(position):
|
||||
return position.profit_ratio > 0.20
|
||||
|
||||
# 未触发出场
|
||||
position = Position(
|
||||
code='code1',
|
||||
hook.register('custom_exit', exit_on_loss)
|
||||
hook.register('custom_exit', exit_on_profit)
|
||||
|
||||
# custom_exit任一回调触发即可
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=99.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
current_price=125.0, # 盈利25%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == False
|
||||
|
||||
# 触发出场
|
||||
position.current_price = 97.0
|
||||
result = hook.trigger('custom_exit', position)
|
||||
result = hook.trigger('custom_exit', profit_position)
|
||||
assert result == True
|
||||
|
||||
# 未触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', normal_position)
|
||||
assert result == False
|
||||
|
||||
def test_multiple_callbacks(self):
|
||||
"""测试多个回调组合"""
|
||||
def test_clear_hooks(self):
|
||||
"""测试清空回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册多个入场前回调
|
||||
hook.register('before_entry', premium_filter_callback(0.10))
|
||||
hook.register('before_entry', lambda code, price, **kwargs: price > 50)
|
||||
def callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
# 溢价正常 + 价格>50,允许入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 清空特定钩子
|
||||
hook.clear('before_entry')
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
|
||||
# 清空所有钩子
|
||||
hook.register('before_entry', callback)
|
||||
hook.register('after_entry', callback)
|
||||
hook.clear()
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
assert len(hook.get_callbacks('after_entry')) == 0
|
||||
|
||||
def test_default_behavior(self):
|
||||
"""测试默认行为"""
|
||||
hook = CallbackHook() # 无注册回调
|
||||
|
||||
# before_entry默认允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场(任一回调返回False)
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
# dynamic_stoploss默认值
|
||||
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
|
||||
assert result == -0.05
|
||||
|
||||
# 价格过低,拒绝入场
|
||||
result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05)
|
||||
# custom_exit默认不出场
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == False
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user