feat(strategy): 实现策略层与配置加载
核心组件: - StrategyBase: 策略抽象基类(含回调钩子) - 类属性(可被配置覆盖) - init_factors(): 初始化因子组合 - init_signal_generator(): 初始化信号生成器 - run(): 运行策略 - RotationStrategy: 轮动策略示例实现 - 动量因子 + TopN选股 - before_entry回调(溢价过滤) - dynamic_stoploss回调(持仓时间动态止损) - ConfigLoader: 配置加载器(YAML支持) - StrategyConfig: 策略配置数据类 特点: - 配置覆盖类属性 - 回调自动注册 - 策略工厂模式 测试覆盖:8个测试全部通过
This commit is contained in:
140
framework/tests/test_strategy.py
Normal file
140
framework/tests/test_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
策略与配置层测试
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
|
||||
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
|
||||
from framework.factors import FactorRegistry
|
||||
from framework.factors.momentum import MomentumFactor
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestStrategyBase:
|
||||
"""测试策略基类"""
|
||||
|
||||
def test_strategy_attributes(self):
|
||||
"""测试策略属性"""
|
||||
strategy = RotationStrategy()
|
||||
assert strategy.name == "rotation"
|
||||
assert strategy.select_num == 3
|
||||
assert strategy.stoploss == -0.05
|
||||
|
||||
def test_config_override(self):
|
||||
"""测试配置覆盖"""
|
||||
config = {
|
||||
'params': {
|
||||
'select_num': 5,
|
||||
'stoploss': -0.08
|
||||
}
|
||||
}
|
||||
|
||||
strategy = RotationStrategy(config=config)
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.08
|
||||
|
||||
def test_factor_initialization(self):
|
||||
"""测试因子初始化"""
|
||||
strategy = RotationStrategy()
|
||||
factors = strategy.init_factors()
|
||||
|
||||
assert factors is not None
|
||||
assert len(factors.factors) == 1
|
||||
|
||||
def test_signal_generator_initialization(self):
|
||||
"""测试信号生成器初始化"""
|
||||
strategy = RotationStrategy()
|
||||
signal_gen = strategy.init_signal_generator()
|
||||
|
||||
assert signal_gen is not None
|
||||
assert signal_gen.select_num == 3
|
||||
|
||||
def test_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据(需要'close'列)
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'code1': np.random.randn(100).cumsum() + 100,
|
||||
'code2': np.random.randn(100).cumsum() + 100,
|
||||
'code3': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
assert len(result) == len(data)
|
||||
|
||||
|
||||
class TestRotationStrategy:
|
||||
"""测试轮动策略"""
|
||||
|
||||
def test_before_entry_callback(self):
|
||||
"""测试入场前回调"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 溢价正常,允许入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_dynamic_stoploss_callback(self):
|
||||
"""测试动态止损回调"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 持仓5天
|
||||
position = Position(
|
||||
code='code1',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 6), # 5天后
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
)
|
||||
stoploss = strategy.dynamic_stoploss(position)
|
||||
# holding_days=5,返回-0.05
|
||||
assert stoploss == -0.05
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""测试配置加载器"""
|
||||
|
||||
def test_load_from_yaml_string(self):
|
||||
"""测试从YAML字符串加载"""
|
||||
yaml_str = """
|
||||
strategy:
|
||||
name: test_strategy
|
||||
version: 1
|
||||
|
||||
factors:
|
||||
- name: momentum
|
||||
weight: 1.0
|
||||
params:
|
||||
n_days: 25
|
||||
|
||||
signal:
|
||||
mode: top_n
|
||||
select_num: 3
|
||||
|
||||
params:
|
||||
stoploss: -0.05
|
||||
"""
|
||||
|
||||
config = ConfigLoader.from_yaml(yaml_str)
|
||||
|
||||
assert config['strategy']['name'] == 'test_strategy'
|
||||
assert config['factors'][0]['name'] == 'momentum'
|
||||
assert config['signal']['select_num'] == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user