Files
etf/framework/strategy/__init__.py
aszerW 7468130450 feat(strategy): 实现策略层与配置加载
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
  - 类属性(可被配置覆盖)
  - init_factors(): 初始化因子组合
  - init_signal_generator(): 初始化信号生成器
  - run(): 运行策略
- RotationStrategy: 轮动策略示例实现
  - 动量因子 + TopN选股
  - before_entry回调(溢价过滤)
  - dynamic_stoploss回调(持仓时间动态止损)
- ConfigLoader: 配置加载器(YAML支持)
- StrategyConfig: 策略配置数据类

特点:
- 配置覆盖类属性
- 回调自动注册
- 策略工厂模式

测试覆盖:8个测试全部通过
2026-05-11 22:18:55 +08:00

357 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略基类与配置层
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
- ConfigLoader: 配置加载器
"""
import yaml
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor
from framework.signals import SignalGenerator, TopNSelector
from framework.risk import CallbackHook, Position
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: List[Dict]
signal: Dict
callbacks: Dict
params: Dict
class StrategyBase(ABC):
"""
策略抽象基类
融合Freqtrade回调机制 + 模块化因子设计
类属性(可被配置覆盖):
- name: 策略名称
- version: 接口版本
- timeframe: K线周期
- select_num: 选中数量
- stoploss: 止损比例
回调钩子(可选实现):
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场条件
"""
# 接口版本
INTERFACE_VERSION = 1
# 类属性(可被配置覆盖)
name: str = "base"
timeframe: str = "1d"
select_num: int = 3
stoploss: float = -0.05
def __init__(self, config: Optional[Dict] = None):
"""
初始化策略
Args:
config: 策略配置(可覆盖类属性)
"""
# 配置覆盖类属性
if config:
self._apply_config(config)
# 初始化回调钩子
self._callbacks = CallbackHook()
self._register_default_callbacks()
# 初始化因子和信号生成器
self._factors = None
self._signal_gen = None
def _apply_config(self, config: Dict) -> None:
"""应用配置"""
params = config.get('params', {})
# 覆盖类属性
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
# 保存完整配置
self._config = config
def _register_default_callbacks(self) -> None:
"""注册默认回调"""
# 注册入场前回调(溢价过滤)
if hasattr(self, 'before_entry'):
self._callbacks.register('before_entry', self.before_entry)
# 注册动态止损回调
if hasattr(self, 'dynamic_stoploss'):
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
# 注册自定义出场回调
if hasattr(self, 'custom_exit'):
self._callbacks.register('custom_exit', self.custom_exit)
@abstractmethod
def init_factors(self) -> FactorCombiner:
"""
初始化因子组合
Returns:
因子组合器
"""
pass
@abstractmethod
def init_signal_generator(self) -> SignalGenerator:
"""
初始化信号生成器
Returns:
信号生成器
"""
pass
def run(self, data: pd.DataFrame) -> pd.DataFrame:
"""
运行策略
Args:
data: 输入数据
Returns:
包含信号的DataFrame
"""
# 初始化因子和信号生成器
if self._factors is None:
self._factors = self.init_factors()
if self._signal_gen is None:
self._signal_gen = self.init_signal_generator()
# 1. 计算因子
factor_data = self._factors.compute(data)
# 2. 生成信号
signals = self._signal_gen.generate(factor_data)
# 3. 应用回调钩子
signals = self._apply_callbacks(signals, data)
return signals
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
"""应用回调钩子"""
# 遍历每行信号
for date in signals.index:
signal = signals.loc[date, 'signal']
if not signal or pd.isna(signal):
continue
# 解析信号(逗号分隔的代码)
codes = signal.split(',')
# 应用入场前回调
for code in codes:
if code not in data.columns:
continue
price = data.loc[date, code]
premium = 0.0 # TODO: 从溢价数据获取
# 触发回调
allowed = self._callbacks.trigger(
'before_entry',
code,
price,
premium=premium,
history=data
)
if not allowed:
# 移除被拒绝的代码
codes.remove(code)
# 更新信号
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
return signals
# ===== 可选回调方法 =====
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""
入场前检查
Args:
code: 标的代码
price: 入场价格
**kwargs: 其他参数premium, history等
Returns:
是否允许入场
"""
# 默认:允许入场
return True
def after_entry(self, trade, **kwargs) -> None:
"""入场后处理"""
pass
def before_exit(self, position: Position, **kwargs) -> bool:
"""出场前检查"""
return True
def after_exit(self, trade, **kwargs) -> None:
"""出场后处理"""
pass
def dynamic_stoploss(self, position: Position) -> float:
"""
动态止损
Args:
position: 持仓信息
Returns:
止损比例
"""
# 默认:返回固定止损
return self.stoploss
def custom_exit(self, position: Position) -> bool:
"""
自定义出场条件
Args:
position: 持仓信息
Returns:
是否触发出场
"""
return False
class ConfigLoader:
"""
配置加载器
支持YAML配置文件加载和验证
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)
# 示例策略实现
class RotationStrategy(StrategyBase):
"""
ETF轮动策略
基于动量因子 + Top N选股
"""
name = "rotation"
select_num = 3
def init_factors(self) -> FactorCombiner:
"""初始化动量因子"""
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
return FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
])
def init_signal_generator(self) -> SignalGenerator:
"""初始化Top N选股器"""
from framework.signals import TopNSelector
return TopNSelector(
select_num=self.select_num,
min_score=0.0
)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前:溢价过滤"""
premium = kwargs.get('premium', 0)
# 溢价超过10%拒绝入场
if premium > 0.10:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损:根据持仓时间调整"""
if position.holding_days >= 10:
return -0.03 # 10天后收紧止损
elif position.holding_days >= 5:
return -0.05
return -0.10