Files
factorhack/factor_mining/operators.py
2025-11-09 20:19:08 +08:00

654 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
算子系统:基础数学算子和技术指标算子的注册与管理
支持算子的注册、查询、反射调用
"""
import numpy as np
import pandas as pd
from typing import Dict, Callable, List, Optional, Any
from abc import ABC, abstractmethod
import inspect
import talib
class Operator(ABC):
"""算子基类"""
def __init__(self, name: str, func: Callable, description: str = ""):
"""
Parameters:
-----------
name : str
算子名称(唯一标识)
func : Callable
算子函数
description : str
算子描述
"""
self.name = name
self.func = func
self.description = description
self._signature = inspect.signature(func)
def __call__(self, *args, **kwargs):
"""调用算子函数"""
return self.func(*args, **kwargs)
def get_signature(self):
"""获取函数签名"""
return self._signature
def __repr__(self):
return f"Operator(name='{self.name}', description='{self.description}')"
class OperatorRegistry:
"""算子注册表"""
def __init__(self):
self._operators: Dict[str, Operator] = {}
def register(self, operator: Operator):
"""注册算子"""
if operator.name in self._operators:
raise ValueError(f"算子 '{operator.name}' 已存在")
self._operators[operator.name] = operator
def register_function(self, name: str, func: Callable, description: str = ""):
"""直接注册函数为算子"""
operator = Operator(name, func, description)
self.register(operator)
def get(self, name: str) -> Optional[Operator]:
"""获取算子"""
return self._operators.get(name)
def has(self, name: str) -> bool:
"""检查算子是否存在"""
return name in self._operators
def list_all(self) -> List[str]:
"""列出所有算子名称"""
return list(self._operators.keys())
def get_all(self) -> Dict[str, Operator]:
"""获取所有算子"""
return self._operators.copy()
# 全局算子注册表
_registry = OperatorRegistry()
def register_operator(name: str, description: str = ""):
"""装饰器:注册算子"""
def decorator(func: Callable):
_registry.register_function(name, func, description)
return func
return decorator
def get_operator(name: str) -> Optional[Operator]:
"""获取算子"""
return _registry.get(name)
def get_registry() -> OperatorRegistry:
"""获取全局注册表"""
return _registry
# 定义period参数的值范围
PERIOD_RANGE = range(10, 100) # 10到99
# ==================== 基础数学算子 ====================
@register_operator("add", "加法: x + y")
def _add(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return x + y
@register_operator("sub", "减法: x - y")
def _sub(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return x - y
@register_operator("mul", "乘法: x * y")
def _mul(x: np.ndarray, y: np.ndarray) -> np.ndarray:
return x * y
@register_operator("div", "除法: x / y (安全除法)")
def _div(x: np.ndarray, y: np.ndarray) -> np.ndarray:
denom = np.where(np.abs(y) < 1e-12, np.nan, y)
return x / denom
@register_operator("neg", "取负: -x")
def _neg(x: np.ndarray) -> np.ndarray:
return np.negative(x)
@register_operator("abs", "绝对值: |x|")
def _abs(x: np.ndarray) -> np.ndarray:
return np.abs(x)
@register_operator("log", "对数: log(|x|)")
def _log(x: np.ndarray) -> np.ndarray:
return np.log(np.clip(np.abs(x), 1e-12, None))
@register_operator("sqrt", "平方根: sqrt(x)")
def _sqrt(x: np.ndarray) -> np.ndarray:
return np.sqrt(np.clip(x, 0.0, None))
@register_operator("pow", "幂运算: x^y (限制范围)")
def _pow(x: np.ndarray, y: np.ndarray) -> np.ndarray:
y_clip = np.clip(y, -3.0, 3.0)
with np.errstate(over="ignore", invalid="ignore"):
out = np.power(np.clip(x, -1e6, 1e6), y_clip)
out[~np.isfinite(out)] = np.nan
return out
# ==================== 时间序列算子 ====================
def _rolling_mean(x: np.ndarray, window: int) -> np.ndarray:
s = pd.Series(x)
return s.rolling(window, min_periods=max(2, window // 2)).mean().to_numpy()
def _rolling_std(x: np.ndarray, window: int) -> np.ndarray:
s = pd.Series(x)
return s.rolling(window, min_periods=max(2, window // 2)).std().to_numpy()
def _ts_delta(x: np.ndarray, period: int) -> np.ndarray:
s = pd.Series(x)
return s.diff(period).to_numpy()
def _ts_rank(x: np.ndarray, window: int) -> np.ndarray:
s = pd.Series(x)
return (
s.rolling(window, min_periods=max(2, window // 2))
.apply(lambda a: pd.Series(a).rank(pct=True).iloc[-1], raw=False)
.to_numpy()
)
def _delay(x: np.ndarray, period: int) -> np.ndarray:
s = pd.Series(x)
return s.shift(period).to_numpy()
def _pct_change(x: np.ndarray, period: int = 1) -> np.ndarray:
"""百分比变化"""
s = pd.Series(x)
return s.pct_change(periods=period, fill_method=None).to_numpy()
# 注册单参数百分比变化算子
@register_operator("pct", "百分比变化: PCT(x, 1)")
def _pct(x: np.ndarray) -> np.ndarray:
return _pct_change(x, 1)
# 注册时间序列算子(带不同窗口)
for w in PERIOD_RANGE:
_registry.register_function(
f"sma{w}", lambda x, w=w: _rolling_mean(x, w), f"简单移动平均: SMA(x, {w})"
)
_registry.register_function(
f"std{w}", lambda x, w=w: _rolling_std(x, w), f"滚动标准差: STD(x, {w})"
)
_registry.register_function(
f"rank{w}", lambda x, w=w: _ts_rank(x, w), f"滚动排名: RANK(x, {w})"
)
_registry.register_function(
f"delta{w}", lambda x, w=w: _ts_delta(x, w), f"差分: DELTA(x, {w})"
)
_registry.register_function(
f"delay{w}", lambda x, w=w: _delay(x, w), f"延迟: DELAY(x, {w})"
)
# ==================== 技术指标算子含自定义与ta-lib====================
def _try_float(x):
try:
return float(x)
except Exception:
return x
def _convert_input(v):
# 如果是pd.Series,返回np.ndarray; 如果已经是np.ndarray则原样返回
if isinstance(v, pd.Series):
return v.values
return v
# 注册 ta-lib 技术指标
# 获取 TA-Lib 的所有函数名常用financial indicators均为大写
talib_func_list = [f for f in dir(talib) if f.isupper() and callable(getattr(talib, f))]
# 定义需要生成多版本的参数名period相关参数
# 按优先级排序优先匹配主要的period参数
PERIOD_PARAM_NAMES = [
"timeperiod", # 最常见的参数名
"period", # 通用period参数
"optintimeperiod", # TA-Lib内部参数名
]
# 多period参数的函数需要特殊处理
# 对于这些函数明确指定主要period参数避免自动检测错误
MULTI_PERIOD_FUNCTIONS = {
# 函数名: (主要period参数名, 次要period参数列表仅用于文档)
"MACD": ("fastperiod", ["slowperiod", "signalperiod"]),
"MACDEXT": ("fastperiod", ["slowperiod", "signalperiod"]),
"MACDFIX": ("signalperiod", []),
"STOCH": ("fastk_period", ["slowk_period", "slowd_period"]),
"STOCHF": ("fastk_period", ["fastd_period"]),
"STOCHRSI": ("timeperiod", ["fastk_period", "fastd_period"]),
"BBANDS": ("timeperiod", ["nbdevup", "nbdevdn"]),
"APO": ("fastperiod", ["slowperiod"]),
"PPO": ("fastperiod", ["slowperiod"]),
"ULTOSC": ("timeperiod1", ["timeperiod2", "timeperiod3"]),
"BOP": ("", []), # 无period参数注册默认版本
}
def build_talib_wrapper(func, func_name, fixed_params=None):
"""构建talib函数包装器支持固定某些参数"""
fixed_params = fixed_params or {}
def _talib_wrap(*args, **kwargs):
# 合并固定参数和传入参数
merged_kwargs = {**fixed_params, **kwargs}
# ta-lib 有些函数只支持关键字参数
# 自动转换所有输入类型
args = tuple(_convert_input(arg) for arg in args)
for k in merged_kwargs:
merged_kwargs[k] = _convert_input(merged_kwargs[k])
result = func(*args, **merged_kwargs)
# TA-Lib有些输出是tuple比如MACD统一返回ndarray/tuple[ndarray]
if isinstance(result, tuple):
# 保持tuple结构
return tuple(
np.asarray(item) if item is not None else None for item in result
)
return np.asarray(result)
_talib_wrap.__name__ = f"talib_{func_name.lower()}"
return _talib_wrap
for func_name in talib_func_list:
func = getattr(talib, func_name)
sig = inspect.signature(func)
params = sig.parameters
# 检查是否在特殊配置字典中
if func_name in MULTI_PERIOD_FUNCTIONS:
main_period_param, _ = MULTI_PERIOD_FUNCTIONS[func_name]
# 如果配置中指定了主要period参数使用它
if main_period_param and main_period_param in params:
for period_value in PERIOD_RANGE:
fixed_params = {main_period_param: period_value}
wrapper = build_talib_wrapper(func, func_name, fixed_params)
op_name = f"talib_{func_name.lower()}_{period_value}"
desc = f"ta-lib: {func_name}({main_period_param}={period_value})"
_registry.register_function(op_name, wrapper, desc)
else:
# 配置中指定无period参数注册默认版本
wrapper = build_talib_wrapper(func, func_name)
op_name = f"talib_{func_name.lower()}"
desc = f"ta-lib: {func_name}"
_registry.register_function(op_name, wrapper, desc)
else:
# 不在特殊配置中自动检测period参数
period_params = {}
for param_name, param in params.items():
param_lower = param_name.lower()
# 检查是否是period相关参数
if any(
period_keyword in param_lower for period_keyword in PERIOD_PARAM_NAMES
):
period_params[param_name] = param
if period_params:
# 如果有period参数为每个period值生成一个版本
# 优先选择timeperiod否则选择第一个
main_period_param = None
for preferred in ["timeperiod", "period", "optintimeperiod"]:
for param_name in period_params.keys():
if preferred in param_name.lower():
main_period_param = param_name
break
if main_period_param:
break
if not main_period_param:
main_period_param = list(period_params.keys())[0]
for period_value in PERIOD_RANGE:
fixed_params = {main_period_param: period_value}
wrapper = build_talib_wrapper(func, func_name, fixed_params)
op_name = f"talib_{func_name.lower()}_{period_value}"
desc = f"ta-lib: {func_name}({main_period_param}={period_value})"
_registry.register_function(op_name, wrapper, desc)
else:
# 如果没有period参数注册默认版本
wrapper = build_talib_wrapper(func, func_name)
op_name = f"talib_{func_name.lower()}"
desc = f"ta-lib: {func_name}"
_registry.register_function(op_name, wrapper, desc)
# ==================== 自定义常见技术指标 ====================
def _ewm_forward(x: np.ndarray, alpha: float) -> np.ndarray:
"""指数加权移动平均(前向计算)"""
result = np.zeros_like(x)
if len(x) == 0:
return result
result[0] = x[0]
for i in range(1, len(x)):
result[i] = x[i] * alpha + (1 - alpha) * result[i - 1]
return result
def _rsv(x: np.ndarray, window: int) -> np.ndarray:
"""相对强弱值: (当前值 - 最小值) / (最大值 - 最小值)"""
s = pd.Series(x)
rolling = s.rolling(window, min_periods=max(2, window // 2), closed="both")
min_val = rolling.min()
max_val = rolling.max()
diff = max_val - min_val
# 避免除零
diff = np.where(np.abs(diff) < 1e-12, np.nan, diff)
result = (s - min_val) / diff
return result.to_numpy()
def _bband(x: np.ndarray, window: int) -> np.ndarray:
"""布林带指标: (当前值 - 均值) / 标准差"""
s = pd.Series(x)
rolling = s.rolling(window, min_periods=max(2, window // 2), closed="both")
mean_val = rolling.mean()
std_val = rolling.std()
# 避免除零
std_val = np.where(np.abs(std_val) < 1e-12, np.nan, std_val)
result = (s - mean_val) / std_val
return result.to_numpy()
def _rsi(x: np.ndarray, window: int, threshold: float = 0.00001) -> np.ndarray:
"""相对强弱指标: 上涨和下跌的比例"""
s = pd.Series(x)
diff = s.diff()
rolling = diff.rolling(window, min_periods=max(2, window // 2), closed="both")
def _rsi_calc(series):
up_sum = series[series > threshold].sum()
down_sum = abs(series[series < -threshold].sum())
total = up_sum + down_sum
if total < 1e-12:
return np.nan
return up_sum / total
result = rolling.apply(_rsi_calc, raw=False)
return result.to_numpy()
def _rolling_skew(x: np.ndarray, window: int) -> np.ndarray:
"""滚动偏度"""
s = pd.Series(x)
return (
s.rolling(window, min_periods=max(2, window // 2), closed="both")
.skew()
.to_numpy()
)
def _rolling_kurtosis(x: np.ndarray, window: int) -> np.ndarray:
"""滚动峰度"""
s = pd.Series(x)
return (
s.rolling(window, min_periods=max(2, window // 2), closed="both")
.kurt()
.to_numpy()
)
def _rolling_linear(x: np.ndarray, window: int) -> np.ndarray:
"""滚动线性回归斜率"""
s = pd.Series(x)
def _linear_slope(series):
valid = series.dropna()
if len(valid) < 2:
return np.nan
try:
coeffs = np.polyfit(np.arange(len(valid)), valid.values, 1)
return coeffs[0]
except:
return np.nan
result = s.rolling(window, min_periods=max(2, window // 2), closed="both").apply(
_linear_slope, raw=False
)
return result.to_numpy()
def _rolling_autocorr(x: np.ndarray, window: int, lag: int = 1) -> np.ndarray:
"""滚动自相关"""
s = pd.Series(x)
result = s.rolling(window, min_periods=max(2, window // 2), closed="both").apply(
lambda series: (
series.autocorr(lag=lag) if len(series.dropna()) >= 2 else np.nan
),
raw=False,
)
return result.to_numpy()
def _rolling_max(x: np.ndarray, window: int) -> np.ndarray:
"""滚动最大值"""
s = pd.Series(x)
return (
s.rolling(window, min_periods=max(2, window // 2), closed="both")
.max()
.to_numpy()
)
def _rolling_min(x: np.ndarray, window: int) -> np.ndarray:
"""滚动最小值"""
s = pd.Series(x)
return (
s.rolling(window, min_periods=max(2, window // 2), closed="both")
.min()
.to_numpy()
)
def _huanbi(x: np.ndarray, window: int) -> np.ndarray:
"""环比: 当前值 / 窗口起始值"""
s = pd.Series(x)
def _huanbi_calc(series):
if len(series) < 2:
return np.nan
start_val = series.iloc[0]
end_val = series.iloc[-1]
if abs(start_val) < 1e-12:
return np.nan
return end_val / start_val
result = s.rolling(window, min_periods=max(2, window // 2), closed="both").apply(
_huanbi_calc, raw=False
)
return result.to_numpy()
# 注册技术指标算子(带不同窗口)
for w in PERIOD_RANGE:
# EWM算子使用固定alpha值
alpha = 2.0 / (w + 1)
_registry.register_function(
f"ewm{w}",
lambda x, w=w, a=alpha: _ewm_forward(x, a),
f"指数加权移动平均: EWM(x, {w})",
)
# 百分比变化
_registry.register_function(
f"pct{w}", lambda x, w=w: _pct_change(x, w), f"百分比变化: PCT(x, {w})"
)
# RSV相对强弱值
_registry.register_function(
f"rsv{w}", lambda x, w=w: _rsv(x, w), f"相对强弱值: RSV(x, {w})"
)
# 布林带
_registry.register_function(
f"bband{w}", lambda x, w=w: _bband(x, w), f"布林带指标: BBAND(x, {w})"
)
# RSI
_registry.register_function(
f"rsi{w}", lambda x, w=w: _rsi(x, w), f"相对强弱指标: RSI(x, {w})"
)
# 统计量
_registry.register_function(
f"skew{w}", lambda x, w=w: _rolling_skew(x, w), f"滚动偏度: SKEW(x, {w})"
)
_registry.register_function(
f"kurt{w}", lambda x, w=w: _rolling_kurtosis(x, w), f"滚动峰度: KURT(x, {w})"
)
_registry.register_function(
f"linear{w}",
lambda x, w=w: _rolling_linear(x, w),
f"滚动线性斜率: LINEAR(x, {w})",
)
_registry.register_function(
f"autocorr{w}",
lambda x, w=w: _rolling_autocorr(x, w),
f"滚动自相关: AUTOCORR(x, {w})",
)
_registry.register_function(
f"max{w}", lambda x, w=w: _rolling_max(x, w), f"滚动最大值: MAX(x, {w})"
)
_registry.register_function(
f"min{w}", lambda x, w=w: _rolling_min(x, w), f"滚动最小值: MIN(x, {w})"
)
# 环比
_registry.register_function(
f"huanbi{w}", lambda x, w=w: _huanbi(x, w), f"环比: HUANBI(x, {w})"
)
# ==================== 因子公式解析与计算 ====================
class FactorFormula:
"""因子公式:支持序列化和反序列化"""
def __init__(self, expression: str, feature_names: List[str]):
"""
Parameters:
-----------
expression : str
因子表达式(使用算子名称)
feature_names : List[str]
特征名称列表
"""
self.expression = expression
self.feature_names = feature_names
def compute(self, features: Dict[str, np.ndarray]) -> np.ndarray:
"""
计算因子值
Parameters:
-----------
features : Dict[str, np.ndarray]
特征字典key为特征名称
Returns:
--------
np.ndarray: 因子值
"""
# 构建计算环境
env = {}
# 添加特征
for name in self.feature_names:
if name not in features:
raise KeyError(f"特征 '{name}' 不存在")
env[name] = features[name]
# 添加算子
for op_name in _registry.list_all():
op = _registry.get(op_name)
if op:
env[op_name] = op.func
# 添加numpy和pandas用于某些表达式
env["np"] = np
env["pd"] = pd
# 执行表达式
try:
# 限制可用的内置函数
safe_builtins = {
"abs": abs,
"min": min,
"max": max,
"sum": sum,
"len": len,
}
result = eval(self.expression, {"__builtins__": safe_builtins}, env)
# 确保结果是numpy数组
if not isinstance(result, np.ndarray):
if isinstance(result, (int, float)):
# 标量转换为数组(广播)
result = np.full(len(features[self.feature_names[0]]), result)
else:
result = np.array(result)
# 确保长度一致
expected_len = len(features[self.feature_names[0]])
if len(result) != expected_len:
raise ValueError(
f"表达式结果长度 {len(result)} 与特征长度 {expected_len} 不匹配"
)
return result
except Exception as e:
raise RuntimeError(f"计算因子表达式失败: {e}\n表达式: {self.expression}")
def to_dict(self) -> Dict:
"""序列化为字典"""
return {"expression": self.expression, "feature_names": self.feature_names}
@classmethod
def from_dict(cls, data: Dict) -> "FactorFormula":
"""从字典反序列化"""
return cls(data["expression"], data["feature_names"])
def __repr__(self):
return f"FactorFormula(expression='{self.expression}', features={self.feature_names})"