#!/usr/bin/env python3 """Custom web application for the drum sample extractor. Run with: uvicorn app:app --host 0.0.0.0 --port 7860 """ from __future__ import annotations import asyncio import json import shutil import time import traceback import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from pathlib import Path from threading import Lock from typing import Any from fastapi import Body, FastAPI, File, Form, HTTPException, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pipeline_runner import PipelineParams, SPLEETER_MODELS, SPLEETER_STEMS, SEPARATION_BACKENDS, clear_disk_cache, initial_stages, run_extraction_pipeline, separation_runtime_status from sample_extractor import DEMUCS_MODELS, DEMUCS_STEMS, cache_clear from supervised_state import ( accept_suggestion, explain_cluster as build_cluster_explanation, force_onset as apply_force_onset, load_or_create_state, lock_cluster as apply_cluster_lock, move_hit as apply_hit_move, public_state, pull_hit_to_new_cluster, reject_suggestion, restore_hit as apply_hit_restore, set_hit_review_status, suppress_hit as apply_hit_suppression, draw_next_representative as apply_draw_next_representative, edit_hit_timing as apply_hit_timing_edit, undo_last as apply_undo, ) from supervised_export import export_selected_samples, export_supervised_state ROOT = Path(__file__).resolve().parent WEB_DIR = ROOT / "web" RUNS_DIR = ROOT / ".runs" RUNS_DIR.mkdir(exist_ok=True) app = FastAPI(title="Drum Sample Extractor", version="15.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) executor = ThreadPoolExecutor(max_workers=1) jobs_lock = Lock() jobs: dict[str, dict[str, Any]] = {} def _job_url(job_id: str, relative_path: str) -> str: return f"/api/jobs/{job_id}/files/{relative_path}" def _serialise_export(job_id: str, export_manifest: dict[str, Any]) -> dict[str, Any]: payload = dict(export_manifest) payload["file_urls"] = {key: _job_url(job_id, path) for key, path in payload.get("files", {}).items()} payload["samples"] = [ {**sample, "url": _job_url(job_id, sample["file"])} for sample in payload.get("samples", []) ] return payload def _serialise_job(job: dict[str, Any]) -> dict[str, Any]: payload = {key: value for key, value in job.items() if key not in {"input_path", "output_dir"}} if payload.get("progress") is None: payload["progress"] = _progress_from_stages(payload.get("stages", []), payload.get("status")) if payload.get("partial_samples"): payload["partial_samples"] = [ {**sample, "url": _job_url(job["id"], sample["file"])} for sample in payload.get("partial_samples", []) ] if payload.get("result"): result = dict(payload["result"]) result["file_urls"] = {key: _job_url(job["id"], path) for key, path in result.get("files", {}).items()} result["samples"] = [ {**sample, "url": _job_url(job["id"], sample["file"])} for sample in result.get("samples", []) ] result["hits"] = [ {**hit, "url": _job_url(job["id"], hit["file"])} for hit in result.get("hits", []) ] payload["result"] = result return payload def _progress_from_stages(stages: list[dict[str, Any]], status: str | None = None) -> dict[str, Any]: running = next((stage for stage in stages if stage.get("status") == "running"), None) running_progress = float(running.get("progress") or 0.0) if running else 0.0 total = 0.0 completed_units = 0.0 completed = 0 for stage in stages: units = float(stage.get("work_total") or 1) total += units if stage.get("status") == "done": completed += 1 completed_units += units elif stage.get("status") == "running": completed_units += float(stage.get("work_done")) if stage.get("work_total") and stage.get("work_done") is not None else running_progress * units fraction = min(1.0, max(0.0, completed_units / max(total, 1.0))) if status == "complete": fraction = 1.0 completed_units = total return { "fraction": round(fraction, 6), "completed_steps": completed, "total_steps": len(stages), "completed_units": round(completed_units, 6), "total_units": round(total, 6), "stage_key": running.get("key") if running else None, "stage_label": running.get("label") if running else None, "stage_fraction": round(running_progress, 6), "stage_work_done": running.get("work_done") if running else None, "stage_work_total": running.get("work_total") if running else None, "basis": "exact completed work units: Demucs chunks when available, otherwise stage boundary units; no time-based estimates", } def _manifest_path(job_id: str) -> Path: return RUNS_DIR / job_id / "output" / "manifest.json" def _read_manifest_job(job_id: str) -> dict[str, Any] | None: manifest = _manifest_path(job_id) if not manifest.exists(): return None result = json.loads(manifest.read_text(encoding="utf-8")) return { "id": job_id, "status": "complete", "filename": result.get("source", {}).get("filename"), "params": result.get("params", {}), "stages": result.get("stages", []), "progress": _progress_from_stages(result.get("stages", []), "complete"), "logs": [], "result": result, "error": None, "traceback": None, "output_dir": str(manifest.parent), } def _summarise_job(job: dict[str, Any]) -> dict[str, Any]: result = job.get("result") or {} return { "id": job["id"], "status": job.get("status"), "filename": job.get("filename"), "created_at": job.get("created_at"), "duration_sec": result.get("duration_sec"), "audio_duration_sec": result.get("audio_duration_sec"), "realtime_factor": result.get("realtime_factor"), "bpm": result.get("bpm"), "hit_count": result.get("hit_count"), "cluster_count": result.get("cluster_count"), "clustering_mode": (result.get("params") or job.get("params") or {}).get("clustering_mode"), "stem": (result.get("params") or job.get("params") or {}).get("stem"), "error": job.get("error"), } def _list_manifest_jobs(limit: int = 50) -> list[dict[str, Any]]: rows: list[dict[str, Any]] = [] for manifest in sorted(RUNS_DIR.glob("*/output/manifest.json"), key=lambda p: p.stat().st_mtime, reverse=True): job_id = manifest.parents[1].name manifest_job = _read_manifest_job(job_id) if not manifest_job: continue summary = _summarise_job(manifest_job) summary["created_at"] = manifest.stat().st_mtime rows.append(summary) if len(rows) >= limit: break return rows def _update_job(job_id: str, **patch: Any) -> None: with jobs_lock: jobs[job_id].update(patch) def _append_log(job_id: str, message: str) -> None: with jobs_lock: jobs[job_id].setdefault("logs", []).append(message) def _run_job(job_id: str) -> None: with jobs_lock: job = jobs[job_id] input_path = Path(job["input_path"]) output_dir = Path(job["output_dir"]) params = job["params"] job["status"] = "running" def progress(event: dict[str, Any]) -> None: patch: dict[str, Any] = {} if "stages" in event: patch["stages"] = event["stages"] if "progress" in event: patch["progress"] = event["progress"] elif "stages" in event: patch["progress"] = _progress_from_stages(event["stages"], "running") if event.get("type") == "sample": patch["partial_samples"] = event.get("samples", []) if patch: _update_job(job_id, **patch) if event.get("type") == "sample" and event.get("sample"): sample = event["sample"] _append_log(job_id, f"Sample ready: {sample.get('label')} ({sample.get('classification')})") if event.get("stage"): stage = event["stage"] if stage.get("status") == "running": _append_log(job_id, f"Started: {stage['label']}") elif stage.get("status") == "done": detail = f" ยท {stage['detail']}" if stage.get("detail") else "" _append_log(job_id, f"Finished: {stage['label']} in {stage['duration_sec']:.3f}s{detail}") try: result = run_extraction_pipeline(input_path, output_dir, PipelineParams.from_mapping(params), progress_cb=progress) load_or_create_state(job_id, output_dir) _update_job(job_id, status="complete", result=asdict(result), error=None) except Exception as exc: # deliberately explicit for UI diagnostics _update_job(job_id, status="error", error=str(exc), traceback=traceback.format_exc()) _append_log(job_id, f"Error: {exc}") @app.get("/api/health") def health() -> dict[str, str]: return {"status": "ok"} @app.get("/api/config") def config() -> dict[str, Any]: return { "separation_backends": SEPARATION_BACKENDS, "spleeter_models": SPLEETER_MODELS, "spleeter_stems": {key: value + ["all"] for key, value in SPLEETER_STEMS.items()}, "demucs_models": DEMUCS_MODELS, "demucs_stems": {key: value + ["all"] for key, value in DEMUCS_STEMS.items()}, "defaults": asdict(PipelineParams()), "runtime": separation_runtime_status(), "stages": initial_stages(), "clustering_modes": { "batch_quality": "Batch quality: all-pairs mel/NCC + agglomerative clustering", "online_preview": "Online preview: prototype assignment for near-realtime feedback", }, } @app.post("/api/cache/clear") def clear_cache() -> dict[str, str]: cache_clear() clear_disk_cache() return {"status": "cleared", "scope": "memory+disk"} @app.get("/api/jobs") def list_jobs(limit: int = 50) -> dict[str, Any]: limit = max(1, min(int(limit), 200)) with jobs_lock: active = [_summarise_job(dict(job)) for job in jobs.values() if job.get("status") != "complete"] history = _list_manifest_jobs(limit=limit) return {"active": active, "history": history} @app.post("/api/jobs") async def create_job(file: UploadFile = File(...), params: str = Form("{}")) -> JSONResponse: filename = file.filename or "input.wav" suffix = Path(filename).suffix.lower() if suffix and suffix not in {".wav", ".mp3", ".flac", ".aiff", ".aif", ".ogg", ".m4a", ".aac"}: raise HTTPException(status_code=400, detail=f"Unsupported audio file extension: {suffix}") try: parsed_params = json.loads(params) validated = PipelineParams.from_mapping(parsed_params) except json.JSONDecodeError as exc: raise HTTPException(status_code=400, detail=f"Invalid params JSON: {exc.msg}") from exc except ValueError as exc: raise HTTPException(status_code=400, detail=f"Invalid extraction parameter: {exc}") from exc except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid extraction request: {exc}") from exc job_id = uuid.uuid4().hex[:12] job_dir = RUNS_DIR / job_id input_dir = job_dir / "input" output_dir = job_dir / "output" input_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True) suffix = suffix or ".wav" input_path = input_dir / f"source{suffix}" with input_path.open("wb") as handle: shutil.copyfileobj(file.file, handle) job = { "id": job_id, "status": "pending", "filename": file.filename, "created_at": time.time(), "params": asdict(validated), "stages": initial_stages(), "progress": _progress_from_stages(initial_stages(), "pending"), "logs": [], "result": None, "partial_samples": [], "error": None, "traceback": None, "input_path": str(input_path), "output_dir": str(output_dir), } with jobs_lock: jobs[job_id] = job executor.submit(_run_job, job_id) return JSONResponse(_serialise_job(job), status_code=202) @app.get("/api/jobs/{job_id}") def get_job(job_id: str) -> dict[str, Any]: with jobs_lock: job = jobs.get(job_id) if job: return _serialise_job(dict(job)) manifest_job = _read_manifest_job(job_id) if manifest_job: return _serialise_job(manifest_job) raise HTTPException(status_code=404, detail="Job not found") def _job_output_dir(job_id: str) -> Path: with jobs_lock: job = jobs.get(job_id) if job and job.get("output_dir"): return Path(job["output_dir"]) manifest = _manifest_path(job_id) if manifest.exists(): return manifest.parent raise HTTPException(status_code=404, detail="Job not found") def _state_payload(job_id: str) -> dict[str, Any]: out = _job_output_dir(job_id) try: state = load_or_create_state(job_id, out) except FileNotFoundError as exc: raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return public_state(state, url_for=lambda rel: _job_url(job_id, rel)) def _json_patch(payload: dict[str, Any] | None) -> dict[str, Any]: return dict(payload or {}) def _state_for_mutation(job_id: str) -> tuple[Path, dict[str, Any]]: out = _job_output_dir(job_id) try: return out, load_or_create_state(job_id, out) except FileNotFoundError as exc: raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc def _cluster_id_for_sample_label(state: dict[str, Any], sample_label: str) -> str: clusters = state.get("clusters", {}) exact = [cid for cid, cluster in clusters.items() if str(cluster.get("label")) == str(sample_label)] if exact: return exact[0] # Fall back to a classification/base-name match for labels that have been renamed by user edits. base = str(sample_label).rsplit("_", 1)[0] fuzzy = [cid for cid, cluster in clusters.items() if str(cluster.get("classification") or "") == base] if len(fuzzy) == 1: return fuzzy[0] raise HTTPException(status_code=404, detail=f"Sample label not found in current state: {sample_label}") def _public_sample_from_cluster(job_id: str, state: dict[str, Any], cluster_id: str, label_override: str | None = None) -> dict[str, Any]: clusters = state.get("clusters", {}) hits = state.get("hits", {}) if cluster_id not in clusters: raise HTTPException(status_code=404, detail=f"Unknown cluster: {cluster_id}") cluster = clusters[cluster_id] active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in hits and not hits[hid].get("suppressed")] if not active_ids: raise HTTPException(status_code=409, detail=f"Cluster {cluster.get('label', cluster_id)} has no active hits") rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else active_ids[0] hit = hits[rep_id] raw_cluster_id = cluster_id.split(":", 1)[1] if ":" in cluster_id else cluster_id try: raw_cluster_id_value: int | str = int(raw_cluster_id) except ValueError: raw_cluster_id_value = raw_cluster_id file_path = str(hit.get("file") or "") first_onset = min(float(hits[hid].get("onset_sec") or 0.0) for hid in active_ids) return { "label": label_override or cluster.get("label") or str(cluster_id), "classification": cluster.get("classification") or str(cluster.get("label") or "other").rsplit("_", 1)[0], "hits": len(active_ids), "midi_note": cluster.get("midi_note", 60), "score": "edited", "duration_ms": round(float(hit.get("duration_ms") or 0.0), 1), "first_onset_sec": round(first_onset, 4), "representative_hit_index": int(hit.get("index") or 0), "state_hit_id": rep_id, "cluster_id": raw_cluster_id_value, "state_cluster_id": cluster_id, "file": file_path, "url": _job_url(job_id, file_path) if file_path else None, } @app.get("/api/jobs/{job_id}/events") def get_job_events(job_id: str) -> StreamingResponse: with jobs_lock: exists_in_memory = job_id in jobs exists_on_disk = _read_manifest_job(job_id) is not None if not exists_in_memory and not exists_on_disk: raise HTTPException(status_code=404, detail="Job not found") async def event_stream(): last_payload: str | None = None while True: with jobs_lock: memory_job = jobs.get(job_id) job = dict(memory_job) if memory_job else None if job is None: job = _read_manifest_job(job_id) if job is None: payload = {"id": job_id, "status": "error", "error": "Job disappeared"} else: payload = _serialise_job(job) encoded = json.dumps(payload, sort_keys=True) if encoded != last_payload: yield f"event: job\ndata: {encoded}\n\n" last_payload = encoded if payload.get("status") in {"complete", "error"}: break await asyncio.sleep(0.5) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", }, ) @app.get("/api/jobs/{job_id}/state") def get_job_state(job_id: str) -> dict[str, Any]: return _state_payload(job_id) @app.post("/api/jobs/{job_id}/export") def post_supervised_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: patch = _json_patch(payload) try: export_manifest = export_supervised_state( _job_output_dir(job_id), job_id, synthesize=bool(patch.get("synthesize", True)), quantize=patch.get("quantize"), subdivision=patch.get("subdivision"), ) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)} @app.post("/api/jobs/{job_id}/export-selected") def post_selected_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: patch = _json_patch(payload) labels = [str(item) for item in patch.get("labels", []) if str(item).strip()] if not labels: raise HTTPException(status_code=400, detail="labels must contain at least one selected sample label") try: export_manifest = export_selected_samples( _job_output_dir(job_id), job_id, selected_labels=labels, synthesize=bool(patch.get("synthesize", True)), quantize=patch.get("quantize"), subdivision=patch.get("subdivision"), ) except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)} @app.post("/api/jobs/{job_id}/samples/{sample_label:path}/draw") def post_draw_sample(job_id: str, sample_label: str) -> dict[str, Any]: try: out, state = _state_for_mutation(job_id) cluster_id = _cluster_id_for_sample_label(state, sample_label) state = apply_draw_next_representative(out, job_id, cluster_id, source="sample-card") return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))} except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc except HTTPException: raise except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @app.post("/api/jobs/{job_id}/samples/{sample_label:path}/edit") def post_edit_sample(job_id: str, sample_label: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: patch = _json_patch(payload) try: out, state = _state_for_mutation(job_id) cluster_id = _cluster_id_for_sample_label(state, sample_label) cluster = state.get("clusters", {}).get(cluster_id) or {} active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in state.get("hits", {}) and not state["hits"][hid].get("suppressed")] rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else (active_ids[0] if active_ids else None) if not rep_id: raise HTTPException(status_code=409, detail=f"Sample {sample_label} has no active representative hit") state = apply_hit_timing_edit( out, job_id, rep_id, start_offset_ms=float(patch.get("start_offset_ms", 0.0)), tail_offset_ms=float(patch.get("tail_offset_ms", 0.0)), source="sample-card", ) return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))} except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except HTTPException: raise except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc @app.post("/api/jobs/{job_id}/hits/force-onset") def post_force_onset(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: patch = _json_patch(payload) if "onset_sec" not in patch: raise HTTPException(status_code=400, detail="onset_sec is required") try: apply_force_onset( _job_output_dir(job_id), job_id, float(patch["onset_sec"]), duration_ms=patch.get("duration_ms"), label=patch.get("label"), target_cluster_id=patch.get("target_cluster_id"), pre_pad_sec=float(patch.get("pre_pad_sec", 0.003)), ) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/hits/{hit_id}/move") def post_move_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: target_cluster_id = _json_patch(payload).get("target_cluster_id") if not target_cluster_id: raise HTTPException(status_code=400, detail="target_cluster_id is required") try: apply_hit_move(_job_output_dir(job_id), job_id, hit_id, str(target_cluster_id)) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/hits/{hit_id}/pull-out") def post_pull_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: label = _json_patch(payload).get("label") try: pull_hit_to_new_cluster(_job_output_dir(job_id), job_id, hit_id, label=label) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/hits/{hit_id}/suppress") def post_suppress_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: reason = str(_json_patch(payload).get("reason") or "bleed") try: apply_hit_suppression(_job_output_dir(job_id), job_id, hit_id, reason=reason) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/hits/{hit_id}/restore") def post_restore_hit(job_id: str, hit_id: str) -> dict[str, Any]: try: apply_hit_restore(_job_output_dir(job_id), job_id, hit_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/hits/{hit_id}/review") def post_review_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: status = str(_json_patch(payload).get("status") or "accepted") try: set_hit_review_status(_job_output_dir(job_id), job_id, hit_id, status=status) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/clusters/{cluster_id:path}/lock") def post_lock_cluster(job_id: str, cluster_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: locked = bool(_json_patch(payload).get("locked", True)) try: apply_cluster_lock(_job_output_dir(job_id), job_id, cluster_id, locked=locked) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.get("/api/jobs/{job_id}/suggestions") def get_suggestions(job_id: str) -> dict[str, Any]: state = _state_payload(job_id) return {"suggestions": state.get("suggestions", []), "summary": state.get("summary", {})} @app.post("/api/jobs/{job_id}/suggestions/{suggestion_id}/accept") def post_accept_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]: try: accept_suggestion(_job_output_dir(job_id), job_id, suggestion_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except ValueError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.post("/api/jobs/{job_id}/suggestions/{suggestion_id}/reject") def post_reject_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]: try: reject_suggestion(_job_output_dir(job_id), job_id, suggestion_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.get("/api/jobs/{job_id}/explain/cluster/{cluster_id:path}") def get_cluster_explanation(job_id: str, cluster_id: str) -> dict[str, Any]: out = _job_output_dir(job_id) state = load_or_create_state(job_id, out) try: explanation = build_cluster_explanation(state, cluster_id) except KeyError as exc: raise HTTPException(status_code=404, detail=str(exc)) from exc return explanation @app.post("/api/jobs/{job_id}/undo") def post_undo(job_id: str) -> dict[str, Any]: try: apply_undo(_job_output_dir(job_id), job_id) except Exception as exc: raise HTTPException(status_code=500, detail=str(exc)) from exc return _state_payload(job_id) @app.get("/api/jobs/{job_id}/files/{relative_path:path}") def get_job_file(job_id: str, relative_path: str) -> FileResponse: root = (RUNS_DIR / job_id / "output").resolve() path = (root / relative_path).resolve() try: path.relative_to(root) except ValueError as exc: raise HTTPException(status_code=404, detail="File not found") from exc if not path.exists() or not path.is_file(): raise HTTPException(status_code=404, detail="File not found") return FileResponse(path) if WEB_DIR.exists(): app.mount("/web", StaticFiles(directory=WEB_DIR), name="web") @app.get("/") def index() -> FileResponse: index_path = WEB_DIR / "index.html" if not index_path.exists(): raise HTTPException(status_code=500, detail="web/index.html is missing") return FileResponse(index_path)