核心组件: - StrategyBase: 策略抽象基类(含回调钩子) - 类属性(可被配置覆盖) - init_factors(): 初始化因子组合 - init_signal_generator(): 初始化信号生成器 - run(): 运行策略 - RotationStrategy: 轮动策略示例实现 - 动量因子 + TopN选股 - before_entry回调(溢价过滤) - dynamic_stoploss回调(持仓时间动态止损) - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 特点: - 配置覆盖类属性 - 回调自动注册 - 策略工厂模式 测试覆盖:8个测试全部通过
357 lines
9.6 KiB
Python
357 lines
9.6 KiB
Python
"""
|
||
策略基类与配置层
|
||
|
||
核心组件:
|
||
- StrategyBase: 策略抽象基类(含回调钩子)
|
||
- ConfigLoader: 配置加载器
|
||
"""
|
||
|
||
import yaml
|
||
import pandas as pd
|
||
from abc import ABC, abstractmethod
|
||
from typing import Dict, Any, Optional, List
|
||
from pathlib import Path
|
||
from dataclasses import dataclass
|
||
|
||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||
from framework.factors.momentum import MomentumFactor
|
||
from framework.signals import SignalGenerator, TopNSelector
|
||
from framework.risk import CallbackHook, Position
|
||
|
||
|
||
@dataclass
|
||
class StrategyConfig:
|
||
"""策略配置"""
|
||
name: str
|
||
version: int
|
||
factors: List[Dict]
|
||
signal: Dict
|
||
callbacks: Dict
|
||
params: Dict
|
||
|
||
|
||
class StrategyBase(ABC):
|
||
"""
|
||
策略抽象基类
|
||
|
||
融合Freqtrade回调机制 + 模块化因子设计
|
||
|
||
类属性(可被配置覆盖):
|
||
- name: 策略名称
|
||
- version: 接口版本
|
||
- timeframe: K线周期
|
||
- select_num: 选中数量
|
||
- stoploss: 止损比例
|
||
|
||
回调钩子(可选实现):
|
||
- before_entry: 入场前检查
|
||
- after_entry: 入场后处理
|
||
- before_exit: 出场前检查
|
||
- after_exit: 出场后处理
|
||
- dynamic_stoploss: 动态止损
|
||
- custom_exit: 自定义出场条件
|
||
"""
|
||
|
||
# 接口版本
|
||
INTERFACE_VERSION = 1
|
||
|
||
# 类属性(可被配置覆盖)
|
||
name: str = "base"
|
||
timeframe: str = "1d"
|
||
select_num: int = 3
|
||
stoploss: float = -0.05
|
||
|
||
def __init__(self, config: Optional[Dict] = None):
|
||
"""
|
||
初始化策略
|
||
|
||
Args:
|
||
config: 策略配置(可覆盖类属性)
|
||
"""
|
||
# 配置覆盖类属性
|
||
if config:
|
||
self._apply_config(config)
|
||
|
||
# 初始化回调钩子
|
||
self._callbacks = CallbackHook()
|
||
self._register_default_callbacks()
|
||
|
||
# 初始化因子和信号生成器
|
||
self._factors = None
|
||
self._signal_gen = None
|
||
|
||
def _apply_config(self, config: Dict) -> None:
|
||
"""应用配置"""
|
||
params = config.get('params', {})
|
||
|
||
# 覆盖类属性
|
||
for key, value in params.items():
|
||
if hasattr(self, key):
|
||
setattr(self, key, value)
|
||
|
||
# 保存完整配置
|
||
self._config = config
|
||
|
||
def _register_default_callbacks(self) -> None:
|
||
"""注册默认回调"""
|
||
# 注册入场前回调(溢价过滤)
|
||
if hasattr(self, 'before_entry'):
|
||
self._callbacks.register('before_entry', self.before_entry)
|
||
|
||
# 注册动态止损回调
|
||
if hasattr(self, 'dynamic_stoploss'):
|
||
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
|
||
|
||
# 注册自定义出场回调
|
||
if hasattr(self, 'custom_exit'):
|
||
self._callbacks.register('custom_exit', self.custom_exit)
|
||
|
||
@abstractmethod
|
||
def init_factors(self) -> FactorCombiner:
|
||
"""
|
||
初始化因子组合
|
||
|
||
Returns:
|
||
因子组合器
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def init_signal_generator(self) -> SignalGenerator:
|
||
"""
|
||
初始化信号生成器
|
||
|
||
Returns:
|
||
信号生成器
|
||
"""
|
||
pass
|
||
|
||
def run(self, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
运行策略
|
||
|
||
Args:
|
||
data: 输入数据
|
||
|
||
Returns:
|
||
包含信号的DataFrame
|
||
"""
|
||
# 初始化因子和信号生成器
|
||
if self._factors is None:
|
||
self._factors = self.init_factors()
|
||
|
||
if self._signal_gen is None:
|
||
self._signal_gen = self.init_signal_generator()
|
||
|
||
# 1. 计算因子
|
||
factor_data = self._factors.compute(data)
|
||
|
||
# 2. 生成信号
|
||
signals = self._signal_gen.generate(factor_data)
|
||
|
||
# 3. 应用回调钩子
|
||
signals = self._apply_callbacks(signals, data)
|
||
|
||
return signals
|
||
|
||
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""应用回调钩子"""
|
||
# 遍历每行信号
|
||
for date in signals.index:
|
||
signal = signals.loc[date, 'signal']
|
||
|
||
if not signal or pd.isna(signal):
|
||
continue
|
||
|
||
# 解析信号(逗号分隔的代码)
|
||
codes = signal.split(',')
|
||
|
||
# 应用入场前回调
|
||
for code in codes:
|
||
if code not in data.columns:
|
||
continue
|
||
|
||
price = data.loc[date, code]
|
||
premium = 0.0 # TODO: 从溢价数据获取
|
||
|
||
# 触发回调
|
||
allowed = self._callbacks.trigger(
|
||
'before_entry',
|
||
code,
|
||
price,
|
||
premium=premium,
|
||
history=data
|
||
)
|
||
|
||
if not allowed:
|
||
# 移除被拒绝的代码
|
||
codes.remove(code)
|
||
|
||
# 更新信号
|
||
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
|
||
|
||
return signals
|
||
|
||
# ===== 可选回调方法 =====
|
||
|
||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||
"""
|
||
入场前检查
|
||
|
||
Args:
|
||
code: 标的代码
|
||
price: 入场价格
|
||
**kwargs: 其他参数(premium, history等)
|
||
|
||
Returns:
|
||
是否允许入场
|
||
"""
|
||
# 默认:允许入场
|
||
return True
|
||
|
||
def after_entry(self, trade, **kwargs) -> None:
|
||
"""入场后处理"""
|
||
pass
|
||
|
||
def before_exit(self, position: Position, **kwargs) -> bool:
|
||
"""出场前检查"""
|
||
return True
|
||
|
||
def after_exit(self, trade, **kwargs) -> None:
|
||
"""出场后处理"""
|
||
pass
|
||
|
||
def dynamic_stoploss(self, position: Position) -> float:
|
||
"""
|
||
动态止损
|
||
|
||
Args:
|
||
position: 持仓信息
|
||
|
||
Returns:
|
||
止损比例
|
||
"""
|
||
# 默认:返回固定止损
|
||
return self.stoploss
|
||
|
||
def custom_exit(self, position: Position) -> bool:
|
||
"""
|
||
自定义出场条件
|
||
|
||
Args:
|
||
position: 持仓信息
|
||
|
||
Returns:
|
||
是否触发出场
|
||
"""
|
||
return False
|
||
|
||
|
||
class ConfigLoader:
|
||
"""
|
||
配置加载器
|
||
|
||
支持YAML配置文件加载和验证
|
||
"""
|
||
|
||
def __init__(self, config_path: str):
|
||
"""
|
||
初始化配置加载器
|
||
|
||
Args:
|
||
config_path: 配置文件路径
|
||
"""
|
||
self._config_path = Path(config_path)
|
||
self._config = None
|
||
|
||
def load(self) -> Dict:
|
||
"""加载配置"""
|
||
if not self._config_path.exists():
|
||
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||
|
||
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||
self._config = yaml.safe_load(f)
|
||
|
||
return self._config
|
||
|
||
def validate(self) -> bool:
|
||
"""验证配置"""
|
||
if self._config is None:
|
||
self.load()
|
||
|
||
# 必须字段
|
||
required_fields = ['strategy', 'factors', 'signal']
|
||
|
||
for field in required_fields:
|
||
if field not in self._config:
|
||
raise ValueError(f"Missing required field: {field}")
|
||
|
||
return True
|
||
|
||
def get_strategy_config(self) -> StrategyConfig:
|
||
"""获取策略配置"""
|
||
if self._config is None:
|
||
self.load()
|
||
|
||
return StrategyConfig(
|
||
name=self._config['strategy']['name'],
|
||
version=self._config['strategy'].get('version', 1),
|
||
factors=self._config['factors'],
|
||
signal=self._config['signal'],
|
||
callbacks=self._config.get('callbacks', {}),
|
||
params=self._config.get('params', {})
|
||
)
|
||
|
||
@staticmethod
|
||
def from_yaml(yaml_str: str) -> Dict:
|
||
"""从YAML字符串加载"""
|
||
return yaml.safe_load(yaml_str)
|
||
|
||
|
||
# 示例策略实现
|
||
class RotationStrategy(StrategyBase):
|
||
"""
|
||
ETF轮动策略
|
||
|
||
基于动量因子 + Top N选股
|
||
"""
|
||
|
||
name = "rotation"
|
||
select_num = 3
|
||
|
||
def init_factors(self) -> FactorCombiner:
|
||
"""初始化动量因子"""
|
||
FactorRegistry.clear()
|
||
FactorRegistry.register(MomentumFactor)
|
||
|
||
return FactorCombiner([
|
||
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||
])
|
||
|
||
def init_signal_generator(self) -> SignalGenerator:
|
||
"""初始化Top N选股器"""
|
||
from framework.signals import TopNSelector
|
||
|
||
return TopNSelector(
|
||
select_num=self.select_num,
|
||
min_score=0.0
|
||
)
|
||
|
||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||
"""入场前:溢价过滤"""
|
||
premium = kwargs.get('premium', 0)
|
||
|
||
# 溢价超过10%拒绝入场
|
||
if premium > 0.10:
|
||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||
return False
|
||
|
||
return True
|
||
|
||
def dynamic_stoploss(self, position: Position) -> float:
|
||
"""动态止损:根据持仓时间调整"""
|
||
if position.holding_days >= 10:
|
||
return -0.03 # 10天后收紧止损
|
||
elif position.holding_days >= 5:
|
||
return -0.05
|
||
return -0.10 |