Files
etf/archive/framework/tests/test_execution.py
aszerW c905230a40 refactor(archive): move unused modules to archive/
Archive legacy framework and utility modules that are no longer
referenced by the active core (datasource/ and rotation/):

- framework/ -> archive/framework/
- framework_v2/ -> archive/framework_v2/
- strategies/ -> archive/strategies/
- config/ -> archive/config/
- visualization/ -> archive/visualization/
- scripts/ -> archive/scripts/
- tests/ -> archive/tests/
- run_rotation.py, run_us_rotation.py -> archive/single_files/
- compare_*.py, test_api_dates.py -> archive/single_files/
2026-06-03 23:41:46 +08:00

174 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
执行层测试
测试Portfolio、Executor抽象接口
"""
import pandas as pd
import numpy as np
import pytest
from datetime import datetime
from framework.execution import Portfolio, Executor, BacktestExecutor, DryRunExecutor
from framework.risk import Position
class TestPortfolio:
"""测试Portfolio"""
def test_portfolio_init(self):
"""测试初始化"""
portfolio = Portfolio(initial_capital=100000)
assert portfolio.initial_capital == 100000
assert portfolio.cash == 100000
assert len(portfolio.positions) == 0
def test_add_position(self):
"""测试添加持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position(
code='AAPL',
price=100.0,
quantity=10,
time=datetime.now()
)
assert len(portfolio.positions) == 1
assert 'AAPL' in portfolio.positions
assert portfolio.cash == 99000 # 100000 - 100*10
def test_remove_position(self):
"""测试移除持仓"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
profit = portfolio.remove_position('AAPL', 110.0, datetime.now())
assert len(portfolio.positions) == 0
assert profit == 100.0 # (110-100)*10
assert portfolio.cash == 100100 # 99000 + 110*10
def test_update_prices(self):
"""测试更新价格"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.add_position('MSFT', 200.0, 5, datetime.now())
portfolio.update_prices({'AAPL': 110.0, 'MSFT': 220.0})
assert portfolio.positions['AAPL'].current_price == 110.0
assert portfolio.positions['MSFT'].current_price == 220.0
def test_get_net_value(self):
"""测试净值计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
net_value = portfolio.get_net_value()
expected = 99000 + 110 * 10 # cash + position_value
assert net_value == expected
def test_get_weight(self):
"""测试权重计算"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 100, datetime.now())
portfolio.update_prices({'AAPL': 110.0})
weight = portfolio.get_weight('AAPL')
# 持仓价值=110*100=11000净值=90000+11000=101000
expected_weight = 11000 / 101000
assert abs(weight - expected_weight) < 0.01
def test_record_net_value(self):
"""测试净值记录"""
portfolio = Portfolio(initial_capital=100000)
portfolio.record_net_value()
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
portfolio.record_net_value()
series = portfolio.get_net_value_series()
assert len(series) == 2
def test_portfolio_repr(self):
"""测试字符串表示"""
portfolio = Portfolio(initial_capital=100000)
portfolio.add_position('AAPL', 100.0, 10, datetime.now())
repr_str = repr(portfolio)
assert 'Portfolio' in repr_str
assert 'positions=1' in repr_str
class TestBacktestExecutor:
"""测试回测执行器"""
def test_backtest_executor_init(self):
"""测试初始化"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
assert executor.initial_capital == 100000
assert executor.trade_cost == 0.001
assert executor.mode == "backtest"
def test_backtest_execute(self):
"""测试回测执行"""
executor = BacktestExecutor(initial_capital=100000)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
def test_backtest_executor_repr(self):
"""测试字符串表示"""
executor = BacktestExecutor(initial_capital=100000, trade_cost=0.001)
repr_str = repr(executor)
assert 'BacktestExecutor' in repr_str
class TestDryRunExecutor:
"""测试DryRun执行器"""
def test_dry_run_executor_init(self):
"""测试初始化"""
executor = DryRunExecutor(verbose=True)
assert executor.verbose == True
assert executor.mode == "dry_run"
def test_dry_run_execute(self):
"""测试模拟执行"""
executor = DryRunExecutor(verbose=False)
dates = pd.date_range('2020-01-01', periods=50)
signals = pd.DataFrame({
'signal': ['AAPL,MSFT'] * 50
}, index=dates)
data = pd.DataFrame({
'close': np.random.randn(50).cumsum() + 100
}, index=dates)
portfolio = executor.execute(signals, data)
assert isinstance(portfolio, Portfolio)
if __name__ == '__main__':
pytest.main([__file__, '-v'])