init
This commit is contained in:
286
macd_fit.py
Normal file
286
macd_fit.py
Normal file
@@ -0,0 +1,286 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone script to:
|
||||
- Load OHLCV feather data
|
||||
- Compute MACD (12, 26, 9 by default)
|
||||
- Fit MACD histogram with a simple trigonometric model
|
||||
|
||||
Usage examples:
|
||||
python macd_fit.py \
|
||||
--feather "/Users/aszer/Documents/vscode/cta/user_data/data/okx/ADA_USDT-1d.feather"
|
||||
|
||||
Optional arguments (see -h):
|
||||
--fast 12 --slow 26 --signal 9
|
||||
--min-period 5 --max-period 500
|
||||
--recent 1000 # only use most recent N points for fitting
|
||||
--plot # show a quick plot (if matplotlib is available)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Data structures
|
||||
# ----------------------------
|
||||
|
||||
@dataclass
|
||||
class MacdResult:
|
||||
macd: np.ndarray
|
||||
signal: np.ndarray
|
||||
hist: np.ndarray
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# MACD computation
|
||||
# ----------------------------
|
||||
|
||||
def compute_ema(values: pd.Series, span: int) -> pd.Series:
|
||||
"""Compute EMA with pandas ewm for numerical stability.
|
||||
|
||||
The adjust=False setting produces the standard EMA used in trading.
|
||||
"""
|
||||
return values.ewm(span=span, adjust=False).mean()
|
||||
|
||||
|
||||
def compute_macd(close_prices: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> MacdResult:
|
||||
if slow <= fast:
|
||||
raise ValueError("'slow' period must be greater than 'fast' period")
|
||||
ema_fast = compute_ema(close_prices, span=fast)
|
||||
ema_slow = compute_ema(close_prices, span=slow)
|
||||
macd_line = ema_fast - ema_slow
|
||||
signal_line = compute_ema(macd_line, span=signal)
|
||||
hist = macd_line - signal_line
|
||||
return MacdResult(macd=macd_line.to_numpy(), signal=signal_line.to_numpy(), hist=hist.to_numpy())
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Trigonometric fitting
|
||||
# ----------------------------
|
||||
|
||||
def trig_model(t: np.ndarray, a: float, b: float, c: float, omega: float) -> np.ndarray:
|
||||
"""a*sin(omega*t) + b*cos(omega*t) + c"""
|
||||
return a * np.sin(omega * t) + b * np.cos(omega * t) + c
|
||||
|
||||
|
||||
def r2_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
||||
ss_res = float(np.sum((y_true - y_pred) ** 2))
|
||||
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
|
||||
return 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
|
||||
|
||||
|
||||
def fit_with_scipy(
|
||||
t: np.ndarray,
|
||||
y: np.ndarray,
|
||||
omega_bounds: Tuple[float, float],
|
||||
initial_period_guess: int,
|
||||
) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Try using SciPy's curve_fit if available. Returns (params, cov) or None if SciPy is missing."""
|
||||
try:
|
||||
from scipy.optimize import curve_fit # type: ignore
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Initial guesses
|
||||
y_std = float(np.std(y))
|
||||
y_mean = float(np.mean(y))
|
||||
omega0 = 2.0 * math.pi / float(max(initial_period_guess, 1))
|
||||
p0 = np.array([0.7 * y_std if y_std > 0 else 0.0, 0.0, y_mean, omega0], dtype=float)
|
||||
|
||||
bounds = (
|
||||
np.array([-np.inf, -np.inf, -np.inf, omega_bounds[0]], dtype=float),
|
||||
np.array([np.inf, np.inf, np.inf, omega_bounds[1]], dtype=float),
|
||||
)
|
||||
|
||||
params, cov = curve_fit(trig_model, t, y, p0=p0, bounds=bounds, maxfev=20000)
|
||||
return params, cov
|
||||
|
||||
|
||||
def fit_without_scipy(
|
||||
t: np.ndarray,
|
||||
y: np.ndarray,
|
||||
min_period: int,
|
||||
max_period: int,
|
||||
num_omegas: int = 200,
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Fallback: grid-search omega, solve a,b,c by linear least squares for each omega.
|
||||
|
||||
For each omega, we solve y ~ a*sin(omega*t) + b*cos(omega*t) + c.
|
||||
Returns best_params([a,b,c,omega]), best_r2.
|
||||
"""
|
||||
if min_period < 2:
|
||||
min_period = 2
|
||||
if max_period <= min_period:
|
||||
max_period = min_period + 1
|
||||
|
||||
candidate_periods = np.linspace(min_period, max_period, num=num_omegas)
|
||||
best_r2 = -np.inf
|
||||
best_params = np.array([0.0, 0.0, float(np.mean(y)), 2.0 * math.pi / float(max_period)], dtype=float)
|
||||
|
||||
# Precompute vectors that do not depend on omega (only t does)
|
||||
ones = np.ones_like(t)
|
||||
|
||||
for period in candidate_periods:
|
||||
omega = 2.0 * math.pi / float(period)
|
||||
# Design matrix: [sin(ωt), cos(ωt), 1]
|
||||
X = np.column_stack((np.sin(omega * t), np.cos(omega * t), ones))
|
||||
# Solve least squares for [a,b,c]
|
||||
coeffs, *_ = np.linalg.lstsq(X, y, rcond=None)
|
||||
y_hat = X @ coeffs
|
||||
score = r2_score(y, y_hat)
|
||||
if score > best_r2:
|
||||
best_r2 = score
|
||||
best_params = np.array([coeffs[0], coeffs[1], coeffs[2], omega], dtype=float)
|
||||
|
||||
return best_params, float(best_r2)
|
||||
|
||||
|
||||
def fit_histogram(
|
||||
hist: np.ndarray,
|
||||
min_period: int,
|
||||
max_period: int,
|
||||
initial_period_guess: int,
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Fit histogram with trig model. Returns (best_params, best_r2).
|
||||
|
||||
The time axis t uses uniform steps (index-based). This is sufficient because
|
||||
MACD is sampled at regular intervals (1d here); absolute timestamps are not required.
|
||||
"""
|
||||
n = hist.shape[0]
|
||||
t = np.arange(n, dtype=float)
|
||||
|
||||
# Reasonable omega bounds from period range
|
||||
omega_min = 2.0 * math.pi / float(max_period)
|
||||
omega_max = 2.0 * math.pi / float(max(min_period, 2))
|
||||
|
||||
scipy_fit = fit_with_scipy(t, hist, (omega_min, omega_max), initial_period_guess)
|
||||
if scipy_fit is not None:
|
||||
params, _ = scipy_fit
|
||||
y_hat = trig_model(t, *params)
|
||||
return params, float(r2_score(hist, y_hat))
|
||||
|
||||
# Fallback path without SciPy
|
||||
return fit_without_scipy(t, hist, min_period=min_period, max_period=max_period)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# I/O and CLI
|
||||
# ----------------------------
|
||||
|
||||
def read_feather_ohlcv(feather_path: str) -> pd.DataFrame:
|
||||
df = pd.read_feather(feather_path)
|
||||
# Normalize columns
|
||||
expected = {"date", "open", "high", "low", "close", "volume"}
|
||||
lower_map = {col.lower(): col for col in df.columns}
|
||||
if not expected.issubset(lower_map.keys()):
|
||||
raise ValueError(f"Feather file is missing required columns. Found: {list(df.columns)}")
|
||||
|
||||
# Ensure correct ordering and types
|
||||
df = df[[lower_map[c] for c in ["date", "open", "high", "low", "close", "volume"]]].copy()
|
||||
# Convert date to pandas datetime (timezone-aware handled by pandas)
|
||||
df[lower_map["date"]] = pd.to_datetime(df[lower_map["date"]], utc=True, errors="coerce")
|
||||
df = df.rename(columns={
|
||||
lower_map["date"]: "date",
|
||||
lower_map["open"]: "open",
|
||||
lower_map["high"]: "high",
|
||||
lower_map["low"]: "low",
|
||||
lower_map["close"]: "close",
|
||||
lower_map["volume"]: "volume",
|
||||
})
|
||||
df = df.sort_values("date").reset_index(drop=True)
|
||||
return df
|
||||
|
||||
|
||||
def try_plot(df: pd.DataFrame, hist: np.ndarray, y_hat: np.ndarray) -> None:
|
||||
try:
|
||||
import matplotlib.pyplot as plt # type: ignore
|
||||
except Exception:
|
||||
print("[info] matplotlib not available; skipping plot.")
|
||||
return
|
||||
|
||||
fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
|
||||
ax[0].plot(df["date"], df["close"], label="Close", color="#1976D2")
|
||||
ax[0].set_title("Close")
|
||||
ax[0].grid(True, alpha=0.3)
|
||||
ax[0].legend()
|
||||
|
||||
ax[1].plot(df["date"], hist, label="MACD Hist", color="#D32F2F", linewidth=1)
|
||||
ax[1].plot(df["date"], y_hat, label="Trig Fit", color="#388E3C", linewidth=1.2)
|
||||
ax[1].set_title("MACD Histogram and Trig Fit")
|
||||
ax[1].grid(True, alpha=0.3)
|
||||
ax[1].legend()
|
||||
|
||||
fig.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Compute MACD and fit histogram with trigonometric model.")
|
||||
parser.add_argument(
|
||||
"--feather",
|
||||
type=str,
|
||||
default="/Users/aszer/Documents/vscode/cta/user_data/data/okx/ADA_USDT-1d.feather",
|
||||
help="Path to feather OHLCV file",
|
||||
)
|
||||
parser.add_argument("--fast", type=int, default=12, help="MACD fast EMA period")
|
||||
parser.add_argument("--slow", type=int, default=26, help="MACD slow EMA period")
|
||||
parser.add_argument("--signal", type=int, default=9, help="MACD signal EMA period")
|
||||
parser.add_argument("--recent", type=int, default=0, help="Use only most recent N rows for fitting (0 = all)")
|
||||
parser.add_argument("--min-period", type=int, default=5, help="Minimum oscillation period (in bars) for fit")
|
||||
parser.add_argument("--max-period", type=int, default=500, help="Maximum oscillation period (in bars) for fit")
|
||||
parser.add_argument("--period-guess", type=int, default=50, help="Initial period guess (in bars)")
|
||||
parser.add_argument("--plot", action="store_true", help="Show a quick matplotlib plot (if available)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
df = read_feather_ohlcv(args.feather)
|
||||
macd_res = compute_macd(df["close"], fast=args.fast, slow=args.slow, signal=args.signal)
|
||||
|
||||
hist = macd_res.hist
|
||||
if args.recent and args.recent > 0:
|
||||
hist = hist[-args.recent :]
|
||||
df_for_fit = df.tail(hist.shape[0]).reset_index(drop=True)
|
||||
else:
|
||||
df_for_fit = df.copy()
|
||||
|
||||
params, score = fit_histogram(
|
||||
hist=hist,
|
||||
min_period=args.min_period,
|
||||
max_period=args.max_period,
|
||||
initial_period_guess=args.period_guess,
|
||||
)
|
||||
|
||||
a, b, c, omega = map(float, params)
|
||||
period = 2.0 * math.pi / omega if omega != 0 else float("inf")
|
||||
y_hat = trig_model(np.arange(hist.shape[0], dtype=float), a, b, c, omega)
|
||||
|
||||
# Console summary
|
||||
print("=== MACD Histogram Trigonometric Fit ===")
|
||||
print(f"Rows used: {hist.shape[0]}")
|
||||
print(f"Parameters: a={a:.6g}, b={b:.6g}, c={c:.6g}, omega={omega:.6g}")
|
||||
print(f"Implied period (bars): {period:.3f}")
|
||||
print(f"R^2: {score:.6f}")
|
||||
|
||||
# Show last few comparisons
|
||||
tail_n = min(10, hist.shape[0])
|
||||
print("\nLast samples (date, hist, fit):")
|
||||
for i in range(-tail_n, 0):
|
||||
date_str = df_for_fit["date"].iloc[i].strftime("%Y-%m-%d")
|
||||
print(f" {date_str} hist={hist[i]: .6g} fit={y_hat[i]: .6g}")
|
||||
|
||||
if args.plot:
|
||||
try_plot(df_for_fit, hist, y_hat)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user