核心组件: - RiskControl: 风控抽象基类 - StopLossControl: 止损控制(固定止损/跟踪止损) - PositionLimitControl: 仓位限制控制 - PremiumControl: 溢价控制(filter/penalize模式) 回调钩子机制: - CallbackHook: 回调管理器(注册/触发) - 5个核心回调:before_entry, after_entry, before_exit, after_exit, dynamic_stoploss, custom_exit 便捷回调函数: - premium_filter_callback: 溢价过滤回调 - crash_filter_callback: 崩盘检测回调 - holding_time_stoploss_callback: 持仓时间动态止损 测试覆盖:13个测试全部通过
351 lines
9.5 KiB
Python
351 lines
9.5 KiB
Python
"""
|
||
回调钩子与风控层设计
|
||
|
||
核心组件:
|
||
- RiskControl: 风控抽象基类
|
||
- StopLossControl: 止损控制
|
||
- PositionLimitControl: 仓位限制控制
|
||
- CallbackHook: 回调钩子管理
|
||
"""
|
||
|
||
import pandas as pd
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, Any, Optional, List
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
|
||
|
||
@dataclass
|
||
class Position:
|
||
"""持仓信息"""
|
||
code: str
|
||
entry_price: float
|
||
entry_date: datetime
|
||
current_price: float
|
||
current_date: datetime
|
||
quantity: float
|
||
weight: float
|
||
|
||
@property
|
||
def profit_ratio(self) -> float:
|
||
"""盈亏比例"""
|
||
return (self.current_price - self.entry_price) / self.entry_price
|
||
|
||
@property
|
||
def holding_days(self) -> int:
|
||
"""持仓天数"""
|
||
return (self.current_date - self.entry_date).days
|
||
|
||
@property
|
||
def is_profit(self) -> bool:
|
||
"""是否盈利"""
|
||
return self.profit_ratio > 0
|
||
|
||
|
||
@dataclass
|
||
class Trade:
|
||
"""交易信息"""
|
||
code: str
|
||
direction: str # 'entry' or 'exit'
|
||
price: float
|
||
date: datetime
|
||
quantity: float
|
||
reason: str = ""
|
||
|
||
|
||
class RiskControl(ABC):
|
||
"""
|
||
风控抽象基类
|
||
|
||
所有风控组件必须继承此基类。
|
||
"""
|
||
|
||
name: str = "base"
|
||
|
||
def __init__(self, **params):
|
||
self._params = params
|
||
|
||
@abstractmethod
|
||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||
"""
|
||
风控检查
|
||
|
||
Args:
|
||
position: 持仓信息(可选)
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
是否通过检查
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def apply(self, position: Position) -> Optional[float]:
|
||
"""
|
||
应用风控
|
||
|
||
Args:
|
||
position: 持仓信息
|
||
|
||
Returns:
|
||
应用结果(如止损价格、仓位调整比例等)
|
||
"""
|
||
pass
|
||
|
||
@property
|
||
def params(self) -> Dict[str, Any]:
|
||
return self._params
|
||
|
||
|
||
class StopLossControl(RiskControl):
|
||
"""
|
||
止损控制
|
||
|
||
参数:
|
||
- threshold: 止损阈值(默认-0.05)
|
||
- trailing: 是否跟踪止损(默认False)
|
||
- trailing_percent: 跟踪止损比例(默认0.03)
|
||
"""
|
||
|
||
name = "stop_loss"
|
||
|
||
def __init__(
|
||
self,
|
||
threshold: float = -0.05,
|
||
trailing: bool = False,
|
||
trailing_percent: float = 0.03
|
||
):
|
||
super().__init__(
|
||
threshold=threshold,
|
||
trailing=trailing,
|
||
trailing_percent=trailing_percent
|
||
)
|
||
self.threshold = threshold
|
||
self.trailing = trailing
|
||
self.trailing_percent = trailing_percent
|
||
self._highest_price = {} # 跟踪最高价
|
||
|
||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||
"""检查是否触发止损"""
|
||
if position is None:
|
||
return True
|
||
|
||
# 更新最高价(跟踪止损)
|
||
if self.trailing:
|
||
if position.code not in self._highest_price:
|
||
self._highest_price[position.code] = position.entry_price
|
||
self._highest_price[position.code] = max(
|
||
self._highest_price[position.code],
|
||
position.current_price
|
||
)
|
||
|
||
# 检查止损
|
||
if self.trailing:
|
||
# 跟踪止损:从最高价回撤超过阈值
|
||
highest = self._highest_price[position.code]
|
||
drawdown = (position.current_price - highest) / highest
|
||
return drawdown > -self.trailing_percent
|
||
else:
|
||
# 固定止损:从入场价亏损超过阈值
|
||
return position.profit_ratio > self.threshold
|
||
|
||
def apply(self, position: Position) -> Optional[float]:
|
||
"""返回止损价格"""
|
||
if self.trailing:
|
||
highest = self._highest_price.get(position.code, position.entry_price)
|
||
return highest * (1 - self.trailing_percent)
|
||
else:
|
||
return position.entry_price * (1 + self.threshold)
|
||
|
||
|
||
class PositionLimitControl(RiskControl):
|
||
"""
|
||
仓位限制控制
|
||
|
||
参数:
|
||
- max_position: 单品种最大仓位(默认0.33)
|
||
- max_total: 总仓位上限(默认1.0)
|
||
"""
|
||
|
||
name = "position_limit"
|
||
|
||
def __init__(
|
||
self,
|
||
max_position: float = 0.33,
|
||
max_total: float = 1.0
|
||
):
|
||
super().__init__(
|
||
max_position=max_position,
|
||
max_total=max_total
|
||
)
|
||
self.max_position = max_position
|
||
self.max_total = max_total
|
||
|
||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||
"""检查仓位是否超限"""
|
||
if position is None:
|
||
return True
|
||
|
||
# 检查单品种仓位
|
||
if position.weight > self.max_position:
|
||
return False
|
||
|
||
return True
|
||
|
||
def apply(self, position: Position) -> Optional[float]:
|
||
"""返回建议仓位"""
|
||
return min(position.weight, self.max_position)
|
||
|
||
|
||
class PremiumControl(RiskControl):
|
||
"""
|
||
溢价控制
|
||
|
||
参数:
|
||
- threshold: 溢价阈值(默认0.10)
|
||
- mode: 控制模式('filter'或'penalize')
|
||
"""
|
||
|
||
name = "premium"
|
||
|
||
def __init__(
|
||
self,
|
||
threshold: float = 0.10,
|
||
mode: str = 'filter'
|
||
):
|
||
super().__init__(
|
||
threshold=threshold,
|
||
mode=mode
|
||
)
|
||
self.threshold = threshold
|
||
self.mode = mode
|
||
|
||
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||
"""检查溢价是否超限"""
|
||
premium = kwargs.get('premium', 0)
|
||
|
||
if self.mode == 'filter':
|
||
# 完全排除
|
||
return premium <= self.threshold
|
||
else:
|
||
# 仅降权,允许通过
|
||
return True
|
||
|
||
def apply(self, position: Position) -> Optional[float]:
|
||
"""返回溢价惩罚系数"""
|
||
if self.mode == 'penalize':
|
||
return 0.5 # 降权50%
|
||
return None
|
||
|
||
|
||
class CallbackHook:
|
||
"""
|
||
回调钩子管理
|
||
|
||
支持策略生命周期回调:
|
||
- before_entry: 入场前检查
|
||
- after_entry: 入场后处理
|
||
- before_exit: 出场前检查
|
||
- after_exit: 出场后处理
|
||
- dynamic_stoploss: 动态止损
|
||
- custom_exit: 自定义出场
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._hooks = {
|
||
'before_entry': [],
|
||
'after_entry': [],
|
||
'before_exit': [],
|
||
'after_exit': [],
|
||
'dynamic_stoploss': [],
|
||
'custom_exit': []
|
||
}
|
||
|
||
def register(self, hook_name: str, callback: callable) -> None:
|
||
"""注册回调"""
|
||
if hook_name not in self._hooks:
|
||
raise ValueError(f"Unknown hook: {hook_name}")
|
||
self._hooks[hook_name].append(callback)
|
||
|
||
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||
"""触发回调"""
|
||
if hook_name not in self._hooks:
|
||
raise ValueError(f"Unknown hook: {hook_name}")
|
||
|
||
results = []
|
||
for callback in self._hooks[hook_name]:
|
||
try:
|
||
result = callback(*args, **kwargs)
|
||
results.append(result)
|
||
except Exception as e:
|
||
print(f"⚠ Callback error: {e}")
|
||
|
||
# before_entry和before_exit需要所有回调返回True
|
||
if hook_name in ['before_entry', 'before_exit']:
|
||
return all(results)
|
||
|
||
# dynamic_stoploss返回最小的止损值
|
||
if hook_name == 'dynamic_stoploss':
|
||
return min(results) if results else -0.05
|
||
|
||
# custom_exit返回是否有任一回调触发出场
|
||
if hook_name == 'custom_exit':
|
||
return any(results)
|
||
|
||
return results
|
||
|
||
def clear(self, hook_name: str = None) -> None:
|
||
"""清空回调"""
|
||
if hook_name:
|
||
self._hooks[hook_name] = []
|
||
else:
|
||
for key in self._hooks:
|
||
self._hooks[key] = []
|
||
|
||
|
||
# 便捷回调函数
|
||
def premium_filter_callback(threshold: float = 0.10):
|
||
"""溢价过滤回调"""
|
||
def callback(code: str, price: float, **kwargs) -> bool:
|
||
premium = kwargs.get('premium', 0)
|
||
if premium > threshold:
|
||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||
return False
|
||
return True
|
||
return callback
|
||
|
||
|
||
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||
"""崩盘过滤回调"""
|
||
def callback(code: str, price: float, **kwargs) -> bool:
|
||
history = kwargs.get('history', None)
|
||
if history is None:
|
||
return True
|
||
|
||
# 检查最近N天是否有崩盘
|
||
recent = history.tail(lookback)
|
||
if len(recent) < lookback:
|
||
return True
|
||
|
||
returns = recent['close'].pct_change()
|
||
min_return = returns.min()
|
||
|
||
if min_return < -crash_threshold:
|
||
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||
return False
|
||
return True
|
||
return callback
|
||
|
||
|
||
def holding_time_stoploss_callback(
|
||
day_5_stoploss: float = -0.05,
|
||
day_10_stoploss: float = -0.03
|
||
):
|
||
"""持仓时间动态止损回调"""
|
||
def callback(position: Position) -> float:
|
||
if position.holding_days >= 10:
|
||
return day_10_stoploss # 10天后收紧止损
|
||
elif position.holding_days >= 5:
|
||
return day_5_stoploss
|
||
return -0.10 # 默认止损
|
||
return callback |