feat(strategy): 实现策略层与配置加载
核心组件: - StrategyBase: 策略抽象基类(含回调钩子) - 类属性(可被配置覆盖) - init_factors(): 初始化因子组合 - init_signal_generator(): 初始化信号生成器 - run(): 运行策略 - RotationStrategy: 轮动策略示例实现 - 动量因子 + TopN选股 - before_entry回调(溢价过滤) - dynamic_stoploss回调(持仓时间动态止损) - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 特点: - 配置覆盖类属性 - 回调自动注册 - 策略工厂模式 测试覆盖:8个测试全部通过
This commit is contained in:
357
framework/strategy/__init__.py
Normal file
357
framework/strategy/__init__.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
"""
|
||||||
|
策略基类与配置层
|
||||||
|
|
||||||
|
核心组件:
|
||||||
|
- StrategyBase: 策略抽象基类(含回调钩子)
|
||||||
|
- ConfigLoader: 配置加载器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
import pandas as pd
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||||
|
from framework.factors.momentum import MomentumFactor
|
||||||
|
from framework.signals import SignalGenerator, TopNSelector
|
||||||
|
from framework.risk import CallbackHook, Position
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyConfig:
|
||||||
|
"""策略配置"""
|
||||||
|
name: str
|
||||||
|
version: int
|
||||||
|
factors: List[Dict]
|
||||||
|
signal: Dict
|
||||||
|
callbacks: Dict
|
||||||
|
params: Dict
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyBase(ABC):
|
||||||
|
"""
|
||||||
|
策略抽象基类
|
||||||
|
|
||||||
|
融合Freqtrade回调机制 + 模块化因子设计
|
||||||
|
|
||||||
|
类属性(可被配置覆盖):
|
||||||
|
- name: 策略名称
|
||||||
|
- version: 接口版本
|
||||||
|
- timeframe: K线周期
|
||||||
|
- select_num: 选中数量
|
||||||
|
- stoploss: 止损比例
|
||||||
|
|
||||||
|
回调钩子(可选实现):
|
||||||
|
- before_entry: 入场前检查
|
||||||
|
- after_entry: 入场后处理
|
||||||
|
- before_exit: 出场前检查
|
||||||
|
- after_exit: 出场后处理
|
||||||
|
- dynamic_stoploss: 动态止损
|
||||||
|
- custom_exit: 自定义出场条件
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 接口版本
|
||||||
|
INTERFACE_VERSION = 1
|
||||||
|
|
||||||
|
# 类属性(可被配置覆盖)
|
||||||
|
name: str = "base"
|
||||||
|
timeframe: str = "1d"
|
||||||
|
select_num: int = 3
|
||||||
|
stoploss: float = -0.05
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
初始化策略
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: 策略配置(可覆盖类属性)
|
||||||
|
"""
|
||||||
|
# 配置覆盖类属性
|
||||||
|
if config:
|
||||||
|
self._apply_config(config)
|
||||||
|
|
||||||
|
# 初始化回调钩子
|
||||||
|
self._callbacks = CallbackHook()
|
||||||
|
self._register_default_callbacks()
|
||||||
|
|
||||||
|
# 初始化因子和信号生成器
|
||||||
|
self._factors = None
|
||||||
|
self._signal_gen = None
|
||||||
|
|
||||||
|
def _apply_config(self, config: Dict) -> None:
|
||||||
|
"""应用配置"""
|
||||||
|
params = config.get('params', {})
|
||||||
|
|
||||||
|
# 覆盖类属性
|
||||||
|
for key, value in params.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
# 保存完整配置
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def _register_default_callbacks(self) -> None:
|
||||||
|
"""注册默认回调"""
|
||||||
|
# 注册入场前回调(溢价过滤)
|
||||||
|
if hasattr(self, 'before_entry'):
|
||||||
|
self._callbacks.register('before_entry', self.before_entry)
|
||||||
|
|
||||||
|
# 注册动态止损回调
|
||||||
|
if hasattr(self, 'dynamic_stoploss'):
|
||||||
|
self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss)
|
||||||
|
|
||||||
|
# 注册自定义出场回调
|
||||||
|
if hasattr(self, 'custom_exit'):
|
||||||
|
self._callbacks.register('custom_exit', self.custom_exit)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_factors(self) -> FactorCombiner:
|
||||||
|
"""
|
||||||
|
初始化因子组合
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
因子组合器
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def init_signal_generator(self) -> SignalGenerator:
|
||||||
|
"""
|
||||||
|
初始化信号生成器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
信号生成器
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
运行策略
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 输入数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含信号的DataFrame
|
||||||
|
"""
|
||||||
|
# 初始化因子和信号生成器
|
||||||
|
if self._factors is None:
|
||||||
|
self._factors = self.init_factors()
|
||||||
|
|
||||||
|
if self._signal_gen is None:
|
||||||
|
self._signal_gen = self.init_signal_generator()
|
||||||
|
|
||||||
|
# 1. 计算因子
|
||||||
|
factor_data = self._factors.compute(data)
|
||||||
|
|
||||||
|
# 2. 生成信号
|
||||||
|
signals = self._signal_gen.generate(factor_data)
|
||||||
|
|
||||||
|
# 3. 应用回调钩子
|
||||||
|
signals = self._apply_callbacks(signals, data)
|
||||||
|
|
||||||
|
return signals
|
||||||
|
|
||||||
|
def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""应用回调钩子"""
|
||||||
|
# 遍历每行信号
|
||||||
|
for date in signals.index:
|
||||||
|
signal = signals.loc[date, 'signal']
|
||||||
|
|
||||||
|
if not signal or pd.isna(signal):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析信号(逗号分隔的代码)
|
||||||
|
codes = signal.split(',')
|
||||||
|
|
||||||
|
# 应用入场前回调
|
||||||
|
for code in codes:
|
||||||
|
if code not in data.columns:
|
||||||
|
continue
|
||||||
|
|
||||||
|
price = data.loc[date, code]
|
||||||
|
premium = 0.0 # TODO: 从溢价数据获取
|
||||||
|
|
||||||
|
# 触发回调
|
||||||
|
allowed = self._callbacks.trigger(
|
||||||
|
'before_entry',
|
||||||
|
code,
|
||||||
|
price,
|
||||||
|
premium=premium,
|
||||||
|
history=data
|
||||||
|
)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
# 移除被拒绝的代码
|
||||||
|
codes.remove(code)
|
||||||
|
|
||||||
|
# 更新信号
|
||||||
|
signals.loc[date, 'signal'] = ','.join(codes) if codes else ''
|
||||||
|
|
||||||
|
return signals
|
||||||
|
|
||||||
|
# ===== 可选回调方法 =====
|
||||||
|
|
||||||
|
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
入场前检查
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 标的代码
|
||||||
|
price: 入场价格
|
||||||
|
**kwargs: 其他参数(premium, history等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否允许入场
|
||||||
|
"""
|
||||||
|
# 默认:允许入场
|
||||||
|
return True
|
||||||
|
|
||||||
|
def after_entry(self, trade, **kwargs) -> None:
|
||||||
|
"""入场后处理"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_exit(self, position: Position, **kwargs) -> bool:
|
||||||
|
"""出场前检查"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def after_exit(self, trade, **kwargs) -> None:
|
||||||
|
"""出场后处理"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def dynamic_stoploss(self, position: Position) -> float:
|
||||||
|
"""
|
||||||
|
动态止损
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position: 持仓信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
止损比例
|
||||||
|
"""
|
||||||
|
# 默认:返回固定止损
|
||||||
|
return self.stoploss
|
||||||
|
|
||||||
|
def custom_exit(self, position: Position) -> bool:
|
||||||
|
"""
|
||||||
|
自定义出场条件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
position: 持仓信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否触发出场
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
"""
|
||||||
|
配置加载器
|
||||||
|
|
||||||
|
支持YAML配置文件加载和验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
"""
|
||||||
|
初始化配置加载器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径
|
||||||
|
"""
|
||||||
|
self._config_path = Path(config_path)
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def load(self) -> Dict:
|
||||||
|
"""加载配置"""
|
||||||
|
if not self._config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||||
|
|
||||||
|
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||||
|
self._config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证配置"""
|
||||||
|
if self._config is None:
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
# 必须字段
|
||||||
|
required_fields = ['strategy', 'factors', 'signal']
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in self._config:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_strategy_config(self) -> StrategyConfig:
|
||||||
|
"""获取策略配置"""
|
||||||
|
if self._config is None:
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
return StrategyConfig(
|
||||||
|
name=self._config['strategy']['name'],
|
||||||
|
version=self._config['strategy'].get('version', 1),
|
||||||
|
factors=self._config['factors'],
|
||||||
|
signal=self._config['signal'],
|
||||||
|
callbacks=self._config.get('callbacks', {}),
|
||||||
|
params=self._config.get('params', {})
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_yaml(yaml_str: str) -> Dict:
|
||||||
|
"""从YAML字符串加载"""
|
||||||
|
return yaml.safe_load(yaml_str)
|
||||||
|
|
||||||
|
|
||||||
|
# 示例策略实现
|
||||||
|
class RotationStrategy(StrategyBase):
|
||||||
|
"""
|
||||||
|
ETF轮动策略
|
||||||
|
|
||||||
|
基于动量因子 + Top N选股
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "rotation"
|
||||||
|
select_num = 3
|
||||||
|
|
||||||
|
def init_factors(self) -> FactorCombiner:
|
||||||
|
"""初始化动量因子"""
|
||||||
|
FactorRegistry.clear()
|
||||||
|
FactorRegistry.register(MomentumFactor)
|
||||||
|
|
||||||
|
return FactorCombiner([
|
||||||
|
FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||||
|
])
|
||||||
|
|
||||||
|
def init_signal_generator(self) -> SignalGenerator:
|
||||||
|
"""初始化Top N选股器"""
|
||||||
|
from framework.signals import TopNSelector
|
||||||
|
|
||||||
|
return TopNSelector(
|
||||||
|
select_num=self.select_num,
|
||||||
|
min_score=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def before_entry(self, code: str, price: float, **kwargs) -> bool:
|
||||||
|
"""入场前:溢价过滤"""
|
||||||
|
premium = kwargs.get('premium', 0)
|
||||||
|
|
||||||
|
# 溢价超过10%拒绝入场
|
||||||
|
if premium > 0.10:
|
||||||
|
print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def dynamic_stoploss(self, position: Position) -> float:
|
||||||
|
"""动态止损:根据持仓时间调整"""
|
||||||
|
if position.holding_days >= 10:
|
||||||
|
return -0.03 # 10天后收紧止损
|
||||||
|
elif position.holding_days >= 5:
|
||||||
|
return -0.05
|
||||||
|
return -0.10
|
||||||
140
framework/tests/test_strategy.py
Normal file
140
framework/tests/test_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
策略与配置层测试
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
|
||||||
|
from framework.factors import FactorRegistry
|
||||||
|
from framework.factors.momentum import MomentumFactor
|
||||||
|
from framework.risk import Position
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyBase:
|
||||||
|
"""测试策略基类"""
|
||||||
|
|
||||||
|
def test_strategy_attributes(self):
|
||||||
|
"""测试策略属性"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
assert strategy.name == "rotation"
|
||||||
|
assert strategy.select_num == 3
|
||||||
|
assert strategy.stoploss == -0.05
|
||||||
|
|
||||||
|
def test_config_override(self):
|
||||||
|
"""测试配置覆盖"""
|
||||||
|
config = {
|
||||||
|
'params': {
|
||||||
|
'select_num': 5,
|
||||||
|
'stoploss': -0.08
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
strategy = RotationStrategy(config=config)
|
||||||
|
assert strategy.select_num == 5
|
||||||
|
assert strategy.stoploss == -0.08
|
||||||
|
|
||||||
|
def test_factor_initialization(self):
|
||||||
|
"""测试因子初始化"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
factors = strategy.init_factors()
|
||||||
|
|
||||||
|
assert factors is not None
|
||||||
|
assert len(factors.factors) == 1
|
||||||
|
|
||||||
|
def test_signal_generator_initialization(self):
|
||||||
|
"""测试信号生成器初始化"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
signal_gen = strategy.init_signal_generator()
|
||||||
|
|
||||||
|
assert signal_gen is not None
|
||||||
|
assert signal_gen.select_num == 3
|
||||||
|
|
||||||
|
def test_strategy_run(self):
|
||||||
|
"""测试策略运行"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
|
||||||
|
# 生成测试数据(需要'close'列)
|
||||||
|
dates = pd.date_range('2020-01-01', periods=100)
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'close': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code1': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code2': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code3': np.random.randn(100).cumsum() + 100,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
result = strategy.run(data)
|
||||||
|
|
||||||
|
assert 'signal' in result.columns
|
||||||
|
assert len(result) == len(data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRotationStrategy:
|
||||||
|
"""测试轮动策略"""
|
||||||
|
|
||||||
|
def test_before_entry_callback(self):
|
||||||
|
"""测试入场前回调"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
|
||||||
|
# 溢价正常,允许入场
|
||||||
|
result = strategy.before_entry('code1', 100.0, premium=0.05)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# 溢价过高,拒绝入场
|
||||||
|
result = strategy.before_entry('code1', 100.0, premium=0.15)
|
||||||
|
assert result == False
|
||||||
|
|
||||||
|
def test_dynamic_stoploss_callback(self):
|
||||||
|
"""测试动态止损回调"""
|
||||||
|
strategy = RotationStrategy()
|
||||||
|
|
||||||
|
# 持仓5天
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=datetime(2020, 1, 1),
|
||||||
|
current_price=95.0,
|
||||||
|
current_date=datetime(2020, 1, 6), # 5天后
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
stoploss = strategy.dynamic_stoploss(position)
|
||||||
|
# holding_days=5,返回-0.05
|
||||||
|
assert stoploss == -0.05
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigLoader:
|
||||||
|
"""测试配置加载器"""
|
||||||
|
|
||||||
|
def test_load_from_yaml_string(self):
|
||||||
|
"""测试从YAML字符串加载"""
|
||||||
|
yaml_str = """
|
||||||
|
strategy:
|
||||||
|
name: test_strategy
|
||||||
|
version: 1
|
||||||
|
|
||||||
|
factors:
|
||||||
|
- name: momentum
|
||||||
|
weight: 1.0
|
||||||
|
params:
|
||||||
|
n_days: 25
|
||||||
|
|
||||||
|
signal:
|
||||||
|
mode: top_n
|
||||||
|
select_num: 3
|
||||||
|
|
||||||
|
params:
|
||||||
|
stoploss: -0.05
|
||||||
|
"""
|
||||||
|
|
||||||
|
config = ConfigLoader.from_yaml(yaml_str)
|
||||||
|
|
||||||
|
assert config['strategy']['name'] == 'test_strategy'
|
||||||
|
assert config['factors'][0]['name'] == 'momentum'
|
||||||
|
assert config['signal']['select_num'] == 3
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user