feat(framework_v2): 创建框架V2骨架 - 三层架构+因子验证通过

## 架构设计
- 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试)
- 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher
- 零侵入:与现有框架并行开发,不修改生产代码

## 已完成
✓ 核心接口层(5个ABC类)
✓ 通用因子层(MomentumFactor完全复制现有逻辑)
✓ 对比验证测试(新旧因子输出差异=0,测试通过)

## 验证结果
- 最大差异: 0.000000e+00
- 平均差异: 0.000000e+00
- 容差: < 1e-10

## 下一步
- 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController)
- 阶段4: 执行层迁移(BacktestRunner)
- 阶段5: 数据层迁移(DataFetcher实现)
- 阶段6: 完整策略对比验证

## 设计原则
- 按需抽象,不预先设计
- 职责分离,避免框架膨胀
- 测试驱动,每个组件必须有对比测试
- 渐进式迁移,验证通过再替换
This commit is contained in:
2026-05-24 09:12:29 +08:00
parent 226a27361f
commit 908b28473f
14 changed files with 928 additions and 0 deletions

180
framework_v2/README.md Normal file
View File

@@ -0,0 +1,180 @@
# 框架 V2 - 重构版本
## 📋 设计理念
### 三层架构
```
framework_v2/
├── core/ # 纯抽象接口(零实现)
├── shared/ # 通用实现2+策略复用)
└── tests/ # 框架测试
```
### 设计原则
1. **按需抽象**:不预先设计,只抽象已验证的通用逻辑
2. **职责分离**:数据获取、因子计算、信号生成、回测执行各司其职
3. **向后兼容**:与现有策略并行运行,验证一致后再替换
4. **测试驱动**:每个组件必须有对比验证测试
---
## 🏗️ 目录结构
```
framework_v2/
├── __init__.py
├── README.md
├── core/ # 核心抽象接口
│ ├── __init__.py
│ ├── strategy.py # StrategyBase (ABC)
│ ├── factor.py # FactorBase (ABC)
│ ├── signal.py # SignalGenerator (ABC)
│ ├── executor.py # Executor (ABC)
│ └── data.py # DataFetcher (ABC)
├── shared/ # 通用实现
│ ├── __init__.py
│ └── factors/
│ ├── __init__.py
│ ├── talib_base.py # TALibFactorBase (需要 talib)
│ └── momentum.py # 动量因子(已验证✓)
└── tests/ # 测试
├── __init__.py
└── test_momentum_parity.py # 因子对比测试(通过✓)
```
---
## ✅ 已完成
### 阶段1: 核心接口层 ✓
- [x] StrategyBase - 策略抽象基类
- [x] FactorBase - 因子抽象基类
- [x] SignalGenerator - 信号生成器抽象基类
- [x] Executor - 执行器抽象基类
- [x] DataFetcher - 数据获取器抽象基类
### 阶段2: 通用因子层 ✓
- [x] MomentumFactor - 动量因子(完全复制现有逻辑)
- [x] 对比验证测试(通过✓,差异 = 0
---
## 🎯 验证结果
### MomentumFactor 对比测试
```
============================================================
MomentumFactor 对比测试
============================================================
1. 加载测试数据...
⚠ 未找到测试数据,使用模拟数据
2. 计算旧因子strategies/shared/factors/momentum.py...
✓ 旧因子计算完成
结果范围: -0.8515 ~ 8.5805
NaN 数量: 22
3. 计算新因子framework_v2/shared/factors/momentum.py...
✓ 新因子计算完成
结果范围: -0.8515 ~ 8.5805
NaN 数量: 22
4. 对比结果...
✓ 索引一致
最大差异: 0.000000e+00
平均差异: 0.000000e+00
✓ 差异在容差范围内 (< 1e-10)
============================================================
✓ 测试通过:新旧因子输出完全一致!
============================================================
```
---
## 📝 下一步计划
### 阶段3: 信号层迁移
- [ ] TopNSelector - Top N 选股器
- [ ] DynamicThreshold - 动态阈值V3逻辑
- [ ] RebalanceController - 调仓控制器
- [ ] 信号对比验证测试
### 阶段4: 执行层迁移
- [ ] BacktestRunner - 回测执行器
- [ ] 收益计算对比测试
### 阶段5: 数据层迁移
- [ ] RotationDataFetcher - 轮动策略数据获取器
- [ ] CrossMarketAligner - 跨市场对齐器
### 阶段6: 策略组装
- [ ] RotationStrategyV2 - 新框架轮动策略
- [ ] 完整策略对比测试
---
## 🔧 使用方法
### 运行测试
```bash
# 运行因子对比测试
python framework_v2/tests/test_momentum_parity.py
# 运行所有测试
python -m pytest framework_v2/tests/
```
### 使用新因子
```python
from framework_v2.shared.factors import MomentumFactor
# 创建因子
factor = MomentumFactor(n_days=25, weighted=True, crash_filter=True)
# 计算因子值
import pandas as pd
data = pd.DataFrame({'close': [...]}, index=[...])
factor_values = factor.compute(data)
```
---
## 📊 与旧框架对比
| 维度 | 旧框架 (framework/) | 新框架 (framework_v2/) |
|------|---------------------|------------------------|
| **架构** | 抽象+实现混杂 | 三层分离core/shared/tests |
| **因子** | 独立实现 | TALibFactorBase + 定制继承 |
| **信号** | 包含所有逻辑 | 拆分为 Signal + Threshold + Rebalance |
| **数据** | 耦合在策略中 | DataFetcher 抽象 |
| **测试** | 部分覆盖 | 每个组件必须有对比测试 |
| **状态** | 生产环境 ✓ | 开发中 🚧 |
---
## ⚠️ 注意事项
1. **talib 依赖**TALibFactorBase 需要安装 `ta-lib`,但未安装不影响 MomentumFactor 使用
2. **并行开发**:新框架与旧框架并行,不修改现有代码
3. **验证优先**:每个模块迁移后立即验证,确保结果一致
---
*创建日期: 2026-05-06*
*版本: 2.0.0*

15
framework_v2/__init__.py Normal file
View File

@@ -0,0 +1,15 @@
"""
框架 V2 - 重构版本
三层架构:
├── core/ # 纯抽象接口(零实现)
├── shared/ # 通用实现2+策略复用)
└── tests/ # 框架测试
设计原则:
├── 按需抽象,不预先设计
├── 只放通用逻辑,定制逻辑在 strategies/
└── 每个组件必须有测试
"""
__version__ = "2.0.0"

View File

@@ -0,0 +1,19 @@
"""
核心抽象接口层纯ABC零实现
只定义策略框架的标准接口,不包含任何业务逻辑
"""
from framework_v2.core.strategy import StrategyBase
from framework_v2.core.factor import FactorBase
from framework_v2.core.signal import SignalGenerator
from framework_v2.core.executor import Executor
from framework_v2.core.data import DataFetcher
__all__ = [
'StrategyBase',
'FactorBase',
'SignalGenerator',
'Executor',
'DataFetcher',
]

97
framework_v2/core/data.py Normal file
View File

@@ -0,0 +1,97 @@
"""
数据获取器抽象基类
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
import pandas as pd
class DataFetcher(ABC):
"""
数据获取器抽象基类
所有数据获取器必须实现必要方法
"""
name: str = "base"
def __init__(self, **params):
"""
初始化数据获取器参数
Args:
**params: 数据源参数(如 api_url, ssh_config 等)
"""
self._params = params
@abstractmethod
def fetch_indices(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取指数 OHLCV 数据
Args:
codes: 指数代码列表
start: 开始日期 (YYYY-MM-DD)
end: 结束日期 (YYYY-MM-DD)
Returns:
{code: DataFrame} 字典
"""
pass
@abstractmethod
def fetch_etf(
self,
codes: List[str],
start: str,
end: str
) -> Dict[str, pd.DataFrame]:
"""
获取 ETF 数据(价格 + 净值)
Args:
codes: ETF 代码列表
start: 开始日期
end: 结束日期
Returns:
{code: DataFrame} 字典
"""
pass
@abstractmethod
def get_trading_calendar(self, market: str = 'A') -> pd.Index:
"""
获取交易日历
Args:
market: 市场代码('A', 'US', 'HK' 等)
Returns:
交易日历 Index
"""
pass
def get_benchmark(self, code: str, start: str, end: str) -> pd.Series:
"""
获取基准数据(可选)
Args:
code: 基准代码
start: 开始日期
end: 结束日期
Returns:
基准收盘价 Series
"""
raise NotImplementedError("Optional method")
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}(name={self.name})"

View File

@@ -0,0 +1,46 @@
"""
执行器抽象基类
"""
from abc import ABC, abstractmethod
import pandas as pd
class Executor(ABC):
"""
执行器抽象基类
所有执行器必须实现 execute 方法
"""
mode: str = "base"
def __init__(self, **params):
"""
初始化执行器参数
Args:
**params: 执行参数(如 initial_capital, trade_cost 等)
"""
self._params = params
@abstractmethod
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> dict:
"""
执行信号
Args:
signals: 信号 DataFrame
data: 收益率数据 DataFrame
Returns:
回测结果字典,包含:
- result: 回测 DataFrame含净值、收益率
- portfolio: 组合对象(可选)
- metrics: 绩效指标(可选)
"""
pass
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"

View File

@@ -0,0 +1,59 @@
"""
因子抽象基类
"""
from abc import ABC, abstractmethod
import pandas as pd
class FactorBase(ABC):
"""
因子抽象基类
所有因子必须实现 compute 方法
"""
name: str = "base"
category: str = "unknown"
def __init__(self, **params):
"""
初始化因子参数
Args:
**params: 因子参数(如 n_days, weighted 等)
"""
self._params = params
@abstractmethod
def compute(self, data: pd.DataFrame) -> pd.Series:
"""
计算因子值
Args:
data: OHLCV 数据,必须包含 'close'
Returns:
因子值序列(与 data 同索引)
"""
pass
def validate_data(self, data: pd.DataFrame) -> bool:
"""
验证数据是否满足计算要求
Args:
data: OHLCV 数据
Returns:
True 如果数据有效
"""
if 'close' not in data.columns:
return False
min_periods = self._params.get('min_periods', 20)
return len(data) >= min_periods
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"

View File

@@ -0,0 +1,57 @@
"""
信号生成器抽象基类
"""
from abc import ABC, abstractmethod
import pandas as pd
class SignalGenerator(ABC):
"""
信号生成器抽象基类
所有信号生成器必须实现 generate 方法
"""
mode: str = "base"
def __init__(self, **params):
"""
初始化信号生成器参数
Args:
**params: 信号参数(如 select_num, rebalance_days 等)
"""
self._params = params
@abstractmethod
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
"""
生成交易信号
Args:
factor_data: 因子数据 DataFrame
Returns:
信号 DataFrame必须包含 'signal'
"""
pass
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
"""
验证因子数据是否有效
Args:
factor_data: 因子数据
Returns:
True 如果数据有效
"""
if factor_data.empty:
return False
return True
def __repr__(self) -> str:
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
return f"{self.__class__.__name__}({params_str})"

View File

@@ -0,0 +1,151 @@
"""
策略抽象基类
所有策略必须继承此类并实现必要方法
"""
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import pandas as pd
class StrategyBase(ABC):
"""
策略抽象基类
定义策略的标准生命周期:
1. 初始化配置
2. 获取数据
3. 计算因子
4. 生成信号
5. 执行回测
子类必须实现:
- init_factors(): 初始化因子
- init_signal_generator(): 初始化信号生成器
"""
INTERFACE_VERSION = 2 # V2 版本
name: str = "base"
timeframe: str = "1d"
def __init__(self, config: Optional[Dict] = None):
"""
初始化策略
Args:
config: 策略配置字典
"""
self.config = config or {}
self._factor = None
self._signal_generator = None
@abstractmethod
def init_factors(self) -> Any:
"""
初始化因子组件
Returns:
因子实例(继承 FactorBase
"""
pass
@abstractmethod
def init_signal_generator(self) -> Any:
"""
初始化信号生成器
Returns:
信号生成器实例(继承 SignalGenerator
"""
pass
def get_data(self) -> Dict[str, Any]:
"""
获取数据(可选覆盖)
Returns:
数据字典,包含:
- index_data: 指数数据
- etf_data: ETF数据
- benchmark_data: 基准数据
- valid_codes: 有效标的列表
- trading_calendar: 交易日历
"""
raise NotImplementedError("Subclasses must implement get_data()")
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
"""
计算因子(可选覆盖)
Args:
data: 数据字典
Returns:
因子 DataFrame日期 × 标的)
"""
if self._factor is None:
self._factor = self.init_factors()
# 默认实现:遍历标的计算因子
factor_values = {}
for code in data.get('valid_codes', []):
if code in data.get('index_data', {}):
factor_values[code] = self._factor.compute(data['index_data'][code])
return pd.DataFrame(factor_values)
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
"""
生成信号
Args:
factor_df: 因子 DataFrame
Returns:
信号 DataFrame包含 'signal' 列)
"""
if self._signal_generator is None:
self._signal_generator = self.init_signal_generator()
return self._signal_generator.generate(factor_df)
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
"""
运行完整回测流程
Args:
data: 可选,如不提供则自动获取
Returns:
回测结果字典
"""
# 1. 获取数据
if data is None:
data = self.get_data()
# 2. 计算因子
factor_df = self.compute_factors(data)
# 3. 生成信号
signals = self.generate_signals(factor_df)
# 4. 执行回测(子类实现)
return self._execute_backtest(signals, data)
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
"""
执行回测(子类可覆盖)
Args:
signals: 信号 DataFrame
data: 数据字典
Returns:
回测结果
"""
raise NotImplementedError("Subclasses must implement _execute_backtest()")
def __repr__(self) -> str:
return f"{self.__class__.__name__}(name={self.name})"

View File

@@ -0,0 +1,9 @@
"""
通用实现层2+ 策略复用的组件)
包含:
├── factors/ # 通用因子
├── signals/ # 通用信号生成器
├── execution/ # 通用执行器
└── data/ # 通用数据处理
"""

View File

@@ -0,0 +1,17 @@
"""
通用因子实现
"""
from framework_v2.shared.factors.momentum import MomentumFactor
# TALibFactorBase 需要安装 talib可选导入
try:
from framework_v2.shared.factors.talib_base import TALibFactorBase
__all__ = [
'TALibFactorBase',
'MomentumFactor',
]
except ImportError:
__all__ = [
'MomentumFactor',
]

View File

@@ -0,0 +1,104 @@
"""
动量因子(通用版本)
使用加权线性回归:得分 = 年化收益率 ×
与现有 MomentumFactor 对比验证:
- 输入相同 → 输出应该相同
"""
import pandas as pd
import numpy as np
import math
from framework_v2.core import FactorBase
class MomentumFactor(FactorBase):
"""
动量因子
计算加权线性回归动量得分:
得分 = 年化收益率 ×
参数:
- n_days: 动量窗口默认25
- weighted: 是否加权默认True
- crash_filter: 是否启用崩盘过滤默认True
"""
name = "momentum"
category = "momentum"
def __init__(
self,
n_days: int = 25,
weighted: bool = True,
crash_filter: bool = True
):
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
self.n_days = n_days
self.weighted = weighted
self.crash_filter = crash_filter
def compute(self, data: pd.DataFrame) -> pd.Series:
"""计算动量因子值"""
if 'close' not in data.columns:
raise ValueError("data must contain 'close' column")
prices = data['close']
if self.weighted:
factor_values = prices.rolling(self.n_days).apply(
lambda x: self._weighted_momentum_score(x.values),
raw=False
)
else:
factor_values = prices.pct_change(self.n_days)
if self.crash_filter:
factor_values = self._apply_crash_filter(prices, factor_values)
return factor_values
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
"""计算加权动量得分(完全复制现有逻辑)"""
if len(prices) < 5:
return 0.0
# 价格下界 clip防止 log(0) 或 log(负数)
prices = np.clip(prices, 0.01, None)
y = np.log(prices)
# 异常值检测
if np.any(np.isnan(y)) or np.any(np.isinf(y)):
return 0.0
x = np.arange(len(y))
weights = np.linspace(1, 2, len(y))
slope, intercept = np.polyfit(x, y, 1, w=weights)
annualized_returns = math.exp(slope * 250) - 1
y_pred = slope * x + intercept
ss_res = np.sum(weights * (y - y_pred) ** 2)
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
return annualized_returns * r2
def _apply_crash_filter(self, prices: pd.Series, factor_values: pd.Series) -> pd.Series:
"""崩盘过滤连续3天跌>5%清零(完全复制现有逻辑)"""
result = factor_values.copy()
for i in range(3, len(prices)):
r1 = prices.iloc[i] / prices.iloc[i-1]
r2 = prices.iloc[i-1] / prices.iloc[i-2]
r3 = prices.iloc[i-2] / prices.iloc[i-3]
con1 = min(r1, r2, r3) < 0.95
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
if con1 or con2:
result.iloc[i] = 0.0
return result

View File

@@ -0,0 +1,55 @@
"""
ta-lib 因子基类(通用)
所有 ta-lib 因子继承此类,只需指定函数和参数
"""
import talib
import pandas as pd
import numpy as np
from framework_v2.core import FactorBase
class TALibFactorBase(FactorBase):
"""
ta-lib 因子基类
子类只需实现:
- name: 因子名称
- _talib_func: 返回 ta-lib 函数
"""
category = "technical"
def __init__(self, period: int = 14, **params):
"""
初始化
Args:
period: 周期参数
**params: 其他参数
"""
super().__init__(period=period, **params)
self.period = period
def compute(self, data: pd.DataFrame) -> pd.Series:
"""
计算因子值
Args:
data: OHLCV 数据
Returns:
因子值序列
"""
close = data['close'].values.astype(float)
# 调用子类指定的 ta-lib 函数
result = self._talib_func(close, timeperiod=self.period)
return pd.Series(result, index=data.index, name=self.name)
@property
def _talib_func(self):
"""子类必须实现,返回 ta-lib 函数"""
raise NotImplementedError("Subclasses must implement _talib_func")

View File

@@ -0,0 +1,3 @@
"""
框架 V2 测试
"""

View File

@@ -0,0 +1,116 @@
"""
因子对比验证测试
验证新框架的 MomentumFactor 与现有实现输出一致
"""
import sys
import pandas as pd
import numpy as np
from pathlib import Path
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
def test_momentum_factor_parity():
"""验证新因子与旧因子输出一致"""
print("=" * 60)
print(" MomentumFactor 对比测试")
print("=" * 60)
# 1. 加载测试数据
print("\n1. 加载测试数据...")
test_data_path = project_root / 'data' / 'index_history_data'
# 使用纳指100数据测试
import glob
ndx_files = glob.glob(str(test_data_path / '*NDX*'))
if ndx_files:
test_file = ndx_files[0]
data = pd.read_csv(test_file, index_col=0, parse_dates=True)
print(f" ✓ 加载 {test_file}")
print(f" 数据范围: {data.index[0]} ~ {data.index[-1]}")
print(f" 数据长度: {len(data)}")
else:
print(" ⚠ 未找到测试数据,使用模拟数据")
# 生成模拟数据
np.random.seed(42)
dates = pd.date_range('2020-01-01', periods=500, freq='B')
prices = 100 * np.cumprod(1 + np.random.randn(500) * 0.02)
data = pd.DataFrame({
'close': prices,
'open': prices * 0.99,
'high': prices * 1.01,
'low': prices * 0.98,
'volume': np.random.randint(1000000, 10000000, 500)
}, index=dates)
# 2. 计算旧因子
print("\n2. 计算旧因子strategies/shared/factors/momentum.py...")
from strategies.shared.factors.momentum import MomentumFactor as OldMomentum
old_factor = OldMomentum(n_days=25, weighted=True, crash_filter=True)
old_result = old_factor.compute(data)
print(f" ✓ 旧因子计算完成")
print(f" 结果范围: {old_result.min():.4f} ~ {old_result.max():.4f}")
print(f" NaN 数量: {old_result.isna().sum()}")
# 3. 计算新因子
print("\n3. 计算新因子framework_v2/shared/factors/momentum.py...")
from framework_v2.shared.factors.momentum import MomentumFactor as NewMomentum
new_factor = NewMomentum(n_days=25, weighted=True, crash_filter=True)
new_result = new_factor.compute(data)
print(f" ✓ 新因子计算完成")
print(f" 结果范围: {new_result.min():.4f} ~ {new_result.max():.4f}")
print(f" NaN 数量: {new_result.isna().sum()}")
# 4. 对比结果
print("\n4. 对比结果...")
# 检查索引是否一致
if not old_result.index.equals(new_result.index):
print(" ✗ 索引不一致")
return False
print(" ✓ 索引一致")
# 检查数值差异
diff = (old_result - new_result).abs()
max_diff = diff.max()
mean_diff = diff.mean()
print(f" 最大差异: {max_diff:.6e}")
print(f" 平均差异: {mean_diff:.6e}")
# 允许浮点数精度误差1e-10
tolerance = 1e-10
if max_diff < tolerance:
print(f" ✓ 差异在容差范围内 (< {tolerance:.0e})")
print("\n" + "=" * 60)
print(" ✓ 测试通过:新旧因子输出完全一致!")
print("=" * 60)
return True
else:
print(f" ✗ 差异超出容差范围")
print("\n" + "=" * 60)
print(" ✗ 测试失败:新旧因子输出不一致")
print("=" * 60)
# 打印前10个差异点
diff_nonzero = diff[diff > tolerance]
if len(diff_nonzero) > 0:
print(f"\n 前10个差异点:")
for date, val in diff_nonzero.head(10).items():
old_val = old_result.loc[date]
new_val = new_result.loc[date]
print(f" {date}: 旧={old_val:.6f}, 新={new_val:.6f}, 差异={val:.6e}")
return False
if __name__ == '__main__':
success = test_momentum_factor_parity()
sys.exit(0 if success else 1)