factor minner
This commit is contained in:
123
factor_mining/mining.py
Normal file
123
factor_mining/mining.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
因子挖掘抽象层:支持多种挖掘算法(DEAP、DL、RL等)
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Optional, Any
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
from factor_mining.operators import FactorFormula
|
||||
|
||||
|
||||
@dataclass
|
||||
class MiningConfig:
|
||||
"""挖掘配置基类"""
|
||||
ret_horizon: int = 1
|
||||
ic_window: int = 30
|
||||
ic_method: str = "spearman" # "spearman" or "pearson"
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class FactorMiner(ABC):
|
||||
"""因子挖掘器抽象基类"""
|
||||
|
||||
def __init__(self, config: MiningConfig):
|
||||
self.config = config
|
||||
|
||||
@abstractmethod
|
||||
def mine(
|
||||
self,
|
||||
data: pd.DataFrame,
|
||||
feature_cols: List[str],
|
||||
price_col: str = "close"
|
||||
) -> List[FactorFormula]:
|
||||
"""
|
||||
挖掘因子
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
data : DataFrame
|
||||
数据
|
||||
feature_cols : List[str]
|
||||
特征列名列表
|
||||
price_col : str
|
||||
价格列名
|
||||
|
||||
Returns:
|
||||
--------
|
||||
List[FactorFormula]: 挖掘出的因子公式列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""获取挖掘器名称"""
|
||||
pass
|
||||
|
||||
|
||||
class MiningPipeline:
|
||||
"""挖掘流程管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.miners: Dict[str, FactorMiner] = {}
|
||||
|
||||
def register_miner(self, miner: FactorMiner):
|
||||
"""注册挖掘器"""
|
||||
name = miner.get_name()
|
||||
if name in self.miners:
|
||||
raise ValueError(f"挖掘器 '{name}' 已存在")
|
||||
self.miners[name] = miner
|
||||
|
||||
def get_miner(self, name: str) -> Optional[FactorMiner]:
|
||||
"""获取挖掘器"""
|
||||
return self.miners.get(name)
|
||||
|
||||
def list_miners(self) -> List[str]:
|
||||
"""列出所有挖掘器"""
|
||||
return list(self.miners.keys())
|
||||
|
||||
def mine(
|
||||
self,
|
||||
miner_name: str,
|
||||
data: pd.DataFrame,
|
||||
feature_cols: List[str],
|
||||
price_col: str = "close"
|
||||
) -> List[FactorFormula]:
|
||||
"""
|
||||
使用指定挖掘器进行挖掘
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
miner_name : str
|
||||
挖掘器名称
|
||||
data : DataFrame
|
||||
数据
|
||||
feature_cols : List[str]
|
||||
特征列名列表
|
||||
price_col : str
|
||||
价格列名
|
||||
|
||||
Returns:
|
||||
--------
|
||||
List[FactorFormula]: 挖掘出的因子公式列表
|
||||
"""
|
||||
miner = self.get_miner(miner_name)
|
||||
if miner is None:
|
||||
raise ValueError(f"挖掘器 '{miner_name}' 不存在")
|
||||
|
||||
return miner.mine(data, feature_cols, price_col)
|
||||
|
||||
|
||||
# 全局挖掘流程管理器
|
||||
_pipeline = MiningPipeline()
|
||||
|
||||
|
||||
def register_miner(miner: FactorMiner):
|
||||
"""注册挖掘器到全局管理器"""
|
||||
_pipeline.register_miner(miner)
|
||||
|
||||
|
||||
def get_pipeline() -> MiningPipeline:
|
||||
"""获取全局挖掘流程管理器"""
|
||||
return _pipeline
|
||||
|
||||
Reference in New Issue
Block a user