File size: 8,217 Bytes
9244b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
training_worker: thread entry point for every AutoML run.
Captures stdout/stderr, feeds log_queue, puts result into result_queue,
and respects the stop_event for graceful cancellation.

Log isolation strategy (definitive):
  - We attach a _QueueLogHandler to each relevant named library logger.
  - Each handler has a _ThreadFilter that only accepts log records whose
    record.thread matches the experiment thread's ID.
  - This means messages from Thread A never land in Thread B's queue,
    even though they share the same named logger objects.
  - propagate is set to False to prevent double-delivery via the root logger.
  - All are restored in the finally block.

Stdout/Stderr isolation:
  - redirect_stdout/redirect_stderr are process-global (they overwrite sys.stdout).
  - We use a _ThreadAwareIO wrapper instead: it checks threading.current_thread()
    on every write() call, so writes only reach the owning thread's queue.
"""
import io
import sys
import logging
import threading
import traceback

from src.experiment_manager import ExperimentEntry

_LIB_LOGGERS = [
    'flaml', 'autogluon', 'mlflow', 'h2o', 'tpot',
    'pycaret', 'lale', 'hyperopt', 'lightgbm', 'xgboost', 'catboost'
]


# ---------------------------------------------------------------------------
# Thread-aware stdout/stderr router (installed once, process-wide)
# ---------------------------------------------------------------------------
class _ThreadAwareIO(io.TextIOBase):
    """
    Drop-in replacement for sys.stdout / sys.stderr that routes each write()
    to the queue registered for the current thread, or falls back to the
    original stream.
    """
    def __init__(self, original_stream):
        super().__init__()
        self._original = original_stream
        self._lock = threading.Lock()
        self._thread_queues: dict[int, "queue.Queue"] = {}

    def register(self, thread_id: int, q):
        with self._lock:
            self._thread_queues[thread_id] = q

    def unregister(self, thread_id: int):
        with self._lock:
            self._thread_queues.pop(thread_id, None)

    def write(self, s: str) -> int:
        if not isinstance(s, str):
            try:
                s = str(s)
            except Exception:
                return 0
        tid = threading.current_thread().ident
        with self._lock:
            q = self._thread_queues.get(tid)
        if q is not None:
            if s.strip():
                # Filter out progress bar characters that fail on Windows cp1252
                # \u2588 is the full block character
                safe_s = s.replace('\u2588', '#').replace('\u258c', '|').replace('\u2584', '-')
                q.put(safe_s.strip())
        else:
            # Fall back to original stream for threads not registered
            try:
                self._original.write(s)
            except Exception:
                pass
        return len(s)

    def flush(self):
        try:
            self._original.flush()
        except Exception:
            pass

    @property
    def encoding(self):
        return getattr(self._original, 'encoding', 'utf-8') or 'utf-8'

    @property
    def errors(self):
        return getattr(self._original, 'errors', 'replace')


# Install thread-aware routers once for the entire process
_stdout_router = _ThreadAwareIO(sys.__stdout__)
_stderr_router = _ThreadAwareIO(sys.__stderr__)
sys.stdout = _stdout_router
sys.stderr = _stderr_router


# ---------------------------------------------------------------------------
# Per-thread log handler with thread filter
# ---------------------------------------------------------------------------
class _ThreadFilter(logging.Filter):
    """Only accepts log records emitted by a specific OS thread."""
    def __init__(self, thread_id: int):
        super().__init__()
        self._thread_id = thread_id

    def filter(self, record: logging.LogRecord) -> bool:
        return record.thread == self._thread_id


class _QueueLogHandler(logging.Handler):
    def __init__(self, log_queue):
        super().__init__()
        self.log_queue = log_queue

    def emit(self, record: logging.LogRecord):
        try:
            msg = self.format(record)
            self.log_queue.put(msg)
        except Exception:
            pass


# ---------------------------------------------------------------------------
# Worker entry point
# ---------------------------------------------------------------------------
def run_training_worker(entry: ExperimentEntry, train_fn, kwargs: dict):
    """
    Thread target. Runs train_fn(**kwargs, stop_event=entry.stop_event),
    keeps the entry status updated, and puts the final result dict in
    result_queue.
    """
    thread_id = threading.current_thread().ident

    # --- Thread-aware stdout/stderr capture ---
    _stdout_router.register(thread_id, entry.log_queue)
    _stderr_router.register(thread_id, entry.log_queue)

    # --- Per-thread logging handler (with thread filter) ---
    handler = _QueueLogHandler(entry.log_queue)
    handler.setFormatter(logging.Formatter('%(message)s'))
    handler.addFilter(_ThreadFilter(thread_id))

    saved_propagate: dict[str, bool] = {}
    for lib in _LIB_LOGGERS:
        lib_logger = logging.getLogger(lib)
        saved_propagate[lib] = lib_logger.propagate
        lib_logger.propagate = False  # prevents root from seeing AND double-deliver
        lib_logger.addHandler(handler)
        if lib_logger.level == logging.NOTSET or lib_logger.level > logging.INFO:
            lib_logger.setLevel(logging.INFO)

    entry.status = "running"
    entry.log_queue.put(f"[Worker] Starting training: {entry.metadata.get('run_name', entry.key)}")

    try:
        # Inject stop_event and telemetry_queue into kwargs if the function accepts it
        try:
            import inspect
            sig = inspect.signature(train_fn)
            if 'stop_event' in sig.parameters:
                kwargs['stop_event'] = entry.stop_event
            if 'telemetry_queue' in sig.parameters:
                kwargs['telemetry_queue'] = entry.telemetry_queue
        except Exception:
            pass

        result = train_fn(**kwargs)

        # Normalise result into a standard dict
        if isinstance(result, tuple):
            if len(result) == 2:
                predictor, run_id = result
                entry.result_queue.put({
                    "success": True, "predictor": predictor, "run_id": run_id,
                    "type": entry.metadata.get("framework_key", "unknown")
                })
            elif len(result) == 4:
                tpot, pipeline, run_id, info = result
                entry.result_queue.put({
                    "success": True, "predictor": pipeline, "run_id": run_id, "info": info, "type": "tpot"
                })
            else:
                entry.result_queue.put({
                    "success": True, "predictor": result[0], "run_id": result[-1],
                    "type": entry.metadata.get("framework_key", "unknown")
                })
        elif isinstance(result, dict):
            entry.result_queue.put(result)
        else:
            entry.result_queue.put({
                "success": True, "predictor": result, "run_id": None,
                "type": entry.metadata.get("framework_key", "unknown")
            })

    except StopIteration:
        entry.log_queue.put("[Worker] Training cancelled by user request.")
        entry.result_queue.put({"success": False, "cancelled": True, "error": "Cancelled by user"})
    except Exception as e:
        err_tb = traceback.format_exc()
        entry.log_queue.put(f"[Worker] CRITICAL ERROR: {e}\n{err_tb}")
        entry.result_queue.put({"success": False, "error": str(e), "traceback": err_tb})
    finally:
        # Restore all lib loggers
        for lib in _LIB_LOGGERS:
            lib_logger = logging.getLogger(lib)
            lib_logger.removeHandler(handler)
            lib_logger.propagate = saved_propagate.get(lib, True)

        # Unregister stdout/stderr routing for this thread
        _stdout_router.unregister(thread_id)
        _stderr_router.unregister(thread_id)

        entry.log_queue.put("[Worker] Thread finished.")