Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python
|
| 2 |
+
# app.py
|
| 3 |
+
"""
|
| 4 |
+
Gradio (blank) tabanlı Hugging Face Space uygulaması.
|
| 5 |
+
- OKX REST API'den BTC/USDT (spot) candle verisi çeker
|
| 6 |
+
- Teknik göstergeler üretir
|
| 7 |
+
- Ensemble: LightGBM, XGBoost, RandomForest (sklearn) + küçük PyTorch LSTM
|
| 8 |
+
- Eğer pretrained model dosyaları yoksa küçük demo modeller oluşturur
|
| 9 |
+
- Outputs: tahmin (regresyon: next-close), model katkıları, grafikler
|
| 10 |
+
|
| 11 |
+
Not:
|
| 12 |
+
- requirements.txt'de aşağıdakiler olmalı:
|
| 13 |
+
gradio, pandas, numpy, requests, ta, scikit-learn, lightgbm, xgboost, torch, matplotlib
|
| 14 |
+
- Kullanıcı OKX API anahtarı gerekli değildir (public candles endpoint kullanılıyor).
|
| 15 |
+
- Bu dosya tek başına çalışır; ancak ağır paketler (lightgbm, xgboost, torch) Spaces ortamında kurulmadıysa hata verebilir.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import io
|
| 20 |
+
import time
|
| 21 |
+
import math
|
| 22 |
+
import json
|
| 23 |
+
import threading
|
| 24 |
+
from typing import Tuple, Dict, Any, List
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import pandas as pd
|
| 28 |
+
import requests
|
| 29 |
+
from datetime import datetime, timedelta, timezone
|
| 30 |
+
|
| 31 |
+
# Visualization
|
| 32 |
+
import matplotlib
|
| 33 |
+
matplotlib.use("Agg")
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
+
|
| 36 |
+
# Technical indicators
|
| 37 |
+
try:
|
| 38 |
+
import ta
|
| 39 |
+
except Exception:
|
| 40 |
+
# Minimal fallback implementations if ta isn't installed
|
| 41 |
+
ta = None
|
| 42 |
+
|
| 43 |
+
# ML libs
|
| 44 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 45 |
+
from sklearn.preprocessing import StandardScaler
|
| 46 |
+
from sklearn.pipeline import Pipeline
|
| 47 |
+
from sklearn.base import BaseEstimator, RegressorMixin
|
| 48 |
+
|
| 49 |
+
# Try import optional libs
|
| 50 |
+
HAS_LGB = True
|
| 51 |
+
HAS_XGB = True
|
| 52 |
+
HAS_TORCH = True
|
| 53 |
+
try:
|
| 54 |
+
import lightgbm as lgb
|
| 55 |
+
except Exception:
|
| 56 |
+
HAS_LGB = False
|
| 57 |
+
try:
|
| 58 |
+
import xgboost as xgb
|
| 59 |
+
except Exception:
|
| 60 |
+
HAS_XGB = False
|
| 61 |
+
try:
|
| 62 |
+
import torch
|
| 63 |
+
import torch.nn as nn
|
| 64 |
+
import torch.nn.functional as F
|
| 65 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 66 |
+
except Exception:
|
| 67 |
+
HAS_TORCH = False
|
| 68 |
+
|
| 69 |
+
# Gradio
|
| 70 |
+
import gradio as gr
|
| 71 |
+
|
| 72 |
+
# -------------------------
|
| 73 |
+
# Configuration/Constants
|
| 74 |
+
# -------------------------
|
| 75 |
+
OKX_BASE = "https://www.okx.com"
|
| 76 |
+
# Public candles: GET /api/v5/market/history-candles?instId=BTC-USDT-SWAP&bar=1m&limit=100
|
| 77 |
+
# We'll use spot: BTC-USDT
|
| 78 |
+
DEFAULT_INSTRUMENT = "BTC-USDT"
|
| 79 |
+
DEFAULT_BAR = "1m" # options: 1m, 3m, 5m, 15m, 1H etc.
|
| 80 |
+
DEFAULT_LIMIT = 500 # up to 1000 depending on endpoint
|
| 81 |
+
|
| 82 |
+
# Model filenames (in repo or persisted by training)
|
| 83 |
+
MODEL_DIR = "models"
|
| 84 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 85 |
+
LGB_MODEL_FILE = os.path.join(MODEL_DIR, "lgb_model.txt")
|
| 86 |
+
XGB_MODEL_FILE = os.path.join(MODEL_DIR, "xgb_model.json")
|
| 87 |
+
RF_MODEL_FILE = os.path.join(MODEL_DIR, "rf_model.pkl")
|
| 88 |
+
LSTM_MODEL_FILE = os.path.join(MODEL_DIR, "lstm_model.pt")
|
| 89 |
+
SCALER_FILE = os.path.join(MODEL_DIR, "scaler.npy") # save scaler mean/scale
|
| 90 |
+
|
| 91 |
+
# Thread-safe model cache
|
| 92 |
+
_MODEL_LOCK = threading.Lock()
|
| 93 |
+
_MODELS = {}
|
| 94 |
+
|
| 95 |
+
# -------------------------
|
| 96 |
+
# Utilities
|
| 97 |
+
# -------------------------
|
| 98 |
+
def now_iso():
|
| 99 |
+
return datetime.now(timezone.utc).isoformat()
|
| 100 |
+
|
| 101 |
+
def okx_candles(inst_id: str = DEFAULT_INSTRUMENT, bar: str = DEFAULT_BAR, limit: int = DEFAULT_LIMIT) -> pd.DataFrame:
|
| 102 |
+
"""
|
| 103 |
+
Fetch recent candle data from OKX public REST API.
|
| 104 |
+
Returns DataFrame with columns: ts, open, high, low, close, volume
|
| 105 |
+
ts in UTC datetime
|
| 106 |
+
"""
|
| 107 |
+
url = f"{OKX_BASE}/api/v5/market/history-candles"
|
| 108 |
+
params = {"instId": inst_id, "bar": bar, "limit": str(limit)}
|
| 109 |
+
resp = requests.get(url, params=params, timeout=15)
|
| 110 |
+
resp.raise_for_status()
|
| 111 |
+
data = resp.json()
|
| 112 |
+
|
| 113 |
+
if not data or data.get("code") not in (None, "0", 0):
|
| 114 |
+
# OKX returns "code": "0" on success sometimes; be permissive
|
| 115 |
+
# If structure unexpected, raise
|
| 116 |
+
# Try to parse anyway
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
cand = data.get("data", [])
|
| 120 |
+
if not cand:
|
| 121 |
+
# Possibly different field
|
| 122 |
+
raise RuntimeError("No candle data returned from OKX")
|
| 123 |
+
|
| 124 |
+
# OKX returns list of lists: [ts, open, high, low, close, volume, ...]
|
| 125 |
+
# timestamp in millis
|
| 126 |
+
rows = []
|
| 127 |
+
for c in cand:
|
| 128 |
+
# According to OKX docs: [ts, open, high, low, close, volume]
|
| 129 |
+
ts = int(c[0]) // 1000 if len(str(c[0])) > 10 else int(c[0])
|
| 130 |
+
dt = datetime.fromtimestamp(ts, tz=timezone.utc)
|
| 131 |
+
rows.append({
|
| 132 |
+
"ts": dt,
|
| 133 |
+
"open": float(c[1]),
|
| 134 |
+
"high": float(c[2]),
|
| 135 |
+
"low": float(c[3]),
|
| 136 |
+
"close": float(c[4]),
|
| 137 |
+
"volume": float(c[5])
|
| 138 |
+
})
|
| 139 |
+
df = pd.DataFrame(rows)
|
| 140 |
+
df = df.sort_values("ts").reset_index(drop=True)
|
| 141 |
+
return df
|
| 142 |
+
|
| 143 |
+
# Minimal TA indicators if `ta` package is not available
|
| 144 |
+
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
|
| 145 |
+
df = df.copy()
|
| 146 |
+
if ta is not None:
|
| 147 |
+
# Use ta to add common indicators
|
| 148 |
+
df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14, fillna=True).rsi()
|
| 149 |
+
df["ema12"] = ta.trend.EMAIndicator(df["close"], window=12, fillna=True).ema_indicator()
|
| 150 |
+
df["ema26"] = ta.trend.EMAIndicator(df["close"], window=26, fillna=True).ema_indicator()
|
| 151 |
+
macd = ta.trend.MACD(df["close"], window_slow=26, window_fast=12, window_sign=9, fillna=True)
|
| 152 |
+
df["macd"] = macd.macd()
|
| 153 |
+
df["macd_signal"] = macd.macd_signal()
|
| 154 |
+
df["bb_high"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_hband()
|
| 155 |
+
df["bb_low"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_lband()
|
| 156 |
+
df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"], window=14, fillna=True).average_true_range()
|
| 157 |
+
else:
|
| 158 |
+
# Fallback simple computations
|
| 159 |
+
df["rsi"] = simple_rsi(df["close"], window=14)
|
| 160 |
+
df["ema12"] = df["close"].ewm(span=12, adjust=False).mean()
|
| 161 |
+
df["ema26"] = df["close"].ewm(span=26, adjust=False).mean()
|
| 162 |
+
df["macd"] = df["ema12"] - df["ema26"]
|
| 163 |
+
df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
|
| 164 |
+
df["bb_mid"] = df["close"].rolling(20).mean()
|
| 165 |
+
df["bb_std"] = df["close"].rolling(20).std()
|
| 166 |
+
df["bb_high"] = df["bb_mid"] + 2 * df["bb_std"]
|
| 167 |
+
df["bb_low"] = df["bb_mid"] - 2 * df["bb_std"]
|
| 168 |
+
df["atr"] = simple_atr(df, window=14)
|
| 169 |
+
# Fill na
|
| 170 |
+
df = df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)
|
| 171 |
+
return df
|
| 172 |
+
|
| 173 |
+
def simple_rsi(series: pd.Series, window: int = 14) -> pd.Series:
|
| 174 |
+
delta = series.diff()
|
| 175 |
+
up = delta.clip(lower=0)
|
| 176 |
+
down = -1 * delta.clip(upper=0)
|
| 177 |
+
ma_up = up.ewm(alpha=1/window, adjust=False).mean()
|
| 178 |
+
ma_down = down.ewm(alpha=1/window, adjust=False).mean()
|
| 179 |
+
rs = ma_up / (ma_down + 1e-8)
|
| 180 |
+
rsi = 100 - (100 / (1 + rs))
|
| 181 |
+
return rsi.fillna(50.0)
|
| 182 |
+
|
| 183 |
+
def simple_atr(df: pd.DataFrame, window: int = 14) -> pd.Series:
|
| 184 |
+
high_low = df["high"] - df["low"]
|
| 185 |
+
high_close = (df["high"] - df["close"].shift()).abs()
|
| 186 |
+
low_close = (df["low"] - df["close"].shift()).abs()
|
| 187 |
+
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
| 188 |
+
atr = tr.ewm(span=window, adjust=False).mean()
|
| 189 |
+
return atr.fillna(0.0)
|
| 190 |
+
|
| 191 |
+
def create_features(df: pd.DataFrame) -> pd.DataFrame:
|
| 192 |
+
df = df.copy()
|
| 193 |
+
df = add_technical_indicators(df)
|
| 194 |
+
# Returns features aligned to each row predicting next row's close
|
| 195 |
+
# Feature engineering: returns, log returns, vol, moving averages, ratios
|
| 196 |
+
df["return_1"] = df["close"].pct_change().fillna(0.0)
|
| 197 |
+
df["log_return_1"] = np.log1p(df["return_1"])
|
| 198 |
+
df["vol_5"] = df["close"].rolling(5).std().fillna(0.0)
|
| 199 |
+
df["vol_20"] = df["close"].rolling(20).std().fillna(0.0)
|
| 200 |
+
df["ma_5"] = df["close"].rolling(5).mean().fillna(method="bfill")
|
| 201 |
+
df["ma_20"] = df["close"].rolling(20).mean().fillna(method="bfill")
|
| 202 |
+
df["ma_50"] = df["close"].rolling(50).mean().fillna(method="bfill")
|
| 203 |
+
# ratio features
|
| 204 |
+
df["ma5_div_ma20"] = df["ma_5"] / (df["ma_20"] + 1e-9)
|
| 205 |
+
df["ema_diff"] = df["ema12"] - df["ema26"]
|
| 206 |
+
# time features
|
| 207 |
+
df["ts_unix"] = df["ts"].astype(np.int64) // 10**9
|
| 208 |
+
df["hour"] = df["ts"].dt.hour
|
| 209 |
+
df["minute"] = df["ts"].dt.minute
|
| 210 |
+
# fill remaining na
|
| 211 |
+
df = df.fillna(method="bfill").fillna(0.0)
|
| 212 |
+
return df
|
| 213 |
+
|
| 214 |
+
# -------------------------
|
| 215 |
+
# Model wrappers and helpers
|
| 216 |
+
# -------------------------
|
| 217 |
+
class DummyRegressor(BaseEstimator, RegressorMixin):
|
| 218 |
+
"""Simple mean predictor used as fallback."""
|
| 219 |
+
def fit(self, X, y):
|
| 220 |
+
self._mean = np.mean(y) if len(y) else 0.0
|
| 221 |
+
return self
|
| 222 |
+
def predict(self, X):
|
| 223 |
+
return np.full((X.shape[0],), getattr(self, "_mean", 0.0))
|
| 224 |
+
|
| 225 |
+
def save_numpy(obj: np.ndarray, path: str):
|
| 226 |
+
np.save(path, obj)
|
| 227 |
+
|
| 228 |
+
def load_numpy(path: str) -> np.ndarray:
|
| 229 |
+
return np.load(path)
|
| 230 |
+
|
| 231 |
+
def get_feature_columns() -> List[str]:
|
| 232 |
+
cols = [
|
| 233 |
+
"open","high","low","close","volume",
|
| 234 |
+
"rsi","ema12","ema26","macd","macd_signal","bb_high","bb_low","atr",
|
| 235 |
+
"return_1","log_return_1","vol_5","vol_20","ma_5","ma_20","ma_50",
|
| 236 |
+
"ma5_div_ma20","ema_diff","ts_unix","hour","minute"
|
| 237 |
+
]
|
| 238 |
+
return cols
|
| 239 |
+
|
| 240 |
+
# Model persistence helpers (light, simple)
|
| 241 |
+
def load_models() -> Dict[str, Any]:
|
| 242 |
+
"""
|
| 243 |
+
Try to load pretrained models from MODEL_DIR. If missing, create small demo models.
|
| 244 |
+
Returns dict of models and scaler.
|
| 245 |
+
"""
|
| 246 |
+
with _MODEL_LOCK:
|
| 247 |
+
if _MODELS:
|
| 248 |
+
return _MODELS
|
| 249 |
+
|
| 250 |
+
models = {}
|
| 251 |
+
scaler = None
|
| 252 |
+
|
| 253 |
+
# Try load scaler if exists
|
| 254 |
+
if os.path.exists(SCALER_FILE):
|
| 255 |
+
try:
|
| 256 |
+
sc = np.load(SCALER_FILE, allow_pickle=True).item()
|
| 257 |
+
scaler = StandardScaler()
|
| 258 |
+
scaler.mean_ = sc["mean"]
|
| 259 |
+
scaler.scale_ = sc["scale"]
|
| 260 |
+
scaler.n_features_in_ = sc["n_in"]
|
| 261 |
+
except Exception:
|
| 262 |
+
scaler = None
|
| 263 |
+
|
| 264 |
+
# RandomForest (sklearn)
|
| 265 |
+
try:
|
| 266 |
+
import joblib
|
| 267 |
+
if os.path.exists(RF_MODEL_FILE):
|
| 268 |
+
models["rf"] = joblib.load(RF_MODEL_FILE)
|
| 269 |
+
else:
|
| 270 |
+
raise FileNotFoundError
|
| 271 |
+
except Exception:
|
| 272 |
+
# create small RF demo
|
| 273 |
+
models["rf"] = RandomForestRegressor(n_estimators=10, random_state=42)
|
| 274 |
+
|
| 275 |
+
# LightGBM
|
| 276 |
+
if HAS_LGB and os.path.exists(LGB_MODEL_FILE):
|
| 277 |
+
try:
|
| 278 |
+
models["lgb"] = lgb.Booster(model_file=LGB_MODEL_FILE)
|
| 279 |
+
except Exception:
|
| 280 |
+
models["lgb"] = None
|
| 281 |
+
else:
|
| 282 |
+
models["lgb"] = None if not HAS_LGB else None
|
| 283 |
+
|
| 284 |
+
# XGBoost
|
| 285 |
+
if HAS_XGB and os.path.exists(XGB_MODEL_FILE):
|
| 286 |
+
try:
|
| 287 |
+
models["xgb"] = xgb.Booster()
|
| 288 |
+
models["xgb"].load_model(XGB_MODEL_FILE)
|
| 289 |
+
except Exception:
|
| 290 |
+
models["xgb"] = None
|
| 291 |
+
else:
|
| 292 |
+
models["xgb"] = None
|
| 293 |
+
|
| 294 |
+
# LSTM / PyTorch
|
| 295 |
+
if HAS_TORCH and os.path.exists(LSTM_MODEL_FILE):
|
| 296 |
+
try:
|
| 297 |
+
lstm = torch.load(LSTM_MODEL_FILE, map_location=torch.device("cpu"))
|
| 298 |
+
models["lstm"] = lstm
|
| 299 |
+
except Exception:
|
| 300 |
+
models["lstm"] = None
|
| 301 |
+
else:
|
| 302 |
+
models["lstm"] = None
|
| 303 |
+
|
| 304 |
+
# If scaler missing, create a dummy one later in pipeline when training; for inference create StandardScaler default
|
| 305 |
+
if scaler is None:
|
| 306 |
+
scaler = StandardScaler()
|
| 307 |
+
|
| 308 |
+
# Create an ensemble wrapper
|
| 309 |
+
models["scaler"] = scaler
|
| 310 |
+
|
| 311 |
+
_MODELS.update(models)
|
| 312 |
+
return _MODELS
|
| 313 |
+
|
| 314 |
+
def save_scaler(scaler: StandardScaler, path: str = SCALER_FILE):
|
| 315 |
+
obj = {"mean": scaler.mean_, "scale": scaler.scale_, "n_in": scaler.n_features_in_}
|
| 316 |
+
np.save(path, obj)
|
| 317 |
+
|
| 318 |
+
# -------------------------
|
| 319 |
+
# Inference logic
|
| 320 |
+
# -------------------------
|
| 321 |
+
def prepare_inference_features(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], pd.DataFrame]:
|
| 322 |
+
"""
|
| 323 |
+
Takes raw candles df, returns (X, feature_cols, df_ready)
|
| 324 |
+
X is 2D array for model input, aligned so that each row predicts next close.
|
| 325 |
+
"""
|
| 326 |
+
df2 = create_features(df)
|
| 327 |
+
feat_cols = get_feature_columns()
|
| 328 |
+
# Ensure columns present
|
| 329 |
+
for c in feat_cols:
|
| 330 |
+
if c not in df2.columns:
|
| 331 |
+
df2[c] = 0.0
|
| 332 |
+
X = df2[feat_cols].values
|
| 333 |
+
return X, feat_cols, df2
|
| 334 |
+
|
| 335 |
+
def predict_ensemble(X: np.ndarray, models: Dict[str, Any]) -> Dict[str, Any]:
|
| 336 |
+
"""
|
| 337 |
+
Predict next-step close using ensemble of models.
|
| 338 |
+
Return dict:
|
| 339 |
+
- per_model_preds: {name: scalar_pred}
|
| 340 |
+
- ensemble_mean: float
|
| 341 |
+
- weighted: float (weights fallback equal)
|
| 342 |
+
"""
|
| 343 |
+
scaler = models.get("scaler", None)
|
| 344 |
+
if scaler is None:
|
| 345 |
+
scaler = StandardScaler()
|
| 346 |
+
# Use last row features to predict next
|
| 347 |
+
if X.ndim == 1:
|
| 348 |
+
X_row = X.reshape(1, -1)
|
| 349 |
+
else:
|
| 350 |
+
X_row = X[-1:, :]
|
| 351 |
+
# scale
|
| 352 |
+
try:
|
| 353 |
+
Xs = scaler.transform(X_row)
|
| 354 |
+
except Exception:
|
| 355 |
+
# If scaler not fitted, fit on X (fallback)
|
| 356 |
+
try:
|
| 357 |
+
scaler.fit(X)
|
| 358 |
+
save_scaler(scaler)
|
| 359 |
+
Xs = scaler.transform(X_row)
|
| 360 |
+
except Exception:
|
| 361 |
+
Xs = X_row
|
| 362 |
+
|
| 363 |
+
preds = {}
|
| 364 |
+
# RandomForest
|
| 365 |
+
rf = models.get("rf", None)
|
| 366 |
+
if rf is not None:
|
| 367 |
+
try:
|
| 368 |
+
p = rf.predict(Xs)[0]
|
| 369 |
+
except Exception:
|
| 370 |
+
p = float(np.nan)
|
| 371 |
+
else:
|
| 372 |
+
p = float(np.nan)
|
| 373 |
+
preds["rf"] = float(p)
|
| 374 |
+
|
| 375 |
+
# LightGBM
|
| 376 |
+
if HAS_LGB and models.get("lgb", None) is not None:
|
| 377 |
+
try:
|
| 378 |
+
dmat = lgb.Dataset(Xs, free_raw_data=False)
|
| 379 |
+
p = models["lgb"].predict(Xs)[0]
|
| 380 |
+
except Exception:
|
| 381 |
+
p = float(np.nan)
|
| 382 |
+
else:
|
| 383 |
+
p = float(np.nan)
|
| 384 |
+
preds["lgb"] = float(p)
|
| 385 |
+
|
| 386 |
+
# XGBoost
|
| 387 |
+
if HAS_XGB and models.get("xgb", None) is not None:
|
| 388 |
+
try:
|
| 389 |
+
dm = xgb.DMatrix(Xs)
|
| 390 |
+
p = models["xgb"].predict(dm)[0]
|
| 391 |
+
except Exception:
|
| 392 |
+
p = float(np.nan)
|
| 393 |
+
else:
|
| 394 |
+
p = float(np.nan)
|
| 395 |
+
preds["xgb"] = float(p)
|
| 396 |
+
|
| 397 |
+
# LSTM (PyTorch)
|
| 398 |
+
if HAS_TORCH and models.get("lstm", None) is not None:
|
| 399 |
+
try:
|
| 400 |
+
model = models["lstm"]
|
| 401 |
+
model.eval()
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
t = torch.tensor(X_row, dtype=torch.float32).unsqueeze(0) # shape (1,1,features) if expected
|
| 404 |
+
# try both (1,features) or (1,seq,features)
|
| 405 |
+
if t.dim() == 3:
|
| 406 |
+
out = model(t)
|
| 407 |
+
else:
|
| 408 |
+
# reshape to (1,1,features)
|
| 409 |
+
t2 = t.unsqueeze(1)
|
| 410 |
+
out = model(t2)
|
| 411 |
+
p = float(out.squeeze().cpu().numpy())
|
| 412 |
+
except Exception:
|
| 413 |
+
p = float(np.nan)
|
| 414 |
+
else:
|
| 415 |
+
p = float(np.nan)
|
| 416 |
+
preds["lstm"] = float(p)
|
| 417 |
+
|
| 418 |
+
# If models missing, fallback: use RF or mean of last price as naive
|
| 419 |
+
valid_preds = [v for v in preds.values() if not (math.isnan(v) or v is None)]
|
| 420 |
+
if not valid_preds:
|
| 421 |
+
# fallback naive next-close = last close
|
| 422 |
+
naive = float(X_row[0, get_feature_columns().index("close")])
|
| 423 |
+
ensemble_mean = naive
|
| 424 |
+
weighted = naive
|
| 425 |
+
else:
|
| 426 |
+
ensemble_mean = float(np.nanmean(valid_preds))
|
| 427 |
+
# Simple weighting: prefer models that exist; equal weight
|
| 428 |
+
weighted = ensemble_mean
|
| 429 |
+
|
| 430 |
+
return {
|
| 431 |
+
"per_model": preds,
|
| 432 |
+
"ensemble_mean": ensemble_mean,
|
| 433 |
+
"weighted": weighted
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
# -------------------------
|
| 437 |
+
# LSTM simple architecture (for demo)
|
| 438 |
+
# -------------------------
|
| 439 |
+
if HAS_TORCH:
|
| 440 |
+
class SimpleLSTM(nn.Module):
|
| 441 |
+
def __init__(self, input_size: int, hidden_size: int = 32, num_layers: int = 1):
|
| 442 |
+
super().__init__()
|
| 443 |
+
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
|
| 444 |
+
self.fc = nn.Linear(hidden_size, 1)
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
# x: (batch, seq_len, input_size)
|
| 447 |
+
out, _ = self.lstm(x)
|
| 448 |
+
# take last time step
|
| 449 |
+
last = out[:, -1, :]
|
| 450 |
+
return self.fc(last)
|
| 451 |
+
|
| 452 |
+
# -------------------------
|
| 453 |
+
# Visualization helpers
|
| 454 |
+
# -------------------------
|
| 455 |
+
def plot_price_and_preds(df: pd.DataFrame, preds: Dict[str, Any]) -> bytes:
|
| 456 |
+
fig, ax = plt.subplots(figsize=(9,4))
|
| 457 |
+
ax.plot(df["ts"], df["close"], label="close", color="black", lw=1)
|
| 458 |
+
# mark last price and ensemble prediction
|
| 459 |
+
last_ts = df["ts"].iloc[-1]
|
| 460 |
+
last_close = df["close"].iloc[-1]
|
| 461 |
+
pred = preds.get("weighted", preds.get("ensemble_mean", last_close))
|
| 462 |
+
ax.scatter([last_ts + pd.Timedelta(seconds=1)], [pred], color="red", label="ensemble_pred")
|
| 463 |
+
ax.axhline(last_close, linestyle="--", color="gray", alpha=0.6)
|
| 464 |
+
ax.set_title("BTC/USDT close and ensemble prediction")
|
| 465 |
+
ax.set_xlabel("Time (UTC)")
|
| 466 |
+
ax.set_ylabel("Price")
|
| 467 |
+
ax.legend()
|
| 468 |
+
fig.tight_layout()
|
| 469 |
+
buf = io.BytesIO()
|
| 470 |
+
fig.savefig(buf, format="png")
|
| 471 |
+
plt.close(fig)
|
| 472 |
+
buf.seek(0)
|
| 473 |
+
return buf.read()
|
| 474 |
+
|
| 475 |
+
def plot_model_contributions(per_model: Dict[str, float]) -> bytes:
|
| 476 |
+
names = list(per_model.keys())
|
| 477 |
+
vals = [per_model[n] if (not math.isnan(per_model[n])) else 0.0 for n in names]
|
| 478 |
+
fig, ax = plt.subplots(figsize=(6,3))
|
| 479 |
+
ax.bar(names, vals, color=["#1f77b4","#ff7f0e","#2ca02c","#d62728"])
|
| 480 |
+
ax.set_title("Per-model predictions (abs values)")
|
| 481 |
+
ax.set_ylabel("Predicted price")
|
| 482 |
+
fig.tight_layout()
|
| 483 |
+
buf = io.BytesIO()
|
| 484 |
+
fig.savefig(buf, format="png")
|
| 485 |
+
plt.close(fig)
|
| 486 |
+
buf.seek(0)
|
| 487 |
+
return buf.read()
|
| 488 |
+
|
| 489 |
+
# -------------------------
|
| 490 |
+
# Gradio app components
|
| 491 |
+
# -------------------------
|
| 492 |
+
def inference_pipeline(inst_id: str = DEFAULT_INSTRUMENT,
|
| 493 |
+
bar: str = DEFAULT_BAR,
|
| 494 |
+
limit: int = DEFAULT_LIMIT,
|
| 495 |
+
show_plot: bool = True):
|
| 496 |
+
"""
|
| 497 |
+
High-level function called by Gradio. Returns JSON/dicts + image bytes for display.
|
| 498 |
+
"""
|
| 499 |
+
# Step 1: fetch candles
|
| 500 |
+
try:
|
| 501 |
+
df = okx_candles(inst_id=inst_id, bar=bar, limit=int(limit))
|
| 502 |
+
except Exception as e:
|
| 503 |
+
return {"error": f"Failed to fetch candles: {e}"}
|
| 504 |
+
|
| 505 |
+
# Step 2: prepare features
|
| 506 |
+
X, feat_cols, df_ready = prepare_inference_features(df)
|
| 507 |
+
|
| 508 |
+
# Step 3: load models
|
| 509 |
+
models = load_models()
|
| 510 |
+
|
| 511 |
+
# Step 4: predict
|
| 512 |
+
preds = predict_ensemble(X, models)
|
| 513 |
+
|
| 514 |
+
# Step 5: build result
|
| 515 |
+
last_close = float(df_ready["close"].iloc[-1])
|
| 516 |
+
ensemble = preds.get("weighted", preds.get("ensemble_mean", last_close))
|
| 517 |
+
|
| 518 |
+
out = {
|
| 519 |
+
"instrument": inst_id,
|
| 520 |
+
"bar": bar,
|
| 521 |
+
"fetched_candles": int(limit),
|
| 522 |
+
"last_ts": df_ready["ts"].iloc[-1].isoformat(),
|
| 523 |
+
"last_close": float(last_close),
|
| 524 |
+
"ensemble_prediction": float(ensemble),
|
| 525 |
+
"per_model": preds.get("per_model", {})
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
# Prepare images
|
| 529 |
+
img_price = plot_price_and_preds(df_ready, {"weighted": ensemble})
|
| 530 |
+
img_contrib = plot_model_contributions(out["per_model"])
|
| 531 |
+
|
| 532 |
+
return {
|
| 533 |
+
"result": out,
|
| 534 |
+
"img_price": img_price,
|
| 535 |
+
"img_contrib": img_contrib
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
# Helper to convert bytes to gradio displayable
|
| 539 |
+
def bytes_to_pil(b: bytes):
|
| 540 |
+
from PIL import Image
|
| 541 |
+
buf = io.BytesIO(b)
|
| 542 |
+
return Image.open(buf)
|
| 543 |
+
|
| 544 |
+
# -------------------------
|
| 545 |
+
# Gradio layout (blank template)
|
| 546 |
+
# -------------------------
|
| 547 |
+
def build_gradio_app():
|
| 548 |
+
title = "BTC/USDT Price Prediction (OKX REST) — Ensemble Demo"
|
| 549 |
+
description = "Fetch recent candles from OKX and predict next close using an ensemble (demo)."
|
| 550 |
+
with gr.Blocks(title=title) as demo:
|
| 551 |
+
gr.Markdown(f"## {title}")
|
| 552 |
+
gr.Markdown(description)
|
| 553 |
+
|
| 554 |
+
with gr.Row():
|
| 555 |
+
with gr.Column(scale=1):
|
| 556 |
+
inst_in = gr.Textbox(label="Instrument", value=DEFAULT_INSTRUMENT)
|
| 557 |
+
bar_in = gr.Dropdown(label="Candle bar", choices=["1m","3m","5m","15m","1H","4H","1D"], value=DEFAULT_BAR)
|
| 558 |
+
limit_in = gr.Slider(label="Limit (number of candles)", minimum=50, maximum=1000, step=50, value=DEFAULT_LIMIT)
|
| 559 |
+
run_btn = gr.Button("Run Inference")
|
| 560 |
+
refresh_btn = gr.Button("Refresh Models (clear cache)")
|
| 561 |
+
info_out = gr.Textbox(label="Info / JSON result", interactive=False)
|
| 562 |
+
with gr.Column(scale=2):
|
| 563 |
+
price_img = gr.Image(label="Price & Prediction", type="pil")
|
| 564 |
+
contrib_img = gr.Image(label="Per-model predictions", type="pil")
|
| 565 |
+
|
| 566 |
+
# Callbacks
|
| 567 |
+
def on_run(inst, bar, limit):
|
| 568 |
+
res = inference_pipeline(inst, bar, limit)
|
| 569 |
+
if "error" in res:
|
| 570 |
+
return "", gr.update(value=None), gr.update(value=None), json.dumps({"error": res["error"]}, indent=2)
|
| 571 |
+
out = res["result"]
|
| 572 |
+
price_pil = bytes_to_pil(res["img_price"])
|
| 573 |
+
contrib_pil = bytes_to_pil(res["img_contrib"])
|
| 574 |
+
info_json = json.dumps(out, indent=2, default=str)
|
| 575 |
+
return price_pil, contrib_pil, info_json
|
| 576 |
+
|
| 577 |
+
def on_refresh():
|
| 578 |
+
# clear model cache and reload
|
| 579 |
+
with _MODEL_LOCK:
|
| 580 |
+
_MODELS.clear()
|
| 581 |
+
return "Model cache cleared."
|
| 582 |
+
|
| 583 |
+
run_btn.click(on_run, inputs=[inst_in, bar_in, limit_in], outputs=[price_img, contrib_img, info_out])
|
| 584 |
+
refresh_btn.click(on_refresh, inputs=None, outputs=info_out)
|
| 585 |
+
|
| 586 |
+
gr.Markdown("Notes: This demo uses public OKX market endpoints. For production, validate rate limits and handle API keys for private data. Ensemble models here are demo-friendly; train and persist stronger models for real use.")
|
| 587 |
+
return demo
|
| 588 |
+
|
| 589 |
+
# -------------------------
|
| 590 |
+
# If run as app
|
| 591 |
+
# -------------------------
|
| 592 |
+
if __name__ == "__main__":
|
| 593 |
+
app = build_gradio_app()
|
| 594 |
+
app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|
| 595 |
+
```
|
| 596 |
+
|
| 597 |
+
Kullanım/Notlar:
|
| 598 |
+
- Bu tek dosya app.py olarak koyulup Spaces'e deploy edilebilir. Ancak dependencies (requirements.txt) doğru kurulmalı: gradio, requests, pandas, numpy, scikit-learn, matplotlib, ta (opsiyonel), lightgbm (opsiyonel), xgboost (opsiyonel), torch (opsiyonel), pillow.
|
| 599 |
+
- Eğer LightGBM/XGBoost/PyTorch kurulmazsa kod bunların yokluğuna dayanacak (demo model ile çalışır).
|
| 600 |
+
- Gerçek model eğitimi için ek `train.py` ve model kayıt adımları eklenmeli; istersen onu da üretirim.
|
| 601 |
+
|
| 602 |
+
İstersen aynı dosyayı eğitim ve model kaydetme yeteneği eklenmiş hâliyle (train/save) de veririm. Hemen başka bir şey ekleyeyim mi?
|