Files
etf/framework/risk/__init__.py
aszerW 30ea2970bd refactor(framework): 框架只保留抽象接口,具体实现移至strategies/shared
- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口
- SignalGenerator: 信号生成抽象接口
- RiskControl/Position/CallbackHook: 风控抽象接口
- StrategyBase: 策略抽象基类
- Executor/Portfolio: 执行器抽象接口
- ConfigLoader: 配置加载器
- 删除framework/factors/momentum.py(具体实现)
2026-05-11 23:09:01 +08:00

188 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
风控层抽象接口(通用)
只提供抽象基类和回调机制具体风控组件在strategies/shared/risk/
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any, Callable, Optional
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""
持仓数据结构(通用)
用于表示单个持仓的状态
"""
code: str
entry_price: float
current_price: float
entry_time: datetime
quantity: float = 1.0
weight: float = 1.0
@property
def profit_ratio(self) -> float:
"""计算盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def profit_amount(self) -> float:
"""计算盈亏金额"""
return (self.current_price - self.entry_price) * self.quantity
@property
def holding_days(self) -> int:
"""计算持仓天数"""
if self.entry_time is None:
return 0
return (datetime.now() - self.entry_time).days
def __repr__(self) -> str:
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
class RiskControl(ABC):
"""
风控组件抽象基类
所有风控组件必须实现check方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化风控参数"""
self._params = params
@abstractmethod
def check(self, position: Position, **kwargs) -> bool:
"""
检查风控条件
Args:
position: 持仓对象
kwargs: 额外参数如premium、history等
Returns:
True表示通过检查False表示触发风控
"""
pass
def apply(self, position: Position) -> Any:
"""
应用风控规则(可选)
Args:
position: 持仓对象
Returns:
风控结果(如止损价格、建议仓位等)
"""
return None
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class CallbackHook:
"""
回调钩子管理(通用)
支持在策略生命周期的关键节点注入自定义逻辑
"""
SUPPORTED_HOOKS = [
'before_entry', # 入场前检查
'after_entry', # 入场后处理
'before_exit', # 出场前检查
'after_exit', # 出场后处理
'dynamic_stoploss', # 动态止损计算
'custom_exit' # 自定义出场条件
]
def __init__(self):
"""初始化回调钩子"""
self._hooks: Dict[str, List[Callable]] = {
hook: [] for hook in self.SUPPORTED_HOOKS
}
def register(self, hook_name: str, callback: Callable) -> None:
"""注册回调函数"""
if hook_name not in self._hooks:
raise ValueError(f"Unsupported hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""
触发回调
Args:
hook_name: 钩子名称
args: 位置参数
kwargs: 关键字参数
Returns:
回调结果(根据钩子类型返回不同结果)
"""
if hook_name not in self._hooks:
raise ValueError(f"Unsupported hook: {hook_name}")
callbacks = self._hooks[hook_name]
if not callbacks:
# 默认行为
if hook_name == 'dynamic_stoploss':
return kwargs.get('default_stoploss', -0.05)
elif hook_name in ['before_entry', 'before_exit']:
return True
elif hook_name == 'custom_exit':
return False
return None
results = []
for callback in callbacks:
result = callback(*args, **kwargs)
results.append(result)
# 根据钩子类型返回不同结果
if hook_name in ['before_entry', 'before_exit']:
# 所有回调返回True才允许
return all(results)
if hook_name == 'dynamic_stoploss':
# 返回最小止损值(最严格)
return min(results)
if hook_name == 'custom_exit':
# 任一回调触发出场
return any(results)
# 其他钩子返回最后一个结果
return results[-1] if results else None
def clear(self, hook_name: Optional[str] = None) -> None:
"""清空回调"""
if hook_name:
if hook_name in self._hooks:
self._hooks[hook_name] = []
else:
for hook in self._hooks:
self._hooks[hook] = []
def list_hooks(self) -> List[str]:
"""列出支持的钩子"""
return self.SUPPORTED_HOOKS
def get_callbacks(self, hook_name: str) -> List[Callable]:
"""获取钩子的所有回调"""
return self._hooks.get(hook_name, [])
# 导出抽象接口
__all__ = ['Position', 'RiskControl', 'CallbackHook']