""" app.py — Gradio Interface for the Tabular AutoML Framework Run: python app.py """ import gradio as gr import pandas as pd import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.gridspec import GridSpec import io, os, sys, time, threading, json, textwrap from pathlib import Path # ── make sure the automl package is importable ──────────────────────────────── sys.path.insert(0, str(Path(__file__).parent)) from automl import AutoML # ── Gradio version compatibility check ─────────────────────────────────────── import gradio as _gr_check _gr_version = tuple(int(x) for x in _gr_check.__version__.split(".")[:2]) _is_gradio_4_plus = _gr_version[0] >= 4 _is_gradio_6_plus = _gr_version[0] >= 6 print(f" Gradio version: {_gr_check.__version__}") # gr.File type param: needed in 3.x only _FILE_KWARGS = {} if _is_gradio_4_plus else {"type": "file"} # ── Global state ────────────────────────────────────────────────────────────── _state: dict = { "automl": None, "df": None, "log_lines": [], "running": False, } # ── File path compatibility helper (Gradio 3.x / 4.x / 5.x / 6.x) ────────── def _get_filepath(file): """Handle gr.File output across ALL Gradio versions including 6.x.""" if file is None: return None # Gradio 6.x: returns plain string filepath directly if isinstance(file, str): return file # Gradio 6.x: sometimes returns a list (multiple files) if isinstance(file, list): return file[0] if file else None # Gradio 3.x early: returns dict if isinstance(file, dict): return file.get("name") or file.get("path") or file.get("tmp_path") # Gradio 4.x / 5.x: UploadData with .path attribute if hasattr(file, "path"): return file.path # Gradio 3.x late: object with .name if hasattr(file, "name"): return file.name # Last resort return str(file) PALETTE = { "bg": "#0d1117", "surface": "#161b22", "border": "#30363d", "accent": "#58a6ff", "green": "#3fb950", "yellow": "#d29922", "red": "#f85149", "text": "#e6edf3", "muted": "#8b949e", } # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor(), dpi=130) buf.seek(0) from PIL import Image img = Image.open(buf).copy() buf.close() plt.close(fig) return img def _styled_fig(w=12, h=6): fig = plt.figure(figsize=(w, h), facecolor=PALETTE["bg"]) return fig def _ax_style(ax, title="", xlabel="", ylabel=""): ax.set_facecolor(PALETTE["surface"]) ax.tick_params(colors=PALETTE["muted"], labelsize=9) ax.spines[:].set_color(PALETTE["border"]) if title: ax.set_title(title, color=PALETTE["text"], fontsize=11, pad=10, fontweight="bold") if xlabel: ax.set_xlabel(xlabel, color=PALETTE["muted"], fontsize=9) if ylabel: ax.set_ylabel(ylabel, color=PALETTE["muted"], fontsize=9) ax.tick_params(axis="x", colors=PALETTE["muted"]) ax.tick_params(axis="y", colors=PALETTE["muted"]) return ax # ───────────────────────────────────────────────────────────────────────────── # Tab 1 — Upload & Preview # ───────────────────────────────────────────────────────────────────────────── def handle_upload(file): if file is None: return (gr.update(choices=[], value=None), gr.update(value="

Upload a CSV to see summary.

"), gr.update(value=None)) try: print(f" [Upload] file type: {type(file)}, value: {repr(file)[:200]}") filepath = _get_filepath(file) print(f" [Upload] resolved filepath: {filepath}") if filepath is None: return (gr.update(choices=[], value=None), gr.update(value="

Could not read file. Try uploading again.

"), gr.update(value=None)) df = pd.read_csv(filepath) _state["df"] = df cols = df.columns.tolist() # Build a rich HTML preview n_rows, n_cols = df.shape missing = df.isnull().sum().sum() dtypes = df.dtypes.value_counts().to_dict() dtype_str = ", ".join(f"{v}× {k}" for k, v in dtypes.items()) summary_html = f"""
{n_rows:,}
ROWS
{n_cols}
COLUMNS
{missing:,}
MISSING CELLS
Dtypes: {dtype_str}
""" # First 6 rows as styled HTML table preview_df = df.head(6) tbl = preview_df.to_html(index=False, border=0, classes="preview-tbl") table_html = f"""
{tbl}
""" combined = summary_html + "
" + table_html return (gr.update(choices=cols, value=cols[-1]), gr.update(value=combined), gr.update(value=None)) except Exception as e: import traceback err_detail = traceback.format_exc() print(f" [Upload Error] {err_detail}") return (gr.update(choices=[], value=None), gr.update(value=f"
Error loading file:
{str(e)}

Check HF Space logs for details.
"), gr.update(value=None)) def show_column_stats(col_name): df = _state.get("df") if df is None or not col_name: return gr.update(value=None) s = df[col_name] fig = _styled_fig(10, 3.5) gs = GridSpec(1, 2, figure=fig, wspace=0.4) # Distribution plot ax1 = fig.add_subplot(gs[0]) _ax_style(ax1, title=f"Distribution — {col_name}") if pd.api.types.is_numeric_dtype(s): clean = s.dropna() ax1.hist(clean, bins=30, color=PALETTE["accent"], alpha=0.85, edgecolor="none") ax1.axvline(clean.mean(), color=PALETTE["yellow"], lw=1.5, linestyle="--", label=f"mean={clean.mean():.2f}") ax1.legend(fontsize=8, labelcolor=PALETTE["muted"], facecolor=PALETTE["surface"], edgecolor=PALETTE["border"]) else: vc = s.value_counts().head(12) bars = ax1.barh(vc.index.astype(str), vc.values, color=PALETTE["accent"], alpha=0.85) ax1.invert_yaxis() # Stats panel ax2 = fig.add_subplot(gs[1]) ax2.axis("off") lines = [f"dtype: {s.dtype}", f"missing: {s.isnull().sum()} ({s.isnull().mean()*100:.1f}%)", f"unique: {s.nunique()}"] if pd.api.types.is_numeric_dtype(s): lines += [f"mean: {s.mean():.3f}", f"std: {s.std():.3f}", f"min: {s.min():.3f}", f"25%: {s.quantile(.25):.3f}", f"median: {s.median():.3f}", f"75%: {s.quantile(.75):.3f}", f"max: {s.max():.3f}"] txt = "\n".join(lines) ax2.text(0.05, 0.95, txt, transform=ax2.transAxes, va="top", ha="left", fontsize=9.5, fontfamily="monospace", color=PALETTE["text"], bbox=dict(boxstyle="round,pad=0.6", facecolor=PALETTE["surface"], edgecolor=PALETTE["border"])) fig.patch.set_facecolor(PALETTE["bg"]) return gr.update(value=_fig_to_pil(fig)) # ───────────────────────────────────────────────────────────────────────────── # Tab 2 — Training # ───────────────────────────────────────────────────────────────────────────── class LogCapture: """Redirect stdout to both terminal and our log buffer.""" def __init__(self, original): self.original = original self.lines = _state["log_lines"] def write(self, msg): self.original.write(msg) if msg.strip(): ts = time.strftime("%H:%M:%S") self.lines.append(f"[{ts}] {msg.rstrip()}") def flush(self): self.original.flush() def run_training(target_col, task_type, time_budget, n_trials, use_fe, val_size, test_size, seed): df = _state.get("df") if df is None: return ("❌ Please upload a dataset first.", "", None, None) if not target_col: return ("❌ Please select a target column.", "", None, None) _state["log_lines"].clear() _state["running"] = True original_stdout = sys.stdout sys.stdout = LogCapture(original_stdout) status = "✅ Training complete!" try: budget = float(time_budget) if time_budget and float(time_budget) > 0 else None automl = AutoML( task_type=task_type, time_budget=budget, n_optuna_trials=int(n_trials), val_size=float(val_size), test_size=float(test_size), seed=int(seed), use_feature_engineering=use_fe, output_dir="./automl_output", ) automl.fit(df, target_col=target_col) _state["automl"] = automl except Exception as e: status = f"❌ Error: {e}" import traceback; traceback.print_exc() finally: sys.stdout = original_stdout _state["running"] = False log_text = "\n".join(_state["log_lines"]) lb_df = _build_leaderboard_df() lb_plot = _build_leaderboard_plot() return (status, log_text, lb_df, lb_plot) def _build_leaderboard_df(): am = _state.get("automl") if am is None: return None df = am.leaderboard.to_dataframe() df = df.drop(columns=["_type", "primary_score"], errors="ignore") # Round floats for c in df.select_dtypes("float").columns: df[c] = df[c].round(4) return df def _build_leaderboard_plot(): am = _state.get("automl") if am is None: return None lb = am.leaderboard.to_dataframe() if lb.empty: return None lb = lb.drop(columns=["_type", "primary_score"], errors="ignore") metric_cols = [c for c in lb.columns if c != "model_name"] if not metric_cols: return None n_metrics = len(metric_cols) fig, axes = plt.subplots(1, n_metrics, figsize=(max(4, 3.5 * n_metrics), 4.5), facecolor=PALETTE["bg"]) if n_metrics == 1: axes = [axes] colors = [PALETTE["accent"], PALETTE["green"], PALETTE["yellow"], "#a371f7", "#f78166", "#79c0ff", "#56d364"] for i, (ax, metric) in enumerate(zip(axes, metric_cols)): _ax_style(ax, title=metric.upper()) vals = lb[metric].values names = lb["model_name"].values col = colors[i % len(colors)] bars = ax.barh(names, vals, color=col, alpha=0.85, height=0.55) ax.invert_yaxis() for bar, v in zip(bars, vals): ax.text(bar.get_width() + max(vals) * 0.01, bar.get_y() + bar.get_height() / 2, f"{v:.3f}", va="center", fontsize=8, color=PALETTE["text"]) ax.set_xlim(0, max(vals) * 1.18) # Highlight best bar best_idx = np.argmax(vals) if metric not in ("rmse","mae") else np.argmin(vals) axes[i].get_children()[best_idx].set_color(PALETTE["green"]) fig.suptitle("Model Leaderboard — All Metrics", color=PALETTE["text"], fontsize=13, fontweight="bold", y=1.02) fig.tight_layout() return _fig_to_pil(fig) def poll_log(): """Stream log lines while training is running.""" return "\n".join(_state["log_lines"]) # ───────────────────────────────────────────────────────────────────────────── # Tab 3 — Results & Metrics # ───────────────────────────────────────────────────────────────────────────── def build_results_tab(): am = _state.get("automl") if am is None: return (gr.update(value="No training run yet."), gr.update(value=None), gr.update(value=None)) # Summary card best_name = am.best_model_name best_metrics = am.best_metrics task = am.task_type metrics_rows = "".join( f"""
{k} {v:.4f}
""" for k, v in best_metrics.items() ) card = f"""
🏆 BEST MODEL
{best_name}
TASK: {task.upper()}
{metrics_rows}
""" # Metric radar / bar chart fig = _build_metrics_radar(best_metrics, task) radar_img = _fig_to_pil(fig) # Learning curves for PyTorch models curves_img = _build_loss_curves() return (gr.update(value=card), gr.update(value=radar_img), gr.update(value=curves_img)) def _build_metrics_radar(metrics, task): names = list(metrics.keys()) vals = list(metrics.values()) fig, ax = plt.subplots(figsize=(5.5, 4.5), facecolor=PALETTE["bg"]) _ax_style(ax, title=f"Best Model Metrics") x = np.arange(len(names)) bars = ax.bar(x, vals, color=[PALETTE["accent"], PALETTE["green"], PALETTE["yellow"], "#a371f7", "#f78166"][:len(names)], alpha=0.85, width=0.55) ax.set_xticks(x) ax.set_xticklabels(names, rotation=20, ha="right", fontsize=9) ax.set_ylim(0, max(vals) * 1.2) for bar, v in zip(bars, vals): ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(vals)*0.01, f"{v:.4f}", ha="center", va="bottom", fontsize=8.5, color=PALETTE["text"]) fig.tight_layout() return fig def _build_loss_curves(): am = _state.get("automl") if am is None: return None # Collect trainers from leaderboard trainers = [(e["model_name"], e["_model"]) for e in am.leaderboard.entries if e.get("_type") == "pytorch" and hasattr(e["_model"], "history")] if not trainers: return None fig, axes = plt.subplots(1, len(trainers), figsize=(5.5 * len(trainers), 4), facecolor=PALETTE["bg"]) if len(trainers) == 1: axes = [axes] for ax, (name, trainer) in zip(axes, trainers): _ax_style(ax, title=f"{name} — Loss Curves", xlabel="Epoch", ylabel="Loss") h = trainer.history epochs = range(1, len(h["train_loss"]) + 1) ax.plot(epochs, h["train_loss"], color=PALETTE["accent"], lw=1.8, label="Train") ax.plot(epochs, h["val_loss"], color=PALETTE["green"], lw=1.8, linestyle="--", label="Validation") ax.legend(fontsize=8, labelcolor=PALETTE["muted"], facecolor=PALETTE["surface"], edgecolor=PALETTE["border"]) ax.fill_between(epochs, h["train_loss"], alpha=0.08, color=PALETTE["accent"]) ax.fill_between(epochs, h["val_loss"], alpha=0.08, color=PALETTE["green"]) fig.tight_layout() return _fig_to_pil(fig) # ───────────────────────────────────────────────────────────────────────────── # Tab 4 — Feature Importance # ───────────────────────────────────────────────────────────────────────────── def build_importance_tab(top_k): am = _state.get("automl") if am is None or not am.feature_importance: return gr.update(value=None) importance = am.feature_importance items = list(importance.items())[:int(top_k)] names = [i[0] for i in items] vals = [i[1] for i in items] fig = _styled_fig(10, max(4, len(names) * 0.45)) ax = fig.add_subplot(111) _ax_style(ax, title=f"Top {len(names)} Feature Importances", xlabel="Importance Score") colors_grad = plt.cm.Blues(np.linspace(0.4, 0.9, len(names)))[::-1] bars = ax.barh(names[::-1], vals[::-1], color=colors_grad, alpha=0.9, height=0.65) for bar, v in zip(bars, vals[::-1]): ax.text(bar.get_width() + max(vals) * 0.01, bar.get_y() + bar.get_height()/2, f"{v:.4f}", va="center", fontsize=8, color=PALETTE["text"]) ax.set_xlim(0, max(vals) * 1.18) ax.xaxis.set_tick_params(labelsize=8) ax.yaxis.set_tick_params(labelsize=8) fig.tight_layout() return gr.update(value=_fig_to_pil(fig)) # ───────────────────────────────────────────────────────────────────────────── # Tab 5 — Predict on New Data # ───────────────────────────────────────────────────────────────────────────── def predict_on_file(file): am = _state.get("automl") if am is None: return (None, "

No trained model. Run training first.

") if file is None: return (None, "

Upload a CSV to predict on.

") try: filepath = _get_filepath(file) new_df = pd.read_csv(filepath) preds = am.predict(new_df) new_df["prediction"] = preds out_path = "./automl_output/predictions.csv" os.makedirs("./automl_output", exist_ok=True) new_df.to_csv(out_path, index=False) n = len(preds) if am.task_type == "regression": html = f"""
✅ {n} predictions generated

Mean: {preds.mean():.3f}  |  Std: {preds.std():.3f}  |  Min: {preds.min():.3f}  |  Max: {preds.max():.3f}
""" else: unique, counts = np.unique(preds, return_counts=True) dist = "  |  ".join(f"Class {u}: {c}" for u, c in zip(unique, counts)) html = f"""
✅ {n} predictions generated

{dist}
""" return (out_path, html) except Exception as e: return (None, f"

Error: {e}

") def predict_manual(vals_json): am = _state.get("automl") df = _state.get("df") if am is None or df is None: return "

Train a model first.

" try: row = json.loads(vals_json) input_df = pd.DataFrame([row]) pred = am.predict(input_df) val = pred[0] if hasattr(pred, "__len__") else pred return f"""
PREDICTION
{val:.4f if isinstance(val, float) else val}
Model: {am.best_model_name}
""" except Exception as e: return f"

Error: {e}

" def build_manual_input_template(): df = _state.get("df") am = _state.get("automl") if df is None or am is None: return "{}" feature_cols = [c for c in df.columns if c != am.best_model_name] sample = df.drop(columns=[am.best_model_name] if am.best_model_name in df.columns else [], errors="ignore").iloc[0].to_dict() # Clean non-serialisable types clean = {k: (float(v) if isinstance(v, (np.floating, np.integer)) else v) for k, v in sample.items() if not isinstance(v, float) or not np.isnan(v)} return json.dumps(clean, indent=2) # ───────────────────────────────────────────────────────────────────────────── # Tab 6 — Dataset Analysis Visuals # ───────────────────────────────────────────────────────────────────────────── def build_analysis_plots(): df = _state.get("df") if df is None: return (None, None, None) # 1. Missing values heatmap style bar missing = df.isnull().mean() * 100 missing = missing[missing > 0].sort_values(ascending=False) fig1 = _styled_fig(10, max(3, len(missing) * 0.4 + 1.5)) if not missing.empty: ax = fig1.add_subplot(111) _ax_style(ax, title="Missing Values by Column (%)", xlabel="Missing %") cols_m = [PALETTE["red"] if v > 20 else PALETTE["yellow"] if v > 5 else PALETTE["accent"] for v in missing.values] bars = ax.barh(missing.index, missing.values, color=cols_m, alpha=0.85) for bar, v in zip(bars, missing.values): ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2, f"{v:.1f}%", va="center", fontsize=8, color=PALETTE["text"]) ax.set_xlim(0, max(missing.values) * 1.25) red_p = mpatches.Patch(color=PALETTE["red"], label=">20%") yel_p = mpatches.Patch(color=PALETTE["yellow"], label="5-20%") blue_p = mpatches.Patch(color=PALETTE["accent"], label="<5%") ax.legend(handles=[red_p, yel_p, blue_p], fontsize=8, facecolor=PALETTE["surface"], edgecolor=PALETTE["border"], labelcolor=PALETTE["muted"]) else: ax = fig1.add_subplot(111) _ax_style(ax, title="Missing Values") ax.text(0.5, 0.5, "✅ No missing values!", transform=ax.transAxes, ha="center", va="center", fontsize=14, color=PALETTE["green"]) fig1.tight_layout() # 2. Correlation matrix (numeric only) num_df = df.select_dtypes("number") fig2 = _styled_fig(9, 7) ax2 = fig2.add_subplot(111) if len(num_df.columns) >= 2: corr = num_df.corr() im = ax2.imshow(corr, cmap="RdYlBu_r", vmin=-1, vmax=1, aspect="auto") plt.colorbar(im, ax=ax2, fraction=0.03, pad=0.04) ticks = range(len(corr.columns)) ax2.set_xticks(ticks); ax2.set_yticks(ticks) ax2.set_xticklabels(corr.columns, rotation=45, ha="right", fontsize=7, color=PALETTE["muted"]) ax2.set_yticklabels(corr.columns, fontsize=7, color=PALETTE["muted"]) ax2.set_title("Correlation Matrix", color=PALETTE["text"], fontsize=11, fontweight="bold", pad=10) ax2.set_facecolor(PALETTE["surface"]) else: _ax_style(ax2, title="Correlation Matrix") ax2.text(0.5, 0.5, "Need ≥2 numeric columns", transform=ax2.transAxes, ha="center", va="center", color=PALETTE["muted"]) fig2.patch.set_facecolor(PALETTE["bg"]) fig2.tight_layout() # 3. Data types pie dtypes_count = {} for col in df.columns: if pd.api.types.is_numeric_dtype(df[col]): dtypes_count["Numeric"] = dtypes_count.get("Numeric", 0) + 1 elif pd.api.types.is_object_dtype(df[col]): avg_len = df[col].dropna().astype(str).str.len().mean() if avg_len > 30: dtypes_count["Text"] = dtypes_count.get("Text", 0) + 1 else: dtypes_count["Categorical"] = dtypes_count.get("Categorical", 0) + 1 else: dtypes_count["Other"] = dtypes_count.get("Other", 0) + 1 fig3 = _styled_fig(5.5, 4.5) ax3 = fig3.add_subplot(111) ax3.set_facecolor(PALETTE["bg"]) wedge_colors = [PALETTE["accent"], PALETTE["green"], PALETTE["yellow"], PALETTE["red"]][:len(dtypes_count)] wedges, texts, autotexts = ax3.pie( dtypes_count.values(), labels=dtypes_count.keys(), colors=wedge_colors, autopct="%1.0f%%", startangle=140, pctdistance=0.75, wedgeprops=dict(width=0.55, edgecolor=PALETTE["bg"], linewidth=2), ) for t in texts: t.set_color(PALETTE["muted"]); t.set_fontsize(10) for t in autotexts: t.set_color(PALETTE["bg"]); t.set_fontsize(9) ax3.set_title("Feature Type Distribution", color=PALETTE["text"], fontsize=11, fontweight="bold", pad=10) fig3.patch.set_facecolor(PALETTE["bg"]) return (_fig_to_pil(fig1), _fig_to_pil(fig2), _fig_to_pil(fig3)) # ───────────────────────────────────────────────────────────────────────────── # Build the Gradio App # ───────────────────────────────────────────────────────────────────────────── CUSTOM_CSS = """ /* ── Global ───────────────────────────────── */ body, .gradio-container { background: #0d1117 !important; font-family: 'JetBrains Mono', 'Fira Code', monospace !important; color: #e6edf3 !important; } /* ── Header ───────────────────────────────── */ .app-header { background: linear-gradient(135deg, #161b22 0%, #0d1117 100%); border-bottom: 1px solid #30363d; padding: 28px 32px 20px; margin-bottom: 8px; } .app-title { font-size: 28px; font-weight: 800; letter-spacing: -0.5px; background: linear-gradient(90deg, #58a6ff, #3fb950); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .app-subtitle { color: #8b949e; font-size: 12px; margin-top: 4px; letter-spacing: 1px; } /* ── Tabs ─────────────────────────────────── */ .tab-nav button { background: transparent !important; color: #8b949e !important; border-bottom: 2px solid transparent !important; border-radius: 0 !important; font-family: inherit !important; font-size: 12px !important; padding: 10px 18px !important; letter-spacing: 0.5px; } .tab-nav button.selected { color: #58a6ff !important; border-bottom-color: #58a6ff !important; } /* ── Inputs & Textboxes ───────────────────── */ input, textarea, select, .gr-input, .gr-textarea { background: #161b22 !important; border: 1px solid #30363d !important; color: #e6edf3 !important; border-radius: 6px !important; font-family: inherit !important; } input:focus, textarea:focus { border-color: #58a6ff !important; outline: none !important; box-shadow: 0 0 0 2px rgba(88,166,255,0.15) !important; } /* ── Buttons ──────────────────────────────── */ .gr-button-primary, button.primary { background: linear-gradient(135deg, #1f6feb, #388bfd) !important; color: white !important; border: none !important; border-radius: 6px !important; font-family: inherit !important; font-size: 13px !important; font-weight: 600 !important; padding: 10px 22px !important; letter-spacing: 0.3px; transition: opacity 0.2s !important; } .gr-button-primary:hover { opacity: 0.88 !important; } .gr-button-secondary, button.secondary { background: #21262d !important; color: #e6edf3 !important; border: 1px solid #30363d !important; border-radius: 6px !important; font-family: inherit !important; } /* ── Dropdown ─────────────────────────────── */ .gr-dropdown select { background: #161b22 !important; color: #e6edf3 !important; } /* ── Slider ───────────────────────────────── */ .gr-slider input[type=range] { accent-color: #58a6ff; } /* ── Blocks / Panels ──────────────────────── */ .gr-block, .gr-panel, .gr-box { background: #161b22 !important; border: 1px solid #30363d !important; border-radius: 8px !important; } /* ── Log textbox ──────────────────────────── */ .log-box textarea { background: #0d1117 !important; color: #3fb950 !important; font-size: 11px !important; font-family: 'JetBrains Mono', monospace !important; border: 1px solid #30363d !important; } /* ── Status badge ─────────────────────────── */ .status-box textarea { background: #161b22 !important; color: #58a6ff !important; font-weight: 600 !important; font-size: 13px !important; border: 1px solid #30363d !important; text-align: center !important; } /* ── DataFrame ────────────────────────────── */ .gr-dataframe table { background: #161b22 !important; color: #e6edf3 !important; font-size: 12px !important; font-family: inherit !important; } .gr-dataframe th { background: #21262d !important; color: #58a6ff !important; border-bottom: 1px solid #30363d !important; } .gr-dataframe td { border-bottom: 1px solid #21262d !important; } /* ── Labels ───────────────────────────────── */ label, .gr-form label, .gr-label { color: #8b949e !important; font-size: 11px !important; letter-spacing: 0.5px !important; text-transform: uppercase !important; } /* ── Accordion ────────────────────────────── */ .gr-accordion { border: 1px solid #30363d !important; border-radius: 8px !important; } """ HEADER_HTML = """
⚡ Tabular AutoML
AUTOMATED MACHINE LEARNING · CLASSICAL + NEURAL MODELS · BAYESIAN HPO
""" def build_app(): # Blocks title param available since gradio 3.9 try: blocks_kwargs = dict(css=CUSTOM_CSS, title="AutoML Studio") gr.Blocks(**blocks_kwargs) # test except TypeError: blocks_kwargs = dict(css=CUSTOM_CSS) with gr.Blocks(**blocks_kwargs) as app: gr.HTML(HEADER_HTML) with gr.Tabs(): # ══════════════════════════════════════════════════════════════ # TAB 1 — Upload & Explore # ══════════════════════════════════════════════════════════════ with gr.Tab("📂 Upload & Explore"): gr.Markdown("### Upload your CSV dataset to get started") with gr.Row(): with gr.Column(scale=1): upload_btn = gr.File(label="📂 Drop CSV here (or click to browse)") target_dd = gr.Dropdown(label="Target Column", choices=[], interactive=True) explore_btn = gr.Button("🔍 Analyze Selected Column", variant="secondary") with gr.Column(scale=2): dataset_summary = gr.HTML(value="

Upload a CSV to see summary.

") col_plot = gr.Image(label="Column Distribution") upload_btn.change( fn=handle_upload, inputs=[upload_btn], outputs=[target_dd, dataset_summary, col_plot] ) explore_btn.click( fn=show_column_stats, inputs=[target_dd], outputs=[col_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 2 — Configure & Train # ══════════════════════════════════════════════════════════════ with gr.Tab("🚀 Configure & Train"): gr.Markdown("### Training Configuration") with gr.Row(): with gr.Column(scale=1): task_radio = gr.Radio(["classification", "regression"], label="Task Type", value="regression") time_budget = gr.Number(label="Time Budget (seconds, blank = unlimited)", value=300, precision=0) n_trials = gr.Slider(3, 30, value=15, step=1, label="Optuna HPO Trials per Model") use_fe = gr.Checkbox(label="Enable Feature Engineering", value=True) with gr.Column(scale=1): val_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Validation Split Size") test_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Test Split Size") seed_in = gr.Number(label="Random Seed", value=42, precision=0) train_btn = gr.Button("▶ Start Training", variant="primary") with gr.Row(): train_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("📋 Training Log", open=True): log_box = gr.Textbox(label="Live Output", lines=18, interactive=False) refresh_btn = gr.Button("↻ Refresh Log", variant="secondary") gr.Markdown("### Leaderboard Preview") lb_table = gr.DataFrame(label="Model Scores") lb_plot = gr.Image(label="Metric Comparison") train_btn.click( fn=run_training, inputs=[target_dd, task_radio, time_budget, n_trials, use_fe, val_size, test_size, seed_in], outputs=[train_status, log_box, lb_table, lb_plot] ) refresh_btn.click(fn=poll_log, inputs=[], outputs=[log_box]) # ══════════════════════════════════════════════════════════════ # TAB 3 — Results & Metrics # ══════════════════════════════════════════════════════════════ with gr.Tab("📊 Results & Metrics"): gr.Markdown("### Best Model Performance") results_btn = gr.Button("Load Results", variant="secondary") with gr.Row(): with gr.Column(scale=1): best_card = gr.HTML(value="

Run training first.

") with gr.Column(scale=1): metrics_plot = gr.Image(label="Metrics Chart") gr.Markdown("### PyTorch Training Curves") curves_plot = gr.Image(label="Loss Curves (neural models only)") results_btn.click( fn=build_results_tab, inputs=[], outputs=[best_card, metrics_plot, curves_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 4 — Feature Importance # ══════════════════════════════════════════════════════════════ with gr.Tab("🔍 Feature Importance"): gr.Markdown("### SHAP / Model-Based Feature Importance") with gr.Row(): top_k_slider = gr.Slider(5, 40, value=15, step=1, label="Top K Features") imp_btn = gr.Button("Generate Importance Plot", variant="primary") importance_plot = gr.Image(label="Feature Importances") imp_btn.click( fn=build_importance_tab, inputs=[top_k_slider], outputs=[importance_plot] ) top_k_slider.change( fn=build_importance_tab, inputs=[top_k_slider], outputs=[importance_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 5 — Predict # ══════════════════════════════════════════════════════════════ with gr.Tab("🎯 Predict"): gr.Markdown("### Batch Prediction (CSV file)") with gr.Row(): with gr.Column(): pred_file = gr.File(label="Upload CSV for prediction") pred_btn = gr.Button("Run Prediction", variant="primary") with gr.Column(): pred_result = gr.HTML() pred_dl = gr.File(label="Download Predictions CSV") pred_btn.click( fn=predict_on_file, inputs=[pred_file], outputs=[pred_dl, pred_result] ) gr.Markdown("---") gr.Markdown("### Manual Single-Row Prediction") gr.Markdown("Paste a JSON object with feature values:") with gr.Row(): with gr.Column(): template_btn = gr.Button("📋 Load Sample Row Template", variant="secondary") manual_json = gr.Textbox(label="Input JSON", lines=10, placeholder='{"feature_0": 1.23, "cat_A": "high", ...}') manual_btn = gr.Button("Predict", variant="primary") with gr.Column(): manual_out = gr.HTML() template_btn.click( fn=build_manual_input_template, inputs=[], outputs=[manual_json] ) manual_btn.click( fn=predict_manual, inputs=[manual_json], outputs=[manual_out] ) # ══════════════════════════════════════════════════════════════ # TAB 6 — Dataset Analysis # ══════════════════════════════════════════════════════════════ with gr.Tab("🧬 Dataset Analysis"): gr.Markdown("### Automated Dataset Visualizations") analysis_btn = gr.Button("Generate Analysis Plots", variant="primary") with gr.Row(): missing_plot = gr.Image(label="Missing Values") with gr.Row(): corr_plot = gr.Image(label="Correlation Matrix") dtype_plot = gr.Image(label="Feature Types") analysis_btn.click( fn=build_analysis_plots, inputs=[], outputs=[missing_plot, corr_plot, dtype_plot] ) # Footer gr.HTML("""
TABULAR AUTOML · SKLEARN + PYTORCH · OPTUNA HPO · SHAP EXPLAINABILITY
""") return app # ───────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": print("\n" + "="*55) print(" AutoML Gradio Interface") print(f" Gradio version: {_gr_check.__version__}") print("="*55 + "\n") app = build_app() # Detect if running on HF Spaces import os on_hf_spaces = os.environ.get("SPACE_ID") is not None if on_hf_spaces: # HF Spaces: minimal launch args app.launch(ssr_mode=False) else: # Local: full args launch_kwargs = dict( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, ) # ssr_mode only exists in Gradio 5+ if _gr_version[0] >= 5: launch_kwargs["ssr_mode"] = False app.launch(**launch_kwargs)