Files
factorhack/backtest.py
2025-11-08 13:39:02 +08:00

181 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
回测模块
"""
import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple
class BacktestEngine:
"""回测引擎"""
def __init__(
self,
commission: float = 0.001, # 手续费率
slippage: float = 0.0005, # 滑点
initial_capital: float = 10000.0
):
self.commission = commission
self.slippage = slippage
self.initial_capital = initial_capital
def run(
self,
signals: pd.Series,
price: pd.Series,
score: Optional[pd.Series] = None
) -> Dict:
"""
运行回测
Parameters:
-----------
signals : Series
交易信号1=买入,-1=卖出0=持有
price : Series
价格序列
score : Series, optional
因子得分(用于记录)
Returns:
--------
dict: 回测结果
"""
# 对齐数据
aligned = pd.concat([signals, price], axis=1).dropna()
aligned.columns = ['signal', 'price']
if score is not None:
aligned = pd.concat([aligned, score], axis=1)
aligned.columns = ['signal', 'price', 'score']
# 向量化优化:先计算价格变化率
price_pct = aligned['price'].pct_change().fillna(0)
# 初始化
capital = self.initial_capital
position = 0 # 持仓0=空仓1=满仓
equity = np.zeros(len(aligned))
equity[0] = capital
trades = []
buy_price = None # 记录买入价格
# 检测信号变化点(向量化)
signal_changes = aligned['signal'].diff().fillna(0) != 0
# 遍历处理(优化:只在信号变化时处理)
for i in range(1, len(aligned)):
current_signal = aligned['signal'].iloc[i]
current_price = aligned['price'].iloc[i]
prev_signal = aligned['signal'].iloc[i-1]
# 计算收益率(基于价格变化)
if position == 1:
period_return = price_pct.iloc[i]
else:
period_return = 0
# 交易逻辑(只在信号变化时处理)
if signal_changes.iloc[i]:
if current_signal == 1 and position == 0: # 买入
# 扣除手续费和滑点
cost = self.commission + self.slippage
capital *= (1 - cost)
position = 1
buy_price = current_price
trades.append({
'date': aligned.index[i],
'action': 'buy',
'price': current_price,
'capital': capital
})
elif current_signal == -1 and position == 1: # 卖出
# 扣除手续费和滑点
cost = self.commission + self.slippage
capital *= (1 - cost)
position = 0
buy_price = None
trades.append({
'date': aligned.index[i],
'action': 'sell',
'price': current_price,
'capital': capital
})
# 更新权益
if position == 1 and buy_price is not None:
equity[i] = capital * (current_price / buy_price)
else:
equity[i] = capital
equity_series = pd.Series(equity, index=aligned.index)
returns_series = price_pct * (aligned['signal'].shift(1) == 1).astype(int)
# 计算回测指标
metrics = self._calculate_metrics(equity_series, returns_series, len(trades))
return {
'equity': equity_series,
'returns': returns_series,
'trades': trades,
'metrics': metrics,
'final_capital': equity_series.iloc[-1] if len(equity_series) > 0 else self.initial_capital
}
def _calculate_metrics(
self,
equity: pd.Series,
returns: pd.Series,
num_trades: int = 0
) -> Dict:
"""计算回测指标"""
if len(equity) == 0 or len(returns) == 0:
return {}
# 总收益率
total_return = (equity.iloc[-1] / equity.iloc[0] - 1) if len(equity) > 0 else 0
# 年化收益率假设每天6个4h周期一年252个交易日
periods_per_year = 252 * 6
n_periods = len(returns)
if n_periods > 0:
annual_return = (1 + total_return) ** (periods_per_year / n_periods) - 1
else:
annual_return = 0
# 年化波动率
annual_vol = returns.std() * np.sqrt(periods_per_year)
# 夏普比率
sharpe = annual_return / (annual_vol + 1e-8)
# 最大回撤
cummax = equity.cummax()
drawdown = (equity - cummax) / cummax
max_drawdown = drawdown.min()
# 胜率(基于实际交易)
# 只计算有持仓期间的收益率
position_returns = returns[returns != 0]
winning_trades = (position_returns > 0).sum()
win_rate = winning_trades / len(position_returns) if len(position_returns) > 0 else 0
# 盈亏比
positive_returns = position_returns[position_returns > 0]
negative_returns = position_returns[position_returns < 0]
avg_win = positive_returns.mean() if len(positive_returns) > 0 else 0
avg_loss = abs(negative_returns.mean()) if len(negative_returns) > 0 else 0
profit_loss_ratio = avg_win / (avg_loss + 1e-8)
return {
'total_return': total_return,
'annual_return': annual_return,
'annual_volatility': annual_vol,
'sharpe_ratio': sharpe,
'max_drawdown': max_drawdown,
'win_rate': win_rate,
'profit_loss_ratio': profit_loss_ratio,
'total_trades': num_trades # 实际交易次数
}