Files
etf/framework/tests/test_strategy.py
aszerW 7468130450 feat(strategy): 实现策略层与配置加载
核心组件:
- StrategyBase: 策略抽象基类(含回调钩子)
  - 类属性(可被配置覆盖)
  - init_factors(): 初始化因子组合
  - init_signal_generator(): 初始化信号生成器
  - run(): 运行策略
- RotationStrategy: 轮动策略示例实现
  - 动量因子 + TopN选股
  - before_entry回调(溢价过滤)
  - dynamic_stoploss回调(持仓时间动态止损)
- ConfigLoader: 配置加载器(YAML支持)
- StrategyConfig: 策略配置数据类

特点:
- 配置覆盖类属性
- 回调自动注册
- 策略工厂模式

测试覆盖:8个测试全部通过
2026-05-11 22:18:55 +08:00

140 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
策略与配置层测试
"""
import pytest
import pandas as pd
import numpy as np
from datetime import datetime
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
from framework.factors import FactorRegistry
from framework.factors.momentum import MomentumFactor
from framework.risk import Position
class TestStrategyBase:
"""测试策略基类"""
def test_strategy_attributes(self):
"""测试策略属性"""
strategy = RotationStrategy()
assert strategy.name == "rotation"
assert strategy.select_num == 3
assert strategy.stoploss == -0.05
def test_config_override(self):
"""测试配置覆盖"""
config = {
'params': {
'select_num': 5,
'stoploss': -0.08
}
}
strategy = RotationStrategy(config=config)
assert strategy.select_num == 5
assert strategy.stoploss == -0.08
def test_factor_initialization(self):
"""测试因子初始化"""
strategy = RotationStrategy()
factors = strategy.init_factors()
assert factors is not None
assert len(factors.factors) == 1
def test_signal_generator_initialization(self):
"""测试信号生成器初始化"""
strategy = RotationStrategy()
signal_gen = strategy.init_signal_generator()
assert signal_gen is not None
assert signal_gen.select_num == 3
def test_strategy_run(self):
"""测试策略运行"""
strategy = RotationStrategy()
# 生成测试数据(需要'close'列)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'code1': np.random.randn(100).cumsum() + 100,
'code2': np.random.randn(100).cumsum() + 100,
'code3': np.random.randn(100).cumsum() + 100,
}, index=dates)
result = strategy.run(data)
assert 'signal' in result.columns
assert len(result) == len(data)
class TestRotationStrategy:
"""测试轮动策略"""
def test_before_entry_callback(self):
"""测试入场前回调"""
strategy = RotationStrategy()
# 溢价正常,允许入场
result = strategy.before_entry('code1', 100.0, premium=0.05)
assert result == True
# 溢价过高,拒绝入场
result = strategy.before_entry('code1', 100.0, premium=0.15)
assert result == False
def test_dynamic_stoploss_callback(self):
"""测试动态止损回调"""
strategy = RotationStrategy()
# 持仓5天
position = Position(
code='code1',
entry_price=100.0,
entry_date=datetime(2020, 1, 1),
current_price=95.0,
current_date=datetime(2020, 1, 6), # 5天后
quantity=100,
weight=0.33
)
stoploss = strategy.dynamic_stoploss(position)
# holding_days=5返回-0.05
assert stoploss == -0.05
class TestConfigLoader:
"""测试配置加载器"""
def test_load_from_yaml_string(self):
"""测试从YAML字符串加载"""
yaml_str = """
strategy:
name: test_strategy
version: 1
factors:
- name: momentum
weight: 1.0
params:
n_days: 25
signal:
mode: top_n
select_num: 3
params:
stoploss: -0.05
"""
config = ConfigLoader.from_yaml(yaml_str)
assert config['strategy']['name'] == 'test_strategy'
assert config['factors'][0]['name'] == 'momentum'
assert config['signal']['select_num'] == 3
if __name__ == '__main__':
pytest.main([__file__, '-v'])