feat(risk): 实现风控层与回调钩子机制(融合Freqtrade设计)

核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制(固定止损/跟踪止损)
- PositionLimitControl: 仓位限制控制
- PremiumControl: 溢价控制(filter/penalize模式)

回调钩子机制:
- CallbackHook: 回调管理器(注册/触发)
- 5个核心回调:before_entry, after_entry, before_exit, after_exit, dynamic_stoploss, custom_exit

便捷回调函数:
- premium_filter_callback: 溢价过滤回调
- crash_filter_callback: 崩盘检测回调
- holding_time_stoploss_callback: 持仓时间动态止损

测试覆盖:13个测试全部通过
This commit is contained in:
2026-05-11 22:18:41 +08:00
parent f5e6202eee
commit 512b73ac04
2 changed files with 644 additions and 0 deletions

351
framework/risk/__init__.py Normal file
View File

@@ -0,0 +1,351 @@
"""
回调钩子与风控层设计
核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制
- PositionLimitControl: 仓位限制控制
- CallbackHook: 回调钩子管理
"""
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""持仓信息"""
code: str
entry_price: float
entry_date: datetime
current_price: float
current_date: datetime
quantity: float
weight: float
@property
def profit_ratio(self) -> float:
"""盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def holding_days(self) -> int:
"""持仓天数"""
return (self.current_date - self.entry_date).days
@property
def is_profit(self) -> bool:
"""是否盈利"""
return self.profit_ratio > 0
@dataclass
class Trade:
"""交易信息"""
code: str
direction: str # 'entry' or 'exit'
price: float
date: datetime
quantity: float
reason: str = ""
class RiskControl(ABC):
"""
风控抽象基类
所有风控组件必须继承此基类。
"""
name: str = "base"
def __init__(self, **params):
self._params = params
@abstractmethod
def check(self, position: Optional[Position], **kwargs) -> bool:
"""
风控检查
Args:
position: 持仓信息(可选)
**kwargs: 其他参数
Returns:
是否通过检查
"""
pass
@abstractmethod
def apply(self, position: Position) -> Optional[float]:
"""
应用风控
Args:
position: 持仓信息
Returns:
应用结果(如止损价格、仓位调整比例等)
"""
pass
@property
def params(self) -> Dict[str, Any]:
return self._params
class StopLossControl(RiskControl):
"""
止损控制
参数:
- threshold: 止损阈值(默认-0.05
- trailing: 是否跟踪止损默认False
- trailing_percent: 跟踪止损比例默认0.03
"""
name = "stop_loss"
def __init__(
self,
threshold: float = -0.05,
trailing: bool = False,
trailing_percent: float = 0.03
):
super().__init__(
threshold=threshold,
trailing=trailing,
trailing_percent=trailing_percent
)
self.threshold = threshold
self.trailing = trailing
self.trailing_percent = trailing_percent
self._highest_price = {} # 跟踪最高价
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查是否触发止损"""
if position is None:
return True
# 更新最高价(跟踪止损)
if self.trailing:
if position.code not in self._highest_price:
self._highest_price[position.code] = position.entry_price
self._highest_price[position.code] = max(
self._highest_price[position.code],
position.current_price
)
# 检查止损
if self.trailing:
# 跟踪止损:从最高价回撤超过阈值
highest = self._highest_price[position.code]
drawdown = (position.current_price - highest) / highest
return drawdown > -self.trailing_percent
else:
# 固定止损:从入场价亏损超过阈值
return position.profit_ratio > self.threshold
def apply(self, position: Position) -> Optional[float]:
"""返回止损价格"""
if self.trailing:
highest = self._highest_price.get(position.code, position.entry_price)
return highest * (1 - self.trailing_percent)
else:
return position.entry_price * (1 + self.threshold)
class PositionLimitControl(RiskControl):
"""
仓位限制控制
参数:
- max_position: 单品种最大仓位默认0.33
- max_total: 总仓位上限默认1.0
"""
name = "position_limit"
def __init__(
self,
max_position: float = 0.33,
max_total: float = 1.0
):
super().__init__(
max_position=max_position,
max_total=max_total
)
self.max_position = max_position
self.max_total = max_total
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查仓位是否超限"""
if position is None:
return True
# 检查单品种仓位
if position.weight > self.max_position:
return False
return True
def apply(self, position: Position) -> Optional[float]:
"""返回建议仓位"""
return min(position.weight, self.max_position)
class PremiumControl(RiskControl):
"""
溢价控制
参数:
- threshold: 溢价阈值默认0.10
- mode: 控制模式('filter''penalize'
"""
name = "premium"
def __init__(
self,
threshold: float = 0.10,
mode: str = 'filter'
):
super().__init__(
threshold=threshold,
mode=mode
)
self.threshold = threshold
self.mode = mode
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查溢价是否超限"""
premium = kwargs.get('premium', 0)
if self.mode == 'filter':
# 完全排除
return premium <= self.threshold
else:
# 仅降权,允许通过
return True
def apply(self, position: Position) -> Optional[float]:
"""返回溢价惩罚系数"""
if self.mode == 'penalize':
return 0.5 # 降权50%
return None
class CallbackHook:
"""
回调钩子管理
支持策略生命周期回调:
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场
"""
def __init__(self):
self._hooks = {
'before_entry': [],
'after_entry': [],
'before_exit': [],
'after_exit': [],
'dynamic_stoploss': [],
'custom_exit': []
}
def register(self, hook_name: str, callback: callable) -> None:
"""注册回调"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""触发回调"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
results = []
for callback in self._hooks[hook_name]:
try:
result = callback(*args, **kwargs)
results.append(result)
except Exception as e:
print(f"⚠ Callback error: {e}")
# before_entry和before_exit需要所有回调返回True
if hook_name in ['before_entry', 'before_exit']:
return all(results)
# dynamic_stoploss返回最小的止损值
if hook_name == 'dynamic_stoploss':
return min(results) if results else -0.05
# custom_exit返回是否有任一回调触发出场
if hook_name == 'custom_exit':
return any(results)
return results
def clear(self, hook_name: str = None) -> None:
"""清空回调"""
if hook_name:
self._hooks[hook_name] = []
else:
for key in self._hooks:
self._hooks[key] = []
# 便捷回调函数
def premium_filter_callback(threshold: float = 0.10):
"""溢价过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
premium = kwargs.get('premium', 0)
if premium > threshold:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
return callback
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
"""崩盘过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
history = kwargs.get('history', None)
if history is None:
return True
# 检查最近N天是否有崩盘
recent = history.tail(lookback)
if len(recent) < lookback:
return True
returns = recent['close'].pct_change()
min_return = returns.min()
if min_return < -crash_threshold:
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
return False
return True
return callback
def holding_time_stoploss_callback(
day_5_stoploss: float = -0.05,
day_10_stoploss: float = -0.03
):
"""持仓时间动态止损回调"""
def callback(position: Position) -> float:
if position.holding_days >= 10:
return day_10_stoploss # 10天后收紧止损
elif position.holding_days >= 5:
return day_5_stoploss
return -0.10 # 默认止损
return callback

View File

@@ -0,0 +1,293 @@
"""
风控层测试
测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook
"""
import pandas as pd
import pytest
from datetime import datetime, timedelta
from framework.risk import (
RiskControl, StopLossControl, PositionLimitControl, PremiumControl,
CallbackHook, Position, Trade,
premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback
)
class TestPosition:
"""测试持仓信息"""
def test_position_profit(self):
"""测试盈亏计算"""
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=110.0,
current_date=datetime(2020, 1, 10),
quantity=100,
weight=0.33
)
assert position.profit_ratio == 0.10
assert position.is_profit == True
assert position.holding_days == 9
def test_position_loss(self):
"""测试亏损计算"""
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
assert position.profit_ratio == -0.05
assert position.is_profit == False
assert position.holding_days == 4
class TestStopLossControl:
"""测试止损控制"""
def test_fixed_stoploss_check(self):
"""测试固定止损检查"""
control = StopLossControl(threshold=-0.05)
# 未触发止损
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=96.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
assert control.check(position) == True
# 触发止损
position.current_price = 94.0
assert control.check(position) == False
def test_fixed_stoploss_apply(self):
"""测试固定止损应用"""
control = StopLossControl(threshold=-0.05)
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
stop_price = control.apply(position)
assert stop_price == 95.0 # 100 * (1 - 0.05)
def test_trailing_stoploss(self):
"""测试跟踪止损"""
control = StopLossControl(trailing=True, trailing_percent=0.03)
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
# 最高价更新为105
control.check(position)
assert control._highest_price['code1'] == 105.0
# 当前价回撤到101从105回撤4%超过3%阈值
position.current_price = 101.0
assert control.check(position) == False
# 止损价格应为 105 * (1 - 0.03) = 101.85
stop_price = control.apply(position)
assert abs(stop_price - 101.85) < 0.01
class TestPositionLimitControl:
"""测试仓位限制控制"""
def test_position_limit_check(self):
"""测试仓位限制检查"""
control = PositionLimitControl(max_position=0.33)
# 仓位未超限
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.30
)
assert control.check(position) == True
# 仓位超限
position.weight = 0.40
assert control.check(position) == False
def test_position_limit_apply(self):
"""测试仓位限制应用"""
control = PositionLimitControl(max_position=0.33)
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.50
)
suggested_weight = control.apply(position)
assert suggested_weight == 0.33
class TestPremiumControl:
"""测试溢价控制"""
def test_premium_filter(self):
"""测试溢价过滤"""
control = PremiumControl(threshold=0.10, mode='filter')
# 溢价未超限
assert control.check(None, premium=0.05) == True
# 溢价超限
assert control.check(None, premium=0.15) == False
def test_premium_penalize(self):
"""测试溢价降权"""
control = PremiumControl(threshold=0.10, mode='penalize')
# 降权模式下允许通过
assert control.check(None, premium=0.15) == True
# 返回降权系数
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=105.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
penalty = control.apply(position)
assert penalty == 0.5
class TestCallbackHook:
"""测试回调钩子"""
def test_register_hook(self):
"""测试注册回调"""
hook = CallbackHook()
def dummy_callback(code, price):
return True
hook.register('before_entry', dummy_callback)
assert len(hook._hooks['before_entry']) == 1
def test_trigger_before_entry(self):
"""测试触发入场前回调"""
hook = CallbackHook()
# 注册溢价过滤回调
hook.register('before_entry', premium_filter_callback(threshold=0.10))
# 溢价正常,允许入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
assert result == True
# 溢价过高,拒绝入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
assert result == False
def test_trigger_dynamic_stoploss(self):
"""测试触发动态止损回调"""
hook = CallbackHook()
# 注册持仓时间止损回调
hook.register('dynamic_stoploss', holding_time_stoploss_callback())
# 持仓5天止损-5%
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 6),
quantity=100,
weight=0.33
)
stoploss = hook.trigger('dynamic_stoploss', position)
# holding_days=5返回-0.05
assert stoploss == -0.05
def test_trigger_custom_exit(self):
"""测试触发自定义出场回调"""
hook = CallbackHook()
def reversal_exit_callback(position):
# 反转信号触发出场
return position.profit_ratio < -0.02
hook.register('custom_exit', reversal_exit_callback)
# 未触发出场
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=99.0,
current_date=datetime(2020, 1, 5),
quantity=100,
weight=0.33
)
result = hook.trigger('custom_exit', position)
assert result == False
# 触发出场
position.current_price = 97.0
result = hook.trigger('custom_exit', position)
assert result == True
def test_multiple_callbacks(self):
"""测试多个回调组合"""
hook = CallbackHook()
# 注册多个入场前回调
hook.register('before_entry', premium_filter_callback(0.10))
hook.register('before_entry', lambda code, price, **kwargs: price > 50)
# 溢价正常 + 价格>50允许入场
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
assert result == True
# 溢价过高拒绝入场任一回调返回False
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
assert result == False
# 价格过低,拒绝入场
result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05)
assert result == False
if __name__ == '__main__':
pytest.main([__file__, '-v'])