feat(framework): 完成框架入口与集成测试
核心组件: - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 - framework/__init__.py: 框架统一入口 导出接口: - FactorBase, FactorRegistry, FactorCombiner - SignalGenerator, TopNSelector, TrendFollower, ReversalTrader - StrategyBase, RotationStrategy - RiskControl, StopLossControl, PositionLimitControl - Executor, BacktestExecutor, DryRunExecutor - ConfigLoader 集成测试: - 轮动策略完整流程验证 - 趋势策略完整流程验证 - 回调钩子完整流程验证 总计:62个测试全部通过,框架核心实现完成
This commit is contained in:
21
framework/__init__.py
Normal file
21
framework/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""
|
||||||
|
量化策略通用框架
|
||||||
|
|
||||||
|
融合Freqtrade回调机制 + 模块化因子设计
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .factors import FactorBase, FactorRegistry, FactorCombiner
|
||||||
|
from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
|
||||||
|
from .strategy import StrategyBase, RotationStrategy
|
||||||
|
from .risk import RiskControl, StopLossControl, PositionLimitControl
|
||||||
|
from .execution import Executor, BacktestExecutor, DryRunExecutor
|
||||||
|
from .config import ConfigLoader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'FactorBase', 'FactorRegistry', 'FactorCombiner',
|
||||||
|
'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader',
|
||||||
|
'StrategyBase', 'RotationStrategy',
|
||||||
|
'RiskControl', 'StopLossControl', 'PositionLimitControl',
|
||||||
|
'Executor', 'BacktestExecutor', 'DryRunExecutor',
|
||||||
|
'ConfigLoader',
|
||||||
|
]
|
||||||
83
framework/config/__init__.py
Normal file
83
framework/config/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""
|
||||||
|
配置层抽象设计
|
||||||
|
|
||||||
|
核心组件:
|
||||||
|
- ConfigLoader: 配置加载器
|
||||||
|
"""
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StrategyConfig:
|
||||||
|
"""策略配置"""
|
||||||
|
name: str
|
||||||
|
version: int
|
||||||
|
factors: list
|
||||||
|
signal: dict
|
||||||
|
callbacks: dict
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
"""
|
||||||
|
配置加载器
|
||||||
|
|
||||||
|
支持YAML配置文件加载和验证
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config_path: str):
|
||||||
|
"""
|
||||||
|
初始化配置加载器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: 配置文件路径
|
||||||
|
"""
|
||||||
|
self._config_path = Path(config_path)
|
||||||
|
self._config = None
|
||||||
|
|
||||||
|
def load(self) -> Dict:
|
||||||
|
"""加载配置"""
|
||||||
|
if not self._config_path.exists():
|
||||||
|
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||||
|
|
||||||
|
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||||
|
self._config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""验证配置"""
|
||||||
|
if self._config is None:
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
# 必须字段
|
||||||
|
required_fields = ['strategy', 'factors', 'signal']
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in self._config:
|
||||||
|
raise ValueError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_strategy_config(self) -> StrategyConfig:
|
||||||
|
"""获取策略配置"""
|
||||||
|
if self._config is None:
|
||||||
|
self.load()
|
||||||
|
|
||||||
|
return StrategyConfig(
|
||||||
|
name=self._config['strategy']['name'],
|
||||||
|
version=self._config['strategy'].get('version', 1),
|
||||||
|
factors=self._config['factors'],
|
||||||
|
signal=self._config['signal'],
|
||||||
|
callbacks=self._config.get('callbacks', {}),
|
||||||
|
params=self._config.get('params', {})
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_yaml(yaml_str: str) -> Dict:
|
||||||
|
"""从YAML字符串加载"""
|
||||||
|
return yaml.safe_load(yaml_str)
|
||||||
101
framework/tests/test_integration.py
Normal file
101
framework/tests/test_integration.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
框架核心测试
|
||||||
|
|
||||||
|
验证框架整体功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||||
|
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||||
|
from framework.signals import TopNSelector, TrendFollower, ReversalTrader
|
||||||
|
from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback
|
||||||
|
|
||||||
|
|
||||||
|
class TestFrameworkIntegration:
|
||||||
|
"""测试框架集成"""
|
||||||
|
|
||||||
|
def test_rotation_strategy_workflow(self):
|
||||||
|
"""测试轮动策略完整流程"""
|
||||||
|
# 清空注册表
|
||||||
|
FactorRegistry.clear()
|
||||||
|
|
||||||
|
# 1. 注册因子
|
||||||
|
FactorRegistry.register(MomentumFactor)
|
||||||
|
FactorRegistry.register(VolatilityFactor)
|
||||||
|
|
||||||
|
# 2. 创建因子组合
|
||||||
|
factors = FactorCombiner([
|
||||||
|
FactorRegistry.get('momentum', n_days=25, crash_filter=True),
|
||||||
|
], weights=[1.0])
|
||||||
|
|
||||||
|
# 3. 生成测试数据(需要'close'列)
|
||||||
|
dates = pd.date_range('2020-01-01', periods=100)
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'close': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code1': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code2': np.random.randn(100).cumsum() + 100,
|
||||||
|
'code3': np.random.randn(100).cumsum() + 100,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
# 4. 计算因子
|
||||||
|
factor_result = factors.compute(data)
|
||||||
|
|
||||||
|
# 5. 生成信号
|
||||||
|
selector = TopNSelector(select_num=2)
|
||||||
|
signals = selector.generate(factor_result)
|
||||||
|
|
||||||
|
# 6. 验证结果
|
||||||
|
assert 'signal' in signals.columns
|
||||||
|
assert len(signals) == len(data)
|
||||||
|
|
||||||
|
def test_trend_strategy_workflow(self):
|
||||||
|
"""测试趋势策略完整流程"""
|
||||||
|
FactorRegistry.clear()
|
||||||
|
FactorRegistry.register(TrendFactor)
|
||||||
|
|
||||||
|
factors = FactorCombiner([
|
||||||
|
FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20),
|
||||||
|
])
|
||||||
|
|
||||||
|
dates = pd.date_range('2020-01-01', periods=100)
|
||||||
|
data = pd.DataFrame({
|
||||||
|
'close': np.random.randn(100).cumsum() + 100,
|
||||||
|
}, index=dates)
|
||||||
|
|
||||||
|
factor_result = factors.compute(data)
|
||||||
|
follower = TrendFollower(entry_threshold=0.02)
|
||||||
|
signals = follower.generate(factor_result)
|
||||||
|
|
||||||
|
assert 'signal' in signals.columns
|
||||||
|
|
||||||
|
def test_callbacks_workflow(self):
|
||||||
|
"""测试回调钩子完整流程"""
|
||||||
|
hook = CallbackHook()
|
||||||
|
|
||||||
|
# 注册回调
|
||||||
|
hook.register('before_entry', premium_filter_callback(0.10))
|
||||||
|
hook.register('dynamic_stoploss', lambda pos: -0.05)
|
||||||
|
|
||||||
|
# 测试入场前回调
|
||||||
|
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||||
|
assert result == True
|
||||||
|
|
||||||
|
# 测试动态止损
|
||||||
|
position = Position(
|
||||||
|
code='code1',
|
||||||
|
entry_price=100.0,
|
||||||
|
entry_date=pd.Timestamp('2020-01-01'),
|
||||||
|
current_price=95.0,
|
||||||
|
current_date=pd.Timestamp('2020-01-10'),
|
||||||
|
quantity=100,
|
||||||
|
weight=0.33
|
||||||
|
)
|
||||||
|
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||||
|
assert stoploss == -0.05
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
pytest.main([__file__, '-v'])
|
||||||
Reference in New Issue
Block a user