Spaces:
Sleeping
Sleeping
File size: 7,776 Bytes
1330e26 9244b7e 1330e26 9244b7e 1330e26 0cb14f5 9244b7e 1330e26 9c720d9 1330e26 9244b7e 1330e26 9c720d9 1330e26 0cb14f5 9c720d9 0cb14f5 9244b7e 0cb14f5 9c720d9 0cb14f5 1330e26 9244b7e 9c720d9 1330e26 9244b7e 1330e26 9c720d9 0cb14f5 9244b7e 1330e26 9c720d9 1330e26 9c720d9 1330e26 9c720d9 1330e26 9c720d9 1330e26 9244b7e 1330e26 9c720d9 1330e26 9244b7e 1330e26 9c720d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | import os
import threading
import pandas as pd
import mlflow
import shutil
import logging
from flaml import AutoML
import matplotlib.pyplot as plt
import time
from src.mlflow_utils import safe_set_experiment
from src.onnx_utils import export_to_onnx
logger = logging.getLogger(__name__)
def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
time_budget: int = 60, task: str = 'classification', metric: str = 'auto',
estimator_list: list = 'auto', seed: int = 42, cv_folds: int = 0,
n_jobs: int = 1,
stop_event=None, telemetry_queue=None):
"""
Trains a FLAML model and logs results to MLflow.
"""
safe_set_experiment("FLAML_Experiments")
logging.info(f"Starting FLAML training for run: {run_name}")
# Ensure flaml logger is also at INFO level
import flaml
from flaml import AutoML
flaml_logger = logging.getLogger('flaml')
flaml_logger.setLevel(logging.INFO)
# Ensure no leaked runs in this thread
try:
if mlflow.active_run():
mlflow.end_run()
except:
pass
with mlflow.start_run(run_name=run_name, nested=True) as run:
# Data cleaning: drop rows where target is NaN
train_data = train_data.dropna(subset=[target])
logging.info(f"Data ready: {len(train_data)} rows.")
# Log parameters
mlflow.log_param("target", target)
mlflow.log_param("time_budget", time_budget)
mlflow.log_param("task", task)
mlflow.log_param("metric", metric)
mlflow.log_param("estimator_list", str(estimator_list))
mlflow.log_param("seed", seed)
X_train = train_data.drop(columns=[target])
y_train = train_data[target]
X_val, y_val = None, None
if valid_data is not None:
if target not in valid_data.columns:
raise ValueError(f"Target column '{target}' not found in Validation data.")
valid_data = valid_data.dropna(subset=[target])
X_val = valid_data.drop(columns=[target])
y_val = valid_data[target]
mlflow.log_param("has_validation_data", True)
if test_data is not None:
if target not in test_data.columns:
raise ValueError(f"Target column '{target}' not found in Test data.")
mlflow.log_param("has_test_data", True)
automl = AutoML()
# Note: We are NOT using low_cost_partial_config because it causes
# TypeError in some estimators (like LGBM) when passed via automl.fit.
# The 'No low-cost partial config given' message is just an INFO warning from FLAML.
settings = {
"metric": metric,
"task": task,
"estimator_list": estimator_list,
"log_file_name": "flaml.log",
"seed": seed,
"n_jobs": n_jobs,
"verbose": 0, # Reduce internal verbosity to avoid pollution, progress goes to flaml.log
}
if time_budget is not None:
settings["time_budget"] = time_budget
if cv_folds > 0:
settings["eval_method"] = "cv"
settings["n_splits"] = cv_folds
if X_val is not None:
settings["X_val"] = X_val
settings["y_val"] = y_val
# Start a watcher thread to respect stop_event
_cancel_watcher = None
if stop_event is not None:
def _watch():
stop_event.wait()
try:
automl._state.time_budget = 0 # Signal FLAML to stop
except Exception:
pass
_cancel_watcher = threading.Thread(target=_watch, daemon=True)
_cancel_watcher.start()
# Custom callback for telemetry
def _telemetry_callback(iter_count, time_used, best_loss, best_config, estimator, trial_id):
try:
if telemetry_queue:
telemetry_queue.put({
"status": "running",
"iterations": iter_count,
"time_used": time_used,
"best_loss": best_loss,
"best_estimator": str(estimator),
"best_config_preview": str(best_config)[:200]
})
except Exception:
pass
if telemetry_queue:
settings["callbacks"] = [_telemetry_callback]
# Train model
logging.info("Executing hyperparameter search (automl.fit)...")
try:
automl.fit(X_train=X_train, y_train=y_train, **settings)
logging.info("Search finished successfully.")
except StopIteration:
logging.info("Search interrupted (time limit reached).")
if not hasattr(automl, 'best_estimator') or automl.best_estimator is None:
raise RuntimeError("FLAML stopped without finding a valid model.")
if stop_event and stop_event.is_set():
raise StopIteration("Training cancelled by user")
# Log metrics
if hasattr(automl, 'best_loss'):
mlflow.log_metric("best_loss", automl.best_loss)
logging.info(f"Best final Loss: {automl.best_loss:.4f}")
# Save best model
model_path = os.path.join("models", f"flaml_{run_name}.pkl")
os.makedirs("models", exist_ok=True)
import pickle
with open(model_path, "wb") as f:
pickle.dump(automl, f)
# Log as artifact
mlflow.log_artifact(model_path, artifact_path="model")
mlflow.log_param("model_type", "flaml")
# ONNX Export
try:
onnx_path = os.path.join("models", f"flaml_{run_name}.onnx")
# For FLAML, we can often export the underlying best estimator or the AutoML object if it's scikit-learn compatible
# We pass X_train[:1] as sample input for shape inference
export_to_onnx(automl.model.estimator, "flaml", target, onnx_path, input_sample=X_train[:1])
mlflow.log_artifact(onnx_path, artifact_path="model")
except Exception as e:
logger.warning(f"Failed to export FLAML model to ONNX: {e}")
# Generate and log consumption code sample
try:
from src.code_gen_utils import generate_consumption_code
code_sample = generate_consumption_code("flaml", run.info.run_id, target)
code_path = "consumption_sample.py"
with open(code_path, "w") as f:
f.write(code_sample)
mlflow.log_artifact(code_path)
if os.path.exists(code_path):
os.remove(code_path)
except Exception as e:
logger.warning(f"Failed to generate consumption code: {e}")
# Log training log as artifact
if os.path.exists("flaml.log"):
mlflow.log_artifact("flaml.log")
return automl, run.info.run_id
def load_flaml_model(run_id: str):
import mlflow
import pickle
local_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path="model")
# Find the .pkl file in the downloaded folder
for root, dirs, files in os.walk(local_path):
for file in files:
if file.endswith(".pkl"):
with open(os.path.join(root, file), "rb") as f:
return pickle.load(f)
raise FileNotFoundError("FLAML model not found in artifacts.")
|