test: 更新测试以验证框架重构正确性
- 测试文件改用strategies.shared的具体实现 - 新增framework_comparison_test.py对比新旧实现结果 - 因子计算相关系数达到1.0000,差异为0.000000 - 79个单元测试全部通过
This commit is contained in:
@@ -1,139 +1,193 @@
|
||||
"""
|
||||
策略与配置层测试
|
||||
策略层测试
|
||||
|
||||
测试StrategyBase抽象接口
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.strategy import StrategyBase, ConfigLoader, RotationStrategy
|
||||
from framework.factors import FactorRegistry
|
||||
from framework.factors.momentum import MomentumFactor
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestStrategyBase:
|
||||
"""测试策略基类"""
|
||||
"""测试StrategyBase抽象基类"""
|
||||
|
||||
def test_strategy_attributes(self):
|
||||
"""测试策略属性"""
|
||||
def test_strategy_config_override(self):
|
||||
"""测试配置覆盖类属性"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03})
|
||||
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.03
|
||||
|
||||
def test_strategy_default_values(self):
|
||||
"""测试默认值"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
assert strategy.name == "rotation"
|
||||
|
||||
assert strategy.select_num == 3
|
||||
assert strategy.stoploss == -0.05
|
||||
|
||||
def test_config_override(self):
|
||||
"""测试配置覆盖"""
|
||||
config = {
|
||||
'params': {
|
||||
'select_num': 5,
|
||||
'stoploss': -0.08
|
||||
}
|
||||
}
|
||||
def test_strategy_repr(self):
|
||||
"""测试字符串表示"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy(config=config)
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.08
|
||||
|
||||
def test_factor_initialization(self):
|
||||
"""测试因子初始化"""
|
||||
strategy = RotationStrategy()
|
||||
factors = strategy.init_factors()
|
||||
repr_str = repr(strategy)
|
||||
|
||||
assert factors is not None
|
||||
assert len(factors.factors) == 1
|
||||
assert 'RotationStrategy' in repr_str
|
||||
assert 'rotation' in repr_str
|
||||
|
||||
def test_signal_generator_initialization(self):
|
||||
"""测试信号生成器初始化"""
|
||||
def test_strategy_interface_version(self):
|
||||
"""测试接口版本"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
signal_gen = strategy.init_signal_generator()
|
||||
|
||||
assert signal_gen is not None
|
||||
assert signal_gen.select_num == 3
|
||||
|
||||
def test_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据(需要'close'列)
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'code1': np.random.randn(100).cumsum() + 100,
|
||||
'code2': np.random.randn(100).cumsum() + 100,
|
||||
'code3': np.random.randn(100).cumsum() + 100,
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
assert len(result) == len(data)
|
||||
assert strategy.INTERFACE_VERSION == 1
|
||||
|
||||
|
||||
class TestRotationStrategy:
|
||||
"""测试轮动策略"""
|
||||
|
||||
def test_before_entry_callback(self):
|
||||
"""测试入场前回调"""
|
||||
def test_rotation_strategy_init(self):
|
||||
"""测试初始化"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 溢价正常,允许入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.05)
|
||||
# 检查因子初始化
|
||||
assert strategy._factors is not None
|
||||
assert strategy._signal_gen is not None
|
||||
|
||||
def test_rotation_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
# 检查结果
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_dynamic_stoploss(self):
|
||||
"""测试动态止损"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 测试不同持仓时间
|
||||
position_5days = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=95.0,
|
||||
entry_time=datetime.now() - pd.Timedelta(days=5)
|
||||
)
|
||||
|
||||
# 5天持仓止损阈值应该为-0.05
|
||||
stoploss = strategy.dynamic_stoploss(position_5days)
|
||||
assert stoploss == -0.05
|
||||
|
||||
def test_before_entry_premium_filter(self):
|
||||
"""测试入场前溢价过滤"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 正常溢价应该通过
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 溢价过高,拒绝入场
|
||||
result = strategy.before_entry('code1', 100.0, premium=0.15)
|
||||
# 高溢价应该被拒绝
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_dynamic_stoploss_callback(self):
|
||||
"""测试动态止损回调"""
|
||||
def test_custom_exit(self):
|
||||
"""测试自定义出场"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 持仓5天
|
||||
position = Position(
|
||||
code='code1',
|
||||
# 正常盈亏不触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
entry_date=datetime(2020, 1, 1),
|
||||
current_price=95.0,
|
||||
current_date=datetime(2020, 1, 6), # 5天后
|
||||
quantity=100,
|
||||
weight=0.33
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
stoploss = strategy.dynamic_stoploss(position)
|
||||
# holding_days=5,返回-0.05
|
||||
assert stoploss == -0.05
|
||||
result = strategy.custom_exit(normal_position)
|
||||
assert result == False
|
||||
|
||||
# 大亏损触发出场
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=85.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = strategy.custom_exit(loss_position)
|
||||
assert result == True
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""测试配置加载器"""
|
||||
class TestStrategyCallbacks:
|
||||
"""测试策略回调机制"""
|
||||
|
||||
def test_load_from_yaml_string(self):
|
||||
"""测试从YAML字符串加载"""
|
||||
yaml_str = """
|
||||
strategy:
|
||||
name: test_strategy
|
||||
version: 1
|
||||
|
||||
factors:
|
||||
- name: momentum
|
||||
weight: 1.0
|
||||
params:
|
||||
n_days: 25
|
||||
|
||||
signal:
|
||||
mode: top_n
|
||||
select_num: 3
|
||||
|
||||
params:
|
||||
stoploss: -0.05
|
||||
"""
|
||||
def test_callback_registration(self):
|
||||
"""测试回调自动注册"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
config = ConfigLoader.from_yaml(yaml_str)
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
assert config['strategy']['name'] == 'test_strategy'
|
||||
assert config['factors'][0]['name'] == 'momentum'
|
||||
assert config['signal']['select_num'] == 3
|
||||
# 检查回调是否注册
|
||||
callbacks = strategy._callbacks.get_callbacks('before_entry')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
def test_callback_trigger_in_run(self):
|
||||
"""测试回调在策略运行中触发"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 添加自定义回调
|
||||
call_count = {'count': 0}
|
||||
|
||||
def counting_callback(code, price, **kwargs):
|
||||
call_count['count'] += 1
|
||||
return True
|
||||
|
||||
strategy._callbacks.register('before_entry', counting_callback)
|
||||
|
||||
# 运行策略
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
strategy.run(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user