refactor(framework): 框架只保留抽象接口,具体实现移至strategies/shared

- FactorBase/FactorRegistry/FactorCombiner: 因子抽象接口
- SignalGenerator: 信号生成抽象接口
- RiskControl/Position/CallbackHook: 风控抽象接口
- StrategyBase: 策略抽象基类
- Executor/Portfolio: 执行器抽象接口
- ConfigLoader: 配置加载器
- 删除framework/factors/momentum.py(具体实现)
This commit is contained in:
2026-05-11 23:09:01 +08:00
parent 9a8a0d7c72
commit 30ea2970bd
8 changed files with 503 additions and 1516 deletions

View File

@@ -1,63 +1,30 @@
"""
策略基类与配置
策略层抽象基类(通用)
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
- ConfigLoader: 配置加载器
只提供抽象接口具体策略实现在strategies/
"""
import yaml
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Optional, Any
import pandas as pd
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor
from framework.signals import SignalGenerator, TopNSelector
from framework.factors import FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: List[Dict]
signal: Dict
callbacks: Dict
params: Dict
class StrategyBase(ABC):
"""
策略抽象基类
融合Freqtrade回调机制 + 模块化因子设计
类属性(可被配置覆盖):
- name: 策略名称
- version: 接口版本
- timeframe: K线周期
- select_num: 选中数量
- stoploss: 止损比例
回调钩子(可选实现):
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场条件
所有策略必须实现init_factors和init_signal_generator方法
"""
# 接口版本
INTERFACE_VERSION = 1
# 类属性(可被配置覆盖)
name: str = "base"
timeframe: str = "1d"
# 类属性(可被配置覆盖)
select_num: int = 3
stoploss: float = -0.05
@@ -66,53 +33,50 @@ class StrategyBase(ABC):
初始化策略
Args:
config: 策略配置(可覆盖类属性)
config: 配置字典(可选,用于覆盖类属性)
"""
# 配置覆盖类属性
if config:
self._apply_config(config)
# 初始化回调钩子
self._callbacks = CallbackHook()
self._register_default_callbacks()
# 初始化因子和信号生成器
self._factors = None
self._signal_gen = None
self._factors = self.init_factors()
self._signal_gen = self.init_signal_generator()
def _apply_config(self, config: Dict) -> None:
"""应用配置"""
params = config.get('params', {})
# 覆盖类属性
for key, value in params.items():
"""应用配置覆盖类属性"""
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
# 保存完整配置
self._config = config
def _register_default_callbacks(self) -> None:
"""注册默认回调"""
# 注册入场前回调(溢价过滤)
"""注册默认回调方法"""
if hasattr(self, 'before_entry'):
self._callbacks.register('before_entry', self.before_entry)
# 注册动态止损回调
if hasattr(self, 'after_entry'):
self._callbacks.register('after_entry', self.after_entry)
if hasattr(self, 'before_exit'):
self._callbacks.register('before_exit', self.before_exit)
if hasattr(self, 'after_exit'):
self._callbacks.register('after_exit', self.after_exit)
if hasattr(self, 'dynamic_stoploss'):
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
# 注册自定义出场回调
if hasattr(self, 'custom_exit'):
self._callbacks.register('custom_exit', self.custom_exit)
@abstractmethod
def init_factors(self) -> FactorCombiner:
"""
初始化因子组合
初始化因子组合
Returns:
因子组合器
FactorCombiner实例
"""
pass
@@ -122,7 +86,7 @@ class StrategyBase(ABC):
初始化信号生成器
Returns:
信号生成器
SignalGenerator实例
"""
pass
@@ -131,85 +95,28 @@ class StrategyBase(ABC):
运行策略
Args:
data: 输入数据
data: OHLCV数据
Returns:
包含信号的DataFrame
"""
# 初始化因子和信号生成器
if self._factors is None:
self._factors = self.init_factors()
if self._signal_gen is None:
self._signal_gen = self.init_signal_generator()
# 1. 计算因子
factor_data = self._factors.compute(data)
# 2. 生成信号
signals = self._signal_gen.generate(factor_data)
# 3. 应用回调钩子
signals = self._apply_callbacks(signals, data)
return signals
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
"""应用回调钩子"""
# 遍历每行信号
for date in signals.index:
signal = signals.loc[date, 'signal']
if not signal or pd.isna(signal):
continue
# 解析信号(逗号分隔的代码)
codes = signal.split(',')
# 应用入场前回调
for code in codes:
if code not in data.columns:
continue
price = data.loc[date, code]
premium = 0.0 # TODO: 从溢价数据获取
# 触发回调
allowed = self._callbacks.trigger(
'before_entry',
code,
price,
premium=premium,
history=data
)
if not allowed:
# 移除被拒绝的代码
codes.remove(code)
# 更新信号
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
"""应用回调处理"""
return signals
# ===== 可选回调方法 =====
# 可选回调方法(子类可覆盖)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""
入场前检查
Args:
code: 标的代码
price: 入场价格
**kwargs: 其他参数premium, history等
Returns:
是否允许入场
"""
# 默认:允许入场
"""入场前检查"""
return True
def after_entry(self, trade, **kwargs) -> None:
def after_entry(self, code: str, price: float, **kwargs) -> None:
"""入场后处理"""
pass
@@ -217,141 +124,21 @@ class StrategyBase(ABC):
"""出场前检查"""
return True
def after_exit(self, trade, **kwargs) -> None:
def after_exit(self, position: Position, **kwargs) -> None:
"""出场后处理"""
pass
def dynamic_stoploss(self, position: Position) -> float:
"""
动态止损
Args:
position: 持仓信息
Returns:
止损比例
"""
# 默认:返回固定止损
"""动态止损"""
return self.stoploss
def custom_exit(self, position: Position) -> bool:
"""
自定义出场条件
Args:
position: 持仓信息
Returns:
是否触发出场
"""
"""自定义出场条件"""
return False
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
class ConfigLoader:
"""
配置加载器
支持YAML配置文件加载和验证
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)
# 示例策略实现
class RotationStrategy(StrategyBase):
"""
ETF轮动策略
基于动量因子 + Top N选股
"""
name = "rotation"
select_num = 3
def init_factors(self) -> FactorCombiner:
"""初始化动量因子"""
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
return FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
])
def init_signal_generator(self) -> SignalGenerator:
"""初始化Top N选股器"""
from framework.signals import TopNSelector
return TopNSelector(
select_num=self.select_num,
min_score=0.0
)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前:溢价过滤"""
premium = kwargs.get('premium', 0)
# 溢价超过10%拒绝入场
if premium > 0.10:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损:根据持仓时间调整"""
if position.holding_days >= 10:
return -0.03 # 10天后收紧止损
elif position.holding_days >= 5:
return -0.05
return -0.10
# 导出抽象接口
__all__ = ['StrategyBase']