- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口 - SignalGenerator: 信号生成抽象接口 - RiskControl/Position/CallbackHook: 风控抽象接口 - StrategyBase: 策略抽象基类 - Executor/Portfolio: 执行器抽象接口 - ConfigLoader: 配置加载器 - 删除framework/factors/momentum.py(具体实现)
188 lines
5.2 KiB
Python
188 lines
5.2 KiB
Python
"""
|
||
风控层抽象接口(通用)
|
||
|
||
只提供抽象基类和回调机制,具体风控组件在strategies/shared/risk/
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, List, Any, Callable, Optional
|
||
from dataclasses import dataclass
|
||
from datetime import datetime
|
||
|
||
|
||
@dataclass
|
||
class Position:
|
||
"""
|
||
持仓数据结构(通用)
|
||
|
||
用于表示单个持仓的状态
|
||
"""
|
||
code: str
|
||
entry_price: float
|
||
current_price: float
|
||
entry_time: datetime
|
||
quantity: float = 1.0
|
||
weight: float = 1.0
|
||
|
||
@property
|
||
def profit_ratio(self) -> float:
|
||
"""计算盈亏比例"""
|
||
return (self.current_price - self.entry_price) / self.entry_price
|
||
|
||
@property
|
||
def profit_amount(self) -> float:
|
||
"""计算盈亏金额"""
|
||
return (self.current_price - self.entry_price) * self.quantity
|
||
|
||
@property
|
||
def holding_days(self) -> int:
|
||
"""计算持仓天数"""
|
||
if self.entry_time is None:
|
||
return 0
|
||
return (datetime.now() - self.entry_time).days
|
||
|
||
def __repr__(self) -> str:
|
||
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
|
||
|
||
|
||
class RiskControl(ABC):
|
||
"""
|
||
风控组件抽象基类
|
||
|
||
所有风控组件必须实现check方法
|
||
"""
|
||
|
||
name: str = "base"
|
||
|
||
def __init__(self, **params):
|
||
"""初始化风控参数"""
|
||
self._params = params
|
||
|
||
@abstractmethod
|
||
def check(self, position: Position, **kwargs) -> bool:
|
||
"""
|
||
检查风控条件
|
||
|
||
Args:
|
||
position: 持仓对象
|
||
kwargs: 额外参数(如premium、history等)
|
||
|
||
Returns:
|
||
True表示通过检查,False表示触发风控
|
||
"""
|
||
pass
|
||
|
||
def apply(self, position: Position) -> Any:
|
||
"""
|
||
应用风控规则(可选)
|
||
|
||
Args:
|
||
position: 持仓对象
|
||
|
||
Returns:
|
||
风控结果(如止损价格、建议仓位等)
|
||
"""
|
||
return None
|
||
|
||
def __repr__(self) -> str:
|
||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||
return f"{self.__class__.__name__}({params_str})"
|
||
|
||
|
||
class CallbackHook:
|
||
"""
|
||
回调钩子管理(通用)
|
||
|
||
支持在策略生命周期的关键节点注入自定义逻辑
|
||
"""
|
||
|
||
SUPPORTED_HOOKS = [
|
||
'before_entry', # 入场前检查
|
||
'after_entry', # 入场后处理
|
||
'before_exit', # 出场前检查
|
||
'after_exit', # 出场后处理
|
||
'dynamic_stoploss', # 动态止损计算
|
||
'custom_exit' # 自定义出场条件
|
||
]
|
||
|
||
def __init__(self):
|
||
"""初始化回调钩子"""
|
||
self._hooks: Dict[str, List[Callable]] = {
|
||
hook: [] for hook in self.SUPPORTED_HOOKS
|
||
}
|
||
|
||
def register(self, hook_name: str, callback: Callable) -> None:
|
||
"""注册回调函数"""
|
||
if hook_name not in self._hooks:
|
||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||
|
||
self._hooks[hook_name].append(callback)
|
||
|
||
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
|
||
"""
|
||
触发回调
|
||
|
||
Args:
|
||
hook_name: 钩子名称
|
||
args: 位置参数
|
||
kwargs: 关键字参数
|
||
|
||
Returns:
|
||
回调结果(根据钩子类型返回不同结果)
|
||
"""
|
||
if hook_name not in self._hooks:
|
||
raise ValueError(f"Unsupported hook: {hook_name}")
|
||
|
||
callbacks = self._hooks[hook_name]
|
||
|
||
if not callbacks:
|
||
# 默认行为
|
||
if hook_name == 'dynamic_stoploss':
|
||
return kwargs.get('default_stoploss', -0.05)
|
||
elif hook_name in ['before_entry', 'before_exit']:
|
||
return True
|
||
elif hook_name == 'custom_exit':
|
||
return False
|
||
return None
|
||
|
||
results = []
|
||
for callback in callbacks:
|
||
result = callback(*args, **kwargs)
|
||
results.append(result)
|
||
|
||
# 根据钩子类型返回不同结果
|
||
if hook_name in ['before_entry', 'before_exit']:
|
||
# 所有回调返回True才允许
|
||
return all(results)
|
||
|
||
if hook_name == 'dynamic_stoploss':
|
||
# 返回最小止损值(最严格)
|
||
return min(results)
|
||
|
||
if hook_name == 'custom_exit':
|
||
# 任一回调触发出场
|
||
return any(results)
|
||
|
||
# 其他钩子返回最后一个结果
|
||
return results[-1] if results else None
|
||
|
||
def clear(self, hook_name: Optional[str] = None) -> None:
|
||
"""清空回调"""
|
||
if hook_name:
|
||
if hook_name in self._hooks:
|
||
self._hooks[hook_name] = []
|
||
else:
|
||
for hook in self._hooks:
|
||
self._hooks[hook] = []
|
||
|
||
def list_hooks(self) -> List[str]:
|
||
"""列出支持的钩子"""
|
||
return self.SUPPORTED_HOOKS
|
||
|
||
def get_callbacks(self, hook_name: str) -> List[Callable]:
|
||
"""获取钩子的所有回调"""
|
||
return self._hooks.get(hook_name, [])
|
||
|
||
|
||
# 导出抽象接口
|
||
__all__ = ['Position', 'RiskControl', 'CallbackHook'] |