feat(risk): 实现风控层与回调钩子机制(融合Freqtrade设计)
核心组件: - RiskControl: 风控抽象基类 - StopLossControl: 止损控制(固定止损/跟踪止损) - PositionLimitControl: 仓位限制控制 - PremiumControl: 溢价控制(filter/penalize模式) 回调钩子机制: - CallbackHook: 回调管理器(注册/触发) - 5个核心回调:before_entry, after_entry, before_exit, after_exit, dynamic_stoploss, custom_exit 便捷回调函数: - premium_filter_callback: 溢价过滤回调 - crash_filter_callback: 崩盘检测回调 - holding_time_stoploss_callback: 持仓时间动态止损 测试覆盖:13个测试全部通过
This commit is contained in:
351
framework/risk/__init__.py
Normal file
351
framework/risk/__init__.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
回调钩子与风控层设计
|
||||
|
||||
核心组件:
|
||||
- RiskControl: 风控抽象基类
|
||||
- StopLossControl: 止损控制
|
||||
- PositionLimitControl: 仓位限制控制
|
||||
- CallbackHook: 回调钩子管理
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
"""持仓信息"""
|
||||
code: str
|
||||
entry_price: float
|
||||
entry_date: datetime
|
||||
current_price: float
|
||||
current_date: datetime
|
||||
quantity: float
|
||||
weight: float
|
||||
|
||||
@property
|
||||
def profit_ratio(self) -> float:
|
||||
"""盈亏比例"""
|
||||
return (self.current_price - self.entry_price) / self.entry_price
|
||||
|
||||
@property
|
||||
def holding_days(self) -> int:
|
||||
"""持仓天数"""
|
||||
return (self.current_date - self.entry_date).days
|
||||
|
||||
@property
|
||||
def is_profit(self) -> bool:
|
||||
"""是否盈利"""
|
||||
return self.profit_ratio > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class Trade:
|
||||
"""交易信息"""
|
||||
code: str
|
||||
direction: str # 'entry' or 'exit'
|
||||
price: float
|
||||
date: datetime
|
||||
quantity: float
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class RiskControl(ABC):
|
||||
"""
|
||||
风控抽象基类
|
||||
|
||||
所有风控组件必须继承此基类。
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""
|
||||
风控检查
|
||||
|
||||
Args:
|
||||
position: 持仓信息(可选)
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
是否通过检查
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""
|
||||
应用风控
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
|
||||
Returns:
|
||||
应用结果(如止损价格、仓位调整比例等)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, Any]:
|
||||
return self._params
|
||||
|
||||
|
||||
class StopLossControl(RiskControl):
|
||||
"""
|
||||
止损控制
|
||||
|
||||
参数:
|
||||
- threshold: 止损阈值(默认-0.05)
|
||||
- trailing: 是否跟踪止损(默认False)
|
||||
- trailing_percent: 跟踪止损比例(默认0.03)
|
||||
"""
|
||||
|
||||
name = "stop_loss"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = -0.05,
|
||||
trailing: bool = False,
|
||||
trailing_percent: float = 0.03
|
||||
):
|
||||
super().__init__(
|
||||
threshold=threshold,
|
||||
trailing=trailing,
|
||||
trailing_percent=trailing_percent
|
||||
)
|
||||
self.threshold = threshold
|
||||
self.trailing = trailing
|
||||
self.trailing_percent = trailing_percent
|
||||
self._highest_price = {} # 跟踪最高价
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查是否触发止损"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
# 更新最高价(跟踪止损)
|
||||
if self.trailing:
|
||||
if position.code not in self._highest_price:
|
||||
self._highest_price[position.code] = position.entry_price
|
||||
self._highest_price[position.code] = max(
|
||||
self._highest_price[position.code],
|
||||
position.current_price
|
||||
)
|
||||
|
||||
# 检查止损
|
||||
if self.trailing:
|
||||
# 跟踪止损:从最高价回撤超过阈值
|
||||
highest = self._highest_price[position.code]
|
||||
drawdown = (position.current_price - highest) / highest
|
||||
return drawdown > -self.trailing_percent
|
||||
else:
|
||||
# 固定止损:从入场价亏损超过阈值
|
||||
return position.profit_ratio > self.threshold
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回止损价格"""
|
||||
if self.trailing:
|
||||
highest = self._highest_price.get(position.code, position.entry_price)
|
||||
return highest * (1 - self.trailing_percent)
|
||||
else:
|
||||
return position.entry_price * (1 + self.threshold)
|
||||
|
||||
|
||||
class PositionLimitControl(RiskControl):
|
||||
"""
|
||||
仓位限制控制
|
||||
|
||||
参数:
|
||||
- max_position: 单品种最大仓位(默认0.33)
|
||||
- max_total: 总仓位上限(默认1.0)
|
||||
"""
|
||||
|
||||
name = "position_limit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_position: float = 0.33,
|
||||
max_total: float = 1.0
|
||||
):
|
||||
super().__init__(
|
||||
max_position=max_position,
|
||||
max_total=max_total
|
||||
)
|
||||
self.max_position = max_position
|
||||
self.max_total = max_total
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查仓位是否超限"""
|
||||
if position is None:
|
||||
return True
|
||||
|
||||
# 检查单品种仓位
|
||||
if position.weight > self.max_position:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回建议仓位"""
|
||||
return min(position.weight, self.max_position)
|
||||
|
||||
|
||||
class PremiumControl(RiskControl):
|
||||
"""
|
||||
溢价控制
|
||||
|
||||
参数:
|
||||
- threshold: 溢价阈值(默认0.10)
|
||||
- mode: 控制模式('filter'或'penalize')
|
||||
"""
|
||||
|
||||
name = "premium"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.10,
|
||||
mode: str = 'filter'
|
||||
):
|
||||
super().__init__(
|
||||
threshold=threshold,
|
||||
mode=mode
|
||||
)
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||
"""检查溢价是否超限"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
|
||||
if self.mode == 'filter':
|
||||
# 完全排除
|
||||
return premium <= self.threshold
|
||||
else:
|
||||
# 仅降权,允许通过
|
||||
return True
|
||||
|
||||
def apply(self, position: Position) -> Optional[float]:
|
||||
"""返回溢价惩罚系数"""
|
||||
if self.mode == 'penalize':
|
||||
return 0.5 # 降权50%
|
||||
return None
|
||||
|
||||
|
||||
class CallbackHook:
|
||||
"""
|
||||
回调钩子管理
|
||||
|
||||
支持策略生命周期回调:
|
||||
- before_entry: 入场前检查
|
||||
- after_entry: 入场后处理
|
||||
- before_exit: 出场前检查
|
||||
- after_exit: 出场后处理
|
||||
- dynamic_stoploss: 动态止损
|
||||
- custom_exit: 自定义出场
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks = {
|
||||
'before_entry': [],
|
||||
'after_entry': [],
|
||||
'before_exit': [],
|
||||
'after_exit': [],
|
||||
'dynamic_stoploss': [],
|
||||
'custom_exit': []
|
||||
}
|
||||
|
||||
def register(self, hook_name: str, callback: callable) -> None:
|
||||
"""注册回调"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unknown hook: {hook_name}")
|
||||
self._hooks[hook_name].append(callback)
|
||||
|
||||
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||||
"""触发回调"""
|
||||
if hook_name not in self._hooks:
|
||||
raise ValueError(f"Unknown hook: {hook_name}")
|
||||
|
||||
results = []
|
||||
for callback in self._hooks[hook_name]:
|
||||
try:
|
||||
result = callback(*args, **kwargs)
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
print(f"⚠ Callback error: {e}")
|
||||
|
||||
# before_entry和before_exit需要所有回调返回True
|
||||
if hook_name in ['before_entry', 'before_exit']:
|
||||
return all(results)
|
||||
|
||||
# dynamic_stoploss返回最小的止损值
|
||||
if hook_name == 'dynamic_stoploss':
|
||||
return min(results) if results else -0.05
|
||||
|
||||
# custom_exit返回是否有任一回调触发出场
|
||||
if hook_name == 'custom_exit':
|
||||
return any(results)
|
||||
|
||||
return results
|
||||
|
||||
def clear(self, hook_name: str = None) -> None:
|
||||
"""清空回调"""
|
||||
if hook_name:
|
||||
self._hooks[hook_name] = []
|
||||
else:
|
||||
for key in self._hooks:
|
||||
self._hooks[key] = []
|
||||
|
||||
|
||||
# 便捷回调函数
|
||||
def premium_filter_callback(threshold: float = 0.10):
|
||||
"""溢价过滤回调"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
premium = kwargs.get('premium', 0)
|
||||
if premium > threshold:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||||
"""崩盘过滤回调"""
|
||||
def callback(code: str, price: float, **kwargs) -> bool:
|
||||
history = kwargs.get('history', None)
|
||||
if history is None:
|
||||
return True
|
||||
|
||||
# 检查最近N天是否有崩盘
|
||||
recent = history.tail(lookback)
|
||||
if len(recent) < lookback:
|
||||
return True
|
||||
|
||||
returns = recent['close'].pct_change()
|
||||
min_return = returns.min()
|
||||
|
||||
if min_return < -crash_threshold:
|
||||
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||||
return False
|
||||
return True
|
||||
return callback
|
||||
|
||||
|
||||
def holding_time_stoploss_callback(
|
||||
day_5_stoploss: float = -0.05,
|
||||
day_10_stoploss: float = -0.03
|
||||
):
|
||||
"""持仓时间动态止损回调"""
|
||||
def callback(position: Position) -> float:
|
||||
if position.holding_days >= 10:
|
||||
return day_10_stoploss # 10天后收紧止损
|
||||
elif position.holding_days >= 5:
|
||||
return day_5_stoploss
|
||||
return -0.10 # 默认止损
|
||||
return callback
|
||||
293
framework/tests/test_risk.py
Normal file
293
framework/tests/test_risk.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
风控层测试
|
||||
|
||||
测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from framework.risk import (
|
||||
RiskControl, StopLossControl, PositionLimitControl, PremiumControl,
|
||||
CallbackHook, Position, Trade,
|
||||
premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback
|
||||
)
|
||||
|
||||
|
||||
class TestPosition:
|
||||
"""测试持仓信息"""
|
||||
|
||||
def test_position_profit(self):
|
||||
"""测试盈亏计算"""
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=110.0,
|
||||
current_date=datetime(2020, 1, 10),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
assert position.profit_ratio == 0.10
|
||||
assert position.is_profit == True
|
||||
assert position.holding_days == 9
|
||||
|
||||
def test_position_loss(self):
|
||||
"""测试亏损计算"""
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
assert position.profit_ratio == -0.05
|
||||
assert position.is_profit == False
|
||||
assert position.holding_days == 4
|
||||
|
||||
|
||||
class TestStopLossControl:
|
||||
"""测试止损控制"""
|
||||
|
||||
def test_fixed_stoploss_check(self):
|
||||
"""测试固定止损检查"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
|
||||
# 未触发止损
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=96.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
assert control.check(position) == True
|
||||
|
||||
# 触发止损
|
||||
position.current_price = 94.0
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_fixed_stoploss_apply(self):
|
||||
"""测试固定止损应用"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
stop_price = control.apply(position)
|
||||
assert stop_price == 95.0 # 100 * (1 - 0.05)
|
||||
|
||||
def test_trailing_stoploss(self):
|
||||
"""测试跟踪止损"""
|
||||
control = StopLossControl(trailing=True, trailing_percent=0.03)
|
||||
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
|
||||
# 最高价更新为105
|
||||
control.check(position)
|
||||
assert control._highest_price['code1'] == 105.0
|
||||
|
||||
# 当前价回撤到101(从105回撤4%),超过3%阈值
|
||||
position.current_price = 101.0
|
||||
assert control.check(position) == False
|
||||
|
||||
# 止损价格应为 105 * (1 - 0.03) = 101.85
|
||||
stop_price = control.apply(position)
|
||||
assert abs(stop_price - 101.85) < 0.01
|
||||
|
||||
|
||||
class TestPositionLimitControl:
|
||||
"""测试仓位限制控制"""
|
||||
|
||||
def test_position_limit_check(self):
|
||||
"""测试仓位限制检查"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
|
||||
# 仓位未超限
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.30
|
||||
)
|
||||
assert control.check(position) == True
|
||||
|
||||
# 仓位超限
|
||||
position.weight = 0.40
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_position_limit_apply(self):
|
||||
"""测试仓位限制应用"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.50
|
||||
)
|
||||
|
||||
suggested_weight = control.apply(position)
|
||||
assert suggested_weight == 0.33
|
||||
|
||||
|
||||
class TestPremiumControl:
|
||||
"""测试溢价控制"""
|
||||
|
||||
def test_premium_filter(self):
|
||||
"""测试溢价过滤"""
|
||||
control = PremiumControl(threshold=0.10, mode='filter')
|
||||
|
||||
# 溢价未超限
|
||||
assert control.check(None, premium=0.05) == True
|
||||
|
||||
# 溢价超限
|
||||
assert control.check(None, premium=0.15) == False
|
||||
|
||||
def test_premium_penalize(self):
|
||||
"""测试溢价降权"""
|
||||
control = PremiumControl(threshold=0.10, mode='penalize')
|
||||
|
||||
# 降权模式下允许通过
|
||||
assert control.check(None, premium=0.15) == True
|
||||
|
||||
# 返回降权系数
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=105.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
penalty = control.apply(position)
|
||||
assert penalty == 0.5
|
||||
|
||||
|
||||
class TestCallbackHook:
|
||||
"""测试回调钩子"""
|
||||
|
||||
def test_register_hook(self):
|
||||
"""测试注册回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def dummy_callback(code, price):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', dummy_callback)
|
||||
assert len(hook._hooks['before_entry']) == 1
|
||||
|
||||
def test_trigger_before_entry(self):
|
||||
"""测试触发入场前回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册溢价过滤回调
|
||||
hook.register('before_entry', premium_filter_callback(threshold=0.10))
|
||||
|
||||
# 溢价正常,允许入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_trigger_dynamic_stoploss(self):
|
||||
"""测试触发动态止损回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册持仓时间止损回调
|
||||
hook.register('dynamic_stoploss', holding_time_stoploss_callback())
|
||||
|
||||
# 持仓5天,止损-5%
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 6),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||
# holding_days=5,返回-0.05
|
||||
assert stoploss == -0.05
|
||||
|
||||
def test_trigger_custom_exit(self):
|
||||
"""测试触发自定义出场回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def reversal_exit_callback(position):
|
||||
# 反转信号触发出场
|
||||
return position.profit_ratio < -0.02
|
||||
|
||||
hook.register('custom_exit', reversal_exit_callback)
|
||||
|
||||
# 未触发出场
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=99.0,
|
||||
current_date=datetime(2020, 1, 5),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == False
|
||||
|
||||
# 触发出场
|
||||
position.current_price = 97.0
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == True
|
||||
|
||||
def test_multiple_callbacks(self):
|
||||
"""测试多个回调组合"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册多个入场前回调
|
||||
hook.register('before_entry', premium_filter_callback(0.10))
|
||||
hook.register('before_entry', lambda code, price, **kwargs: price > 50)
|
||||
|
||||
# 溢价正常 + 价格>50,允许入场
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场(任一回调返回False)
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
# 价格过低,拒绝入场
|
||||
result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05)
|
||||
assert result == False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user