160 lines
4.1 KiB
Python
160 lines
4.1 KiB
Python
"""
|
|
算子系统:基础数学算子和技术指标算子的注册与管理
|
|
支持算子的注册、查询、反射调用
|
|
"""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from typing import Dict, Callable, List, Optional, Any
|
|
from abc import ABC, abstractmethod
|
|
import inspect
|
|
|
|
import talib
|
|
from factor_mining.time_series_op import register_time_series_operator
|
|
|
|
|
|
class Operator(ABC):
|
|
"""算子基类"""
|
|
|
|
def __init__(self, name: str, func: Callable, description: str = ""):
|
|
"""
|
|
Parameters:
|
|
-----------
|
|
name : str
|
|
算子名称(唯一标识)
|
|
func : Callable
|
|
算子函数
|
|
description : str
|
|
算子描述
|
|
"""
|
|
self.name = name
|
|
self.func = func
|
|
self.description = description
|
|
self._signature = inspect.signature(func)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""调用算子函数"""
|
|
return self.func(*args, **kwargs)
|
|
|
|
def get_signature(self):
|
|
"""获取函数签名"""
|
|
return self._signature
|
|
|
|
def __repr__(self):
|
|
return f"Operator(name='{self.name}', description='{self.description}')"
|
|
|
|
|
|
class OperatorRegistry:
|
|
"""算子注册表"""
|
|
|
|
def __init__(self):
|
|
self._operators: Dict[str, Operator] = {}
|
|
|
|
def register(self, operator: Operator):
|
|
"""注册算子"""
|
|
if operator.name in self._operators:
|
|
raise ValueError(f"算子 '{operator.name}' 已存在")
|
|
self._operators[operator.name] = operator
|
|
|
|
def register_function(self, name: str, func: Callable, description: str = ""):
|
|
"""直接注册函数为算子"""
|
|
operator = Operator(name, func, description)
|
|
self.register(operator)
|
|
|
|
def get(self, name: str) -> Optional[Operator]:
|
|
"""获取算子"""
|
|
return self._operators.get(name)
|
|
|
|
def has(self, name: str) -> bool:
|
|
"""检查算子是否存在"""
|
|
return name in self._operators
|
|
|
|
def list_all(self) -> List[str]:
|
|
"""列出所有算子名称"""
|
|
return list(self._operators.keys())
|
|
|
|
def get_all(self) -> Dict[str, Operator]:
|
|
"""获取所有算子"""
|
|
return self._operators.copy()
|
|
|
|
|
|
# 全局算子注册表
|
|
_registry = OperatorRegistry()
|
|
|
|
|
|
def register_operator(name: str, description: str = ""):
|
|
"""装饰器:注册算子"""
|
|
|
|
def decorator(func: Callable):
|
|
_registry.register_function(name, func, description)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
|
|
def get_operator(name: str) -> Optional[Operator]:
|
|
"""获取算子"""
|
|
return _registry.get(name)
|
|
|
|
|
|
def get_registry() -> OperatorRegistry:
|
|
"""获取全局注册表"""
|
|
return _registry
|
|
|
|
|
|
# ==================== 基础数学算子 ====================
|
|
|
|
|
|
@register_operator("add", "加法: x + y")
|
|
def _add(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return x + y
|
|
|
|
|
|
@register_operator("sub", "减法: x - y")
|
|
def _sub(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return x - y
|
|
|
|
|
|
@register_operator("mul", "乘法: x * y")
|
|
def _mul(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
return x * y
|
|
|
|
|
|
@register_operator("div", "除法: x / y (安全除法)")
|
|
def _div(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
denom = np.where(np.abs(y) < 1e-12, np.nan, y)
|
|
return x / denom
|
|
|
|
|
|
@register_operator("neg", "取负: -x")
|
|
def _neg(x: np.ndarray) -> np.ndarray:
|
|
return np.negative(x)
|
|
|
|
|
|
@register_operator("abs", "绝对值: |x|")
|
|
def _abs(x: np.ndarray) -> np.ndarray:
|
|
return np.abs(x)
|
|
|
|
|
|
@register_operator("log", "对数: log(|x|)")
|
|
def _log(x: np.ndarray) -> np.ndarray:
|
|
return np.log(np.clip(np.abs(x), 1e-12, None))
|
|
|
|
|
|
@register_operator("sqrt", "平方根: sqrt(x)")
|
|
def _sqrt(x: np.ndarray) -> np.ndarray:
|
|
return np.sqrt(np.clip(x, 0.0, None))
|
|
|
|
|
|
@register_operator("pow", "幂运算: x^y (限制范围)")
|
|
def _pow(x: np.ndarray, y: np.ndarray) -> np.ndarray:
|
|
y_clip = np.clip(y, -3.0, 3.0)
|
|
with np.errstate(over="ignore", invalid="ignore"):
|
|
out = np.power(np.clip(x, -1e6, 1e6), y_clip)
|
|
out[~np.isfinite(out)] = np.nan
|
|
return out
|
|
|
|
|
|
# ==================== 时间序列算子 ====================
|
|
register_time_series_operator(_registry)
|