From 7468130450a97e7e8273d24c6304f25f48538af2 Mon Sep 17 00:00:00 2001 From: aszerW Date: Mon, 11 May 2026 22:18:55 +0800 Subject: [PATCH] =?UTF-8?q?feat(strategy):=20=E5=AE=9E=E7=8E=B0=E7=AD=96?= =?UTF-8?q?=E7=95=A5=E5=B1=82=E4=B8=8E=E9=85=8D=E7=BD=AE=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 核心组件: - StrategyBase: 策略抽象基类(含回调钩子) - 类属性(可被配置覆盖) - init_factors(): 初始化因子组合 - init_signal_generator(): 初始化信号生成器 - run(): 运行策略 - RotationStrategy: 轮动策略示例实现 - 动量因子 + TopN选股 - before_entry回调(溢价过滤) - dynamic_stoploss回调(持仓时间动态止损) - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 特点: - 配置覆盖类属性 - 回调自动注册 - 策略工厂模式 测试覆盖:8个测试全部通过 --- framework/strategy/__init__.py | 357 +++++++++++++++++++++++++++++++ framework/tests/test_strategy.py | 140 ++++++++++++ 2 files changed, 497 insertions(+) create mode 100644 framework/strategy/__init__.py create mode 100644 framework/tests/test_strategy.py diff --git a/framework/strategy/__init__.py b/framework/strategy/__init__.py new file mode 100644 index 0000000..9b025cf --- /dev/null +++ b/framework/strategy/__init__.py @@ -0,0 +1,357 @@ +""" +策略基类与配置层 + +核心组件: +- StrategyBase: 策略抽象基类(含回调钩子) +- ConfigLoader: 配置加载器 +""" + +import yaml +import pandas as pd +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional, List +from pathlib import Path +from dataclasses import dataclass + +from framework.factors import FactorBase, FactorRegistry, FactorCombiner +from framework.factors.momentum import MomentumFactor +from framework.signals import SignalGenerator, TopNSelector +from framework.risk import CallbackHook, Position + + +@dataclass +class StrategyConfig: + """策略配置""" + name: str + version: int + factors: List[Dict] + signal: Dict + callbacks: Dict + params: Dict + + +class StrategyBase(ABC): + """ + 策略抽象基类 + + 融合Freqtrade回调机制 + 模块化因子设计 + + 类属性(可被配置覆盖): + - name: 策略名称 + - version: 接口版本 + - timeframe: K线周期 + - select_num: 选中数量 + - stoploss: 止损比例 + + 回调钩子(可选实现): + - before_entry: 入场前检查 + - after_entry: 入场后处理 + - before_exit: 出场前检查 + - after_exit: 出场后处理 + - dynamic_stoploss: 动态止损 + - custom_exit: 自定义出场条件 + """ + + # 接口版本 + INTERFACE_VERSION = 1 + + # 类属性(可被配置覆盖) + name: str = "base" + timeframe: str = "1d" + select_num: int = 3 + stoploss: float = -0.05 + + def __init__(self, config: Optional[Dict] = None): + """ + 初始化策略 + + Args: + config: 策略配置(可覆盖类属性) + """ + # 配置覆盖类属性 + if config: + self._apply_config(config) + + # 初始化回调钩子 + self._callbacks = CallbackHook() + self._register_default_callbacks() + + # 初始化因子和信号生成器 + self._factors = None + self._signal_gen = None + + def _apply_config(self, config: Dict) -> None: + """应用配置""" + params = config.get('params', {}) + + # 覆盖类属性 + for key, value in params.items(): + if hasattr(self, key): + setattr(self, key, value) + + # 保存完整配置 + self._config = config + + def _register_default_callbacks(self) -> None: + """注册默认回调""" + # 注册入场前回调(溢价过滤) + if hasattr(self, 'before_entry'): + self._callbacks.register('before_entry', self.before_entry) + + # 注册动态止损回调 + if hasattr(self, 'dynamic_stoploss'): + self._callbacks.register('dynamic_stoploss', self.dynamic_stoploss) + + # 注册自定义出场回调 + if hasattr(self, 'custom_exit'): + self._callbacks.register('custom_exit', self.custom_exit) + + @abstractmethod + def init_factors(self) -> FactorCombiner: + """ + 初始化因子组合 + + Returns: + 因子组合器 + """ + pass + + @abstractmethod + def init_signal_generator(self) -> SignalGenerator: + """ + 初始化信号生成器 + + Returns: + 信号生成器 + """ + pass + + def run(self, data: pd.DataFrame) -> pd.DataFrame: + """ + 运行策略 + + Args: + data: 输入数据 + + Returns: + 包含信号的DataFrame + """ + # 初始化因子和信号生成器 + if self._factors is None: + self._factors = self.init_factors() + + if self._signal_gen is None: + self._signal_gen = self.init_signal_generator() + + # 1. 计算因子 + factor_data = self._factors.compute(data) + + # 2. 生成信号 + signals = self._signal_gen.generate(factor_data) + + # 3. 应用回调钩子 + signals = self._apply_callbacks(signals, data) + + return signals + + def _apply_callbacks(self, signals: pd.DataFrame, data: pd.DataFrame) -> pd.DataFrame: + """应用回调钩子""" + # 遍历每行信号 + for date in signals.index: + signal = signals.loc[date, 'signal'] + + if not signal or pd.isna(signal): + continue + + # 解析信号(逗号分隔的代码) + codes = signal.split(',') + + # 应用入场前回调 + for code in codes: + if code not in data.columns: + continue + + price = data.loc[date, code] + premium = 0.0 # TODO: 从溢价数据获取 + + # 触发回调 + allowed = self._callbacks.trigger( + 'before_entry', + code, + price, + premium=premium, + history=data + ) + + if not allowed: + # 移除被拒绝的代码 + codes.remove(code) + + # 更新信号 + signals.loc[date, 'signal'] = ','.join(codes) if codes else '' + + return signals + + # ===== 可选回调方法 ===== + + def before_entry(self, code: str, price: float, **kwargs) -> bool: + """ + 入场前检查 + + Args: + code: 标的代码 + price: 入场价格 + **kwargs: 其他参数(premium, history等) + + Returns: + 是否允许入场 + """ + # 默认:允许入场 + return True + + def after_entry(self, trade, **kwargs) -> None: + """入场后处理""" + pass + + def before_exit(self, position: Position, **kwargs) -> bool: + """出场前检查""" + return True + + def after_exit(self, trade, **kwargs) -> None: + """出场后处理""" + pass + + def dynamic_stoploss(self, position: Position) -> float: + """ + 动态止损 + + Args: + position: 持仓信息 + + Returns: + 止损比例 + """ + # 默认:返回固定止损 + return self.stoploss + + def custom_exit(self, position: Position) -> bool: + """ + 自定义出场条件 + + Args: + position: 持仓信息 + + Returns: + 是否触发出场 + """ + return False + + +class ConfigLoader: + """ + 配置加载器 + + 支持YAML配置文件加载和验证 + """ + + def __init__(self, config_path: str): + """ + 初始化配置加载器 + + Args: + config_path: 配置文件路径 + """ + self._config_path = Path(config_path) + self._config = None + + def load(self) -> Dict: + """加载配置""" + if not self._config_path.exists(): + raise FileNotFoundError(f"Config file not found: {self._config_path}") + + with open(self._config_path, 'r', encoding='utf-8') as f: + self._config = yaml.safe_load(f) + + return self._config + + def validate(self) -> bool: + """验证配置""" + if self._config is None: + self.load() + + # 必须字段 + required_fields = ['strategy', 'factors', 'signal'] + + for field in required_fields: + if field not in self._config: + raise ValueError(f"Missing required field: {field}") + + return True + + def get_strategy_config(self) -> StrategyConfig: + """获取策略配置""" + if self._config is None: + self.load() + + return StrategyConfig( + name=self._config['strategy']['name'], + version=self._config['strategy'].get('version', 1), + factors=self._config['factors'], + signal=self._config['signal'], + callbacks=self._config.get('callbacks', {}), + params=self._config.get('params', {}) + ) + + @staticmethod + def from_yaml(yaml_str: str) -> Dict: + """从YAML字符串加载""" + return yaml.safe_load(yaml_str) + + +# 示例策略实现 +class RotationStrategy(StrategyBase): + """ + ETF轮动策略 + + 基于动量因子 + Top N选股 + """ + + name = "rotation" + select_num = 3 + + def init_factors(self) -> FactorCombiner: + """初始化动量因子""" + FactorRegistry.clear() + FactorRegistry.register(MomentumFactor) + + return FactorCombiner([ + FactorRegistry.get('momentum', n_days=25, crash_filter=True) + ]) + + def init_signal_generator(self) -> SignalGenerator: + """初始化Top N选股器""" + from framework.signals import TopNSelector + + return TopNSelector( + select_num=self.select_num, + min_score=0.0 + ) + + def before_entry(self, code: str, price: float, **kwargs) -> bool: + """入场前:溢价过滤""" + premium = kwargs.get('premium', 0) + + # 溢价超过10%拒绝入场 + if premium > 0.10: + print(f"溢价过高,拒绝入场: {code} (溢价={premium:.2%})") + return False + + return True + + def dynamic_stoploss(self, position: Position) -> float: + """动态止损:根据持仓时间调整""" + if position.holding_days >= 10: + return -0.03 # 10天后收紧止损 + elif position.holding_days >= 5: + return -0.05 + return -0.10 \ No newline at end of file diff --git a/framework/tests/test_strategy.py b/framework/tests/test_strategy.py new file mode 100644 index 0000000..59b5b90 --- /dev/null +++ b/framework/tests/test_strategy.py @@ -0,0 +1,140 @@ +""" +策略与配置层测试 +""" + +import pytest +import pandas as pd +import numpy as np +from datetime import datetime + +from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy +from framework.factors import FactorRegistry +from framework.factors.momentum import MomentumFactor +from framework.risk import Position + + +class TestStrategyBase: + """测试策略基类""" + + def test_strategy_attributes(self): + """测试策略属性""" + strategy = RotationStrategy() + assert strategy.name == "rotation" + assert strategy.select_num == 3 + assert strategy.stoploss == -0.05 + + def test_config_override(self): + """测试配置覆盖""" + config = { + 'params': { + 'select_num': 5, + 'stoploss': -0.08 + } + } + + strategy = RotationStrategy(config=config) + assert strategy.select_num == 5 + assert strategy.stoploss == -0.08 + + def test_factor_initialization(self): + """测试因子初始化""" + strategy = RotationStrategy() + factors = strategy.init_factors() + + assert factors is not None + assert len(factors.factors) == 1 + + def test_signal_generator_initialization(self): + """测试信号生成器初始化""" + strategy = RotationStrategy() + signal_gen = strategy.init_signal_generator() + + assert signal_gen is not None + assert signal_gen.select_num == 3 + + def test_strategy_run(self): + """测试策略运行""" + strategy = RotationStrategy() + + # 生成测试数据(需要'close'列) + dates = pd.date_range('2020-01-01', periods=100) + data = pd.DataFrame({ + 'close': np.random.randn(100).cumsum() + 100, + 'code1': np.random.randn(100).cumsum() + 100, + 'code2': np.random.randn(100).cumsum() + 100, + 'code3': np.random.randn(100).cumsum() + 100, + }, index=dates) + + result = strategy.run(data) + + assert 'signal' in result.columns + assert len(result) == len(data) + + +class TestRotationStrategy: + """测试轮动策略""" + + def test_before_entry_callback(self): + """测试入场前回调""" + strategy = RotationStrategy() + + # 溢价正常,允许入场 + result = strategy.before_entry('code1', 100.0, premium=0.05) + assert result == True + + # 溢价过高,拒绝入场 + result = strategy.before_entry('code1', 100.0, premium=0.15) + assert result == False + + def test_dynamic_stoploss_callback(self): + """测试动态止损回调""" + strategy = RotationStrategy() + + # 持仓5天 + position = Position( + code='code1', + entry_price=100.0, + entry_date=datetime(2020, 1, 1), + current_price=95.0, + current_date=datetime(2020, 1, 6), # 5天后 + quantity=100, + weight=0.33 + ) + stoploss = strategy.dynamic_stoploss(position) + # holding_days=5,返回-0.05 + assert stoploss == -0.05 + + +class TestConfigLoader: + """测试配置加载器""" + + def test_load_from_yaml_string(self): + """测试从YAML字符串加载""" + yaml_str = """ +strategy: + name: test_strategy + version: 1 + +factors: + - name: momentum + weight: 1.0 + params: + n_days: 25 + +signal: + mode: top_n + select_num: 3 + +params: + stoploss: -0.05 +""" + + config = ConfigLoader.from_yaml(yaml_str) + + assert config['strategy']['name'] == 'test_strategy' + assert config['factors'][0]['name'] == 'momentum' + assert config['signal']['select_num'] == 3 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file