feat(risk): 实现风控层与回调钩子机制(融合Freqtrade设计)
核心组件: - RiskControl: 风控抽象基类 - StopLossControl: 止损控制(固定止损/跟踪止损) - PositionLimitControl: 仓位限制控制 - PremiumControl: 溢价控制(filter/penalize模式) 回调钩子机制: - CallbackHook: 回调管理器(注册/触发) - 5个核心回调:before_entry, after_entry, before_exit, after_exit, dynamic_stoploss, custom_exit 便捷回调函数: - premium_filter_callback: 溢价过滤回调 - crash_filter_callback: 崩盘检测回调 - holding_time_stoploss_callback: 持仓时间动态止损 测试覆盖:13个测试全部通过
This commit is contained in:
351
framework/risk/__init__.py
Normal file
351
framework/risk/__init__.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""
|
||||||
|
回调钩子与风控层设计
|
||||||
|
|
||||||
|
核心组件:
|
||||||
|
- RiskControl: 风控抽象基类
|
||||||
|
- StopLossControl: 止损控制
|
||||||
|
- PositionLimitControl: 仓位限制控制
|
||||||
|
- CallbackHook: 回调钩子管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Position:
|
||||||
|
"""持仓信息"""
|
||||||
|
code: str
|
||||||
|
entry_price: float
|
||||||
|
entry_date: datetime
|
||||||
|
current_price: float
|
||||||
|
current_date: datetime
|
||||||
|
quantity: float
|
||||||
|
weight: float
|
||||||
|
|
||||||
|
@property
|
||||||
|
def profit_ratio(self) -> float:
|
||||||
|
"""盈亏比例"""
|
||||||
|
return (self.current_price - self.entry_price) / self.entry_price
|
||||||
|
|
||||||
|
@property
|
||||||
|
def holding_days(self) -> int:
|
||||||
|
"""持仓天数"""
|
||||||
|
return (self.current_date - self.entry_date).days
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_profit(self) -> bool:
|
||||||
|
"""是否盈利"""
|
||||||
|
return self.profit_ratio > 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Trade:
|
||||||
|
"""交易信息"""
|
||||||
|
code: str
|
||||||
|
direction: str # 'entry' or 'exit'
|
||||||
|
price: float
|
||||||
|
date: datetime
|
||||||
|
quantity: float
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class RiskControl(ABC):
|
||||||
|
"""
|
||||||
|
风控抽象基类
|
||||||
|
|
||||||
|
所有风控组件必须继承此基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "base"
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
self._params = params
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
风控检查
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position: 持仓信息(可选)
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否通过检查
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, position: Position) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
应用风控
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position: 持仓信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
应用结果(如止损价格、仓位调整比例等)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def params(self) -> Dict[str, Any]:
|
||||||
|
return self._params
|
||||||
|
|
||||||
|
|
||||||
|
class StopLossControl(RiskControl):
|
||||||
|
"""
|
||||||
|
止损控制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- threshold: 止损阈值(默认-0.05)
|
||||||
|
- trailing: 是否跟踪止损(默认False)
|
||||||
|
- trailing_percent: 跟踪止损比例(默认0.03)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "stop_loss"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
threshold: float = -0.05,
|
||||||
|
trailing: bool = False,
|
||||||
|
trailing_percent: float = 0.03
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
threshold=threshold,
|
||||||
|
trailing=trailing,
|
||||||
|
trailing_percent=trailing_percent
|
||||||
|
)
|
||||||
|
self.threshold = threshold
|
||||||
|
self.trailing = trailing
|
||||||
|
self.trailing_percent = trailing_percent
|
||||||
|
self._highest_price = {} # 跟踪最高价
|
||||||
|
|
||||||
|
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||||
|
"""检查是否触发止损"""
|
||||||
|
if position is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 更新最高价(跟踪止损)
|
||||||
|
if self.trailing:
|
||||||
|
if position.code not in self._highest_price:
|
||||||
|
self._highest_price[position.code] = position.entry_price
|
||||||
|
self._highest_price[position.code] = max(
|
||||||
|
self._highest_price[position.code],
|
||||||
|
position.current_price
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查止损
|
||||||
|
if self.trailing:
|
||||||
|
# 跟踪止损:从最高价回撤超过阈值
|
||||||
|
highest = self._highest_price[position.code]
|
||||||
|
drawdown = (position.current_price - highest) / highest
|
||||||
|
return drawdown > -self.trailing_percent
|
||||||
|
else:
|
||||||
|
# 固定止损:从入场价亏损超过阈值
|
||||||
|
return position.profit_ratio > self.threshold
|
||||||
|
|
||||||
|
def apply(self, position: Position) -> Optional[float]:
|
||||||
|
"""返回止损价格"""
|
||||||
|
if self.trailing:
|
||||||
|
highest = self._highest_price.get(position.code, position.entry_price)
|
||||||
|
return highest * (1 - self.trailing_percent)
|
||||||
|
else:
|
||||||
|
return position.entry_price * (1 + self.threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionLimitControl(RiskControl):
|
||||||
|
"""
|
||||||
|
仓位限制控制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- max_position: 单品种最大仓位(默认0.33)
|
||||||
|
- max_total: 总仓位上限(默认1.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "position_limit"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_position: float = 0.33,
|
||||||
|
max_total: float = 1.0
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
max_position=max_position,
|
||||||
|
max_total=max_total
|
||||||
|
)
|
||||||
|
self.max_position = max_position
|
||||||
|
self.max_total = max_total
|
||||||
|
|
||||||
|
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||||
|
"""检查仓位是否超限"""
|
||||||
|
if position is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查单品种仓位
|
||||||
|
if position.weight > self.max_position:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, position: Position) -> Optional[float]:
|
||||||
|
"""返回建议仓位"""
|
||||||
|
return min(position.weight, self.max_position)
|
||||||
|
|
||||||
|
|
||||||
|
class PremiumControl(RiskControl):
|
||||||
|
"""
|
||||||
|
溢价控制
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- threshold: 溢价阈值(默认0.10)
|
||||||
|
- mode: 控制模式('filter'或'penalize')
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "premium"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
threshold: float = 0.10,
|
||||||
|
mode: str = 'filter'
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
threshold=threshold,
|
||||||
|
mode=mode
|
||||||
|
)
|
||||||
|
self.threshold = threshold
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def check(self, position: Optional[Position], **kwargs) -> bool:
|
||||||
|
"""检查溢价是否超限"""
|
||||||
|
premium = kwargs.get('premium', 0)
|
||||||
|
|
||||||
|
if self.mode == 'filter':
|
||||||
|
# 完全排除
|
||||||
|
return premium <= self.threshold
|
||||||
|
else:
|
||||||
|
# 仅降权,允许通过
|
||||||
|
return True
|
||||||
|
|
||||||
|
def apply(self, position: Position) -> Optional[float]:
|
||||||
|
"""返回溢价惩罚系数"""
|
||||||
|
if self.mode == 'penalize':
|
||||||
|
return 0.5 # 降权50%
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackHook:
|
||||||
|
"""
|
||||||
|
回调钩子管理
|
||||||
|
|
||||||
|
支持策略生命周期回调:
|
||||||
|
- before_entry: 入场前检查
|
||||||
|
- after_entry: 入场后处理
|
||||||
|
- before_exit: 出场前检查
|
||||||
|
- after_exit: 出场后处理
|
||||||
|
- dynamic_stoploss: 动态止损
|
||||||
|
- custom_exit: 自定义出场
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hooks = {
|
||||||
|
'before_entry': [],
|
||||||
|
'after_entry': [],
|
||||||
|
'before_exit': [],
|
||||||
|
'after_exit': [],
|
||||||
|
'dynamic_stoploss': [],
|
||||||
|
'custom_exit': []
|
||||||
|
}
|
||||||
|
|
||||||
|
def register(self, hook_name: str, callback: callable) -> None:
|
||||||
|
"""注册回调"""
|
||||||
|
if hook_name not in self._hooks:
|
||||||
|
raise ValueError(f"Unknown hook: {hook_name}")
|
||||||
|
self._hooks[hook_name].append(callback)
|
||||||
|
|
||||||
|
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||||||
|
"""触发回调"""
|
||||||
|
if hook_name not in self._hooks:
|
||||||
|
raise ValueError(f"Unknown hook: {hook_name}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for callback in self._hooks[hook_name]:
|
||||||
|
try:
|
||||||
|
result = callback(*args, **kwargs)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠ Callback error: {e}")
|
||||||
|
|
||||||
|
# before_entry和before_exit需要所有回调返回True
|
||||||
|
if hook_name in ['before_entry', 'before_exit']:
|
||||||
|
return all(results)
|
||||||
|
|
||||||
|
# dynamic_stoploss返回最小的止损值
|
||||||
|
if hook_name == 'dynamic_stoploss':
|
||||||
|
return min(results) if results else -0.05
|
||||||
|
|
||||||
|
# custom_exit返回是否有任一回调触发出场
|
||||||
|
if hook_name == 'custom_exit':
|
||||||
|
return any(results)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def clear(self, hook_name: str = None) -> None:
|
||||||
|
"""清空回调"""
|
||||||
|
if hook_name:
|
||||||
|
self._hooks[hook_name] = []
|
||||||
|
else:
|
||||||
|
for key in self._hooks:
|
||||||
|
self._hooks[key] = []
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷回调函数
|
||||||
|
def premium_filter_callback(threshold: float = 0.10):
|
||||||
|
"""溢价过滤回调"""
|
||||||
|
def callback(code: str, price: float, **kwargs) -> bool:
|
||||||
|
premium = kwargs.get('premium', 0)
|
||||||
|
if premium > threshold:
|
||||||
|
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
|
||||||
|
"""崩盘过滤回调"""
|
||||||
|
def callback(code: str, price: float, **kwargs) -> bool:
|
||||||
|
history = kwargs.get('history', None)
|
||||||
|
if history is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 检查最近N天是否有崩盘
|
||||||
|
recent = history.tail(lookback)
|
||||||
|
if len(recent) < lookback:
|
||||||
|
return True
|
||||||
|
|
||||||
|
returns = recent['close'].pct_change()
|
||||||
|
min_return = returns.min()
|
||||||
|
|
||||||
|
if min_return < -crash_threshold:
|
||||||
|
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def holding_time_stoploss_callback(
|
||||||
|
day_5_stoploss: float = -0.05,
|
||||||
|
day_10_stoploss: float = -0.03
|
||||||
|
):
|
||||||
|
"""持仓时间动态止损回调"""
|
||||||
|
def callback(position: Position) -> float:
|
||||||
|
if position.holding_days >= 10:
|
||||||
|
return day_10_stoploss # 10天后收紧止损
|
||||||
|
elif position.holding_days >= 5:
|
||||||
|
return day_5_stoploss
|
||||||
|
return -0.10 # 默认止损
|
||||||
|
return callback
|
||||||
293
framework/tests/test_risk.py
Normal file
293
framework/tests/test_risk.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""
|
||||||
|
风控层测试
|
||||||
|
|
||||||
|
测试RiskControl、StopLossControl、PositionLimitControl、CallbackHook
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pytest
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from framework.risk import (
|
||||||
|
RiskControl, StopLossControl, PositionLimitControl, PremiumControl,
|
||||||
|
CallbackHook, Position, Trade,
|
||||||
|
premium_filter_callback, crash_filter_callback, holding_time_stoploss_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPosition:
|
||||||
|
"""测试持仓信息"""
|
||||||
|
|
||||||
|
def test_position_profit(self):
|
||||||
|
"""测试盈亏计算"""
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=110.0,
|
||||||
|
current_date=datetime(2020, 1, 10),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
|
||||||
|
assert position.profit_ratio == 0.10
|
||||||
|
assert position.is_profit == True
|
||||||
|
assert position.holding_days == 9
|
||||||
|
|
||||||
|
def test_position_loss(self):
|
||||||
|
"""测试亏损计算"""
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=95.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
|
||||||
|
assert position.profit_ratio == -0.05
|
||||||
|
assert position.is_profit == False
|
||||||
|
assert position.holding_days == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopLossControl:
|
||||||
|
"""测试止损控制"""
|
||||||
|
|
||||||
|
def test_fixed_stoploss_check(self):
|
||||||
|
"""测试固定止损检查"""
|
||||||
|
control = StopLossControl(threshold=-0.05)
|
||||||
|
|
||||||
|
# 未触发止损
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=96.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
assert control.check(position) == True
|
||||||
|
|
||||||
|
# 触发止损
|
||||||
|
position.current_price = 94.0
|
||||||
|
assert control.check(position) == False
|
||||||
|
|
||||||
|
def test_fixed_stoploss_apply(self):
|
||||||
|
"""测试固定止损应用"""
|
||||||
|
control = StopLossControl(threshold=-0.05)
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=95.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
|
||||||
|
stop_price = control.apply(position)
|
||||||
|
assert stop_price == 95.0 # 100 * (1 - 0.05)
|
||||||
|
|
||||||
|
def test_trailing_stoploss(self):
|
||||||
|
"""测试跟踪止损"""
|
||||||
|
control = StopLossControl(trailing=True, trailing_percent=0.03)
|
||||||
|
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=105.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
|
||||||
|
# 最高价更新为105
|
||||||
|
control.check(position)
|
||||||
|
assert control._highest_price['code1'] == 105.0
|
||||||
|
|
||||||
|
# 当前价回撤到101(从105回撤4%),超过3%阈值
|
||||||
|
position.current_price = 101.0
|
||||||
|
assert control.check(position) == False
|
||||||
|
|
||||||
|
# 止损价格应为 105 * (1 - 0.03) = 101.85
|
||||||
|
stop_price = control.apply(position)
|
||||||
|
assert abs(stop_price - 101.85) < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class TestPositionLimitControl:
|
||||||
|
"""测试仓位限制控制"""
|
||||||
|
|
||||||
|
def test_position_limit_check(self):
|
||||||
|
"""测试仓位限制检查"""
|
||||||
|
control = PositionLimitControl(max_position=0.33)
|
||||||
|
|
||||||
|
# 仓位未超限
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=105.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.30
|
||||||
|
)
|
||||||
|
assert control.check(position) == True
|
||||||
|
|
||||||
|
# 仓位超限
|
||||||
|
position.weight = 0.40
|
||||||
|
assert control.check(position) == False
|
||||||
|
|
||||||
|
def test_position_limit_apply(self):
|
||||||
|
"""测试仓位限制应用"""
|
||||||
|
control = PositionLimitControl(max_position=0.33)
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=105.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.50
|
||||||
|
)
|
||||||
|
|
||||||
|
suggested_weight = control.apply(position)
|
||||||
|
assert suggested_weight == 0.33
|
||||||
|
|
||||||
|
|
||||||
|
class TestPremiumControl:
|
||||||
|
"""测试溢价控制"""
|
||||||
|
|
||||||
|
def test_premium_filter(self):
|
||||||
|
"""测试溢价过滤"""
|
||||||
|
control = PremiumControl(threshold=0.10, mode='filter')
|
||||||
|
|
||||||
|
# 溢价未超限
|
||||||
|
assert control.check(None, premium=0.05) == True
|
||||||
|
|
||||||
|
# 溢价超限
|
||||||
|
assert control.check(None, premium=0.15) == False
|
||||||
|
|
||||||
|
def test_premium_penalize(self):
|
||||||
|
"""测试溢价降权"""
|
||||||
|
control = PremiumControl(threshold=0.10, mode='penalize')
|
||||||
|
|
||||||
|
# 降权模式下允许通过
|
||||||
|
assert control.check(None, premium=0.15) == True
|
||||||
|
|
||||||
|
# 返回降权系数
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=105.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
penalty = control.apply(position)
|
||||||
|
assert penalty == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbackHook:
|
||||||
|
"""测试回调钩子"""
|
||||||
|
|
||||||
|
def test_register_hook(self):
|
||||||
|
"""测试注册回调"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
def dummy_callback(code, price):
|
||||||
|
return True
|
||||||
|
|
||||||
|
hook.register('before_entry', dummy_callback)
|
||||||
|
assert len(hook._hooks['before_entry']) == 1
|
||||||
|
|
||||||
|
def test_trigger_before_entry(self):
|
||||||
|
"""测试触发入场前回调"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
# 注册溢价过滤回调
|
||||||
|
hook.register('before_entry', premium_filter_callback(threshold=0.10))
|
||||||
|
|
||||||
|
# 溢价正常,允许入场
|
||||||
|
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# 溢价过高,拒绝入场
|
||||||
|
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
def test_trigger_dynamic_stoploss(self):
|
||||||
|
"""测试触发动态止损回调"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
# 注册持仓时间止损回调
|
||||||
|
hook.register('dynamic_stoploss', holding_time_stoploss_callback())
|
||||||
|
|
||||||
|
# 持仓5天,止损-5%
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=95.0,
|
||||||
|
current_date=datetime(2020, 1, 6),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||||
|
# holding_days=5,返回-0.05
|
||||||
|
assert stoploss == -0.05
|
||||||
|
|
||||||
|
def test_trigger_custom_exit(self):
|
||||||
|
"""测试触发自定义出场回调"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
def reversal_exit_callback(position):
|
||||||
|
# 反转信号触发出场
|
||||||
|
return position.profit_ratio < -0.02
|
||||||
|
|
||||||
|
hook.register('custom_exit', reversal_exit_callback)
|
||||||
|
|
||||||
|
# 未触发出场
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=99.0,
|
||||||
|
current_date=datetime(2020, 1, 5),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
result = hook.trigger('custom_exit', position)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
# 触发出场
|
||||||
|
position.current_price = 97.0
|
||||||
|
result = hook.trigger('custom_exit', position)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
def test_multiple_callbacks(self):
|
||||||
|
"""测试多个回调组合"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
# 注册多个入场前回调
|
||||||
|
hook.register('before_entry', premium_filter_callback(0.10))
|
||||||
|
hook.register('before_entry', lambda code, price, **kwargs: price > 50)
|
||||||
|
|
||||||
|
# 溢价正常 + 价格>50,允许入场
|
||||||
|
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# 溢价过高,拒绝入场(任一回调返回False)
|
||||||
|
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.15)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
# 价格过低,拒绝入场
|
||||||
|
result = hook.trigger('before_entry', 'code1', 40.0, premium=0.05)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user