""" 执行层测试 测试Portfolio、Executor抽象接口 """ import pandas as pd import numpy as np import pytest from datetime import datetime from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor from framework.risk import Position class TestPortfolio: """测试Portfolio""" def test_portfolio_init(self): """测试初始化""" portfolio = Portfolio(initial_capital=100000) assert portfolio.initial_capital == 100000 assert portfolio.cash == 100000 assert len(portfolio.positions) == 0 def test_add_position(self): """测试添加持仓""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position( code='AAPL', price=100.0, quantity=10, time=datetime.now() ) assert len(portfolio.positions) == 1 assert 'AAPL' in portfolio.positions assert portfolio.cash == 99000 # 100000 - 100*10 def test_remove_position(self): """测试移除持仓""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position('AAPL', 100.0, 10, datetime.now()) profit = portfolio.remove_position('AAPL', 110.0, datetime.now()) assert len(portfolio.positions) == 0 assert profit == 100.0 # (110-100)*10 assert portfolio.cash == 100100 # 99000 + 110*10 def test_update_prices(self): """测试更新价格""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position('AAPL', 100.0, 10, datetime.now()) portfolio.add_position('MSFT', 200.0, 5, datetime.now()) portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0}) assert portfolio.positions['AAPL'].current_price == 110.0 assert portfolio.positions['MSFT'].current_price == 220.0 def test_get_net_value(self): """测试净值计算""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position('AAPL', 100.0, 10, datetime.now()) portfolio.update_prices({'AAPL': 110.0}) net_value = portfolio.get_net_value() expected = 99000 + 110 * 10 # cash + position_value assert net_value == expected def test_get_weight(self): """测试权重计算""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position('AAPL', 100.0, 100, datetime.now()) portfolio.update_prices({'AAPL': 110.0}) weight = portfolio.get_weight('AAPL') # 持仓价值=110*100=11000,净值=90000+11000=101000 expected_weight = 11000 / 101000 assert abs(weight - expected_weight) < 0.01 def test_record_net_value(self): """测试净值记录""" portfolio = Portfolio(initial_capital=100000) portfolio.record_net_value() portfolio.add_position('AAPL', 100.0, 10, datetime.now()) portfolio.record_net_value() series = portfolio.get_net_value_series() assert len(series) == 2 def test_portfolio_repr(self): """测试字符串表示""" portfolio = Portfolio(initial_capital=100000) portfolio.add_position('AAPL', 100.0, 10, datetime.now()) repr_str = repr(portfolio) assert 'Portfolio' in repr_str assert 'positions=1' in repr_str class TestBacktestExecutor: """测试回测执行器""" def test_backtest_executor_init(self): """测试初始化""" executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001) assert executor.initial_capital == 100000 assert executor.trade_cost == 0.001 assert executor.mode == "backtest" def test_backtest_execute(self): """测试回测执行""" executor = BacktestExecutor(initial_capital=100000) dates = pd.date_range('2020-01-01', periods=50) signals = pd.DataFrame({ 'signal': ['AAPL'] * 50 }, index=dates) data = pd.DataFrame({ 'close': np.random.randn(50).cumsum() + 100 }, index=dates) portfolio = executor.execute(signals, data) assert isinstance(portfolio, Portfolio) def test_backtest_executor_repr(self): """测试字符串表示""" executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001) repr_str = repr(executor) assert 'BacktestExecutor' in repr_str class TestDryRunExecutor: """测试DryRun执行器""" def test_dry_run_executor_init(self): """测试初始化""" executor = DryRunExecutor(verbose=True) assert executor.verbose == True assert executor.mode == "dry_run" def test_dry_run_execute(self): """测试模拟执行""" executor = DryRunExecutor(verbose=False) dates = pd.date_range('2020-01-01', periods=50) signals = pd.DataFrame({ 'signal': ['AAPL,MSFT'] * 50 }, index=dates) data = pd.DataFrame({ 'close': np.random.randn(50).cumsum() + 100 }, index=dates) portfolio = executor.execute(signals, data) assert isinstance(portfolio, Portfolio) if __name__ == '__main__': pytest.main([__file__, '-v'])