""" 回调钩子与风控层设计 核心组件: - RiskControl: 风控抽象基类 - StopLossControl: 止损控制 - PositionLimitControl: 仓位限制控制 - CallbackHook: 回调钩子管理 """ import pandas as pd from abc import ABC, abstractmethod from typing import Dict, Any, Optional, List from dataclasses import dataclass from datetime import datetime @dataclass class Position: """持仓信息""" code: str entry_price: float entry_date: datetime current_price: float current_date: datetime quantity: float weight: float @property def profit_ratio(self) -> float: """盈亏比例""" return (self.current_price - self.entry_price) / self.entry_price @property def holding_days(self) -> int: """持仓天数""" return (self.current_date - self.entry_date).days @property def is_profit(self) -> bool: """是否盈利""" return self.profit_ratio > 0 @dataclass class Trade: """交易信息""" code: str direction: str # 'entry' or 'exit' price: float date: datetime quantity: float reason: str = "" class RiskControl(ABC): """ 风控抽象基类 所有风控组件必须继承此基类。 """ name: str = "base" def __init__(self, **params): self._params = params @abstractmethod def check(self, position: Optional[Position], **kwargs) -> bool: """ 风控检查 Args: position: 持仓信息(可选) **kwargs: 其他参数 Returns: 是否通过检查 """ pass @abstractmethod def apply(self, position: Position) -> Optional[float]: """ 应用风控 Args: position: 持仓信息 Returns: 应用结果(如止损价格、仓位调整比例等) """ pass @property def params(self) -> Dict[str, Any]: return self._params class StopLossControl(RiskControl): """ 止损控制 参数: - threshold: 止损阈值(默认-0.05) - trailing: 是否跟踪止损(默认False) - trailing_percent: 跟踪止损比例(默认0.03) """ name = "stop_loss" def __init__( self, threshold: float = -0.05, trailing: bool = False, trailing_percent: float = 0.03 ): super().__init__( threshold=threshold, trailing=trailing, trailing_percent=trailing_percent ) self.threshold = threshold self.trailing = trailing self.trailing_percent = trailing_percent self._highest_price = {} # 跟踪最高价 def check(self, position: Optional[Position], **kwargs) -> bool: """检查是否触发止损""" if position is None: return True # 更新最高价(跟踪止损) if self.trailing: if position.code not in self._highest_price: self._highest_price[position.code] = position.entry_price self._highest_price[position.code] = max( self._highest_price[position.code], position.current_price ) # 检查止损 if self.trailing: # 跟踪止损:从最高价回撤超过阈值 highest = self._highest_price[position.code] drawdown = (position.current_price - highest) / highest return drawdown > -self.trailing_percent else: # 固定止损:从入场价亏损超过阈值 return position.profit_ratio > self.threshold def apply(self, position: Position) -> Optional[float]: """返回止损价格""" if self.trailing: highest = self._highest_price.get(position.code, position.entry_price) return highest * (1 - self.trailing_percent) else: return position.entry_price * (1 + self.threshold) class PositionLimitControl(RiskControl): """ 仓位限制控制 参数: - max_position: 单品种最大仓位(默认0.33) - max_total: 总仓位上限(默认1.0) """ name = "position_limit" def __init__( self, max_position: float = 0.33, max_total: float = 1.0 ): super().__init__( max_position=max_position, max_total=max_total ) self.max_position = max_position self.max_total = max_total def check(self, position: Optional[Position], **kwargs) -> bool: """检查仓位是否超限""" if position is None: return True # 检查单品种仓位 if position.weight > self.max_position: return False return True def apply(self, position: Position) -> Optional[float]: """返回建议仓位""" return min(position.weight, self.max_position) class PremiumControl(RiskControl): """ 溢价控制 参数: - threshold: 溢价阈值(默认0.10) - mode: 控制模式('filter'或'penalize') """ name = "premium" def __init__( self, threshold: float = 0.10, mode: str = 'filter' ): super().__init__( threshold=threshold, mode=mode ) self.threshold = threshold self.mode = mode def check(self, position: Optional[Position], **kwargs) -> bool: """检查溢价是否超限""" premium = kwargs.get('premium', 0) if self.mode == 'filter': # 完全排除 return premium <= self.threshold else: # 仅降权,允许通过 return True def apply(self, position: Position) -> Optional[float]: """返回溢价惩罚系数""" if self.mode == 'penalize': return 0.5 # 降权50% return None class CallbackHook: """ 回调钩子管理 支持策略生命周期回调: - before_entry: 入场前检查 - after_entry: 入场后处理 - before_exit: 出场前检查 - after_exit: 出场后处理 - dynamic_stoploss: 动态止损 - custom_exit: 自定义出场 """ def __init__(self): self._hooks = { 'before_entry': [], 'after_entry': [], 'before_exit': [], 'after_exit': [], 'dynamic_stoploss': [], 'custom_exit': [] } def register(self, hook_name: str, callback: callable) -> None: """注册回调""" if hook_name not in self._hooks: raise ValueError(f"Unknown hook: {hook_name}") self._hooks[hook_name].append(callback) def trigger(self, hook_name: str, *args, **kwargs) -> Any: """触发回调""" if hook_name not in self._hooks: raise ValueError(f"Unknown hook: {hook_name}") results = [] for callback in self._hooks[hook_name]: try: result = callback(*args, **kwargs) results.append(result) except Exception as e: print(f"⚠ Callback error: {e}") # before_entry和before_exit需要所有回调返回True if hook_name in ['before_entry', 'before_exit']: return all(results) # dynamic_stoploss返回最小的止损值 if hook_name == 'dynamic_stoploss': return min(results) if results else -0.05 # custom_exit返回是否有任一回调触发出场 if hook_name == 'custom_exit': return any(results) return results def clear(self, hook_name: str = None) -> None: """清空回调""" if hook_name: self._hooks[hook_name] = [] else: for key in self._hooks: self._hooks[key] = [] # 便捷回调函数 def premium_filter_callback(threshold: float = 0.10): """溢价过滤回调""" def callback(code: str, price: float, **kwargs) -> bool: premium = kwargs.get('premium', 0) if premium > threshold: print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") return False return True return callback def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05): """崩盘过滤回调""" def callback(code: str, price: float, **kwargs) -> bool: history = kwargs.get('history', None) if history is None: return True # 检查最近N天是否有崩盘 recent = history.tail(lookback) if len(recent) < lookback: return True returns = recent['close'].pct_change() min_return = returns.min() if min_return < -crash_threshold: print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})") return False return True return callback def holding_time_stoploss_callback( day_5_stoploss: float = -0.05, day_10_stoploss: float = -0.03 ): """持仓时间动态止损回调""" def callback(position: Position) -> float: if position.holding_days >= 10: return day_10_stoploss # 10天后收紧止损 elif position.holding_days >= 5: return day_5_stoploss return -0.10 # 默认止损 return callback