File size: 7,776 Bytes
1330e26
9244b7e
1330e26
 
 
 
 
 
 
 
9244b7e
1330e26
 
 
0cb14f5
 
9244b7e
 
 
 
1330e26
 
 
 
9c720d9
1330e26
 
 
 
 
 
 
9244b7e
 
 
 
 
 
 
 
1330e26
 
9c720d9
1330e26
 
 
 
 
 
 
 
 
 
 
 
0cb14f5
 
 
9c720d9
0cb14f5
 
 
9244b7e
0cb14f5
 
 
 
9c720d9
0cb14f5
 
1330e26
 
 
 
 
 
 
 
 
 
 
 
9244b7e
9c720d9
1330e26
9244b7e
 
1330e26
9c720d9
 
 
 
0cb14f5
 
 
 
9244b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1330e26
9c720d9
1330e26
 
9c720d9
1330e26
9c720d9
1330e26
9c720d9
1330e26
9244b7e
 
 
1330e26
 
 
9c720d9
1330e26
 
 
 
 
 
 
 
 
 
 
 
9244b7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1330e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c720d9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import threading
import pandas as pd
import mlflow
import shutil
import logging
from flaml import AutoML
import matplotlib.pyplot as plt
import time
from src.mlflow_utils import safe_set_experiment
from src.onnx_utils import export_to_onnx

logger = logging.getLogger(__name__)

def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str, 
                      valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
                       time_budget: int = 60, task: str = 'classification', metric: str = 'auto',
                       estimator_list: list = 'auto', seed: int = 42, cv_folds: int = 0,
                       n_jobs: int = 1,
                       stop_event=None, telemetry_queue=None):
    """
    Trains a FLAML model and logs results to MLflow.
    """
    safe_set_experiment("FLAML_Experiments")
    logging.info(f"Starting FLAML training for run: {run_name}")
    
    # Ensure flaml logger is also at INFO level
    import flaml
    from flaml import AutoML
    flaml_logger = logging.getLogger('flaml')
    flaml_logger.setLevel(logging.INFO)
    
    # Ensure no leaked runs in this thread
    try:
        if mlflow.active_run():
            mlflow.end_run()
    except:
        pass

    with mlflow.start_run(run_name=run_name, nested=True) as run:
        # Data cleaning: drop rows where target is NaN
        train_data = train_data.dropna(subset=[target])
        logging.info(f"Data ready: {len(train_data)} rows.")
        
        # Log parameters
        mlflow.log_param("target", target)
        mlflow.log_param("time_budget", time_budget)
        mlflow.log_param("task", task)
        mlflow.log_param("metric", metric)
        mlflow.log_param("estimator_list", str(estimator_list))
        mlflow.log_param("seed", seed)
        
        X_train = train_data.drop(columns=[target])
        y_train = train_data[target]
        
        X_val, y_val = None, None
        if valid_data is not None:
            if target not in valid_data.columns:
                raise ValueError(f"Target column '{target}' not found in Validation data.")
            valid_data = valid_data.dropna(subset=[target])
            X_val = valid_data.drop(columns=[target])
            y_val = valid_data[target]

            mlflow.log_param("has_validation_data", True)
            
        if test_data is not None:
             if target not in test_data.columns:
                 raise ValueError(f"Target column '{target}' not found in Test data.")
             mlflow.log_param("has_test_data", True)
        
        automl = AutoML()
        
        # Note: We are NOT using low_cost_partial_config because it causes 
        # TypeError in some estimators (like LGBM) when passed via automl.fit.
        # The 'No low-cost partial config given' message is just an INFO warning from FLAML.

        settings = {
            "metric": metric,
            "task": task,
            "estimator_list": estimator_list,
            "log_file_name": "flaml.log",
            "seed": seed,
            "n_jobs": n_jobs,
            "verbose": 0, # Reduce internal verbosity to avoid pollution, progress goes to flaml.log
        }
        if time_budget is not None:
            settings["time_budget"] = time_budget
        
        if cv_folds > 0:
            settings["eval_method"] = "cv"
            settings["n_splits"] = cv_folds
            
        if X_val is not None:
            settings["X_val"] = X_val
            settings["y_val"] = y_val
        
        # Start a watcher thread to respect stop_event
        _cancel_watcher = None
        if stop_event is not None:
            def _watch():
                stop_event.wait()
                try:
                    automl._state.time_budget = 0  # Signal FLAML to stop
                except Exception:
                    pass
            _cancel_watcher = threading.Thread(target=_watch, daemon=True)
            _cancel_watcher.start()

        # Custom callback for telemetry
        def _telemetry_callback(iter_count, time_used, best_loss, best_config, estimator, trial_id):
            try:
                if telemetry_queue:
                    telemetry_queue.put({
                        "status": "running",
                        "iterations": iter_count,
                        "time_used": time_used,
                        "best_loss": best_loss,
                        "best_estimator": str(estimator),
                        "best_config_preview": str(best_config)[:200]
                    })
            except Exception:
                pass

        if telemetry_queue:
            settings["callbacks"] = [_telemetry_callback]

        # Train model
        logging.info("Executing hyperparameter search (automl.fit)...")
        try:
            automl.fit(X_train=X_train, y_train=y_train, **settings)
            logging.info("Search finished successfully.")
        except StopIteration:
            logging.info("Search interrupted (time limit reached).")
            if not hasattr(automl, 'best_estimator') or automl.best_estimator is None:
                raise RuntimeError("FLAML stopped without finding a valid model.")
        
        if stop_event and stop_event.is_set():
            raise StopIteration("Training cancelled by user")
        
        # Log metrics
        if hasattr(automl, 'best_loss'):
            mlflow.log_metric("best_loss", automl.best_loss)
            logging.info(f"Best final Loss: {automl.best_loss:.4f}")
        
        # Save best model
        model_path = os.path.join("models", f"flaml_{run_name}.pkl")
        os.makedirs("models", exist_ok=True)
        import pickle
        with open(model_path, "wb") as f:
            pickle.dump(automl, f)
            
        # Log as artifact
        mlflow.log_artifact(model_path, artifact_path="model")
        mlflow.log_param("model_type", "flaml")
        
        # ONNX Export
        try:
            onnx_path = os.path.join("models", f"flaml_{run_name}.onnx")
            # For FLAML, we can often export the underlying best estimator or the AutoML object if it's scikit-learn compatible
            # We pass X_train[:1] as sample input for shape inference
            export_to_onnx(automl.model.estimator, "flaml", target, onnx_path, input_sample=X_train[:1])
            mlflow.log_artifact(onnx_path, artifact_path="model")
        except Exception as e:
            logger.warning(f"Failed to export FLAML model to ONNX: {e}")
        
        # Generate and log consumption code sample
        try:
            from src.code_gen_utils import generate_consumption_code
            code_sample = generate_consumption_code("flaml", run.info.run_id, target)
            code_path = "consumption_sample.py"
            with open(code_path, "w") as f:
                f.write(code_sample)
            mlflow.log_artifact(code_path)
            if os.path.exists(code_path):
                os.remove(code_path)
        except Exception as e:
            logger.warning(f"Failed to generate consumption code: {e}")
            
        # Log training log as artifact
        if os.path.exists("flaml.log"):
            mlflow.log_artifact("flaml.log")
            
        return automl, run.info.run_id

def load_flaml_model(run_id: str):
    import mlflow
    import pickle
    local_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path="model")
    # Find the .pkl file in the downloaded folder
    for root, dirs, files in os.walk(local_path):
        for file in files:
            if file.endswith(".pkl"):
                with open(os.path.join(root, file), "rb") as f:
                    return pickle.load(f)
    raise FileNotFoundError("FLAML model not found in artifacts.")