refactor(framework): 框架只保留抽象接口,具体实现移至strategies/shared

- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口
- SignalGenerator: 信号生成抽象接口
- RiskControl/Position/CallbackHook: 风控抽象接口
- StrategyBase: 策略抽象基类
- Executor/Portfolio: 执行器抽象接口
- ConfigLoader: 配置加载器
- 删除framework/factors/momentum.py(具体实现)
This commit is contained in:
2026-05-11 23:09:01 +08:00
parent 9a8a0d7c72
commit 30ea2970bd
8 changed files with 503 additions and 1516 deletions

View File

@@ -1,351 +1,188 @@
"""
回调钩子与风控层设计
风控层抽象接口(通用)
核心组件:
- RiskControl: 风控抽象基类
- StopLossControl: 止损控制
- PositionLimitControl: 仓位限制控制
- CallbackHook: 回调钩子管理
只提供抽象基类和回调机制具体风控组件在strategies/shared/risk/
"""
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from typing import Dict, List, Any, Callable, Optional
from dataclasses import dataclass
from datetime import datetime
@dataclass
class Position:
"""持仓信息"""
"""
持仓数据结构(通用)
用于表示单个持仓的状态
"""
code: str
entry_price: float
entry_date: datetime
current_price: float
current_date: datetime
quantity: float
weight: float
entry_time: datetime
quantity: float = 1.0
weight: float = 1.0
@property
def profit_ratio(self) -> float:
"""盈亏比例"""
"""计算盈亏比例"""
return (self.current_price - self.entry_price) / self.entry_price
@property
def holding_days(self) -> int:
"""持仓天数"""
return (self.current_date - self.entry_date).days
def profit_amount(self) -> float:
"""计算盈亏金额"""
return (self.current_price - self.entry_price) * self.quantity
@property
def is_profit(self) -> bool:
"""是否盈利"""
return self.profit_ratio > 0
@dataclass
class Trade:
"""交易信息"""
code: str
direction: str # 'entry' or 'exit'
price: float
date: datetime
quantity: float
reason: str = ""
def holding_days(self) -> int:
"""计算持仓天数"""
if self.entry_time is None:
return 0
return (datetime.now() - self.entry_time).days
def __repr__(self) -> str:
return f"Position(code={self.code}, profit={self.profit_ratio:.2%}, days={self.holding_days})"
class RiskControl(ABC):
"""
风控抽象基类
风控组件抽象基类
所有风控组件必须继承此基类。
所有风控组件必须实现check方法
"""
name: str = "base"
def __init__(self, **params):
"""初始化风控参数"""
self._params = params
@abstractmethod
def check(self, position: Optional[Position], **kwargs) -> bool:
def check(self, position: Position, **kwargs) -> bool:
"""
风控检查
检查风控条件
Args:
position: 持仓信息(可选)
**kwargs: 其他参数
position: 持仓对象
kwargs: 额外参数如premium、history等
Returns:
是否通过检查
True表示通过检查False表示触发风控
"""
pass
@abstractmethod
def apply(self, position: Position) -> Optional[float]:
def apply(self, position: Position) -> Any:
"""
应用风控
应用风控规则(可选)
Args:
position: 持仓信息
position: 持仓对象
Returns:
应用结果(如止损价格、仓位调整比例等)
风控结果(如止损价格、建议仓位等)
"""
pass
@property
def params(self) -> Dict[str, Any]:
return self._params
class StopLossControl(RiskControl):
"""
止损控制
参数:
- threshold: 止损阈值(默认-0.05
- trailing: 是否跟踪止损默认False
- trailing_percent: 跟踪止损比例默认0.03
"""
name = "stop_loss"
def __init__(
self,
threshold: float = -0.05,
trailing: bool = False,
trailing_percent: float = 0.03
):
super().__init__(
threshold=threshold,
trailing=trailing,
trailing_percent=trailing_percent
)
self.threshold = threshold
self.trailing = trailing
self.trailing_percent = trailing_percent
self._highest_price = {} # 跟踪最高价
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查是否触发止损"""
if position is None:
return True
# 更新最高价(跟踪止损)
if self.trailing:
if position.code not in self._highest_price:
self._highest_price[position.code] = position.entry_price
self._highest_price[position.code] = max(
self._highest_price[position.code],
position.current_price
)
# 检查止损
if self.trailing:
# 跟踪止损:从最高价回撤超过阈值
highest = self._highest_price[position.code]
drawdown = (position.current_price - highest) / highest
return drawdown > -self.trailing_percent
else:
# 固定止损:从入场价亏损超过阈值
return position.profit_ratio > self.threshold
def apply(self, position: Position) -> Optional[float]:
"""返回止损价格"""
if self.trailing:
highest = self._highest_price.get(position.code, position.entry_price)
return highest * (1 - self.trailing_percent)
else:
return position.entry_price * (1 + self.threshold)
class PositionLimitControl(RiskControl):
"""
仓位限制控制
参数:
- max_position: 单品种最大仓位默认0.33
- max_total: 总仓位上限默认1.0
"""
name = "position_limit"
def __init__(
self,
max_position: float = 0.33,
max_total: float = 1.0
):
super().__init__(
max_position=max_position,
max_total=max_total
)
self.max_position = max_position
self.max_total = max_total
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查仓位是否超限"""
if position is None:
return True
# 检查单品种仓位
if position.weight > self.max_position:
return False
return True
def apply(self, position: Position) -> Optional[float]:
"""返回建议仓位"""
return min(position.weight, self.max_position)
class PremiumControl(RiskControl):
"""
溢价控制
参数:
- threshold: 溢价阈值默认0.10
- mode: 控制模式('filter''penalize'
"""
name = "premium"
def __init__(
self,
threshold: float = 0.10,
mode: str = 'filter'
):
super().__init__(
threshold=threshold,
mode=mode
)
self.threshold = threshold
self.mode = mode
def check(self, position: Optional[Position], **kwargs) -> bool:
"""检查溢价是否超限"""
premium = kwargs.get('premium', 0)
if self.mode == 'filter':
# 完全排除
return premium <= self.threshold
else:
# 仅降权,允许通过
return True
def apply(self, position: Position) -> Optional[float]:
"""返回溢价惩罚系数"""
if self.mode == 'penalize':
return 0.5 # 降权50%
return None
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class CallbackHook:
"""
回调钩子管理
回调钩子管理(通用)
支持策略生命周期回调:
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场
支持策略生命周期的关键节点注入自定义逻辑
"""
SUPPORTED_HOOKS = [
'before_entry', # 入场前检查
'after_entry', # 入场后处理
'before_exit', # 出场前检查
'after_exit', # 出场后处理
'dynamic_stoploss', # 动态止损计算
'custom_exit' # 自定义出场条件
]
def __init__(self):
self._hooks = {
'before_entry': [],
'after_entry': [],
'before_exit': [],
'after_exit': [],
'dynamic_stoploss': [],
'custom_exit': []
"""初始化回调钩子"""
self._hooks: Dict[str, List[Callable]] = {
hook: [] for hook in self.SUPPORTED_HOOKS
}
def register(self, hook_name: str, callback: callable) -> None:
"""注册回调"""
def register(self, hook_name: str, callback: Callable) -> None:
"""注册回调函数"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
raise ValueError(f"Unsupported hook: {hook_name}")
self._hooks[hook_name].append(callback)
def trigger(self, hook_name: str, *args, **kwargs) -> Any:
"""触发回调"""
"""
触发回调
Args:
hook_name: 钩子名称
args: 位置参数
kwargs: 关键字参数
Returns:
回调结果(根据钩子类型返回不同结果)
"""
if hook_name not in self._hooks:
raise ValueError(f"Unknown hook: {hook_name}")
raise ValueError(f"Unsupported hook: {hook_name}")
callbacks = self._hooks[hook_name]
if not callbacks:
# 默认行为
if hook_name == 'dynamic_stoploss':
return kwargs.get('default_stoploss', -0.05)
elif hook_name in ['before_entry', 'before_exit']:
return True
elif hook_name == 'custom_exit':
return False
return None
results = []
for callback in self._hooks[hook_name]:
try:
result = callback(*args, **kwargs)
results.append(result)
except Exception as e:
print(f"⚠ Callback error: {e}")
for callback in callbacks:
result = callback(*args, **kwargs)
results.append(result)
# before_entry和before_exit需要所有回调返回True
# 根据钩子类型返回不同结果
if hook_name in ['before_entry', 'before_exit']:
# 所有回调返回True才允许
return all(results)
# dynamic_stoploss返回最小的止损值
if hook_name == 'dynamic_stoploss':
return min(results) if results else -0.05
# 返回最小止损值(最严格)
return min(results)
# custom_exit返回是否有任一回调触发出场
if hook_name == 'custom_exit':
# 任一回调触发出场
return any(results)
return results
# 其他钩子返回最后一个结果
return results[-1] if results else None
def clear(self, hook_name: str = None) -> None:
def clear(self, hook_name: Optional[str] = None) -> None:
"""清空回调"""
if hook_name:
self._hooks[hook_name] = []
if hook_name in self._hooks:
self._hooks[hook_name] = []
else:
for key in self._hooks:
self._hooks[key] = []
for hook in self._hooks:
self._hooks[hook] = []
def list_hooks(self) -> List[str]:
"""列出支持的钩子"""
return self.SUPPORTED_HOOKS
def get_callbacks(self, hook_name: str) -> List[Callable]:
"""获取钩子的所有回调"""
return self._hooks.get(hook_name, [])
# 便捷回调函数
def premium_filter_callback(threshold: float = 0.10):
"""溢价过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
premium = kwargs.get('premium', 0)
if premium > threshold:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
return callback
def crash_filter_callback(lookback: int = 3, crash_threshold: float = 0.05):
"""崩盘过滤回调"""
def callback(code: str, price: float, **kwargs) -> bool:
history = kwargs.get('history', None)
if history is None:
return True
# 检查最近N天是否有崩盘
recent = history.tail(lookback)
if len(recent) < lookback:
return True
returns = recent['close'].pct_change()
min_return = returns.min()
if min_return < -crash_threshold:
print(f"崩盘检测,拒绝入场: {code} (最大跌幅={min_return:.2%})")
return False
return True
return callback
def holding_time_stoploss_callback(
day_5_stoploss: float = -0.05,
day_10_stoploss: float = -0.03
):
"""持仓时间动态止损回调"""
def callback(position: Position) -> float:
if position.holding_days >= 10:
return day_10_stoploss # 10天后收紧止损
elif position.holding_days >= 5:
return day_5_stoploss
return -0.10 # 默认止损
return callback
# 导出抽象接口
__all__ = ['Position', 'RiskControl', 'CallbackHook']