feat(factors): 实现因子层抽象
核心组件: - FactorBase: 因子抽象基类(compute方法 + 数据验证) - FactorRegistry: 因子注册器(注册/获取/按类别筛选) - FactorCombiner: 因子组合器(加权组合4种方法) 已实现因子: - MomentumFactor: 加权动量因子(含崩盘过滤) - TrendFactor: 趋势因子(MA交叉/MACD) - ReversalFactor: 反转因子(RSI/KDJ) - VolatilityFactor: 波动率因子(ATR/标准差) 测试覆盖:18个测试全部通过
This commit is contained in:
282
framework/factors/__init__.py
Normal file
282
framework/factors/__init__.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
因子层抽象设计
|
||||
|
||||
核心组件:
|
||||
- FactorBase: 因子抽象基类
|
||||
- FactorRegistry: 因子注册器
|
||||
- FactorCombiner: 因子组合器
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactorMeta:
|
||||
"""因子元信息"""
|
||||
name: str
|
||||
category: str # 'momentum', 'trend', 'reversal', 'volatility', 'fundamental'
|
||||
params: Dict[str, Any]
|
||||
description: str = ""
|
||||
|
||||
|
||||
class FactorBase(ABC):
|
||||
"""
|
||||
因子抽象基类
|
||||
|
||||
所有因子必须继承此基类,实现compute方法。
|
||||
支持参数配置、数据验证、元信息管理。
|
||||
"""
|
||||
|
||||
# 类属性(可被配置覆盖)
|
||||
name: str = "base"
|
||||
category: str = "unknown"
|
||||
|
||||
def __init__(self, **params):
|
||||
"""
|
||||
初始化因子
|
||||
|
||||
Args:
|
||||
**params: 因子参数(如n_days=25, period=14等)
|
||||
"""
|
||||
self._params = params
|
||||
self._meta = FactorMeta(
|
||||
name=self.name,
|
||||
category=self.category,
|
||||
params=params,
|
||||
description=self.__doc__ or ""
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
计算因子值
|
||||
|
||||
Args:
|
||||
data: 包含OHLCV数据的DataFrame
|
||||
|
||||
Returns:
|
||||
因子值序列(Series)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def params(self) -> Dict[str, Any]:
|
||||
"""获取因子参数"""
|
||||
return self._params
|
||||
|
||||
@property
|
||||
def meta(self) -> FactorMeta:
|
||||
"""获取因子元信息"""
|
||||
return self._meta
|
||||
|
||||
def validate_data(self, data: pd.DataFrame) -> bool:
|
||||
"""
|
||||
验证数据是否满足计算要求
|
||||
|
||||
Args:
|
||||
data: 数据DataFrame
|
||||
|
||||
Returns:
|
||||
是否满足要求
|
||||
"""
|
||||
# 默认验证:数据长度 >= 最小周期
|
||||
min_periods = self._params.get('min_periods', 20)
|
||||
return len(data) >= min_periods
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(name={self.name}, params={self._params})"
|
||||
|
||||
|
||||
class FactorRegistry:
|
||||
"""
|
||||
因子注册器
|
||||
|
||||
管理所有注册的因子,支持:
|
||||
- 注册因子类
|
||||
- 获取因子实例
|
||||
- 列出可用因子
|
||||
- 按类别筛选因子
|
||||
"""
|
||||
|
||||
_factors: Dict[str, type] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, factor_class: type) -> None:
|
||||
"""
|
||||
注册因子类
|
||||
|
||||
Args:
|
||||
factor_class: 因子类(必须继承FactorBase)
|
||||
"""
|
||||
if not isinstance(factor_class, type) or not issubclass(factor_class, FactorBase):
|
||||
raise TypeError(f"factor_class must be a subclass of FactorBase")
|
||||
|
||||
# 创建临时实例获取名称
|
||||
temp_instance = factor_class()
|
||||
name = temp_instance.name
|
||||
cls._factors[name] = factor_class
|
||||
print(f"✓ 因子已注册: {name} ({factor_class.__name__})")
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str, **params) -> FactorBase:
|
||||
"""
|
||||
获取因子实例
|
||||
|
||||
Args:
|
||||
name: 因子名称
|
||||
**params: 因子参数
|
||||
|
||||
Returns:
|
||||
因子实例
|
||||
"""
|
||||
if name not in cls._factors:
|
||||
raise KeyError(f"Factor '{name}' not registered. Available: {cls.list()}")
|
||||
|
||||
factor_class = cls._factors[name]
|
||||
return factor_class(**params)
|
||||
|
||||
@classmethod
|
||||
def list(cls, category: str = None) -> List[str]:
|
||||
"""
|
||||
列出可用因子
|
||||
|
||||
Args:
|
||||
category: 按类别筛选(可选)
|
||||
|
||||
Returns:
|
||||
因子名称列表
|
||||
"""
|
||||
if category:
|
||||
return [
|
||||
name for name, factor_class in cls._factors.items()
|
||||
if factor_class().category == category
|
||||
]
|
||||
return list(cls._factors.keys())
|
||||
|
||||
@classmethod
|
||||
def list_by_category(cls) -> Dict[str, List[str]]:
|
||||
"""
|
||||
按类别列出因子
|
||||
|
||||
Returns:
|
||||
类别→因子列表字典
|
||||
"""
|
||||
result = {}
|
||||
for name, factor_class in cls._factors.items():
|
||||
cat = factor_class().category
|
||||
if cat not in result:
|
||||
result[cat] = []
|
||||
result[cat].append(name)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def clear(cls) -> None:
|
||||
"""清空注册表(用于测试)"""
|
||||
cls._factors.clear()
|
||||
|
||||
|
||||
class FactorCombiner:
|
||||
"""
|
||||
因子组合器
|
||||
|
||||
支持多因子加权组合,用于:
|
||||
- 多因子策略
|
||||
- 因子权重调整
|
||||
- 因子结果合并
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
factors: List[FactorBase],
|
||||
weights: Optional[List[float]] = None,
|
||||
method: str = 'weighted_sum'
|
||||
):
|
||||
"""
|
||||
初始化因子组合器
|
||||
|
||||
Args:
|
||||
factors: 因子实例列表
|
||||
weights: 权重列表(默认等权)
|
||||
method: 组合方法 ('weighted_sum', 'average', 'max', 'min')
|
||||
"""
|
||||
self._factors = factors
|
||||
self._weights = weights or [1.0 / len(factors)] * len(factors)
|
||||
self._method = method
|
||||
|
||||
# 验证权重
|
||||
if len(self._weights) != len(factors):
|
||||
raise ValueError(f"weights length ({len(self._weights)}) != factors length ({len(factors)})")
|
||||
|
||||
# 归一化权重
|
||||
total_weight = sum(self._weights)
|
||||
self._weights = [w / total_weight for w in self._weights]
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
计算所有因子并组合
|
||||
|
||||
Args:
|
||||
data: 输入数据
|
||||
|
||||
Returns:
|
||||
包含各因子值和组合因子值的DataFrame
|
||||
"""
|
||||
result = pd.DataFrame(index=data.index)
|
||||
|
||||
# 计算各因子
|
||||
for i, factor in enumerate(self._factors):
|
||||
# 验证数据
|
||||
if not factor.validate_data(data):
|
||||
print(f"⚠ 因子 {factor.name} 数据验证失败,跳过")
|
||||
continue
|
||||
|
||||
# 计算因子值
|
||||
factor_values = factor.compute(data)
|
||||
result[factor.name] = factor_values
|
||||
|
||||
# 加权因子值
|
||||
result[f"{factor.name}_weighted"] = factor_values * self._weights[i]
|
||||
|
||||
# 组合因子值
|
||||
weighted_cols = [f"{f.name}_weighted" for f in self._factors if f.name in result.columns]
|
||||
|
||||
if self._method == 'weighted_sum':
|
||||
result['combined'] = result[weighted_cols].sum(axis=1)
|
||||
elif self._method == 'average':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
result['combined'] = result[factor_cols].mean(axis=1)
|
||||
elif self._method == 'max':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
result['combined'] = result[factor_cols].max(axis=1)
|
||||
elif self._method == 'min':
|
||||
factor_cols = [f.name for f in self._factors if f.name in result.columns]
|
||||
result['combined'] = result[factor_cols].min(axis=1)
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self._method}")
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def factors(self) -> List[FactorBase]:
|
||||
"""获取因子列表"""
|
||||
return self._factors
|
||||
|
||||
@property
|
||||
def weights(self) -> List[float]:
|
||||
"""获取权重列表"""
|
||||
return self._weights
|
||||
|
||||
def set_weights(self, weights: List[float]) -> None:
|
||||
"""设置权重"""
|
||||
if len(weights) != len(self._factors):
|
||||
raise ValueError(f"weights length must equal factors length")
|
||||
total = sum(weights)
|
||||
self._weights = [w / total for w in weights]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
factor_names = [f.name for f in self._factors]
|
||||
return f"FactorCombiner(factors={factor_names}, weights={self._weights})"
|
||||
312
framework/factors/momentum.py
Normal file
312
framework/factors/momentum.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
动量因子实现
|
||||
|
||||
基于加权线性回归动量的因子
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from framework.factors import FactorBase
|
||||
|
||||
|
||||
class MomentumFactor(FactorBase):
|
||||
"""
|
||||
动量因子
|
||||
|
||||
计算加权线性回归动量得分:
|
||||
得分 = 年化收益率 × R²
|
||||
|
||||
参数:
|
||||
- n_days: 动量窗口(默认25)
|
||||
- weighted: 是否加权(默认True)
|
||||
- crash_filter: 是否启用崩盘过滤(默认True)
|
||||
"""
|
||||
|
||||
name = "momentum"
|
||||
category = "momentum"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_days: int = 25,
|
||||
weighted: bool = True,
|
||||
crash_filter: bool = True
|
||||
):
|
||||
super().__init__(n_days=n_days, weighted=weighted, crash_filter=crash_filter)
|
||||
self.n_days = n_days
|
||||
self.weighted = weighted
|
||||
self.crash_filter = crash_filter
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算动量因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.weighted:
|
||||
# 加权动量得分
|
||||
factor_values = prices.rolling(self.n_days).apply(
|
||||
lambda x: self._weighted_momentum_score(x.values),
|
||||
raw=False
|
||||
)
|
||||
else:
|
||||
# 简单动量
|
||||
factor_values = prices.pct_change(self.n_days)
|
||||
|
||||
# 应用崩盘过滤
|
||||
if self.crash_filter:
|
||||
factor_values = self._apply_crash_filter(prices, factor_values)
|
||||
|
||||
return factor_values
|
||||
|
||||
def _weighted_momentum_score(self, prices: np.ndarray) -> float:
|
||||
"""计算加权动量得分"""
|
||||
if len(prices) < 5:
|
||||
return 0.0
|
||||
|
||||
y = np.log(prices)
|
||||
x = np.arange(len(y))
|
||||
weights = np.linspace(1, 2, len(y))
|
||||
|
||||
# 加权线性回归
|
||||
slope, intercept = np.polyfit(x, y, 1, w=weights)
|
||||
annualized_returns = math.exp(slope * 250) - 1
|
||||
|
||||
# 加权R²
|
||||
y_pred = slope * x + intercept
|
||||
ss_res = np.sum(weights * (y - y_pred) ** 2)
|
||||
ss_tot = np.sum(weights * (y - np.average(y, weights=weights)) ** 2)
|
||||
r2 = 1 - ss_res / ss_tot if ss_tot > 0 else 0
|
||||
|
||||
return annualized_returns * r2
|
||||
|
||||
def _apply_crash_filter(
|
||||
self,
|
||||
prices: pd.Series,
|
||||
factor_values: pd.Series
|
||||
) -> pd.Series:
|
||||
"""崩盘过滤:连续3天跌>5%清零"""
|
||||
result = factor_values.copy()
|
||||
|
||||
for i in range(3, len(prices)):
|
||||
r1 = prices.iloc[i] / prices.iloc[i-1]
|
||||
r2 = prices.iloc[i-1] / prices.iloc[i-2]
|
||||
r3 = prices.iloc[i-2] / prices.iloc[i-3]
|
||||
|
||||
# 条件1:任一天跌>5%
|
||||
con1 = min(r1, r2, r3) < 0.95
|
||||
# 条件2:连续下跌且累计跌>5%
|
||||
con2 = (r1 < 1) and (r2 < 1) and (r3 < 1) and (prices.iloc[i] / prices.iloc[i-3] < 0.95)
|
||||
|
||||
if con1 or con2:
|
||||
result.iloc[i] = 0.0
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class TrendFactor(FactorBase):
|
||||
"""
|
||||
趋势因子
|
||||
|
||||
计算趋势强度:
|
||||
- MA交叉偏离度
|
||||
- MACD趋势
|
||||
|
||||
参数:
|
||||
- method: 趋势方法('ma_cross', 'macd')
|
||||
- fast: 快线周期(默认5)
|
||||
- slow: 慢线周期(默认20)
|
||||
"""
|
||||
|
||||
name = "trend"
|
||||
category = "trend"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'ma_cross',
|
||||
fast: int = 5,
|
||||
slow: int = 20
|
||||
):
|
||||
super().__init__(method=method, fast=fast, slow=slow)
|
||||
self.method = method
|
||||
self.fast = fast
|
||||
self.slow = slow
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算趋势因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'ma_cross':
|
||||
# MA交叉偏离度
|
||||
fast_ma = prices.rolling(self.fast).mean()
|
||||
slow_ma = prices.rolling(self.slow).mean()
|
||||
trend_strength = (fast_ma - slow_ma) / slow_ma
|
||||
return trend_strength
|
||||
|
||||
elif self.method == 'macd':
|
||||
# MACD趋势
|
||||
ema12 = prices.ewm(span=12).mean()
|
||||
ema26 = prices.ewm(span=26).mean()
|
||||
macd = ema12 - ema26
|
||||
signal = macd.ewm(span=9).mean()
|
||||
return macd - signal
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
|
||||
class ReversalFactor(FactorBase):
|
||||
"""
|
||||
反转因子
|
||||
|
||||
计算超买超卖信号:
|
||||
- RSI偏离度
|
||||
- KDJ
|
||||
|
||||
参数:
|
||||
- method: 反转方法('rsi', 'kdj')
|
||||
- period: 周期(默认14)
|
||||
- overbought: 超买阈值(默认70)
|
||||
- oversold: 超卖阈值(默认30)
|
||||
"""
|
||||
|
||||
name = "reversal"
|
||||
category = "reversal"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'rsi',
|
||||
period: int = 14,
|
||||
overbought: float = 70,
|
||||
oversold: float = 30
|
||||
):
|
||||
super().__init__(method=method, period=period, overbought=overbought, oversold=oversold)
|
||||
self.method = method
|
||||
self.period = period
|
||||
self.overbought = overbought
|
||||
self.oversold = oversold
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算反转因子值"""
|
||||
if 'close' not in data.columns:
|
||||
raise ValueError("data must contain 'close' column")
|
||||
|
||||
prices = data['close']
|
||||
|
||||
if self.method == 'rsi':
|
||||
# RSI反转信号
|
||||
rsi = self._compute_rsi(prices, self.period)
|
||||
|
||||
# 超买超卖偏离度
|
||||
# 超买 → 负值(反转向下信号)
|
||||
# 超卖 → 正值(反转向上信号)
|
||||
reversal_signal = pd.Series(index=prices.index, dtype=float)
|
||||
reversal_signal = np.where(
|
||||
rsi > self.overbought,
|
||||
-(rsi - self.overbought) / (100 - self.overbought), # 超买:负值
|
||||
np.where(
|
||||
rsi < self.oversold,
|
||||
(self.oversold - rsi) / self.oversold, # 超卖:正值
|
||||
0 # 正常区间:0
|
||||
)
|
||||
)
|
||||
return pd.Series(reversal_signal, index=prices.index)
|
||||
|
||||
elif self.method == 'kdj':
|
||||
# KDJ反转信号
|
||||
return self._compute_kdj(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_rsi(self, prices: pd.Series, period: int) -> pd.Series:
|
||||
"""计算RSI"""
|
||||
delta = prices.diff()
|
||||
gain = delta.where(delta > 0, 0)
|
||||
loss = (-delta).where(delta < 0, 0)
|
||||
|
||||
avg_gain = gain.rolling(period).mean()
|
||||
avg_loss = loss.rolling(period).mean()
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi
|
||||
|
||||
def _compute_kdj(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算KDJ反转信号"""
|
||||
low = data['low']
|
||||
high = data['high']
|
||||
close = data['close']
|
||||
|
||||
# 计算K、D、J
|
||||
low_min = low.rolling(self.period).min()
|
||||
high_max = high.rolling(self.period).max()
|
||||
|
||||
rsv = (close - low_min) / (high_max - low_min) * 100
|
||||
|
||||
k = rsv.ewm(alpha=1/3).mean()
|
||||
d = k.ewm(alpha=1/3).mean()
|
||||
j = 3 * k - 2 * d
|
||||
|
||||
# J值偏离度作为反转信号
|
||||
return j
|
||||
|
||||
|
||||
class VolatilityFactor(FactorBase):
|
||||
"""
|
||||
波动率因子
|
||||
|
||||
计算价格波动率:
|
||||
- ATR
|
||||
- 标准差
|
||||
|
||||
参数:
|
||||
- method: 波动率方法('atr', 'std')
|
||||
- period: 周期(默认20)
|
||||
"""
|
||||
|
||||
name = "volatility"
|
||||
category = "volatility"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
method: str = 'std',
|
||||
period: int = 20
|
||||
):
|
||||
super().__init__(method=method, period=period)
|
||||
self.method = method
|
||||
self.period = period
|
||||
|
||||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算波动率因子值"""
|
||||
if self.method == 'std':
|
||||
# 标准差波动率
|
||||
return data['close'].rolling(self.period).std()
|
||||
|
||||
elif self.method == 'atr':
|
||||
# ATR波动率
|
||||
return self._compute_atr(data)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {self.method}")
|
||||
|
||||
def _compute_atr(self, data: pd.DataFrame) -> pd.Series:
|
||||
"""计算ATR"""
|
||||
high = data['high']
|
||||
low = data['low']
|
||||
close = data['close']
|
||||
|
||||
prev_close = close.shift(1)
|
||||
tr = pd.concat([
|
||||
high - low,
|
||||
(high - prev_close).abs(),
|
||||
(low - prev_close).abs()
|
||||
], axis=1).max(axis=1)
|
||||
|
||||
return tr.rolling(self.period).mean()
|
||||
282
framework/tests/test_factors.py
Normal file
282
framework/tests/test_factors.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
因子层测试
|
||||
|
||||
测试FactorBase、FactorRegistry、FactorCombiner
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from framework.factors import FactorBase, FactorRegistry, FactorCombiner
|
||||
from framework.factors.momentum import MomentumFactor, TrendFactor, ReversalFactor, VolatilityFactor
|
||||
|
||||
|
||||
class TestFactorBase:
|
||||
"""测试FactorBase抽象基类"""
|
||||
|
||||
def test_factor_meta(self):
|
||||
"""测试因子元信息"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
assert factor.name == "momentum"
|
||||
assert factor.category == "momentum"
|
||||
assert factor.params == {'n_days': 25, 'weighted': True, 'crash_filter': True}
|
||||
|
||||
def test_factor_repr(self):
|
||||
"""测试因子字符串表示"""
|
||||
factor = MomentumFactor(n_days=30)
|
||||
repr_str = repr(factor)
|
||||
assert "MomentumFactor" in repr_str
|
||||
assert "momentum" in repr_str
|
||||
|
||||
def test_validate_data(self):
|
||||
"""测试数据验证"""
|
||||
factor = MomentumFactor(n_days=25)
|
||||
|
||||
# 数据充足
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(data) == True
|
||||
|
||||
# 数据不足
|
||||
short_data = pd.DataFrame({
|
||||
'close': np.random.randn(10).cumsum() + 100
|
||||
})
|
||||
assert factor.validate_data(short_data) == False
|
||||
|
||||
|
||||
class TestFactorRegistry:
|
||||
"""测试因子注册器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_register_factor(self):
|
||||
"""测试因子注册"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
assert 'momentum' in FactorRegistry.list()
|
||||
|
||||
def test_get_factor(self):
|
||||
"""测试获取因子实例"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
factor = FactorRegistry.get('momentum', n_days=30)
|
||||
assert isinstance(factor, MomentumFactor)
|
||||
assert factor.n_days == 30
|
||||
|
||||
def test_get_unknown_factor(self):
|
||||
"""测试获取未注册因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
with pytest.raises(KeyError):
|
||||
FactorRegistry.get('unknown_factor')
|
||||
|
||||
def test_list_by_category(self):
|
||||
"""测试按类别列出因子"""
|
||||
FactorRegistry.register(MomentumFactor)
|
||||
FactorRegistry.register(TrendFactor)
|
||||
FactorRegistry.register(ReversalFactor)
|
||||
|
||||
categories = FactorRegistry.list_by_category()
|
||||
assert 'momentum' in categories
|
||||
assert 'trend' in categories
|
||||
assert 'reversal' in categories
|
||||
|
||||
def test_register_invalid_factor(self):
|
||||
"""测试注册无效因子"""
|
||||
with pytest.raises(TypeError):
|
||||
FactorRegistry.register(str) # 不是FactorBase子类
|
||||
|
||||
|
||||
class TestFactorCombiner:
|
||||
"""测试因子组合器"""
|
||||
|
||||
def setup_method(self):
|
||||
"""每个测试前清空注册表"""
|
||||
FactorRegistry.clear()
|
||||
|
||||
def test_combiner_init(self):
|
||||
"""测试组合器初始化"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor(method='ma_cross')
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.7, 0.3])
|
||||
|
||||
assert len(combiner.factors) == 2
|
||||
assert combiner.weights == [0.7, 0.3] # 未归一化时
|
||||
|
||||
def test_combiner_equal_weights(self):
|
||||
"""测试等权组合"""
|
||||
factors = [
|
||||
MomentumFactor(n_days=25),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors) # 默认等权
|
||||
|
||||
# 权重应该归一化
|
||||
assert sum(combiner.weights) == 1.0
|
||||
|
||||
def test_combiner_compute(self):
|
||||
"""测试因子组合计算"""
|
||||
# 生成测试数据
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100,
|
||||
'high': np.random.randn(100).cumsum() + 105,
|
||||
'low': np.random.randn(100).cumsum() + 95
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor(fast=5, slow=10)
|
||||
]
|
||||
combiner = FactorCombiner(factors, weights=[0.6, 0.4])
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# 检查结果列
|
||||
assert 'momentum' in result.columns
|
||||
assert 'trend' in result.columns
|
||||
assert 'combined' in result.columns
|
||||
|
||||
# 检查加权列
|
||||
assert 'momentum_weighted' in result.columns
|
||||
assert 'trend_weighted' in result.columns
|
||||
|
||||
def test_combiner_method_max(self):
|
||||
"""测试max组合方法"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
data = pd.DataFrame({
|
||||
'close': np.random.randn(100).cumsum() + 100
|
||||
}, index=dates)
|
||||
|
||||
factors = [
|
||||
MomentumFactor(n_days=20),
|
||||
TrendFactor()
|
||||
]
|
||||
combiner = FactorCombiner(factors, method='max')
|
||||
|
||||
result = combiner.compute(data)
|
||||
|
||||
# combined应该是momentum和trend的最大值
|
||||
factor_cols = ['momentum', 'trend']
|
||||
expected_max = result[factor_cols].max(axis=1)
|
||||
pd.testing.assert_series_equal(result['combined'], expected_max, check_names=False)
|
||||
|
||||
|
||||
class TestMomentumFactor:
|
||||
"""测试动量因子"""
|
||||
|
||||
def test_momentum_compute(self):
|
||||
"""测试动量因子计算"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势数据
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的动量得分
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_crash_filter(self):
|
||||
"""测试崩盘过滤"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成正常数据,然后在末尾添加崩盘
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
prices[-3:] = prices[-4] * np.array([0.96, 0.93, 0.90]) # 连续大跌
|
||||
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, crash_filter=True)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 崩盘后动量得分应该被清零
|
||||
assert values.iloc[-1] == 0.0
|
||||
|
||||
def test_simple_momentum(self):
|
||||
"""测试简单动量(无加权,无崩盘过滤)"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = MomentumFactor(n_days=25, weighted=False, crash_filter=False)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 简单动量应该是N日涨幅(无崩盘过滤时)
|
||||
expected = data['close'].pct_change(25)
|
||||
# 验证前25个值都是NaN
|
||||
assert values.iloc[:25].isna().all()
|
||||
# 验证后续值大致正确
|
||||
assert len(values) == len(expected)
|
||||
|
||||
|
||||
class TestTrendFactor:
|
||||
"""测试趋势因子"""
|
||||
|
||||
def test_ma_cross(self):
|
||||
"""测试MA交叉趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成上升趋势
|
||||
prices = 100 + np.arange(100) * 0.5
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='ma_cross', fast=5, slow=20)
|
||||
values = factor.compute(data)
|
||||
|
||||
# 上升趋势应该有正的趋势强度
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
def test_macd(self):
|
||||
"""测试MACD趋势"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
prices = 100 + np.random.randn(100).cumsum()
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = TrendFactor(method='macd')
|
||||
values = factor.compute(data)
|
||||
|
||||
# 检查计算结果
|
||||
assert len(values) == len(data)
|
||||
assert not values.iloc[:26].isna().all() # MACD应该有值
|
||||
|
||||
|
||||
class TestReversalFactor:
|
||||
"""测试反转因子"""
|
||||
|
||||
def test_rsi_reversal(self):
|
||||
"""测试RSI反转信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超买数据(持续上涨)
|
||||
prices = 100 + np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, overbought=70)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI超过70应该产生负值(反转向下信号)
|
||||
assert values.iloc[-1] < 0
|
||||
|
||||
def test_rsi_oversold(self):
|
||||
"""测试RSI超卖信号"""
|
||||
dates = pd.date_range('2020-01-01', periods=100)
|
||||
|
||||
# 生成超卖数据(持续下跌)
|
||||
prices = 100 - np.arange(100) * 1.0
|
||||
data = pd.DataFrame({'close': prices}, index=dates)
|
||||
|
||||
factor = ReversalFactor(method='rsi', period=14, oversold=30)
|
||||
values = factor.compute(data)
|
||||
|
||||
# RSI低于30应该产生正值(反转向上信号)
|
||||
assert values.iloc[-1] > 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user