feat(framework): 完成框架入口与集成测试

核心组件:
- ConfigLoader: 配置加载器(YAML支持)
- StrategyConfig: 策略配置数据类
- framework/__init__.py: 框架统一入口

导出接口:
- FactorBase, FactorRegistry, FactorCombiner
- SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
- StrategyBase, RotationStrategy
- RiskControl, StopLossControl, PositionLimitControl
- Executor, BacktestExecutor, DryRunExecutor
- ConfigLoader

集成测试:
- 轮动策略完整流程验证
- 趋势策略完整流程验证
- 回调钩子完整流程验证

总计:62个测试全部通过,框架核心实现完成
This commit is contained in:
2026-05-11 22:19:26 +08:00
parent babf224203
commit 95c0d79172
3 changed files with 205 additions and 0 deletions

21
framework/__init__.py Normal file
View File

@@ -0,0 +1,21 @@
"""
量化策略通用框架
融合Freqtrade回调机制 + 模块化因子设计
"""
from .factors import FactorBase, FactorRegistry, FactorCombiner
from .signals import SignalGenerator, TopNSelector, TrendFollower, ReversalTrader
from .strategy import StrategyBase, RotationStrategy
from .risk import RiskControl, StopLossControl, PositionLimitControl
from .execution import Executor, BacktestExecutor, DryRunExecutor
from .config import ConfigLoader
__all__ = [
'FactorBase', 'FactorRegistry', 'FactorCombiner',
'SignalGenerator', 'TopNSelector', 'TrendFollower', 'ReversalTrader',
'StrategyBase', 'RotationStrategy',
'RiskControl', 'StopLossControl', 'PositionLimitControl',
'Executor', 'BacktestExecutor', 'DryRunExecutor',
'ConfigLoader',
]

View File

@@ -0,0 +1,83 @@
"""
配置层抽象设计
核心组件:
- ConfigLoader: 配置加载器
"""
import yaml
from typing import Dict, Any, Optional
from pathlib import Path
from dataclasses import dataclass
@dataclass
class StrategyConfig:
"""策略配置"""
name: str
version: int
factors: list
signal: dict
callbacks: dict
params: dict
class ConfigLoader:
"""
配置加载器
支持YAML配置文件加载和验证
"""
def __init__(self, config_path: str):
"""
初始化配置加载器
Args:
config_path: 配置文件路径
"""
self._config_path = Path(config_path)
self._config = None
def load(self) -> Dict:
"""加载配置"""
if not self._config_path.exists():
raise FileNotFoundError(f"Config file not found: {self._config_path}")
with open(self._config_path, 'r', encoding='utf-8') as f:
self._config = yaml.safe_load(f)
return self._config
def validate(self) -> bool:
"""验证配置"""
if self._config is None:
self.load()
# 必须字段
required_fields = ['strategy', 'factors', 'signal']
for field in required_fields:
if field not in self._config:
raise ValueError(f"Missing required field: {field}")
return True
def get_strategy_config(self) -> StrategyConfig:
"""获取策略配置"""
if self._config is None:
self.load()
return StrategyConfig(
name=self._config['strategy']['name'],
version=self._config['strategy'].get('version', 1),
factors=self._config['factors'],
signal=self._config['signal'],
callbacks=self._config.get('callbacks', {}),
params=self._config.get('params', {})
)
@staticmethod
def from_yaml(yaml_str: str) -> Dict:
"""从YAML字符串加载"""
return yaml.safe_load(yaml_str)

View File

@@ -0,0 +1,101 @@
"""
框架核心测试
验证框架整体功能
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
from framework.signals import TopNSelector, TrendFollower, ReversalTrader
from framework.risk import StopLossControl, CallbackHook, Position, premium_filter_callback
class TestFrameworkIntegration:
"""测试框架集成"""
def test_rotation_strategy_workflow(self):
"""测试轮动策略完整流程"""
# 清空注册表
FactorRegistry.clear()
# 1. 注册因子
FactorRegistry.register(MomentumFactor)
FactorRegistry.register(VolatilityFactor)
# 2. 创建因子组合
factors = FactorCombiner([
FactorRegistry.get('momentum', n_days=25, crash_filter=True),
], weights=[1.0])
# 3. 生成测试数据(需要'close'列)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'code1': np.random.randn(100).cumsum() + 100,
'code2': np.random.randn(100).cumsum() + 100,
'code3': np.random.randn(100).cumsum() + 100,
}, index=dates)
# 4. 计算因子
factor_result = factors.compute(data)
# 5. 生成信号
selector = TopNSelector(select_num=2)
signals = selector.generate(factor_result)
# 6. 验证结果
assert 'signal' in signals.columns
assert len(signals) == len(data)
def test_trend_strategy_workflow(self):
"""测试趋势策略完整流程"""
FactorRegistry.clear()
FactorRegistry.register(TrendFactor)
factors = FactorCombiner([
FactorRegistry.get('trend', method='ma_cross', fast=5, slow=20),
])
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
}, index=dates)
factor_result = factors.compute(data)
follower = TrendFollower(entry_threshold=0.02)
signals = follower.generate(factor_result)
assert 'signal' in signals.columns
def test_callbacks_workflow(self):
"""测试回调钩子完整流程"""
hook = CallbackHook()
# 注册回调
hook.register('before_entry', premium_filter_callback(0.10))
hook.register('dynamic_stoploss', lambda pos: -0.05)
# 测试入场前回调
result = hook.trigger('before_entry', 'code1', 100.0, premium=0.05)
assert result == True
# 测试动态止损
position = Position(
code='code1',
entry_price=100.0,
entry_date=pd.Timestamp('2020-01-01'),
current_price=95.0,
current_date=pd.Timestamp('2020-01-10'),
quantity=100,
weight=0.33
)
stoploss = hook.trigger('dynamic_stoploss', position)
assert stoploss == -0.05
if __name__ == '__main__':
pytest.main([__file__, '-v'])