124 lines
2.9 KiB
Python
124 lines
2.9 KiB
Python
"""
|
||
因子挖掘抽象层:支持多种挖掘算法(DEAP、DL、RL等)
|
||
"""
|
||
from abc import ABC, abstractmethod
|
||
from typing import List, Dict, Optional, Any
|
||
import pandas as pd
|
||
from dataclasses import dataclass
|
||
|
||
from factor_mining.operators import FactorFormula
|
||
|
||
|
||
@dataclass
|
||
class MiningConfig:
|
||
"""挖掘配置基类"""
|
||
ret_horizon: int = 1
|
||
ic_window: int = 30
|
||
ic_method: str = "spearman" # "spearman" or "pearson"
|
||
seed: Optional[int] = None
|
||
|
||
|
||
class FactorMiner(ABC):
|
||
"""因子挖掘器抽象基类"""
|
||
|
||
def __init__(self, config: MiningConfig):
|
||
self.config = config
|
||
|
||
@abstractmethod
|
||
def mine(
|
||
self,
|
||
data: pd.DataFrame,
|
||
feature_cols: List[str],
|
||
price_col: str = "close"
|
||
) -> List[FactorFormula]:
|
||
"""
|
||
挖掘因子
|
||
|
||
Parameters:
|
||
-----------
|
||
data : DataFrame
|
||
数据
|
||
feature_cols : List[str]
|
||
特征列名列表
|
||
price_col : str
|
||
价格列名
|
||
|
||
Returns:
|
||
--------
|
||
List[FactorFormula]: 挖掘出的因子公式列表
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_name(self) -> str:
|
||
"""获取挖掘器名称"""
|
||
pass
|
||
|
||
|
||
class MiningPipeline:
|
||
"""挖掘流程管理器"""
|
||
|
||
def __init__(self):
|
||
self.miners: Dict[str, FactorMiner] = {}
|
||
|
||
def register_miner(self, miner: FactorMiner):
|
||
"""注册挖掘器"""
|
||
name = miner.get_name()
|
||
if name in self.miners:
|
||
raise ValueError(f"挖掘器 '{name}' 已存在")
|
||
self.miners[name] = miner
|
||
|
||
def get_miner(self, name: str) -> Optional[FactorMiner]:
|
||
"""获取挖掘器"""
|
||
return self.miners.get(name)
|
||
|
||
def list_miners(self) -> List[str]:
|
||
"""列出所有挖掘器"""
|
||
return list(self.miners.keys())
|
||
|
||
def mine(
|
||
self,
|
||
miner_name: str,
|
||
data: pd.DataFrame,
|
||
feature_cols: List[str],
|
||
price_col: str = "close"
|
||
) -> List[FactorFormula]:
|
||
"""
|
||
使用指定挖掘器进行挖掘
|
||
|
||
Parameters:
|
||
-----------
|
||
miner_name : str
|
||
挖掘器名称
|
||
data : DataFrame
|
||
数据
|
||
feature_cols : List[str]
|
||
特征列名列表
|
||
price_col : str
|
||
价格列名
|
||
|
||
Returns:
|
||
--------
|
||
List[FactorFormula]: 挖掘出的因子公式列表
|
||
"""
|
||
miner = self.get_miner(miner_name)
|
||
if miner is None:
|
||
raise ValueError(f"挖掘器 '{miner_name}' 不存在")
|
||
|
||
return miner.mine(data, feature_cols, price_col)
|
||
|
||
|
||
# 全局挖掘流程管理器
|
||
_pipeline = MiningPipeline()
|
||
|
||
|
||
def register_miner(miner: FactorMiner):
|
||
"""注册挖掘器到全局管理器"""
|
||
_pipeline.register_miner(miner)
|
||
|
||
|
||
def get_pipeline() -> MiningPipeline:
|
||
"""获取全局挖掘流程管理器"""
|
||
return _pipeline
|
||
|