feat(framework): 完成框架入口与集成测试
核心组件: - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 - framework/__init__.py: 框架统一入口 导出接口: - FactorBase, FactorRegistry, FactorCombiner - SignalGenerator, TopNSelector, TrendFollower, ReversalTrader - StrategyBase, RotationStrategy - RiskControl, StopLossControl, PositionLimitControl - Executor, BacktestExecutor, DryRunExecutor - ConfigLoader 集成测试: - 轮动策略完整流程验证 - 趋势策略完整流程验证 - 回调钩子完整流程验证 总计:62个测试全部通过,框架核心实现完成
This commit is contained in:
21
framework/__init__.py
Normal file
21
framework/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
量化策略通用框架
|
||||
|
||||
融合Freqtrade回调机制 + 模块化因子设计
|
||||
"""
|
||||
|
||||
from .factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
|
||||
from .strategy import StrategyBase, RotationStrategy
|
||||
from .risk import RiskControl, StopLossControl, PositionLimitControl
|
||||
from .execution import Executor, BacktestExecutor, DryRunExecutor
|
||||
from .config import ConfigLoader
|
||||
|
||||
__all__ = [
|
||||
'FactorBase', 'FactorRegistry', 'FactorCombiner',
|
||||
'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader',
|
||||
'StrategyBase', 'RotationStrategy',
|
||||
'RiskControl', 'StopLossControl', 'PositionLimitControl',
|
||||
'Executor', 'BacktestExecutor', 'DryRunExecutor',
|
||||
'ConfigLoader',
|
||||
]
|
||||
83
framework/config/__init__.py
Normal file
83
framework/config/__init__.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
配置层抽象设计
|
||||
|
||||
核心组件:
|
||||
- ConfigLoader: 配置加载器
|
||||
"""
|
||||
|
||||
import yaml
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyConfig:
|
||||
"""策略配置"""
|
||||
name: str
|
||||
version: int
|
||||
factors: list
|
||||
signal: dict
|
||||
callbacks: dict
|
||||
params: dict
|
||||
|
||||
|
||||
class ConfigLoader:
|
||||
"""
|
||||
配置加载器
|
||||
|
||||
支持YAML配置文件加载和验证
|
||||
"""
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
"""
|
||||
初始化配置加载器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径
|
||||
"""
|
||||
self._config_path = Path(config_path)
|
||||
self._config = None
|
||||
|
||||
def load(self) -> Dict:
|
||||
"""加载配置"""
|
||||
if not self._config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {self._config_path}")
|
||||
|
||||
with open(self._config_path, 'r', encoding='utf-8') as f:
|
||||
self._config = yaml.safe_load(f)
|
||||
|
||||
return self._config
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""验证配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
# 必须字段
|
||||
required_fields = ['strategy', 'factors', 'signal']
|
||||
|
||||
for field in required_fields:
|
||||
if field not in self._config:
|
||||
raise ValueError(f"Missing required field: {field}")
|
||||
|
||||
return True
|
||||
|
||||
def get_strategy_config(self) -> StrategyConfig:
|
||||
"""获取策略配置"""
|
||||
if self._config is None:
|
||||
self.load()
|
||||
|
||||
return StrategyConfig(
|
||||
name=self._config['strategy']['name'],
|
||||
version=self._config['strategy'].get('version', 1),
|
||||
factors=self._config['factors'],
|
||||
signal=self._config['signal'],
|
||||
callbacks=self._config.get('callbacks', {}),
|
||||
params=self._config.get('params', {})
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml(yaml_str: str) -> Dict:
|
||||
"""从YAML字符串加载"""
|
||||
return yaml.safe_load(yaml_str)
|
||||
101
framework/tests/test_integration.py
Normal file
101
framework/tests/test_integration.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
框架核心测试
|
||||
|
||||
验证框架整体功能
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
from framework.signals import TopNSelector, TrendFollower, ReversalTrader
|
||||
from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback
|
||||
|
||||
|
||||
class TestFrameworkIntegration:
|
||||
"""测试框架集成"""
|
||||
|
||||
def test_rotation_strategy_workflow(self):
|
||||
"""测试轮动策略完整流程"""
|
||||
# 清空注册表
|
||||
FactorRegistry.clear()
|
||||
|
||||
# 1. 注册因子
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(VolatilityFactor)
|
||||
|
||||
# 2. 创建因子组合
|
||||
factors = FactorCombiner([
|
||||
FactorRegistry.get('momentum', n_days=25, crash_filter=True),
|
||||
], weights=[1.0])
|
||||
|
||||
# 3. 生成测试数据(需要'close'列)
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'code1': np.random.randn(100).cumsum() + 100,
|
||||
'code2': np.random.randn(100).cumsum() + 100,
|
||||
'code3': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
# 4. 计算因子
|
||||
factor_result = factors.compute(data)
|
||||
|
||||
# 5. 生成信号
|
||||
selector = TopNSelector(select_num=2)
|
||||
signals = selector.generate(factor_result)
|
||||
|
||||
# 6. 验证结果
|
||||
assert 'signal' in signals.columns
|
||||
assert len(signals) == len(data)
|
||||
|
||||
def test_trend_strategy_workflow(self):
|
||||
"""测试趋势策略完整流程"""
|
||||
FactorRegistry.clear()
|
||||
FactorRegistry.register(TrendFactor)
|
||||
|
||||
factors = FactorCombiner([
|
||||
FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20),
|
||||
])
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
factor_result = factors.compute(data)
|
||||
follower = TrendFollower(entry_threshold=0.02)
|
||||
signals = follower.generate(factor_result)
|
||||
|
||||
assert 'signal' in signals.columns
|
||||
|
||||
def test_callbacks_workflow(self):
|
||||
"""测试回调钩子完整流程"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册回调
|
||||
hook.register('before_entry', premium_filter_callback(0.10))
|
||||
hook.register('dynamic_stoploss', lambda pos: -0.05)
|
||||
|
||||
# 测试入场前回调
|
||||
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 测试动态止损
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=pd.Timestamp('2020-01-01'),
|
||||
current_price=95.0,
|
||||
current_date=pd.Timestamp('2020-01-10'),
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
stoploss = hook.trigger('dynamic_stoploss', position)
|
||||
assert stoploss == -0.05
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user