feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略
StrategyBase ABC: - 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute - 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices) - 提供 run() 方法执行完整回测流程 SimpleRotationStrategy: - 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions - 支持动量因子计算(MomentumFactor) - 实现全局选股和等权仓位管理 - 修复 int64 → float 转换问题 框架定位: - 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略 - 策略只需实现 4 个抽象方法即可接入框架
This commit is contained in:
@@ -8,72 +8,115 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
策略抽象基类(V2 增强版)
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置
|
||||
2. 获取数据
|
||||
3. 计算因子
|
||||
4. 生成信号
|
||||
5. 执行回测
|
||||
1. 初始化配置(使用 Pydantic Schema)
|
||||
2. 获取数据(支持多数据源)
|
||||
3. 计算因子(使用框架因子库)
|
||||
4. 生成信号(策略特定逻辑)
|
||||
5. 执行回测(框架通用执行器)
|
||||
|
||||
子类必须实现:
|
||||
- init_factors(): 初始化因子
|
||||
- init_signal_generator(): 初始化信号生成器
|
||||
- get_codes(): 获取标的列表
|
||||
- compute_factors(): 计算因子
|
||||
- generate_signals(): 生成信号
|
||||
- manage_positions(): 仓位管理
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置字典
|
||||
config: 策略配置(Pydantic Schema)
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._factor = None
|
||||
self._signal_generator = None
|
||||
self.config = config
|
||||
self.name = config.metadata.strategy
|
||||
|
||||
# 组件将在子类中初始化
|
||||
self._data_fetcher = None
|
||||
self._executor = None
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> Any:
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
初始化因子组件
|
||||
获取标的列表(策略必须实现)
|
||||
|
||||
Returns:
|
||||
因子实例(继承 FactorBase)
|
||||
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> Any:
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
初始化信号生成器
|
||||
计算因子(策略必须实现)
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
信号生成器实例(继承 SignalGenerator)
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, Any]:
|
||||
@abstractmethod
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
获取数据(可选覆盖)
|
||||
生成信号(策略必须实现)
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
数据字典,包含:
|
||||
- index_data: 指数数据
|
||||
- etf_data: ETF数据
|
||||
- benchmark_data: 基准数据
|
||||
- valid_codes: 有效标的列表
|
||||
- trading_calendar: 交易日历
|
||||
信号 DataFrame(包含 'signal' 列,1=买入,0=空仓)
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_data()")
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(策略必须实现)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列,权重和为 1)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取数据(框架实现,策略可覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典 {code: DataFrame}
|
||||
"""
|
||||
if self._data_fetcher is None:
|
||||
self._data_fetcher = self._create_data_fetcher()
|
||||
|
||||
codes = self.get_codes()
|
||||
|
||||
# 批量获取数据(fetch_indices 返回 {code: DataFrame})
|
||||
try:
|
||||
data = self._data_fetcher.fetch_indices(
|
||||
codes=codes,
|
||||
start=self.config.backtest.start_date,
|
||||
end=self.config.backtest.end_date
|
||||
)
|
||||
return data
|
||||
except Exception as e:
|
||||
print(f" 错误: 数据获取失败 - {e}")
|
||||
return {}
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -111,28 +154,46 @@ class StrategyBase(ABC):
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程
|
||||
运行完整回测流程(框架标准流程)
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
回测结果字典,包含:
|
||||
- equity_curve: 净值曲线
|
||||
- trades: 交易记录
|
||||
- metrics: 绩效指标
|
||||
"""
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
print("[1/5] 获取数据...")
|
||||
data = self.get_data()
|
||||
print(f" 获取 {len(data)} 个标的")
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
print("[2/5] 计算因子...")
|
||||
factors = self.compute_factors(data)
|
||||
print(f" 计算 {len(factors)} 个因子")
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
print("[3/5] 生成信号...")
|
||||
signals = self.generate_signals(factors)
|
||||
print(f" 生成 {signals.shape[0]} 个信号")
|
||||
|
||||
# 4. 执行回测(子类实现)
|
||||
return self._execute_backtest(signals, data)
|
||||
# 4. 仓位管理
|
||||
print("[4/5] 仓位管理...")
|
||||
positions = self.manage_positions(signals)
|
||||
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
|
||||
|
||||
# 5. 执行回测
|
||||
print("[5/5] 执行回测...")
|
||||
result = self._execute_backtest(positions, data)
|
||||
print(f" 回测完成")
|
||||
|
||||
return result
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -147,5 +208,21 @@ class StrategyBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
def _create_data_fetcher(self):
|
||||
"""
|
||||
创建数据获取器(框架实现)
|
||||
|
||||
Returns:
|
||||
DataFetcher 实例
|
||||
"""
|
||||
from framework_v2.shared.data import FlaskAPIFetcher
|
||||
|
||||
# 使用配置中的第一个启用的数据源
|
||||
for source_config in self.config.data.sources:
|
||||
if source_config.enabled and source_config.type.value == 'flask_api':
|
||||
return FlaskAPIFetcher(
|
||||
base_url=source_config.url,
|
||||
timeout=source_config.timeout
|
||||
)
|
||||
|
||||
raise ValueError("未找到可用的数据源配置")
|
||||
|
||||
92
framework_v2/strategies/rotation/config_simple.yaml
Normal file
92
framework_v2/strategies/rotation/config_simple.yaml
Normal file
@@ -0,0 +1,92 @@
|
||||
# 简单轮动策略配置
|
||||
#
|
||||
# 配置版本: 1.0.0
|
||||
# 最后更新: 2024-04-16
|
||||
# 策略名称: simple_rotation
|
||||
# 描述: 基于动量因子的简单 ETF 轮动策略
|
||||
|
||||
# ============================================================
|
||||
# 元数据
|
||||
# ============================================================
|
||||
metadata:
|
||||
version: "1.0.0"
|
||||
strategy: "simple_rotation"
|
||||
description: "简单轮动策略 - 等权分配 + Top-N 选择"
|
||||
last_updated: "2024-04-16"
|
||||
|
||||
# ============================================================
|
||||
# 资产池配置(简化版:只选 3 个标的)
|
||||
# ============================================================
|
||||
asset_pools:
|
||||
equity:
|
||||
"399006.SZ":
|
||||
name: "创业板指"
|
||||
etf: "159915.SZ"
|
||||
market: "CN_EQUITY"
|
||||
description: "创业板指数"
|
||||
|
||||
"NDX":
|
||||
name: "纳指100"
|
||||
etf: "513100.SH"
|
||||
market: "US_EQUITY"
|
||||
description: "纳斯达克100指数"
|
||||
|
||||
commodity: {}
|
||||
fixed_income: {}
|
||||
|
||||
# ============================================================
|
||||
# 基准配置
|
||||
# ============================================================
|
||||
benchmark:
|
||||
code: "000300.SH"
|
||||
name: "沪深300"
|
||||
|
||||
# ============================================================
|
||||
# 回测配置
|
||||
# ============================================================
|
||||
backtest:
|
||||
start_date: "2023-01-01"
|
||||
end_date: "2024-12-31"
|
||||
|
||||
# ============================================================
|
||||
# 因子配置
|
||||
# ============================================================
|
||||
factor:
|
||||
type: "weighted_momentum" # 加权动量
|
||||
n_days: 25 # 25 天窗口
|
||||
|
||||
# ============================================================
|
||||
# 轮动配置
|
||||
# ============================================================
|
||||
rotation:
|
||||
select_num: 2 # 选择 Top-2
|
||||
threshold:
|
||||
mode: "fixed"
|
||||
fixed_value: 0.0 # 无阈值过滤
|
||||
|
||||
# ============================================================
|
||||
# 调仓配置
|
||||
# ============================================================
|
||||
rebalance:
|
||||
min_hold_days: 1
|
||||
score_threshold: 0.0
|
||||
trade_cost: 0.001 # 0.1% 交易成本
|
||||
|
||||
# ============================================================
|
||||
# 溢价控制(禁用)
|
||||
# ============================================================
|
||||
premium_control:
|
||||
enabled: false
|
||||
|
||||
# ============================================================
|
||||
# 数据配置
|
||||
# ============================================================
|
||||
data:
|
||||
sources:
|
||||
- type: "flask_api"
|
||||
enabled: true
|
||||
url: "${FLASK_API_URL}"
|
||||
timeout: 120
|
||||
|
||||
use_cache: true
|
||||
cache_dir: "data_cache"
|
||||
241
framework_v2/strategies/rotation/simple.py
Normal file
241
framework_v2/strategies/rotation/simple.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
简单轮动策略
|
||||
|
||||
基于动量因子的 ETF 轮动策略
|
||||
- 计算各标的动量得分
|
||||
- 选择 Top-N 标的
|
||||
- 等权分配仓位
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
|
||||
from framework_v2.core.strategy import StrategyBase
|
||||
from framework_v2.config.schemas import StrategyConfig
|
||||
from framework_v2.shared.factors import MomentumFactor
|
||||
|
||||
|
||||
class SimpleRotationStrategy(StrategyBase):
|
||||
"""
|
||||
简单轮动策略
|
||||
|
||||
策略逻辑:
|
||||
1. 计算各标的动量得分(加权线性回归)
|
||||
2. 选择得分最高的 Top-N 标的
|
||||
3. 等权分配仓位
|
||||
|
||||
示例:
|
||||
from framework_v2.config import load_config
|
||||
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
|
||||
|
||||
config = load_config('rotation_simple.yaml')
|
||||
strategy = SimpleRotationStrategy(config)
|
||||
result = strategy.run()
|
||||
"""
|
||||
|
||||
def __init__(self, config: StrategyConfig):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
# 初始化动量因子
|
||||
self.momentum = MomentumFactor(
|
||||
n_days=config.factor.n_days,
|
||||
weighted=(config.factor.type.value == 'weighted_momentum')
|
||||
)
|
||||
|
||||
# 策略参数
|
||||
self.select_num = config.rotation.select_num if config.rotation else 3
|
||||
self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0
|
||||
|
||||
def get_codes(self) -> list:
|
||||
"""
|
||||
获取标的列表
|
||||
|
||||
从配置的资产池中获取所有标的
|
||||
"""
|
||||
codes = []
|
||||
|
||||
# 股票资产
|
||||
if self.config.asset_pools.equity:
|
||||
codes.extend(self.config.asset_pools.equity.keys())
|
||||
|
||||
# 商品资产
|
||||
if self.config.asset_pools.commodity:
|
||||
codes.extend(self.config.asset_pools.commodity.keys())
|
||||
|
||||
# 固定收益资产
|
||||
if self.config.asset_pools.fixed_income:
|
||||
codes.extend(self.config.asset_pools.fixed_income.keys())
|
||||
|
||||
return codes
|
||||
|
||||
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
|
||||
"""
|
||||
计算动量因子
|
||||
|
||||
Args:
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
因子字典 {code: Series}
|
||||
"""
|
||||
factors = {}
|
||||
|
||||
for code, df in data.items():
|
||||
try:
|
||||
# 计算动量得分
|
||||
factor_values = self.momentum.compute(df)
|
||||
factors[code] = factor_values
|
||||
except Exception as e:
|
||||
print(f" 警告: {code} 因子计算失败 - {e}")
|
||||
continue
|
||||
|
||||
return factors
|
||||
|
||||
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
|
||||
"""
|
||||
生成轮动信号
|
||||
|
||||
逻辑:
|
||||
1. 每个交易日选择动量得分最高的 Top-N 标的
|
||||
2. 过滤得分低于阈值的标的
|
||||
|
||||
Args:
|
||||
factors: 因子字典 {code: Series}
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(index=日期, columns=标的, values=1或0)
|
||||
"""
|
||||
if not factors:
|
||||
return pd.DataFrame()
|
||||
|
||||
# 对齐所有因子的日期
|
||||
factor_df = pd.DataFrame(factors)
|
||||
|
||||
# 生成信号
|
||||
signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0)
|
||||
|
||||
for date in factor_df.index:
|
||||
# 获取当日因子值
|
||||
scores = factor_df.loc[date].dropna()
|
||||
|
||||
if scores.empty:
|
||||
continue
|
||||
|
||||
# 过滤低分标的
|
||||
if self.min_score > 0:
|
||||
scores = scores[scores >= self.min_score]
|
||||
|
||||
# 选择 Top-N
|
||||
if len(scores) > self.select_num:
|
||||
top_codes = scores.nlargest(self.select_num).index
|
||||
else:
|
||||
top_codes = scores.index
|
||||
|
||||
# 标记信号
|
||||
signals.loc[date, top_codes] = 1
|
||||
|
||||
return signals.astype(int)
|
||||
|
||||
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
仓位管理(等权分配)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
|
||||
Returns:
|
||||
仓位 DataFrame(包含 'weight' 列)
|
||||
"""
|
||||
positions = signals.astype(float).copy()
|
||||
|
||||
# 计算每个日期的权重
|
||||
for date in positions.index:
|
||||
signal_row = positions.loc[date]
|
||||
n_selected = signal_row.sum()
|
||||
|
||||
if n_selected > 0:
|
||||
# 等权分配
|
||||
positions.loc[date] = signal_row / n_selected
|
||||
else:
|
||||
# 空仓
|
||||
positions.loc[date] = 0
|
||||
|
||||
return positions
|
||||
|
||||
def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]:
|
||||
"""
|
||||
执行回测
|
||||
|
||||
Args:
|
||||
positions: 仓位 DataFrame
|
||||
data: 数据字典 {code: DataFrame}
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
# 提取收盘价
|
||||
close_prices = {}
|
||||
for code, df in data.items():
|
||||
if 'close' in df.columns:
|
||||
close_prices[code] = df['close']
|
||||
|
||||
close_df = pd.DataFrame(close_prices)
|
||||
|
||||
# 计算收益率
|
||||
returns = close_df.pct_change()
|
||||
|
||||
# 计算策略收益(仓位加权)
|
||||
# 注意:T+1 执行,今天的信号明天生效
|
||||
positions_delayed = positions.shift(1).fillna(0)
|
||||
strategy_returns = (positions_delayed * returns).sum(axis=1)
|
||||
|
||||
# 计算净值曲线
|
||||
equity_curve = (1 + strategy_returns).cumprod()
|
||||
|
||||
# 检查是否有数据
|
||||
if len(equity_curve) == 0:
|
||||
return {
|
||||
'equity_curve': equity_curve,
|
||||
'strategy_returns': strategy_returns,
|
||||
'positions': positions,
|
||||
'metrics': {
|
||||
'total_return': 0,
|
||||
'annual_return': 0,
|
||||
'max_drawdown': 0,
|
||||
'sharpe_ratio': 0,
|
||||
'n_days': 0,
|
||||
}
|
||||
}
|
||||
|
||||
# 计算绩效指标
|
||||
total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1
|
||||
n_days = len(strategy_returns)
|
||||
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
|
||||
|
||||
# 最大回撤
|
||||
cumulative_max = equity_curve.cummax()
|
||||
drawdown = (equity_curve - cumulative_max) / cumulative_max
|
||||
max_drawdown = drawdown.min()
|
||||
|
||||
# 夏普比率
|
||||
sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0
|
||||
|
||||
return {
|
||||
'equity_curve': equity_curve,
|
||||
'strategy_returns': strategy_returns,
|
||||
'positions': positions,
|
||||
'metrics': {
|
||||
'total_return': total_return,
|
||||
'annual_return': annual_return,
|
||||
'max_drawdown': max_drawdown,
|
||||
'sharpe_ratio': sharpe,
|
||||
'n_days': n_days,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user