refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer referenced by the active core (datasource/ and rotation/): - framework/ -> archive/framework/ - framework_v2/ -> archive/framework_v2/ - strategies/ -> archive/strategies/ - config/ -> archive/config/ - visualization/ -> archive/visualization/ - scripts/ -> archive/scripts/ - tests/ -> archive/tests/ - run_rotation.py, run_us_rotation.py -> archive/single_files/ - compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
174
archive/framework/tests/test_execution.py
Normal file
174
archive/framework/tests/test_execution.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
执行层测试
|
||||
|
||||
测试Portfolio、Executor抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestPortfolio:
|
||||
"""测试Portfolio"""
|
||||
|
||||
def test_portfolio_init(self):
|
||||
"""测试初始化"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
assert portfolio.initial_capital == 100000
|
||||
assert portfolio.cash == 100000
|
||||
assert len(portfolio.positions) == 0
|
||||
|
||||
def test_add_position(self):
|
||||
"""测试添加持仓"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position(
|
||||
code='AAPL',
|
||||
price=100.0,
|
||||
quantity=10,
|
||||
time=datetime.now()
|
||||
)
|
||||
|
||||
assert len(portfolio.positions) == 1
|
||||
assert 'AAPL' in portfolio.positions
|
||||
assert portfolio.cash == 99000 # 100000 - 100*10
|
||||
|
||||
def test_remove_position(self):
|
||||
"""测试移除持仓"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
profit = portfolio.remove_position('AAPL', 110.0, datetime.now())
|
||||
|
||||
assert len(portfolio.positions) == 0
|
||||
assert profit == 100.0 # (110-100)*10
|
||||
assert portfolio.cash == 100100 # 99000 + 110*10
|
||||
|
||||
def test_update_prices(self):
|
||||
"""测试更新价格"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.add_position('MSFT', 200.0, 5, datetime.now())
|
||||
|
||||
portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0})
|
||||
|
||||
assert portfolio.positions['AAPL'].current_price == 110.0
|
||||
assert portfolio.positions['MSFT'].current_price == 220.0
|
||||
|
||||
def test_get_net_value(self):
|
||||
"""测试净值计算"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.update_prices({'AAPL': 110.0})
|
||||
|
||||
net_value = portfolio.get_net_value()
|
||||
expected = 99000 + 110 * 10 # cash + position_value
|
||||
assert net_value == expected
|
||||
|
||||
def test_get_weight(self):
|
||||
"""测试权重计算"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.add_position('AAPL', 100.0, 100, datetime.now())
|
||||
portfolio.update_prices({'AAPL': 110.0})
|
||||
|
||||
weight = portfolio.get_weight('AAPL')
|
||||
# 持仓价值=110*100=11000,净值=90000+11000=101000
|
||||
expected_weight = 11000 / 101000
|
||||
assert abs(weight - expected_weight) < 0.01
|
||||
|
||||
def test_record_net_value(self):
|
||||
"""测试净值记录"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
|
||||
portfolio.record_net_value()
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
portfolio.record_net_value()
|
||||
|
||||
series = portfolio.get_net_value_series()
|
||||
assert len(series) == 2
|
||||
|
||||
def test_portfolio_repr(self):
|
||||
"""测试字符串表示"""
|
||||
portfolio = Portfolio(initial_capital=100000)
|
||||
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
|
||||
|
||||
repr_str = repr(portfolio)
|
||||
assert 'Portfolio' in repr_str
|
||||
assert 'positions=1' in repr_str
|
||||
|
||||
|
||||
class TestBacktestExecutor:
|
||||
"""测试回测执行器"""
|
||||
|
||||
def test_backtest_executor_init(self):
|
||||
"""测试初始化"""
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
|
||||
assert executor.initial_capital == 100000
|
||||
assert executor.trade_cost == 0.001
|
||||
assert executor.mode == "backtest"
|
||||
|
||||
def test_backtest_execute(self):
|
||||
"""测试回测执行"""
|
||||
executor = BacktestExecutor(initial_capital=100000)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
signals = pd.DataFrame({
|
||||
'signal': ['AAPL'] * 50
|
||||
}, index=dates)
|
||||
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(50).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
def test_backtest_executor_repr(self):
|
||||
"""测试字符串表示"""
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
repr_str = repr(executor)
|
||||
|
||||
assert 'BacktestExecutor' in repr_str
|
||||
|
||||
|
||||
class TestDryRunExecutor:
|
||||
"""测试DryRun执行器"""
|
||||
|
||||
def test_dry_run_executor_init(self):
|
||||
"""测试初始化"""
|
||||
executor = DryRunExecutor(verbose=True)
|
||||
|
||||
assert executor.verbose == True
|
||||
assert executor.mode == "dry_run"
|
||||
|
||||
def test_dry_run_execute(self):
|
||||
"""测试模拟执行"""
|
||||
executor = DryRunExecutor(verbose=False)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
signals = pd.DataFrame({
|
||||
'signal': ['AAPL,MSFT'] * 50
|
||||
}, index=dates)
|
||||
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(50).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
288
archive/framework/tests/test_factors.py
Normal file
288
archive/framework/tests/test_factors.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
因子层测试
|
||||
|
||||
测试FactorBase、FactorRegistry、FactorCombiner抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
|
||||
|
||||
class TestFactorBase:
|
||||
"""测试FactorBase抽象基类"""
|
||||
|
||||
def test_factor_meta(self):
|
||||
"""测试因子元信息"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
assert factor.name == "momentum"
|
||||
assert factor.category == "momentum"
|
||||
|
||||
def test_factor_repr(self):
|
||||
"""测试因子字符串表示"""
|
||||
factor = MomentumFactor(n_days=30)
|
||||
repr_str = repr(factor)
|
||||
assert "MomentumFactor" in repr_str
|
||||
|
||||
def test_validate_data(self):
|
||||
"""测试数据验证"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
|
||||
# 数据充足
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(data) == True
|
||||
|
||||
# 数据不足
|
||||
short_data = pd.DataFrame({
|
||||
'close': np.random.randn(10).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(short_data) == False
|
||||
|
||||
|
||||
class TestFactorRegistry:
|
||||
"""测试因子注册器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_factor(self):
|
||||
"""测试因子注册"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
assert 'momentum' in FactorRegistry.list_factors()
|
||||
|
||||
def test_get_factor(self):
|
||||
"""测试获取因子实例"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
factor = FactorRegistry.get('momentum', n_days=30)
|
||||
assert isinstance(factor, MomentumFactor)
|
||||
assert factor.n_days == 30
|
||||
|
||||
def test_get_unknown_factor(self):
|
||||
"""测试获取未注册因子"""
|
||||
with pytest.raises(ValueError):
|
||||
FactorRegistry.get('unknown_factor')
|
||||
|
||||
def test_get_category(self):
|
||||
"""测试获取因子类别"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
category = FactorRegistry.get_category('momentum')
|
||||
assert category == 'momentum'
|
||||
|
||||
|
||||
class TestFactorCombiner:
|
||||
"""测试因子组合器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_combiner_init(self):
|
||||
"""测试组合器初始化"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor(method='ma_cross')
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
|
||||
|
||||
assert len(combiner.get_factor_names()) == 2
|
||||
|
||||
def test_combiner_equal_weights(self):
|
||||
"""测试等权组合"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors) # 默认等权
|
||||
|
||||
# 权重应该归一化
|
||||
assert sum(combiner._weights) == 1.0
|
||||
|
||||
def test_combiner_compute(self):
|
||||
"""测试因子组合计算"""
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor(fast=5, slow=10)
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'momentum' in result.columns
|
||||
assert 'trend' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
def test_combiner_method_rank_average(self):
|
||||
"""测试rank_average组合方法"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors, method='rank_average')
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# combined应该是排名平均值
|
||||
assert 'combined' in result.columns
|
||||
|
||||
|
||||
class TestMomentumFactor:
|
||||
"""测试动量因子"""
|
||||
|
||||
def test_momentum_compute(self):
|
||||
"""测试动量因子计算"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势数据
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的动量得分
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_crash_filter(self):
|
||||
"""测试崩盘过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成正常数据,然后在末尾添加崩盘
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
|
||||
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, crash_filter=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 崩盘后动量得分应该被清零
|
||||
assert values.iloc[-1] == 0.0
|
||||
|
||||
def test_simple_momentum(self):
|
||||
"""测试简单动量(无加权,无崩盘过滤)"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 简单动量应该是N日涨幅
|
||||
expected = data['close'].pct_change(25)
|
||||
# 验证长度一致
|
||||
assert len(values) == len(expected)
|
||||
|
||||
|
||||
class TestTrendFactor:
|
||||
"""测试趋势因子"""
|
||||
|
||||
def test_ma_cross(self):
|
||||
"""测试MA交叉趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的趋势强度
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_macd(self):
|
||||
"""测试MACD趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='macd')
|
||||
values = factor.compute(data)
|
||||
|
||||
# 检查计算结果
|
||||
assert len(values) == len(data)
|
||||
|
||||
|
||||
class TestReversalFactor:
|
||||
"""测试反转因子"""
|
||||
|
||||
def test_rsi_reversal(self):
|
||||
"""测试RSI反转信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超买数据(持续上涨)
|
||||
prices = 100 + np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, overbought=70)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI超过70应该产生负值(反转向下信号)
|
||||
assert values.iloc[-1] < 0
|
||||
|
||||
def test_rsi_oversold(self):
|
||||
"""测试RSI超卖信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超卖数据(持续下跌)
|
||||
prices = 100 - np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, oversold=30)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI低于30应该产生正值(反转向上信号)
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
|
||||
class TestVolatilityFactor:
|
||||
"""测试波动率因子"""
|
||||
|
||||
def test_std_volatility(self):
|
||||
"""测试标准差波动率"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = VolatilityFactor(method='std', period=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
def test_atr_volatility(self):
|
||||
"""测试ATR波动率"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factor = VolatilityFactor(method='atr', period=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
166
archive/framework/tests/test_integration.py
Normal file
166
archive/framework/tests/test_integration.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
集成测试
|
||||
|
||||
测试框架与定制组件的集成
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import CallbackHook, Position
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.execution import Portfolio, BacktestExecutor
|
||||
|
||||
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
|
||||
from strategies.shared.signals.selectors import TopNSelector
|
||||
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
|
||||
class TestFactorIntegration:
|
||||
"""测试因子集成"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_and_use_custom_factor(self):
|
||||
"""测试注册并使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
|
||||
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
values = factor.compute(data)
|
||||
|
||||
assert len(values) == len(data)
|
||||
|
||||
def test_combiner_with_custom_factors(self):
|
||||
"""测试组合器使用定制因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(VolatilityFactor)
|
||||
|
||||
momentum = FactorRegistry.get('momentum', n_days=25)
|
||||
volatility = FactorRegistry.get('volatility', method='std', period=20)
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
|
||||
result = combiner.compute(data)
|
||||
|
||||
assert 'momentum' in result.columns
|
||||
assert 'volatility' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
|
||||
class TestSignalIntegration:
|
||||
"""测试信号生成器集成"""
|
||||
|
||||
def test_custom_signal_generator_with_factors(self):
|
||||
"""测试定制信号生成器与因子集成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
factor_data = pd.DataFrame({
|
||||
'momentum_A': np.random.randn(50),
|
||||
'momentum_B': np.random.randn(50),
|
||||
'momentum_C': np.random.randn(50),
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||||
result = selector.generate(factor_data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestRiskIntegration:
|
||||
"""测试风控组件集成"""
|
||||
|
||||
def test_custom_risk_control(self):
|
||||
"""测试定制风控组件"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=pd.Timestamp.now()
|
||||
)
|
||||
|
||||
# 首次检查,设置最高价
|
||||
assert control.check(position) == True
|
||||
|
||||
# 价格下跌,触发跟踪止损
|
||||
position.current_price = 100.0
|
||||
assert control.check(position) == False
|
||||
|
||||
def test_callback_hook_with_custom_callback(self):
|
||||
"""测试回调钩子与定制回调集成"""
|
||||
hook = CallbackHook()
|
||||
|
||||
# 注册定制回调
|
||||
callback = premium_filter_callback(threshold=0.10)
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 正常溢价通过
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 高溢价拒绝
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
|
||||
class TestStrategyIntegration:
|
||||
"""测试策略集成"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_rotation_strategy_full_flow(self):
|
||||
"""测试轮动策略完整流程"""
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
# 运行策略
|
||||
result = strategy.run(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_strategy_with_backtest_executor(self):
|
||||
"""测试策略与回测执行器集成"""
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
signals = strategy.run(data)
|
||||
|
||||
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
|
||||
portfolio = executor.execute(signals, data)
|
||||
|
||||
assert isinstance(portfolio, Portfolio)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
325
archive/framework/tests/test_risk.py
Normal file
325
archive/framework/tests/test_risk.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
风控层测试
|
||||
|
||||
测试RiskControl、CallbackHook抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.risk import RiskControl, CallbackHook, Position
|
||||
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
|
||||
|
||||
|
||||
class TestPosition:
|
||||
"""测试Position数据结构"""
|
||||
|
||||
def test_position_creation(self):
|
||||
"""测试持仓创建"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.code == 'AAPL'
|
||||
assert position.entry_price == 100.0
|
||||
assert position.current_price == 105.0
|
||||
|
||||
def test_profit_ratio(self):
|
||||
"""测试盈亏比例计算"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
assert position.profit_ratio == 0.10
|
||||
|
||||
def test_profit_amount(self):
|
||||
"""测试盈亏金额计算"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now(),
|
||||
quantity=10
|
||||
)
|
||||
|
||||
assert position.profit_amount == 100.0
|
||||
|
||||
def test_position_repr(self):
|
||||
"""测试持仓字符串表示"""
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
repr_str = repr(position)
|
||||
assert 'AAPL' in repr_str
|
||||
assert '5.00%' in repr_str
|
||||
|
||||
|
||||
class TestStopLossControl:
|
||||
"""测试止损控制"""
|
||||
|
||||
def test_stop_loss_init(self):
|
||||
"""测试初始化"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
assert control.threshold == -0.05
|
||||
assert control.name == "stop_loss"
|
||||
|
||||
def test_stop_loss_check(self):
|
||||
"""测试止损检查"""
|
||||
control = StopLossControl(threshold=-0.05)
|
||||
|
||||
# 盈利持仓应该通过
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
assert control.check(profit_position) == True
|
||||
|
||||
# 亏损持仓应该触发止损
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=90.0, # 亏损10%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
# 亏损超过阈值应该触发
|
||||
assert control.check(loss_position) == False
|
||||
|
||||
def test_trailing_stop_loss(self):
|
||||
"""测试跟踪止损"""
|
||||
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# 初始最高价=入场价100,当前价105
|
||||
# 更新后最高价=105
|
||||
control.check(position)
|
||||
|
||||
# 从最高价回撤不超过阈值应该通过
|
||||
position.current_price = 103.0 # 回撤约2%
|
||||
assert control.check(position) == True
|
||||
|
||||
# 回撤超过阈值应该触发
|
||||
position.current_price = 100.0 # 回撤约5%
|
||||
assert control.check(position) == False
|
||||
|
||||
|
||||
class TestPositionLimitControl:
|
||||
"""测试仓位限制控制"""
|
||||
|
||||
def test_position_limit_init(self):
|
||||
"""测试初始化"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
assert control.max_position == 0.33
|
||||
assert control.name == "position_limit"
|
||||
|
||||
def test_position_limit_check(self):
|
||||
"""测试仓位限制检查"""
|
||||
control = PositionLimitControl(max_position=0.33)
|
||||
|
||||
# 正常仓位应该通过
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
weight=0.20
|
||||
)
|
||||
assert control.check(normal_position) == True
|
||||
|
||||
# 超限仓位应该触发
|
||||
over_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now(),
|
||||
weight=0.50
|
||||
)
|
||||
assert control.check(over_position) == False
|
||||
|
||||
|
||||
class TestPremiumControl:
|
||||
"""测试溢价控制"""
|
||||
|
||||
def test_premium_control_init(self):
|
||||
"""测试初始化"""
|
||||
control = PremiumControl(threshold=0.10)
|
||||
assert control.threshold == 0.10
|
||||
assert control.name == "premium"
|
||||
|
||||
def test_premium_filter_mode(self):
|
||||
"""测试溢价过滤模式"""
|
||||
control = PremiumControl(threshold=0.10, mode='filter')
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# 正常溢价应该通过
|
||||
assert control.check(position, premium=0.05) == True
|
||||
|
||||
# 高溢价应该被过滤
|
||||
assert control.check(position, premium=0.15) == False
|
||||
|
||||
|
||||
class TestCallbackHook:
|
||||
"""测试回调钩子"""
|
||||
|
||||
def test_callback_hook_init(self):
|
||||
"""测试初始化"""
|
||||
hook = CallbackHook()
|
||||
assert len(hook.list_hooks()) == 6
|
||||
|
||||
def test_register_callback(self):
|
||||
"""测试注册回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def my_callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', my_callback)
|
||||
|
||||
callbacks = hook.get_callbacks('before_entry')
|
||||
assert len(callbacks) == 1
|
||||
|
||||
def test_trigger_before_entry(self):
|
||||
"""测试触发入场前回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def always_pass(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
def always_block(code, price, **kwargs):
|
||||
return False
|
||||
|
||||
hook.register('before_entry', always_pass)
|
||||
hook.register('before_entry', always_block)
|
||||
|
||||
# before_entry要求所有回调返回True才允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == False
|
||||
|
||||
def test_trigger_dynamic_stoploss(self):
|
||||
"""测试触发动态止损回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def stoploss_5p(position):
|
||||
return -0.05
|
||||
|
||||
def stoploss_3p(position):
|
||||
return -0.03
|
||||
|
||||
hook.register('dynamic_stoploss', stoploss_5p)
|
||||
hook.register('dynamic_stoploss', stoploss_3p)
|
||||
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
|
||||
# dynamic_stoploss返回最小止损值(最严格)
|
||||
result = hook.trigger('dynamic_stoploss', position)
|
||||
assert result == -0.05
|
||||
|
||||
def test_trigger_custom_exit(self):
|
||||
"""测试触发自定义出场回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def exit_on_loss(position):
|
||||
return position.profit_ratio < -0.05
|
||||
|
||||
def exit_on_profit(position):
|
||||
return position.profit_ratio > 0.20
|
||||
|
||||
hook.register('custom_exit', exit_on_loss)
|
||||
hook.register('custom_exit', exit_on_profit)
|
||||
|
||||
# custom_exit任一回调触发即可
|
||||
profit_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=125.0, # 盈利25%
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', profit_position)
|
||||
assert result == True
|
||||
|
||||
# 未触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=110.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', normal_position)
|
||||
assert result == False
|
||||
|
||||
def test_clear_hooks(self):
|
||||
"""测试清空回调"""
|
||||
hook = CallbackHook()
|
||||
|
||||
def callback(code, price, **kwargs):
|
||||
return True
|
||||
|
||||
hook.register('before_entry', callback)
|
||||
|
||||
# 清空特定钩子
|
||||
hook.clear('before_entry')
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
|
||||
# 清空所有钩子
|
||||
hook.register('before_entry', callback)
|
||||
hook.register('after_entry', callback)
|
||||
hook.clear()
|
||||
assert len(hook.get_callbacks('before_entry')) == 0
|
||||
assert len(hook.get_callbacks('after_entry')) == 0
|
||||
|
||||
def test_default_behavior(self):
|
||||
"""测试默认行为"""
|
||||
hook = CallbackHook() # 无注册回调
|
||||
|
||||
# before_entry默认允许
|
||||
result = hook.trigger('before_entry', 'AAPL', 100.0)
|
||||
assert result == True
|
||||
|
||||
# dynamic_stoploss默认值
|
||||
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
|
||||
assert result == -0.05
|
||||
|
||||
# custom_exit默认不出场
|
||||
position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=105.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = hook.trigger('custom_exit', position)
|
||||
assert result == False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
163
archive/framework/tests/test_signals.py
Normal file
163
archive/framework/tests/test_signals.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
信号层测试
|
||||
|
||||
测试SignalGenerator抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.signals import SignalGenerator
|
||||
from strategies.shared.signals.selectors import TopNSelector, TrendFollower, ReversalTrader
|
||||
|
||||
|
||||
class TestTopNSelector:
|
||||
"""测试TopNSelector"""
|
||||
|
||||
def test_top_n_selector_init(self):
|
||||
"""测试初始化"""
|
||||
selector = TopNSelector(select_num=3)
|
||||
assert selector.select_num == 3
|
||||
assert selector.mode == "top_n"
|
||||
|
||||
def test_top_n_selection(self):
|
||||
"""测试Top N选股"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据
|
||||
data = pd.DataFrame({
|
||||
'factor_A': np.random.randn(50),
|
||||
'factor_B': np.random.randn(50),
|
||||
'factor_C': np.random.randn(50),
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2)
|
||||
result = selector.generate(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'signal' in result.columns
|
||||
assert 'signal_raw' in result.columns
|
||||
|
||||
# 检查T+1移位(signal比signal_raw滞后1天)
|
||||
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
|
||||
|
||||
def test_min_score_filter(self):
|
||||
"""测试最小得分过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据,部分为负值
|
||||
data = pd.DataFrame({
|
||||
'factor_A': [0.1] * 50,
|
||||
'factor_B': [-0.1] * 50, # 负分
|
||||
'factor_C': [0.2] * 50,
|
||||
}, index=dates)
|
||||
|
||||
selector = TopNSelector(select_num=2, min_score=0.0)
|
||||
result = selector.generate(data)
|
||||
|
||||
# 负分因子应该被过滤
|
||||
signals = result['signal_raw'].dropna().unique()
|
||||
for sig in signals:
|
||||
if sig:
|
||||
codes = sig.split(',')
|
||||
assert 'factor_B' not in codes
|
||||
|
||||
def test_grouped_selection(self):
|
||||
"""测试分组选股"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成因子数据和分组信息
|
||||
data = pd.DataFrame({
|
||||
'factor_A': [0.1] * 50,
|
||||
'factor_B': [0.2] * 50,
|
||||
'factor_C': [0.15] * 50,
|
||||
}, index=dates)
|
||||
|
||||
# 分组信息(模拟)
|
||||
# 注:实际使用需要在数据中包含group_info列
|
||||
|
||||
selector = TopNSelector(select_num=2, group_by='market')
|
||||
result = selector.generate(data)
|
||||
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestTrendFollower:
|
||||
"""测试趋势跟随器"""
|
||||
|
||||
def test_trend_follower_init(self):
|
||||
"""测试初始化"""
|
||||
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
|
||||
assert follower.entry_threshold == 0.02
|
||||
assert follower.exit_threshold == -0.02
|
||||
assert follower.mode == "trend"
|
||||
|
||||
def test_trend_signal_generation(self):
|
||||
"""测试趋势信号生成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成趋势因子数据
|
||||
data = pd.DataFrame({
|
||||
'trend_A': [0.05] * 50, # 强趋势
|
||||
'trend_B': [-0.05] * 50, # 弱趋势
|
||||
}, index=dates)
|
||||
|
||||
follower = TrendFollower(entry_threshold=0.02)
|
||||
result = follower.generate(data)
|
||||
|
||||
# 检查信号列
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestReversalTrader:
|
||||
"""测试反转交易器"""
|
||||
|
||||
def test_reversal_trader_init(self):
|
||||
"""测试初始化"""
|
||||
trader = ReversalTrader(overbought=70, oversold=30)
|
||||
assert trader.overbought == 70
|
||||
assert trader.oversold == 30
|
||||
assert trader.mode == "reversal"
|
||||
|
||||
def test_reversal_signal_generation(self):
|
||||
"""测试反转信号生成"""
|
||||
dates = pd.date_range('2020-01-01', periods=50)
|
||||
|
||||
# 生成反转因子数据
|
||||
data = pd.DataFrame({
|
||||
'reversal_A': [0.15] * 50, # 超卖反转
|
||||
'reversal_B': [-0.15] * 50, # 超买反转
|
||||
}, index=dates)
|
||||
|
||||
trader = ReversalTrader(reversal_threshold=0.1)
|
||||
result = trader.generate(data)
|
||||
|
||||
# 检查信号列
|
||||
assert 'signal' in result.columns
|
||||
|
||||
|
||||
class TestSignalGeneratorBase:
|
||||
"""测试SignalGenerator抽象基类"""
|
||||
|
||||
def test_validate_factor_data(self):
|
||||
"""测试数据验证"""
|
||||
selector = TopNSelector(select_num=3)
|
||||
|
||||
# 空数据应该返回False
|
||||
empty_data = pd.DataFrame()
|
||||
assert selector.validate_factor_data(empty_data) == False
|
||||
|
||||
# 有效数据应该返回True
|
||||
valid_data = pd.DataFrame({'factor_A': [1, 2, 3]})
|
||||
assert selector.validate_factor_data(valid_data) == True
|
||||
|
||||
def test_repr(self):
|
||||
"""测试字符串表示"""
|
||||
selector = TopNSelector(select_num=3, min_score=0.5)
|
||||
repr_str = repr(selector)
|
||||
assert "TopNSelector" in repr_str
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
194
archive/framework/tests/test_strategy.py
Normal file
194
archive/framework/tests/test_strategy.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
策略层测试
|
||||
|
||||
测试StrategyBase抽象接口
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from framework.strategy import StrategyBase
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.signals import SignalGenerator
|
||||
from framework.risk import Position
|
||||
|
||||
|
||||
class TestStrategyBase:
|
||||
"""测试StrategyBase抽象基类"""
|
||||
|
||||
def test_strategy_config_override(self):
|
||||
"""测试配置覆盖类属性"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03})
|
||||
|
||||
assert strategy.select_num == 5
|
||||
assert strategy.stoploss == -0.03
|
||||
|
||||
def test_strategy_default_values(self):
|
||||
"""测试默认值"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
|
||||
assert strategy.select_num == 3
|
||||
assert strategy.stoploss == -0.05
|
||||
|
||||
def test_strategy_repr(self):
|
||||
"""测试字符串表示"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
repr_str = repr(strategy)
|
||||
|
||||
assert 'RotationStrategy' in repr_str
|
||||
assert 'rotation' in repr_str
|
||||
|
||||
def test_strategy_interface_version(self):
|
||||
"""测试接口版本"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
strategy = RotationStrategy()
|
||||
assert strategy.INTERFACE_VERSION == 1
|
||||
|
||||
|
||||
class TestRotationStrategy:
|
||||
"""测试轮动策略"""
|
||||
|
||||
def test_rotation_strategy_init(self):
|
||||
"""测试初始化"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 检查因子初始化
|
||||
assert strategy._factors is not None
|
||||
assert strategy._signal_gen is not None
|
||||
|
||||
def test_rotation_strategy_run(self):
|
||||
"""测试策略运行"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
result = strategy.run(data)
|
||||
|
||||
# 检查结果
|
||||
assert 'signal' in result.columns
|
||||
|
||||
def test_dynamic_stoploss(self):
|
||||
"""测试动态止损"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 测试不同持仓时间
|
||||
position_5days = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=95.0,
|
||||
entry_time=datetime.now() - pd.Timedelta(days=5)
|
||||
)
|
||||
|
||||
# 5天持仓止损阈值应该为-0.05
|
||||
stoploss = strategy.dynamic_stoploss(position_5days)
|
||||
assert stoploss == -0.05
|
||||
|
||||
def test_before_entry_premium_filter(self):
|
||||
"""测试入场前溢价过滤"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 正常溢价应该通过
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.05)
|
||||
assert result == True
|
||||
|
||||
# 高溢价应该被拒绝
|
||||
result = strategy.before_entry('AAPL', 100.0, premium=0.15)
|
||||
assert result == False
|
||||
|
||||
def test_custom_exit(self):
|
||||
"""测试自定义出场"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 正常盈亏不触发
|
||||
normal_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=95.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = strategy.custom_exit(normal_position)
|
||||
assert result == False
|
||||
|
||||
# 大亏损触发出场
|
||||
loss_position = Position(
|
||||
code='AAPL',
|
||||
entry_price=100.0,
|
||||
current_price=85.0,
|
||||
entry_time=datetime.now()
|
||||
)
|
||||
result = strategy.custom_exit(loss_position)
|
||||
assert result == True
|
||||
|
||||
|
||||
class TestStrategyCallbacks:
|
||||
"""测试策略回调机制"""
|
||||
|
||||
def test_callback_registration(self):
|
||||
"""测试回调自动注册"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 检查回调是否注册
|
||||
callbacks = strategy._callbacks.get_callbacks('before_entry')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss')
|
||||
assert len(callbacks) > 0
|
||||
|
||||
def test_callback_trigger_in_run(self):
|
||||
"""测试回调在策略运行中触发"""
|
||||
from strategies.rotation.strategy import RotationStrategy
|
||||
|
||||
FactorRegistry.clear()
|
||||
strategy = RotationStrategy()
|
||||
|
||||
# 添加自定义回调
|
||||
call_count = {'count': 0}
|
||||
|
||||
def counting_callback(code, price, **kwargs):
|
||||
call_count['count'] += 1
|
||||
return True
|
||||
|
||||
strategy._callbacks.register('before_entry', counting_callback)
|
||||
|
||||
# 运行策略
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
strategy.run(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user