feat(strategy): 实现策略层与配置加载

核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
  - 类属性(可被配置覆盖)
  - init_factors(): 初始化因子组合
  - init_signal_generator(): 初始化信号生成器
  - run(): 运行策略
- RotationStrategy: 轮动策略示例实现
  - 动量因子 + TopN选股
  - before_entry回调(溢价过滤)
  - dynamic_stoploss回调(持仓时间动态止损)
- ConfigLoader: 配置加载器(YAML支持)
- StrategyConfig: 策略配置数据类

特点:
- 配置覆盖类属性
- 回调自动注册
- 策略工厂模式

测试覆盖:8个测试全部通过
This commit is contained in:
2026-05-11 22:18:55 +08:00
parent 512b73ac04
commit 7468130450
2 changed files with 497 additions and 0 deletions

View File

@@ -0,0 +1,357 @@
"""
策略基类与配置层
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
- ConfigLoader: 配置加载器
"""
import yaml
import pandas as pd
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, List
from pathlib import Path
from dataclasses import dataclass
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor
from framework.signals import SignalGenerator, TopNSelector
from framework.risk import CallbackHook, Position
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: List[Dict]
signal: Dict
callbacks: Dict
params: Dict
class StrategyBase(ABC):
"""
策略抽象基类
融合Freqtrade回调机制 + 模块化因子设计
类属性(可被配置覆盖):
- name: 策略名称
- version: 接口版本
- timeframe: K线周期
- select_num: 选中数量
- stoploss: 止损比例
回调钩子(可选实现):
- before_entry: 入场前检查
- after_entry: 入场后处理
- before_exit: 出场前检查
- after_exit: 出场后处理
- dynamic_stoploss: 动态止损
- custom_exit: 自定义出场条件
"""
# 接口版本
INTERFACE_VERSION = 1
# 类属性(可被配置覆盖)
name: str = "base"
timeframe: str = "1d"
select_num: int = 3
stoploss: float = -0.05
def __init__(self, config: Optional[Dict] = None):
"""
初始化策略
Args:
config: 策略配置(可覆盖类属性)
"""
# 配置覆盖类属性
if config:
self._apply_config(config)
# 初始化回调钩子
self._callbacks = CallbackHook()
self._register_default_callbacks()
# 初始化因子和信号生成器
self._factors = None
self._signal_gen = None
def _apply_config(self, config: Dict) -> None:
"""应用配置"""
params = config.get('params', {})
# 覆盖类属性
for key, value in params.items():
if hasattr(self, key):
setattr(self, key, value)
# 保存完整配置
self._config = config
def _register_default_callbacks(self) -> None:
"""注册默认回调"""
# 注册入场前回调(溢价过滤)
if hasattr(self, 'before_entry'):
self._callbacks.register('before_entry', self.before_entry)
# 注册动态止损回调
if hasattr(self, 'dynamic_stoploss'):
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
# 注册自定义出场回调
if hasattr(self, 'custom_exit'):
self._callbacks.register('custom_exit', self.custom_exit)
@abstractmethod
def init_factors(self) -> FactorCombiner:
"""
初始化因子组合
Returns:
因子组合器
"""
pass
@abstractmethod
def init_signal_generator(self) -> SignalGenerator:
"""
初始化信号生成器
Returns:
信号生成器
"""
pass
def run(self, data: pd.DataFrame) -> pd.DataFrame:
"""
运行策略
Args:
data: 输入数据
Returns:
包含信号的DataFrame
"""
# 初始化因子和信号生成器
if self._factors is None:
self._factors = self.init_factors()
if self._signal_gen is None:
self._signal_gen = self.init_signal_generator()
# 1. 计算因子
factor_data = self._factors.compute(data)
# 2. 生成信号
signals = self._signal_gen.generate(factor_data)
# 3. 应用回调钩子
signals = self._apply_callbacks(signals, data)
return signals
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
"""应用回调钩子"""
# 遍历每行信号
for date in signals.index:
signal = signals.loc[date, 'signal']
if not signal or pd.isna(signal):
continue
# 解析信号(逗号分隔的代码)
codes = signal.split(',')
# 应用入场前回调
for code in codes:
if code not in data.columns:
continue
price = data.loc[date, code]
premium = 0.0 # TODO: 从溢价数据获取
# 触发回调
allowed = self._callbacks.trigger(
'before_entry',
code,
price,
premium=premium,
history=data
)
if not allowed:
# 移除被拒绝的代码
codes.remove(code)
# 更新信号
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
return signals
# ===== 可选回调方法 =====
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""
入场前检查
Args:
code: 标的代码
price: 入场价格
**kwargs: 其他参数premium, history等
Returns:
是否允许入场
"""
# 默认:允许入场
return True
def after_entry(self, trade, **kwargs) -> None:
"""入场后处理"""
pass
def before_exit(self, position: Position, **kwargs) -> bool:
"""出场前检查"""
return True
def after_exit(self, trade, **kwargs) -> None:
"""出场后处理"""
pass
def dynamic_stoploss(self, position: Position) -> float:
"""
动态止损
Args:
position: 持仓信息
Returns:
止损比例
"""
# 默认:返回固定止损
return self.stoploss
def custom_exit(self, position: Position) -> bool:
"""
自定义出场条件
Args:
position: 持仓信息
Returns:
是否触发出场
"""
return False
class ConfigLoader:
"""
配置加载器
支持YAML配置文件加载和验证
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)
# 示例策略实现
class RotationStrategy(StrategyBase):
"""
ETF轮动策略
基于动量因子 + Top N选股
"""
name = "rotation"
select_num = 3
def init_factors(self) -> FactorCombiner:
"""初始化动量因子"""
FactorRegistry.clear()
FactorRegistry.register(MomentumFactor)
return FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
])
def init_signal_generator(self) -> SignalGenerator:
"""初始化Top N选股器"""
from framework.signals import TopNSelector
return TopNSelector(
select_num=self.select_num,
min_score=0.0
)
def before_entry(self, code: str, price: float, **kwargs) -> bool:
"""入场前:溢价过滤"""
premium = kwargs.get('premium', 0)
# 溢价超过10%拒绝入场
if premium > 0.10:
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
return False
return True
def dynamic_stoploss(self, position: Position) -> float:
"""动态止损:根据持仓时间调整"""
if position.holding_days >= 10:
return -0.03 # 10天后收紧止损
elif position.holding_days >= 5:
return -0.05
return -0.10

View File

@@ -0,0 +1,140 @@
"""
策略与配置层测试
"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
from framework.factors import FactorRegistry
from framework.factors.momentum import MomentumFactor
from framework.risk import Position
class TestStrategyBase:
"""测试策略基类"""
def test_strategy_attributes(self):
"""测试策略属性"""
strategy = RotationStrategy()
assert strategy.name == "rotation"
assert strategy.select_num == 3
assert strategy.stoploss == -0.05
def test_config_override(self):
"""测试配置覆盖"""
config = {
'params': {
'select_num': 5,
'stoploss': -0.08
}
}
strategy = RotationStrategy(config=config)
assert strategy.select_num == 5
assert strategy.stoploss == -0.08
def test_factor_initialization(self):
"""测试因子初始化"""
strategy = RotationStrategy()
factors = strategy.init_factors()
assert factors is not None
assert len(factors.factors) == 1
def test_signal_generator_initialization(self):
"""测试信号生成器初始化"""
strategy = RotationStrategy()
signal_gen = strategy.init_signal_generator()
assert signal_gen is not None
assert signal_gen.select_num == 3
def test_strategy_run(self):
"""测试策略运行"""
strategy = RotationStrategy()
# 生成测试数据(需要'close'列)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'code1': np.random.randn(100).cumsum() + 100,
'code2': np.random.randn(100).cumsum() + 100,
'code3': np.random.randn(100).cumsum() + 100,
}, index=dates)
result = strategy.run(data)
assert 'signal' in result.columns
assert len(result) == len(data)
class TestRotationStrategy:
"""测试轮动策略"""
def test_before_entry_callback(self):
"""测试入场前回调"""
strategy = RotationStrategy()
# 溢价正常,允许入场
result = strategy.before_entry('code1', 100.0, premium=0.05)
assert result == True
# 溢价过高,拒绝入场
result = strategy.before_entry('code1', 100.0, premium=0.15)
assert result == False
def test_dynamic_stoploss_callback(self):
"""测试动态止损回调"""
strategy = RotationStrategy()
# 持仓5天
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 6), # 5天后
quantity=100,
weight=0.33
)
stoploss = strategy.dynamic_stoploss(position)
# holding_days=5返回-0.05
assert stoploss == -0.05
class TestConfigLoader:
"""测试配置加载器"""
def test_load_from_yaml_string(self):
"""测试从YAML字符串加载"""
yaml_str = """
strategy:
name: test_strategy
version: 1
factors:
- name: momentum
weight: 1.0
params:
n_days: 25
signal:
mode: top_n
select_num: 3
params:
stoploss: -0.05
"""
config = ConfigLoader.from_yaml(yaml_str)
assert config['strategy']['name'] == 'test_strategy'
assert config['factors'][0]['name'] == 'momentum'
assert config['signal']['select_num'] == 3
if __name__ == '__main__':
pytest.main([__file__, '-v'])