GoshawkVortexAI commited on
Commit
354aba9
·
verified ·
1 Parent(s): f9e6c03

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +602 -0
app.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python
2
+ # app.py
3
+ """
4
+ Gradio (blank) tabanlı Hugging Face Space uygulaması.
5
+ - OKX REST API'den BTC/USDT (spot) candle verisi çeker
6
+ - Teknik göstergeler üretir
7
+ - Ensemble: LightGBM, XGBoost, RandomForest (sklearn) + küçük PyTorch LSTM
8
+ - Eğer pretrained model dosyaları yoksa küçük demo modeller oluşturur
9
+ - Outputs: tahmin (regresyon: next-close), model katkıları, grafikler
10
+
11
+ Not:
12
+ - requirements.txt'de aşağıdakiler olmalı:
13
+ gradio, pandas, numpy, requests, ta, scikit-learn, lightgbm, xgboost, torch, matplotlib
14
+ - Kullanıcı OKX API anahtarı gerekli değildir (public candles endpoint kullanılıyor).
15
+ - Bu dosya tek başına çalışır; ancak ağır paketler (lightgbm, xgboost, torch) Spaces ortamında kurulmadıysa hata verebilir.
16
+ """
17
+
18
+ import os
19
+ import io
20
+ import time
21
+ import math
22
+ import json
23
+ import threading
24
+ from typing import Tuple, Dict, Any, List
25
+
26
+ import numpy as np
27
+ import pandas as pd
28
+ import requests
29
+ from datetime import datetime, timedelta, timezone
30
+
31
+ # Visualization
32
+ import matplotlib
33
+ matplotlib.use("Agg")
34
+ import matplotlib.pyplot as plt
35
+
36
+ # Technical indicators
37
+ try:
38
+ import ta
39
+ except Exception:
40
+ # Minimal fallback implementations if ta isn't installed
41
+ ta = None
42
+
43
+ # ML libs
44
+ from sklearn.ensemble import RandomForestRegressor
45
+ from sklearn.preprocessing import StandardScaler
46
+ from sklearn.pipeline import Pipeline
47
+ from sklearn.base import BaseEstimator, RegressorMixin
48
+
49
+ # Try import optional libs
50
+ HAS_LGB = True
51
+ HAS_XGB = True
52
+ HAS_TORCH = True
53
+ try:
54
+ import lightgbm as lgb
55
+ except Exception:
56
+ HAS_LGB = False
57
+ try:
58
+ import xgboost as xgb
59
+ except Exception:
60
+ HAS_XGB = False
61
+ try:
62
+ import torch
63
+ import torch.nn as nn
64
+ import torch.nn.functional as F
65
+ from torch.utils.data import DataLoader, TensorDataset
66
+ except Exception:
67
+ HAS_TORCH = False
68
+
69
+ # Gradio
70
+ import gradio as gr
71
+
72
+ # -------------------------
73
+ # Configuration/Constants
74
+ # -------------------------
75
+ OKX_BASE = "https://www.okx.com"
76
+ # Public candles: GET /api/v5/market/history-candles?instId=BTC-USDT-SWAP&bar=1m&limit=100
77
+ # We'll use spot: BTC-USDT
78
+ DEFAULT_INSTRUMENT = "BTC-USDT"
79
+ DEFAULT_BAR = "1m" # options: 1m, 3m, 5m, 15m, 1H etc.
80
+ DEFAULT_LIMIT = 500 # up to 1000 depending on endpoint
81
+
82
+ # Model filenames (in repo or persisted by training)
83
+ MODEL_DIR = "models"
84
+ os.makedirs(MODEL_DIR, exist_ok=True)
85
+ LGB_MODEL_FILE = os.path.join(MODEL_DIR, "lgb_model.txt")
86
+ XGB_MODEL_FILE = os.path.join(MODEL_DIR, "xgb_model.json")
87
+ RF_MODEL_FILE = os.path.join(MODEL_DIR, "rf_model.pkl")
88
+ LSTM_MODEL_FILE = os.path.join(MODEL_DIR, "lstm_model.pt")
89
+ SCALER_FILE = os.path.join(MODEL_DIR, "scaler.npy") # save scaler mean/scale
90
+
91
+ # Thread-safe model cache
92
+ _MODEL_LOCK = threading.Lock()
93
+ _MODELS = {}
94
+
95
+ # -------------------------
96
+ # Utilities
97
+ # -------------------------
98
+ def now_iso():
99
+ return datetime.now(timezone.utc).isoformat()
100
+
101
+ def okx_candles(inst_id: str = DEFAULT_INSTRUMENT, bar: str = DEFAULT_BAR, limit: int = DEFAULT_LIMIT) -> pd.DataFrame:
102
+ """
103
+ Fetch recent candle data from OKX public REST API.
104
+ Returns DataFrame with columns: ts, open, high, low, close, volume
105
+ ts in UTC datetime
106
+ """
107
+ url = f"{OKX_BASE}/api/v5/market/history-candles"
108
+ params = {"instId": inst_id, "bar": bar, "limit": str(limit)}
109
+ resp = requests.get(url, params=params, timeout=15)
110
+ resp.raise_for_status()
111
+ data = resp.json()
112
+
113
+ if not data or data.get("code") not in (None, "0", 0):
114
+ # OKX returns "code": "0" on success sometimes; be permissive
115
+ # If structure unexpected, raise
116
+ # Try to parse anyway
117
+ pass
118
+
119
+ cand = data.get("data", [])
120
+ if not cand:
121
+ # Possibly different field
122
+ raise RuntimeError("No candle data returned from OKX")
123
+
124
+ # OKX returns list of lists: [ts, open, high, low, close, volume, ...]
125
+ # timestamp in millis
126
+ rows = []
127
+ for c in cand:
128
+ # According to OKX docs: [ts, open, high, low, close, volume]
129
+ ts = int(c[0]) // 1000 if len(str(c[0])) > 10 else int(c[0])
130
+ dt = datetime.fromtimestamp(ts, tz=timezone.utc)
131
+ rows.append({
132
+ "ts": dt,
133
+ "open": float(c[1]),
134
+ "high": float(c[2]),
135
+ "low": float(c[3]),
136
+ "close": float(c[4]),
137
+ "volume": float(c[5])
138
+ })
139
+ df = pd.DataFrame(rows)
140
+ df = df.sort_values("ts").reset_index(drop=True)
141
+ return df
142
+
143
+ # Minimal TA indicators if `ta` package is not available
144
+ def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
145
+ df = df.copy()
146
+ if ta is not None:
147
+ # Use ta to add common indicators
148
+ df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14, fillna=True).rsi()
149
+ df["ema12"] = ta.trend.EMAIndicator(df["close"], window=12, fillna=True).ema_indicator()
150
+ df["ema26"] = ta.trend.EMAIndicator(df["close"], window=26, fillna=True).ema_indicator()
151
+ macd = ta.trend.MACD(df["close"], window_slow=26, window_fast=12, window_sign=9, fillna=True)
152
+ df["macd"] = macd.macd()
153
+ df["macd_signal"] = macd.macd_signal()
154
+ df["bb_high"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_hband()
155
+ df["bb_low"] = ta.volatility.BollingerBands(df["close"], window=20, fillna=True).bollinger_lband()
156
+ df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"], window=14, fillna=True).average_true_range()
157
+ else:
158
+ # Fallback simple computations
159
+ df["rsi"] = simple_rsi(df["close"], window=14)
160
+ df["ema12"] = df["close"].ewm(span=12, adjust=False).mean()
161
+ df["ema26"] = df["close"].ewm(span=26, adjust=False).mean()
162
+ df["macd"] = df["ema12"] - df["ema26"]
163
+ df["macd_signal"] = df["macd"].ewm(span=9, adjust=False).mean()
164
+ df["bb_mid"] = df["close"].rolling(20).mean()
165
+ df["bb_std"] = df["close"].rolling(20).std()
166
+ df["bb_high"] = df["bb_mid"] + 2 * df["bb_std"]
167
+ df["bb_low"] = df["bb_mid"] - 2 * df["bb_std"]
168
+ df["atr"] = simple_atr(df, window=14)
169
+ # Fill na
170
+ df = df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)
171
+ return df
172
+
173
+ def simple_rsi(series: pd.Series, window: int = 14) -> pd.Series:
174
+ delta = series.diff()
175
+ up = delta.clip(lower=0)
176
+ down = -1 * delta.clip(upper=0)
177
+ ma_up = up.ewm(alpha=1/window, adjust=False).mean()
178
+ ma_down = down.ewm(alpha=1/window, adjust=False).mean()
179
+ rs = ma_up / (ma_down + 1e-8)
180
+ rsi = 100 - (100 / (1 + rs))
181
+ return rsi.fillna(50.0)
182
+
183
+ def simple_atr(df: pd.DataFrame, window: int = 14) -> pd.Series:
184
+ high_low = df["high"] - df["low"]
185
+ high_close = (df["high"] - df["close"].shift()).abs()
186
+ low_close = (df["low"] - df["close"].shift()).abs()
187
+ tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
188
+ atr = tr.ewm(span=window, adjust=False).mean()
189
+ return atr.fillna(0.0)
190
+
191
+ def create_features(df: pd.DataFrame) -> pd.DataFrame:
192
+ df = df.copy()
193
+ df = add_technical_indicators(df)
194
+ # Returns features aligned to each row predicting next row's close
195
+ # Feature engineering: returns, log returns, vol, moving averages, ratios
196
+ df["return_1"] = df["close"].pct_change().fillna(0.0)
197
+ df["log_return_1"] = np.log1p(df["return_1"])
198
+ df["vol_5"] = df["close"].rolling(5).std().fillna(0.0)
199
+ df["vol_20"] = df["close"].rolling(20).std().fillna(0.0)
200
+ df["ma_5"] = df["close"].rolling(5).mean().fillna(method="bfill")
201
+ df["ma_20"] = df["close"].rolling(20).mean().fillna(method="bfill")
202
+ df["ma_50"] = df["close"].rolling(50).mean().fillna(method="bfill")
203
+ # ratio features
204
+ df["ma5_div_ma20"] = df["ma_5"] / (df["ma_20"] + 1e-9)
205
+ df["ema_diff"] = df["ema12"] - df["ema26"]
206
+ # time features
207
+ df["ts_unix"] = df["ts"].astype(np.int64) // 10**9
208
+ df["hour"] = df["ts"].dt.hour
209
+ df["minute"] = df["ts"].dt.minute
210
+ # fill remaining na
211
+ df = df.fillna(method="bfill").fillna(0.0)
212
+ return df
213
+
214
+ # -------------------------
215
+ # Model wrappers and helpers
216
+ # -------------------------
217
+ class DummyRegressor(BaseEstimator, RegressorMixin):
218
+ """Simple mean predictor used as fallback."""
219
+ def fit(self, X, y):
220
+ self._mean = np.mean(y) if len(y) else 0.0
221
+ return self
222
+ def predict(self, X):
223
+ return np.full((X.shape[0],), getattr(self, "_mean", 0.0))
224
+
225
+ def save_numpy(obj: np.ndarray, path: str):
226
+ np.save(path, obj)
227
+
228
+ def load_numpy(path: str) -> np.ndarray:
229
+ return np.load(path)
230
+
231
+ def get_feature_columns() -> List[str]:
232
+ cols = [
233
+ "open","high","low","close","volume",
234
+ "rsi","ema12","ema26","macd","macd_signal","bb_high","bb_low","atr",
235
+ "return_1","log_return_1","vol_5","vol_20","ma_5","ma_20","ma_50",
236
+ "ma5_div_ma20","ema_diff","ts_unix","hour","minute"
237
+ ]
238
+ return cols
239
+
240
+ # Model persistence helpers (light, simple)
241
+ def load_models() -> Dict[str, Any]:
242
+ """
243
+ Try to load pretrained models from MODEL_DIR. If missing, create small demo models.
244
+ Returns dict of models and scaler.
245
+ """
246
+ with _MODEL_LOCK:
247
+ if _MODELS:
248
+ return _MODELS
249
+
250
+ models = {}
251
+ scaler = None
252
+
253
+ # Try load scaler if exists
254
+ if os.path.exists(SCALER_FILE):
255
+ try:
256
+ sc = np.load(SCALER_FILE, allow_pickle=True).item()
257
+ scaler = StandardScaler()
258
+ scaler.mean_ = sc["mean"]
259
+ scaler.scale_ = sc["scale"]
260
+ scaler.n_features_in_ = sc["n_in"]
261
+ except Exception:
262
+ scaler = None
263
+
264
+ # RandomForest (sklearn)
265
+ try:
266
+ import joblib
267
+ if os.path.exists(RF_MODEL_FILE):
268
+ models["rf"] = joblib.load(RF_MODEL_FILE)
269
+ else:
270
+ raise FileNotFoundError
271
+ except Exception:
272
+ # create small RF demo
273
+ models["rf"] = RandomForestRegressor(n_estimators=10, random_state=42)
274
+
275
+ # LightGBM
276
+ if HAS_LGB and os.path.exists(LGB_MODEL_FILE):
277
+ try:
278
+ models["lgb"] = lgb.Booster(model_file=LGB_MODEL_FILE)
279
+ except Exception:
280
+ models["lgb"] = None
281
+ else:
282
+ models["lgb"] = None if not HAS_LGB else None
283
+
284
+ # XGBoost
285
+ if HAS_XGB and os.path.exists(XGB_MODEL_FILE):
286
+ try:
287
+ models["xgb"] = xgb.Booster()
288
+ models["xgb"].load_model(XGB_MODEL_FILE)
289
+ except Exception:
290
+ models["xgb"] = None
291
+ else:
292
+ models["xgb"] = None
293
+
294
+ # LSTM / PyTorch
295
+ if HAS_TORCH and os.path.exists(LSTM_MODEL_FILE):
296
+ try:
297
+ lstm = torch.load(LSTM_MODEL_FILE, map_location=torch.device("cpu"))
298
+ models["lstm"] = lstm
299
+ except Exception:
300
+ models["lstm"] = None
301
+ else:
302
+ models["lstm"] = None
303
+
304
+ # If scaler missing, create a dummy one later in pipeline when training; for inference create StandardScaler default
305
+ if scaler is None:
306
+ scaler = StandardScaler()
307
+
308
+ # Create an ensemble wrapper
309
+ models["scaler"] = scaler
310
+
311
+ _MODELS.update(models)
312
+ return _MODELS
313
+
314
+ def save_scaler(scaler: StandardScaler, path: str = SCALER_FILE):
315
+ obj = {"mean": scaler.mean_, "scale": scaler.scale_, "n_in": scaler.n_features_in_}
316
+ np.save(path, obj)
317
+
318
+ # -------------------------
319
+ # Inference logic
320
+ # -------------------------
321
+ def prepare_inference_features(df: pd.DataFrame) -> Tuple[np.ndarray, List[str], pd.DataFrame]:
322
+ """
323
+ Takes raw candles df, returns (X, feature_cols, df_ready)
324
+ X is 2D array for model input, aligned so that each row predicts next close.
325
+ """
326
+ df2 = create_features(df)
327
+ feat_cols = get_feature_columns()
328
+ # Ensure columns present
329
+ for c in feat_cols:
330
+ if c not in df2.columns:
331
+ df2[c] = 0.0
332
+ X = df2[feat_cols].values
333
+ return X, feat_cols, df2
334
+
335
+ def predict_ensemble(X: np.ndarray, models: Dict[str, Any]) -> Dict[str, Any]:
336
+ """
337
+ Predict next-step close using ensemble of models.
338
+ Return dict:
339
+ - per_model_preds: {name: scalar_pred}
340
+ - ensemble_mean: float
341
+ - weighted: float (weights fallback equal)
342
+ """
343
+ scaler = models.get("scaler", None)
344
+ if scaler is None:
345
+ scaler = StandardScaler()
346
+ # Use last row features to predict next
347
+ if X.ndim == 1:
348
+ X_row = X.reshape(1, -1)
349
+ else:
350
+ X_row = X[-1:, :]
351
+ # scale
352
+ try:
353
+ Xs = scaler.transform(X_row)
354
+ except Exception:
355
+ # If scaler not fitted, fit on X (fallback)
356
+ try:
357
+ scaler.fit(X)
358
+ save_scaler(scaler)
359
+ Xs = scaler.transform(X_row)
360
+ except Exception:
361
+ Xs = X_row
362
+
363
+ preds = {}
364
+ # RandomForest
365
+ rf = models.get("rf", None)
366
+ if rf is not None:
367
+ try:
368
+ p = rf.predict(Xs)[0]
369
+ except Exception:
370
+ p = float(np.nan)
371
+ else:
372
+ p = float(np.nan)
373
+ preds["rf"] = float(p)
374
+
375
+ # LightGBM
376
+ if HAS_LGB and models.get("lgb", None) is not None:
377
+ try:
378
+ dmat = lgb.Dataset(Xs, free_raw_data=False)
379
+ p = models["lgb"].predict(Xs)[0]
380
+ except Exception:
381
+ p = float(np.nan)
382
+ else:
383
+ p = float(np.nan)
384
+ preds["lgb"] = float(p)
385
+
386
+ # XGBoost
387
+ if HAS_XGB and models.get("xgb", None) is not None:
388
+ try:
389
+ dm = xgb.DMatrix(Xs)
390
+ p = models["xgb"].predict(dm)[0]
391
+ except Exception:
392
+ p = float(np.nan)
393
+ else:
394
+ p = float(np.nan)
395
+ preds["xgb"] = float(p)
396
+
397
+ # LSTM (PyTorch)
398
+ if HAS_TORCH and models.get("lstm", None) is not None:
399
+ try:
400
+ model = models["lstm"]
401
+ model.eval()
402
+ with torch.no_grad():
403
+ t = torch.tensor(X_row, dtype=torch.float32).unsqueeze(0) # shape (1,1,features) if expected
404
+ # try both (1,features) or (1,seq,features)
405
+ if t.dim() == 3:
406
+ out = model(t)
407
+ else:
408
+ # reshape to (1,1,features)
409
+ t2 = t.unsqueeze(1)
410
+ out = model(t2)
411
+ p = float(out.squeeze().cpu().numpy())
412
+ except Exception:
413
+ p = float(np.nan)
414
+ else:
415
+ p = float(np.nan)
416
+ preds["lstm"] = float(p)
417
+
418
+ # If models missing, fallback: use RF or mean of last price as naive
419
+ valid_preds = [v for v in preds.values() if not (math.isnan(v) or v is None)]
420
+ if not valid_preds:
421
+ # fallback naive next-close = last close
422
+ naive = float(X_row[0, get_feature_columns().index("close")])
423
+ ensemble_mean = naive
424
+ weighted = naive
425
+ else:
426
+ ensemble_mean = float(np.nanmean(valid_preds))
427
+ # Simple weighting: prefer models that exist; equal weight
428
+ weighted = ensemble_mean
429
+
430
+ return {
431
+ "per_model": preds,
432
+ "ensemble_mean": ensemble_mean,
433
+ "weighted": weighted
434
+ }
435
+
436
+ # -------------------------
437
+ # LSTM simple architecture (for demo)
438
+ # -------------------------
439
+ if HAS_TORCH:
440
+ class SimpleLSTM(nn.Module):
441
+ def __init__(self, input_size: int, hidden_size: int = 32, num_layers: int = 1):
442
+ super().__init__()
443
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
444
+ self.fc = nn.Linear(hidden_size, 1)
445
+ def forward(self, x):
446
+ # x: (batch, seq_len, input_size)
447
+ out, _ = self.lstm(x)
448
+ # take last time step
449
+ last = out[:, -1, :]
450
+ return self.fc(last)
451
+
452
+ # -------------------------
453
+ # Visualization helpers
454
+ # -------------------------
455
+ def plot_price_and_preds(df: pd.DataFrame, preds: Dict[str, Any]) -> bytes:
456
+ fig, ax = plt.subplots(figsize=(9,4))
457
+ ax.plot(df["ts"], df["close"], label="close", color="black", lw=1)
458
+ # mark last price and ensemble prediction
459
+ last_ts = df["ts"].iloc[-1]
460
+ last_close = df["close"].iloc[-1]
461
+ pred = preds.get("weighted", preds.get("ensemble_mean", last_close))
462
+ ax.scatter([last_ts + pd.Timedelta(seconds=1)], [pred], color="red", label="ensemble_pred")
463
+ ax.axhline(last_close, linestyle="--", color="gray", alpha=0.6)
464
+ ax.set_title("BTC/USDT close and ensemble prediction")
465
+ ax.set_xlabel("Time (UTC)")
466
+ ax.set_ylabel("Price")
467
+ ax.legend()
468
+ fig.tight_layout()
469
+ buf = io.BytesIO()
470
+ fig.savefig(buf, format="png")
471
+ plt.close(fig)
472
+ buf.seek(0)
473
+ return buf.read()
474
+
475
+ def plot_model_contributions(per_model: Dict[str, float]) -> bytes:
476
+ names = list(per_model.keys())
477
+ vals = [per_model[n] if (not math.isnan(per_model[n])) else 0.0 for n in names]
478
+ fig, ax = plt.subplots(figsize=(6,3))
479
+ ax.bar(names, vals, color=["#1f77b4","#ff7f0e","#2ca02c","#d62728"])
480
+ ax.set_title("Per-model predictions (abs values)")
481
+ ax.set_ylabel("Predicted price")
482
+ fig.tight_layout()
483
+ buf = io.BytesIO()
484
+ fig.savefig(buf, format="png")
485
+ plt.close(fig)
486
+ buf.seek(0)
487
+ return buf.read()
488
+
489
+ # -------------------------
490
+ # Gradio app components
491
+ # -------------------------
492
+ def inference_pipeline(inst_id: str = DEFAULT_INSTRUMENT,
493
+ bar: str = DEFAULT_BAR,
494
+ limit: int = DEFAULT_LIMIT,
495
+ show_plot: bool = True):
496
+ """
497
+ High-level function called by Gradio. Returns JSON/dicts + image bytes for display.
498
+ """
499
+ # Step 1: fetch candles
500
+ try:
501
+ df = okx_candles(inst_id=inst_id, bar=bar, limit=int(limit))
502
+ except Exception as e:
503
+ return {"error": f"Failed to fetch candles: {e}"}
504
+
505
+ # Step 2: prepare features
506
+ X, feat_cols, df_ready = prepare_inference_features(df)
507
+
508
+ # Step 3: load models
509
+ models = load_models()
510
+
511
+ # Step 4: predict
512
+ preds = predict_ensemble(X, models)
513
+
514
+ # Step 5: build result
515
+ last_close = float(df_ready["close"].iloc[-1])
516
+ ensemble = preds.get("weighted", preds.get("ensemble_mean", last_close))
517
+
518
+ out = {
519
+ "instrument": inst_id,
520
+ "bar": bar,
521
+ "fetched_candles": int(limit),
522
+ "last_ts": df_ready["ts"].iloc[-1].isoformat(),
523
+ "last_close": float(last_close),
524
+ "ensemble_prediction": float(ensemble),
525
+ "per_model": preds.get("per_model", {})
526
+ }
527
+
528
+ # Prepare images
529
+ img_price = plot_price_and_preds(df_ready, {"weighted": ensemble})
530
+ img_contrib = plot_model_contributions(out["per_model"])
531
+
532
+ return {
533
+ "result": out,
534
+ "img_price": img_price,
535
+ "img_contrib": img_contrib
536
+ }
537
+
538
+ # Helper to convert bytes to gradio displayable
539
+ def bytes_to_pil(b: bytes):
540
+ from PIL import Image
541
+ buf = io.BytesIO(b)
542
+ return Image.open(buf)
543
+
544
+ # -------------------------
545
+ # Gradio layout (blank template)
546
+ # -------------------------
547
+ def build_gradio_app():
548
+ title = "BTC/USDT Price Prediction (OKX REST) — Ensemble Demo"
549
+ description = "Fetch recent candles from OKX and predict next close using an ensemble (demo)."
550
+ with gr.Blocks(title=title) as demo:
551
+ gr.Markdown(f"## {title}")
552
+ gr.Markdown(description)
553
+
554
+ with gr.Row():
555
+ with gr.Column(scale=1):
556
+ inst_in = gr.Textbox(label="Instrument", value=DEFAULT_INSTRUMENT)
557
+ bar_in = gr.Dropdown(label="Candle bar", choices=["1m","3m","5m","15m","1H","4H","1D"], value=DEFAULT_BAR)
558
+ limit_in = gr.Slider(label="Limit (number of candles)", minimum=50, maximum=1000, step=50, value=DEFAULT_LIMIT)
559
+ run_btn = gr.Button("Run Inference")
560
+ refresh_btn = gr.Button("Refresh Models (clear cache)")
561
+ info_out = gr.Textbox(label="Info / JSON result", interactive=False)
562
+ with gr.Column(scale=2):
563
+ price_img = gr.Image(label="Price & Prediction", type="pil")
564
+ contrib_img = gr.Image(label="Per-model predictions", type="pil")
565
+
566
+ # Callbacks
567
+ def on_run(inst, bar, limit):
568
+ res = inference_pipeline(inst, bar, limit)
569
+ if "error" in res:
570
+ return "", gr.update(value=None), gr.update(value=None), json.dumps({"error": res["error"]}, indent=2)
571
+ out = res["result"]
572
+ price_pil = bytes_to_pil(res["img_price"])
573
+ contrib_pil = bytes_to_pil(res["img_contrib"])
574
+ info_json = json.dumps(out, indent=2, default=str)
575
+ return price_pil, contrib_pil, info_json
576
+
577
+ def on_refresh():
578
+ # clear model cache and reload
579
+ with _MODEL_LOCK:
580
+ _MODELS.clear()
581
+ return "Model cache cleared."
582
+
583
+ run_btn.click(on_run, inputs=[inst_in, bar_in, limit_in], outputs=[price_img, contrib_img, info_out])
584
+ refresh_btn.click(on_refresh, inputs=None, outputs=info_out)
585
+
586
+ gr.Markdown("Notes: This demo uses public OKX market endpoints. For production, validate rate limits and handle API keys for private data. Ensemble models here are demo-friendly; train and persist stronger models for real use.")
587
+ return demo
588
+
589
+ # -------------------------
590
+ # If run as app
591
+ # -------------------------
592
+ if __name__ == "__main__":
593
+ app = build_gradio_app()
594
+ app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
595
+ ```
596
+
597
+ Kullanım/Notlar:
598
+ - Bu tek dosya app.py olarak koyulup Spaces'e deploy edilebilir. Ancak dependencies (requirements.txt) doğru kurulmalı: gradio, requests, pandas, numpy, scikit-learn, matplotlib, ta (opsiyonel), lightgbm (opsiyonel), xgboost (opsiyonel), torch (opsiyonel), pillow.
599
+ - Eğer LightGBM/XGBoost/PyTorch kurulmazsa kod bunların yokluğuna dayanacak (demo model ile çalışır).
600
+ - Gerçek model eğitimi için ek `train.py` ve model kayıt adımları eklenmeli; istersen onu da üretirim.
601
+
602
+ İstersen aynı dosyayı eğitim ve model kaydetme yeteneği eklenmiş hâliyle (train/save) de veririm. Hemen başka bir şey ekleyeyim mi?