第一版流程

This commit is contained in:
2025-11-08 13:39:02 +08:00
parent dcfe2d84d5
commit a66e42a8ae
11 changed files with 1648 additions and 0 deletions

138
data.py Normal file
View File

@@ -0,0 +1,138 @@
"""
数据加载和预处理模块
"""
import numpy as np
import pandas as pd
from typing import Optional, List, Dict
def load_data(file_path: str) -> pd.DataFrame:
"""加载数据文件支持feather和csv格式"""
if file_path.endswith('.feather'):
df = pd.read_feather(file_path)
elif file_path.endswith('.csv'):
df = pd.read_csv(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_path}")
# 尝试解析时间索引
for col in ['datetime', 'time', 'timestamp', 'date']:
if col in df.columns:
df[col] = pd.to_datetime(df[col])
df = df.set_index(col).sort_index()
break
return df
def compute_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
"""计算技术指标作为候选因子"""
data = df.copy()
# 收益率
data['return'] = np.log(data['close'] / data['close'].shift(1))
# 波动率滚动12期标准差
data['volatility'] = data['return'].rolling(12).std()
# 偏度滚动6期
data['skewness'] = data['return'].rolling(6).skew()
# EMA
data['ema4'] = data['close'].ewm(span=4, adjust=False).mean()
data['ema8'] = data['close'].ewm(span=8, adjust=False).mean()
data['ema16'] = data['close'].ewm(span=16, adjust=False).mean()
# MACD
ema12 = data['close'].ewm(span=12, adjust=False).mean()
ema26 = data['close'].ewm(span=26, adjust=False).mean()
data['macd'] = ema12 - ema26
# RSI
delta = data['close'].diff()
gain = (delta.where(delta > 0, 0)).rolling(14).mean()
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
rs = gain / loss
data['rsi'] = 100 - (100 / (1 + rs))
# 成交量指标
data['volume_ma6'] = data['volume'].rolling(6).mean()
data['volume_ratio'] = data['volume'] / (data['volume_ma6'] + 1e-8)
# ATR
high_low = data['high'] - data['low']
high_close = np.abs(data['high'] - data['close'].shift())
low_close = np.abs(data['low'] - data['close'].shift())
ranges = pd.concat([high_low, high_close, low_close], axis=1)
true_range = ranges.max(axis=1)
data['atr'] = true_range.rolling(14).mean()
# 高低价差率
data['price_range'] = (data['high'] - data['low']) / (data['close'].shift(1) + 1e-8)
return data
def preprocess_data(
df: pd.DataFrame,
outlier_threshold: float = 3.0,
fill_method: str = 'ffill',
normalize_window: int = 30
) -> pd.DataFrame:
"""
数据预处理:异常值处理、缺失值填补、标准化
Parameters:
-----------
df : DataFrame
原始数据
outlier_threshold : float
异常值阈值(标准差倍数)
fill_method : str
缺失值填充方法:'ffill', 'bfill', 'mean'
normalize_window : int
标准化滚动窗口
"""
data = df.copy()
# 异常值处理3σ法则- 向量化优化
numeric_cols = data.select_dtypes(include=[np.number]).columns
if 'return' in numeric_cols:
col = 'return'
mean = data[col].rolling(normalize_window, min_periods=1).mean()
std = data[col].rolling(normalize_window, min_periods=1).std()
mask = np.abs(data[col] - mean) > (outlier_threshold * std)
if mask.any():
# 用前后相邻数据的线性插值替换
data.loc[mask, col] = np.nan
data[col] = data[col].interpolate(method='linear')
# 缺失值填补(向量化)
if fill_method == 'ffill':
data = data.ffill()
elif fill_method == 'bfill':
data = data.bfill()
elif fill_method == 'mean':
data = data.fillna(data.rolling(3, min_periods=1).mean())
# 标准化滚动Z-score- 向量化处理
exclude_cols = {'open', 'high', 'low', 'close', 'volume'}
cols_to_normalize = [col for col in numeric_cols if col not in exclude_cols]
if cols_to_normalize:
# 批量计算滚动均值和标准差
rolling_mean = data[cols_to_normalize].rolling(normalize_window, min_periods=1).mean()
rolling_std = data[cols_to_normalize].rolling(normalize_window, min_periods=1).std() + 1e-8
# 批量标准化
normalized = (data[cols_to_normalize] - rolling_mean) / rolling_std
# 批量赋值
for col in cols_to_normalize:
data[f'{col}_norm'] = normalized[col]
return data
def compute_forward_returns(price: pd.Series, horizon: int = 1) -> pd.Series:
"""计算未来收益率"""
return price.pct_change(horizon).shift(-horizon)