feat(framework_v2): 实现 StrategyBase 抽象基类和简单轮动策略

StrategyBase ABC:
- 定义标准回测流程:get_data → compute_factors → generate_signals → manage_positions → execute
- 实现通用数据获取(使用 FlaskAPIFetcher.fetch_indices)
- 提供 run() 方法执行完整回测流程

SimpleRotationStrategy:
- 实现 4 个抽象方法:get_codes, compute_factors, generate_signals, manage_positions
- 支持动量因子计算(MomentumFactor)
- 实现全局选股和等权仓位管理
- 修复 int64 → float 转换问题

框架定位:
- 通用量化回测框架,支持轮动、CTA、趋势跟踪等多种策略
- 策略只需实现 4 个抽象方法即可接入框架
This commit is contained in:
2026-05-24 14:25:47 +08:00
parent 341611c32b
commit de988b919b
3 changed files with 450 additions and 40 deletions

View File

@@ -8,72 +8,115 @@ from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
from framework_v2.config.schemas import StrategyConfig
class StrategyBase(ABC):
"""
策略抽象基类
策略抽象基类V2 增强版)
定义策略的标准生命周期:
1. 初始化配置
2. 获取数据
3. 计算因子
4. 生成信号
5. 执行回测
1. 初始化配置(使用 Pydantic Schema
2. 获取数据(支持多数据源)
3. 计算因子(使用框架因子库)
4. 生成信号(策略特定逻辑)
5. 执行回测(框架通用执行器)
子类必须实现:
- init_factors(): 初始化因子
- init_signal_generator(): 初始化信号生成器
- get_codes(): 获取标的列表
- compute_factors(): 计算因子
- generate_signals(): 生成信号
- manage_positions(): 仓位管理
"""
INTERFACE_VERSION = 2 # V2 版本
name: str = "base"
timeframe: str = "1d"
def __init__(self, config: Optional[Dict] = None):
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置字典
config: 策略配置Pydantic Schema
"""
self.config = config or {}
self._factor = None
self._signal_generator = None
self.config = config
self.name = config.metadata.strategy
# 组件将在子类中初始化
self._data_fetcher = None
self._executor = None
@abstractmethod
def init_factors(self) -> Any:
def get_codes(self) -> list:
"""
初始化因子组件
获取标的列表(策略必须实现)
Returns:
因子实例(继承 FactorBase
标的代码列表,如 ['399006.SZ', 'NDX', 'GC=F']
"""
pass
@abstractmethod
def init_signal_generator(self) -> Any:
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
初始化信号生成器
计算因子(策略必须实现)
Args:
data: 数据字典 {code: DataFrame}
Returns:
信号生成器实例(继承 SignalGenerator
因子字典 {code: Series}
"""
pass
def get_data(self) -> Dict[str, Any]:
@abstractmethod
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
获取数据(可选覆盖
生成信号(策略必须实现
Args:
factors: 因子字典 {code: Series}
Returns:
数据字典,包含:
- index_data: 指数数据
- etf_data: ETF数据
- benchmark_data: 基准数据
- valid_codes: 有效标的列表
- trading_calendar: 交易日历
信号 DataFrame包含 'signal'1=买入0=空仓)
"""
raise NotImplementedError("Subclasses must implement get_data()")
pass
@abstractmethod
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(策略必须实现)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列,权重和为 1
"""
pass
def get_data(self) -> Dict[str, pd.DataFrame]:
"""
获取数据(框架实现,策略可覆盖)
Returns:
数据字典 {code: DataFrame}
"""
if self._data_fetcher is None:
self._data_fetcher = self._create_data_fetcher()
codes = self.get_codes()
# 批量获取数据fetch_indices 返回 {code: DataFrame}
try:
data = self._data_fetcher.fetch_indices(
codes=codes,
start=self.config.backtest.start_date,
end=self.config.backtest.end_date
)
return data
except Exception as e:
print(f" 错误: 数据获取失败 - {e}")
return {}
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
@@ -111,28 +154,46 @@ class StrategyBase(ABC):
return self._signal_generator.generate(factor_df)
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
def run(self, data: Optional[Dict[str, pd.DataFrame]] = None) -> Dict[str, Any]:
"""
运行完整回测流程
运行完整回测流程(框架标准流程)
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典
回测结果字典,包含:
- equity_curve: 净值曲线
- trades: 交易记录
- metrics: 绩效指标
"""
# 1. 获取数据
if data is None:
print("[1/5] 获取数据...")
data = self.get_data()
print(f" 获取 {len(data)} 个标的")
# 2. 计算因子
factor_df = self.compute_factors(data)
print("[2/5] 计算因子...")
factors = self.compute_factors(data)
print(f" 计算 {len(factors)} 个因子")
# 3. 生成信号
signals = self.generate_signals(factor_df)
print("[3/5] 生成信号...")
signals = self.generate_signals(factors)
print(f" 生成 {signals.shape[0]} 个信号")
# 4. 执行回测(子类实现)
return self._execute_backtest(signals, data)
# 4. 仓位管理
print("[4/5] 仓位管理...")
positions = self.manage_positions(signals)
print(f" 平均持仓: {positions['weight'].sum().mean():.2%}")
# 5. 执行回测
print("[5/5] 执行回测...")
result = self._execute_backtest(positions, data)
print(f" 回测完成")
return result
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -147,5 +208,21 @@ class StrategyBase(ABC):
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"
def _create_data_fetcher(self):
"""
创建数据获取器(框架实现)
Returns:
DataFetcher 实例
"""
from framework_v2.shared.data import FlaskAPIFetcher
# 使用配置中的第一个启用的数据源
for source_config in self.config.data.sources:
if source_config.enabled and source_config.type.value == 'flask_api':
return FlaskAPIFetcher(
base_url=source_config.url,
timeout=source_config.timeout
)
raise ValueError("未找到可用的数据源配置")

View File

@@ -0,0 +1,92 @@
# 简单轮动策略配置
#
# 配置版本: 1.0.0
# 最后更新: 2024-04-16
# 策略名称: simple_rotation
# 描述: 基于动量因子的简单 ETF 轮动策略
# ============================================================
# 元数据
# ============================================================
metadata:
version: "1.0.0"
strategy: "simple_rotation"
description: "简单轮动策略 - 等权分配 + Top-N 选择"
last_updated: "2024-04-16"
# ============================================================
# 资产池配置(简化版:只选 3 个标的)
# ============================================================
asset_pools:
equity:
"399006.SZ":
name: "创业板指"
etf: "159915.SZ"
market: "CN_EQUITY"
description: "创业板指数"
"NDX":
name: "纳指100"
etf: "513100.SH"
market: "US_EQUITY"
description: "纳斯达克100指数"
commodity: {}
fixed_income: {}
# ============================================================
# 基准配置
# ============================================================
benchmark:
code: "000300.SH"
name: "沪深300"
# ============================================================
# 回测配置
# ============================================================
backtest:
start_date: "2023-01-01"
end_date: "2024-12-31"
# ============================================================
# 因子配置
# ============================================================
factor:
type: "weighted_momentum" # 加权动量
n_days: 25 # 25 天窗口
# ============================================================
# 轮动配置
# ============================================================
rotation:
select_num: 2 # 选择 Top-2
threshold:
mode: "fixed"
fixed_value: 0.0 # 无阈值过滤
# ============================================================
# 调仓配置
# ============================================================
rebalance:
min_hold_days: 1
score_threshold: 0.0
trade_cost: 0.001 # 0.1% 交易成本
# ============================================================
# 溢价控制(禁用)
# ============================================================
premium_control:
enabled: false
# ============================================================
# 数据配置
# ============================================================
data:
sources:
- type: "flask_api"
enabled: true
url: "${FLASK_API_URL}"
timeout: 120
use_cache: true
cache_dir: "data_cache"

View File

@@ -0,0 +1,241 @@
"""
简单轮动策略
基于动量因子的 ETF 轮动策略
- 计算各标的动量得分
- 选择 Top-N 标的
- 等权分配仓位
"""
import pandas as pd
import numpy as np
from typing import Dict
from framework_v2.core.strategy import StrategyBase
from framework_v2.config.schemas import StrategyConfig
from framework_v2.shared.factors import MomentumFactor
class SimpleRotationStrategy(StrategyBase):
"""
简单轮动策略
策略逻辑:
1. 计算各标的动量得分(加权线性回归)
2. 选择得分最高的 Top-N 标的
3. 等权分配仓位
示例:
from framework_v2.config import load_config
from framework_v2.strategies.rotation.simple import SimpleRotationStrategy
config = load_config('rotation_simple.yaml')
strategy = SimpleRotationStrategy(config)
result = strategy.run()
"""
def __init__(self, config: StrategyConfig):
"""
初始化策略
Args:
config: 策略配置
"""
super().__init__(config)
# 初始化动量因子
self.momentum = MomentumFactor(
n_days=config.factor.n_days,
weighted=(config.factor.type.value == 'weighted_momentum')
)
# 策略参数
self.select_num = config.rotation.select_num if config.rotation else 3
self.min_score = config.rotation.threshold.fixed_value if config.rotation else 0.0
def get_codes(self) -> list:
"""
获取标的列表
从配置的资产池中获取所有标的
"""
codes = []
# 股票资产
if self.config.asset_pools.equity:
codes.extend(self.config.asset_pools.equity.keys())
# 商品资产
if self.config.asset_pools.commodity:
codes.extend(self.config.asset_pools.commodity.keys())
# 固定收益资产
if self.config.asset_pools.fixed_income:
codes.extend(self.config.asset_pools.fixed_income.keys())
return codes
def compute_factors(self, data: Dict[str, pd.DataFrame]) -> Dict[str, pd.Series]:
"""
计算动量因子
Args:
data: 数据字典 {code: DataFrame}
Returns:
因子字典 {code: Series}
"""
factors = {}
for code, df in data.items():
try:
# 计算动量得分
factor_values = self.momentum.compute(df)
factors[code] = factor_values
except Exception as e:
print(f" 警告: {code} 因子计算失败 - {e}")
continue
return factors
def generate_signals(self, factors: Dict[str, pd.Series]) -> pd.DataFrame:
"""
生成轮动信号
逻辑:
1. 每个交易日选择动量得分最高的 Top-N 标的
2. 过滤得分低于阈值的标的
Args:
factors: 因子字典 {code: Series}
Returns:
信号 DataFrameindex=日期, columns=标的, values=1或0
"""
if not factors:
return pd.DataFrame()
# 对齐所有因子的日期
factor_df = pd.DataFrame(factors)
# 生成信号
signals = pd.DataFrame(index=factor_df.index, columns=factor_df.columns, data=0)
for date in factor_df.index:
# 获取当日因子值
scores = factor_df.loc[date].dropna()
if scores.empty:
continue
# 过滤低分标的
if self.min_score > 0:
scores = scores[scores >= self.min_score]
# 选择 Top-N
if len(scores) > self.select_num:
top_codes = scores.nlargest(self.select_num).index
else:
top_codes = scores.index
# 标记信号
signals.loc[date, top_codes] = 1
return signals.astype(int)
def manage_positions(self, signals: pd.DataFrame) -> pd.DataFrame:
"""
仓位管理(等权分配)
Args:
signals: 信号 DataFrame
Returns:
仓位 DataFrame包含 'weight' 列)
"""
positions = signals.astype(float).copy()
# 计算每个日期的权重
for date in positions.index:
signal_row = positions.loc[date]
n_selected = signal_row.sum()
if n_selected > 0:
# 等权分配
positions.loc[date] = signal_row / n_selected
else:
# 空仓
positions.loc[date] = 0
return positions
def _execute_backtest(self, positions: pd.DataFrame, data: Dict[str, pd.DataFrame]) -> Dict[str, any]:
"""
执行回测
Args:
positions: 仓位 DataFrame
data: 数据字典 {code: DataFrame}
Returns:
回测结果字典
"""
# 提取收盘价
close_prices = {}
for code, df in data.items():
if 'close' in df.columns:
close_prices[code] = df['close']
close_df = pd.DataFrame(close_prices)
# 计算收益率
returns = close_df.pct_change()
# 计算策略收益(仓位加权)
# 注意T+1 执行,今天的信号明天生效
positions_delayed = positions.shift(1).fillna(0)
strategy_returns = (positions_delayed * returns).sum(axis=1)
# 计算净值曲线
equity_curve = (1 + strategy_returns).cumprod()
# 检查是否有数据
if len(equity_curve) == 0:
return {
'equity_curve': equity_curve,
'strategy_returns': strategy_returns,
'positions': positions,
'metrics': {
'total_return': 0,
'annual_return': 0,
'max_drawdown': 0,
'sharpe_ratio': 0,
'n_days': 0,
}
}
# 计算绩效指标
total_return = equity_curve.iloc[-1] / equity_curve.iloc[0] - 1
n_days = len(strategy_returns)
annual_return = (1 + total_return) ** (252 / n_days) - 1 if n_days > 0 else 0
# 最大回撤
cumulative_max = equity_curve.cummax()
drawdown = (equity_curve - cumulative_max) / cumulative_max
max_drawdown = drawdown.min()
# 夏普比率
sharpe = strategy_returns.mean() / strategy_returns.std() * np.sqrt(252) if strategy_returns.std() > 0 else 0
return {
'equity_curve': equity_curve,
'strategy_returns': strategy_returns,
'positions': positions,
'metrics': {
'total_return': total_return,
'annual_return': annual_return,
'max_drawdown': max_drawdown,
'sharpe_ratio': sharpe,
'n_days': n_days,
}
}