diff --git a/framework/__init__.py b/framework/__init__.py new file mode 100644 index 0000000..9add992 --- /dev/null +++ b/framework/__init__.py @@ -0,0 +1,21 @@ +""" +量化策略通用框架 + +融合Freqtrade回调机制 + 模块化因子设计 +""" + +from .factors import FactorBase, FactorRegistry, FactorCombiner +from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader +from .strategy import StrategyBase, RotationStrategy +from .risk import RiskControl, StopLossControl, PositionLimitControl +from .execution import Executor, BacktestExecutor, DryRunExecutor +from .config import ConfigLoader + +__all__ = [ + 'FactorBase', 'FactorRegistry', 'FactorCombiner', + 'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader', + 'StrategyBase', 'RotationStrategy', + 'RiskControl', 'StopLossControl', 'PositionLimitControl', + 'Executor', 'BacktestExecutor', 'DryRunExecutor', + 'ConfigLoader', +] \ No newline at end of file diff --git a/framework/config/__init__.py b/framework/config/__init__.py new file mode 100644 index 0000000..7246012 --- /dev/null +++ b/framework/config/__init__.py @@ -0,0 +1,83 @@ +""" +配置层抽象设计 + +核心组件: +- ConfigLoader: 配置加载器 +""" + +import yaml +from typing import Dict, Any, Optional +from pathlib import Path +from dataclasses import dataclass + + +@dataclass +class StrategyConfig: + """策略配置""" + name: str + version: int + factors: list + signal: dict + callbacks: dict + params: dict + + +class ConfigLoader: + """ + 配置加载器 + + 支持YAML配置文件加载和验证 + """ + + def __init__(self, config_path: str): + """ + 初始化配置加载器 + + Args: + config_path: 配置文件路径 + """ + self._config_path = Path(config_path) + self._config = None + + def load(self) -> Dict: + """加载配置""" + if not self._config_path.exists(): + raise FileNotFoundError(f"Config file not found: {self._config_path}") + + with open(self._config_path, 'r', encoding='utf-8') as f: + self._config = yaml.safe_load(f) + + return self._config + + def validate(self) -> bool: + """验证配置""" + if self._config is None: + self.load() + + # 必须字段 + required_fields = ['strategy', 'factors', 'signal'] + + for field in required_fields: + if field not in self._config: + raise ValueError(f"Missing required field: {field}") + + return True + + def get_strategy_config(self) -> StrategyConfig: + """获取策略配置""" + if self._config is None: + self.load() + + return StrategyConfig( + name=self._config['strategy']['name'], + version=self._config['strategy'].get('version', 1), + factors=self._config['factors'], + signal=self._config['signal'], + callbacks=self._config.get('callbacks', {}), + params=self._config.get('params', {}) + ) + + @staticmethod + def from_yaml(yaml_str: str) -> Dict: + """从YAML字符串加载""" + return yaml.safe_load(yaml_str) \ No newline at end of file diff --git a/framework/tests/test_integration.py b/framework/tests/test_integration.py new file mode 100644 index 0000000..49d9629 --- /dev/null +++ b/framework/tests/test_integration.py @@ -0,0 +1,101 @@ +""" +框架核心测试 + +验证框架整体功能 +""" + +import pandas as pd +import numpy as np +import pytest + +from framework.factors import FactorBase, FactorRegistry, FactorCombiner +from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor +from framework.signals import TopNSelector, TrendFollower, ReversalTrader +from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback + + +class TestFrameworkIntegration: + """测试框架集成""" + + def test_rotation_strategy_workflow(self): + """测试轮动策略完整流程""" + # 清空注册表 + FactorRegistry.clear() + + # 1. 注册因子 + FactorRegistry.register(MomentumFactor) + FactorRegistry.register(VolatilityFactor) + + # 2. 创建因子组合 + factors = FactorCombiner([ + FactorRegistry.get('momentum', n_days=25, crash_filter=True), + ], weights=[1.0]) + + # 3. 生成测试数据(需要'close'列) + dates = pd.date_range('2020-01-01', periods=100) + data = pd.DataFrame({ + 'close': np.random.randn(100).cumsum() + 100, + 'code1': np.random.randn(100).cumsum() + 100, + 'code2': np.random.randn(100).cumsum() + 100, + 'code3': np.random.randn(100).cumsum() + 100, + }, index=dates) + + # 4. 计算因子 + factor_result = factors.compute(data) + + # 5. 生成信号 + selector = TopNSelector(select_num=2) + signals = selector.generate(factor_result) + + # 6. 验证结果 + assert 'signal' in signals.columns + assert len(signals) == len(data) + + def test_trend_strategy_workflow(self): + """测试趋势策略完整流程""" + FactorRegistry.clear() + FactorRegistry.register(TrendFactor) + + factors = FactorCombiner([ + FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20), + ]) + + dates = pd.date_range('2020-01-01', periods=100) + data = pd.DataFrame({ + 'close': np.random.randn(100).cumsum() + 100, + }, index=dates) + + factor_result = factors.compute(data) + follower = TrendFollower(entry_threshold=0.02) + signals = follower.generate(factor_result) + + assert 'signal' in signals.columns + + def test_callbacks_workflow(self): + """测试回调钩子完整流程""" + hook = CallbackHook() + + # 注册回调 + hook.register('before_entry', premium_filter_callback(0.10)) + hook.register('dynamic_stoploss', lambda pos: -0.05) + + # 测试入场前回调 + result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05) + assert result == True + + # 测试动态止损 + position = Position( + code='code1', + entry_price=100.0, + entry_date=pd.Timestamp('2020-01-01'), + current_price=95.0, + current_date=pd.Timestamp('2020-01-10'), + quantity=100, + weight=0.33 + ) + stoploss = hook.trigger('dynamic_stoploss', position) + assert stoploss == -0.05 + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file