refactor(archive): move unused modules to archive/

Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
This commit is contained in:
2026-06-03 23:41:46 +08:00
parent d700bc1dfd
commit c905230a40
98 changed files with 0 additions and 714 deletions

View File

@@ -0,0 +1,174 @@
"""
执行层测试
测试Portfolio、Executor抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
from framework.risk import Position
class TestPortfolio:
"""测试Portfolio"""
def test_portfolio_init(self):
"""测试初始化"""
portfolio = Portfolio(initial_capital=100000)
assert portfolio.initial_capital == 100000
assert portfolio.cash == 100000
assert len(portfolio.positions) == 0
def test_add_position(self):
"""测试添加持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position(
code='AAPL',
price=100.0,
quantity=10,
time=datetime.now()
)
assert len(portfolio.positions) == 1
assert 'AAPL' in portfolio.positions
assert portfolio.cash == 99000 # 100000 - 100*10
def test_remove_position(self):
"""测试移除持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
profit = portfolio.remove_position('AAPL', 110.0, datetime.now())
assert len(portfolio.positions) == 0
assert profit == 100.0 # (110-100)*10
assert portfolio.cash == 100100 # 99000 + 110*10
def test_update_prices(self):
"""测试更新价格"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.add_position('MSFT', 200.0, 5, datetime.now())
portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0})
assert portfolio.positions['AAPL'].current_price == 110.0
assert portfolio.positions['MSFT'].current_price == 220.0
def test_get_net_value(self):
"""测试净值计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
net_value = portfolio.get_net_value()
expected = 99000 + 110 * 10 # cash + position_value
assert net_value == expected
def test_get_weight(self):
"""测试权重计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 100, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
weight = portfolio.get_weight('AAPL')
# 持仓价值=110*100=11000净值=90000+11000=101000
expected_weight = 11000 / 101000
assert abs(weight - expected_weight) < 0.01
def test_record_net_value(self):
"""测试净值记录"""
portfolio = Portfolio(initial_capital=100000)
portfolio.record_net_value()
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.record_net_value()
series = portfolio.get_net_value_series()
assert len(series) == 2
def test_portfolio_repr(self):
"""测试字符串表示"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
repr_str = repr(portfolio)
assert 'Portfolio' in repr_str
assert 'positions=1' in repr_str
class TestBacktestExecutor:
"""测试回测执行器"""
def test_backtest_executor_init(self):
"""测试初始化"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
assert executor.initial_capital == 100000
assert executor.trade_cost == 0.001
assert executor.mode == "backtest"
def test_backtest_execute(self):
"""测试回测执行"""
executor = BacktestExecutor(initial_capital=100000)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
def test_backtest_executor_repr(self):
"""测试字符串表示"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
repr_str = repr(executor)
assert 'BacktestExecutor' in repr_str
class TestDryRunExecutor:
"""测试DryRun执行器"""
def test_dry_run_executor_init(self):
"""测试初始化"""
executor = DryRunExecutor(verbose=True)
assert executor.verbose == True
assert executor.mode == "dry_run"
def test_dry_run_execute(self):
"""测试模拟执行"""
executor = DryRunExecutor(verbose=False)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL,MSFT'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,288 @@
"""
因子层测试
测试FactorBase、FactorRegistry、FactorCombiner抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from strategies.shared.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
class TestFactorBase:
"""测试FactorBase抽象基类"""
def test_factor_meta(self):
"""测试因子元信息"""
factor = MomentumFactor(n_days=25)
assert factor.name == "momentum"
assert factor.category == "momentum"
def test_factor_repr(self):
"""测试因子字符串表示"""
factor = MomentumFactor(n_days=30)
repr_str = repr(factor)
assert "MomentumFactor" in repr_str
def test_validate_data(self):
"""测试数据验证"""
factor = MomentumFactor(n_days=25)
# 数据充足
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
})
assert factor.validate_data(data) == True
# 数据不足
short_data = pd.DataFrame({
'close': np.random.randn(10).cumsum() + 100
})
assert factor.validate_data(short_data) == False
class TestFactorRegistry:
"""测试因子注册器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_factor(self):
"""测试因子注册"""
FactorRegistry.register(MomentumFactor)
assert 'momentum' in FactorRegistry.list_factors()
def test_get_factor(self):
"""测试获取因子实例"""
FactorRegistry.register(MomentumFactor)
factor = FactorRegistry.get('momentum', n_days=30)
assert isinstance(factor, MomentumFactor)
assert factor.n_days == 30
def test_get_unknown_factor(self):
"""测试获取未注册因子"""
with pytest.raises(ValueError):
FactorRegistry.get('unknown_factor')
def test_get_category(self):
"""测试获取因子类别"""
FactorRegistry.register(MomentumFactor)
category = FactorRegistry.get_category('momentum')
assert category == 'momentum'
class TestFactorCombiner:
"""测试因子组合器"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_combiner_init(self):
"""测试组合器初始化"""
factors = [
MomentumFactor(n_days=25),
TrendFactor(method='ma_cross')
]
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
assert len(combiner.get_factor_names()) == 2
def test_combiner_equal_weights(self):
"""测试等权组合"""
factors = [
MomentumFactor(n_days=25),
TrendFactor()
]
combiner = FactorCombiner(factors) # 默认等权
# 权重应该归一化
assert sum(combiner._weights) == 1.0
def test_combiner_compute(self):
"""测试因子组合计算"""
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor(fast=5, slow=10)
]
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
result = combiner.compute(data)
# 检查结果列
assert 'momentum' in result.columns
assert 'trend' in result.columns
assert 'combined' in result.columns
def test_combiner_method_rank_average(self):
"""测试rank_average组合方法"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
factors = [
MomentumFactor(n_days=20),
TrendFactor()
]
combiner = FactorCombiner(factors, method='rank_average')
result = combiner.compute(data)
# combined应该是排名平均值
assert 'combined' in result.columns
class TestMomentumFactor:
"""测试动量因子"""
def test_momentum_compute(self):
"""测试动量因子计算"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势数据
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=True)
values = factor.compute(data)
# 上升趋势应该有正的动量得分
assert values.iloc[-1] > 0
def test_crash_filter(self):
"""测试崩盘过滤"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成正常数据,然后在末尾添加崩盘
prices = 100 + np.random.randn(100).cumsum()
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, crash_filter=True)
values = factor.compute(data)
# 崩盘后动量得分应该被清零
assert values.iloc[-1] == 0.0
def test_simple_momentum(self):
"""测试简单动量(无加权,无崩盘过滤)"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
values = factor.compute(data)
# 简单动量应该是N日涨幅
expected = data['close'].pct_change(25)
# 验证长度一致
assert len(values) == len(expected)
class TestTrendFactor:
"""测试趋势因子"""
def test_ma_cross(self):
"""测试MA交叉趋势"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成上升趋势
prices = 100 + np.arange(100) * 0.5
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
values = factor.compute(data)
# 上升趋势应该有正的趋势强度
assert values.iloc[-1] > 0
def test_macd(self):
"""测试MACD趋势"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = TrendFactor(method='macd')
values = factor.compute(data)
# 检查计算结果
assert len(values) == len(data)
class TestReversalFactor:
"""测试反转因子"""
def test_rsi_reversal(self):
"""测试RSI反转信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超买数据(持续上涨)
prices = 100 + np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, overbought=70)
values = factor.compute(data)
# RSI超过70应该产生负值反转向下信号
assert values.iloc[-1] < 0
def test_rsi_oversold(self):
"""测试RSI超卖信号"""
dates = pd.date_range('2020-01-01', periods=100)
# 生成超卖数据(持续下跌)
prices = 100 - np.arange(100) * 1.0
data = pd.DataFrame({'close': prices}, index=dates)
factor = ReversalFactor(method='rsi', period=14, oversold=30)
values = factor.compute(data)
# RSI低于30应该产生正值反转向上信号
assert values.iloc[-1] > 0
class TestVolatilityFactor:
"""测试波动率因子"""
def test_std_volatility(self):
"""测试标准差波动率"""
dates = pd.date_range('2020-01-01', periods=100)
prices = 100 + np.random.randn(100).cumsum()
data = pd.DataFrame({'close': prices}, index=dates)
factor = VolatilityFactor(method='std', period=20)
values = factor.compute(data)
assert len(values) == len(data)
def test_atr_volatility(self):
"""测试ATR波动率"""
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
factor = VolatilityFactor(method='atr', period=20)
values = factor.compute(data)
assert len(values) == len(data)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,166 @@
"""
集成测试
测试框架与定制组件的集成
"""
import pandas as pd
import numpy as np
import pytest
from framework.factors import FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import CallbackHook, Position
from framework.strategy import StrategyBase
from framework.execution import Portfolio, BacktestExecutor
from strategies.shared.factors.momentum import MomentumFactor, VolatilityFactor
from strategies.shared.signals.selectors import TopNSelector
from strategies.shared.risk.controls import StopLossControl, premium_filter_callback
from strategies.rotation.strategy import RotationStrategy
class TestFactorIntegration:
"""测试因子集成"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_register_and_use_custom_factor(self):
"""测试注册并使用定制因子"""
FactorRegistry.register(MomentumFactor)
factor = FactorRegistry.get('momentum', n_days=25, crash_filter=True)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
values = factor.compute(data)
assert len(values) == len(data)
def test_combiner_with_custom_factors(self):
"""测试组合器使用定制因子"""
FactorRegistry.register(MomentumFactor)
FactorRegistry.register(VolatilityFactor)
momentum = FactorRegistry.get('momentum', n_days=25)
volatility = FactorRegistry.get('volatility', method='std', period=20)
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100,
'high': np.random.randn(100).cumsum() + 105,
'low': np.random.randn(100).cumsum() + 95
}, index=dates)
combiner = FactorCombiner([momentum, volatility], weights=[0.7, 0.3])
result = combiner.compute(data)
assert 'momentum' in result.columns
assert 'volatility' in result.columns
assert 'combined' in result.columns
class TestSignalIntegration:
"""测试信号生成器集成"""
def test_custom_signal_generator_with_factors(self):
"""测试定制信号生成器与因子集成"""
dates = pd.date_range('2020-01-01', periods=50)
factor_data = pd.DataFrame({
'momentum_A': np.random.randn(50),
'momentum_B': np.random.randn(50),
'momentum_C': np.random.randn(50),
}, index=dates)
selector = TopNSelector(select_num=2, min_score=0.0)
result = selector.generate(factor_data)
assert 'signal' in result.columns
class TestRiskIntegration:
"""测试风控组件集成"""
def test_custom_risk_control(self):
"""测试定制风控组件"""
control = StopLossControl(threshold=-0.05, trailing=True)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=pd.Timestamp.now()
)
# 首次检查,设置最高价
assert control.check(position) == True
# 价格下跌,触发跟踪止损
position.current_price = 100.0
assert control.check(position) == False
def test_callback_hook_with_custom_callback(self):
"""测试回调钩子与定制回调集成"""
hook = CallbackHook()
# 注册定制回调
callback = premium_filter_callback(threshold=0.10)
hook.register('before_entry', callback)
# 正常溢价通过
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.05)
assert result == True
# 高溢价拒绝
result = hook.trigger('before_entry', 'AAPL', 100.0, premium=0.15)
assert result == False
class TestStrategyIntegration:
"""测试策略集成"""
def setup_method(self):
"""每个测试前清空注册表"""
FactorRegistry.clear()
def test_rotation_strategy_full_flow(self):
"""测试轮动策略完整流程"""
strategy = RotationStrategy()
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
# 运行策略
result = strategy.run(data)
assert 'signal' in result.columns
def test_strategy_with_backtest_executor(self):
"""测试策略与回测执行器集成"""
FactorRegistry.clear()
strategy = RotationStrategy()
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
signals = strategy.run(data)
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,325 @@
"""
风控层测试
测试RiskControl、CallbackHook抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.risk import RiskControl, CallbackHook, Position
from strategies.shared.risk.controls import StopLossControl, PositionLimitControl, PremiumControl
class TestPosition:
"""测试Position数据结构"""
def test_position_creation(self):
"""测试持仓创建"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
quantity=10
)
assert position.code == 'AAPL'
assert position.entry_price == 100.0
assert position.current_price == 105.0
def test_profit_ratio(self):
"""测试盈亏比例计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert position.profit_ratio == 0.10
def test_profit_amount(self):
"""测试盈亏金额计算"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now(),
quantity=10
)
assert position.profit_amount == 100.0
def test_position_repr(self):
"""测试持仓字符串表示"""
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
repr_str = repr(position)
assert 'AAPL' in repr_str
assert '5.00%' in repr_str
class TestStopLossControl:
"""测试止损控制"""
def test_stop_loss_init(self):
"""测试初始化"""
control = StopLossControl(threshold=-0.05)
assert control.threshold == -0.05
assert control.name == "stop_loss"
def test_stop_loss_check(self):
"""测试止损检查"""
control = StopLossControl(threshold=-0.05)
# 盈利持仓应该通过
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
assert control.check(profit_position) == True
# 亏损持仓应该触发止损
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=90.0, # 亏损10%
entry_time=datetime.now()
)
# 亏损超过阈值应该触发
assert control.check(loss_position) == False
def test_trailing_stop_loss(self):
"""测试跟踪止损"""
control = StopLossControl(threshold=-0.05, trailing=True, trailing_percent=0.03)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 初始最高价=入场价100当前价105
# 更新后最高价=105
control.check(position)
# 从最高价回撤不超过阈值应该通过
position.current_price = 103.0 # 回撤约2%
assert control.check(position) == True
# 回撤超过阈值应该触发
position.current_price = 100.0 # 回撤约5%
assert control.check(position) == False
class TestPositionLimitControl:
"""测试仓位限制控制"""
def test_position_limit_init(self):
"""测试初始化"""
control = PositionLimitControl(max_position=0.33)
assert control.max_position == 0.33
assert control.name == "position_limit"
def test_position_limit_check(self):
"""测试仓位限制检查"""
control = PositionLimitControl(max_position=0.33)
# 正常仓位应该通过
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.20
)
assert control.check(normal_position) == True
# 超限仓位应该触发
over_position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now(),
weight=0.50
)
assert control.check(over_position) == False
class TestPremiumControl:
"""测试溢价控制"""
def test_premium_control_init(self):
"""测试初始化"""
control = PremiumControl(threshold=0.10)
assert control.threshold == 0.10
assert control.name == "premium"
def test_premium_filter_mode(self):
"""测试溢价过滤模式"""
control = PremiumControl(threshold=0.10, mode='filter')
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# 正常溢价应该通过
assert control.check(position, premium=0.05) == True
# 高溢价应该被过滤
assert control.check(position, premium=0.15) == False
class TestCallbackHook:
"""测试回调钩子"""
def test_callback_hook_init(self):
"""测试初始化"""
hook = CallbackHook()
assert len(hook.list_hooks()) == 6
def test_register_callback(self):
"""测试注册回调"""
hook = CallbackHook()
def my_callback(code, price, **kwargs):
return True
hook.register('before_entry', my_callback)
callbacks = hook.get_callbacks('before_entry')
assert len(callbacks) == 1
def test_trigger_before_entry(self):
"""测试触发入场前回调"""
hook = CallbackHook()
def always_pass(code, price, **kwargs):
return True
def always_block(code, price, **kwargs):
return False
hook.register('before_entry', always_pass)
hook.register('before_entry', always_block)
# before_entry要求所有回调返回True才允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == False
def test_trigger_dynamic_stoploss(self):
"""测试触发动态止损回调"""
hook = CallbackHook()
def stoploss_5p(position):
return -0.05
def stoploss_3p(position):
return -0.03
hook.register('dynamic_stoploss', stoploss_5p)
hook.register('dynamic_stoploss', stoploss_3p)
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
# dynamic_stoploss返回最小止损值最严格
result = hook.trigger('dynamic_stoploss', position)
assert result == -0.05
def test_trigger_custom_exit(self):
"""测试触发自定义出场回调"""
hook = CallbackHook()
def exit_on_loss(position):
return position.profit_ratio < -0.05
def exit_on_profit(position):
return position.profit_ratio > 0.20
hook.register('custom_exit', exit_on_loss)
hook.register('custom_exit', exit_on_profit)
# custom_exit任一回调触发即可
profit_position = Position(
code='AAPL',
entry_price=100.0,
current_price=125.0, # 盈利25%
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', profit_position)
assert result == True
# 未触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=110.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', normal_position)
assert result == False
def test_clear_hooks(self):
"""测试清空回调"""
hook = CallbackHook()
def callback(code, price, **kwargs):
return True
hook.register('before_entry', callback)
# 清空特定钩子
hook.clear('before_entry')
assert len(hook.get_callbacks('before_entry')) == 0
# 清空所有钩子
hook.register('before_entry', callback)
hook.register('after_entry', callback)
hook.clear()
assert len(hook.get_callbacks('before_entry')) == 0
assert len(hook.get_callbacks('after_entry')) == 0
def test_default_behavior(self):
"""测试默认行为"""
hook = CallbackHook() # 无注册回调
# before_entry默认允许
result = hook.trigger('before_entry', 'AAPL', 100.0)
assert result == True
# dynamic_stoploss默认值
result = hook.trigger('dynamic_stoploss', None, default_stoploss=-0.05)
assert result == -0.05
# custom_exit默认不出场
position = Position(
code='AAPL',
entry_price=100.0,
current_price=105.0,
entry_time=datetime.now()
)
result = hook.trigger('custom_exit', position)
assert result == False
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,163 @@
"""
信号层测试
测试SignalGenerator抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from framework.signals import SignalGenerator
from strategies.shared.signals.selectors import TopNSelector, TrendFollower, ReversalTrader
class TestTopNSelector:
"""测试TopNSelector"""
def test_top_n_selector_init(self):
"""测试初始化"""
selector = TopNSelector(select_num=3)
assert selector.select_num == 3
assert selector.mode == "top_n"
def test_top_n_selection(self):
"""测试Top N选股"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据
data = pd.DataFrame({
'factor_A': np.random.randn(50),
'factor_B': np.random.randn(50),
'factor_C': np.random.randn(50),
}, index=dates)
selector = TopNSelector(select_num=2)
result = selector.generate(data)
# 检查结果列
assert 'signal' in result.columns
assert 'signal_raw' in result.columns
# 检查T+1移位signal比signal_raw滞后1天
assert result['signal'].iloc[0] == '' or pd.isna(result['signal'].iloc[0])
def test_min_score_filter(self):
"""测试最小得分过滤"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据,部分为负值
data = pd.DataFrame({
'factor_A': [0.1] * 50,
'factor_B': [-0.1] * 50, # 负分
'factor_C': [0.2] * 50,
}, index=dates)
selector = TopNSelector(select_num=2, min_score=0.0)
result = selector.generate(data)
# 负分因子应该被过滤
signals = result['signal_raw'].dropna().unique()
for sig in signals:
if sig:
codes = sig.split(',')
assert 'factor_B' not in codes
def test_grouped_selection(self):
"""测试分组选股"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成因子数据和分组信息
data = pd.DataFrame({
'factor_A': [0.1] * 50,
'factor_B': [0.2] * 50,
'factor_C': [0.15] * 50,
}, index=dates)
# 分组信息(模拟)
# 注实际使用需要在数据中包含group_info列
selector = TopNSelector(select_num=2, group_by='market')
result = selector.generate(data)
assert 'signal' in result.columns
class TestTrendFollower:
"""测试趋势跟随器"""
def test_trend_follower_init(self):
"""测试初始化"""
follower = TrendFollower(entry_threshold=0.02, exit_threshold=-0.02)
assert follower.entry_threshold == 0.02
assert follower.exit_threshold == -0.02
assert follower.mode == "trend"
def test_trend_signal_generation(self):
"""测试趋势信号生成"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成趋势因子数据
data = pd.DataFrame({
'trend_A': [0.05] * 50, # 强趋势
'trend_B': [-0.05] * 50, # 弱趋势
}, index=dates)
follower = TrendFollower(entry_threshold=0.02)
result = follower.generate(data)
# 检查信号列
assert 'signal' in result.columns
class TestReversalTrader:
"""测试反转交易器"""
def test_reversal_trader_init(self):
"""测试初始化"""
trader = ReversalTrader(overbought=70, oversold=30)
assert trader.overbought == 70
assert trader.oversold == 30
assert trader.mode == "reversal"
def test_reversal_signal_generation(self):
"""测试反转信号生成"""
dates = pd.date_range('2020-01-01', periods=50)
# 生成反转因子数据
data = pd.DataFrame({
'reversal_A': [0.15] * 50, # 超卖反转
'reversal_B': [-0.15] * 50, # 超买反转
}, index=dates)
trader = ReversalTrader(reversal_threshold=0.1)
result = trader.generate(data)
# 检查信号列
assert 'signal' in result.columns
class TestSignalGeneratorBase:
"""测试SignalGenerator抽象基类"""
def test_validate_factor_data(self):
"""测试数据验证"""
selector = TopNSelector(select_num=3)
# 空数据应该返回False
empty_data = pd.DataFrame()
assert selector.validate_factor_data(empty_data) == False
# 有效数据应该返回True
valid_data = pd.DataFrame({'factor_A': [1, 2, 3]})
assert selector.validate_factor_data(valid_data) == True
def test_repr(self):
"""测试字符串表示"""
selector = TopNSelector(select_num=3, min_score=0.5)
repr_str = repr(selector)
assert "TopNSelector" in repr_str
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

@@ -0,0 +1,194 @@
"""
策略层测试
测试StrategyBase抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.strategy import StrategyBase
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
from framework.signals import SignalGenerator
from framework.risk import Position
class TestStrategyBase:
"""测试StrategyBase抽象基类"""
def test_strategy_config_override(self):
"""测试配置覆盖类属性"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy(config={'select_num': 5, 'stoploss': -0.03})
assert strategy.select_num == 5
assert strategy.stoploss == -0.03
def test_strategy_default_values(self):
"""测试默认值"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
assert strategy.select_num == 3
assert strategy.stoploss == -0.05
def test_strategy_repr(self):
"""测试字符串表示"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
repr_str = repr(strategy)
assert 'RotationStrategy' in repr_str
assert 'rotation' in repr_str
def test_strategy_interface_version(self):
"""测试接口版本"""
from strategies.rotation.strategy import RotationStrategy
strategy = RotationStrategy()
assert strategy.INTERFACE_VERSION == 1
class TestRotationStrategy:
"""测试轮动策略"""
def test_rotation_strategy_init(self):
"""测试初始化"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 检查因子初始化
assert strategy._factors is not None
assert strategy._signal_gen is not None
def test_rotation_strategy_run(self):
"""测试策略运行"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 生成测试数据
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
result = strategy.run(data)
# 检查结果
assert 'signal' in result.columns
def test_dynamic_stoploss(self):
"""测试动态止损"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 测试不同持仓时间
position_5days = Position(
code='AAPL',
entry_price=100.0,
current_price=95.0,
entry_time=datetime.now() - pd.Timedelta(days=5)
)
# 5天持仓止损阈值应该为-0.05
stoploss = strategy.dynamic_stoploss(position_5days)
assert stoploss == -0.05
def test_before_entry_premium_filter(self):
"""测试入场前溢价过滤"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 正常溢价应该通过
result = strategy.before_entry('AAPL', 100.0, premium=0.05)
assert result == True
# 高溢价应该被拒绝
result = strategy.before_entry('AAPL', 100.0, premium=0.15)
assert result == False
def test_custom_exit(self):
"""测试自定义出场"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 正常盈亏不触发
normal_position = Position(
code='AAPL',
entry_price=100.0,
current_price=95.0,
entry_time=datetime.now()
)
result = strategy.custom_exit(normal_position)
assert result == False
# 大亏损触发出场
loss_position = Position(
code='AAPL',
entry_price=100.0,
current_price=85.0,
entry_time=datetime.now()
)
result = strategy.custom_exit(loss_position)
assert result == True
class TestStrategyCallbacks:
"""测试策略回调机制"""
def test_callback_registration(self):
"""测试回调自动注册"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 检查回调是否注册
callbacks = strategy._callbacks.get_callbacks('before_entry')
assert len(callbacks) > 0
callbacks = strategy._callbacks.get_callbacks('dynamic_stoploss')
assert len(callbacks) > 0
def test_callback_trigger_in_run(self):
"""测试回调在策略运行中触发"""
from strategies.rotation.strategy import RotationStrategy
FactorRegistry.clear()
strategy = RotationStrategy()
# 添加自定义回调
call_count = {'count': 0}
def counting_callback(code, price, **kwargs):
call_count['count'] += 1
return True
strategy._callbacks.register('before_entry', counting_callback)
# 运行策略
dates = pd.date_range('2020-01-01', periods=100)
data = pd.DataFrame({
'close': np.random.randn(100).cumsum() + 100
}, index=dates)
strategy.run(data)
if __name__ == '__main__':
pytest.main([__file__, '-v'])