Spaces:
Sleeping
Sleeping
File size: 5,985 Bytes
9244b7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """
ExperimentManager: central registry for all training runs.
Stored as a singleton in st.session_state['exp_manager'].
"""
import threading
import queue
import time
import logging
from dataclasses import dataclass, field
from typing import Optional, Any
logger = logging.getLogger(__name__)
@dataclass
class ExperimentEntry:
key: str # unique slug: "autogluon_1712345678"
metadata: dict # framework, run_name, config snapshot
thread: Optional[threading.Thread] = field(default=None, repr=False)
stop_event: threading.Event = field(default_factory=threading.Event, repr=False)
log_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
telemetry_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
result_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
status: str = "queued" # queued | running | completed | failed | cancelled
started_at: float = field(default_factory=time.time)
finished_at: Optional[float] = None
result: Optional[dict] = None # {predictor, run_id, type, ...} or {error: str}
all_logs: list = field(default_factory=list)
latest_telemetry: dict = field(default_factory=dict)
last_update: float = field(default_factory=time.time)
def elapsed_str(self) -> str:
end = self.finished_at or time.time()
secs = int(end - self.started_at)
m, s = divmod(secs, 60)
return f"{m}m {s:02d}s"
def status_icon(self) -> str:
return {
"queued": "β³",
"running": "π’",
"completed": "β
",
"failed": "β",
"cancelled": "π«",
}.get(self.status, "β")
def drain_logs(self) -> bool:
"""Pull all pending log lines and telemetry into the entry."""
new = False
while not self.log_queue.empty():
try:
line = self.log_queue.get_nowait()
self.all_logs.append(line)
new = True
except queue.Empty:
break
while not self.telemetry_queue.empty():
try:
data = self.telemetry_queue.get_nowait()
if isinstance(data, dict):
self.latest_telemetry.update(data)
new = True
except queue.Empty:
break
if new:
self.last_update = time.time()
return new
def check_result(self):
"""Non-blocking check: pull result from queue if available."""
if not self.result_queue.empty():
try:
res = self.result_queue.get_nowait()
self.result = res
if res.get("success"):
self.status = "completed"
else:
self.status = "failed"
self.finished_at = time.time()
self.last_update = time.time()
except queue.Empty:
pass
class ExperimentManager:
"""In-process registry of all AutoML experiments."""
def __init__(self):
self._runs: dict[str, ExperimentEntry] = {}
self._lock = threading.Lock()
def add(self, entry: ExperimentEntry) -> str:
with self._lock:
self._runs[entry.key] = entry
return entry.key
def cancel(self, key: str):
"""Request graceful cancellation of a running experiment."""
with self._lock:
entry = self._runs.get(key)
if entry and entry.status == "running":
entry.stop_event.set()
entry.status = "cancelled"
entry.finished_at = time.time()
entry.last_update = time.time()
logger.info(f"Cancel requested for experiment: {key}")
def delete(self, key: str):
"""Remove experiment from registry (only if not actively running)."""
with self._lock:
entry = self._runs.get(key)
if entry and entry.status == "running":
# Cancel first
entry.stop_event.set()
entry.status = "cancelled"
entry.finished_at = time.time()
entry.last_update = time.time()
self._runs.pop(key, None)
def get(self, key: str) -> Optional[ExperimentEntry]:
with self._lock:
return self._runs.get(key)
def get_all(self) -> list[ExperimentEntry]:
"""Return all experiments newest-first."""
with self._lock:
entries = list(self._runs.values())
return sorted(entries, key=lambda e: e.started_at, reverse=True)
def has_running(self) -> bool:
return any(e.status == "running" for e in self.get_all())
def refresh_all(self):
"""Sync status/logs/results for all experiments."""
for entry in self.get_all():
entry.drain_logs()
if entry.status in ("running", "queued"):
entry.check_result()
# Also check if thread died unexpectedly
if getattr(entry, 'thread', None) is not None:
# Defensive check for is_alive
if not entry.thread.is_alive() and entry.status == "running":
if entry.result is None:
entry.status = "failed"
entry.result = {"success": False, "error": "Thread terminated unexpectedly"}
entry.finished_at = time.time()
entry.last_update = time.time()
def get_or_create_manager(session_state) -> ExperimentManager:
"""Get or create the singleton ExperimentManager from Streamlit session state."""
if 'exp_manager' not in session_state or not isinstance(session_state.get('exp_manager'), ExperimentManager):
session_state['exp_manager'] = ExperimentManager()
return session_state['exp_manager']
|