Files
etf/framework/risk/__init__.py
aszerW 512b73ac04 feat(risk): 实现风控层与回调钩子机制(融合Freqtrade设计)
核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制(固定止损/跟踪止损)
- PositionLimitControl: 仓位限制控制
- PremiumControl: 溢价控制(filter/penalize模式)

回调钩子机制:
- CallbackHook: 回调管理器(注册/触发)
- 5个核心回调:before_entry, after_entry, before_exit, after_exit, dynamic_stoploss, custom_exit

便捷回调函数:
- premium_filter_callback: 溢价过滤回调
- crash_filter_callback: 崩盘检测回调
- holding_time_stoploss_callback: 持仓时间动态止损

测试覆盖:13个测试全部通过
2026-05-11 22:18:41 +08:00

351 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
回调钩子与风控层设计
核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制
- PositionLimitControl: 仓位限制控制
- CallbackHook: 回调钩子管理
"""
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""持仓信息"""
code: str
entry_price: float
entry_date: datetime
current_price: float
current_date: datetime
quantity: float
weight: float
@property
def profit_ratio(self) -> float:
"""盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def holding_days(self) -> int:
"""持仓天数"""
return (self.current_date - self.entry_date).days
@property
def is_profit(self) -> bool:
"""是否盈利"""
return self.profit_ratio > 0
@dataclass
class Trade:
"""交易信息"""
code: str
direction: str # 'entry' or 'exit'
price: float
date: datetime
quantity: float
reason: str = ""
class RiskControl(ABC):
"""
风控抽象基类
所有风控组件必须继承此基类。
"""
name: str = "base"
def __init__(self, **params):
self._params = params
@abstractmethod
def check(self, position: Optional[Position], **kwargs) -> bool:
"""
风控检查
Args:
position: 持仓信息(可选)
**kwargs: 其他参数
Returns:
是否通过检查
"""
pass
@abstractmethod
def apply(self, position: Position) -> Optional[float]:
"""
应用风控
Args:
position: 持仓信息
Returns:
应用结果(如止损价格、仓位调整比例等)
"""
pass
@property
def params(self) -> Dict[str, Any]:
return self._params
class StopLossControl(RiskControl):
"""
止损控制
参数:
- threshold: 止损阈值(默认-0.05
- trailing: 是否跟踪止损默认False
- trailing_percent: 跟踪止损比例默认0.03
"""
name = "stop_loss"
def __init__(
self,
threshold: float = -0.05,
trailing: bool = False,
trailing_percent: float = 0.03
):
super().__init__(
threshold=threshold,
trailing=trailing,
trailing_percent=trailing_percent
)
self.threshold = threshold
self.trailing = trailing
self.trailing_percent = trailing_percent
self._highest_price = {} # 跟踪最高价
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查是否触发止损"""
if position is None:
return True
# 更新最高价(跟踪止损)
if self.trailing:
if position.code not in self._highest_price:
self._highest_price[position.code] = position.entry_price
self._highest_price[position.code] = max(
self._highest_price[position.code],
position.current_price
)
# 检查止损
if self.trailing:
# 跟踪止损:从最高价回撤超过阈值
highest = self._highest_price[position.code]
drawdown = (position.current_price - highest) / highest
return drawdown > -self.trailing_percent
else:
# 固定止损:从入场价亏损超过阈值
return position.profit_ratio > self.threshold
def apply(self, position: Position) -> Optional[float]:
"""返回止损价格"""
if self.trailing:
highest = self._highest_price.get(position.code, position.entry_price)
return highest * (1 - self.trailing_percent)
else:
return position.entry_price * (1 + self.threshold)
class PositionLimitControl(RiskControl):
"""
仓位限制控制
参数:
- max_position: 单品种最大仓位默认0.33
- max_total: 总仓位上限默认1.0
"""
name = "position_limit"
def __init__(
self,
max_position: float = 0.33,
max_total: float = 1.0
):
super().__init__(
max_position=max_position,
max_total=max_total
)
self.max_position = max_position
self.max_total = max_total
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查仓位是否超限"""
if position is None:
return True
# 检查单品种仓位
if position.weight > self.max_position:
return False
return True
def apply(self, position: Position) -> Optional[float]:
"""返回建议仓位"""
return min(position.weight, self.max_position)
class PremiumControl(RiskControl):
"""
溢价控制
参数:
- threshold: 溢价阈值默认0.10
- mode: 控制模式('filter''penalize'
"""
name = "premium"
def __init__(
self,
threshold: float = 0.10,
mode: str = 'filter'
):
super().__init__(
threshold=threshold,
mode=mode
)
self.threshold = threshold
self.mode = mode
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查溢价是否超限"""
premium = kwargs.get('premium', 0)
if self.mode == 'filter':
# 完全排除
return premium <= self.threshold
else:
# 仅降权,允许通过
return True
def apply(self, position: Position) -> Optional[float]:
"""返回溢价惩罚系数"""
if self.mode == 'penalize':
return 0.5 # 降权50%
return None
class CallbackHook:
"""
回调钩子管理
支持策略生命周期回调:
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场
"""
def __init__(self):
self._hooks = {
'before_entry': [],
'after_entry': [],
'before_exit': [],
'after_exit': [],
'dynamic_stoploss': [],
'custom_exit': []
}
def register(self, hook_name: str, callback: callable) -> None:
"""注册回调"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""触发回调"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
results = []
for callback in self._hooks[hook_name]:
try:
result = callback(*args, **kwargs)
results.append(result)
except Exception as e:
print(f"⚠ Callback error: {e}")
# before_entry和before_exit需要所有回调返回True
if hook_name in ['before_entry', 'before_exit']:
return all(results)
# dynamic_stoploss返回最小的止损值
if hook_name == 'dynamic_stoploss':
return min(results) if results else -0.05
# custom_exit返回是否有任一回调触发出场
if hook_name == 'custom_exit':
return any(results)
return results
def clear(self, hook_name: str = None) -> None:
"""清空回调"""
if hook_name:
self._hooks[hook_name] = []
else:
for key in self._hooks:
self._hooks[key] = []
# 便捷回调函数
def premium_filter_callback(threshold: float = 0.10):
"""溢价过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
premium = kwargs.get('premium', 0)
if premium > threshold:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
return callback
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
"""崩盘过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
history = kwargs.get('history', None)
if history is None:
return True
# 检查最近N天是否有崩盘
recent = history.tail(lookback)
if len(recent) < lookback:
return True
returns = recent['close'].pct_change()
min_return = returns.min()
if min_return < -crash_threshold:
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
return False
return True
return callback
def holding_time_stoploss_callback(
day_5_stoploss: float = -0.05,
day_10_stoploss: float = -0.03
):
"""持仓时间动态止损回调"""
def callback(position: Position) -> float:
if position.holding_days >= 10:
return day_10_stoploss # 10天后收紧止损
elif position.holding_days >= 5:
return day_5_stoploss
return -0.10 # 默认止损
return callback