Spaces:
Running
Running
| """ | |
| app.py — Gradio Interface for the Tabular AutoML Framework | |
| Run: python app.py | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from matplotlib.gridspec import GridSpec | |
| import io, os, sys, time, threading, json, textwrap | |
| from pathlib import Path | |
| # ── make sure the automl package is importable ──────────────────────────────── | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| from automl import AutoML | |
| # ── Gradio version compatibility check ─────────────────────────────────────── | |
| import gradio as _gr_check | |
| _gr_version = tuple(int(x) for x in _gr_check.__version__.split(".")[:2]) | |
| _is_gradio_4_plus = _gr_version[0] >= 4 | |
| _is_gradio_6_plus = _gr_version[0] >= 6 | |
| print(f" Gradio version: {_gr_check.__version__}") | |
| # gr.File type param: needed in 3.x only | |
| _FILE_KWARGS = {} if _is_gradio_4_plus else {"type": "file"} | |
| # ── Global state ────────────────────────────────────────────────────────────── | |
| _state: dict = { | |
| "automl": None, | |
| "df": None, | |
| "log_lines": [], | |
| "running": False, | |
| } | |
| # ── File path compatibility helper (Gradio 3.x / 4.x / 5.x / 6.x) ────────── | |
| def _get_filepath(file): | |
| """Handle gr.File output across ALL Gradio versions including 6.x.""" | |
| if file is None: | |
| return None | |
| # Gradio 6.x: returns plain string filepath directly | |
| if isinstance(file, str): | |
| return file | |
| # Gradio 6.x: sometimes returns a list (multiple files) | |
| if isinstance(file, list): | |
| return file[0] if file else None | |
| # Gradio 3.x early: returns dict | |
| if isinstance(file, dict): | |
| return file.get("name") or file.get("path") or file.get("tmp_path") | |
| # Gradio 4.x / 5.x: UploadData with .path attribute | |
| if hasattr(file, "path"): | |
| return file.path | |
| # Gradio 3.x late: object with .name | |
| if hasattr(file, "name"): | |
| return file.name | |
| # Last resort | |
| return str(file) | |
| PALETTE = { | |
| "bg": "#0d1117", | |
| "surface": "#161b22", | |
| "border": "#30363d", | |
| "accent": "#58a6ff", | |
| "green": "#3fb950", | |
| "yellow": "#d29922", | |
| "red": "#f85149", | |
| "text": "#e6edf3", | |
| "muted": "#8b949e", | |
| } | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Helpers | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def _fig_to_pil(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", | |
| facecolor=fig.get_facecolor(), dpi=130) | |
| buf.seek(0) | |
| from PIL import Image | |
| img = Image.open(buf).copy() | |
| buf.close() | |
| plt.close(fig) | |
| return img | |
| def _styled_fig(w=12, h=6): | |
| fig = plt.figure(figsize=(w, h), facecolor=PALETTE["bg"]) | |
| return fig | |
| def _ax_style(ax, title="", xlabel="", ylabel=""): | |
| ax.set_facecolor(PALETTE["surface"]) | |
| ax.tick_params(colors=PALETTE["muted"], labelsize=9) | |
| ax.spines[:].set_color(PALETTE["border"]) | |
| if title: ax.set_title(title, color=PALETTE["text"], fontsize=11, pad=10, fontweight="bold") | |
| if xlabel: ax.set_xlabel(xlabel, color=PALETTE["muted"], fontsize=9) | |
| if ylabel: ax.set_ylabel(ylabel, color=PALETTE["muted"], fontsize=9) | |
| ax.tick_params(axis="x", colors=PALETTE["muted"]) | |
| ax.tick_params(axis="y", colors=PALETTE["muted"]) | |
| return ax | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 1 — Upload & Preview | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def handle_upload(file): | |
| if file is None: | |
| return (gr.update(choices=[], value=None), | |
| gr.update(value="<p style='color:#8b949e;'>Upload a CSV to see summary.</p>"), | |
| gr.update(value=None)) | |
| try: | |
| print(f" [Upload] file type: {type(file)}, value: {repr(file)[:200]}") | |
| filepath = _get_filepath(file) | |
| print(f" [Upload] resolved filepath: {filepath}") | |
| if filepath is None: | |
| return (gr.update(choices=[], value=None), | |
| gr.update(value="<p style='color:red'>Could not read file. Try uploading again.</p>"), | |
| gr.update(value=None)) | |
| df = pd.read_csv(filepath) | |
| _state["df"] = df | |
| cols = df.columns.tolist() | |
| # Build a rich HTML preview | |
| n_rows, n_cols = df.shape | |
| missing = df.isnull().sum().sum() | |
| dtypes = df.dtypes.value_counts().to_dict() | |
| dtype_str = ", ".join(f"{v}× {k}" for k, v in dtypes.items()) | |
| summary_html = f""" | |
| <div style="font-family:'JetBrains Mono',monospace; | |
| background:#161b22;border:1px solid #30363d; | |
| border-radius:10px;padding:18px;color:#e6edf3;"> | |
| <div style="display:flex;gap:30px;margin-bottom:14px;flex-wrap:wrap;"> | |
| <div style="text-align:center;"> | |
| <div style="font-size:28px;font-weight:700;color:#58a6ff;">{n_rows:,}</div> | |
| <div style="font-size:11px;color:#8b949e;">ROWS</div> | |
| </div> | |
| <div style="text-align:center;"> | |
| <div style="font-size:28px;font-weight:700;color:#3fb950;">{n_cols}</div> | |
| <div style="font-size:11px;color:#8b949e;">COLUMNS</div> | |
| </div> | |
| <div style="text-align:center;"> | |
| <div style="font-size:28px;font-weight:700;color:#d29922;">{missing:,}</div> | |
| <div style="font-size:11px;color:#8b949e;">MISSING CELLS</div> | |
| </div> | |
| </div> | |
| <div style="font-size:11px;color:#8b949e;border-top:1px solid #30363d;padding-top:10px;"> | |
| Dtypes: {dtype_str} | |
| </div> | |
| </div> | |
| """ | |
| # First 6 rows as styled HTML table | |
| preview_df = df.head(6) | |
| tbl = preview_df.to_html(index=False, border=0, classes="preview-tbl") | |
| table_html = f""" | |
| <style> | |
| .preview-tbl {{width:100%;border-collapse:collapse;font-family:'JetBrains Mono',monospace;font-size:12px;}} | |
| .preview-tbl th {{background:#21262d;color:#58a6ff;padding:8px 12px;text-align:left;border-bottom:1px solid #30363d;}} | |
| .preview-tbl td {{padding:6px 12px;color:#e6edf3;border-bottom:1px solid #21262d;}} | |
| .preview-tbl tr:hover td {{background:#1c2128;}} | |
| </style> | |
| <div style="overflow-x:auto;border-radius:8px;border:1px solid #30363d;"> | |
| {tbl} | |
| </div> | |
| """ | |
| combined = summary_html + "<br>" + table_html | |
| return (gr.update(choices=cols, value=cols[-1]), | |
| gr.update(value=combined), | |
| gr.update(value=None)) | |
| except Exception as e: | |
| import traceback | |
| err_detail = traceback.format_exc() | |
| print(f" [Upload Error] {err_detail}") | |
| return (gr.update(choices=[], value=None), | |
| gr.update(value=f"<div style='color:#f85149;font-family:monospace;padding:12px;background:#161b22;border:1px solid #f85149;border-radius:6px;'><b>Error loading file:</b><br>{str(e)}<br><br><small>Check HF Space logs for details.</small></div>"), | |
| gr.update(value=None)) | |
| def show_column_stats(col_name): | |
| df = _state.get("df") | |
| if df is None or not col_name: | |
| return gr.update(value=None) | |
| s = df[col_name] | |
| fig = _styled_fig(10, 3.5) | |
| gs = GridSpec(1, 2, figure=fig, wspace=0.4) | |
| # Distribution plot | |
| ax1 = fig.add_subplot(gs[0]) | |
| _ax_style(ax1, title=f"Distribution — {col_name}") | |
| if pd.api.types.is_numeric_dtype(s): | |
| clean = s.dropna() | |
| ax1.hist(clean, bins=30, color=PALETTE["accent"], alpha=0.85, edgecolor="none") | |
| ax1.axvline(clean.mean(), color=PALETTE["yellow"], lw=1.5, linestyle="--", label=f"mean={clean.mean():.2f}") | |
| ax1.legend(fontsize=8, labelcolor=PALETTE["muted"], facecolor=PALETTE["surface"], edgecolor=PALETTE["border"]) | |
| else: | |
| vc = s.value_counts().head(12) | |
| bars = ax1.barh(vc.index.astype(str), vc.values, color=PALETTE["accent"], alpha=0.85) | |
| ax1.invert_yaxis() | |
| # Stats panel | |
| ax2 = fig.add_subplot(gs[1]) | |
| ax2.axis("off") | |
| lines = [f"dtype: {s.dtype}", | |
| f"missing: {s.isnull().sum()} ({s.isnull().mean()*100:.1f}%)", | |
| f"unique: {s.nunique()}"] | |
| if pd.api.types.is_numeric_dtype(s): | |
| lines += [f"mean: {s.mean():.3f}", | |
| f"std: {s.std():.3f}", | |
| f"min: {s.min():.3f}", | |
| f"25%: {s.quantile(.25):.3f}", | |
| f"median: {s.median():.3f}", | |
| f"75%: {s.quantile(.75):.3f}", | |
| f"max: {s.max():.3f}"] | |
| txt = "\n".join(lines) | |
| ax2.text(0.05, 0.95, txt, transform=ax2.transAxes, | |
| va="top", ha="left", fontsize=9.5, | |
| fontfamily="monospace", color=PALETTE["text"], | |
| bbox=dict(boxstyle="round,pad=0.6", facecolor=PALETTE["surface"], | |
| edgecolor=PALETTE["border"])) | |
| fig.patch.set_facecolor(PALETTE["bg"]) | |
| return gr.update(value=_fig_to_pil(fig)) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 2 — Training | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class LogCapture: | |
| """Redirect stdout to both terminal and our log buffer.""" | |
| def __init__(self, original): | |
| self.original = original | |
| self.lines = _state["log_lines"] | |
| def write(self, msg): | |
| self.original.write(msg) | |
| if msg.strip(): | |
| ts = time.strftime("%H:%M:%S") | |
| self.lines.append(f"[{ts}] {msg.rstrip()}") | |
| def flush(self): | |
| self.original.flush() | |
| def run_training(target_col, task_type, time_budget, n_trials, | |
| use_fe, val_size, test_size, seed): | |
| df = _state.get("df") | |
| if df is None: | |
| return ("❌ Please upload a dataset first.", "", None, None) | |
| if not target_col: | |
| return ("❌ Please select a target column.", "", None, None) | |
| _state["log_lines"].clear() | |
| _state["running"] = True | |
| original_stdout = sys.stdout | |
| sys.stdout = LogCapture(original_stdout) | |
| status = "✅ Training complete!" | |
| try: | |
| budget = float(time_budget) if time_budget and float(time_budget) > 0 else None | |
| automl = AutoML( | |
| task_type=task_type, | |
| time_budget=budget, | |
| n_optuna_trials=int(n_trials), | |
| val_size=float(val_size), | |
| test_size=float(test_size), | |
| seed=int(seed), | |
| use_feature_engineering=use_fe, | |
| output_dir="./automl_output", | |
| ) | |
| automl.fit(df, target_col=target_col) | |
| _state["automl"] = automl | |
| except Exception as e: | |
| status = f"❌ Error: {e}" | |
| import traceback; traceback.print_exc() | |
| finally: | |
| sys.stdout = original_stdout | |
| _state["running"] = False | |
| log_text = "\n".join(_state["log_lines"]) | |
| lb_df = _build_leaderboard_df() | |
| lb_plot = _build_leaderboard_plot() | |
| return (status, log_text, lb_df, lb_plot) | |
| def _build_leaderboard_df(): | |
| am = _state.get("automl") | |
| if am is None: | |
| return None | |
| df = am.leaderboard.to_dataframe() | |
| df = df.drop(columns=["_type", "primary_score"], errors="ignore") | |
| # Round floats | |
| for c in df.select_dtypes("float").columns: | |
| df[c] = df[c].round(4) | |
| return df | |
| def _build_leaderboard_plot(): | |
| am = _state.get("automl") | |
| if am is None: | |
| return None | |
| lb = am.leaderboard.to_dataframe() | |
| if lb.empty: | |
| return None | |
| lb = lb.drop(columns=["_type", "primary_score"], errors="ignore") | |
| metric_cols = [c for c in lb.columns if c != "model_name"] | |
| if not metric_cols: | |
| return None | |
| n_metrics = len(metric_cols) | |
| fig, axes = plt.subplots(1, n_metrics, | |
| figsize=(max(4, 3.5 * n_metrics), 4.5), | |
| facecolor=PALETTE["bg"]) | |
| if n_metrics == 1: | |
| axes = [axes] | |
| colors = [PALETTE["accent"], PALETTE["green"], PALETTE["yellow"], | |
| "#a371f7", "#f78166", "#79c0ff", "#56d364"] | |
| for i, (ax, metric) in enumerate(zip(axes, metric_cols)): | |
| _ax_style(ax, title=metric.upper()) | |
| vals = lb[metric].values | |
| names = lb["model_name"].values | |
| col = colors[i % len(colors)] | |
| bars = ax.barh(names, vals, color=col, alpha=0.85, height=0.55) | |
| ax.invert_yaxis() | |
| for bar, v in zip(bars, vals): | |
| ax.text(bar.get_width() + max(vals) * 0.01, bar.get_y() + bar.get_height() / 2, | |
| f"{v:.3f}", va="center", fontsize=8, color=PALETTE["text"]) | |
| ax.set_xlim(0, max(vals) * 1.18) | |
| # Highlight best bar | |
| best_idx = np.argmax(vals) if metric not in ("rmse","mae") else np.argmin(vals) | |
| axes[i].get_children()[best_idx].set_color(PALETTE["green"]) | |
| fig.suptitle("Model Leaderboard — All Metrics", color=PALETTE["text"], | |
| fontsize=13, fontweight="bold", y=1.02) | |
| fig.tight_layout() | |
| return _fig_to_pil(fig) | |
| def poll_log(): | |
| """Stream log lines while training is running.""" | |
| return "\n".join(_state["log_lines"]) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 3 — Results & Metrics | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def build_results_tab(): | |
| am = _state.get("automl") | |
| if am is None: | |
| return (gr.update(value="No training run yet."), | |
| gr.update(value=None), | |
| gr.update(value=None)) | |
| # Summary card | |
| best_name = am.best_model_name | |
| best_metrics = am.best_metrics | |
| task = am.task_type | |
| metrics_rows = "".join( | |
| f"""<div style="display:flex;justify-content:space-between; | |
| border-bottom:1px solid #30363d;padding:6px 0;"> | |
| <span style="color:#8b949e;">{k}</span> | |
| <span style="color:#58a6ff;font-weight:600;">{v:.4f}</span> | |
| </div>""" | |
| for k, v in best_metrics.items() | |
| ) | |
| card = f""" | |
| <div style="font-family:'JetBrains Mono',monospace; | |
| background:#161b22;border:1px solid #3fb950; | |
| border-radius:10px;padding:20px;color:#e6edf3;max-width:480px;"> | |
| <div style="font-size:11px;color:#3fb950;letter-spacing:2px;margin-bottom:6px;"> | |
| 🏆 BEST MODEL | |
| </div> | |
| <div style="font-size:22px;font-weight:700;margin-bottom:16px;">{best_name}</div> | |
| <div style="font-size:11px;color:#8b949e;margin-bottom:8px;">TASK: {task.upper()}</div> | |
| {metrics_rows} | |
| </div> | |
| """ | |
| # Metric radar / bar chart | |
| fig = _build_metrics_radar(best_metrics, task) | |
| radar_img = _fig_to_pil(fig) | |
| # Learning curves for PyTorch models | |
| curves_img = _build_loss_curves() | |
| return (gr.update(value=card), | |
| gr.update(value=radar_img), | |
| gr.update(value=curves_img)) | |
| def _build_metrics_radar(metrics, task): | |
| names = list(metrics.keys()) | |
| vals = list(metrics.values()) | |
| fig, ax = plt.subplots(figsize=(5.5, 4.5), facecolor=PALETTE["bg"]) | |
| _ax_style(ax, title=f"Best Model Metrics") | |
| x = np.arange(len(names)) | |
| bars = ax.bar(x, vals, color=[PALETTE["accent"], PALETTE["green"], | |
| PALETTE["yellow"], "#a371f7", "#f78166"][:len(names)], | |
| alpha=0.85, width=0.55) | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(names, rotation=20, ha="right", fontsize=9) | |
| ax.set_ylim(0, max(vals) * 1.2) | |
| for bar, v in zip(bars, vals): | |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(vals)*0.01, | |
| f"{v:.4f}", ha="center", va="bottom", fontsize=8.5, | |
| color=PALETTE["text"]) | |
| fig.tight_layout() | |
| return fig | |
| def _build_loss_curves(): | |
| am = _state.get("automl") | |
| if am is None: | |
| return None | |
| # Collect trainers from leaderboard | |
| trainers = [(e["model_name"], e["_model"]) | |
| for e in am.leaderboard.entries | |
| if e.get("_type") == "pytorch" and hasattr(e["_model"], "history")] | |
| if not trainers: | |
| return None | |
| fig, axes = plt.subplots(1, len(trainers), | |
| figsize=(5.5 * len(trainers), 4), | |
| facecolor=PALETTE["bg"]) | |
| if len(trainers) == 1: | |
| axes = [axes] | |
| for ax, (name, trainer) in zip(axes, trainers): | |
| _ax_style(ax, title=f"{name} — Loss Curves", | |
| xlabel="Epoch", ylabel="Loss") | |
| h = trainer.history | |
| epochs = range(1, len(h["train_loss"]) + 1) | |
| ax.plot(epochs, h["train_loss"], color=PALETTE["accent"], | |
| lw=1.8, label="Train") | |
| ax.plot(epochs, h["val_loss"], color=PALETTE["green"], | |
| lw=1.8, linestyle="--", label="Validation") | |
| ax.legend(fontsize=8, labelcolor=PALETTE["muted"], | |
| facecolor=PALETTE["surface"], edgecolor=PALETTE["border"]) | |
| ax.fill_between(epochs, h["train_loss"], alpha=0.08, color=PALETTE["accent"]) | |
| ax.fill_between(epochs, h["val_loss"], alpha=0.08, color=PALETTE["green"]) | |
| fig.tight_layout() | |
| return _fig_to_pil(fig) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 4 — Feature Importance | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def build_importance_tab(top_k): | |
| am = _state.get("automl") | |
| if am is None or not am.feature_importance: | |
| return gr.update(value=None) | |
| importance = am.feature_importance | |
| items = list(importance.items())[:int(top_k)] | |
| names = [i[0] for i in items] | |
| vals = [i[1] for i in items] | |
| fig = _styled_fig(10, max(4, len(names) * 0.45)) | |
| ax = fig.add_subplot(111) | |
| _ax_style(ax, title=f"Top {len(names)} Feature Importances", xlabel="Importance Score") | |
| colors_grad = plt.cm.Blues(np.linspace(0.4, 0.9, len(names)))[::-1] | |
| bars = ax.barh(names[::-1], vals[::-1], color=colors_grad, alpha=0.9, height=0.65) | |
| for bar, v in zip(bars, vals[::-1]): | |
| ax.text(bar.get_width() + max(vals) * 0.01, bar.get_y() + bar.get_height()/2, | |
| f"{v:.4f}", va="center", fontsize=8, color=PALETTE["text"]) | |
| ax.set_xlim(0, max(vals) * 1.18) | |
| ax.xaxis.set_tick_params(labelsize=8) | |
| ax.yaxis.set_tick_params(labelsize=8) | |
| fig.tight_layout() | |
| return gr.update(value=_fig_to_pil(fig)) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 5 — Predict on New Data | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def predict_on_file(file): | |
| am = _state.get("automl") | |
| if am is None: | |
| return (None, "<p style='color:red'>No trained model. Run training first.</p>") | |
| if file is None: | |
| return (None, "<p style='color:red'>Upload a CSV to predict on.</p>") | |
| try: | |
| filepath = _get_filepath(file) | |
| new_df = pd.read_csv(filepath) | |
| preds = am.predict(new_df) | |
| new_df["prediction"] = preds | |
| out_path = "./automl_output/predictions.csv" | |
| os.makedirs("./automl_output", exist_ok=True) | |
| new_df.to_csv(out_path, index=False) | |
| n = len(preds) | |
| if am.task_type == "regression": | |
| html = f""" | |
| <div style="font-family:monospace;background:#161b22;border:1px solid #30363d; | |
| border-radius:8px;padding:16px;color:#e6edf3;"> | |
| <b>✅ {n} predictions generated</b><br><br> | |
| Mean: {preds.mean():.3f} | | |
| Std: {preds.std():.3f} | | |
| Min: {preds.min():.3f} | | |
| Max: {preds.max():.3f} | |
| </div>""" | |
| else: | |
| unique, counts = np.unique(preds, return_counts=True) | |
| dist = " | ".join(f"Class {u}: {c}" for u, c in zip(unique, counts)) | |
| html = f""" | |
| <div style="font-family:monospace;background:#161b22;border:1px solid #30363d; | |
| border-radius:8px;padding:16px;color:#e6edf3;"> | |
| <b>✅ {n} predictions generated</b><br><br>{dist} | |
| </div>""" | |
| return (out_path, html) | |
| except Exception as e: | |
| return (None, f"<p style='color:red'>Error: {e}</p>") | |
| def predict_manual(vals_json): | |
| am = _state.get("automl") | |
| df = _state.get("df") | |
| if am is None or df is None: | |
| return "<p style='color:red'>Train a model first.</p>" | |
| try: | |
| row = json.loads(vals_json) | |
| input_df = pd.DataFrame([row]) | |
| pred = am.predict(input_df) | |
| val = pred[0] if hasattr(pred, "__len__") else pred | |
| return f""" | |
| <div style="font-family:monospace;background:#161b22;border:2px solid #3fb950; | |
| border-radius:8px;padding:20px;color:#e6edf3;text-align:center;"> | |
| <div style="font-size:13px;color:#8b949e;margin-bottom:8px;">PREDICTION</div> | |
| <div style="font-size:36px;font-weight:700;color:#3fb950;">{val:.4f if isinstance(val, float) else val}</div> | |
| <div style="font-size:11px;color:#8b949e;margin-top:8px;">Model: {am.best_model_name}</div> | |
| </div>""" | |
| except Exception as e: | |
| return f"<p style='color:red'>Error: {e}</p>" | |
| def build_manual_input_template(): | |
| df = _state.get("df") | |
| am = _state.get("automl") | |
| if df is None or am is None: | |
| return "{}" | |
| feature_cols = [c for c in df.columns if c != am.best_model_name] | |
| sample = df.drop(columns=[am.best_model_name] if am.best_model_name in df.columns else [], | |
| errors="ignore").iloc[0].to_dict() | |
| # Clean non-serialisable types | |
| clean = {k: (float(v) if isinstance(v, (np.floating, np.integer)) else v) | |
| for k, v in sample.items() | |
| if not isinstance(v, float) or not np.isnan(v)} | |
| return json.dumps(clean, indent=2) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Tab 6 — Dataset Analysis Visuals | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| def build_analysis_plots(): | |
| df = _state.get("df") | |
| if df is None: | |
| return (None, None, None) | |
| # 1. Missing values heatmap style bar | |
| missing = df.isnull().mean() * 100 | |
| missing = missing[missing > 0].sort_values(ascending=False) | |
| fig1 = _styled_fig(10, max(3, len(missing) * 0.4 + 1.5)) | |
| if not missing.empty: | |
| ax = fig1.add_subplot(111) | |
| _ax_style(ax, title="Missing Values by Column (%)", xlabel="Missing %") | |
| cols_m = [PALETTE["red"] if v > 20 else PALETTE["yellow"] if v > 5 else PALETTE["accent"] | |
| for v in missing.values] | |
| bars = ax.barh(missing.index, missing.values, color=cols_m, alpha=0.85) | |
| for bar, v in zip(bars, missing.values): | |
| ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2, | |
| f"{v:.1f}%", va="center", fontsize=8, color=PALETTE["text"]) | |
| ax.set_xlim(0, max(missing.values) * 1.25) | |
| red_p = mpatches.Patch(color=PALETTE["red"], label=">20%") | |
| yel_p = mpatches.Patch(color=PALETTE["yellow"], label="5-20%") | |
| blue_p = mpatches.Patch(color=PALETTE["accent"], label="<5%") | |
| ax.legend(handles=[red_p, yel_p, blue_p], fontsize=8, | |
| facecolor=PALETTE["surface"], edgecolor=PALETTE["border"], | |
| labelcolor=PALETTE["muted"]) | |
| else: | |
| ax = fig1.add_subplot(111) | |
| _ax_style(ax, title="Missing Values") | |
| ax.text(0.5, 0.5, "✅ No missing values!", transform=ax.transAxes, | |
| ha="center", va="center", fontsize=14, color=PALETTE["green"]) | |
| fig1.tight_layout() | |
| # 2. Correlation matrix (numeric only) | |
| num_df = df.select_dtypes("number") | |
| fig2 = _styled_fig(9, 7) | |
| ax2 = fig2.add_subplot(111) | |
| if len(num_df.columns) >= 2: | |
| corr = num_df.corr() | |
| im = ax2.imshow(corr, cmap="RdYlBu_r", vmin=-1, vmax=1, aspect="auto") | |
| plt.colorbar(im, ax=ax2, fraction=0.03, pad=0.04) | |
| ticks = range(len(corr.columns)) | |
| ax2.set_xticks(ticks); ax2.set_yticks(ticks) | |
| ax2.set_xticklabels(corr.columns, rotation=45, ha="right", | |
| fontsize=7, color=PALETTE["muted"]) | |
| ax2.set_yticklabels(corr.columns, fontsize=7, color=PALETTE["muted"]) | |
| ax2.set_title("Correlation Matrix", color=PALETTE["text"], fontsize=11, | |
| fontweight="bold", pad=10) | |
| ax2.set_facecolor(PALETTE["surface"]) | |
| else: | |
| _ax_style(ax2, title="Correlation Matrix") | |
| ax2.text(0.5, 0.5, "Need ≥2 numeric columns", transform=ax2.transAxes, | |
| ha="center", va="center", color=PALETTE["muted"]) | |
| fig2.patch.set_facecolor(PALETTE["bg"]) | |
| fig2.tight_layout() | |
| # 3. Data types pie | |
| dtypes_count = {} | |
| for col in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[col]): | |
| dtypes_count["Numeric"] = dtypes_count.get("Numeric", 0) + 1 | |
| elif pd.api.types.is_object_dtype(df[col]): | |
| avg_len = df[col].dropna().astype(str).str.len().mean() | |
| if avg_len > 30: | |
| dtypes_count["Text"] = dtypes_count.get("Text", 0) + 1 | |
| else: | |
| dtypes_count["Categorical"] = dtypes_count.get("Categorical", 0) + 1 | |
| else: | |
| dtypes_count["Other"] = dtypes_count.get("Other", 0) + 1 | |
| fig3 = _styled_fig(5.5, 4.5) | |
| ax3 = fig3.add_subplot(111) | |
| ax3.set_facecolor(PALETTE["bg"]) | |
| wedge_colors = [PALETTE["accent"], PALETTE["green"], | |
| PALETTE["yellow"], PALETTE["red"]][:len(dtypes_count)] | |
| wedges, texts, autotexts = ax3.pie( | |
| dtypes_count.values(), | |
| labels=dtypes_count.keys(), | |
| colors=wedge_colors, | |
| autopct="%1.0f%%", | |
| startangle=140, | |
| pctdistance=0.75, | |
| wedgeprops=dict(width=0.55, edgecolor=PALETTE["bg"], linewidth=2), | |
| ) | |
| for t in texts: t.set_color(PALETTE["muted"]); t.set_fontsize(10) | |
| for t in autotexts: t.set_color(PALETTE["bg"]); t.set_fontsize(9) | |
| ax3.set_title("Feature Type Distribution", color=PALETTE["text"], | |
| fontsize=11, fontweight="bold", pad=10) | |
| fig3.patch.set_facecolor(PALETTE["bg"]) | |
| return (_fig_to_pil(fig1), _fig_to_pil(fig2), _fig_to_pil(fig3)) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # Build the Gradio App | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| CUSTOM_CSS = """ | |
| /* ── Global ───────────────────────────────── */ | |
| body, .gradio-container { | |
| background: #0d1117 !important; | |
| font-family: 'JetBrains Mono', 'Fira Code', monospace !important; | |
| color: #e6edf3 !important; | |
| } | |
| /* ── Header ───────────────────────────────── */ | |
| .app-header { | |
| background: linear-gradient(135deg, #161b22 0%, #0d1117 100%); | |
| border-bottom: 1px solid #30363d; | |
| padding: 28px 32px 20px; | |
| margin-bottom: 8px; | |
| } | |
| .app-title { | |
| font-size: 28px; | |
| font-weight: 800; | |
| letter-spacing: -0.5px; | |
| background: linear-gradient(90deg, #58a6ff, #3fb950); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .app-subtitle { | |
| color: #8b949e; | |
| font-size: 12px; | |
| margin-top: 4px; | |
| letter-spacing: 1px; | |
| } | |
| /* ── Tabs ─────────────────────────────────── */ | |
| .tab-nav button { | |
| background: transparent !important; | |
| color: #8b949e !important; | |
| border-bottom: 2px solid transparent !important; | |
| border-radius: 0 !important; | |
| font-family: inherit !important; | |
| font-size: 12px !important; | |
| padding: 10px 18px !important; | |
| letter-spacing: 0.5px; | |
| } | |
| .tab-nav button.selected { | |
| color: #58a6ff !important; | |
| border-bottom-color: #58a6ff !important; | |
| } | |
| /* ── Inputs & Textboxes ───────────────────── */ | |
| input, textarea, select, .gr-input, .gr-textarea { | |
| background: #161b22 !important; | |
| border: 1px solid #30363d !important; | |
| color: #e6edf3 !important; | |
| border-radius: 6px !important; | |
| font-family: inherit !important; | |
| } | |
| input:focus, textarea:focus { | |
| border-color: #58a6ff !important; | |
| outline: none !important; | |
| box-shadow: 0 0 0 2px rgba(88,166,255,0.15) !important; | |
| } | |
| /* ── Buttons ──────────────────────────────── */ | |
| .gr-button-primary, button.primary { | |
| background: linear-gradient(135deg, #1f6feb, #388bfd) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: 6px !important; | |
| font-family: inherit !important; | |
| font-size: 13px !important; | |
| font-weight: 600 !important; | |
| padding: 10px 22px !important; | |
| letter-spacing: 0.3px; | |
| transition: opacity 0.2s !important; | |
| } | |
| .gr-button-primary:hover { opacity: 0.88 !important; } | |
| .gr-button-secondary, button.secondary { | |
| background: #21262d !important; | |
| color: #e6edf3 !important; | |
| border: 1px solid #30363d !important; | |
| border-radius: 6px !important; | |
| font-family: inherit !important; | |
| } | |
| /* ── Dropdown ─────────────────────────────── */ | |
| .gr-dropdown select { background: #161b22 !important; color: #e6edf3 !important; } | |
| /* ── Slider ───────────────────────────────── */ | |
| .gr-slider input[type=range] { accent-color: #58a6ff; } | |
| /* ── Blocks / Panels ──────────────────────── */ | |
| .gr-block, .gr-panel, .gr-box { | |
| background: #161b22 !important; | |
| border: 1px solid #30363d !important; | |
| border-radius: 8px !important; | |
| } | |
| /* ── Log textbox ──────────────────────────── */ | |
| .log-box textarea { | |
| background: #0d1117 !important; | |
| color: #3fb950 !important; | |
| font-size: 11px !important; | |
| font-family: 'JetBrains Mono', monospace !important; | |
| border: 1px solid #30363d !important; | |
| } | |
| /* ── Status badge ─────────────────────────── */ | |
| .status-box textarea { | |
| background: #161b22 !important; | |
| color: #58a6ff !important; | |
| font-weight: 600 !important; | |
| font-size: 13px !important; | |
| border: 1px solid #30363d !important; | |
| text-align: center !important; | |
| } | |
| /* ── DataFrame ────────────────────────────── */ | |
| .gr-dataframe table { | |
| background: #161b22 !important; | |
| color: #e6edf3 !important; | |
| font-size: 12px !important; | |
| font-family: inherit !important; | |
| } | |
| .gr-dataframe th { | |
| background: #21262d !important; | |
| color: #58a6ff !important; | |
| border-bottom: 1px solid #30363d !important; | |
| } | |
| .gr-dataframe td { border-bottom: 1px solid #21262d !important; } | |
| /* ── Labels ───────────────────────────────── */ | |
| label, .gr-form label, .gr-label { | |
| color: #8b949e !important; | |
| font-size: 11px !important; | |
| letter-spacing: 0.5px !important; | |
| text-transform: uppercase !important; | |
| } | |
| /* ── Accordion ────────────────────────────── */ | |
| .gr-accordion { border: 1px solid #30363d !important; border-radius: 8px !important; } | |
| """ | |
| HEADER_HTML = """ | |
| <div class="app-header"> | |
| <div class="app-title">⚡ Tabular AutoML</div> | |
| <div class="app-subtitle">AUTOMATED MACHINE LEARNING · CLASSICAL + NEURAL MODELS · BAYESIAN HPO</div> | |
| </div> | |
| """ | |
| def build_app(): | |
| # Blocks title param available since gradio 3.9 | |
| try: | |
| blocks_kwargs = dict(css=CUSTOM_CSS, title="AutoML Studio") | |
| gr.Blocks(**blocks_kwargs) # test | |
| except TypeError: | |
| blocks_kwargs = dict(css=CUSTOM_CSS) | |
| with gr.Blocks(**blocks_kwargs) as app: | |
| gr.HTML(HEADER_HTML) | |
| with gr.Tabs(): | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 1 — Upload & Explore | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("📂 Upload & Explore"): | |
| gr.Markdown("### Upload your CSV dataset to get started") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| upload_btn = gr.File(label="📂 Drop CSV here (or click to browse)") | |
| target_dd = gr.Dropdown(label="Target Column", choices=[], interactive=True) | |
| explore_btn = gr.Button("🔍 Analyze Selected Column", variant="secondary") | |
| with gr.Column(scale=2): | |
| dataset_summary = gr.HTML(value="<p style='color:#8b949e;'>Upload a CSV to see summary.</p>") | |
| col_plot = gr.Image(label="Column Distribution") | |
| upload_btn.change( | |
| fn=handle_upload, | |
| inputs=[upload_btn], | |
| outputs=[target_dd, dataset_summary, col_plot] | |
| ) | |
| explore_btn.click( | |
| fn=show_column_stats, | |
| inputs=[target_dd], | |
| outputs=[col_plot] | |
| ) | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 2 — Configure & Train | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("🚀 Configure & Train"): | |
| gr.Markdown("### Training Configuration") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| task_radio = gr.Radio(["classification", "regression"], | |
| label="Task Type", value="regression") | |
| time_budget = gr.Number(label="Time Budget (seconds, blank = unlimited)", | |
| value=300, precision=0) | |
| n_trials = gr.Slider(3, 30, value=15, step=1, | |
| label="Optuna HPO Trials per Model") | |
| use_fe = gr.Checkbox(label="Enable Feature Engineering", value=True) | |
| with gr.Column(scale=1): | |
| val_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, | |
| label="Validation Split Size") | |
| test_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, | |
| label="Test Split Size") | |
| seed_in = gr.Number(label="Random Seed", value=42, precision=0) | |
| train_btn = gr.Button("▶ Start Training", variant="primary") | |
| with gr.Row(): | |
| train_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Accordion("📋 Training Log", open=True): | |
| log_box = gr.Textbox(label="Live Output", lines=18, | |
| interactive=False) | |
| refresh_btn = gr.Button("↻ Refresh Log", variant="secondary") | |
| gr.Markdown("### Leaderboard Preview") | |
| lb_table = gr.DataFrame(label="Model Scores") | |
| lb_plot = gr.Image(label="Metric Comparison") | |
| train_btn.click( | |
| fn=run_training, | |
| inputs=[target_dd, task_radio, time_budget, n_trials, | |
| use_fe, val_size, test_size, seed_in], | |
| outputs=[train_status, log_box, lb_table, lb_plot] | |
| ) | |
| refresh_btn.click(fn=poll_log, inputs=[], outputs=[log_box]) | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 3 — Results & Metrics | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("📊 Results & Metrics"): | |
| gr.Markdown("### Best Model Performance") | |
| results_btn = gr.Button("Load Results", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| best_card = gr.HTML(value="<p style='color:#8b949e;'>Run training first.</p>") | |
| with gr.Column(scale=1): | |
| metrics_plot = gr.Image(label="Metrics Chart") | |
| gr.Markdown("### PyTorch Training Curves") | |
| curves_plot = gr.Image(label="Loss Curves (neural models only)") | |
| results_btn.click( | |
| fn=build_results_tab, | |
| inputs=[], | |
| outputs=[best_card, metrics_plot, curves_plot] | |
| ) | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 4 — Feature Importance | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("🔍 Feature Importance"): | |
| gr.Markdown("### SHAP / Model-Based Feature Importance") | |
| with gr.Row(): | |
| top_k_slider = gr.Slider(5, 40, value=15, step=1, label="Top K Features") | |
| imp_btn = gr.Button("Generate Importance Plot", variant="primary") | |
| importance_plot = gr.Image(label="Feature Importances") | |
| imp_btn.click( | |
| fn=build_importance_tab, | |
| inputs=[top_k_slider], | |
| outputs=[importance_plot] | |
| ) | |
| top_k_slider.change( | |
| fn=build_importance_tab, | |
| inputs=[top_k_slider], | |
| outputs=[importance_plot] | |
| ) | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 5 — Predict | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("🎯 Predict"): | |
| gr.Markdown("### Batch Prediction (CSV file)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| pred_file = gr.File(label="Upload CSV for prediction") | |
| pred_btn = gr.Button("Run Prediction", variant="primary") | |
| with gr.Column(): | |
| pred_result = gr.HTML() | |
| pred_dl = gr.File(label="Download Predictions CSV") | |
| pred_btn.click( | |
| fn=predict_on_file, | |
| inputs=[pred_file], | |
| outputs=[pred_dl, pred_result] | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("### Manual Single-Row Prediction") | |
| gr.Markdown("Paste a JSON object with feature values:") | |
| with gr.Row(): | |
| with gr.Column(): | |
| template_btn = gr.Button("📋 Load Sample Row Template", variant="secondary") | |
| manual_json = gr.Textbox(label="Input JSON", lines=10, | |
| placeholder='{"feature_0": 1.23, "cat_A": "high", ...}') | |
| manual_btn = gr.Button("Predict", variant="primary") | |
| with gr.Column(): | |
| manual_out = gr.HTML() | |
| template_btn.click( | |
| fn=build_manual_input_template, | |
| inputs=[], outputs=[manual_json] | |
| ) | |
| manual_btn.click( | |
| fn=predict_manual, | |
| inputs=[manual_json], outputs=[manual_out] | |
| ) | |
| # ══════════════════════════════════════════════════════════════ | |
| # TAB 6 — Dataset Analysis | |
| # ══════════════════════════════════════════════════════════════ | |
| with gr.Tab("🧬 Dataset Analysis"): | |
| gr.Markdown("### Automated Dataset Visualizations") | |
| analysis_btn = gr.Button("Generate Analysis Plots", variant="primary") | |
| with gr.Row(): | |
| missing_plot = gr.Image(label="Missing Values") | |
| with gr.Row(): | |
| corr_plot = gr.Image(label="Correlation Matrix") | |
| dtype_plot = gr.Image(label="Feature Types") | |
| analysis_btn.click( | |
| fn=build_analysis_plots, | |
| inputs=[], | |
| outputs=[missing_plot, corr_plot, dtype_plot] | |
| ) | |
| # Footer | |
| gr.HTML(""" | |
| <div style="text-align:center;padding:20px; | |
| border-top:1px solid #30363d; | |
| color:#8b949e;font-size:11px; | |
| letter-spacing:0.5px;margin-top:16px;"> | |
| TABULAR AUTOML · SKLEARN + PYTORCH · OPTUNA HPO · SHAP EXPLAINABILITY | |
| </div> | |
| """) | |
| return app | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| print("\n" + "="*55) | |
| print(" AutoML Gradio Interface") | |
| print(f" Gradio version: {_gr_check.__version__}") | |
| print("="*55 + "\n") | |
| app = build_app() | |
| # Detect if running on HF Spaces | |
| import os | |
| on_hf_spaces = os.environ.get("SPACE_ID") is not None | |
| if on_hf_spaces: | |
| # HF Spaces: minimal launch args | |
| app.launch(ssr_mode=False) | |
| else: | |
| # Local: full args | |
| launch_kwargs = dict( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True, | |
| ) | |
| # ssr_mode only exists in Gradio 5+ | |
| if _gr_version[0] >= 5: | |
| launch_kwargs["ssr_mode"] = False | |
| app.launch(**launch_kwargs) | |