114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
"""
|
||
因子挖掘模块:支持规则因子和遗传编程因子
|
||
"""
|
||
import numpy as np
|
||
import pandas as pd
|
||
from typing import Callable, Dict, List, Optional
|
||
from abc import ABC, abstractmethod
|
||
|
||
|
||
class BaseFactor(ABC):
|
||
"""因子基类"""
|
||
|
||
def __init__(self, name: str):
|
||
self.name = name
|
||
|
||
@abstractmethod
|
||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||
"""计算因子值"""
|
||
pass
|
||
|
||
|
||
class RuleFactor(BaseFactor):
|
||
"""规则因子:基于固定规则"""
|
||
|
||
def __init__(self, name: str, compute_func: Callable[[pd.DataFrame], pd.Series]):
|
||
super().__init__(name)
|
||
self.compute_func = compute_func
|
||
|
||
def compute(self, data: pd.DataFrame) -> pd.Series:
|
||
return self.compute_func(data)
|
||
|
||
|
||
def create_trend_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""趋势因子:价格趋势方向"""
|
||
trend = pd.Series(0, index=data.index)
|
||
trend[data['close'] > data['ema16']] = 1
|
||
trend[data['close'] < data['ema4']] = -1
|
||
return trend
|
||
|
||
|
||
def create_volatility_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""波动率因子:滚动12期收益率标准差"""
|
||
return data['volatility']
|
||
|
||
|
||
def create_volume_price_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""量价因子:成交量放大且价格上涨"""
|
||
volume_signal = (data['volume'] > data['volume_ma6']).astype(int)
|
||
return volume_signal * data['return']
|
||
|
||
|
||
def create_reversal_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""反转因子:短期反转效应"""
|
||
return -data['return'].shift(1)
|
||
|
||
|
||
def create_momentum_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""动量因子:基于MACD"""
|
||
return data['macd']
|
||
|
||
|
||
def create_rsi_factor(data: pd.DataFrame) -> pd.Series:
|
||
"""RSI因子:相对强弱指数(标准化)"""
|
||
return (data['rsi'] - 50) / 50 # 归一化到[-1, 1]
|
||
|
||
|
||
class FactorMiner:
|
||
"""因子挖掘器"""
|
||
|
||
def __init__(self):
|
||
self.factors: Dict[str, BaseFactor] = {}
|
||
|
||
def register_factor(self, factor: BaseFactor):
|
||
"""注册因子"""
|
||
self.factors[factor.name] = factor
|
||
|
||
def register_rule_factor(self, name: str, compute_func: Callable):
|
||
"""注册规则因子"""
|
||
factor = RuleFactor(name, compute_func)
|
||
self.register_factor(factor)
|
||
|
||
def compute_all_factors(self, data: pd.DataFrame) -> pd.DataFrame:
|
||
"""计算所有因子"""
|
||
factor_df = pd.DataFrame(index=data.index)
|
||
|
||
for name, factor in self.factors.items():
|
||
try:
|
||
factor_df[name] = factor.compute(data)
|
||
except Exception as e:
|
||
print(f"计算因子 {name} 时出错: {e}")
|
||
factor_df[name] = np.nan
|
||
|
||
return factor_df
|
||
|
||
def get_factor(self, name: str) -> Optional[BaseFactor]:
|
||
"""获取指定因子"""
|
||
return self.factors.get(name)
|
||
|
||
|
||
def create_default_factors() -> FactorMiner:
|
||
"""创建默认因子集合"""
|
||
miner = FactorMiner()
|
||
|
||
# 注册基础因子
|
||
miner.register_rule_factor('TREND', create_trend_factor)
|
||
miner.register_rule_factor('VOL', create_volatility_factor)
|
||
miner.register_rule_factor('VOLP', create_volume_price_factor)
|
||
miner.register_rule_factor('REV', create_reversal_factor)
|
||
miner.register_rule_factor('MOM', create_momentum_factor)
|
||
miner.register_rule_factor('RSI', create_rsi_factor)
|
||
|
||
return miner
|
||
|