Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Custom web application for the drum sample extractor. | |
| Run with: | |
| uvicorn app:app --host 0.0.0.0 --port 7860 | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import shutil | |
| import time | |
| import traceback | |
| import uuid | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from threading import Lock | |
| from typing import Any | |
| from fastapi import Body, FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pipeline_runner import PipelineParams, SPLEETER_MODELS, SPLEETER_STEMS, SEPARATION_BACKENDS, clear_disk_cache, initial_stages, run_extraction_pipeline, separation_runtime_status | |
| from sample_extractor import DEMUCS_MODELS, DEMUCS_STEMS, cache_clear | |
| from supervised_state import ( | |
| accept_suggestion, | |
| explain_cluster as build_cluster_explanation, | |
| force_onset as apply_force_onset, | |
| load_or_create_state, | |
| lock_cluster as apply_cluster_lock, | |
| move_hit as apply_hit_move, | |
| public_state, | |
| pull_hit_to_new_cluster, | |
| reject_suggestion, | |
| restore_hit as apply_hit_restore, | |
| set_hit_review_status, | |
| suppress_hit as apply_hit_suppression, | |
| draw_next_representative as apply_draw_next_representative, | |
| edit_hit_timing as apply_hit_timing_edit, | |
| undo_last as apply_undo, | |
| ) | |
| from supervised_export import export_selected_samples, export_supervised_state | |
| ROOT = Path(__file__).resolve().parent | |
| WEB_DIR = ROOT / "web" | |
| RUNS_DIR = ROOT / ".runs" | |
| RUNS_DIR.mkdir(exist_ok=True) | |
| app = FastAPI(title="Drum Sample Extractor", version="15.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| executor = ThreadPoolExecutor(max_workers=1) | |
| jobs_lock = Lock() | |
| jobs: dict[str, dict[str, Any]] = {} | |
| def _job_url(job_id: str, relative_path: str) -> str: | |
| return f"/api/jobs/{job_id}/files/{relative_path}" | |
| def _serialise_export(job_id: str, export_manifest: dict[str, Any]) -> dict[str, Any]: | |
| payload = dict(export_manifest) | |
| payload["file_urls"] = {key: _job_url(job_id, path) for key, path in payload.get("files", {}).items()} | |
| payload["samples"] = [ | |
| {**sample, "url": _job_url(job_id, sample["file"])} | |
| for sample in payload.get("samples", []) | |
| ] | |
| return payload | |
| def _serialise_job(job: dict[str, Any]) -> dict[str, Any]: | |
| payload = {key: value for key, value in job.items() if key not in {"input_path", "output_dir"}} | |
| if payload.get("progress") is None: | |
| payload["progress"] = _progress_from_stages(payload.get("stages", []), payload.get("status")) | |
| if payload.get("partial_samples"): | |
| payload["partial_samples"] = [ | |
| {**sample, "url": _job_url(job["id"], sample["file"])} | |
| for sample in payload.get("partial_samples", []) | |
| ] | |
| if payload.get("result"): | |
| result = dict(payload["result"]) | |
| result["file_urls"] = {key: _job_url(job["id"], path) for key, path in result.get("files", {}).items()} | |
| result["samples"] = [ | |
| {**sample, "url": _job_url(job["id"], sample["file"])} | |
| for sample in result.get("samples", []) | |
| ] | |
| result["hits"] = [ | |
| {**hit, "url": _job_url(job["id"], hit["file"])} | |
| for hit in result.get("hits", []) | |
| ] | |
| payload["result"] = result | |
| return payload | |
| def _progress_from_stages(stages: list[dict[str, Any]], status: str | None = None) -> dict[str, Any]: | |
| running = next((stage for stage in stages if stage.get("status") == "running"), None) | |
| running_progress = float(running.get("progress") or 0.0) if running else 0.0 | |
| total = 0.0 | |
| completed_units = 0.0 | |
| completed = 0 | |
| for stage in stages: | |
| units = float(stage.get("work_total") or 1) | |
| total += units | |
| if stage.get("status") == "done": | |
| completed += 1 | |
| completed_units += units | |
| elif stage.get("status") == "running": | |
| completed_units += float(stage.get("work_done")) if stage.get("work_total") and stage.get("work_done") is not None else running_progress * units | |
| fraction = min(1.0, max(0.0, completed_units / max(total, 1.0))) | |
| if status == "complete": | |
| fraction = 1.0 | |
| completed_units = total | |
| return { | |
| "fraction": round(fraction, 6), | |
| "completed_steps": completed, | |
| "total_steps": len(stages), | |
| "completed_units": round(completed_units, 6), | |
| "total_units": round(total, 6), | |
| "stage_key": running.get("key") if running else None, | |
| "stage_label": running.get("label") if running else None, | |
| "stage_fraction": round(running_progress, 6), | |
| "stage_work_done": running.get("work_done") if running else None, | |
| "stage_work_total": running.get("work_total") if running else None, | |
| "basis": "exact completed work units: Demucs chunks when available, otherwise stage boundary units; no time-based estimates", | |
| } | |
| def _manifest_path(job_id: str) -> Path: | |
| return RUNS_DIR / job_id / "output" / "manifest.json" | |
| def _read_manifest_job(job_id: str) -> dict[str, Any] | None: | |
| manifest = _manifest_path(job_id) | |
| if not manifest.exists(): | |
| return None | |
| result = json.loads(manifest.read_text(encoding="utf-8")) | |
| return { | |
| "id": job_id, | |
| "status": "complete", | |
| "filename": result.get("source", {}).get("filename"), | |
| "params": result.get("params", {}), | |
| "stages": result.get("stages", []), | |
| "progress": _progress_from_stages(result.get("stages", []), "complete"), | |
| "logs": [], | |
| "result": result, | |
| "error": None, | |
| "traceback": None, | |
| "output_dir": str(manifest.parent), | |
| } | |
| def _summarise_job(job: dict[str, Any]) -> dict[str, Any]: | |
| result = job.get("result") or {} | |
| return { | |
| "id": job["id"], | |
| "status": job.get("status"), | |
| "filename": job.get("filename"), | |
| "created_at": job.get("created_at"), | |
| "duration_sec": result.get("duration_sec"), | |
| "audio_duration_sec": result.get("audio_duration_sec"), | |
| "realtime_factor": result.get("realtime_factor"), | |
| "bpm": result.get("bpm"), | |
| "hit_count": result.get("hit_count"), | |
| "cluster_count": result.get("cluster_count"), | |
| "clustering_mode": (result.get("params") or job.get("params") or {}).get("clustering_mode"), | |
| "stem": (result.get("params") or job.get("params") or {}).get("stem"), | |
| "error": job.get("error"), | |
| } | |
| def _list_manifest_jobs(limit: int = 50) -> list[dict[str, Any]]: | |
| rows: list[dict[str, Any]] = [] | |
| for manifest in sorted(RUNS_DIR.glob("*/output/manifest.json"), key=lambda p: p.stat().st_mtime, reverse=True): | |
| job_id = manifest.parents[1].name | |
| manifest_job = _read_manifest_job(job_id) | |
| if not manifest_job: | |
| continue | |
| summary = _summarise_job(manifest_job) | |
| summary["created_at"] = manifest.stat().st_mtime | |
| rows.append(summary) | |
| if len(rows) >= limit: | |
| break | |
| return rows | |
| def _update_job(job_id: str, **patch: Any) -> None: | |
| with jobs_lock: | |
| jobs[job_id].update(patch) | |
| def _append_log(job_id: str, message: str) -> None: | |
| with jobs_lock: | |
| jobs[job_id].setdefault("logs", []).append(message) | |
| def _run_job(job_id: str) -> None: | |
| with jobs_lock: | |
| job = jobs[job_id] | |
| input_path = Path(job["input_path"]) | |
| output_dir = Path(job["output_dir"]) | |
| params = job["params"] | |
| job["status"] = "running" | |
| def progress(event: dict[str, Any]) -> None: | |
| patch: dict[str, Any] = {} | |
| if "stages" in event: | |
| patch["stages"] = event["stages"] | |
| if "progress" in event: | |
| patch["progress"] = event["progress"] | |
| elif "stages" in event: | |
| patch["progress"] = _progress_from_stages(event["stages"], "running") | |
| if event.get("type") == "sample": | |
| patch["partial_samples"] = event.get("samples", []) | |
| if patch: | |
| _update_job(job_id, **patch) | |
| if event.get("type") == "sample" and event.get("sample"): | |
| sample = event["sample"] | |
| _append_log(job_id, f"Sample ready: {sample.get('label')} ({sample.get('classification')})") | |
| if event.get("stage"): | |
| stage = event["stage"] | |
| if stage.get("status") == "running": | |
| _append_log(job_id, f"Started: {stage['label']}") | |
| elif stage.get("status") == "done": | |
| detail = f" · {stage['detail']}" if stage.get("detail") else "" | |
| _append_log(job_id, f"Finished: {stage['label']} in {stage['duration_sec']:.3f}s{detail}") | |
| try: | |
| result = run_extraction_pipeline(input_path, output_dir, PipelineParams.from_mapping(params), progress_cb=progress) | |
| load_or_create_state(job_id, output_dir) | |
| _update_job(job_id, status="complete", result=asdict(result), error=None) | |
| except Exception as exc: # deliberately explicit for UI diagnostics | |
| _update_job(job_id, status="error", error=str(exc), traceback=traceback.format_exc()) | |
| _append_log(job_id, f"Error: {exc}") | |
| def health() -> dict[str, str]: | |
| return {"status": "ok"} | |
| def config() -> dict[str, Any]: | |
| return { | |
| "separation_backends": SEPARATION_BACKENDS, | |
| "spleeter_models": SPLEETER_MODELS, | |
| "spleeter_stems": {key: value + ["all"] for key, value in SPLEETER_STEMS.items()}, | |
| "demucs_models": DEMUCS_MODELS, | |
| "demucs_stems": {key: value + ["all"] for key, value in DEMUCS_STEMS.items()}, | |
| "defaults": asdict(PipelineParams()), | |
| "runtime": separation_runtime_status(), | |
| "stages": initial_stages(), | |
| "clustering_modes": { | |
| "batch_quality": "Batch quality: all-pairs mel/NCC + agglomerative clustering", | |
| "online_preview": "Online preview: prototype assignment for near-realtime feedback", | |
| }, | |
| } | |
| def clear_cache() -> dict[str, str]: | |
| cache_clear() | |
| clear_disk_cache() | |
| return {"status": "cleared", "scope": "memory+disk"} | |
| def list_jobs(limit: int = 50) -> dict[str, Any]: | |
| limit = max(1, min(int(limit), 200)) | |
| with jobs_lock: | |
| active = [_summarise_job(dict(job)) for job in jobs.values() if job.get("status") != "complete"] | |
| history = _list_manifest_jobs(limit=limit) | |
| return {"active": active, "history": history} | |
| async def create_job(file: UploadFile = File(...), params: str = Form("{}")) -> JSONResponse: | |
| filename = file.filename or "input.wav" | |
| suffix = Path(filename).suffix.lower() | |
| if suffix and suffix not in {".wav", ".mp3", ".flac", ".aiff", ".aif", ".ogg", ".m4a", ".aac"}: | |
| raise HTTPException(status_code=400, detail=f"Unsupported audio file extension: {suffix}") | |
| try: | |
| parsed_params = json.loads(params) | |
| validated = PipelineParams.from_mapping(parsed_params) | |
| except json.JSONDecodeError as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid params JSON: {exc.msg}") from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid extraction parameter: {exc}") from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid extraction request: {exc}") from exc | |
| job_id = uuid.uuid4().hex[:12] | |
| job_dir = RUNS_DIR / job_id | |
| input_dir = job_dir / "input" | |
| output_dir = job_dir / "output" | |
| input_dir.mkdir(parents=True, exist_ok=True) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| suffix = suffix or ".wav" | |
| input_path = input_dir / f"source{suffix}" | |
| with input_path.open("wb") as handle: | |
| shutil.copyfileobj(file.file, handle) | |
| job = { | |
| "id": job_id, | |
| "status": "pending", | |
| "filename": file.filename, | |
| "created_at": time.time(), | |
| "params": asdict(validated), | |
| "stages": initial_stages(), | |
| "progress": _progress_from_stages(initial_stages(), "pending"), | |
| "logs": [], | |
| "result": None, | |
| "partial_samples": [], | |
| "error": None, | |
| "traceback": None, | |
| "input_path": str(input_path), | |
| "output_dir": str(output_dir), | |
| } | |
| with jobs_lock: | |
| jobs[job_id] = job | |
| executor.submit(_run_job, job_id) | |
| return JSONResponse(_serialise_job(job), status_code=202) | |
| def get_job(job_id: str) -> dict[str, Any]: | |
| with jobs_lock: | |
| job = jobs.get(job_id) | |
| if job: | |
| return _serialise_job(dict(job)) | |
| manifest_job = _read_manifest_job(job_id) | |
| if manifest_job: | |
| return _serialise_job(manifest_job) | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| def _job_output_dir(job_id: str) -> Path: | |
| with jobs_lock: | |
| job = jobs.get(job_id) | |
| if job and job.get("output_dir"): | |
| return Path(job["output_dir"]) | |
| manifest = _manifest_path(job_id) | |
| if manifest.exists(): | |
| return manifest.parent | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| def _state_payload(job_id: str) -> dict[str, Any]: | |
| out = _job_output_dir(job_id) | |
| try: | |
| state = load_or_create_state(job_id, out) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return public_state(state, url_for=lambda rel: _job_url(job_id, rel)) | |
| def _json_patch(payload: dict[str, Any] | None) -> dict[str, Any]: | |
| return dict(payload or {}) | |
| def _state_for_mutation(job_id: str) -> tuple[Path, dict[str, Any]]: | |
| out = _job_output_dir(job_id) | |
| try: | |
| return out, load_or_create_state(job_id, out) | |
| except FileNotFoundError as exc: | |
| raise HTTPException(status_code=409, detail="Job has no manifest yet; wait until extraction completes") from exc | |
| def _cluster_id_for_sample_label(state: dict[str, Any], sample_label: str) -> str: | |
| clusters = state.get("clusters", {}) | |
| exact = [cid for cid, cluster in clusters.items() if str(cluster.get("label")) == str(sample_label)] | |
| if exact: | |
| return exact[0] | |
| # Fall back to a classification/base-name match for labels that have been renamed by user edits. | |
| base = str(sample_label).rsplit("_", 1)[0] | |
| fuzzy = [cid for cid, cluster in clusters.items() if str(cluster.get("classification") or "") == base] | |
| if len(fuzzy) == 1: | |
| return fuzzy[0] | |
| raise HTTPException(status_code=404, detail=f"Sample label not found in current state: {sample_label}") | |
| def _public_sample_from_cluster(job_id: str, state: dict[str, Any], cluster_id: str, label_override: str | None = None) -> dict[str, Any]: | |
| clusters = state.get("clusters", {}) | |
| hits = state.get("hits", {}) | |
| if cluster_id not in clusters: | |
| raise HTTPException(status_code=404, detail=f"Unknown cluster: {cluster_id}") | |
| cluster = clusters[cluster_id] | |
| active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in hits and not hits[hid].get("suppressed")] | |
| if not active_ids: | |
| raise HTTPException(status_code=409, detail=f"Cluster {cluster.get('label', cluster_id)} has no active hits") | |
| rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else active_ids[0] | |
| hit = hits[rep_id] | |
| raw_cluster_id = cluster_id.split(":", 1)[1] if ":" in cluster_id else cluster_id | |
| try: | |
| raw_cluster_id_value: int | str = int(raw_cluster_id) | |
| except ValueError: | |
| raw_cluster_id_value = raw_cluster_id | |
| file_path = str(hit.get("file") or "") | |
| first_onset = min(float(hits[hid].get("onset_sec") or 0.0) for hid in active_ids) | |
| return { | |
| "label": label_override or cluster.get("label") or str(cluster_id), | |
| "classification": cluster.get("classification") or str(cluster.get("label") or "other").rsplit("_", 1)[0], | |
| "hits": len(active_ids), | |
| "midi_note": cluster.get("midi_note", 60), | |
| "score": "edited", | |
| "duration_ms": round(float(hit.get("duration_ms") or 0.0), 1), | |
| "first_onset_sec": round(first_onset, 4), | |
| "representative_hit_index": int(hit.get("index") or 0), | |
| "state_hit_id": rep_id, | |
| "cluster_id": raw_cluster_id_value, | |
| "state_cluster_id": cluster_id, | |
| "file": file_path, | |
| "url": _job_url(job_id, file_path) if file_path else None, | |
| } | |
| def get_job_events(job_id: str) -> StreamingResponse: | |
| with jobs_lock: | |
| exists_in_memory = job_id in jobs | |
| exists_on_disk = _read_manifest_job(job_id) is not None | |
| if not exists_in_memory and not exists_on_disk: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| async def event_stream(): | |
| last_payload: str | None = None | |
| while True: | |
| with jobs_lock: | |
| memory_job = jobs.get(job_id) | |
| job = dict(memory_job) if memory_job else None | |
| if job is None: | |
| job = _read_manifest_job(job_id) | |
| if job is None: | |
| payload = {"id": job_id, "status": "error", "error": "Job disappeared"} | |
| else: | |
| payload = _serialise_job(job) | |
| encoded = json.dumps(payload, sort_keys=True) | |
| if encoded != last_payload: | |
| yield f"event: job\ndata: {encoded}\n\n" | |
| last_payload = encoded | |
| if payload.get("status") in {"complete", "error"}: | |
| break | |
| await asyncio.sleep(0.5) | |
| return StreamingResponse( | |
| event_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| }, | |
| ) | |
| def get_job_state(job_id: str) -> dict[str, Any]: | |
| return _state_payload(job_id) | |
| def post_supervised_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| patch = _json_patch(payload) | |
| try: | |
| export_manifest = export_supervised_state( | |
| _job_output_dir(job_id), | |
| job_id, | |
| synthesize=bool(patch.get("synthesize", True)), | |
| quantize=patch.get("quantize"), | |
| subdivision=patch.get("subdivision"), | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)} | |
| def post_selected_export(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| patch = _json_patch(payload) | |
| labels = [str(item) for item in patch.get("labels", []) if str(item).strip()] | |
| if not labels: | |
| raise HTTPException(status_code=400, detail="labels must contain at least one selected sample label") | |
| try: | |
| export_manifest = export_selected_samples( | |
| _job_output_dir(job_id), | |
| job_id, | |
| selected_labels=labels, | |
| synthesize=bool(patch.get("synthesize", True)), | |
| quantize=patch.get("quantize"), | |
| subdivision=patch.get("subdivision"), | |
| ) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return {"export": _serialise_export(job_id, export_manifest), "state": _state_payload(job_id)} | |
| def post_draw_sample(job_id: str, sample_label: str) -> dict[str, Any]: | |
| try: | |
| out, state = _state_for_mutation(job_id) | |
| cluster_id = _cluster_id_for_sample_label(state, sample_label) | |
| state = apply_draw_next_representative(out, job_id, cluster_id, source="sample-card") | |
| return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))} | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=409, detail=str(exc)) from exc | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| def post_edit_sample(job_id: str, sample_label: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| patch = _json_patch(payload) | |
| try: | |
| out, state = _state_for_mutation(job_id) | |
| cluster_id = _cluster_id_for_sample_label(state, sample_label) | |
| cluster = state.get("clusters", {}).get(cluster_id) or {} | |
| active_ids = [hid for hid in cluster.get("hit_ids", []) if hid in state.get("hits", {}) and not state["hits"][hid].get("suppressed")] | |
| rep_id = cluster.get("representative_hit_id") if cluster.get("representative_hit_id") in active_ids else (active_ids[0] if active_ids else None) | |
| if not rep_id: | |
| raise HTTPException(status_code=409, detail=f"Sample {sample_label} has no active representative hit") | |
| state = apply_hit_timing_edit( | |
| out, | |
| job_id, | |
| rep_id, | |
| start_offset_ms=float(patch.get("start_offset_ms", 0.0)), | |
| tail_offset_ms=float(patch.get("tail_offset_ms", 0.0)), | |
| source="sample-card", | |
| ) | |
| return {"sample": _public_sample_from_cluster(job_id, state, cluster_id, label_override=sample_label), "state": public_state(state, url_for=lambda rel: _job_url(job_id, rel))} | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except HTTPException: | |
| raise | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| def post_force_onset(job_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| patch = _json_patch(payload) | |
| if "onset_sec" not in patch: | |
| raise HTTPException(status_code=400, detail="onset_sec is required") | |
| try: | |
| apply_force_onset( | |
| _job_output_dir(job_id), | |
| job_id, | |
| float(patch["onset_sec"]), | |
| duration_ms=patch.get("duration_ms"), | |
| label=patch.get("label"), | |
| target_cluster_id=patch.get("target_cluster_id"), | |
| pre_pad_sec=float(patch.get("pre_pad_sec", 0.003)), | |
| ) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_move_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| target_cluster_id = _json_patch(payload).get("target_cluster_id") | |
| if not target_cluster_id: | |
| raise HTTPException(status_code=400, detail="target_cluster_id is required") | |
| try: | |
| apply_hit_move(_job_output_dir(job_id), job_id, hit_id, str(target_cluster_id)) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_pull_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| label = _json_patch(payload).get("label") | |
| try: | |
| pull_hit_to_new_cluster(_job_output_dir(job_id), job_id, hit_id, label=label) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_suppress_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| reason = str(_json_patch(payload).get("reason") or "bleed") | |
| try: | |
| apply_hit_suppression(_job_output_dir(job_id), job_id, hit_id, reason=reason) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_restore_hit(job_id: str, hit_id: str) -> dict[str, Any]: | |
| try: | |
| apply_hit_restore(_job_output_dir(job_id), job_id, hit_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_review_hit(job_id: str, hit_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| status = str(_json_patch(payload).get("status") or "accepted") | |
| try: | |
| set_hit_review_status(_job_output_dir(job_id), job_id, hit_id, status=status) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_lock_cluster(job_id: str, cluster_id: str, payload: dict[str, Any] = Body(default_factory=dict)) -> dict[str, Any]: | |
| locked = bool(_json_patch(payload).get("locked", True)) | |
| try: | |
| apply_cluster_lock(_job_output_dir(job_id), job_id, cluster_id, locked=locked) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def get_suggestions(job_id: str) -> dict[str, Any]: | |
| state = _state_payload(job_id) | |
| return {"suggestions": state.get("suggestions", []), "summary": state.get("summary", {})} | |
| def post_accept_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]: | |
| try: | |
| accept_suggestion(_job_output_dir(job_id), job_id, suggestion_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def post_reject_suggestion(job_id: str, suggestion_id: str) -> dict[str, Any]: | |
| try: | |
| reject_suggestion(_job_output_dir(job_id), job_id, suggestion_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def get_cluster_explanation(job_id: str, cluster_id: str) -> dict[str, Any]: | |
| out = _job_output_dir(job_id) | |
| state = load_or_create_state(job_id, out) | |
| try: | |
| explanation = build_cluster_explanation(state, cluster_id) | |
| except KeyError as exc: | |
| raise HTTPException(status_code=404, detail=str(exc)) from exc | |
| return explanation | |
| def post_undo(job_id: str) -> dict[str, Any]: | |
| try: | |
| apply_undo(_job_output_dir(job_id), job_id) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) from exc | |
| return _state_payload(job_id) | |
| def get_job_file(job_id: str, relative_path: str) -> FileResponse: | |
| root = (RUNS_DIR / job_id / "output").resolve() | |
| path = (root / relative_path).resolve() | |
| try: | |
| path.relative_to(root) | |
| except ValueError as exc: | |
| raise HTTPException(status_code=404, detail="File not found") from exc | |
| if not path.exists() or not path.is_file(): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(path) | |
| if WEB_DIR.exists(): | |
| app.mount("/web", StaticFiles(directory=WEB_DIR), name="web") | |
| def index() -> FileResponse: | |
| index_path = WEB_DIR / "index.html" | |
| if not index_path.exists(): | |
| raise HTTPException(status_code=500, detail="web/index.html is missing") | |
| return FileResponse(index_path) | |