Files
factorhack/factor_mining/mining.py
2025-11-09 23:07:20 +08:00

124 lines
2.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
因子挖掘抽象层支持多种挖掘算法DEAP、DL、RL等
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Optional, Any
import pandas as pd
from dataclasses import dataclass
from factor_mining.FactorFormula import FactorFormula
@dataclass
class MiningConfig:
"""挖掘配置基类"""
ret_horizon: int = 1
ic_window: int = 30
ic_method: str = "spearman" # "spearman" or "pearson"
seed: Optional[int] = None
class FactorMiner(ABC):
"""因子挖掘器抽象基类"""
def __init__(self, config: MiningConfig):
self.config = config
@abstractmethod
def mine(
self,
data: pd.DataFrame,
feature_cols: List[str],
price_col: str = "close"
) -> List[FactorFormula]:
"""
挖掘因子
Parameters:
-----------
data : DataFrame
数据
feature_cols : List[str]
特征列名列表
price_col : str
价格列名
Returns:
--------
List[FactorFormula]: 挖掘出的因子公式列表
"""
pass
@abstractmethod
def get_name(self) -> str:
"""获取挖掘器名称"""
pass
class MiningPipeline:
"""挖掘流程管理器"""
def __init__(self):
self.miners: Dict[str, FactorMiner] = {}
def register_miner(self, miner: FactorMiner):
"""注册挖掘器"""
name = miner.get_name()
if name in self.miners:
raise ValueError(f"挖掘器 '{name}' 已存在")
self.miners[name] = miner
def get_miner(self, name: str) -> Optional[FactorMiner]:
"""获取挖掘器"""
return self.miners.get(name)
def list_miners(self) -> List[str]:
"""列出所有挖掘器"""
return list(self.miners.keys())
def mine(
self,
miner_name: str,
data: pd.DataFrame,
feature_cols: List[str],
price_col: str = "close"
) -> List[FactorFormula]:
"""
使用指定挖掘器进行挖掘
Parameters:
-----------
miner_name : str
挖掘器名称
data : DataFrame
数据
feature_cols : List[str]
特征列名列表
price_col : str
价格列名
Returns:
--------
List[FactorFormula]: 挖掘出的因子公式列表
"""
miner = self.get_miner(miner_name)
if miner is None:
raise ValueError(f"挖掘器 '{miner_name}' 不存在")
return miner.mine(data, feature_cols, price_col)
# 全局挖掘流程管理器
_pipeline = MiningPipeline()
def register_miner(miner: FactorMiner):
"""注册挖掘器到全局管理器"""
_pipeline.register_miner(miner)
def get_pipeline() -> MiningPipeline:
"""获取全局挖掘流程管理器"""
return _pipeline