""" app.py — Gradio Interface for the Tabular AutoML Framework Run: python app.py """ import gradio as gr import pandas as pd import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from matplotlib.gridspec import GridSpec import io, os, sys, time, threading, json, textwrap from pathlib import Path # ── make sure the automl package is importable ──────────────────────────────── sys.path.insert(0, str(Path(__file__).parent)) from automl import AutoML # ── Gradio version compatibility check ─────────────────────────────────────── import gradio as _gr_check _gr_version = tuple(int(x) for x in _gr_check.__version__.split(".")[:2]) _is_gradio_4_plus = _gr_version[0] >= 4 _is_gradio_6_plus = _gr_version[0] >= 6 print(f" Gradio version: {_gr_check.__version__}") # gr.File type param: needed in 3.x only _FILE_KWARGS = {} if _is_gradio_4_plus else {"type": "file"} # ── Global state ────────────────────────────────────────────────────────────── _state: dict = { "automl": None, "df": None, "log_lines": [], "running": False, } # ── File path compatibility helper (Gradio 3.x / 4.x / 5.x / 6.x) ────────── def _get_filepath(file): """Handle gr.File output across ALL Gradio versions including 6.x.""" if file is None: return None # Gradio 6.x: returns plain string filepath directly if isinstance(file, str): return file # Gradio 6.x: sometimes returns a list (multiple files) if isinstance(file, list): return file[0] if file else None # Gradio 3.x early: returns dict if isinstance(file, dict): return file.get("name") or file.get("path") or file.get("tmp_path") # Gradio 4.x / 5.x: UploadData with .path attribute if hasattr(file, "path"): return file.path # Gradio 3.x late: object with .name if hasattr(file, "name"): return file.name # Last resort return str(file) PALETTE = { "bg": "#0d1117", "surface": "#161b22", "border": "#30363d", "accent": "#58a6ff", "green": "#3fb950", "yellow": "#d29922", "red": "#f85149", "text": "#e6edf3", "muted": "#8b949e", } # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def _fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor(), dpi=130) buf.seek(0) from PIL import Image img = Image.open(buf).copy() buf.close() plt.close(fig) return img def _styled_fig(w=12, h=6): fig = plt.figure(figsize=(w, h), facecolor=PALETTE["bg"]) return fig def _ax_style(ax, title="", xlabel="", ylabel=""): ax.set_facecolor(PALETTE["surface"]) ax.tick_params(colors=PALETTE["muted"], labelsize=9) ax.spines[:].set_color(PALETTE["border"]) if title: ax.set_title(title, color=PALETTE["text"], fontsize=11, pad=10, fontweight="bold") if xlabel: ax.set_xlabel(xlabel, color=PALETTE["muted"], fontsize=9) if ylabel: ax.set_ylabel(ylabel, color=PALETTE["muted"], fontsize=9) ax.tick_params(axis="x", colors=PALETTE["muted"]) ax.tick_params(axis="y", colors=PALETTE["muted"]) return ax # ───────────────────────────────────────────────────────────────────────────── # Tab 1 — Upload & Preview # ───────────────────────────────────────────────────────────────────────────── def handle_upload(file): if file is None: return (gr.update(choices=[], value=None), gr.update(value="
Upload a CSV to see summary.
"), gr.update(value=None)) try: print(f" [Upload] file type: {type(file)}, value: {repr(file)[:200]}") filepath = _get_filepath(file) print(f" [Upload] resolved filepath: {filepath}") if filepath is None: return (gr.update(choices=[], value=None), gr.update(value="Could not read file. Try uploading again.
"), gr.update(value=None)) df = pd.read_csv(filepath) _state["df"] = df cols = df.columns.tolist() # Build a rich HTML preview n_rows, n_cols = df.shape missing = df.isnull().sum().sum() dtypes = df.dtypes.value_counts().to_dict() dtype_str = ", ".join(f"{v}× {k}" for k, v in dtypes.items()) summary_html = f"""No trained model. Run training first.
") if file is None: return (None, "Upload a CSV to predict on.
") try: filepath = _get_filepath(file) new_df = pd.read_csv(filepath) preds = am.predict(new_df) new_df["prediction"] = preds out_path = "./automl_output/predictions.csv" os.makedirs("./automl_output", exist_ok=True) new_df.to_csv(out_path, index=False) n = len(preds) if am.task_type == "regression": html = f"""Error: {e}
") def predict_manual(vals_json): am = _state.get("automl") df = _state.get("df") if am is None or df is None: return "Train a model first.
" try: row = json.loads(vals_json) input_df = pd.DataFrame([row]) pred = am.predict(input_df) val = pred[0] if hasattr(pred, "__len__") else pred return f"""Error: {e}
" def build_manual_input_template(): df = _state.get("df") am = _state.get("automl") if df is None or am is None: return "{}" feature_cols = [c for c in df.columns if c != am.best_model_name] sample = df.drop(columns=[am.best_model_name] if am.best_model_name in df.columns else [], errors="ignore").iloc[0].to_dict() # Clean non-serialisable types clean = {k: (float(v) if isinstance(v, (np.floating, np.integer)) else v) for k, v in sample.items() if not isinstance(v, float) or not np.isnan(v)} return json.dumps(clean, indent=2) # ───────────────────────────────────────────────────────────────────────────── # Tab 6 — Dataset Analysis Visuals # ───────────────────────────────────────────────────────────────────────────── def build_analysis_plots(): df = _state.get("df") if df is None: return (None, None, None) # 1. Missing values heatmap style bar missing = df.isnull().mean() * 100 missing = missing[missing > 0].sort_values(ascending=False) fig1 = _styled_fig(10, max(3, len(missing) * 0.4 + 1.5)) if not missing.empty: ax = fig1.add_subplot(111) _ax_style(ax, title="Missing Values by Column (%)", xlabel="Missing %") cols_m = [PALETTE["red"] if v > 20 else PALETTE["yellow"] if v > 5 else PALETTE["accent"] for v in missing.values] bars = ax.barh(missing.index, missing.values, color=cols_m, alpha=0.85) for bar, v in zip(bars, missing.values): ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height()/2, f"{v:.1f}%", va="center", fontsize=8, color=PALETTE["text"]) ax.set_xlim(0, max(missing.values) * 1.25) red_p = mpatches.Patch(color=PALETTE["red"], label=">20%") yel_p = mpatches.Patch(color=PALETTE["yellow"], label="5-20%") blue_p = mpatches.Patch(color=PALETTE["accent"], label="<5%") ax.legend(handles=[red_p, yel_p, blue_p], fontsize=8, facecolor=PALETTE["surface"], edgecolor=PALETTE["border"], labelcolor=PALETTE["muted"]) else: ax = fig1.add_subplot(111) _ax_style(ax, title="Missing Values") ax.text(0.5, 0.5, "✅ No missing values!", transform=ax.transAxes, ha="center", va="center", fontsize=14, color=PALETTE["green"]) fig1.tight_layout() # 2. Correlation matrix (numeric only) num_df = df.select_dtypes("number") fig2 = _styled_fig(9, 7) ax2 = fig2.add_subplot(111) if len(num_df.columns) >= 2: corr = num_df.corr() im = ax2.imshow(corr, cmap="RdYlBu_r", vmin=-1, vmax=1, aspect="auto") plt.colorbar(im, ax=ax2, fraction=0.03, pad=0.04) ticks = range(len(corr.columns)) ax2.set_xticks(ticks); ax2.set_yticks(ticks) ax2.set_xticklabels(corr.columns, rotation=45, ha="right", fontsize=7, color=PALETTE["muted"]) ax2.set_yticklabels(corr.columns, fontsize=7, color=PALETTE["muted"]) ax2.set_title("Correlation Matrix", color=PALETTE["text"], fontsize=11, fontweight="bold", pad=10) ax2.set_facecolor(PALETTE["surface"]) else: _ax_style(ax2, title="Correlation Matrix") ax2.text(0.5, 0.5, "Need ≥2 numeric columns", transform=ax2.transAxes, ha="center", va="center", color=PALETTE["muted"]) fig2.patch.set_facecolor(PALETTE["bg"]) fig2.tight_layout() # 3. Data types pie dtypes_count = {} for col in df.columns: if pd.api.types.is_numeric_dtype(df[col]): dtypes_count["Numeric"] = dtypes_count.get("Numeric", 0) + 1 elif pd.api.types.is_object_dtype(df[col]): avg_len = df[col].dropna().astype(str).str.len().mean() if avg_len > 30: dtypes_count["Text"] = dtypes_count.get("Text", 0) + 1 else: dtypes_count["Categorical"] = dtypes_count.get("Categorical", 0) + 1 else: dtypes_count["Other"] = dtypes_count.get("Other", 0) + 1 fig3 = _styled_fig(5.5, 4.5) ax3 = fig3.add_subplot(111) ax3.set_facecolor(PALETTE["bg"]) wedge_colors = [PALETTE["accent"], PALETTE["green"], PALETTE["yellow"], PALETTE["red"]][:len(dtypes_count)] wedges, texts, autotexts = ax3.pie( dtypes_count.values(), labels=dtypes_count.keys(), colors=wedge_colors, autopct="%1.0f%%", startangle=140, pctdistance=0.75, wedgeprops=dict(width=0.55, edgecolor=PALETTE["bg"], linewidth=2), ) for t in texts: t.set_color(PALETTE["muted"]); t.set_fontsize(10) for t in autotexts: t.set_color(PALETTE["bg"]); t.set_fontsize(9) ax3.set_title("Feature Type Distribution", color=PALETTE["text"], fontsize=11, fontweight="bold", pad=10) fig3.patch.set_facecolor(PALETTE["bg"]) return (_fig_to_pil(fig1), _fig_to_pil(fig2), _fig_to_pil(fig3)) # ───────────────────────────────────────────────────────────────────────────── # Build the Gradio App # ───────────────────────────────────────────────────────────────────────────── CUSTOM_CSS = """ /* ── Global ───────────────────────────────── */ body, .gradio-container { background: #0d1117 !important; font-family: 'JetBrains Mono', 'Fira Code', monospace !important; color: #e6edf3 !important; } /* ── Header ───────────────────────────────── */ .app-header { background: linear-gradient(135deg, #161b22 0%, #0d1117 100%); border-bottom: 1px solid #30363d; padding: 28px 32px 20px; margin-bottom: 8px; } .app-title { font-size: 28px; font-weight: 800; letter-spacing: -0.5px; background: linear-gradient(90deg, #58a6ff, #3fb950); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .app-subtitle { color: #8b949e; font-size: 12px; margin-top: 4px; letter-spacing: 1px; } /* ── Tabs ─────────────────────────────────── */ .tab-nav button { background: transparent !important; color: #8b949e !important; border-bottom: 2px solid transparent !important; border-radius: 0 !important; font-family: inherit !important; font-size: 12px !important; padding: 10px 18px !important; letter-spacing: 0.5px; } .tab-nav button.selected { color: #58a6ff !important; border-bottom-color: #58a6ff !important; } /* ── Inputs & Textboxes ───────────────────── */ input, textarea, select, .gr-input, .gr-textarea { background: #161b22 !important; border: 1px solid #30363d !important; color: #e6edf3 !important; border-radius: 6px !important; font-family: inherit !important; } input:focus, textarea:focus { border-color: #58a6ff !important; outline: none !important; box-shadow: 0 0 0 2px rgba(88,166,255,0.15) !important; } /* ── Buttons ──────────────────────────────── */ .gr-button-primary, button.primary { background: linear-gradient(135deg, #1f6feb, #388bfd) !important; color: white !important; border: none !important; border-radius: 6px !important; font-family: inherit !important; font-size: 13px !important; font-weight: 600 !important; padding: 10px 22px !important; letter-spacing: 0.3px; transition: opacity 0.2s !important; } .gr-button-primary:hover { opacity: 0.88 !important; } .gr-button-secondary, button.secondary { background: #21262d !important; color: #e6edf3 !important; border: 1px solid #30363d !important; border-radius: 6px !important; font-family: inherit !important; } /* ── Dropdown ─────────────────────────────── */ .gr-dropdown select { background: #161b22 !important; color: #e6edf3 !important; } /* ── Slider ───────────────────────────────── */ .gr-slider input[type=range] { accent-color: #58a6ff; } /* ── Blocks / Panels ──────────────────────── */ .gr-block, .gr-panel, .gr-box { background: #161b22 !important; border: 1px solid #30363d !important; border-radius: 8px !important; } /* ── Log textbox ──────────────────────────── */ .log-box textarea { background: #0d1117 !important; color: #3fb950 !important; font-size: 11px !important; font-family: 'JetBrains Mono', monospace !important; border: 1px solid #30363d !important; } /* ── Status badge ─────────────────────────── */ .status-box textarea { background: #161b22 !important; color: #58a6ff !important; font-weight: 600 !important; font-size: 13px !important; border: 1px solid #30363d !important; text-align: center !important; } /* ── DataFrame ────────────────────────────── */ .gr-dataframe table { background: #161b22 !important; color: #e6edf3 !important; font-size: 12px !important; font-family: inherit !important; } .gr-dataframe th { background: #21262d !important; color: #58a6ff !important; border-bottom: 1px solid #30363d !important; } .gr-dataframe td { border-bottom: 1px solid #21262d !important; } /* ── Labels ───────────────────────────────── */ label, .gr-form label, .gr-label { color: #8b949e !important; font-size: 11px !important; letter-spacing: 0.5px !important; text-transform: uppercase !important; } /* ── Accordion ────────────────────────────── */ .gr-accordion { border: 1px solid #30363d !important; border-radius: 8px !important; } """ HEADER_HTML = """Upload a CSV to see summary.
") col_plot = gr.Image(label="Column Distribution") upload_btn.change( fn=handle_upload, inputs=[upload_btn], outputs=[target_dd, dataset_summary, col_plot] ) explore_btn.click( fn=show_column_stats, inputs=[target_dd], outputs=[col_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 2 — Configure & Train # ══════════════════════════════════════════════════════════════ with gr.Tab("🚀 Configure & Train"): gr.Markdown("### Training Configuration") with gr.Row(): with gr.Column(scale=1): task_radio = gr.Radio(["classification", "regression"], label="Task Type", value="regression") time_budget = gr.Number(label="Time Budget (seconds, blank = unlimited)", value=300, precision=0) n_trials = gr.Slider(3, 30, value=15, step=1, label="Optuna HPO Trials per Model") use_fe = gr.Checkbox(label="Enable Feature Engineering", value=True) with gr.Column(scale=1): val_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Validation Split Size") test_size = gr.Slider(0.05, 0.3, value=0.15, step=0.01, label="Test Split Size") seed_in = gr.Number(label="Random Seed", value=42, precision=0) train_btn = gr.Button("▶ Start Training", variant="primary") with gr.Row(): train_status = gr.Textbox(label="Status", interactive=False) with gr.Accordion("📋 Training Log", open=True): log_box = gr.Textbox(label="Live Output", lines=18, interactive=False) refresh_btn = gr.Button("↻ Refresh Log", variant="secondary") gr.Markdown("### Leaderboard Preview") lb_table = gr.DataFrame(label="Model Scores") lb_plot = gr.Image(label="Metric Comparison") train_btn.click( fn=run_training, inputs=[target_dd, task_radio, time_budget, n_trials, use_fe, val_size, test_size, seed_in], outputs=[train_status, log_box, lb_table, lb_plot] ) refresh_btn.click(fn=poll_log, inputs=[], outputs=[log_box]) # ══════════════════════════════════════════════════════════════ # TAB 3 — Results & Metrics # ══════════════════════════════════════════════════════════════ with gr.Tab("📊 Results & Metrics"): gr.Markdown("### Best Model Performance") results_btn = gr.Button("Load Results", variant="secondary") with gr.Row(): with gr.Column(scale=1): best_card = gr.HTML(value="Run training first.
") with gr.Column(scale=1): metrics_plot = gr.Image(label="Metrics Chart") gr.Markdown("### PyTorch Training Curves") curves_plot = gr.Image(label="Loss Curves (neural models only)") results_btn.click( fn=build_results_tab, inputs=[], outputs=[best_card, metrics_plot, curves_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 4 — Feature Importance # ══════════════════════════════════════════════════════════════ with gr.Tab("🔍 Feature Importance"): gr.Markdown("### SHAP / Model-Based Feature Importance") with gr.Row(): top_k_slider = gr.Slider(5, 40, value=15, step=1, label="Top K Features") imp_btn = gr.Button("Generate Importance Plot", variant="primary") importance_plot = gr.Image(label="Feature Importances") imp_btn.click( fn=build_importance_tab, inputs=[top_k_slider], outputs=[importance_plot] ) top_k_slider.change( fn=build_importance_tab, inputs=[top_k_slider], outputs=[importance_plot] ) # ══════════════════════════════════════════════════════════════ # TAB 5 — Predict # ══════════════════════════════════════════════════════════════ with gr.Tab("🎯 Predict"): gr.Markdown("### Batch Prediction (CSV file)") with gr.Row(): with gr.Column(): pred_file = gr.File(label="Upload CSV for prediction") pred_btn = gr.Button("Run Prediction", variant="primary") with gr.Column(): pred_result = gr.HTML() pred_dl = gr.File(label="Download Predictions CSV") pred_btn.click( fn=predict_on_file, inputs=[pred_file], outputs=[pred_dl, pred_result] ) gr.Markdown("---") gr.Markdown("### Manual Single-Row Prediction") gr.Markdown("Paste a JSON object with feature values:") with gr.Row(): with gr.Column(): template_btn = gr.Button("📋 Load Sample Row Template", variant="secondary") manual_json = gr.Textbox(label="Input JSON", lines=10, placeholder='{"feature_0": 1.23, "cat_A": "high", ...}') manual_btn = gr.Button("Predict", variant="primary") with gr.Column(): manual_out = gr.HTML() template_btn.click( fn=build_manual_input_template, inputs=[], outputs=[manual_json] ) manual_btn.click( fn=predict_manual, inputs=[manual_json], outputs=[manual_out] ) # ══════════════════════════════════════════════════════════════ # TAB 6 — Dataset Analysis # ══════════════════════════════════════════════════════════════ with gr.Tab("🧬 Dataset Analysis"): gr.Markdown("### Automated Dataset Visualizations") analysis_btn = gr.Button("Generate Analysis Plots", variant="primary") with gr.Row(): missing_plot = gr.Image(label="Missing Values") with gr.Row(): corr_plot = gr.Image(label="Correlation Matrix") dtype_plot = gr.Image(label="Feature Types") analysis_btn.click( fn=build_analysis_plots, inputs=[], outputs=[missing_plot, corr_plot, dtype_plot] ) # Footer gr.HTML("""