Files
factorhack/data.py
2025-11-08 13:39:02 +08:00

139 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载和预处理模块
"""
import numpy as np
import pandas as pd
from typing import Optional, List, Dict
def load_data(file_path: str) -> pd.DataFrame:
"""加载数据文件支持feather和csv格式"""
if file_path.endswith('.feather'):
df = pd.read_feather(file_path)
elif file_path.endswith('.csv'):
df = pd.read_csv(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_path}")
# 尝试解析时间索引
for col in ['datetime', 'time', 'timestamp', 'date']:
if col in df.columns:
df[col] = pd.to_datetime(df[col])
df = df.set_index(col).sort_index()
break
return df
def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
"""计算技术指标作为候选因子"""
data = df.copy()
# 收益率
data['return'] = np.log(data['close'] / data['close'].shift(1))
# 波动率滚动12期标准差
data['volatility'] = data['return'].rolling(12).std()
# 偏度滚动6期
data['skewness'] = data['return'].rolling(6).skew()
# EMA
data['ema4'] = data['close'].ewm(span=4, adjust=False).mean()
data['ema8'] = data['close'].ewm(span=8, adjust=False).mean()
data['ema16'] = data['close'].ewm(span=16, adjust=False).mean()
# MACD
ema12 = data['close'].ewm(span=12, adjust=False).mean()
ema26 = data['close'].ewm(span=26, adjust=False).mean()
data['macd'] = ema12 - ema26
# RSI
delta = data['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / loss
data['rsi'] = 100 - (100 / (1 + rs))
# 成交量指标
data['volume_ma6'] = data['volume'].rolling(6).mean()
data['volume_ratio'] = data['volume'] / (data['volume_ma6'] + 1e-8)
# ATR
high_low = data['high'] - data['low']
high_close = np.abs(data['high'] - data['close'].shift())
low_close = np.abs(data['low'] - data['close'].shift())
ranges = pd.concat([high_low, high_close, low_close], axis=1)
true_range = ranges.max(axis=1)
data['atr'] = true_range.rolling(14).mean()
# 高低价差率
data['price_range'] = (data['high'] - data['low']) / (data['close'].shift(1) + 1e-8)
return data
def preprocess_data(
df: pd.DataFrame,
outlier_threshold: float = 3.0,
fill_method: str = 'ffill',
normalize_window: int = 30
) -> pd.DataFrame:
"""
数据预处理:异常值处理、缺失值填补、标准化
Parameters:
-----------
df : DataFrame
原始数据
outlier_threshold : float
异常值阈值(标准差倍数)
fill_method : str
缺失值填充方法:'ffill', 'bfill', 'mean'
normalize_window : int
标准化滚动窗口
"""
data = df.copy()
# 异常值处理3σ法则- 向量化优化
numeric_cols = data.select_dtypes(include=[np.number]).columns
if 'return' in numeric_cols:
col = 'return'
mean = data[col].rolling(normalize_window, min_periods=1).mean()
std = data[col].rolling(normalize_window, min_periods=1).std()
mask = np.abs(data[col] - mean) > (outlier_threshold * std)
if mask.any():
# 用前后相邻数据的线性插值替换
data.loc[mask, col] = np.nan
data[col] = data[col].interpolate(method='linear')
# 缺失值填补(向量化)
if fill_method == 'ffill':
data = data.ffill()
elif fill_method == 'bfill':
data = data.bfill()
elif fill_method == 'mean':
data = data.fillna(data.rolling(3, min_periods=1).mean())
# 标准化滚动Z-score- 向量化处理
exclude_cols = {'open', 'high', 'low', 'close', 'volume'}
cols_to_normalize = [col for col in numeric_cols if col not in exclude_cols]
if cols_to_normalize:
# 批量计算滚动均值和标准差
rolling_mean = data[cols_to_normalize].rolling(normalize_window, min_periods=1).mean()
rolling_std = data[cols_to_normalize].rolling(normalize_window, min_periods=1).std() + 1e-8
# 批量标准化
normalized = (data[cols_to_normalize] - rolling_mean) / rolling_std
# 批量赋值
for col in cols_to_normalize:
data[f'{col}_norm'] = normalized[col]
return data
def compute_forward_returns(price: pd.Series, horizon: int = 1) -> pd.Series:
"""计算未来收益率"""
return price.pct_change(horizon).shift(-horizon)