feat(strategy): 实现策略层与配置加载
核心组件: - StrategyBase: 策略抽象基类(含回调钩子) - 类属性(可被配置覆盖) - init_factors(): 初始化因子组合 - init_signal_generator(): 初始化信号生成器 - run(): 运行策略 - RotationStrategy: 轮动策略示例实现 - 动量因子 + TopN选股 - before_entry回调(溢价过滤) - dynamic_stoploss回调(持仓时间动态止损) - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 特点: - 配置覆盖类属性 - 回调自动注册 - 策略工厂模式 测试覆盖:8个测试全部通过
This commit is contained in:
357
framework/strategy/__init__.py
Normal file
357
framework/strategy/__init__.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
策略基类与配置层
|
||||
|
||||
核心组件:
|
||||
- StrategyBase: 策略抽象基类(含回调钩子)
|
||||
- ConfigLoader: 配置加载器
|
||||
"""
|
||||
|
||||
import yaml
|
||||
import pandas as pd
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor
|
||||
from framework.signals import SignalGenerator, TopNSelector
|
||||
from framework.risk import CallbackHook, Position
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""策略配置"""
|
||||
name: str
|
||||
version: int
|
||||
factors: List[Dict]
|
||||
signal: Dict
|
||||
callbacks: Dict
|
||||
params: Dict
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
融合Freqtrade回调机制 + 模块化因子设计
|
||||
|
||||
类属性(可被配置覆盖):
|
||||
- name: 策略名称
|
||||
- version: 接口版本
|
||||
- timeframe: K线周期
|
||||
- select_num: 选中数量
|
||||
- stoploss: 止损比例
|
||||
|
||||
回调钩子(可选实现):
|
||||
- before_entry: 入场前检查
|
||||
- after_entry: 入场后处理
|
||||
- before_exit: 出场前检查
|
||||
- after_exit: 出场后处理
|
||||
- dynamic_stoploss: 动态止损
|
||||
- custom_exit: 自定义出场条件
|
||||
"""
|
||||
|
||||
# 接口版本
|
||||
INTERFACE_VERSION = 1
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
select_num: int = 3
|
||||
stoploss: float = -0.05
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置(可覆盖类属性)
|
||||
"""
|
||||
# 配置覆盖类属性
|
||||
if config:
|
||||
self._apply_config(config)
|
||||
|
||||
# 初始化回调钩子
|
||||
self._callbacks = CallbackHook()
|
||||
self._register_default_callbacks()
|
||||
|
||||
# 初始化因子和信号生成器
|
||||
self._factors = None
|
||||
self._signal_gen = None
|
||||
|
||||
def _apply_config(self, config: Dict) -> None:
|
||||
"""应用配置"""
|
||||
params = config.get('params', {})
|
||||
|
||||
# 覆盖类属性
|
||||
for key, value in params.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
# 保存完整配置
|
||||
self._config = config
|
||||
|
||||
def _register_default_callbacks(self) -> None:
|
||||
"""注册默认回调"""
|
||||
# 注册入场前回调(溢价过滤)
|
||||
if hasattr(self, 'before_entry'):
|
||||
self._callbacks.register('before_entry', self.before_entry)
|
||||
|
||||
# 注册动态止损回调
|
||||
if hasattr(self, 'dynamic_stoploss'):
|
||||
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
|
||||
|
||||
# 注册自定义出场回调
|
||||
if hasattr(self, 'custom_exit'):
|
||||
self._callbacks.register('custom_exit', self.custom_exit)
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""
|
||||
初始化因子组合
|
||||
|
||||
Returns:
|
||||
因子组合器
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
"""
|
||||
初始化信号生成器
|
||||
|
||||
Returns:
|
||||
信号生成器
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
运行策略
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含信号的DataFrame
|
||||
"""
|
||||
# 初始化因子和信号生成器
|
||||
if self._factors is None:
|
||||
self._factors = self.init_factors()
|
||||
|
||||
if self._signal_gen is None:
|
||||
self._signal_gen = self.init_signal_generator()
|
||||
|
||||
# 1. 计算因子
|
||||
factor_data = self._factors.compute(data)
|
||||
|
||||
# 2. 生成信号
|
||||
signals = self._signal_gen.generate(factor_data)
|
||||
|
||||
# 3. 应用回调钩子
|
||||
signals = self._apply_callbacks(signals, data)
|
||||
|
||||
return signals
|
||||
|
||||
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""应用回调钩子"""
|
||||
# 遍历每行信号
|
||||
for date in signals.index:
|
||||
signal = signals.loc[date, 'signal']
|
||||
|
||||
if not signal or pd.isna(signal):
|
||||
continue
|
||||
|
||||
# 解析信号(逗号分隔的代码)
|
||||
codes = signal.split(',')
|
||||
|
||||
# 应用入场前回调
|
||||
for code in codes:
|
||||
if code not in data.columns:
|
||||
continue
|
||||
|
||||
price = data.loc[date, code]
|
||||
premium = 0.0 # TODO: 从溢价数据获取
|
||||
|
||||
# 触发回调
|
||||
allowed = self._callbacks.trigger(
|
||||
'before_entry',
|
||||
code,
|
||||
price,
|
||||
premium=premium,
|
||||
history=data
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
# 移除被拒绝的代码
|
||||
codes.remove(code)
|
||||
|
||||
# 更新信号
|
||||
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
|
||||
|
||||
return signals
|
||||
|
||||
# ===== 可选回调方法 =====
|
||||
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""
|
||||
入场前检查
|
||||
|
||||
Args:
|
||||
code: 标的代码
|
||||
price: 入场价格
|
||||
**kwargs: 其他参数(premium, history等)
|
||||
|
||||
Returns:
|
||||
是否允许入场
|
||||
"""
|
||||
# 默认:允许入场
|
||||
return True
|
||||
|
||||
def after_entry(self, trade, **kwargs) -> None:
|
||||
"""入场后处理"""
|
||||
pass
|
||||
|
||||
def before_exit(self, position: Position, **kwargs) -> bool:
|
||||
"""出场前检查"""
|
||||
return True
|
||||
|
||||
def after_exit(self, trade, **kwargs) -> None:
|
||||
"""出场后处理"""
|
||||
pass
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""
|
||||
动态止损
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
|
||||
Returns:
|
||||
止损比例
|
||||
"""
|
||||
# 默认:返回固定止损
|
||||
return self.stoploss
|
||||
|
||||
def custom_exit(self, position: Position) -> bool:
|
||||
"""
|
||||
自定义出场条件
|
||||
|
||||
Args:
|
||||
position: 持仓信息
|
||||
|
||||
Returns:
|
||||
是否触发出场
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
支持YAML配置文件加载和验证
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化配置加载器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
self._config_path = Path(config_path)
|
||||
self._config = None
|
||||
|
||||
def load(self) -> Dict:
|
||||
"""加载配置"""
|
||||
if not self._config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||
|
||||
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
|
||||
return self._config
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
# 必须字段
|
||||
required_fields = ['strategy', 'factors', 'signal']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in self._config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
return True
|
||||
|
||||
def get_strategy_config(self) -> StrategyConfig:
|
||||
"""获取策略配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
return StrategyConfig(
|
||||
name=self._config['strategy']['name'],
|
||||
version=self._config['strategy'].get('version', 1),
|
||||
factors=self._config['factors'],
|
||||
signal=self._config['signal'],
|
||||
callbacks=self._config.get('callbacks', {}),
|
||||
params=self._config.get('params', {})
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml(yaml_str: str) -> Dict:
|
||||
"""从YAML字符串加载"""
|
||||
return yaml.safe_load(yaml_str)
|
||||
|
||||
|
||||
# 示例策略实现
|
||||
class RotationStrategy(StrategyBase):
|
||||
"""
|
||||
ETF轮动策略
|
||||
|
||||
基于动量因子 + Top N选股
|
||||
"""
|
||||
|
||||
name = "rotation"
|
||||
select_num = 3
|
||||
|
||||
def init_factors(self) -> FactorCombiner:
|
||||
"""初始化动量因子"""
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
return FactorCombiner([
|
||||
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
])
|
||||
|
||||
def init_signal_generator(self) -> SignalGenerator:
|
||||
"""初始化Top N选股器"""
|
||||
from framework.signals import TopNSelector
|
||||
|
||||
return TopNSelector(
|
||||
select_num=self.select_num,
|
||||
min_score=0.0
|
||||
)
|
||||
|
||||
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||
"""入场前:溢价过滤"""
|
||||
premium = kwargs.get('premium', 0)
|
||||
|
||||
# 溢价超过10%拒绝入场
|
||||
if premium > 0.10:
|
||||
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def dynamic_stoploss(self, position: Position) -> float:
|
||||
"""动态止损:根据持仓时间调整"""
|
||||
if position.holding_days >= 10:
|
||||
return -0.03 # 10天后收紧止损
|
||||
elif position.holding_days >= 5:
|
||||
return -0.05
|
||||
return -0.10
|
||||
140
framework/tests/test_strategy.py
Normal file
140
framework/tests/test_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
策略与配置层测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
|
||||
from framework.factors import FactorRegistry
|
||||
from framework.factors.momentum import MomentumFactor
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestStrategyBase:
|
||||
"""测试策略基类"""
|
||||
|
||||
def test_strategy_attributes(self):
|
||||
"""测试策略属性"""
|
||||
strategy = RotationStrategy()
|
||||
assert strategy.name == "rotation"
|
||||
assert strategy.select_num == 3
|
||||
assert strategy.stoploss == -0.05
|
||||
|
||||
def test_config_override(self):
|
||||
"""测试配置覆盖"""
|
||||
config = {
|
||||
'params': {
|
||||
'select_num': 5,
|
||||
'stoploss': -0.08
|
||||
}
|
||||
}
|
||||
|
||||
strategy = RotationStrategy(config=config)
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.08
|
||||
|
||||
def test_factor_initialization(self):
|
||||
"""测试因子初始化"""
|
||||
strategy = RotationStrategy()
|
||||
factors = strategy.init_factors()
|
||||
|
||||
assert factors is not None
|
||||
assert len(factors.factors) == 1
|
||||
|
||||
def test_signal_generator_initialization(self):
|
||||
"""测试信号生成器初始化"""
|
||||
strategy = RotationStrategy()
|
||||
signal_gen = strategy.init_signal_generator()
|
||||
|
||||
assert signal_gen is not None
|
||||
assert signal_gen.select_num == 3
|
||||
|
||||
def test_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据(需要'close'列)
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'code1': np.random.randn(100).cumsum() + 100,
|
||||
'code2': np.random.randn(100).cumsum() + 100,
|
||||
'code3': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
assert len(result) == len(data)
|
||||
|
||||
|
||||
class TestRotationStrategy:
|
||||
"""测试轮动策略"""
|
||||
|
||||
def test_before_entry_callback(self):
|
||||
"""测试入场前回调"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 溢价正常,允许入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_dynamic_stoploss_callback(self):
|
||||
"""测试动态止损回调"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 持仓5天
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 6), # 5天后
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
stoploss = strategy.dynamic_stoploss(position)
|
||||
# holding_days=5,返回-0.05
|
||||
assert stoploss == -0.05
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""测试配置加载器"""
|
||||
|
||||
def test_load_from_yaml_string(self):
|
||||
"""测试从YAML字符串加载"""
|
||||
yaml_str = """
|
||||
strategy:
|
||||
name: test_strategy
|
||||
version: 1
|
||||
|
||||
factors:
|
||||
- name: momentum
|
||||
weight: 1.0
|
||||
params:
|
||||
n_days: 25
|
||||
|
||||
signal:
|
||||
mode: top_n
|
||||
select_num: 3
|
||||
|
||||
params:
|
||||
stoploss: -0.05
|
||||
"""
|
||||
|
||||
config = ConfigLoader.from_yaml(yaml_str)
|
||||
|
||||
assert config['strategy']['name'] == 'test_strategy'
|
||||
assert config['factors'][0]['name'] == 'momentum'
|
||||
assert config['signal']['select_num'] == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user