diff --git a/framework/risk/__init__.py b/framework/risk/__init__.py new file mode 100644 index 0000000..1596769 --- /dev/null +++ b/framework/risk/__init__.py @@ -0,0 +1,351 @@ +""" +回调钩子与风控层设计 + +核心组件: +- RiskControl: 风控抽象基类 +- StopLossControl: 止损控制 +- PositionLimitControl: 仓位限制控制 +- CallbackHook: 回调钩子管理 +""" + +import pandas as pd +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class Position: + """持仓信息""" + code: str + entry_price: float + entry_date: datetime + current_price: float + current_date: datetime + quantity: float + weight: float + + @property + def profit_ratio(self) -> float: + """盈亏比例""" + return (self.current_price - self.entry_price) / self.entry_price + + @property + def holding_days(self) -> int: + """持仓天数""" + return (self.current_date - self.entry_date).days + + @property + def is_profit(self) -> bool: + """是否盈利""" + return self.profit_ratio > 0 + + +@dataclass +class Trade: + """交易信息""" + code: str + direction: str # 'entry' or 'exit' + price: float + date: datetime + quantity: float + reason: str = "" + + +class RiskControl(ABC): + """ + 风控抽象基类 + + 所有风控组件必须继承此基类。 + """ + + name: str = "base" + + def __init__(self, **params): + self._params = params + + @abstractmethod + def check(self, position: Optional[Position], **kwargs) -> bool: + """ + 风控检查 + + Args: + position: 持仓信息(可选) + **kwargs: 其他参数 + + Returns: + 是否通过检查 + """ + pass + + @abstractmethod + def apply(self, position: Position) -> Optional[float]: + """ + 应用风控 + + Args: + position: 持仓信息 + + Returns: + 应用结果(如止损价格、仓位调整比例等) + """ + pass + + @property + def params(self) -> Dict[str, Any]: + return self._params + + +class StopLossControl(RiskControl): + """ + 止损控制 + + 参数: + - threshold: 止损阈值(默认-0.05) + - trailing: 是否跟踪止损(默认False) + - trailing_percent: 跟踪止损比例(默认0.03) + """ + + name = "stop_loss" + + def __init__( + self, + threshold: float = -0.05, + trailing: bool = False, + trailing_percent: float = 0.03 + ): + super().__init__( + threshold=threshold, + trailing=trailing, + trailing_percent=trailing_percent + ) + self.threshold = threshold + self.trailing = trailing + self.trailing_percent = trailing_percent + self._highest_price = {} # 跟踪最高价 + + def check(self, position: Optional[Position], **kwargs) -> bool: + """检查是否触发止损""" + if position is None: + return True + + # 更新最高价(跟踪止损) + if self.trailing: + if position.code not in self._highest_price: + self._highest_price[position.code] = position.entry_price + self._highest_price[position.code] = max( + self._highest_price[position.code], + position.current_price + ) + + # 检查止损 + if self.trailing: + # 跟踪止损:从最高价回撤超过阈值 + highest = self._highest_price[position.code] + drawdown = (position.current_price - highest) / highest + return drawdown > -self.trailing_percent + else: + # 固定止损:从入场价亏损超过阈值 + return position.profit_ratio > self.threshold + + def apply(self, position: Position) -> Optional[float]: + """返回止损价格""" + if self.trailing: + highest = self._highest_price.get(position.code, position.entry_price) + return highest * (1 - self.trailing_percent) + else: + return position.entry_price * (1 + self.threshold) + + +class PositionLimitControl(RiskControl): + """ + 仓位限制控制 + + 参数: + - max_position: 单品种最大仓位(默认0.33) + - max_total: 总仓位上限(默认1.0) + """ + + name = "position_limit" + + def __init__( + self, + max_position: float = 0.33, + max_total: float = 1.0 + ): + super().__init__( + max_position=max_position, + max_total=max_total + ) + self.max_position = max_position + self.max_total = max_total + + def check(self, position: Optional[Position], **kwargs) -> bool: + """检查仓位是否超限""" + if position is None: + return True + + # 检查单品种仓位 + if position.weight > self.max_position: + return False + + return True + + def apply(self, position: Position) -> Optional[float]: + """返回建议仓位""" + return min(position.weight, self.max_position) + + +class PremiumControl(RiskControl): + """ + 溢价控制 + + 参数: + - threshold: 溢价阈值(默认0.10) + - mode: 控制模式('filter'或'penalize') + """ + + name = "premium" + + def __init__( + self, + threshold: float = 0.10, + mode: str = 'filter' + ): + super().__init__( + threshold=threshold, + mode=mode + ) + self.threshold = threshold + self.mode = mode + + def check(self, position: Optional[Position], **kwargs) -> bool: + """检查溢价是否超限""" + premium = kwargs.get('premium', 0) + + if self.mode == 'filter': + # 完全排除 + return premium <= self.threshold + else: + # 仅降权,允许通过 + return True + + def apply(self, position: Position) -> Optional[float]: + """返回溢价惩罚系数""" + if self.mode == 'penalize': + return 0.5 # 降权50% + return None + + +class CallbackHook: + """ + 回调钩子管理 + + 支持策略生命周期回调: + - before_entry: 入场前检查 + - after_entry: 入场后处理 + - before_exit: 出场前检查 + - after_exit: 出场后处理 + - dynamic_stoploss: 动态止损 + - custom_exit: 自定义出场 + """ + + def __init__(self): + self._hooks = { + 'before_entry': [], + 'after_entry': [], + 'before_exit': [], + 'after_exit': [], + 'dynamic_stoploss': [], + 'custom_exit': [] + } + + def register(self, hook_name: str, callback: callable) -> None: + """注册回调""" + if hook_name not in self._hooks: + raise ValueError(f"Unknown hook: {hook_name}") + self._hooks[hook_name].append(callback) + + def trigger(self, hook_name: str, *args, **kwargs) -> Any: + """触发回调""" + if hook_name not in self._hooks: + raise ValueError(f"Unknown hook: {hook_name}") + + results = [] + for callback in self._hooks[hook_name]: + try: + result = callback(*args, **kwargs) + results.append(result) + except Exception as e: + print(f"⚠ Callback error: {e}") + + # before_entry和before_exit需要所有回调返回True + if hook_name in ['before_entry', 'before_exit']: + return all(results) + + # dynamic_stoploss返回最小的止损值 + if hook_name == 'dynamic_stoploss': + return min(results) if results else -0.05 + + # custom_exit返回是否有任一回调触发出场 + if hook_name == 'custom_exit': + return any(results) + + return results + + def clear(self, hook_name: str = None) -> None: + """清空回调""" + if hook_name: + self._hooks[hook_name] = [] + else: + for key in self._hooks: + self._hooks[key] = [] + + +# 便捷回调函数 +def premium_filter_callback(threshold: float = 0.10): + """溢价过滤回调""" + def callback(code: str, price: float, **kwargs) -> bool: + premium = kwargs.get('premium', 0) + if premium > threshold: + print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") + return False + return True + return callback + + +def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05): + """崩盘过滤回调""" + def callback(code: str, price: float, **kwargs) -> bool: + history = kwargs.get('history', None) + if history is None: + return True + + # 检查最近N天是否有崩盘 + recent = history.tail(lookback) + if len(recent) < lookback: + return True + + returns = recent['close'].pct_change() + min_return = returns.min() + + if min_return < -crash_threshold: + print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})") + return False + return True + return callback + + +def holding_time_stoploss_callback( + day_5_stoploss: float = -0.05, + day_10_stoploss: float = -0.03 +): + """持仓时间动态止损回调""" + def callback(position: Position) -> float: + if position.holding_days >= 10: + return day_10_stoploss # 10天后收紧止损 + elif position.holding_days >= 5: + return day_5_stoploss + return -0.10 # 默认止损 + return callback \ No newline at end of file diff --git a/framework/tests/test_risk.py b/framework/tests/test_risk.py new file mode 100644 index 0000000..67bb249 --- /dev/null +++ b/framework/tests/test_risk.py @@ -0,0 +1,293 @@ +""" +风控层测试 + +测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook +""" + +import pandas as pd +import pytest +from datetime import datetime, timedelta + +from framework.risk import ( + RiskControl, StopLossControl, PositionLimitControl, PremiumControl, + CallbackHook, Position, Trade, + premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback +) + + +class TestPosition: + """测试持仓信息""" + + def test_position_profit(self): + """测试盈亏计算""" + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=110.0, + current_date=datetime(2020, 1, 10), + quantity=100, + weight=0.33 + ) + + assert position.profit_ratio == 0.10 + assert position.is_profit == True + assert position.holding_days == 9 + + def test_position_loss(self): + """测试亏损计算""" + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=95.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + + assert position.profit_ratio == -0.05 + assert position.is_profit == False + assert position.holding_days == 4 + + +class TestStopLossControl: + """测试止损控制""" + + def test_fixed_stoploss_check(self): + """测试固定止损检查""" + control = StopLossControl(threshold=-0.05) + + # 未触发止损 + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=96.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + assert control.check(position) == True + + # 触发止损 + position.current_price = 94.0 + assert control.check(position) == False + + def test_fixed_stoploss_apply(self): + """测试固定止损应用""" + control = StopLossControl(threshold=-0.05) + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=95.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + + stop_price = control.apply(position) + assert stop_price == 95.0 # 100 * (1 - 0.05) + + def test_trailing_stoploss(self): + """测试跟踪止损""" + control = StopLossControl(trailing=True, trailing_percent=0.03) + + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=105.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + + # 最高价更新为105 + control.check(position) + assert control._highest_price['code1'] == 105.0 + + # 当前价回撤到101(从105回撤4%),超过3%阈值 + position.current_price = 101.0 + assert control.check(position) == False + + # 止损价格应为 105 * (1 - 0.03) = 101.85 + stop_price = control.apply(position) + assert abs(stop_price - 101.85) < 0.01 + + +class TestPositionLimitControl: + """测试仓位限制控制""" + + def test_position_limit_check(self): + """测试仓位限制检查""" + control = PositionLimitControl(max_position=0.33) + + # 仓位未超限 + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=105.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.30 + ) + assert control.check(position) == True + + # 仓位超限 + position.weight = 0.40 + assert control.check(position) == False + + def test_position_limit_apply(self): + """测试仓位限制应用""" + control = PositionLimitControl(max_position=0.33) + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=105.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.50 + ) + + suggested_weight = control.apply(position) + assert suggested_weight == 0.33 + + +class TestPremiumControl: + """测试溢价控制""" + + def test_premium_filter(self): + """测试溢价过滤""" + control = PremiumControl(threshold=0.10, mode='filter') + + # 溢价未超限 + assert control.check(None, premium=0.05) == True + + # 溢价超限 + assert control.check(None, premium=0.15) == False + + def test_premium_penalize(self): + """测试溢价降权""" + control = PremiumControl(threshold=0.10, mode='penalize') + + # 降权模式下允许通过 + assert control.check(None, premium=0.15) == True + + # 返回降权系数 + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=105.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + penalty = control.apply(position) + assert penalty == 0.5 + + +class TestCallbackHook: + """测试回调钩子""" + + def test_register_hook(self): + """测试注册回调""" + hook = CallbackHook() + + def dummy_callback(code, price): + return True + + hook.register('before_entry', dummy_callback) + assert len(hook._hooks['before_entry']) == 1 + + def test_trigger_before_entry(self): + """测试触发入场前回调""" + hook = CallbackHook() + + # 注册溢价过滤回调 + hook.register('before_entry', premium_filter_callback(threshold=0.10)) + + # 溢价正常,允许入场 + result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05) + assert result == True + + # 溢价过高,拒绝入场 + result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15) + assert result == False + + def test_trigger_dynamic_stoploss(self): + """测试触发动态止损回调""" + hook = CallbackHook() + + # 注册持仓时间止损回调 + hook.register('dynamic_stoploss', holding_time_stoploss_callback()) + + # 持仓5天,止损-5% + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=95.0, + current_date=datetime(2020, 1, 6), + quantity=100, + weight=0.33 + ) + stoploss = hook.trigger('dynamic_stoploss', position) + # holding_days=5,返回-0.05 + assert stoploss == -0.05 + + def test_trigger_custom_exit(self): + """测试触发自定义出场回调""" + hook = CallbackHook() + + def reversal_exit_callback(position): + # 反转信号触发出场 + return position.profit_ratio < -0.02 + + hook.register('custom_exit', reversal_exit_callback) + + # 未触发出场 + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=99.0, + current_date=datetime(2020, 1, 5), + quantity=100, + weight=0.33 + ) + result = hook.trigger('custom_exit', position) + assert result == False + + # 触发出场 + position.current_price = 97.0 + result = hook.trigger('custom_exit', position) + assert result == True + + def test_multiple_callbacks(self): + """测试多个回调组合""" + hook = CallbackHook() + + # 注册多个入场前回调 + hook.register('before_entry', premium_filter_callback(0.10)) + hook.register('before_entry', lambda code, price, **kwargs: price > 50) + + # 溢价正常 + 价格>50,允许入场 + result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05) + assert result == True + + # 溢价过高,拒绝入场(任一回调返回False) + result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15) + assert result == False + + # 价格过低,拒绝入场 + result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05) + assert result == False + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file