Multi-AutoML-Interface / src /training_worker.py
PedroM2626's picture
Add ONNX export utilities, pipeline parser, and PyCaret integration
9244b7e
"""
training_worker: thread entry point for every AutoML run.
Captures stdout/stderr, feeds log_queue, puts result into result_queue,
and respects the stop_event for graceful cancellation.
Log isolation strategy (definitive):
- We attach a _QueueLogHandler to each relevant named library logger.
- Each handler has a _ThreadFilter that only accepts log records whose
record.thread matches the experiment thread's ID.
- This means messages from Thread A never land in Thread B's queue,
even though they share the same named logger objects.
- propagate is set to False to prevent double-delivery via the root logger.
- All are restored in the finally block.
Stdout/Stderr isolation:
- redirect_stdout/redirect_stderr are process-global (they overwrite sys.stdout).
- We use a _ThreadAwareIO wrapper instead: it checks threading.current_thread()
on every write() call, so writes only reach the owning thread's queue.
"""
import io
import sys
import logging
import threading
import traceback
from src.experiment_manager import ExperimentEntry
_LIB_LOGGERS = [
'flaml', 'autogluon', 'mlflow', 'h2o', 'tpot',
'pycaret', 'lale', 'hyperopt', 'lightgbm', 'xgboost', 'catboost'
]
# ---------------------------------------------------------------------------
# Thread-aware stdout/stderr router (installed once, process-wide)
# ---------------------------------------------------------------------------
class _ThreadAwareIO(io.TextIOBase):
"""
Drop-in replacement for sys.stdout / sys.stderr that routes each write()
to the queue registered for the current thread, or falls back to the
original stream.
"""
def __init__(self, original_stream):
super().__init__()
self._original = original_stream
self._lock = threading.Lock()
self._thread_queues: dict[int, "queue.Queue"] = {}
def register(self, thread_id: int, q):
with self._lock:
self._thread_queues[thread_id] = q
def unregister(self, thread_id: int):
with self._lock:
self._thread_queues.pop(thread_id, None)
def write(self, s: str) -> int:
if not isinstance(s, str):
try:
s = str(s)
except Exception:
return 0
tid = threading.current_thread().ident
with self._lock:
q = self._thread_queues.get(tid)
if q is not None:
if s.strip():
# Filter out progress bar characters that fail on Windows cp1252
# \u2588 is the full block character
safe_s = s.replace('\u2588', '#').replace('\u258c', '|').replace('\u2584', '-')
q.put(safe_s.strip())
else:
# Fall back to original stream for threads not registered
try:
self._original.write(s)
except Exception:
pass
return len(s)
def flush(self):
try:
self._original.flush()
except Exception:
pass
@property
def encoding(self):
return getattr(self._original, 'encoding', 'utf-8') or 'utf-8'
@property
def errors(self):
return getattr(self._original, 'errors', 'replace')
# Install thread-aware routers once for the entire process
_stdout_router = _ThreadAwareIO(sys.__stdout__)
_stderr_router = _ThreadAwareIO(sys.__stderr__)
sys.stdout = _stdout_router
sys.stderr = _stderr_router
# ---------------------------------------------------------------------------
# Per-thread log handler with thread filter
# ---------------------------------------------------------------------------
class _ThreadFilter(logging.Filter):
"""Only accepts log records emitted by a specific OS thread."""
def __init__(self, thread_id: int):
super().__init__()
self._thread_id = thread_id
def filter(self, record: logging.LogRecord) -> bool:
return record.thread == self._thread_id
class _QueueLogHandler(logging.Handler):
def __init__(self, log_queue):
super().__init__()
self.log_queue = log_queue
def emit(self, record: logging.LogRecord):
try:
msg = self.format(record)
self.log_queue.put(msg)
except Exception:
pass
# ---------------------------------------------------------------------------
# Worker entry point
# ---------------------------------------------------------------------------
def run_training_worker(entry: ExperimentEntry, train_fn, kwargs: dict):
"""
Thread target. Runs train_fn(**kwargs, stop_event=entry.stop_event),
keeps the entry status updated, and puts the final result dict in
result_queue.
"""
thread_id = threading.current_thread().ident
# --- Thread-aware stdout/stderr capture ---
_stdout_router.register(thread_id, entry.log_queue)
_stderr_router.register(thread_id, entry.log_queue)
# --- Per-thread logging handler (with thread filter) ---
handler = _QueueLogHandler(entry.log_queue)
handler.setFormatter(logging.Formatter('%(message)s'))
handler.addFilter(_ThreadFilter(thread_id))
saved_propagate: dict[str, bool] = {}
for lib in _LIB_LOGGERS:
lib_logger = logging.getLogger(lib)
saved_propagate[lib] = lib_logger.propagate
lib_logger.propagate = False # prevents root from seeing AND double-deliver
lib_logger.addHandler(handler)
if lib_logger.level == logging.NOTSET or lib_logger.level > logging.INFO:
lib_logger.setLevel(logging.INFO)
entry.status = "running"
entry.log_queue.put(f"[Worker] Starting training: {entry.metadata.get('run_name', entry.key)}")
try:
# Inject stop_event and telemetry_queue into kwargs if the function accepts it
try:
import inspect
sig = inspect.signature(train_fn)
if 'stop_event' in sig.parameters:
kwargs['stop_event'] = entry.stop_event
if 'telemetry_queue' in sig.parameters:
kwargs['telemetry_queue'] = entry.telemetry_queue
except Exception:
pass
result = train_fn(**kwargs)
# Normalise result into a standard dict
if isinstance(result, tuple):
if len(result) == 2:
predictor, run_id = result
entry.result_queue.put({
"success": True, "predictor": predictor, "run_id": run_id,
"type": entry.metadata.get("framework_key", "unknown")
})
elif len(result) == 4:
tpot, pipeline, run_id, info = result
entry.result_queue.put({
"success": True, "predictor": pipeline, "run_id": run_id, "info": info, "type": "tpot"
})
else:
entry.result_queue.put({
"success": True, "predictor": result[0], "run_id": result[-1],
"type": entry.metadata.get("framework_key", "unknown")
})
elif isinstance(result, dict):
entry.result_queue.put(result)
else:
entry.result_queue.put({
"success": True, "predictor": result, "run_id": None,
"type": entry.metadata.get("framework_key", "unknown")
})
except StopIteration:
entry.log_queue.put("[Worker] Training cancelled by user request.")
entry.result_queue.put({"success": False, "cancelled": True, "error": "Cancelled by user"})
except Exception as e:
err_tb = traceback.format_exc()
entry.log_queue.put(f"[Worker] CRITICAL ERROR: {e}\n{err_tb}")
entry.result_queue.put({"success": False, "error": str(e), "traceback": err_tb})
finally:
# Restore all lib loggers
for lib in _LIB_LOGGERS:
lib_logger = logging.getLogger(lib)
lib_logger.removeHandler(handler)
lib_logger.propagate = saved_propagate.get(lib, True)
# Unregister stdout/stderr routing for this thread
_stdout_router.unregister(thread_id)
_stderr_router.unregister(thread_id)
entry.log_queue.put("[Worker] Thread finished.")