第一版流程
This commit is contained in:
138
data.py
Normal file
138
data.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
数据加载和预处理模块
|
||||
"""
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
|
||||
def load_data(file_path: str) -> pd.DataFrame:
|
||||
"""加载数据文件(支持feather和csv格式)"""
|
||||
if file_path.endswith('.feather'):
|
||||
df = pd.read_feather(file_path)
|
||||
elif file_path.endswith('.csv'):
|
||||
df = pd.read_csv(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {file_path}")
|
||||
|
||||
# 尝试解析时间索引
|
||||
for col in ['datetime', 'time', 'timestamp', 'date']:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_datetime(df[col])
|
||||
df = df.set_index(col).sort_index()
|
||||
break
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""计算技术指标作为候选因子"""
|
||||
data = df.copy()
|
||||
|
||||
# 收益率
|
||||
data['return'] = np.log(data['close'] / data['close'].shift(1))
|
||||
|
||||
# 波动率(滚动12期标准差)
|
||||
data['volatility'] = data['return'].rolling(12).std()
|
||||
|
||||
# 偏度(滚动6期)
|
||||
data['skewness'] = data['return'].rolling(6).skew()
|
||||
|
||||
# EMA
|
||||
data['ema4'] = data['close'].ewm(span=4, adjust=False).mean()
|
||||
data['ema8'] = data['close'].ewm(span=8, adjust=False).mean()
|
||||
data['ema16'] = data['close'].ewm(span=16, adjust=False).mean()
|
||||
|
||||
# MACD
|
||||
ema12 = data['close'].ewm(span=12, adjust=False).mean()
|
||||
ema26 = data['close'].ewm(span=26, adjust=False).mean()
|
||||
data['macd'] = ema12 - ema26
|
||||
|
||||
# RSI
|
||||
delta = data['close'].diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
||||
rs = gain / loss
|
||||
data['rsi'] = 100 - (100 / (1 + rs))
|
||||
|
||||
# 成交量指标
|
||||
data['volume_ma6'] = data['volume'].rolling(6).mean()
|
||||
data['volume_ratio'] = data['volume'] / (data['volume_ma6'] + 1e-8)
|
||||
|
||||
# ATR
|
||||
high_low = data['high'] - data['low']
|
||||
high_close = np.abs(data['high'] - data['close'].shift())
|
||||
low_close = np.abs(data['low'] - data['close'].shift())
|
||||
ranges = pd.concat([high_low, high_close, low_close], axis=1)
|
||||
true_range = ranges.max(axis=1)
|
||||
data['atr'] = true_range.rolling(14).mean()
|
||||
|
||||
# 高低价差率
|
||||
data['price_range'] = (data['high'] - data['low']) / (data['close'].shift(1) + 1e-8)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def preprocess_data(
|
||||
df: pd.DataFrame,
|
||||
outlier_threshold: float = 3.0,
|
||||
fill_method: str = 'ffill',
|
||||
normalize_window: int = 30
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
数据预处理:异常值处理、缺失值填补、标准化
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
df : DataFrame
|
||||
原始数据
|
||||
outlier_threshold : float
|
||||
异常值阈值(标准差倍数)
|
||||
fill_method : str
|
||||
缺失值填充方法:'ffill', 'bfill', 'mean'
|
||||
normalize_window : int
|
||||
标准化滚动窗口
|
||||
"""
|
||||
data = df.copy()
|
||||
|
||||
# 异常值处理(3σ法则)- 向量化优化
|
||||
numeric_cols = data.select_dtypes(include=[np.number]).columns
|
||||
if 'return' in numeric_cols:
|
||||
col = 'return'
|
||||
mean = data[col].rolling(normalize_window, min_periods=1).mean()
|
||||
std = data[col].rolling(normalize_window, min_periods=1).std()
|
||||
mask = np.abs(data[col] - mean) > (outlier_threshold * std)
|
||||
if mask.any():
|
||||
# 用前后相邻数据的线性插值替换
|
||||
data.loc[mask, col] = np.nan
|
||||
data[col] = data[col].interpolate(method='linear')
|
||||
|
||||
# 缺失值填补(向量化)
|
||||
if fill_method == 'ffill':
|
||||
data = data.ffill()
|
||||
elif fill_method == 'bfill':
|
||||
data = data.bfill()
|
||||
elif fill_method == 'mean':
|
||||
data = data.fillna(data.rolling(3, min_periods=1).mean())
|
||||
|
||||
# 标准化(滚动Z-score)- 向量化处理
|
||||
exclude_cols = {'open', 'high', 'low', 'close', 'volume'}
|
||||
cols_to_normalize = [col for col in numeric_cols if col not in exclude_cols]
|
||||
|
||||
if cols_to_normalize:
|
||||
# 批量计算滚动均值和标准差
|
||||
rolling_mean = data[cols_to_normalize].rolling(normalize_window, min_periods=1).mean()
|
||||
rolling_std = data[cols_to_normalize].rolling(normalize_window, min_periods=1).std() + 1e-8
|
||||
# 批量标准化
|
||||
normalized = (data[cols_to_normalize] - rolling_mean) / rolling_std
|
||||
# 批量赋值
|
||||
for col in cols_to_normalize:
|
||||
data[f'{col}_norm'] = normalized[col]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def compute_forward_returns(price: pd.Series, horizon: int = 1) -> pd.Series:
|
||||
"""计算未来收益率"""
|
||||
return price.pct_change(horizon).shift(-horizon)
|
||||
|
||||
Reference in New Issue
Block a user