feat(framework_v2): 创建框架V2骨架 - 三层架构+因子验证通过
## 架构设计 - 三层架构:core(抽象接口) → shared(通用实现) → tests(验证测试) - 5个核心抽象基类:StrategyBase, FactorBase, SignalGenerator, Executor, DataFetcher - 零侵入:与现有框架并行开发,不修改生产代码 ## 已完成 ✓ 核心接口层(5个ABC类) ✓ 通用因子层(MomentumFactor完全复制现有逻辑) ✓ 对比验证测试(新旧因子输出差异=0,测试通过) ## 验证结果 - 最大差异: 0.000000e+00 - 平均差异: 0.000000e+00 - 容差: < 1e-10 ## 下一步 - 阶段3: 信号层迁移(TopNSelector, DynamicThreshold, RebalanceController) - 阶段4: 执行层迁移(BacktestRunner) - 阶段5: 数据层迁移(DataFetcher实现) - 阶段6: 完整策略对比验证 ## 设计原则 - 按需抽象,不预先设计 - 职责分离,避免框架膨胀 - 测试驱动,每个组件必须有对比测试 - 渐进式迁移,验证通过再替换
This commit is contained in:
19
framework_v2/core/__init__.py
Normal file
19
framework_v2/core/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
核心抽象接口层(纯ABC,零实现)
|
||||
|
||||
只定义策略框架的标准接口,不包含任何业务逻辑
|
||||
"""
|
||||
|
||||
from framework_v2.core.strategy import StrategyBase
|
||||
from framework_v2.core.factor import FactorBase
|
||||
from framework_v2.core.signal import SignalGenerator
|
||||
from framework_v2.core.executor import Executor
|
||||
from framework_v2.core.data import DataFetcher
|
||||
|
||||
__all__ = [
|
||||
'StrategyBase',
|
||||
'FactorBase',
|
||||
'SignalGenerator',
|
||||
'Executor',
|
||||
'DataFetcher',
|
||||
]
|
||||
97
framework_v2/core/data.py
Normal file
97
framework_v2/core/data.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
数据获取器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class DataFetcher(ABC):
|
||||
"""
|
||||
数据获取器抽象基类
|
||||
|
||||
所有数据获取器必须实现必要方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化数据获取器参数
|
||||
|
||||
Args:
|
||||
**params: 数据源参数(如 api_url, ssh_config 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def fetch_indices(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取指数 OHLCV 数据
|
||||
|
||||
Args:
|
||||
codes: 指数代码列表
|
||||
start: 开始日期 (YYYY-MM-DD)
|
||||
end: 结束日期 (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fetch_etf(
|
||||
self,
|
||||
codes: List[str],
|
||||
start: str,
|
||||
end: str
|
||||
) -> Dict[str, pd.DataFrame]:
|
||||
"""
|
||||
获取 ETF 数据(价格 + 净值)
|
||||
|
||||
Args:
|
||||
codes: ETF 代码列表
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
{code: DataFrame} 字典
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_trading_calendar(self, market: str = 'A') -> pd.Index:
|
||||
"""
|
||||
获取交易日历
|
||||
|
||||
Args:
|
||||
market: 市场代码('A', 'US', 'HK' 等)
|
||||
|
||||
Returns:
|
||||
交易日历 Index
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_benchmark(self, code: str, start: str, end: str) -> pd.Series:
|
||||
"""
|
||||
获取基准数据(可选)
|
||||
|
||||
Args:
|
||||
code: 基准代码
|
||||
start: 开始日期
|
||||
end: 结束日期
|
||||
|
||||
Returns:
|
||||
基准收盘价 Series
|
||||
"""
|
||||
raise NotImplementedError("Optional method")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
46
framework_v2/core/executor.py
Normal file
46
framework_v2/core/executor.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
执行器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class Executor(ABC):
|
||||
"""
|
||||
执行器抽象基类
|
||||
|
||||
所有执行器必须实现 execute 方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化执行器参数
|
||||
|
||||
Args:
|
||||
**params: 执行参数(如 initial_capital, trade_cost 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, signals: pd.DataFrame, data: pd.DataFrame) -> dict:
|
||||
"""
|
||||
执行信号
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
data: 收益率数据 DataFrame
|
||||
|
||||
Returns:
|
||||
回测结果字典,包含:
|
||||
- result: 回测 DataFrame(含净值、收益率)
|
||||
- portfolio: 组合对象(可选)
|
||||
- metrics: 绩效指标(可选)
|
||||
"""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
59
framework_v2/core/factor.py
Normal file
59
framework_v2/core/factor.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
因子抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class FactorBase(ABC):
|
||||
"""
|
||||
因子抽象基类
|
||||
|
||||
所有因子必须实现 compute 方法
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
category: str = "unknown"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化因子参数
|
||||
|
||||
Args:
|
||||
**params: 因子参数(如 n_days, weighted 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据,必须包含 'close' 列
|
||||
|
||||
Returns:
|
||||
因子值序列(与 data 同索引)
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证数据是否满足计算要求
|
||||
|
||||
Args:
|
||||
data: OHLCV 数据
|
||||
|
||||
Returns:
|
||||
True 如果数据有效
|
||||
"""
|
||||
if 'close' not in data.columns:
|
||||
return False
|
||||
|
||||
min_periods = self._params.get('min_periods', 20)
|
||||
return len(data) >= min_periods
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
57
framework_v2/core/signal.py
Normal file
57
framework_v2/core/signal.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class SignalGenerator(ABC):
|
||||
"""
|
||||
信号生成器抽象基类
|
||||
|
||||
所有信号生成器必须实现 generate 方法
|
||||
"""
|
||||
|
||||
mode: str = "base"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化信号生成器参数
|
||||
|
||||
Args:
|
||||
**params: 信号参数(如 select_num, rebalance_days 等)
|
||||
"""
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, factor_data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成交易信号
|
||||
|
||||
Args:
|
||||
factor_data: 因子数据 DataFrame
|
||||
|
||||
Returns:
|
||||
信号 DataFrame,必须包含 'signal' 列
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_factor_data(self, factor_data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证因子数据是否有效
|
||||
|
||||
Args:
|
||||
factor_data: 因子数据
|
||||
|
||||
Returns:
|
||||
True 如果数据有效
|
||||
"""
|
||||
if factor_data.empty:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
params_str = ', '.join([f"{k}={v}" for k, v in self._params.items()])
|
||||
return f"{self.__class__.__name__}({params_str})"
|
||||
151
framework_v2/core/strategy.py
Normal file
151
framework_v2/core/strategy.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
所有策略必须继承此类并实现必要方法
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional, Any
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class StrategyBase(ABC):
|
||||
"""
|
||||
策略抽象基类
|
||||
|
||||
定义策略的标准生命周期:
|
||||
1. 初始化配置
|
||||
2. 获取数据
|
||||
3. 计算因子
|
||||
4. 生成信号
|
||||
5. 执行回测
|
||||
|
||||
子类必须实现:
|
||||
- init_factors(): 初始化因子
|
||||
- init_signal_generator(): 初始化信号生成器
|
||||
"""
|
||||
|
||||
INTERFACE_VERSION = 2 # V2 版本
|
||||
|
||||
name: str = "base"
|
||||
timeframe: str = "1d"
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""
|
||||
初始化策略
|
||||
|
||||
Args:
|
||||
config: 策略配置字典
|
||||
"""
|
||||
self.config = config or {}
|
||||
self._factor = None
|
||||
self._signal_generator = None
|
||||
|
||||
@abstractmethod
|
||||
def init_factors(self) -> Any:
|
||||
"""
|
||||
初始化因子组件
|
||||
|
||||
Returns:
|
||||
因子实例(继承 FactorBase)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def init_signal_generator(self) -> Any:
|
||||
"""
|
||||
初始化信号生成器
|
||||
|
||||
Returns:
|
||||
信号生成器实例(继承 SignalGenerator)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_data(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取数据(可选覆盖)
|
||||
|
||||
Returns:
|
||||
数据字典,包含:
|
||||
- index_data: 指数数据
|
||||
- etf_data: ETF数据
|
||||
- benchmark_data: 基准数据
|
||||
- valid_codes: 有效标的列表
|
||||
- trading_calendar: 交易日历
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_data()")
|
||||
|
||||
def compute_factors(self, data: Dict[str, Any]) -> pd.DataFrame:
|
||||
"""
|
||||
计算因子(可选覆盖)
|
||||
|
||||
Args:
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
因子 DataFrame(日期 × 标的)
|
||||
"""
|
||||
if self._factor is None:
|
||||
self._factor = self.init_factors()
|
||||
|
||||
# 默认实现:遍历标的计算因子
|
||||
factor_values = {}
|
||||
for code in data.get('valid_codes', []):
|
||||
if code in data.get('index_data', {}):
|
||||
factor_values[code] = self._factor.compute(data['index_data'][code])
|
||||
|
||||
return pd.DataFrame(factor_values)
|
||||
|
||||
def generate_signals(self, factor_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
生成信号
|
||||
|
||||
Args:
|
||||
factor_df: 因子 DataFrame
|
||||
|
||||
Returns:
|
||||
信号 DataFrame(包含 'signal' 列)
|
||||
"""
|
||||
if self._signal_generator is None:
|
||||
self._signal_generator = self.init_signal_generator()
|
||||
|
||||
return self._signal_generator.generate(factor_df)
|
||||
|
||||
def run_backtest(self, data: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
运行完整回测流程
|
||||
|
||||
Args:
|
||||
data: 可选,如不提供则自动获取
|
||||
|
||||
Returns:
|
||||
回测结果字典
|
||||
"""
|
||||
# 1. 获取数据
|
||||
if data is None:
|
||||
data = self.get_data()
|
||||
|
||||
# 2. 计算因子
|
||||
factor_df = self.compute_factors(data)
|
||||
|
||||
# 3. 生成信号
|
||||
signals = self.generate_signals(factor_df)
|
||||
|
||||
# 4. 执行回测(子类实现)
|
||||
return self._execute_backtest(signals, data)
|
||||
|
||||
def _execute_backtest(self, signals: pd.DataFrame, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
执行回测(子类可覆盖)
|
||||
|
||||
Args:
|
||||
signals: 信号 DataFrame
|
||||
data: 数据字典
|
||||
|
||||
Returns:
|
||||
回测结果
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _execute_backtest()")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name})"
|
||||
Reference in New Issue
Block a user