test: 更新测试以验证框架重构正确性

- 测试文件改用strategies.shared的具体实现
- 新增framework_comparison_test.py对比新旧实现结果
- 因子计算相关系数达到1.0000,差异为0.000000
- 79个单元测试全部通过
This commit is contained in:
2026-05-11 23:10:02 +08:00
parent de31271ab3
commit fc59836ec3
7 changed files with 1066 additions and 618 deletions

View File

@@ -1,291 +1,323 @@
"""
风控层测试
测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook
测试RiskControl、CallbackHook抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime, timedelta
from datetime import datetime
from framework.risk import (
RiskControl, StopLossControl, PositionLimitControl, PremiumControl,
CallbackHook, Position, Trade,
premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback
)
from framework.risk import RiskControl, CallbackHook, Position
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
class TestPosition:
"""测试持仓信息"""
"""测试Position数据结构"""
def test_position_profit(self):
"""测试盈亏计算"""
def test_position_creation(self):
"""测试持仓创建"""
position = Position(
code='code1',
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
quantity=10
)
assert position.code == 'AAPL'
assert position.entry_price == 100.0
assert position.current_price == 105.0
def test_profit_ratio(self):
"""测试盈亏比例计算"""
position = Position(
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=110.0,
current_date=datetime(2020, 1, 10),
quantity=100,
weight=0.33
entry_time=datetime.now()
)
assert position.profit_ratio == 0.10
assert position.is_profit == True
assert position.holding_days == 9
def test_position_loss(self):
"""测试亏损计算"""
def test_profit_amount(self):
"""测试盈亏金额计算"""
position = Position(
code='code1',
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
current_price=110.0,
entry_time=datetime.now(),
quantity=10
)
assert position.profit_ratio == -0.05
assert position.is_profit == False
assert position.holding_days == 4
assert position.profit_amount == 100.0
def test_position_repr(self):
"""测试持仓字符串表示"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
repr_str = repr(position)
assert 'AAPL' in repr_str
assert '5.00%' in repr_str
class TestStopLossControl:
"""测试止损控制"""
def test_fixed_stoploss_check(self):
"""测试固定止损检查"""
def test_stop_loss_init(self):
"""测试初始化"""
control = StopLossControl(threshold=-0.05)
assert control.threshold == -0.05
assert control.name == "stop_loss"
def test_stop_loss_check(self):
"""测试止损检查"""
control = StopLossControl(threshold=-0.05)
# 未触发止损
position = Position(
code='code1',
# 盈利持仓应该通过
profit_position = Position(
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=96.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
current_price=110.0,
entry_time=datetime.now()
)
assert control.check(profit_position) == True
# 亏损持仓应该触发止损
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=90.0, # 亏损10%
entry_time=datetime.now()
)
# 亏损超过阈值应该触发
assert control.check(loss_position) == False
def test_trailing_stop_loss(self):
"""测试跟踪止损"""
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 初始最高价=入场价100当前价105
# 更新后最高价=105
control.check(position)
# 从最高价回撤不超过阈值应该通过
position.current_price = 103.0 # 回撤约2%
assert control.check(position) == True
# 触发止损
position.current_price = 94.0
# 回撤超过阈值应该触发
position.current_price = 100.0 # 回撤约5%
assert control.check(position) == False
def test_fixed_stoploss_apply(self):
"""测试固定止损应用"""
control = StopLossControl(threshold=-0.05)
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
stop_price = control.apply(position)
assert stop_price == 95.0 # 100 * (1 - 0.05)
def test_trailing_stoploss(self):
"""测试跟踪止损"""
control = StopLossControl(trailing=True, trailing_percent=0.03)
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
# 最高价更新为105
control.check(position)
assert control._highest_price['code1'] == 105.0
# 当前价回撤到101从105回撤4%超过3%阈值
position.current_price = 101.0
assert control.check(position) == False
# 止损价格应为 105 * (1 - 0.03) = 101.85
stop_price = control.apply(position)
assert abs(stop_price - 101.85) < 0.01
class TestPositionLimitControl:
"""测试仓位限制控制"""
def test_position_limit_init(self):
"""测试初始化"""
control = PositionLimitControl(max_position=0.33)
assert control.max_position == 0.33
assert control.name == "position_limit"
def test_position_limit_check(self):
"""测试仓位限制检查"""
control = PositionLimitControl(max_position=0.33)
# 仓位未超限
position = Position(
code='code1',
# 正常仓位应该通过
normal_position = Position(
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.30
entry_time=datetime.now(),
weight=0.20
)
assert control.check(position) == True
assert control.check(normal_position) == True
# 仓位超限
position.weight = 0.40
assert control.check(position) == False
def test_position_limit_apply(self):
"""测试仓位限制应用"""
control = PositionLimitControl(max_position=0.33)
position = Position(
code='code1',
# 超限仓位应该触发
over_position = Position(
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
entry_time=datetime.now(),
weight=0.50
)
suggested_weight = control.apply(position)
assert suggested_weight == 0.33
assert control.check(over_position) == False
class TestPremiumControl:
"""测试溢价控制"""
def test_premium_filter(self):
"""测试溢价过滤"""
def test_premium_control_init(self):
"""测试初始化"""
control = PremiumControl(threshold=0.10)
assert control.threshold == 0.10
assert control.name == "premium"
def test_premium_filter_mode(self):
"""测试溢价过滤模式"""
control = PremiumControl(threshold=0.10, mode='filter')
# 溢价未超限
assert control.check(None, premium=0.05) == True
# 溢价超限
assert control.check(None, premium=0.15) == False
def test_premium_penalize(self):
"""测试溢价降权"""
control = PremiumControl(threshold=0.10, mode='penalize')
# 降权模式下允许通过
assert control.check(None, premium=0.15) == True
# 返回降权系数
position = Position(
code='code1',
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
entry_time=datetime.now()
)
penalty = control.apply(position)
assert penalty == 0.5
# 正常溢价应该通过
assert control.check(position, premium=0.05) == True
# 高溢价应该被过滤
assert control.check(position, premium=0.15) == False
class TestCallbackHook:
"""测试回调钩子"""
def test_register_hook(self):
def test_callback_hook_init(self):
"""测试初始化"""
hook = CallbackHook()
assert len(hook.list_hooks()) == 6
def test_register_callback(self):
"""测试注册回调"""
hook = CallbackHook()
def dummy_callback(code, price):
def my_callback(code, price, **kwargs):
return True
hook.register('before_entry', dummy_callback)
assert len(hook._hooks['before_entry']) == 1
hook.register('before_entry', my_callback)
callbacks = hook.get_callbacks('before_entry')
assert len(callbacks) == 1
def test_trigger_before_entry(self):
"""测试触发入场前回调"""
hook = CallbackHook()
# 注册溢价过滤回调
hook.register('before_entry', premium_filter_callback(threshold=0.10))
def always_pass(code, price, **kwargs):
return True
# 溢价正常,允许入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
assert result == True
def always_block(code, price, **kwargs):
return False
# 溢价过高,拒绝入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
hook.register('before_entry', always_pass)
hook.register('before_entry', always_block)
# before_entry要求所有回调返回True才允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == False
def test_trigger_dynamic_stoploss(self):
"""测试触发动态止损回调"""
hook = CallbackHook()
# 注册持仓时间止损回调
hook.register('dynamic_stoploss', holding_time_stoploss_callback())
def stoploss_5p(position):
return -0.05
def stoploss_3p(position):
return -0.03
hook.register('dynamic_stoploss', stoploss_5p)
hook.register('dynamic_stoploss', stoploss_3p)
# 持仓5天止损-5%
position = Position(
code='code1',
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 6),
quantity=100,
weight=0.33
current_price=105.0,
entry_time=datetime.now()
)
stoploss = hook.trigger('dynamic_stoploss', position)
# holding_days=5返回-0.05
assert stoploss == -0.05
# dynamic_stoploss返回最小止损值最严格
result = hook.trigger('dynamic_stoploss', position)
assert result == -0.05
def test_trigger_custom_exit(self):
"""测试触发自定义出场回调"""
hook = CallbackHook()
def reversal_exit_callback(position):
# 反转信号触发出场
return position.profit_ratio < -0.02
def exit_on_loss(position):
return position.profit_ratio < -0.05
hook.register('custom_exit', reversal_exit_callback)
def exit_on_profit(position):
return position.profit_ratio > 0.20
# 未触发出场
position = Position(
code='code1',
hook.register('custom_exit', exit_on_loss)
hook.register('custom_exit', exit_on_profit)
# custom_exit任一回调触发即可
profit_position = Position(
code='AAPL',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=99.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
current_price=125.0, # 盈利25%
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', position)
assert result == False
# 触发出场
position.current_price = 97.0
result = hook.trigger('custom_exit', position)
result = hook.trigger('custom_exit', profit_position)
assert result == True
# 未触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', normal_position)
assert result == False
def test_multiple_callbacks(self):
"""测试多个回调组合"""
def test_clear_hooks(self):
"""测试清空回调"""
hook = CallbackHook()
# 注册多个入场前回调
hook.register('before_entry', premium_filter_callback(0.10))
hook.register('before_entry', lambda code, price, **kwargs: price > 50)
def callback(code, price, **kwargs):
return True
# 溢价正常 + 价格>50允许入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
hook.register('before_entry', callback)
# 清空特定钩子
hook.clear('before_entry')
assert len(hook.get_callbacks('before_entry')) == 0
# 清空所有钩子
hook.register('before_entry', callback)
hook.register('after_entry', callback)
hook.clear()
assert len(hook.get_callbacks('before_entry')) == 0
assert len(hook.get_callbacks('after_entry')) == 0
def test_default_behavior(self):
"""测试默认行为"""
hook = CallbackHook() # 无注册回调
# before_entry默认允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == True
# 溢价过高拒绝入场任一回调返回False
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
assert result == False
# dynamic_stoploss默认值
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
assert result == -0.05
# 价格过低,拒绝入
result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05)
# custom_exit默认不出
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', position)
assert result == False