ChatGPT
fix: make upload and fallback robust
5a90820
#!/usr/bin/env python3
"""Custom web application for the drum sample extractor.
Run with:
uvicorn app:app --host 0.0.0.0 --port 7860
"""
from __future__ import annotations
import asyncio
import json
import shutil
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
from pathlib import Path
from threading import Lock
from typing import Any
from fastapi import Body, FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pipeline_runner import PipelineParams, SPLEETER_MODELS, SPLEETER_STEMS, SEPARATION_BACKENDS, clear_disk_cache, initial_stages, run_extraction_pipeline, separation_runtime_status
from sample_extractor import DEMUCS_MODELS, DEMUCS_STEMS, cache_clear
from supervised_state import (
accept_suggestion,
explain_cluster as build_cluster_explanation,
force_onset as apply_force_onset,
load_or_create_state,
lock_cluster as apply_cluster_lock,
move_hit as apply_hit_move,
public_state,
pull_hit_to_new_cluster,
reject_suggestion,
restore_hit as apply_hit_restore,
set_hit_review_status,
suppress_hit as apply_hit_suppression,
draw_next_representative as apply_draw_next_representative,
edit_hit_timing as apply_hit_timing_edit,
undo_last as apply_undo,
)
from supervised_export import export_selected_samples, export_supervised_state
ROOT = Path(__file__).resolve().parent
WEB_DIR = ROOT / "web"
RUNS_DIR = ROOT / ".runs"
RUNS_DIR.mkdir(exist_ok=True)
app = FastAPI(title="Drum Sample Extractor", version="15.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
executor = ThreadPoolExecutor(max_workers=1)
jobs_lock = Lock()
jobs: dict[str, dict[str, Any]] = {}
def _job_url(job_id: str, relative_path: str) -> str:
return f"/api/jobs/{job_id}/files/{relative_path}"
def _serialise_export(job_id: str, export_manifest: dict[str, Any]) -> dict[str, Any]:
payload = dict(export_manifest)
payload["file_urls"] = {key: _job_url(job_id, path) for key, path in payload.get("files", {}).items()}
payload["samples"] = [
{**sample, "url": _job_url(job_id, sample["file"])}
for sample in payload.get("samples", [])
]
return payload
def _serialise_job(job: dict[str, Any]) -> dict[str, Any]:
payload = {key: value for key, value in job.items() if key not in {"input_path", "output_dir"}}
if payload.get("progress") is None:
payload["progress"] = _progress_from_stages(payload.get("stages", []), payload.get("status"))
if payload.get("partial_samples"):
payload["partial_samples"] = [
{**sample, "url": _job_url(job["id"], sample["file"])}
for sample in payload.get("partial_samples", [])
]
if payload.get("result"):
result = dict(payload["result"])
result["file_urls"] = {key: _job_url(job["id"], path) for key, path in result.get("files", {}).items()}
result["samples"] = [
{**sample, "url": _job_url(job["id"], sample["file"])}
for sample in result.get("samples", [])
]
result["hits"] = [
{**hit, "url": _job_url(job["id"], hit["file"])}
for hit in result.get("hits", [])
]
payload["result"] = result
return payload
def _progress_from_stages(stages: list[dict[str, Any]], status: str | None = None) -> dict[str, Any]:
running = next((stage for stage in stages if stage.get("status") == "running"), None)
running_progress = float(running.get("progress") or 0.0) if running else 0.0
total = 0.0
completed_units = 0.0
completed = 0
for stage in stages:
units = float(stage.get("work_total") or 1)
total += units
if stage.get("status") == "done":
completed += 1
completed_units += units
elif stage.get("status") == "running":
completed_units += float(stage.get("work_done")) if stage.get("work_total") and stage.get("work_done") is not None else running_progress * units
fraction = min(1.0, max(0.0, completed_units / max(total, 1.0)))
if status == "complete":
fraction = 1.0
completed_units = total
return {
"fraction": round(fraction, 6),
"completed_steps": completed,
"total_steps": len(stages),
"completed_units": round(completed_units, 6),
"total_units": round(total, 6),
"stage_key": running.get("key") if running else None,
"stage_label": running.get("label") if running else None,
"stage_fraction": round(running_progress, 6),
"stage_work_done": running.get("work_done") if running else None,
"stage_work_total": running.get("work_total") if running else None,
"basis": "exact completed work units: Demucs chunks when available, otherwise stage boundary units; no time-based estimates",
}
def _manifest_path(job_id: str) -> Path:
return RUNS_DIR / job_id / "output" / "manifest.json"
def _read_manifest_job(job_id: str) -> dict[str, Any] | None:
manifest = _manifest_path(job_id)
if not manifest.exists():
return None
result = json.loads(manifest.read_text(encoding="utf-8"))
return {
"id": job_id,
"status": "complete",
"filename": result.get("source", {}).get("filename"),
"params": result.get("params", {}),
"stages": result.get("stages", []),
"progress": _progress_from_stages(result.get("stages", []), "complete"),
"logs": [],
"result": result,
"error": None,
"traceback": None,
"output_dir": str(manifest.parent),
}
def _summarise_job(job: dict[str, Any]) -> dict[str, Any]:
result = job.get("result") or {}
return {
"id": job["id"],
"status": job.get("status"),
"filename": job.get("filename"),
"created_at": job.get("created_at"),
"duration_sec": result.get("duration_sec"),
"audio_duration_sec": result.get("audio_duration_sec"),
"realtime_factor": result.get("realtime_factor"),
"bpm": result.get("bpm"),
"hit_count": result.get("hit_count"),
"cluster_count": result.get("cluster_count"),
"clustering_mode": (result.get("params") or job.get("params") or {}).get("clustering_mode"),
"stem": (result.get("params") or job.get("params") or {}).get("stem"),
"error": job.get("error"),
}
def _list_manifest_jobs(limit: int = 50) -> list[dict[str, Any]]:
rows: list[dict[str, Any]] = []
for manifest in sorted(RUNS_DIR.glob("*/output/manifest.json"), key=lambda p: p.stat().st_mtime, reverse=True):
job_id = manifest.parents[1].name
manifest_job = _read_manifest_job(job_id)
if not manifest_job:
continue
summary = _summarise_job(manifest_job)
summary["created_at"] = manifest.stat().st_mtime
rows.append(summary)
if len(rows) >= limit:
break
return rows
def _update_job(job_id: str, **patch: Any) -> None:
with jobs_lock:
jobs[job_id].update(patch)
def _append_log(job_id: str, message: str) -> None:
with jobs_lock:
jobs[job_id].setdefault("logs", []).append(message)
def _run_job(job_id: str) -> None:
with jobs_lock:
job = jobs[job_id]
input_path = Path(job["input_path"])
output_dir = Path(job["output_dir"])
params = job["params"]
job["status"] = "running"
def progress(event: dict[str, Any]) -> None:
patch: dict[str, Any] = {}
if "stages" in event:
patch["stages"] = event["stages"]
if "progress" in event:
patch["progress"] = event["progress"]
elif "stages" in event:
patch["progress"] = _progress_from_stages(event["stages"], "running")
if event.get("type") == "sample":
patch["partial_samples"] = event.get("samples", [])
if patch:
_update_job(job_id, **patch)
if event.get("type") == "sample" and event.get("sample"):
sample = event["sample"]
_append_log(job_id, f"Sample ready: {sample.get('label')} ({sample.get('classification')})")
if event.get("stage"):
stage = event["stage"]
if stage.get("status") == "running":
_append_log(job_id, f"Started: {stage['label']}")
elif stage.get("status") == "done":
detail = f" · {stage['detail']}" if stage.get("detail") else ""
_append_log(job_id, f"Finished: {stage['label']} in {stage['duration_sec']:.3f}s{detail}")
try:
result = run_extraction_pipeline(input_path, output_dir, PipelineParams.from_mapping(params), progress_cb=progress)
load_or_create_state(job_id, output_dir)
_update_job(job_id, status="complete", result=asdict(result), error=None)
except Exception as exc: # deliberately explicit for UI diagnostics
_update_job(job_id, status="error", error=str(exc), traceback=traceback.format_exc())
_append_log(job_id, f"Error: {exc}")
@app.get("/api/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/api/config")
def config() -> dict[str, Any]:
return {
"separation_backends": SEPARATION_BACKENDS,
"spleeter_models": SPLEETER_MODELS,
"spleeter_stems": {key: value + ["all"] for key, value in SPLEETER_STEMS.items()},
"demucs_models": DEMUCS_MODELS,
"demucs_stems": {key: value + ["all"] for key, value in DEMUCS_STEMS.items()},
"defaults": asdict(PipelineParams()),
"runtime": separation_runtime_status(),
"stages": initial_stages(),
"clustering_modes": {
"batch_quality": "Batch quality: all-pairs mel/NCC + agglomerative clustering",
"online_preview": "Online preview: prototype assignment for near-realtime feedback",
},
}
@app.post("/api/cache/clear")
def clear_cache() -> dict[str, str]:
cache_clear()
clear_disk_cache()
return {"status": "cleared", "scope": "memory+disk"}
@app.get("/api/jobs")
def list_jobs(limit: int = 50) -> dict[str, Any]:
limit = max(1, min(int(limit), 200))
with jobs_lock:
active = [_summarise_job(dict(job)) for job in jobs.values() if job.get("status") != "complete"]
history = _list_manifest_jobs(limit=limit)
return {"active": active, "history": history}
@app.post("/api/jobs")
async def create_job(file: UploadFile = File(...), params: str = Form("{}")) -> JSONResponse:
filename = file.filename or "input.wav"
suffix = Path(filename).suffix.lower()
if suffix and suffix not in {".wav", ".mp3", ".flac", ".aiff", ".aif", ".ogg", ".m4a", ".aac"}:
raise HTTPException(status_code=400, detail=f"Unsupported audio file extension: {suffix}")
try:
parsed_params = json.loads(params)
validated = PipelineParams.from_mapping(parsed_params)
except json.JSONDecodeError as exc:
raise HTTPException(status_code=400, detail=f"Invalid params JSON: {exc.msg}") from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=f"Invalid extraction parameter: {exc}") from exc
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid extraction request: {exc}") from exc
job_id = uuid.uuid4().hex[:12]
job_dir = RUNS_DIR / job_id
input_dir = job_dir / "input"
output_dir = job_dir / "output"
input_dir.mkdir(parents=True, exist_ok=True)
output_dir.mkdir(parents=True, exist_ok=True)
suffix = suffix or ".wav"
input_path = input_dir / f"source{suffix}"
with input_path.open("wb") as handle:
shutil.copyfileobj(file.file, handle)
job = {
"id": job_id,
"status": "pending",
"filename": file.filename,
"created_at": time.time(),
"params": asdict(validated),
"stages": initial_stages(),
"progress": _progress_from_stages(initial_stages(), "pending"),
"logs": [],
"result": None,
"partial_samples": [],
"error": None,
"traceback": None,
"input_path": str(input_path),
"output_dir": str(output_dir),
}
with jobs_lock:
jobs[job_id] = job
executor.submit(_run_job, job_id)
return JSONResponse(_serialise_job(job), status_code=202)
@app.get("/api/jobs/{job_id}")
def get_job(job_id: str) -> dict[str, Any]:
with jobs_lock:
job = jobs.get(job_id)
if job:
return _serialise_job(dict(job))
manifest_job = _read_manifest_job(job_id)
if manifest_job:
return _serialise_job(manifest_job)
raise HTTPException(status_code=404, detail="Job not found")
def _job_output_dir(job_id: str) -> Path:
with jobs_lock:
job = jobs.get(job_id)
if job and job.get("output_dir"):
return Path(job["output_dir"])
manifest = _manifest_path(job_id)
if manifest.exists():
return manifest.parent
raise HTTPException(status_code=404, detail="Job not found")
def _state_payload(job_id: str) -> dict[str, Any]:
out = _job_output_dir(job_id)
try:
state = load_or_create_state(job_id, out)
except FileNotFoundError as exc:
raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return public_state(state, url_for=lambda rel: _job_url(job_id, rel))
def _json_patch(payload: dict[str, Any] | None) -> dict[str, Any]:
return dict(payload or {})
def _state_for_mutation(job_id: str) -> tuple[Path, dict[str, Any]]:
out = _job_output_dir(job_id)
try:
return out, load_or_create_state(job_id, out)
except FileNotFoundError as exc:
raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc
def _cluster_id_for_sample_label(state: dict[str, Any], sample_label: str) -> str:
clusters = state.get("clusters", {})
exact = [cid for cid, cluster in clusters.items() if str(cluster.get("label")) == str(sample_label)]
if exact:
return exact[0]
# Fall back to a classification/base-name match for labels that have been renamed by user edits.
base = str(sample_label).rsplit("_", 1)[0]
fuzzy = [cid for cid, cluster in clusters.items() if str(cluster.get("classification") or "") == base]
if len(fuzzy) == 1:
return fuzzy[0]
raise HTTPException(status_code=404, detail=f"Sample label not found in current state: {sample_label}")
def _public_sample_from_cluster(job_id: str, state: dict[str, Any], cluster_id: str, label_override: str | None = None) -> dict[str, Any]:
clusters = state.get("clusters", {})
hits = state.get("hits", {})
if cluster_id not in clusters:
raise HTTPException(status_code=404, detail=f"Unknown cluster: {cluster_id}")
cluster = clusters[cluster_id]
active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in hits and not hits[hid].get("suppressed")]
if not active_ids:
raise HTTPException(status_code=409, detail=f"Cluster {cluster.get('label', cluster_id)} has no active hits")
rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else active_ids[0]
hit = hits[rep_id]
raw_cluster_id = cluster_id.split(":", 1)[1] if ":" in cluster_id else cluster_id
try:
raw_cluster_id_value: int | str = int(raw_cluster_id)
except ValueError:
raw_cluster_id_value = raw_cluster_id
file_path = str(hit.get("file") or "")
first_onset = min(float(hits[hid].get("onset_sec") or 0.0) for hid in active_ids)
return {
"label": label_override or cluster.get("label") or str(cluster_id),
"classification": cluster.get("classification") or str(cluster.get("label") or "other").rsplit("_", 1)[0],
"hits": len(active_ids),
"midi_note": cluster.get("midi_note", 60),
"score": "edited",
"duration_ms": round(float(hit.get("duration_ms") or 0.0), 1),
"first_onset_sec": round(first_onset, 4),
"representative_hit_index": int(hit.get("index") or 0),
"state_hit_id": rep_id,
"cluster_id": raw_cluster_id_value,
"state_cluster_id": cluster_id,
"file": file_path,
"url": _job_url(job_id, file_path) if file_path else None,
}
@app.get("/api/jobs/{job_id}/events")
def get_job_events(job_id: str) -> StreamingResponse:
with jobs_lock:
exists_in_memory = job_id in jobs
exists_on_disk = _read_manifest_job(job_id) is not None
if not exists_in_memory and not exists_on_disk:
raise HTTPException(status_code=404, detail="Job not found")
async def event_stream():
last_payload: str | None = None
while True:
with jobs_lock:
memory_job = jobs.get(job_id)
job = dict(memory_job) if memory_job else None
if job is None:
job = _read_manifest_job(job_id)
if job is None:
payload = {"id": job_id, "status": "error", "error": "Job disappeared"}
else:
payload = _serialise_job(job)
encoded = json.dumps(payload, sort_keys=True)
if encoded != last_payload:
yield f"event: job\ndata: {encoded}\n\n"
last_payload = encoded
if payload.get("status") in {"complete", "error"}:
break
await asyncio.sleep(0.5)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)
@app.get("/api/jobs/{job_id}/state")
def get_job_state(job_id: str) -> dict[str, Any]:
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/export")
def post_supervised_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
patch = _json_patch(payload)
try:
export_manifest = export_supervised_state(
_job_output_dir(job_id),
job_id,
synthesize=bool(patch.get("synthesize", True)),
quantize=patch.get("quantize"),
subdivision=patch.get("subdivision"),
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)}
@app.post("/api/jobs/{job_id}/export-selected")
def post_selected_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
patch = _json_patch(payload)
labels = [str(item) for item in patch.get("labels", []) if str(item).strip()]
if not labels:
raise HTTPException(status_code=400, detail="labels must contain at least one selected sample label")
try:
export_manifest = export_selected_samples(
_job_output_dir(job_id),
job_id,
selected_labels=labels,
synthesize=bool(patch.get("synthesize", True)),
quantize=patch.get("quantize"),
subdivision=patch.get("subdivision"),
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)}
@app.post("/api/jobs/{job_id}/samples/{sample_label:path}/draw")
def post_draw_sample(job_id: str, sample_label: str) -> dict[str, Any]:
try:
out, state = _state_for_mutation(job_id)
cluster_id = _cluster_id_for_sample_label(state, sample_label)
state = apply_draw_next_representative(out, job_id, cluster_id, source="sample-card")
return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))}
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.post("/api/jobs/{job_id}/samples/{sample_label:path}/edit")
def post_edit_sample(job_id: str, sample_label: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
patch = _json_patch(payload)
try:
out, state = _state_for_mutation(job_id)
cluster_id = _cluster_id_for_sample_label(state, sample_label)
cluster = state.get("clusters", {}).get(cluster_id) or {}
active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in state.get("hits", {}) and not state["hits"][hid].get("suppressed")]
rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else (active_ids[0] if active_ids else None)
if not rep_id:
raise HTTPException(status_code=409, detail=f"Sample {sample_label} has no active representative hit")
state = apply_hit_timing_edit(
out,
job_id,
rep_id,
start_offset_ms=float(patch.get("start_offset_ms", 0.0)),
tail_offset_ms=float(patch.get("tail_offset_ms", 0.0)),
source="sample-card",
)
return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))}
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except HTTPException:
raise
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
@app.post("/api/jobs/{job_id}/hits/force-onset")
def post_force_onset(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
patch = _json_patch(payload)
if "onset_sec" not in patch:
raise HTTPException(status_code=400, detail="onset_sec is required")
try:
apply_force_onset(
_job_output_dir(job_id),
job_id,
float(patch["onset_sec"]),
duration_ms=patch.get("duration_ms"),
label=patch.get("label"),
target_cluster_id=patch.get("target_cluster_id"),
pre_pad_sec=float(patch.get("pre_pad_sec", 0.003)),
)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/hits/{hit_id}/move")
def post_move_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
target_cluster_id = _json_patch(payload).get("target_cluster_id")
if not target_cluster_id:
raise HTTPException(status_code=400, detail="target_cluster_id is required")
try:
apply_hit_move(_job_output_dir(job_id), job_id, hit_id, str(target_cluster_id))
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/hits/{hit_id}/pull-out")
def post_pull_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
label = _json_patch(payload).get("label")
try:
pull_hit_to_new_cluster(_job_output_dir(job_id), job_id, hit_id, label=label)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/hits/{hit_id}/suppress")
def post_suppress_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
reason = str(_json_patch(payload).get("reason") or "bleed")
try:
apply_hit_suppression(_job_output_dir(job_id), job_id, hit_id, reason=reason)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/hits/{hit_id}/restore")
def post_restore_hit(job_id: str, hit_id: str) -> dict[str, Any]:
try:
apply_hit_restore(_job_output_dir(job_id), job_id, hit_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/hits/{hit_id}/review")
def post_review_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
status = str(_json_patch(payload).get("status") or "accepted")
try:
set_hit_review_status(_job_output_dir(job_id), job_id, hit_id, status=status)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/clusters/{cluster_id:path}/lock")
def post_lock_cluster(job_id: str, cluster_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]:
locked = bool(_json_patch(payload).get("locked", True))
try:
apply_cluster_lock(_job_output_dir(job_id), job_id, cluster_id, locked=locked)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.get("/api/jobs/{job_id}/suggestions")
def get_suggestions(job_id: str) -> dict[str, Any]:
state = _state_payload(job_id)
return {"suggestions": state.get("suggestions", []), "summary": state.get("summary", {})}
@app.post("/api/jobs/{job_id}/suggestions/{suggestion_id}/accept")
def post_accept_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]:
try:
accept_suggestion(_job_output_dir(job_id), job_id, suggestion_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.post("/api/jobs/{job_id}/suggestions/{suggestion_id}/reject")
def post_reject_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]:
try:
reject_suggestion(_job_output_dir(job_id), job_id, suggestion_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.get("/api/jobs/{job_id}/explain/cluster/{cluster_id:path}")
def get_cluster_explanation(job_id: str, cluster_id: str) -> dict[str, Any]:
out = _job_output_dir(job_id)
state = load_or_create_state(job_id, out)
try:
explanation = build_cluster_explanation(state, cluster_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail=str(exc)) from exc
return explanation
@app.post("/api/jobs/{job_id}/undo")
def post_undo(job_id: str) -> dict[str, Any]:
try:
apply_undo(_job_output_dir(job_id), job_id)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
return _state_payload(job_id)
@app.get("/api/jobs/{job_id}/files/{relative_path:path}")
def get_job_file(job_id: str, relative_path: str) -> FileResponse:
root = (RUNS_DIR / job_id / "output").resolve()
path = (root / relative_path).resolve()
try:
path.relative_to(root)
except ValueError as exc:
raise HTTPException(status_code=404, detail="File not found") from exc
if not path.exists() or not path.is_file():
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(path)
if WEB_DIR.exists():
app.mount("/web", StaticFiles(directory=WEB_DIR), name="web")
@app.get("/")
def index() -> FileResponse:
index_path = WEB_DIR / "index.html"
if not index_path.exists():
raise HTTPException(status_code=500, detail="web/index.html is missing")
return FileResponse(index_path)