Files
etf/framework/execution/__init__.py
aszerW 444dc0e751 refactor(execution): 改为固定仓位分配逻辑
- 原逻辑: 按实际持仓数量等权(选出2只时权重50%)
- 新逻辑: 按select_num固定等权(选出2只时权重33.3%+现金33.3%)
- 缺失仓位用现金替代,收益为0
- 交易成本按固定仓位比例计算
- 目的: 保持稳定风险敞口,避免仓位不足时波动放大
2026-05-16 00:18:19 +08:00

356 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
执行层抽象接口(通用)
只提供抽象基类和Portfolio数据结构具体执行器可扩展
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
from datetime import datetime
from framework.risk import Position
class Portfolio:
"""
投资组合数据结构(通用)
用于管理持仓集合
"""
def __init__(self, initial_capital: float = 100000):
"""初始化投资组合"""
self.initial_capital = initial_capital
self.cash = initial_capital
self.positions: Dict[str, Position] = {}
self.trades: List[Dict] = []
self._net_value_history: List[float] = []
def add_position(self, code: str, price: float, quantity: float, time: datetime) -> None:
"""添加持仓"""
position = Position(
code=code,
entry_price=price,
current_price=price,
entry_time=time,
quantity=quantity
)
self.positions[code] = position
self.cash -= price * quantity
self.trades.append({
'action': 'BUY',
'code': code,
'price': price,
'quantity': quantity,
'time': time
})
def remove_position(self, code: str, price: float, time: datetime) -> float:
"""移除持仓"""
if code not in self.positions:
return 0
position = self.positions[code]
profit = (price - position.entry_price) * position.quantity
self.cash += price * position.quantity
del self.positions[code]
self.trades.append({
'action': 'SELL',
'code': code,
'price': price,
'quantity': position.quantity,
'time': time,
'profit': profit
})
return profit
def update_prices(self, prices: Dict[str, float]) -> None:
"""更新持仓价格"""
for code, price in prices.items():
if code in self.positions:
self.positions[code].current_price = price
def get_net_value(self) -> float:
"""计算净值"""
positions_value = sum(
pos.current_price * pos.quantity for pos in self.positions.values()
)
return self.cash + positions_value
def record_net_value(self) -> None:
"""记录当前净值"""
self._net_value_history.append(self.get_net_value())
def get_net_value_series(self) -> pd.Series:
"""获取净值序列"""
return pd.Series(self._net_value_history)
def get_weight(self, code: str) -> float:
"""计算持仓权重"""
if code not in self.positions:
return 0
position_value = self.positions[code].current_price * self.positions[code].quantity
return position_value / self.get_net_value()
def __repr__(self) -> str:
return f"Portfolio(capital={self.cash:.2f}, positions={len(self.positions)})"
class Executor(ABC):
"""
执行器抽象基类
所有执行器必须实现execute方法
"""
mode: str = "base"
def __init__(self, **params):
"""初始化执行器参数"""
self._params = params
@abstractmethod
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""
执行信号
Args:
signals: 信号DataFrame
data: OHLCV数据
Returns:
Portfolio对象
"""
pass
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"
class BacktestExecutor(Executor):
"""
完整回测执行器(通用)
支持:
- 日收益率计算
- 交易成本扣除
- 净值计算(起点归一化)
- 基准对比
- 持仓跟踪
"""
mode = "backtest"
def __init__(
self,
initial_capital: float = 100000,
trade_cost: float = 0.001,
select_num: int = 1,
benchmark_data: Optional[pd.DataFrame] = None
):
super().__init__(
initial_capital=initial_capital,
trade_cost=trade_cost,
select_num=select_num,
benchmark_data=benchmark_data
)
self.initial_capital = initial_capital
self.trade_cost = trade_cost
self.select_num = select_num
self.benchmark_data = benchmark_data
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""
执行完整回测
Args:
signals: 信号DataFrame包含signal或信号列
data: OHLCV数据和日收益率数据
Returns:
Portfolio对象含净值序列、交易记录
"""
portfolio = Portfolio(self.initial_capital)
# 支持中英文列名
signal_col = 'signal' if 'signal' in signals.columns else '信号'
# 删除空信号行
signals = signals.dropna(subset=[signal_col])
signals = signals[signals[signal_col] != '']
if signals.empty:
return portfolio
# 计算策略日收益率
result = self._calculate_daily_returns(signals, data, signal_col)
# 扣除交易成本
result = self._apply_trade_cost(result, signals, signal_col)
# 计算净值(起点归一化)
result = self._calculate_net_value(result)
# 计算基准净值
result = self._calculate_benchmark(result)
# 记录净值历史
for date in result.index:
portfolio.record_net_value()
# 存储回测结果
portfolio.backtest_result = result
return portfolio
def _calculate_daily_returns(self, signals: pd.DataFrame, data: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
"""计算策略日收益率"""
result = signals.copy()
# 日收益率列名格式日收益率_{code} 或 日收益率_{code}
return_cols = [col for col in data.columns if col.startswith('日收益率_')]
if self.select_num == 1:
# 单标的策略
def calc_return(row):
signal = row[signal_col]
if not signal or pd.isna(signal):
return 0.0
return data.loc[row.name, f'日收益率_{signal}'] if f'日收益率_{signal}' in data.columns else 0.0
result['策略日收益率'] = result.apply(calc_return, axis=1)
else:
# 多标的策略(固定仓位分配)
# 核心逻辑按select_num固定分配仓位缺失标的用现金替代
# 例如select_num=3选出2只标的 → 权重=1/3+1/3现金权重=1/3收益为0
def calc_multi_return(row):
codes = [c for c in row[signal_col].split(',') if c]
if not codes:
# 空仓全部现金收益为0
return 0.0
# 固定仓位权重:每只标的权重 = 1 / select_num
unit_weight = 1.0 / self.select_num
# 计算实际持仓收益缺失标的用现金替代收益为0
total_return = 0.0
for c in codes:
ret = data.loc[row.name, f'日收益率_{c}'] if f'日收益率_{c}' in data.columns else None
if ret is not None and pd.notna(ret):
total_return += ret * unit_weight
# 如果数据缺失视为现金收益为0不累加
# 缺失标的的仓位自动变成现金收益为0
# 总收益 = sum(实际持仓收益) + 0 * (缺失仓位)
return total_return
result['策略日收益率'] = result.apply(calc_multi_return, axis=1)
return result
def _apply_trade_cost(self, result: pd.DataFrame, signals: pd.DataFrame, signal_col: str = 'signal') -> pd.DataFrame:
"""扣除交易成本"""
if self.trade_cost <= 0:
return result
prev_signal = signals[signal_col].shift(1)
if self.select_num == 1:
# 单标的策略:调仓时扣除固定成本
changed = (signals[signal_col] != prev_signal) & prev_signal.notna()
result.loc[changed, '策略日收益率'] -= self.trade_cost
else:
# 多标的策略:按固定仓位比例扣除成本
# 核心逻辑每只标的权重固定为1/select_num
# 换手率 = (调出数量 + 调入数量) / select_num
turnover_list = []
for curr, prev in zip(signals[signal_col], prev_signal):
if pd.isna(prev) or curr == prev:
turnover_list.append(0.0)
else:
old = set(prev.split(','))
new = set(curr.split(','))
# 调出的标的数量(这些仓位需要卖出)
exit_count = len(old - new)
# 调入的标的数量(这些仓位需要买入)
enter_count = len(new - old)
# 换手率 = (卖出 + 买入) / select_num
# 每次调仓涉及的仓位比例
turnover = (exit_count + enter_count) / self.select_num
turnover_list.append(turnover)
result['换手率'] = turnover_list
result['策略日收益率'] -= result['换手率'] * self.trade_cost
return result
def _calculate_net_value(self, result: pd.DataFrame) -> pd.DataFrame:
"""计算净值(起点归一化)"""
result['策略净值'] = (1 + result['策略日收益率']).cumprod()
# 归一化确保净值起点为1.0
result['策略净值'] = result['策略净值'] / result['策略净值'].iloc[0]
return result
def _calculate_benchmark(self, result: pd.DataFrame) -> pd.DataFrame:
"""计算基准净值"""
if self.benchmark_data is None:
return result
# 获取基准收益率
if isinstance(self.benchmark_data, pd.DataFrame):
if 'close' in self.benchmark_data.columns:
bench_close = self.benchmark_data['close']
else:
bench_close = self.benchmark_data.iloc[:, 0]
else:
bench_close = self.benchmark_data
bench_ret = bench_close.pct_change().dropna()
common_dates = result.index.intersection(bench_ret.index)
bench_ret = bench_ret.loc[common_dates]
result['基准日收益率'] = bench_ret.reindex(result.index, fill_value=0)
result['基准净值'] = (1 + result['基准日收益率']).cumprod()
result['基准净值'] = result['基准净值'] / result['基准净值'].iloc[0]
return result
class DryRunExecutor(Executor):
"""
Dry-run执行器通用
用于模拟运行,不实际执行交易
"""
mode = "dry_run"
def __init__(self, verbose: bool = True):
super().__init__(verbose=verbose)
self.verbose = verbose
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> Portfolio:
"""模拟执行"""
portfolio = Portfolio(100000)
for date in signals.index:
signal = signals.loc[date, 'signal']
if signal and self.verbose:
print(f"[{date}] Signal: {signal}")
return portfolio
# 导出抽象接口
__all__ = ['Portfolio', 'Executor', 'BacktestExecutor', 'DryRunExecutor']