Spaces:
Sleeping
Sleeping
Commit ·
9244b7e
1
Parent(s): 6eb9f5f
Add ONNX export utilities, pipeline parser, and PyCaret integration
Browse files- Implemented `onnx_utils.py` for exporting models to ONNX format and loading ONNX sessions.
- Created `pipeline_parser.py` to infer active steps in AutoML pipelines from logs for various frameworks.
- Developed `pycaret_utils.py` to run PyCaret experiments, including model comparison, tuning, and ONNX export.
- Introduced `training_worker.py` to manage training threads, capturing logs and handling graceful cancellation.
- Added `xai_utils.py` for generating SHAP explanations and occlusion saliency maps for model interpretability.
- .dvc/.gitignore +3 -0
- .dvc/config +0 -0
- app.py +0 -0
- src/__init__.py +0 -0
- src/autogluon_utils.py +165 -30
- src/autokeras_utils.py +147 -0
- src/code_gen_utils.py +422 -0
- src/data_utils.py +75 -3
- src/experiment_manager.py +156 -0
- src/flaml_utils.py +74 -4
- src/h2o_utils.py +125 -34
- src/huggingface_utils.py +96 -0
- src/lale_utils.py +198 -0
- src/onnx_utils.py +105 -0
- src/pipeline_parser.py +328 -0
- src/pycaret_utils.py +212 -0
- src/tpot_utils.py +21 -4
- src/training_worker.py +218 -0
- src/xai_utils.py +216 -0
.dvc/.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/config.local
|
| 2 |
+
/tmp
|
| 3 |
+
/cache
|
.dvc/config
ADDED
|
File without changes
|
app.py
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/__init__.py
ADDED
|
File without changes
|
src/autogluon_utils.py
CHANGED
|
@@ -3,21 +3,56 @@ import pandas as pd
|
|
| 3 |
import mlflow
|
| 4 |
import shutil
|
| 5 |
import logging
|
|
|
|
|
|
|
| 6 |
from src.mlflow_utils import safe_set_experiment
|
|
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
def train_model(train_data: pd.DataFrame, target: str, run_name: str,
|
| 11 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 12 |
-
time_limit: int = 60, presets: str = 'medium_quality', seed: int = 42, cv_folds: int = 0
|
|
|
|
| 13 |
"""
|
| 14 |
Trains an AutoGluon model and logs results to MLflow using generic artifact logging.
|
|
|
|
| 15 |
"""
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
safe_set_experiment("AutoGluon_Experiments")
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
# Data cleaning: drop rows where target is NaN
|
| 22 |
train_data = train_data.dropna(subset=[target])
|
| 23 |
|
|
@@ -44,40 +79,140 @@ def train_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 44 |
test_data = test_data.dropna(subset=[target])
|
| 45 |
mlflow.log_param("has_test_data", True)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
"presets": presets
|
| 52 |
-
}
|
| 53 |
-
if cv_folds > 0:
|
| 54 |
-
fit_args["num_bag_folds"] = cv_folds
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
-
# Log metrics (leaderboard)
|
| 62 |
-
# If test_data is provided, leaderboard and scoring will strictly use it,
|
| 63 |
-
# otherwise fallback to training data
|
| 64 |
eval_data = test_data if test_data is not None else (valid_data if valid_data is not None else train_data)
|
| 65 |
-
leaderboard = predictor.leaderboard(eval_data, silent=True)
|
| 66 |
-
# Log the best model's score
|
| 67 |
-
best_model_score = leaderboard.iloc[0]['score_val']
|
| 68 |
-
mlflow.log_metric("best_model_score", best_model_score)
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Log AutoGluon model directory as a generic artifact
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
return predictor, run.info.run_id
|
| 83 |
|
|
|
|
| 3 |
import mlflow
|
| 4 |
import shutil
|
| 5 |
import logging
|
| 6 |
+
import time
|
| 7 |
+
import threading
|
| 8 |
from src.mlflow_utils import safe_set_experiment
|
| 9 |
+
from src.onnx_utils import export_to_onnx
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
def train_model(train_data: pd.DataFrame, target: str, run_name: str,
|
| 14 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 15 |
+
time_limit: int = 60, presets: str = 'medium_quality', seed: int = 42, cv_folds: int = 0,
|
| 16 |
+
stop_event=None, task_type: str = "Classification", telemetry_queue=None):
|
| 17 |
"""
|
| 18 |
Trains an AutoGluon model and logs results to MLflow using generic artifact logging.
|
| 19 |
+
Supports both Tabular data and Computer Vision tasks (via MultiModalPredictor).
|
| 20 |
"""
|
| 21 |
+
is_cv_task = task_type and task_type.startswith("Computer Vision")
|
| 22 |
+
is_segmentation = task_type == "Computer Vision - Image Segmentation"
|
| 23 |
+
is_multilabel = task_type == "Computer Vision - Multi-Label Classification"
|
| 24 |
+
|
| 25 |
+
if is_cv_task:
|
| 26 |
+
from autogluon.multimodal import MultiModalPredictor
|
| 27 |
+
|
| 28 |
+
def build_image_df(path_df):
|
| 29 |
+
if path_df is None or "Image_Directory" not in path_df.columns:
|
| 30 |
+
return path_df
|
| 31 |
+
img_dir = path_df.iloc[0]["Image_Directory"]
|
| 32 |
+
data = []
|
| 33 |
+
for root, _, files in os.walk(img_dir):
|
| 34 |
+
label = os.path.basename(root)
|
| 35 |
+
for file in files:
|
| 36 |
+
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
| 37 |
+
data.append({"image": os.path.join(root, file), target: label})
|
| 38 |
+
return pd.DataFrame(data)
|
| 39 |
+
|
| 40 |
+
train_data = build_image_df(train_data)
|
| 41 |
+
valid_data = build_image_df(valid_data)
|
| 42 |
+
test_data = build_image_df(test_data)
|
| 43 |
+
else:
|
| 44 |
+
from autogluon.tabular import TabularPredictor
|
| 45 |
|
| 46 |
safe_set_experiment("AutoGluon_Experiments")
|
| 47 |
|
| 48 |
+
# Ensure no leaked runs in this thread
|
| 49 |
+
try:
|
| 50 |
+
if mlflow.active_run():
|
| 51 |
+
mlflow.end_run()
|
| 52 |
+
except:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
with mlflow.start_run(run_name=run_name, nested=True) as run:
|
| 56 |
# Data cleaning: drop rows where target is NaN
|
| 57 |
train_data = train_data.dropna(subset=[target])
|
| 58 |
|
|
|
|
| 79 |
test_data = test_data.dropna(subset=[target])
|
| 80 |
mlflow.log_param("has_test_data", True)
|
| 81 |
|
| 82 |
+
if is_cv_task:
|
| 83 |
+
mm_fit_args = {"train_data": train_data, "time_limit": time_limit}
|
| 84 |
+
if valid_data is not None:
|
| 85 |
+
mm_fit_args["tuning_data"] = valid_data
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
problem_type = None
|
| 88 |
+
if is_segmentation:
|
| 89 |
+
problem_type = "semantic_segmentation"
|
| 90 |
+
elif task_type == "Computer Vision - Object Detection":
|
| 91 |
+
problem_type = "object_detection"
|
| 92 |
+
|
| 93 |
+
mm_presets = "high_quality" if presets in ["best_quality", "high_quality"] else "medium_quality"
|
| 94 |
+
predictor = MultiModalPredictor(label=target, problem_type=problem_type, path=model_path).fit(**mm_fit_args, presets=mm_presets)
|
| 95 |
+
else:
|
| 96 |
+
fit_args = {
|
| 97 |
+
"train_data": train_data,
|
| 98 |
+
"time_limit": time_limit,
|
| 99 |
+
"presets": presets
|
| 100 |
+
}
|
| 101 |
+
if cv_folds > 0:
|
| 102 |
+
fit_args["num_bag_folds"] = cv_folds
|
| 103 |
+
|
| 104 |
+
if valid_data is not None:
|
| 105 |
+
fit_args["tuning_data"] = valid_data
|
| 106 |
+
# If bagging is enabled (manually or by presets), we must set use_bag_holdout=True to use separate tuning_data
|
| 107 |
+
if cv_folds > 0 or presets in ["best_quality", "high_quality"]:
|
| 108 |
+
fit_args["use_bag_holdout"] = True
|
| 109 |
+
|
| 110 |
+
if is_multilabel:
|
| 111 |
+
fit_args["problem_type"] = "multiclass"
|
| 112 |
+
mlflow.log_param("is_multilabel", True)
|
| 113 |
+
|
| 114 |
+
# Streaming updates thread
|
| 115 |
+
def _push_ag_telemetry():
|
| 116 |
+
while not (stop_event and stop_event.is_set()):
|
| 117 |
+
try:
|
| 118 |
+
if os.path.exists(model_path):
|
| 119 |
+
# AutoGluon sometimes locks the file, so we try-except
|
| 120 |
+
from autogluon.tabular import TabularPredictor
|
| 121 |
+
try:
|
| 122 |
+
temp_predictor = TabularPredictor.load(path=model_path)
|
| 123 |
+
lb = temp_predictor.leaderboard(silent=True)
|
| 124 |
+
if len(lb) > 0:
|
| 125 |
+
best_model = lb.iloc[0]['model']
|
| 126 |
+
best_score = lb.iloc[0]['score_val']
|
| 127 |
+
if telemetry_queue:
|
| 128 |
+
telemetry_queue.put({
|
| 129 |
+
"status": "running",
|
| 130 |
+
"models_trained": len(lb),
|
| 131 |
+
"best_model": best_model,
|
| 132 |
+
"best_value": best_score,
|
| 133 |
+
"leaderboard_preview": lb.head(5).to_dict(orient='records')
|
| 134 |
+
})
|
| 135 |
+
except:
|
| 136 |
+
pass
|
| 137 |
+
except:
|
| 138 |
+
pass
|
| 139 |
+
time.sleep(10)
|
| 140 |
|
| 141 |
+
if telemetry_queue:
|
| 142 |
+
t_telemetry = threading.Thread(target=_push_ag_telemetry, daemon=True)
|
| 143 |
+
t_telemetry.start()
|
| 144 |
+
|
| 145 |
+
predictor = TabularPredictor(label=target, path=model_path).fit(**fit_args)
|
| 146 |
+
|
| 147 |
+
# Check if cancelled before continuing
|
| 148 |
+
if stop_event and stop_event.is_set():
|
| 149 |
+
raise StopIteration("Training cancelled by user")
|
| 150 |
|
|
|
|
|
|
|
|
|
|
| 151 |
eval_data = test_data if test_data is not None else (valid_data if valid_data is not None else train_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
if is_cv_task:
|
| 154 |
+
scores = predictor.evaluate(eval_data)
|
| 155 |
+
best_model_score = scores.get('accuracy', scores.get('roc_auc', 0.0))
|
| 156 |
+
mlflow.log_metrics(scores)
|
| 157 |
+
leaderboard_path = "leaderboard.csv"
|
| 158 |
+
pd.DataFrame([scores]).to_csv(leaderboard_path, index=False)
|
| 159 |
+
else:
|
| 160 |
+
leaderboard = predictor.leaderboard(eval_data, silent=True)
|
| 161 |
+
# Log the best model's score
|
| 162 |
+
best_model_score = leaderboard.iloc[0]['score_val']
|
| 163 |
+
mlflow.log_metric("best_model_score", best_model_score)
|
| 164 |
+
leaderboard_path = "leaderboard.csv"
|
| 165 |
+
leaderboard.to_csv(leaderboard_path, index=False)
|
| 166 |
+
try:
|
| 167 |
+
mlflow.log_artifact(leaderboard_path)
|
| 168 |
+
except Exception as e:
|
| 169 |
+
logger.warning(f"Failed to log leaderboard artifact: {e}")
|
| 170 |
+
finally:
|
| 171 |
+
if os.path.exists(leaderboard_path):
|
| 172 |
+
os.remove(leaderboard_path)
|
| 173 |
|
| 174 |
# Log AutoGluon model directory as a generic artifact
|
| 175 |
+
# We use a try-except here because disk space issues frequently occur during artifact copy
|
| 176 |
+
try:
|
| 177 |
+
mlflow.log_artifacts(model_path, artifact_path="model")
|
| 178 |
+
mlflow.log_param("model_type", "autogluon")
|
| 179 |
+
|
| 180 |
+
# ONNX Export (Best effort for Tabular)
|
| 181 |
+
if not is_cv_task:
|
| 182 |
+
try:
|
| 183 |
+
onnx_path = os.path.join("models", f"ag_{run_name}.onnx")
|
| 184 |
+
# AutoGluon Tabular supports ONNX export for some models
|
| 185 |
+
# This might require specific dependencies or AG version
|
| 186 |
+
# We call our utility which handles AG logic
|
| 187 |
+
export_to_onnx(predictor, "autogluon", target, onnx_path, input_sample=train_data[:1])
|
| 188 |
+
mlflow.log_artifact(onnx_path, artifact_path="model")
|
| 189 |
+
except Exception as e:
|
| 190 |
+
logger.warning(f"Failed to export AutoGluon model to ONNX: {e}")
|
| 191 |
+
|
| 192 |
+
logger.info(f"AutoGluon artifacts logged successfully for {run_name}")
|
| 193 |
+
|
| 194 |
+
# CRITICAL: Delete local model folder after successful MLflow logging to save disk space
|
| 195 |
+
# Only do this if it was logged successfully to the tracking server/local mlruns
|
| 196 |
+
if os.path.exists(model_path):
|
| 197 |
+
shutil.rmtree(model_path)
|
| 198 |
+
logger.info(f"Cleaned up local model folder: {model_path}")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
logger.error(f"Failed to log model artifacts to MLflow (likely disk space): {e}")
|
| 201 |
+
# Do NOT delete model_path here so the user can potentially recover it manually
|
| 202 |
+
# if the MLflow log failed.
|
| 203 |
+
|
| 204 |
+
# Generate and log consumption code sample
|
| 205 |
+
try:
|
| 206 |
+
from src.code_gen_utils import generate_consumption_code
|
| 207 |
+
code_sample = generate_consumption_code("autogluon", run.info.run_id, target)
|
| 208 |
+
code_path = "consumption_sample.py"
|
| 209 |
+
with open(code_path, "w") as f:
|
| 210 |
+
f.write(code_sample)
|
| 211 |
+
mlflow.log_artifact(code_path)
|
| 212 |
+
if os.path.exists(code_path):
|
| 213 |
+
os.remove(code_path)
|
| 214 |
+
except Exception as e:
|
| 215 |
+
logger.warning(f"Failed to generate consumption code: {e}")
|
| 216 |
|
| 217 |
return predictor, run.info.run_id
|
| 218 |
|
src/autokeras_utils.py
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import time
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import numpy as np
|
| 6 |
+
import mlflow
|
| 7 |
+
import logging
|
| 8 |
+
from src.mlflow_utils import safe_set_experiment
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def run_autokeras_experiment(train_data: pd.DataFrame, target: str, run_name: str,
|
| 13 |
+
valid_data: pd.DataFrame = None, task_type: str = "Computer Vision - Image Classification",
|
| 14 |
+
time_limit: int = 60, stop_event=None, log_queue=None):
|
| 15 |
+
"""
|
| 16 |
+
Trains an AutoKeras model for Image tasks.
|
| 17 |
+
train_data contains a dataframe with 'Image_Directory' pointing to the dataset path.
|
| 18 |
+
"""
|
| 19 |
+
safe_set_experiment("AutoKeras_Experiments")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import autokeras as ak
|
| 23 |
+
import tensorflow as tf
|
| 24 |
+
except ImportError:
|
| 25 |
+
raise ImportError("AutoKeras or TensorFlow not installed. Please install them to use AutoKeras.")
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
if mlflow.active_run():
|
| 29 |
+
mlflow.end_run()
|
| 30 |
+
except Exception:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
def qlog(msg):
|
| 34 |
+
if log_queue:
|
| 35 |
+
log_queue.put(msg)
|
| 36 |
+
logger.info(msg)
|
| 37 |
+
|
| 38 |
+
with mlflow.start_run(run_name=run_name, nested=True) as run:
|
| 39 |
+
mlflow.log_param("framework", "autokeras")
|
| 40 |
+
mlflow.log_param("task_type", task_type)
|
| 41 |
+
mlflow.log_param("time_limit", time_limit)
|
| 42 |
+
|
| 43 |
+
if "Image_Directory" not in train_data.columns:
|
| 44 |
+
raise ValueError("AutoKeras requires 'Image_Directory' in the training payload for CV tasks.")
|
| 45 |
+
|
| 46 |
+
img_dir = train_data.iloc[0]["Image_Directory"]
|
| 47 |
+
qlog(f"Scanning image directory: {img_dir}")
|
| 48 |
+
|
| 49 |
+
# We need to construct tf.data.Dataset from directory
|
| 50 |
+
# Since AutoKeras ImageClassifier accepts tf.data.Dataset
|
| 51 |
+
batch_size = 32
|
| 52 |
+
|
| 53 |
+
train_ds = tf.keras.utils.image_dataset_from_directory(
|
| 54 |
+
img_dir,
|
| 55 |
+
validation_split=0.2 if valid_data is None else None,
|
| 56 |
+
subset="training" if valid_data is None else None,
|
| 57 |
+
seed=42,
|
| 58 |
+
image_size=(256, 256),
|
| 59 |
+
batch_size=batch_size
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if valid_data is None:
|
| 63 |
+
val_ds = tf.keras.utils.image_dataset_from_directory(
|
| 64 |
+
img_dir,
|
| 65 |
+
validation_split=0.2,
|
| 66 |
+
subset="validation",
|
| 67 |
+
seed=42,
|
| 68 |
+
image_size=(256, 256),
|
| 69 |
+
batch_size=batch_size
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
val_img_dir = valid_data.iloc[0]["Image_Directory"]
|
| 73 |
+
val_ds = tf.keras.utils.image_dataset_from_directory(
|
| 74 |
+
val_img_dir,
|
| 75 |
+
seed=42,
|
| 76 |
+
image_size=(256, 256),
|
| 77 |
+
batch_size=batch_size
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
mlflow.log_param("num_classes", len(train_ds.class_names))
|
| 81 |
+
|
| 82 |
+
model_path = os.path.join("models", run_name)
|
| 83 |
+
if os.path.exists(model_path):
|
| 84 |
+
shutil.rmtree(model_path)
|
| 85 |
+
|
| 86 |
+
qlog("Starting AutoKeras topology search...")
|
| 87 |
+
|
| 88 |
+
# Estimate max trials based on time_limit pseudo translation (1 trial ~ 100s for small data)
|
| 89 |
+
max_trials = max(1, time_limit // 100)
|
| 90 |
+
|
| 91 |
+
def dataset_to_numpy(ds):
|
| 92 |
+
x_all, y_all = [], []
|
| 93 |
+
for x, y in ds:
|
| 94 |
+
x_all.append(x.numpy())
|
| 95 |
+
y_all.append(y.numpy())
|
| 96 |
+
if not x_all: return None, None
|
| 97 |
+
return np.concatenate(x_all, axis=0), np.concatenate(y_all, axis=0)
|
| 98 |
+
|
| 99 |
+
x_train, y_train = dataset_to_numpy(train_ds)
|
| 100 |
+
|
| 101 |
+
x_val, y_val = None, None
|
| 102 |
+
if val_ds:
|
| 103 |
+
x_val, y_val = dataset_to_numpy(val_ds)
|
| 104 |
+
|
| 105 |
+
if task_type == "Computer Vision - Image Classification":
|
| 106 |
+
clf = ak.ImageClassifier(overwrite=True, max_trials=max_trials, directory=model_path)
|
| 107 |
+
if val_ds:
|
| 108 |
+
clf.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5) # Default short epoch
|
| 109 |
+
else:
|
| 110 |
+
clf.fit(x_train, y_train, epochs=5)
|
| 111 |
+
elif task_type == "Computer Vision - Multi-Label Classification":
|
| 112 |
+
clf = ak.ImageClassifier(overwrite=True, max_trials=max_trials, directory=model_path, multi_label=True)
|
| 113 |
+
if val_ds:
|
| 114 |
+
clf.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5) # Default short epoch
|
| 115 |
+
else:
|
| 116 |
+
clf.fit(x_train, y_train, epochs=5)
|
| 117 |
+
else:
|
| 118 |
+
# We don't natively support bounding boxes or segmentation masks right now without specific parser
|
| 119 |
+
raise NotImplementedError(f"AutoKeras task '{task_type}' requires labels not inherently present in the directory structure or is unsupported by AutoKeras basic API.")
|
| 120 |
+
|
| 121 |
+
if stop_event and stop_event.is_set():
|
| 122 |
+
raise StopIteration("Training cancelled by user")
|
| 123 |
+
|
| 124 |
+
qlog("Evaluating best model...")
|
| 125 |
+
loss, accuracy = clf.evaluate(val_ds)
|
| 126 |
+
mlflow.log_metric("val_loss", loss)
|
| 127 |
+
mlflow.log_metric("val_accuracy", accuracy)
|
| 128 |
+
|
| 129 |
+
qlog("Saving and logging artifacts...")
|
| 130 |
+
export_path = os.path.join(model_path, "best_model")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
model = clf.export_model()
|
| 134 |
+
model.save(export_path, save_format="tf")
|
| 135 |
+
mlflow.log_artifacts(export_path, artifact_path="model")
|
| 136 |
+
mlflow.log_param("model_type", "autokeras")
|
| 137 |
+
qlog("AutoKeras artifacts logged successfully.")
|
| 138 |
+
except Exception as e:
|
| 139 |
+
qlog(f"Warning: Model export failed: {e}")
|
| 140 |
+
|
| 141 |
+
# Return a dictionary of useful data for UI
|
| 142 |
+
return {
|
| 143 |
+
"run_id": run.info.run_id,
|
| 144 |
+
"type": "autokeras",
|
| 145 |
+
# Can't pass TF model across processes easily via queues, so we pass None
|
| 146 |
+
"predictor": None
|
| 147 |
+
}
|
src/code_gen_utils.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def generate_consumption_code(model_type: str, run_id: str, target_column: str) -> str:
|
| 5 |
+
"""
|
| 6 |
+
Generates a Python code snippet to load and run predictions with the trained model.
|
| 7 |
+
Supports: autogluon, flaml, h2o, tpot, pycaret, lale.
|
| 8 |
+
"""
|
| 9 |
+
try:
|
| 10 |
+
client = mlflow.tracking.MlflowClient()
|
| 11 |
+
run = client.get_run(run_id)
|
| 12 |
+
task_type = run.data.params.get("task_type", "Classification")
|
| 13 |
+
except Exception:
|
| 14 |
+
task_type = "Classification"
|
| 15 |
+
|
| 16 |
+
base_code = f"""# Sample code to consume the trained model
|
| 17 |
+
# Run ID: {run_id}
|
| 18 |
+
# Model Type: {model_type}
|
| 19 |
+
# Task Type: {task_type}
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import pandas as pd
|
| 23 |
+
import mlflow
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
if model_type == "autogluon":
|
| 27 |
+
return base_code + f"""
|
| 28 |
+
from autogluon.tabular import TabularPredictor
|
| 29 |
+
|
| 30 |
+
# 1. Download model from MLflow
|
| 31 |
+
local_path = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 32 |
+
|
| 33 |
+
# 2. Load model
|
| 34 |
+
predictor = TabularPredictor.load(local_path)
|
| 35 |
+
|
| 36 |
+
# 3. Predict
|
| 37 |
+
# data = pd.read_csv("your_data.csv")
|
| 38 |
+
# predictions = predictor.predict(data)
|
| 39 |
+
# print(predictions)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
elif model_type == "flaml":
|
| 43 |
+
return base_code + f"""
|
| 44 |
+
import pickle
|
| 45 |
+
|
| 46 |
+
# 1. Download model from MLflow
|
| 47 |
+
local_path = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 48 |
+
|
| 49 |
+
# 2. Load the .pkl file
|
| 50 |
+
model = None
|
| 51 |
+
for root, dirs, files in os.walk(local_path):
|
| 52 |
+
for f in files:
|
| 53 |
+
if f.endswith(".pkl"):
|
| 54 |
+
with open(os.path.join(root, f), "rb") as fh:
|
| 55 |
+
model = pickle.load(fh)
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if model is None:
|
| 59 |
+
raise FileNotFoundError("Model .pkl not found in artifacts.")
|
| 60 |
+
|
| 61 |
+
# 3. Predict
|
| 62 |
+
# data = pd.read_csv("your_data.csv")
|
| 63 |
+
# predictions = model.predict(data)
|
| 64 |
+
# print(predictions)
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
elif model_type == "h2o":
|
| 68 |
+
return base_code + f"""
|
| 69 |
+
import h2o
|
| 70 |
+
|
| 71 |
+
# 1. Initialize H2O
|
| 72 |
+
h2o.init()
|
| 73 |
+
|
| 74 |
+
# 2. Download model from MLflow
|
| 75 |
+
local_path = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 76 |
+
|
| 77 |
+
# 3. Load the H2O model
|
| 78 |
+
model = None
|
| 79 |
+
for root, dirs, files in os.walk(local_path):
|
| 80 |
+
for f in files:
|
| 81 |
+
if f.endswith(".zip") or "." not in f:
|
| 82 |
+
model = h2o.load_model(os.path.join(root, f))
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
# 4. Predict
|
| 86 |
+
# h2o_frame = h2o.H2OFrame(pd.read_csv("your_data.csv"))
|
| 87 |
+
# predictions = model.predict(h2o_frame)
|
| 88 |
+
# print(predictions.as_data_frame())
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
elif model_type == "tpot":
|
| 92 |
+
return base_code + f"""
|
| 93 |
+
import mlflow.sklearn
|
| 94 |
+
|
| 95 |
+
# 1. Load model directly from MLflow
|
| 96 |
+
model = mlflow.sklearn.load_model("runs:/{run_id}/model")
|
| 97 |
+
|
| 98 |
+
# 2. Predict
|
| 99 |
+
# data = pd.read_csv("your_data.csv")
|
| 100 |
+
# predictions = model.predict(data)
|
| 101 |
+
# print(predictions)
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
elif model_type == "pycaret":
|
| 105 |
+
if task_type == "Regression":
|
| 106 |
+
pc_module = "pycaret.regression"
|
| 107 |
+
elif task_type == "Time Series Forecasting":
|
| 108 |
+
pc_module = "pycaret.time_series"
|
| 109 |
+
else:
|
| 110 |
+
pc_module = "pycaret.classification"
|
| 111 |
+
|
| 112 |
+
return base_code + f"""
|
| 113 |
+
import joblib
|
| 114 |
+
from {pc_module} import load_model, predict_model
|
| 115 |
+
|
| 116 |
+
# 1. Download model artifact from MLflow
|
| 117 |
+
local_path = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 118 |
+
|
| 119 |
+
# 2. Find and load the PyCaret .pkl file
|
| 120 |
+
model_path = None
|
| 121 |
+
for root, dirs, files in os.walk(local_path):
|
| 122 |
+
for f in files:
|
| 123 |
+
if f.endswith(".pkl"):
|
| 124 |
+
model_path = os.path.join(root, f).replace(".pkl", "")
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
if model_path is None:
|
| 128 |
+
raise FileNotFoundError("PyCaret model .pkl not found in artifacts.")
|
| 129 |
+
|
| 130 |
+
model = load_model(model_path)
|
| 131 |
+
|
| 132 |
+
# 3. Predict
|
| 133 |
+
# data = pd.read_csv("your_data.csv") # For classification/regression, must NOT contain target column
|
| 134 |
+
# predictions = predict_model(model, data=data)
|
| 135 |
+
# print(predictions)
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
elif model_type == "lale":
|
| 139 |
+
return base_code + f"""
|
| 140 |
+
import joblib
|
| 141 |
+
import numpy as np
|
| 142 |
+
from sklearn.preprocessing import OrdinalEncoder, LabelEncoder
|
| 143 |
+
|
| 144 |
+
# 1. Download model artifact from MLflow
|
| 145 |
+
local_path = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 146 |
+
|
| 147 |
+
# 2. Find and load the Lale joblib bundle
|
| 148 |
+
bundle = None
|
| 149 |
+
for root, dirs, files in os.walk(local_path):
|
| 150 |
+
for f in files:
|
| 151 |
+
if f.endswith(".pkl"):
|
| 152 |
+
bundle = joblib.load(os.path.join(root, f))
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
if bundle is None:
|
| 156 |
+
raise FileNotFoundError("Lale model .pkl not found in artifacts.")
|
| 157 |
+
|
| 158 |
+
model = bundle["model"]
|
| 159 |
+
col_encoders = bundle.get("col_encoders", {{}})
|
| 160 |
+
y_encoder = bundle.get("y_encoder", None)
|
| 161 |
+
|
| 162 |
+
# 3. Preprocess and Predict
|
| 163 |
+
# data = pd.read_csv("your_data.csv") # must NOT contain target column
|
| 164 |
+
#
|
| 165 |
+
# for col, enc in col_encoders.items():
|
| 166 |
+
# data[col] = enc.transform(data[[col]]).ravel()
|
| 167 |
+
#
|
| 168 |
+
# raw_preds = model.predict(data.values)
|
| 169 |
+
#
|
| 170 |
+
# if y_encoder is not None:
|
| 171 |
+
# predictions = y_encoder.inverse_transform(raw_preds)
|
| 172 |
+
# else:
|
| 173 |
+
# predictions = raw_preds
|
| 174 |
+
#
|
| 175 |
+
# print(predictions)
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
else:
|
| 179 |
+
return base_code + f"""
|
| 180 |
+
# Code generation for '{model_type}' is not explicitly implemented.
|
| 181 |
+
# Try loading via: mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def _load_code_for_deploy(model_type: str, run_id: str) -> str:
|
| 186 |
+
"""Returns the model-loading block used in the FastAPI main.py."""
|
| 187 |
+
if model_type == "autogluon":
|
| 188 |
+
return f"""
|
| 189 |
+
from autogluon.tabular import TabularPredictor
|
| 190 |
+
import mlflow
|
| 191 |
+
_local = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 192 |
+
model = TabularPredictor.load(_local)
|
| 193 |
+
|
| 194 |
+
def _predict(df):
|
| 195 |
+
return model.predict(df).tolist()
|
| 196 |
+
"""
|
| 197 |
+
elif model_type == "flaml":
|
| 198 |
+
return f"""
|
| 199 |
+
import pickle, os, mlflow
|
| 200 |
+
_local = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 201 |
+
model = None
|
| 202 |
+
for root, _, files in os.walk(_local):
|
| 203 |
+
for f in files:
|
| 204 |
+
if f.endswith(".pkl"):
|
| 205 |
+
with open(os.path.join(root, f), "rb") as fh:
|
| 206 |
+
model = pickle.load(fh)
|
| 207 |
+
break
|
| 208 |
+
if model is None:
|
| 209 |
+
raise FileNotFoundError("FLAML model not found.")
|
| 210 |
+
|
| 211 |
+
def _predict(df):
|
| 212 |
+
return model.predict(df).tolist()
|
| 213 |
+
"""
|
| 214 |
+
elif model_type == "h2o":
|
| 215 |
+
return f"""
|
| 216 |
+
import h2o, os, mlflow
|
| 217 |
+
h2o.init()
|
| 218 |
+
_local = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 219 |
+
model = None
|
| 220 |
+
for root, _, files in os.walk(_local):
|
| 221 |
+
for f in files:
|
| 222 |
+
if f.endswith(".zip") or "." not in f:
|
| 223 |
+
model = h2o.load_model(os.path.join(root, f))
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
def _predict(df):
|
| 227 |
+
hf = h2o.H2OFrame(df)
|
| 228 |
+
return model.predict(hf).as_data_frame()["predict"].tolist()
|
| 229 |
+
"""
|
| 230 |
+
elif model_type == "tpot":
|
| 231 |
+
return f"""
|
| 232 |
+
import mlflow.sklearn
|
| 233 |
+
model = mlflow.sklearn.load_model("runs:/{run_id}/model")
|
| 234 |
+
|
| 235 |
+
def _predict(df):
|
| 236 |
+
return model.predict(df).tolist()
|
| 237 |
+
"""
|
| 238 |
+
elif model_type == "pycaret":
|
| 239 |
+
return f"""
|
| 240 |
+
import os, mlflow, joblib
|
| 241 |
+
import pandas as pd
|
| 242 |
+
_local = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 243 |
+
|
| 244 |
+
try:
|
| 245 |
+
client = mlflow.tracking.MlflowClient()
|
| 246 |
+
run = client.get_run("{run_id}")
|
| 247 |
+
task_type = run.data.params.get("task_type", "Classification")
|
| 248 |
+
except Exception:
|
| 249 |
+
task_type = "Classification"
|
| 250 |
+
|
| 251 |
+
if task_type == "Regression":
|
| 252 |
+
from pycaret.regression import load_model, predict_model
|
| 253 |
+
elif task_type == "Time Series Forecasting":
|
| 254 |
+
from pycaret.time_series import load_model, predict_model
|
| 255 |
+
else:
|
| 256 |
+
from pycaret.classification import load_model, predict_model
|
| 257 |
+
|
| 258 |
+
_mpath = None
|
| 259 |
+
for root, _, files in os.walk(_local):
|
| 260 |
+
for f in files:
|
| 261 |
+
if f.endswith(".pkl"):
|
| 262 |
+
_mpath = os.path.join(root, f).replace(".pkl", "")
|
| 263 |
+
break
|
| 264 |
+
if _mpath is None:
|
| 265 |
+
raise FileNotFoundError("PyCaret model not found.")
|
| 266 |
+
model = load_model(_mpath)
|
| 267 |
+
|
| 268 |
+
def _predict(df):
|
| 269 |
+
preds = predict_model(model, data=df)
|
| 270 |
+
if task_type == "Classification" and "prediction_label" in preds.columns:
|
| 271 |
+
return preds["prediction_label"].tolist()
|
| 272 |
+
else:
|
| 273 |
+
# For regression or time series, it might return 'prediction_label' or just predictions
|
| 274 |
+
if "prediction_label" in preds.columns:
|
| 275 |
+
return preds["prediction_label"].tolist()
|
| 276 |
+
return preds.iloc[:, 0].tolist()
|
| 277 |
+
"""
|
| 278 |
+
elif model_type == "lale":
|
| 279 |
+
return f"""
|
| 280 |
+
import os, mlflow, joblib
|
| 281 |
+
import numpy as np
|
| 282 |
+
_local = mlflow.artifacts.download_artifacts(run_id="{run_id}", artifact_path="model")
|
| 283 |
+
_bundle = None
|
| 284 |
+
for root, _, files in os.walk(_local):
|
| 285 |
+
for f in files:
|
| 286 |
+
if f.endswith(".pkl"):
|
| 287 |
+
_bundle = joblib.load(os.path.join(root, f))
|
| 288 |
+
break
|
| 289 |
+
if _bundle is None:
|
| 290 |
+
raise FileNotFoundError("Lale model not found.")
|
| 291 |
+
_model = _bundle["model"]
|
| 292 |
+
_col_encoders = _bundle.get("col_encoders", {{}})
|
| 293 |
+
_y_encoder = _bundle.get("y_encoder", None)
|
| 294 |
+
|
| 295 |
+
def _predict(df):
|
| 296 |
+
import pandas as _pd
|
| 297 |
+
df = _pd.DataFrame(df)
|
| 298 |
+
for col, enc in _col_encoders.items():
|
| 299 |
+
if col in df.columns:
|
| 300 |
+
df[col] = enc.transform(df[[col]]).ravel()
|
| 301 |
+
raw = _model.predict(df.values)
|
| 302 |
+
if _y_encoder is not None:
|
| 303 |
+
return _y_encoder.inverse_transform(raw).tolist()
|
| 304 |
+
return raw.tolist()
|
| 305 |
+
"""
|
| 306 |
+
else:
|
| 307 |
+
return """
|
| 308 |
+
model = None
|
| 309 |
+
def _predict(df):
|
| 310 |
+
return []
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def generate_api_deployment(model_type: str, run_id: str, target_column: str, output_dir: str = "deploy") -> str:
|
| 315 |
+
"""
|
| 316 |
+
Generates a ready-to-use FastAPI + Docker deployment package for the model.
|
| 317 |
+
Supports: autogluon, flaml, h2o, tpot, pycaret, lale.
|
| 318 |
+
"""
|
| 319 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 320 |
+
|
| 321 |
+
load_code = _load_code_for_deploy(model_type, run_id)
|
| 322 |
+
|
| 323 |
+
main_py = f"""from fastapi import FastAPI, HTTPException
|
| 324 |
+
from pydantic import BaseModel
|
| 325 |
+
import pandas as pd
|
| 326 |
+
import os
|
| 327 |
+
|
| 328 |
+
app = FastAPI(title="AutoML Generated API - {model_type}", version="1.0")
|
| 329 |
+
|
| 330 |
+
# --- Model Loading ---
|
| 331 |
+
{load_code}
|
| 332 |
+
# ---------------------
|
| 333 |
+
|
| 334 |
+
@app.get("/")
|
| 335 |
+
def health():
|
| 336 |
+
return {{"status": "running", "model": "{model_type}", "run_id": "{run_id}"}}
|
| 337 |
+
|
| 338 |
+
@app.post("/predict")
|
| 339 |
+
def predict(payload: dict):
|
| 340 |
+
try:
|
| 341 |
+
if "data" in payload:
|
| 342 |
+
df = pd.DataFrame(payload["data"])
|
| 343 |
+
else:
|
| 344 |
+
df = pd.DataFrame([payload])
|
| 345 |
+
return {{"predictions": _predict(df)}}
|
| 346 |
+
except Exception as e:
|
| 347 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 348 |
+
|
| 349 |
+
if __name__ == "__main__":
|
| 350 |
+
import uvicorn
|
| 351 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
with open(os.path.join(output_dir, "main.py"), "w", encoding="utf-8") as f:
|
| 355 |
+
f.write(main_py)
|
| 356 |
+
|
| 357 |
+
# requirements.txt
|
| 358 |
+
base_reqs = """fastapi==0.104.1
|
| 359 |
+
uvicorn==0.24.0
|
| 360 |
+
pydantic==2.5.2
|
| 361 |
+
pandas==2.1.4
|
| 362 |
+
mlflow==2.9.2
|
| 363 |
+
"""
|
| 364 |
+
extra = {
|
| 365 |
+
"autogluon": "autogluon==1.0.0\n",
|
| 366 |
+
"flaml": "flaml==2.1.2\n",
|
| 367 |
+
"h2o": "h2o==3.44.0.3\n",
|
| 368 |
+
"tpot": "tpot==0.12.2\nscikit-learn==1.2.2\n",
|
| 369 |
+
"pycaret": "pycaret==3.3.0\nscikit-learn==1.2.2\nscipy==1.11.4\n",
|
| 370 |
+
"lale": "lale==0.9.1\nscikit-learn==1.2.2\njoblib\nhyperopt\n",
|
| 371 |
+
}
|
| 372 |
+
reqs = base_reqs + extra.get(model_type, "")
|
| 373 |
+
with open(os.path.join(output_dir, "requirements.txt"), "w", encoding="utf-8") as f:
|
| 374 |
+
f.write(reqs)
|
| 375 |
+
|
| 376 |
+
# Dockerfile
|
| 377 |
+
dockerfile = f"""FROM python:3.11-slim
|
| 378 |
+
|
| 379 |
+
WORKDIR /app
|
| 380 |
+
|
| 381 |
+
RUN apt-get update && apt-get install -y \\
|
| 382 |
+
build-essential libgomp1 libgl1 python3-dev default-jre curl \\
|
| 383 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 384 |
+
|
| 385 |
+
COPY requirements.txt .
|
| 386 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 387 |
+
|
| 388 |
+
COPY main.py .
|
| 389 |
+
|
| 390 |
+
EXPOSE 8000
|
| 391 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
| 392 |
+
"""
|
| 393 |
+
with open(os.path.join(output_dir, "Dockerfile"), "w", encoding="utf-8") as f:
|
| 394 |
+
f.write(dockerfile)
|
| 395 |
+
|
| 396 |
+
# README
|
| 397 |
+
readme = f"""# API Deployment — {model_type} (Run: {run_id})
|
| 398 |
+
|
| 399 |
+
## Local
|
| 400 |
+
```bash
|
| 401 |
+
pip install -r requirements.txt
|
| 402 |
+
python main.py
|
| 403 |
+
```
|
| 404 |
+
|
| 405 |
+
## Docker
|
| 406 |
+
```bash
|
| 407 |
+
docker build -t ml-api:{run_id[:8]} .
|
| 408 |
+
docker run -p 8000:8000 ml-api:{run_id[:8]}
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
## Example request
|
| 412 |
+
```json
|
| 413 |
+
POST http://localhost:8000/predict
|
| 414 |
+
{{
|
| 415 |
+
"data": [{{"feature1": 1.5, "feature2": "value"}}]
|
| 416 |
+
}}
|
| 417 |
+
```
|
| 418 |
+
"""
|
| 419 |
+
with open(os.path.join(output_dir, "README.md"), "w", encoding="utf-8") as f:
|
| 420 |
+
f.write(readme)
|
| 421 |
+
|
| 422 |
+
return output_dir
|
src/data_utils.py
CHANGED
|
@@ -2,21 +2,40 @@ import os
|
|
| 2 |
import subprocess
|
| 3 |
import hashlib
|
| 4 |
import time
|
|
|
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
def load_data(file):
|
| 8 |
"""
|
| 9 |
Loads data from an uploaded file (CSV or Excel) or a disk path.
|
|
|
|
|
|
|
| 10 |
"""
|
| 11 |
is_path = isinstance(file, str)
|
| 12 |
filename = file if is_path else file.name
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
if filename.endswith('.csv'):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
return pd.read_csv(file)
|
| 16 |
elif filename.endswith(('.xls', '.xlsx')):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
return pd.read_excel(file)
|
| 18 |
else:
|
| 19 |
-
raise ValueError("Unsupported file format. Please use CSV or
|
| 20 |
|
| 21 |
def get_data_summary(df):
|
| 22 |
"""
|
|
@@ -90,14 +109,67 @@ def get_data_lake_files():
|
|
| 90 |
return []
|
| 91 |
|
| 92 |
files = []
|
|
|
|
| 93 |
for f in os.listdir(data_lake_dir):
|
| 94 |
if f.endswith(('.csv', '.xls', '.xlsx')):
|
| 95 |
files.append(os.path.join(data_lake_dir, f))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
# Sort files by creation time descending (newest first)
|
| 98 |
files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
| 99 |
return files
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
def get_dvc_hash(file_path):
|
| 102 |
"""
|
| 103 |
Extracts the DVC hash corresponding to a specific file.
|
|
|
|
| 2 |
import subprocess
|
| 3 |
import hashlib
|
| 4 |
import time
|
| 5 |
+
import sys
|
| 6 |
import pandas as pd
|
| 7 |
+
import zipfile
|
| 8 |
+
import shutil
|
| 9 |
|
| 10 |
+
def load_data(file, no_header=False):
|
| 11 |
"""
|
| 12 |
Loads data from an uploaded file (CSV or Excel) or a disk path.
|
| 13 |
+
If no_header is True, treats the first row as data (no header) and
|
| 14 |
+
auto-generates column names as col_0, col_1, ...
|
| 15 |
"""
|
| 16 |
is_path = isinstance(file, str)
|
| 17 |
filename = file if is_path else file.name
|
| 18 |
|
| 19 |
+
if os.path.isdir(filename):
|
| 20 |
+
# For image directories, return a mock DataFrame to avoid crashing the UI
|
| 21 |
+
# AutoGluon / AutoKeras will use the path string instead of this DataFrame.
|
| 22 |
+
num_files = sum(len(files) for _, _, files in os.walk(filename))
|
| 23 |
+
return pd.DataFrame({"Image_Directory": [filename], "Total_Images": [num_files], "Type": ["Computer Vision Dataset"]})
|
| 24 |
+
|
| 25 |
if filename.endswith('.csv'):
|
| 26 |
+
if no_header:
|
| 27 |
+
df = pd.read_csv(file, header=None)
|
| 28 |
+
df.columns = [f"col_{i}" for i in range(len(df.columns))]
|
| 29 |
+
return df
|
| 30 |
return pd.read_csv(file)
|
| 31 |
elif filename.endswith(('.xls', '.xlsx')):
|
| 32 |
+
if no_header:
|
| 33 |
+
df = pd.read_excel(file, header=None)
|
| 34 |
+
df.columns = [f"col_{i}" for i in range(len(df.columns))]
|
| 35 |
+
return df
|
| 36 |
return pd.read_excel(file)
|
| 37 |
else:
|
| 38 |
+
raise ValueError("Unsupported file format. Please use CSV, Excel, or provide a valid image directory.")
|
| 39 |
|
| 40 |
def get_data_summary(df):
|
| 41 |
"""
|
|
|
|
| 109 |
return []
|
| 110 |
|
| 111 |
files = []
|
| 112 |
+
# Add tabular files
|
| 113 |
for f in os.listdir(data_lake_dir):
|
| 114 |
if f.endswith(('.csv', '.xls', '.xlsx')):
|
| 115 |
files.append(os.path.join(data_lake_dir, f))
|
| 116 |
+
|
| 117 |
+
# Add image directories
|
| 118 |
+
images_dir = os.path.join("data_lake", "images")
|
| 119 |
+
if os.path.exists(images_dir):
|
| 120 |
+
for d in os.listdir(images_dir):
|
| 121 |
+
dir_path = os.path.join(images_dir, d)
|
| 122 |
+
if os.path.isdir(dir_path):
|
| 123 |
+
files.append(dir_path)
|
| 124 |
|
|
|
|
| 125 |
files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
|
| 126 |
return files
|
| 127 |
|
| 128 |
+
def process_image_upload(uploaded_files, dataset_name="image_dataset", is_zip=False):
|
| 129 |
+
"""
|
| 130 |
+
Processes uploaded images (multiple files or a zip) and stores them in data_lake/images/<dataset_name>.
|
| 131 |
+
Supports ZIP extraction or direct copying.
|
| 132 |
+
Returns the path to the dataset directory and a hash.
|
| 133 |
+
"""
|
| 134 |
+
data_lake_dir = os.path.join("data_lake", "images", dataset_name)
|
| 135 |
+
os.makedirs(data_lake_dir, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
timestamp = int(time.time())
|
| 138 |
+
target_dir = f"{data_lake_dir}_{timestamp}"
|
| 139 |
+
os.makedirs(target_dir, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
if is_zip and len(uploaded_files) == 1:
|
| 142 |
+
# Extract ZIP
|
| 143 |
+
zip_file = uploaded_files[0]
|
| 144 |
+
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
| 145 |
+
zip_ref.extractall(target_dir)
|
| 146 |
+
else:
|
| 147 |
+
# Multiple Image Files
|
| 148 |
+
for f in uploaded_files:
|
| 149 |
+
file_path = os.path.join(target_dir, f.name)
|
| 150 |
+
with open(file_path, "wb") as out_f:
|
| 151 |
+
out_f.write(f.getbuffer())
|
| 152 |
+
|
| 153 |
+
# Add directory to DVC
|
| 154 |
+
dvc_hash = "unknown_dir_hash"
|
| 155 |
+
try:
|
| 156 |
+
init_dvc()
|
| 157 |
+
subprocess.run(["dvc", "add", target_dir], check=True, capture_output=True)
|
| 158 |
+
dvc_file_path = target_dir + ".dvc"
|
| 159 |
+
if os.path.exists(dvc_file_path):
|
| 160 |
+
with open(dvc_file_path, "r") as f:
|
| 161 |
+
content = f.read()
|
| 162 |
+
import re
|
| 163 |
+
match = re.search(r'md5:\s*([a-fA-F0-9]+)', content)
|
| 164 |
+
if match:
|
| 165 |
+
dvc_hash = match.group(1)
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"DVC error on image dir: {e}")
|
| 168 |
+
# Pseudo hash fallback
|
| 169 |
+
dvc_hash = hashlib.md5(target_dir.encode()).hexdigest()
|
| 170 |
+
|
| 171 |
+
return target_dir, dvc_hash, dvc_hash[:8]
|
| 172 |
+
|
| 173 |
def get_dvc_hash(file_path):
|
| 174 |
"""
|
| 175 |
Extracts the DVC hash corresponding to a specific file.
|
src/experiment_manager.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ExperimentManager: central registry for all training runs.
|
| 3 |
+
Stored as a singleton in st.session_state['exp_manager'].
|
| 4 |
+
"""
|
| 5 |
+
import threading
|
| 6 |
+
import queue
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional, Any
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class ExperimentEntry:
|
| 17 |
+
key: str # unique slug: "autogluon_1712345678"
|
| 18 |
+
metadata: dict # framework, run_name, config snapshot
|
| 19 |
+
thread: Optional[threading.Thread] = field(default=None, repr=False)
|
| 20 |
+
stop_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
| 21 |
+
log_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
|
| 22 |
+
telemetry_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
|
| 23 |
+
result_queue: queue.Queue = field(default_factory=queue.Queue, repr=False)
|
| 24 |
+
status: str = "queued" # queued | running | completed | failed | cancelled
|
| 25 |
+
started_at: float = field(default_factory=time.time)
|
| 26 |
+
finished_at: Optional[float] = None
|
| 27 |
+
result: Optional[dict] = None # {predictor, run_id, type, ...} or {error: str}
|
| 28 |
+
all_logs: list = field(default_factory=list)
|
| 29 |
+
latest_telemetry: dict = field(default_factory=dict)
|
| 30 |
+
last_update: float = field(default_factory=time.time)
|
| 31 |
+
|
| 32 |
+
def elapsed_str(self) -> str:
|
| 33 |
+
end = self.finished_at or time.time()
|
| 34 |
+
secs = int(end - self.started_at)
|
| 35 |
+
m, s = divmod(secs, 60)
|
| 36 |
+
return f"{m}m {s:02d}s"
|
| 37 |
+
|
| 38 |
+
def status_icon(self) -> str:
|
| 39 |
+
return {
|
| 40 |
+
"queued": "⏳",
|
| 41 |
+
"running": "🟢",
|
| 42 |
+
"completed": "✅",
|
| 43 |
+
"failed": "❌",
|
| 44 |
+
"cancelled": "🚫",
|
| 45 |
+
}.get(self.status, "❓")
|
| 46 |
+
|
| 47 |
+
def drain_logs(self) -> bool:
|
| 48 |
+
"""Pull all pending log lines and telemetry into the entry."""
|
| 49 |
+
new = False
|
| 50 |
+
while not self.log_queue.empty():
|
| 51 |
+
try:
|
| 52 |
+
line = self.log_queue.get_nowait()
|
| 53 |
+
self.all_logs.append(line)
|
| 54 |
+
new = True
|
| 55 |
+
except queue.Empty:
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
while not self.telemetry_queue.empty():
|
| 59 |
+
try:
|
| 60 |
+
data = self.telemetry_queue.get_nowait()
|
| 61 |
+
if isinstance(data, dict):
|
| 62 |
+
self.latest_telemetry.update(data)
|
| 63 |
+
new = True
|
| 64 |
+
except queue.Empty:
|
| 65 |
+
break
|
| 66 |
+
|
| 67 |
+
if new:
|
| 68 |
+
self.last_update = time.time()
|
| 69 |
+
return new
|
| 70 |
+
|
| 71 |
+
def check_result(self):
|
| 72 |
+
"""Non-blocking check: pull result from queue if available."""
|
| 73 |
+
if not self.result_queue.empty():
|
| 74 |
+
try:
|
| 75 |
+
res = self.result_queue.get_nowait()
|
| 76 |
+
self.result = res
|
| 77 |
+
if res.get("success"):
|
| 78 |
+
self.status = "completed"
|
| 79 |
+
else:
|
| 80 |
+
self.status = "failed"
|
| 81 |
+
self.finished_at = time.time()
|
| 82 |
+
self.last_update = time.time()
|
| 83 |
+
except queue.Empty:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ExperimentManager:
|
| 88 |
+
"""In-process registry of all AutoML experiments."""
|
| 89 |
+
|
| 90 |
+
def __init__(self):
|
| 91 |
+
self._runs: dict[str, ExperimentEntry] = {}
|
| 92 |
+
self._lock = threading.Lock()
|
| 93 |
+
|
| 94 |
+
def add(self, entry: ExperimentEntry) -> str:
|
| 95 |
+
with self._lock:
|
| 96 |
+
self._runs[entry.key] = entry
|
| 97 |
+
return entry.key
|
| 98 |
+
|
| 99 |
+
def cancel(self, key: str):
|
| 100 |
+
"""Request graceful cancellation of a running experiment."""
|
| 101 |
+
with self._lock:
|
| 102 |
+
entry = self._runs.get(key)
|
| 103 |
+
if entry and entry.status == "running":
|
| 104 |
+
entry.stop_event.set()
|
| 105 |
+
entry.status = "cancelled"
|
| 106 |
+
entry.finished_at = time.time()
|
| 107 |
+
entry.last_update = time.time()
|
| 108 |
+
logger.info(f"Cancel requested for experiment: {key}")
|
| 109 |
+
|
| 110 |
+
def delete(self, key: str):
|
| 111 |
+
"""Remove experiment from registry (only if not actively running)."""
|
| 112 |
+
with self._lock:
|
| 113 |
+
entry = self._runs.get(key)
|
| 114 |
+
if entry and entry.status == "running":
|
| 115 |
+
# Cancel first
|
| 116 |
+
entry.stop_event.set()
|
| 117 |
+
entry.status = "cancelled"
|
| 118 |
+
entry.finished_at = time.time()
|
| 119 |
+
entry.last_update = time.time()
|
| 120 |
+
self._runs.pop(key, None)
|
| 121 |
+
|
| 122 |
+
def get(self, key: str) -> Optional[ExperimentEntry]:
|
| 123 |
+
with self._lock:
|
| 124 |
+
return self._runs.get(key)
|
| 125 |
+
|
| 126 |
+
def get_all(self) -> list[ExperimentEntry]:
|
| 127 |
+
"""Return all experiments newest-first."""
|
| 128 |
+
with self._lock:
|
| 129 |
+
entries = list(self._runs.values())
|
| 130 |
+
return sorted(entries, key=lambda e: e.started_at, reverse=True)
|
| 131 |
+
|
| 132 |
+
def has_running(self) -> bool:
|
| 133 |
+
return any(e.status == "running" for e in self.get_all())
|
| 134 |
+
|
| 135 |
+
def refresh_all(self):
|
| 136 |
+
"""Sync status/logs/results for all experiments."""
|
| 137 |
+
for entry in self.get_all():
|
| 138 |
+
entry.drain_logs()
|
| 139 |
+
if entry.status in ("running", "queued"):
|
| 140 |
+
entry.check_result()
|
| 141 |
+
# Also check if thread died unexpectedly
|
| 142 |
+
if getattr(entry, 'thread', None) is not None:
|
| 143 |
+
# Defensive check for is_alive
|
| 144 |
+
if not entry.thread.is_alive() and entry.status == "running":
|
| 145 |
+
if entry.result is None:
|
| 146 |
+
entry.status = "failed"
|
| 147 |
+
entry.result = {"success": False, "error": "Thread terminated unexpectedly"}
|
| 148 |
+
entry.finished_at = time.time()
|
| 149 |
+
entry.last_update = time.time()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_or_create_manager(session_state) -> ExperimentManager:
|
| 153 |
+
"""Get or create the singleton ExperimentManager from Streamlit session state."""
|
| 154 |
+
if 'exp_manager' not in session_state or not isinstance(session_state.get('exp_manager'), ExperimentManager):
|
| 155 |
+
session_state['exp_manager'] = ExperimentManager()
|
| 156 |
+
return session_state['exp_manager']
|
src/flaml_utils.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
import pandas as pd
|
| 3 |
import mlflow
|
| 4 |
import shutil
|
|
@@ -7,12 +8,16 @@ from flaml import AutoML
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
import time
|
| 9 |
from src.mlflow_utils import safe_set_experiment
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
| 14 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
Trains a FLAML model and logs results to MLflow.
|
| 18 |
"""
|
|
@@ -25,7 +30,14 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 25 |
flaml_logger = logging.getLogger('flaml')
|
| 26 |
flaml_logger.setLevel(logging.INFO)
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# Data cleaning: drop rows where target is NaN
|
| 30 |
train_data = train_data.dropna(subset=[target])
|
| 31 |
logging.info(f"Data ready: {len(train_data)} rows.")
|
|
@@ -48,6 +60,7 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 48 |
valid_data = valid_data.dropna(subset=[target])
|
| 49 |
X_val = valid_data.drop(columns=[target])
|
| 50 |
y_val = valid_data[target]
|
|
|
|
| 51 |
mlflow.log_param("has_validation_data", True)
|
| 52 |
|
| 53 |
if test_data is not None:
|
|
@@ -62,15 +75,16 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 62 |
# The 'No low-cost partial config given' message is just an INFO warning from FLAML.
|
| 63 |
|
| 64 |
settings = {
|
| 65 |
-
"time_budget": time_budget,
|
| 66 |
"metric": metric,
|
| 67 |
"task": task,
|
| 68 |
"estimator_list": estimator_list,
|
| 69 |
"log_file_name": "flaml.log",
|
| 70 |
"seed": seed,
|
| 71 |
-
"n_jobs":
|
| 72 |
"verbose": 0, # Reduce internal verbosity to avoid pollution, progress goes to flaml.log
|
| 73 |
}
|
|
|
|
|
|
|
| 74 |
|
| 75 |
if cv_folds > 0:
|
| 76 |
settings["eval_method"] = "cv"
|
|
@@ -80,6 +94,36 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 80 |
settings["X_val"] = X_val
|
| 81 |
settings["y_val"] = y_val
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
# Train model
|
| 84 |
logging.info("Executing hyperparameter search (automl.fit)...")
|
| 85 |
try:
|
|
@@ -90,6 +134,9 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 90 |
if not hasattr(automl, 'best_estimator') or automl.best_estimator is None:
|
| 91 |
raise RuntimeError("FLAML stopped without finding a valid model.")
|
| 92 |
|
|
|
|
|
|
|
|
|
|
| 93 |
# Log metrics
|
| 94 |
if hasattr(automl, 'best_loss'):
|
| 95 |
mlflow.log_metric("best_loss", automl.best_loss)
|
|
@@ -106,6 +153,29 @@ def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 106 |
mlflow.log_artifact(model_path, artifact_path="model")
|
| 107 |
mlflow.log_param("model_type", "flaml")
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
# Log training log as artifact
|
| 110 |
if os.path.exists("flaml.log"):
|
| 111 |
mlflow.log_artifact("flaml.log")
|
|
|
|
| 1 |
import os
|
| 2 |
+
import threading
|
| 3 |
import pandas as pd
|
| 4 |
import mlflow
|
| 5 |
import shutil
|
|
|
|
| 8 |
import matplotlib.pyplot as plt
|
| 9 |
import time
|
| 10 |
from src.mlflow_utils import safe_set_experiment
|
| 11 |
+
from src.onnx_utils import export_to_onnx
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
| 15 |
def train_flaml_model(train_data: pd.DataFrame, target: str, run_name: str,
|
| 16 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 17 |
+
time_budget: int = 60, task: str = 'classification', metric: str = 'auto',
|
| 18 |
+
estimator_list: list = 'auto', seed: int = 42, cv_folds: int = 0,
|
| 19 |
+
n_jobs: int = 1,
|
| 20 |
+
stop_event=None, telemetry_queue=None):
|
| 21 |
"""
|
| 22 |
Trains a FLAML model and logs results to MLflow.
|
| 23 |
"""
|
|
|
|
| 30 |
flaml_logger = logging.getLogger('flaml')
|
| 31 |
flaml_logger.setLevel(logging.INFO)
|
| 32 |
|
| 33 |
+
# Ensure no leaked runs in this thread
|
| 34 |
+
try:
|
| 35 |
+
if mlflow.active_run():
|
| 36 |
+
mlflow.end_run()
|
| 37 |
+
except:
|
| 38 |
+
pass
|
| 39 |
+
|
| 40 |
+
with mlflow.start_run(run_name=run_name, nested=True) as run:
|
| 41 |
# Data cleaning: drop rows where target is NaN
|
| 42 |
train_data = train_data.dropna(subset=[target])
|
| 43 |
logging.info(f"Data ready: {len(train_data)} rows.")
|
|
|
|
| 60 |
valid_data = valid_data.dropna(subset=[target])
|
| 61 |
X_val = valid_data.drop(columns=[target])
|
| 62 |
y_val = valid_data[target]
|
| 63 |
+
|
| 64 |
mlflow.log_param("has_validation_data", True)
|
| 65 |
|
| 66 |
if test_data is not None:
|
|
|
|
| 75 |
# The 'No low-cost partial config given' message is just an INFO warning from FLAML.
|
| 76 |
|
| 77 |
settings = {
|
|
|
|
| 78 |
"metric": metric,
|
| 79 |
"task": task,
|
| 80 |
"estimator_list": estimator_list,
|
| 81 |
"log_file_name": "flaml.log",
|
| 82 |
"seed": seed,
|
| 83 |
+
"n_jobs": n_jobs,
|
| 84 |
"verbose": 0, # Reduce internal verbosity to avoid pollution, progress goes to flaml.log
|
| 85 |
}
|
| 86 |
+
if time_budget is not None:
|
| 87 |
+
settings["time_budget"] = time_budget
|
| 88 |
|
| 89 |
if cv_folds > 0:
|
| 90 |
settings["eval_method"] = "cv"
|
|
|
|
| 94 |
settings["X_val"] = X_val
|
| 95 |
settings["y_val"] = y_val
|
| 96 |
|
| 97 |
+
# Start a watcher thread to respect stop_event
|
| 98 |
+
_cancel_watcher = None
|
| 99 |
+
if stop_event is not None:
|
| 100 |
+
def _watch():
|
| 101 |
+
stop_event.wait()
|
| 102 |
+
try:
|
| 103 |
+
automl._state.time_budget = 0 # Signal FLAML to stop
|
| 104 |
+
except Exception:
|
| 105 |
+
pass
|
| 106 |
+
_cancel_watcher = threading.Thread(target=_watch, daemon=True)
|
| 107 |
+
_cancel_watcher.start()
|
| 108 |
+
|
| 109 |
+
# Custom callback for telemetry
|
| 110 |
+
def _telemetry_callback(iter_count, time_used, best_loss, best_config, estimator, trial_id):
|
| 111 |
+
try:
|
| 112 |
+
if telemetry_queue:
|
| 113 |
+
telemetry_queue.put({
|
| 114 |
+
"status": "running",
|
| 115 |
+
"iterations": iter_count,
|
| 116 |
+
"time_used": time_used,
|
| 117 |
+
"best_loss": best_loss,
|
| 118 |
+
"best_estimator": str(estimator),
|
| 119 |
+
"best_config_preview": str(best_config)[:200]
|
| 120 |
+
})
|
| 121 |
+
except Exception:
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
if telemetry_queue:
|
| 125 |
+
settings["callbacks"] = [_telemetry_callback]
|
| 126 |
+
|
| 127 |
# Train model
|
| 128 |
logging.info("Executing hyperparameter search (automl.fit)...")
|
| 129 |
try:
|
|
|
|
| 134 |
if not hasattr(automl, 'best_estimator') or automl.best_estimator is None:
|
| 135 |
raise RuntimeError("FLAML stopped without finding a valid model.")
|
| 136 |
|
| 137 |
+
if stop_event and stop_event.is_set():
|
| 138 |
+
raise StopIteration("Training cancelled by user")
|
| 139 |
+
|
| 140 |
# Log metrics
|
| 141 |
if hasattr(automl, 'best_loss'):
|
| 142 |
mlflow.log_metric("best_loss", automl.best_loss)
|
|
|
|
| 153 |
mlflow.log_artifact(model_path, artifact_path="model")
|
| 154 |
mlflow.log_param("model_type", "flaml")
|
| 155 |
|
| 156 |
+
# ONNX Export
|
| 157 |
+
try:
|
| 158 |
+
onnx_path = os.path.join("models", f"flaml_{run_name}.onnx")
|
| 159 |
+
# For FLAML, we can often export the underlying best estimator or the AutoML object if it's scikit-learn compatible
|
| 160 |
+
# We pass X_train[:1] as sample input for shape inference
|
| 161 |
+
export_to_onnx(automl.model.estimator, "flaml", target, onnx_path, input_sample=X_train[:1])
|
| 162 |
+
mlflow.log_artifact(onnx_path, artifact_path="model")
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.warning(f"Failed to export FLAML model to ONNX: {e}")
|
| 165 |
+
|
| 166 |
+
# Generate and log consumption code sample
|
| 167 |
+
try:
|
| 168 |
+
from src.code_gen_utils import generate_consumption_code
|
| 169 |
+
code_sample = generate_consumption_code("flaml", run.info.run_id, target)
|
| 170 |
+
code_path = "consumption_sample.py"
|
| 171 |
+
with open(code_path, "w") as f:
|
| 172 |
+
f.write(code_sample)
|
| 173 |
+
mlflow.log_artifact(code_path)
|
| 174 |
+
if os.path.exists(code_path):
|
| 175 |
+
os.remove(code_path)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
logger.warning(f"Failed to generate consumption code: {e}")
|
| 178 |
+
|
| 179 |
# Log training log as artifact
|
| 180 |
if os.path.exists("flaml.log"):
|
| 181 |
mlflow.log_artifact("flaml.log")
|
src/h2o_utils.py
CHANGED
|
@@ -111,7 +111,8 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 111 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 112 |
max_runtime_secs: int = 300, max_models: int = 10,
|
| 113 |
nfolds: int = 3, balance_classes: bool = True, seed: int = 42,
|
| 114 |
-
sort_metric: str = "AUTO", exclude_algos: list = None
|
|
|
|
| 115 |
"""
|
| 116 |
Trains H2O AutoML model and registers in MLflow
|
| 117 |
"""
|
|
@@ -125,7 +126,14 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 125 |
h2o_instance = initialize_h2o()
|
| 126 |
|
| 127 |
try:
|
| 128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
# Prepare data
|
| 130 |
h2o_frame, clean_data = prepare_data_for_h2o(train_data, target)
|
| 131 |
|
|
@@ -158,6 +166,23 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 158 |
sort_metric=sort_metric,
|
| 159 |
exclude_algos=exclude_algos or []
|
| 160 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
# Prepare test and validation data if present
|
| 163 |
h2o_valid = None
|
|
@@ -178,14 +203,62 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 178 |
|
| 179 |
# Train model
|
| 180 |
logger.info("Starting H2O AutoML training...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
start_time = time.time()
|
| 182 |
train_kwargs = {"x": features, "y": target, "training_frame": h2o_frame}
|
| 183 |
if h2o_valid is not None:
|
| 184 |
train_kwargs["validation_frame"] = h2o_valid
|
| 185 |
if h2o_test is not None:
|
| 186 |
train_kwargs["leaderboard_frame"] = h2o_test
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
training_duration = time.time() - start_time
|
| 190 |
|
| 191 |
logger.info(f"Training completed in {training_duration:.2f} seconds")
|
|
@@ -214,51 +287,56 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 214 |
|
| 215 |
# Save leaderboard as metric with safe wrapper
|
| 216 |
try:
|
| 217 |
-
|
| 218 |
-
|
|
|
|
| 219 |
try:
|
|
|
|
| 220 |
leaderboard_df = leaderboard.as_data_frame()
|
| 221 |
-
|
|
|
|
| 222 |
except Exception as e:
|
| 223 |
-
logger.warning(f"
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
best_model_score = 0.0
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
break
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
logger.info(f"
|
| 241 |
-
|
| 242 |
-
# Try accessing first row, first metric col
|
| 243 |
-
if len(available_columns) > 0:
|
| 244 |
-
first_col = available_columns[0]
|
| 245 |
-
best_model_score = leaderboard[0, first_col]
|
| 246 |
-
logger.info(f"Using first available column '{first_col}': {best_model_score}")
|
| 247 |
-
|
| 248 |
-
mlflow.log_metric("total_models_trained", leaderboard.nrow)
|
| 249 |
-
except Exception as e:
|
| 250 |
-
logger.warning(f"Could not extract metrics from leaderboard: {e}")
|
| 251 |
-
mlflow.log_metric("total_models_trained", 0)
|
| 252 |
|
|
|
|
|
|
|
| 253 |
mlflow.log_metric("best_model_score", best_model_score)
|
| 254 |
mlflow.log_metric("training_duration", training_duration)
|
|
|
|
|
|
|
| 255 |
|
| 256 |
except Exception as e:
|
| 257 |
logger.warning(f"Error processing leaderboard metrics: {e}")
|
| 258 |
-
#
|
| 259 |
mlflow.log_metric("best_model_score", 0.0)
|
| 260 |
mlflow.log_metric("training_duration", training_duration)
|
| 261 |
-
mlflow.log_metric("total_models_trained", 0)
|
| 262 |
|
| 263 |
# Try saving leaderboard with error handling
|
| 264 |
try:
|
|
@@ -297,6 +375,19 @@ def train_h2o_model(train_data: pd.DataFrame, target: str, run_name: str,
|
|
| 297 |
h2o.save_model(best_model, path=temp_model_path)
|
| 298 |
mlflow.log_artifacts(temp_model_path, artifact_path="model")
|
| 299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
# Clean temp directory
|
| 301 |
import shutil
|
| 302 |
if os.path.exists(temp_model_path):
|
|
|
|
| 111 |
valid_data: pd.DataFrame = None, test_data: pd.DataFrame = None,
|
| 112 |
max_runtime_secs: int = 300, max_models: int = 10,
|
| 113 |
nfolds: int = 3, balance_classes: bool = True, seed: int = 42,
|
| 114 |
+
sort_metric: str = "AUTO", exclude_algos: list = None,
|
| 115 |
+
stop_event=None, telemetry_queue=None):
|
| 116 |
"""
|
| 117 |
Trains H2O AutoML model and registers in MLflow
|
| 118 |
"""
|
|
|
|
| 126 |
h2o_instance = initialize_h2o()
|
| 127 |
|
| 128 |
try:
|
| 129 |
+
# Ensure no leaked runs in this thread
|
| 130 |
+
try:
|
| 131 |
+
if mlflow.active_run():
|
| 132 |
+
mlflow.end_run()
|
| 133 |
+
except:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
with mlflow.start_run(run_name=run_name, nested=True) as run:
|
| 137 |
# Prepare data
|
| 138 |
h2o_frame, clean_data = prepare_data_for_h2o(train_data, target)
|
| 139 |
|
|
|
|
| 166 |
sort_metric=sort_metric,
|
| 167 |
exclude_algos=exclude_algos or []
|
| 168 |
)
|
| 169 |
+
|
| 170 |
+
# Watcher thread for graceful cancellation
|
| 171 |
+
import threading
|
| 172 |
+
_cancel_watcher = None
|
| 173 |
+
if stop_event is not None:
|
| 174 |
+
def _h2o_watch():
|
| 175 |
+
stop_event.wait()
|
| 176 |
+
try:
|
| 177 |
+
import h2o as _h2o
|
| 178 |
+
jobs = _h2o.cluster().jobs()
|
| 179 |
+
for job in jobs:
|
| 180 |
+
if job.status == 'RUNNING':
|
| 181 |
+
job.cancel()
|
| 182 |
+
except Exception:
|
| 183 |
+
pass
|
| 184 |
+
_cancel_watcher = threading.Thread(target=_h2o_watch, daemon=True)
|
| 185 |
+
_cancel_watcher.start()
|
| 186 |
|
| 187 |
# Prepare test and validation data if present
|
| 188 |
h2o_valid = None
|
|
|
|
| 203 |
|
| 204 |
# Train model
|
| 205 |
logger.info("Starting H2O AutoML training...")
|
| 206 |
+
import sys
|
| 207 |
+
# Guard against deep recursion in H2O/Scipy on some datasets
|
| 208 |
+
sys.setrecursionlimit(max(sys.getrecursionlimit(), 3000))
|
| 209 |
+
|
| 210 |
start_time = time.time()
|
| 211 |
train_kwargs = {"x": features, "y": target, "training_frame": h2o_frame}
|
| 212 |
if h2o_valid is not None:
|
| 213 |
train_kwargs["validation_frame"] = h2o_valid
|
| 214 |
if h2o_test is not None:
|
| 215 |
train_kwargs["leaderboard_frame"] = h2o_test
|
| 216 |
+
|
| 217 |
+
# Streaming updates thread
|
| 218 |
+
def _push_h2o_telemetry():
|
| 219 |
+
while aml.leaderboard is None or aml.leaderboard.nrow == 0:
|
| 220 |
+
if stop_event and stop_event.is_set(): break
|
| 221 |
+
time.sleep(2)
|
| 222 |
|
| 223 |
+
last_row_count = 0
|
| 224 |
+
while not (stop_event and stop_event.is_set()):
|
| 225 |
+
try:
|
| 226 |
+
lb = aml.leaderboard
|
| 227 |
+
if lb is not None and lb.nrow > last_row_count:
|
| 228 |
+
last_row_count = lb.nrow
|
| 229 |
+
lb_df = lb.as_data_frame()
|
| 230 |
+
best_metric = lb_df.columns[1] if len(lb_df.columns) > 1 else "score"
|
| 231 |
+
best_val = lb_df.iloc[0, 1] if len(lb_df) > 0 else 0
|
| 232 |
+
|
| 233 |
+
if telemetry_queue:
|
| 234 |
+
telemetry_queue.put({
|
| 235 |
+
"status": "running",
|
| 236 |
+
"models_trained": last_row_count,
|
| 237 |
+
"best_metric": best_metric,
|
| 238 |
+
"best_value": best_val,
|
| 239 |
+
"leaderboard_preview": lb_df.head(5).to_dict(orient='records')
|
| 240 |
+
})
|
| 241 |
+
except Exception:
|
| 242 |
+
pass
|
| 243 |
+
if training_duration := (time.time() - start_time):
|
| 244 |
+
if training_duration > max_runtime_secs and max_runtime_secs > 0: break
|
| 245 |
+
time.sleep(5)
|
| 246 |
+
|
| 247 |
+
if telemetry_queue:
|
| 248 |
+
t_telemetry = threading.Thread(target=_push_h2o_telemetry, daemon=True)
|
| 249 |
+
t_telemetry.start()
|
| 250 |
+
|
| 251 |
+
# Fix encoding issue on Windows by disabling H2O progress bar if it causes issues
|
| 252 |
+
# or wrapping the call. H2O uses ASCII bars if it detects non-tty, but our router
|
| 253 |
+
# might be confusing it.
|
| 254 |
+
try:
|
| 255 |
+
aml.train(**train_kwargs)
|
| 256 |
+
except UnicodeEncodeError:
|
| 257 |
+
# Fallback: try with minimal verbosity if encoding fails
|
| 258 |
+
logger.warning("Encoding error detected, retrying with lower verbosity...")
|
| 259 |
+
aml.project_name = aml.project_name + "_retry"
|
| 260 |
+
aml.train(**train_kwargs)
|
| 261 |
+
|
| 262 |
training_duration = time.time() - start_time
|
| 263 |
|
| 264 |
logger.info(f"Training completed in {training_duration:.2f} seconds")
|
|
|
|
| 287 |
|
| 288 |
# Save leaderboard as metric with safe wrapper
|
| 289 |
try:
|
| 290 |
+
available_metrics = []
|
| 291 |
+
num_models = 0
|
| 292 |
+
|
| 293 |
try:
|
| 294 |
+
num_models = leaderboard.nrow
|
| 295 |
leaderboard_df = leaderboard.as_data_frame()
|
| 296 |
+
available_metrics = [c.lower() for c in leaderboard_df.columns]
|
| 297 |
+
logger.info(f"Available leaderboard columns: {list(leaderboard_df.columns)}")
|
| 298 |
except Exception as e:
|
| 299 |
+
logger.warning(f"Metadata extraction failed: {e}")
|
| 300 |
+
leaderboard_df = None
|
| 301 |
+
|
| 302 |
+
# Search for metrics in preference order
|
| 303 |
best_model_score = 0.0
|
| 304 |
+
found_metric = "none"
|
| 305 |
+
|
| 306 |
+
metric_candidates = ['auc', 'logloss', 'rmse', 'mae', 'r2', 'mse', 'accuracy', 'f1']
|
| 307 |
+
|
| 308 |
+
if leaderboard_df is not None and not leaderboard_df.empty:
|
| 309 |
+
# Find column index
|
| 310 |
+
col_names_lower = [c.lower() for c in leaderboard_df.columns]
|
| 311 |
+
for m_cand in metric_candidates:
|
| 312 |
+
if m_cand in col_names_lower:
|
| 313 |
+
idx = col_names_lower.index(m_cand)
|
| 314 |
+
actual_col = leaderboard_df.columns[idx]
|
| 315 |
+
best_model_score = float(leaderboard_df.iloc[0][actual_col])
|
| 316 |
+
found_metric = actual_col
|
| 317 |
+
logger.info(f"Using metric '{found_metric}': {best_model_score}")
|
| 318 |
break
|
| 319 |
|
| 320 |
+
# If still 0 and we have columns, pick the second one (usually the main metric)
|
| 321 |
+
if best_model_score == 0.0 and len(leaderboard_df.columns) > 1:
|
| 322 |
+
actual_col = leaderboard_df.columns[1]
|
| 323 |
+
best_model_score = float(leaderboard_df.iloc[0][actual_col])
|
| 324 |
+
found_metric = actual_col
|
| 325 |
+
logger.info(f"Fallback to second column '{found_metric}': {best_model_score}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
+
# Log metrics
|
| 328 |
+
mlflow.log_metric("total_models_trained", float(num_models))
|
| 329 |
mlflow.log_metric("best_model_score", best_model_score)
|
| 330 |
mlflow.log_metric("training_duration", training_duration)
|
| 331 |
+
if found_metric != "none":
|
| 332 |
+
mlflow.set_tag("best_metric_name", found_metric)
|
| 333 |
|
| 334 |
except Exception as e:
|
| 335 |
logger.warning(f"Error processing leaderboard metrics: {e}")
|
| 336 |
+
# Ultimate fallback
|
| 337 |
mlflow.log_metric("best_model_score", 0.0)
|
| 338 |
mlflow.log_metric("training_duration", training_duration)
|
| 339 |
+
mlflow.log_metric("total_models_trained", 0.0)
|
| 340 |
|
| 341 |
# Try saving leaderboard with error handling
|
| 342 |
try:
|
|
|
|
| 375 |
h2o.save_model(best_model, path=temp_model_path)
|
| 376 |
mlflow.log_artifacts(temp_model_path, artifact_path="model")
|
| 377 |
|
| 378 |
+
# Generate and log consumption code sample
|
| 379 |
+
try:
|
| 380 |
+
from src.code_gen_utils import generate_consumption_code
|
| 381 |
+
code_sample = generate_consumption_code("h2o", run.info.run_id, target)
|
| 382 |
+
code_path = "consumption_sample.py"
|
| 383 |
+
with open(code_path, "w") as f:
|
| 384 |
+
f.write(code_sample)
|
| 385 |
+
mlflow.log_artifact(code_path)
|
| 386 |
+
if os.path.exists(code_path):
|
| 387 |
+
os.remove(code_path)
|
| 388 |
+
except Exception as e:
|
| 389 |
+
logger.warning(f"Failed to generate consumption code: {e}")
|
| 390 |
+
|
| 391 |
# Clean temp directory
|
| 392 |
import shutil
|
| 393 |
if os.path.exists(temp_model_path):
|
src/huggingface_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
HF_AVAILABLE = None
|
| 8 |
+
|
| 9 |
+
def _check_hf_availability():
|
| 10 |
+
global HF_AVAILABLE
|
| 11 |
+
if HF_AVAILABLE is not None:
|
| 12 |
+
return HF_AVAILABLE
|
| 13 |
+
try:
|
| 14 |
+
import huggingface_hub
|
| 15 |
+
HF_AVAILABLE = True
|
| 16 |
+
except Exception as e:
|
| 17 |
+
logger.warning(f"Hugging Face Hub not available: {e}")
|
| 18 |
+
HF_AVAILABLE = False
|
| 19 |
+
return HF_AVAILABLE
|
| 20 |
+
|
| 21 |
+
class HuggingFaceService:
|
| 22 |
+
def __init__(self, token: Optional[str] = None):
|
| 23 |
+
if not _check_hf_availability():
|
| 24 |
+
self.api = None
|
| 25 |
+
self.token = None
|
| 26 |
+
return
|
| 27 |
+
|
| 28 |
+
from huggingface_hub import HfApi
|
| 29 |
+
self.token = token or os.environ.get("HUGGINGFACE_TOKEN")
|
| 30 |
+
self.api = HfApi(token=self.token) if self.token else HfApi()
|
| 31 |
+
|
| 32 |
+
def authenticate(self, token: str):
|
| 33 |
+
"""Authenticates with the Hugging Face Hub."""
|
| 34 |
+
if not _check_hf_availability():
|
| 35 |
+
raise ImportError("Hugging Face Hub library not found.")
|
| 36 |
+
|
| 37 |
+
from huggingface_hub import login, HfApi
|
| 38 |
+
self.token = token
|
| 39 |
+
self.api = HfApi(token=token)
|
| 40 |
+
login(token=token)
|
| 41 |
+
logger.info("Authenticated with Hugging Face Hub.")
|
| 42 |
+
|
| 43 |
+
def list_models(self, query: str = None, author: str = None) -> List[Dict[str, Any]]:
|
| 44 |
+
"""Lists models on the Hub based on search query or author."""
|
| 45 |
+
if not self.api:
|
| 46 |
+
return []
|
| 47 |
+
models = self.api.list_models(search=query, author=author, limit=10)
|
| 48 |
+
return [{"id": m.id, "author": m.author, "lastModified": m.lastModified} for m in models]
|
| 49 |
+
|
| 50 |
+
def upload_model(self, model_path: str, repo_id: str, commit_message: str = "Upload AutoML model", private: bool = True):
|
| 51 |
+
"""Uploads a model file or directory to a HF repository."""
|
| 52 |
+
if not self.api:
|
| 53 |
+
raise ImportError("Hugging Face Hub library not found.")
|
| 54 |
+
if not self.token:
|
| 55 |
+
raise ValueError("Authentication token is required for upload.")
|
| 56 |
+
|
| 57 |
+
repo_url = self.api.create_repo(repo_id=repo_id, private=private, exist_ok=True)
|
| 58 |
+
logger.info(f"Hub repository ready: {repo_url}")
|
| 59 |
+
|
| 60 |
+
if os.path.isdir(model_path):
|
| 61 |
+
self.api.upload_folder(
|
| 62 |
+
folder_path=model_path,
|
| 63 |
+
repo_id=repo_id,
|
| 64 |
+
commit_message=commit_message
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
self.api.upload_file(
|
| 68 |
+
path_or_fileobj=model_path,
|
| 69 |
+
path_in_repo=os.path.basename(model_path),
|
| 70 |
+
repo_id=repo_id,
|
| 71 |
+
commit_message=commit_message
|
| 72 |
+
)
|
| 73 |
+
logger.info(f"Successfully uploaded {model_path} to {repo_id}")
|
| 74 |
+
|
| 75 |
+
def download_model(self, repo_id: str, filename: str, local_dir: str = "models/hf_downloads") -> str:
|
| 76 |
+
"""Downloads a specific file from a HF repository."""
|
| 77 |
+
if not _check_hf_availability():
|
| 78 |
+
raise ImportError("Hugging Face Hub library not found.")
|
| 79 |
+
|
| 80 |
+
from huggingface_hub import hf_hub_download
|
| 81 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 82 |
+
path = hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
|
| 83 |
+
logger.info(f"Downloaded {filename} from {repo_id} to {path}")
|
| 84 |
+
return path
|
| 85 |
+
|
| 86 |
+
def consult_model_info(self, repo_id: str) -> Dict[str, Any]:
|
| 87 |
+
"""Gets metadata about a model on the Hub."""
|
| 88 |
+
if not self.api:
|
| 89 |
+
return {}
|
| 90 |
+
info = self.api.model_info(repo_id=repo_id)
|
| 91 |
+
return {
|
| 92 |
+
"id": info.id,
|
| 93 |
+
"tags": info.tags,
|
| 94 |
+
"pipeline_tag": info.pipeline_tag,
|
| 95 |
+
"downloads": info.downloads
|
| 96 |
+
}
|
src/lale_utils.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import traceback
|
| 4 |
+
import queue
|
| 5 |
+
import time
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import joblib
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
|
| 11 |
+
import mlflow
|
| 12 |
+
|
| 13 |
+
# Lale core imports
|
| 14 |
+
import lale
|
| 15 |
+
from lale.lib.lale import Hyperopt
|
| 16 |
+
from lale.lib.sklearn import LogisticRegression, RandomForestClassifier
|
| 17 |
+
from lale.lib.sklearn import MinMaxScaler, PCA
|
| 18 |
+
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
|
| 19 |
+
|
| 20 |
+
from src.mlflow_utils import safe_set_experiment
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _preprocess_for_lale(X: pd.DataFrame, y: pd.Series, task_type: str = "Classification"):
|
| 24 |
+
"""
|
| 25 |
+
Encode non-numeric features so that sklearn estimators can handle them.
|
| 26 |
+
Returns (X_encoded, y_encoded, encoders) where encoders can be used for inverse transforms.
|
| 27 |
+
"""
|
| 28 |
+
X = X.copy()
|
| 29 |
+
|
| 30 |
+
# Encode categorical / object columns
|
| 31 |
+
col_encoders = {}
|
| 32 |
+
for col in X.columns:
|
| 33 |
+
if X[col].dtype == object or str(X[col].dtype) == 'category':
|
| 34 |
+
le = OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)
|
| 35 |
+
X[col] = le.fit_transform(X[[col]]).ravel()
|
| 36 |
+
col_encoders[col] = le
|
| 37 |
+
|
| 38 |
+
# Fill any remaining NaNs
|
| 39 |
+
for col in X.columns:
|
| 40 |
+
if X[col].isna().any():
|
| 41 |
+
X[col] = X[col].fillna(X[col].median() if pd.api.types.is_numeric_dtype(X[col]) else 0)
|
| 42 |
+
|
| 43 |
+
# Encode target if classification (Regression target should remain continuous)
|
| 44 |
+
y_encoder = None
|
| 45 |
+
if task_type != "Regression":
|
| 46 |
+
if y.dtype == object or str(y.dtype) == 'category':
|
| 47 |
+
y_encoder = LabelEncoder()
|
| 48 |
+
y = pd.Series(y_encoder.fit_transform(y), name=y.name)
|
| 49 |
+
|
| 50 |
+
return X, y, col_encoders, y_encoder
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_lale_experiment(
|
| 54 |
+
train_df: pd.DataFrame,
|
| 55 |
+
target_col: str,
|
| 56 |
+
run_name: str,
|
| 57 |
+
time_limit: Optional[int],
|
| 58 |
+
log_queue: queue.Queue,
|
| 59 |
+
stop_event=None,
|
| 60 |
+
val_df: Optional[pd.DataFrame] = None,
|
| 61 |
+
task_type: str = "Classification",
|
| 62 |
+
**kwargs
|
| 63 |
+
) -> Dict[str, Any]:
|
| 64 |
+
"""
|
| 65 |
+
Run Lale experiment using scikit-learn compatible pipelines via Hyperopt.
|
| 66 |
+
Handles text/categorical features with automatic encoding.
|
| 67 |
+
"""
|
| 68 |
+
logger = logging.getLogger("lale")
|
| 69 |
+
logger.info(f"Starting Lale experiment: {run_name} (Task: {task_type})")
|
| 70 |
+
logger.info(f"Dataset shape: {train_df.shape}, Target: {target_col}")
|
| 71 |
+
|
| 72 |
+
# Drop NaNs on target
|
| 73 |
+
train_df_c = train_df.dropna(subset=[target_col])
|
| 74 |
+
X_raw = train_df_c.drop(columns=[target_col])
|
| 75 |
+
y_raw = train_df_c[target_col]
|
| 76 |
+
|
| 77 |
+
# Pre-process: encode categoricals/text for sklearn compatibility
|
| 78 |
+
logger.info("Step: Encoding categorical/text features...")
|
| 79 |
+
X, y, col_encoders, y_encoder = _preprocess_for_lale(X_raw, y_raw, task_type)
|
| 80 |
+
|
| 81 |
+
unique_classes_log = ""
|
| 82 |
+
if task_type != "Regression":
|
| 83 |
+
unique_classes_log = f" | Classes: {y.unique()[:5].tolist()}"
|
| 84 |
+
logger.info(f"Features after encoding: {list(X.columns)}{unique_classes_log}")
|
| 85 |
+
|
| 86 |
+
# Validate MLflow tracking
|
| 87 |
+
safe_set_experiment("Multi_AutoML_Project")
|
| 88 |
+
|
| 89 |
+
# Always end any dangling run (Hyperopt can leave runs open)
|
| 90 |
+
try:
|
| 91 |
+
mlflow.end_run()
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
if stop_event and stop_event.is_set():
|
| 96 |
+
raise StopIteration("Experiment cancelled before setup.")
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
with mlflow.start_run(run_name=run_name) as run:
|
| 100 |
+
run_id = run.info.run_id
|
| 101 |
+
logger.info(f"MLflow Run ID: {run_id}")
|
| 102 |
+
mlflow.log_param("model_type", "lale")
|
| 103 |
+
mlflow.log_param("n_features", X.shape[1])
|
| 104 |
+
mlflow.log_param("n_samples", X.shape[0])
|
| 105 |
+
mlflow.log_param("task_type", task_type)
|
| 106 |
+
|
| 107 |
+
# 1. Pipeline Definition (only numeric-friendly preprocessors)
|
| 108 |
+
logger.info("Step: Defining Lale Planned Pipeline...")
|
| 109 |
+
|
| 110 |
+
if task_type == "Regression":
|
| 111 |
+
from lale.lib.sklearn import LinearRegression, RandomForestRegressor
|
| 112 |
+
planned_pipeline = (
|
| 113 |
+
(MinMaxScaler | PCA) >>
|
| 114 |
+
(LinearRegression | RandomForestRegressor)
|
| 115 |
+
)
|
| 116 |
+
scoring_metric = "r2"
|
| 117 |
+
else:
|
| 118 |
+
planned_pipeline = (
|
| 119 |
+
(MinMaxScaler | PCA) >>
|
| 120 |
+
(LogisticRegression | RandomForestClassifier)
|
| 121 |
+
)
|
| 122 |
+
scoring_metric = "accuracy"
|
| 123 |
+
|
| 124 |
+
if stop_event and stop_event.is_set():
|
| 125 |
+
raise StopIteration("Experiment cancelled before Hyperopt setup.")
|
| 126 |
+
|
| 127 |
+
# 2. Hyperparameter Tuning
|
| 128 |
+
logger.info("Step: Tuning with Hyperopt...")
|
| 129 |
+
max_evals = 10 if time_limit is None or time_limit >= 300 else 5
|
| 130 |
+
time_args = {}
|
| 131 |
+
if time_limit and time_limit > 0:
|
| 132 |
+
time_args['max_eval_time'] = time_limit
|
| 133 |
+
|
| 134 |
+
optimizer = Hyperopt(
|
| 135 |
+
estimator=planned_pipeline,
|
| 136 |
+
max_evals=max_evals,
|
| 137 |
+
cv=3,
|
| 138 |
+
scoring=scoring_metric,
|
| 139 |
+
show_progressbar=False,
|
| 140 |
+
verbose=True, # show per-trial info so we can debug failures
|
| 141 |
+
**time_args
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# 3. Fit Model
|
| 145 |
+
logger.info(f"Step: Fitting Lale Optimizer (evals={max_evals})...")
|
| 146 |
+
start_time = time.time()
|
| 147 |
+
trained_optimizer = optimizer.fit(X.values, y.values)
|
| 148 |
+
|
| 149 |
+
if stop_event and stop_event.is_set():
|
| 150 |
+
raise StopIteration("Experiment cancelled after fitting.")
|
| 151 |
+
|
| 152 |
+
best_model = trained_optimizer.get_pipeline()
|
| 153 |
+
|
| 154 |
+
# Extract score
|
| 155 |
+
try:
|
| 156 |
+
summary = trained_optimizer.summary()
|
| 157 |
+
best_score = -summary.iloc[0]['loss'] if 'loss' in summary.columns else 0.0
|
| 158 |
+
except Exception:
|
| 159 |
+
best_score = 0.0
|
| 160 |
+
|
| 161 |
+
elapsed_time = time.time() - start_time
|
| 162 |
+
logger.info(f"Best Score (CV {scoring_metric}): {best_score:.4f}")
|
| 163 |
+
logger.info(f"Optimization time: {elapsed_time:.1f}s")
|
| 164 |
+
|
| 165 |
+
# 4. Save Model
|
| 166 |
+
logger.info("Step: Saving model locally...")
|
| 167 |
+
model_dir = "models"
|
| 168 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 169 |
+
model_path = os.path.join(model_dir, f"{run_name}_lale_model.pkl")
|
| 170 |
+
joblib.dump({"model": best_model, "col_encoders": col_encoders, "y_encoder": y_encoder, "task_type": task_type}, model_path)
|
| 171 |
+
|
| 172 |
+
# Log metrics
|
| 173 |
+
mlflow.log_metric(f"best_cv_{scoring_metric}", best_score)
|
| 174 |
+
mlflow.log_metric("optimization_time", elapsed_time)
|
| 175 |
+
mlflow.log_param("max_evals", max_evals)
|
| 176 |
+
mlflow.log_artifact(model_path, artifact_path="model")
|
| 177 |
+
|
| 178 |
+
logger.info("Lale experiment completed successfully.")
|
| 179 |
+
|
| 180 |
+
# 5. Prepare return bundle
|
| 181 |
+
bundle = {"model": best_model, "col_encoders": col_encoders, "y_encoder": y_encoder, "task_type": task_type}
|
| 182 |
+
|
| 183 |
+
return {
|
| 184 |
+
"success": True,
|
| 185 |
+
"predictor": bundle,
|
| 186 |
+
"run_id": run_id,
|
| 187 |
+
"type": "lale",
|
| 188 |
+
"model_path": model_path,
|
| 189 |
+
"metrics": {f"best_cv_{scoring_metric}": best_score}
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
except StopIteration as si:
|
| 193 |
+
logger.warning(f"Cancelled: {si}")
|
| 194 |
+
raise
|
| 195 |
+
except Exception as e:
|
| 196 |
+
logger.error(f"Lale Error: {e}")
|
| 197 |
+
logger.error(traceback.format_exc())
|
| 198 |
+
raise e
|
src/onnx_utils.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
# Global flags for availability
|
| 8 |
+
ONNX_AVAILABLE = None
|
| 9 |
+
|
| 10 |
+
def _check_onnx_availability():
|
| 11 |
+
global ONNX_AVAILABLE
|
| 12 |
+
if ONNX_AVAILABLE is not None:
|
| 13 |
+
return ONNX_AVAILABLE
|
| 14 |
+
try:
|
| 15 |
+
import onnx
|
| 16 |
+
import onnxruntime as ort
|
| 17 |
+
ONNX_AVAILABLE = True
|
| 18 |
+
except Exception as e:
|
| 19 |
+
logger.warning(f"ONNX or ONNXRuntime not available: {e}")
|
| 20 |
+
ONNX_AVAILABLE = False
|
| 21 |
+
return ONNX_AVAILABLE
|
| 22 |
+
|
| 23 |
+
def export_to_onnx(model: Any, model_type: str, target_col: str, output_path: str, input_sample: Optional[Any] = None) -> str:
|
| 24 |
+
"""
|
| 25 |
+
Exports a trained model to ONNX format.
|
| 26 |
+
Supports: flaml, pycaret, autogluon (tabular), autokeras (tensorflow).
|
| 27 |
+
"""
|
| 28 |
+
if not _check_onnx_availability():
|
| 29 |
+
raise ImportError("ONNX or ONNXRuntime is not available in this environment.")
|
| 30 |
+
|
| 31 |
+
import onnx
|
| 32 |
+
import pandas as pd
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
logger.info(f"Exporting {model_type} model to ONNX: {output_path}")
|
| 36 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
if model_type in ["flaml", "pycaret", "tpot"]:
|
| 40 |
+
from skl2onnx import to_onnx
|
| 41 |
+
|
| 42 |
+
if input_sample is None:
|
| 43 |
+
raise ValueError("input_sample is required for scikit-learn based ONNX export")
|
| 44 |
+
|
| 45 |
+
if isinstance(input_sample, pd.DataFrame) and target_col in input_sample.columns:
|
| 46 |
+
input_sample = input_sample.drop(columns=[target_col])
|
| 47 |
+
|
| 48 |
+
onx = to_onnx(model, input_sample[:1], initial_types=None)
|
| 49 |
+
with open(output_path, "wb") as f:
|
| 50 |
+
f.write(onx.SerializeToString())
|
| 51 |
+
|
| 52 |
+
elif model_type == "autokeras":
|
| 53 |
+
import tf2onnx
|
| 54 |
+
import tensorflow as tf
|
| 55 |
+
|
| 56 |
+
if input_sample is None:
|
| 57 |
+
raise ValueError("input_sample is required for TensorFlow/AutoKeras ONNX export")
|
| 58 |
+
|
| 59 |
+
input_signature = [tf.TensorSpec([None] + list(input_sample.shape[1:]), tf.float32, name='input')]
|
| 60 |
+
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13)
|
| 61 |
+
onnx.save_model(onnx_model, output_path)
|
| 62 |
+
|
| 63 |
+
elif model_type == "autogluon":
|
| 64 |
+
try:
|
| 65 |
+
model.export_onnx(output_path)
|
| 66 |
+
except AttributeError:
|
| 67 |
+
logger.warning("AutoGluon model does not support direct export_onnx.")
|
| 68 |
+
raise NotImplementedError("AutoGluon ONNX export fallback not implemented.")
|
| 69 |
+
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"Unsupported model type for ONNX export: {model_type}")
|
| 72 |
+
|
| 73 |
+
logger.info(f"Successfully exported model to {output_path}")
|
| 74 |
+
return output_path
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Failed to export {model_type} model to ONNX: {e}")
|
| 78 |
+
raise
|
| 79 |
+
|
| 80 |
+
def load_onnx_session(onnx_path: str):
|
| 81 |
+
"""Loads an ONNX model into an inference session."""
|
| 82 |
+
if not _check_onnx_availability():
|
| 83 |
+
raise ImportError("ONNXRuntime is not available.")
|
| 84 |
+
|
| 85 |
+
import onnxruntime as ort
|
| 86 |
+
if not os.path.exists(onnx_path):
|
| 87 |
+
raise FileNotFoundError(f"ONNX file not found: {onnx_path}")
|
| 88 |
+
return ort.InferenceSession(onnx_path)
|
| 89 |
+
|
| 90 |
+
def predict_onnx(session: Any, df: Any) -> Any:
|
| 91 |
+
"""Runs inference on a DataFrame using an ONNX session."""
|
| 92 |
+
import numpy as np
|
| 93 |
+
|
| 94 |
+
inputs = {}
|
| 95 |
+
for node in session.get_inputs():
|
| 96 |
+
name = node.name
|
| 97 |
+
if name in df.columns:
|
| 98 |
+
inputs[name] = df[[name]].values.astype(np.float32)
|
| 99 |
+
else:
|
| 100 |
+
if len(session.get_inputs()) == 1:
|
| 101 |
+
inputs[name] = df.values.astype(np.float32)
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
outputs = session.run(None, inputs)
|
| 105 |
+
return outputs[0]
|
src/pipeline_parser.py
ADDED
|
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pipeline_parser.py — infer which AutoML pipeline step is active from live logs.
|
| 3 |
+
|
| 4 |
+
Each framework has a sequence of steps. This module parses log lines
|
| 5 |
+
to determine which step is "done", which is "active", and which is "pending".
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
# ── Step definitions per framework ───────────────────────────────────────────
|
| 11 |
+
# Each step has:
|
| 12 |
+
# label — displayed name
|
| 13 |
+
# keywords — log keywords that signal this step has STARTED or is active
|
| 14 |
+
# done_kw — log keywords that signal this step is DONE (optional)
|
| 15 |
+
# description — tooltip / explainer text
|
| 16 |
+
|
| 17 |
+
_STEPS: dict[str, list[dict]] = {
|
| 18 |
+
"autogluon": [
|
| 19 |
+
{
|
| 20 |
+
"label": "Data Preparation",
|
| 21 |
+
"icon": "📊",
|
| 22 |
+
"keywords": ["preprocessing", "converting", "fitting", "loading data", "train_data"],
|
| 23 |
+
"done_kw": ["beginning automl", "fitting model:"],
|
| 24 |
+
"description": "Validates and preprocesses the dataset. Handles missing values, categorical encoding and feature types.",
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"label": "Fitting Models",
|
| 28 |
+
"icon": "🤖",
|
| 29 |
+
"keywords": ["fitting model:", "training model for", "fitting with cpus"],
|
| 30 |
+
"done_kw": ["weightedensemble", "autogluon training complete"],
|
| 31 |
+
"description": "Trains each individual model (LightGBM, XGBoost, CatBoost, RF, etc.) within the time budget.",
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"label": "Stacking / Ensembling",
|
| 35 |
+
"icon": "🏗️",
|
| 36 |
+
"keywords": ["weightedensemble", "ensemble weights", "stacking"],
|
| 37 |
+
"done_kw": ["autogluon training complete"],
|
| 38 |
+
"description": "Combines the best models using weighted ensembling or multi-layer stacking.",
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"label": "Evaluation",
|
| 42 |
+
"icon": "📏",
|
| 43 |
+
"keywords": ["leaderboard", "best model:", "validation score", "score_val"],
|
| 44 |
+
"done_kw": ["tabularpredictor saved", "best model logged"],
|
| 45 |
+
"description": "Evaluates all models on the validation set and builds the final leaderboard.",
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"label": "MLflow Logging",
|
| 49 |
+
"icon": "📝",
|
| 50 |
+
"keywords": ["mlflow", "log_artifacts", "logged successfully", "artifacts logged"],
|
| 51 |
+
"done_kw": ["thread finished"],
|
| 52 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 53 |
+
},
|
| 54 |
+
],
|
| 55 |
+
"flaml": [
|
| 56 |
+
{
|
| 57 |
+
"label": "Data Preparation",
|
| 58 |
+
"icon": "📊",
|
| 59 |
+
"keywords": ["data ready", "preprocessing", "starting flaml"],
|
| 60 |
+
"done_kw": ["executing hyperparameter search"],
|
| 61 |
+
"description": "Validates the dataset, detects feature types, and prepares inputs for FLAML's optimizer.",
|
| 62 |
+
},
|
| 63 |
+
{
|
| 64 |
+
"label": "Hyperparameter Search",
|
| 65 |
+
"icon": "🔍",
|
| 66 |
+
"keywords": ["executing hyperparameter search", "automl.fit", "[flaml.automl", "trial", "best config"],
|
| 67 |
+
"done_kw": ["search finished"],
|
| 68 |
+
"description": "FLAML runs a cost-effective search over hyperparameter configurations using Bayesian optimization.",
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"label": "Best Config Selection",
|
| 72 |
+
"icon": "🏆",
|
| 73 |
+
"keywords": ["search finished", "best estimator", "best loss", "best final"],
|
| 74 |
+
"done_kw": ["saving best model"],
|
| 75 |
+
"description": "Identifies the best-performing estimator and its configuration from the search results.",
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"label": "Model Saving",
|
| 79 |
+
"icon": "💾",
|
| 80 |
+
"keywords": ["saving best model", "model_path", "artifact_path"],
|
| 81 |
+
"done_kw": ["mlflow", "logged successfully"],
|
| 82 |
+
"description": "Serializes the trained model to disk using pickle.",
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"label": "MLflow Logging",
|
| 86 |
+
"icon": "📝",
|
| 87 |
+
"keywords": ["mlflow", "log_artifact", "logged successfully"],
|
| 88 |
+
"done_kw": ["thread finished"],
|
| 89 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 90 |
+
},
|
| 91 |
+
],
|
| 92 |
+
"h2o": [
|
| 93 |
+
{
|
| 94 |
+
"label": "H2O Cluster Init",
|
| 95 |
+
"icon": "🌊",
|
| 96 |
+
"keywords": ["h2o cluster initialized", "initializing h2o", "h2o init"],
|
| 97 |
+
"done_kw": ["starting h2o automl"],
|
| 98 |
+
"description": "Starts the local H2O Java cluster and allocates memory for distributed model training.",
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"label": "Data Preparation",
|
| 102 |
+
"icon": "📊",
|
| 103 |
+
"keywords": ["preparing data", "h2oframe", "feature engineering", "asfactor"],
|
| 104 |
+
"done_kw": ["starting h2o automl training"],
|
| 105 |
+
"description": "Converts Pandas DataFrames to H2O frames and applies type casting for features/targets.",
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"label": "AutoML Training",
|
| 109 |
+
"icon": "🤖",
|
| 110 |
+
"keywords": ["starting h2o automl training", "automl session", "training completed", "aml.train"],
|
| 111 |
+
"done_kw": ["training completed in"],
|
| 112 |
+
"description": "H2O trains multiple model families (GBM, XGBoost, GLM, DRF, DeepLearning) and their variants.",
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"label": "Leaderboard & Scoring",
|
| 116 |
+
"icon": "📏",
|
| 117 |
+
"keywords": ["top 5 models", "leaderboard", "best model score", "auc", "total_models_trained"],
|
| 118 |
+
"done_kw": ["model saved at", "log model to mlflow"],
|
| 119 |
+
"description": "Ranks all trained models and evaluates the leader on the validation/test set.",
|
| 120 |
+
},
|
| 121 |
+
{
|
| 122 |
+
"label": "MLflow Logging",
|
| 123 |
+
"icon": "📝",
|
| 124 |
+
"keywords": ["mlflow", "log_artifacts", "logged successfully", "artifacts logged"],
|
| 125 |
+
"done_kw": ["thread finished"],
|
| 126 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 127 |
+
},
|
| 128 |
+
],
|
| 129 |
+
"tpot": [
|
| 130 |
+
{
|
| 131 |
+
"label": "Data Preparation",
|
| 132 |
+
"icon": "📊",
|
| 133 |
+
"keywords": ["problem type:", "training data shape", "test data shape", "label encoder"],
|
| 134 |
+
"done_kw": ["starting tpot training"],
|
| 135 |
+
"description": "Applies feature engineering pipelines: TF-IDF for text, ordinal encoding, and standard scaling.",
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"label": "Pipeline Generation (GA)",
|
| 139 |
+
"icon": "🧬",
|
| 140 |
+
"keywords": ["starting tpot training", "generation:", "pipeline score:", "optimizing pipeline"],
|
| 141 |
+
"done_kw": ["training completed"],
|
| 142 |
+
"description": "TPOT uses a Genetic Algorithm to evolve and select the best scikit-learn pipeline configurations.",
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"label": "Pipeline Selection",
|
| 146 |
+
"icon": "🏆",
|
| 147 |
+
"keywords": ["training completed", "best pipeline", "fitted_pipeline_", "accuracy:", "f1_macro:"],
|
| 148 |
+
"done_kw": ["pipeline exported"],
|
| 149 |
+
"description": "Identifies the highest-scoring pipeline from the genetic search as the final model.",
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"label": "Export & Analysis",
|
| 153 |
+
"icon": "📤",
|
| 154 |
+
"keywords": ["pipeline exported", "export", "classification report"],
|
| 155 |
+
"done_kw": ["mlflow"],
|
| 156 |
+
"description": "Exports the best pipeline as a .py file and generates a classification/regression report.",
|
| 157 |
+
},
|
| 158 |
+
{
|
| 159 |
+
"label": "MLflow Logging",
|
| 160 |
+
"icon": "📝",
|
| 161 |
+
"keywords": ["mlflow", "tpot automl model", "registered_model_name", "logged successfully"],
|
| 162 |
+
"done_kw": ["thread finished"],
|
| 163 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 164 |
+
},
|
| 165 |
+
],
|
| 166 |
+
"pycaret": [
|
| 167 |
+
{
|
| 168 |
+
"label": "Environment Setup",
|
| 169 |
+
"icon": "⚙️",
|
| 170 |
+
"keywords": ["setting up pycaret", "dataset shape"],
|
| 171 |
+
"done_kw": ["comparing models", "step: comparing models..."],
|
| 172 |
+
"description": "Initializes the PyCaret setup, handling normalization, encoding, and train/test splits internally.",
|
| 173 |
+
},
|
| 174 |
+
{
|
| 175 |
+
"label": "Model Comparison",
|
| 176 |
+
"icon": "⚖️",
|
| 177 |
+
"keywords": ["comparing models", "including fast/robust models"],
|
| 178 |
+
"done_kw": ["tuning best model", "step: tuning best model..."],
|
| 179 |
+
"description": "Trains and evaluates a fast baseline of multiple estimators to find the top candidates.",
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"label": "Hyperparameter Tuning",
|
| 183 |
+
"icon": "🔧",
|
| 184 |
+
"keywords": ["tuning best model", "step: tuning best model..."],
|
| 185 |
+
"done_kw": ["blending top models", "step: blending top models..."],
|
| 186 |
+
"description": "Applies randomized search to optimize hyperparameters of the best performing model.",
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"label": "Model Blending",
|
| 190 |
+
"icon": "🌪️",
|
| 191 |
+
"keywords": ["blending top models", "step: blending top models..."],
|
| 192 |
+
"done_kw": ["saving model", "pycaret experiment completed"],
|
| 193 |
+
"description": "Creates an ensemble of the top models to improve generalized performance via voting/averaging.",
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"label": "MLflow Logging",
|
| 197 |
+
"icon": "📝",
|
| 198 |
+
"keywords": ["saving model to", "pycaret experiment completed", "thread finished"],
|
| 199 |
+
"done_kw": ["thread finished"],
|
| 200 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 201 |
+
},
|
| 202 |
+
],
|
| 203 |
+
"lale": [
|
| 204 |
+
{
|
| 205 |
+
"label": "Pipeline Definition",
|
| 206 |
+
"icon": "⚙️",
|
| 207 |
+
"keywords": ["defining lale planned pipeline", "dataset shape"],
|
| 208 |
+
"done_kw": ["tuning with hyperopt", "step: tuning with hyperopt..."],
|
| 209 |
+
"description": "Maps a search space over transformers (PCA, Scalers) and estimators (LR, RF, KNN).",
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"label": "Hyperopt Tuning",
|
| 213 |
+
"icon": "🔧",
|
| 214 |
+
"keywords": ["tuning with hyperopt", "step: tuning with hyperopt..."],
|
| 215 |
+
"done_kw": ["fitting lale optimizer", "step: fitting lale optimizer"],
|
| 216 |
+
"description": "Configures Tree-structured Parzen Estimators (TPE) algorithm for intelligent hyperparameter search.",
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"label": "Fitting Optimizer",
|
| 220 |
+
"icon": "🕒",
|
| 221 |
+
"keywords": ["fitting lale optimizer", "step: fitting lale optimizer"],
|
| 222 |
+
"done_kw": ["saving model locally", "step: saving model locally"],
|
| 223 |
+
"description": "Executes identical cross-validation folds on generated pipelines within the set budget.",
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"label": "Best Model Extraction",
|
| 227 |
+
"icon": "🏆",
|
| 228 |
+
"keywords": ["best pipeline structure:", "best f1 (macro) score"],
|
| 229 |
+
"done_kw": ["saving model locally", "step: saving model locally"],
|
| 230 |
+
"description": "Decodes the structure and metrics of the optimized pipeline graph.",
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"label": "MLflow Logging",
|
| 234 |
+
"icon": "📝",
|
| 235 |
+
"keywords": ["saving model locally", "lale experiment completed", "thread finished"],
|
| 236 |
+
"done_kw": ["thread finished"],
|
| 237 |
+
"description": "Persists model artifacts, parameters, and metrics to MLflow for tracking and versioning.",
|
| 238 |
+
},
|
| 239 |
+
],
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
# ── Public API ────────────────────────────────────────────────────────────────
|
| 243 |
+
|
| 244 |
+
def get_framework_steps(framework_key: str) -> list[dict]:
|
| 245 |
+
"""Return the step definitions for a given framework key."""
|
| 246 |
+
return _STEPS.get(framework_key.lower(), [])
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def infer_pipeline_steps(framework_key: str, logs: list[str], status: str) -> list[dict]:
|
| 250 |
+
"""
|
| 251 |
+
Returns enriched step list with status attached:
|
| 252 |
+
status = "done" | "active" | "pending"
|
| 253 |
+
|
| 254 |
+
On completed/failed/cancelled runs, all matched steps are "done".
|
| 255 |
+
"""
|
| 256 |
+
steps = get_framework_steps(framework_key)
|
| 257 |
+
if not steps:
|
| 258 |
+
return []
|
| 259 |
+
|
| 260 |
+
log_blob = " ".join(logs).lower()
|
| 261 |
+
|
| 262 |
+
if status == "completed":
|
| 263 |
+
# Mark all steps done
|
| 264 |
+
return [{"label": s["label"], "icon": s["icon"], "description": s["description"], "status": "done"} for s in steps]
|
| 265 |
+
|
| 266 |
+
if status in ("failed", "cancelled"):
|
| 267 |
+
# Mark up to the last-seen step as done, rest pending, mark last active as failed
|
| 268 |
+
last_done_idx = -1
|
| 269 |
+
for i, step in enumerate(steps):
|
| 270 |
+
if any(kw in log_blob for kw in step["keywords"]):
|
| 271 |
+
last_done_idx = i
|
| 272 |
+
|
| 273 |
+
result = []
|
| 274 |
+
for i, step in enumerate(steps):
|
| 275 |
+
if i < last_done_idx:
|
| 276 |
+
st_val = "done"
|
| 277 |
+
elif i == last_done_idx:
|
| 278 |
+
st_val = "failed" if status == "failed" else "cancelled"
|
| 279 |
+
else:
|
| 280 |
+
st_val = "pending"
|
| 281 |
+
result.append({"label": step["label"], "icon": step["icon"], "description": step["description"], "status": st_val})
|
| 282 |
+
return result
|
| 283 |
+
|
| 284 |
+
# Running or queued: find the active step
|
| 285 |
+
last_done_idx = -1
|
| 286 |
+
for i, step in enumerate(steps):
|
| 287 |
+
done_signals = step.get("done_kw", [])
|
| 288 |
+
if any(kw in log_blob for kw in done_signals):
|
| 289 |
+
last_done_idx = i
|
| 290 |
+
|
| 291 |
+
# Active = first step after last_done
|
| 292 |
+
active_idx = min(last_done_idx + 1, len(steps) - 1)
|
| 293 |
+
|
| 294 |
+
result = []
|
| 295 |
+
for i, step in enumerate(steps):
|
| 296 |
+
if i <= last_done_idx:
|
| 297 |
+
st_val = "done"
|
| 298 |
+
elif i == active_idx and status == "running":
|
| 299 |
+
st_val = "active"
|
| 300 |
+
else:
|
| 301 |
+
st_val = "pending"
|
| 302 |
+
result.append({"label": step["label"], "icon": step["icon"], "description": step["description"], "status": st_val})
|
| 303 |
+
|
| 304 |
+
return result
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def extract_best_tpot_pipeline(logs: list[str]) -> Optional[str]:
|
| 308 |
+
"""Extract the TPOT best pipeline string from logs."""
|
| 309 |
+
for line in reversed(logs):
|
| 310 |
+
if "best pipeline:" in line.lower() or "fitted_pipeline_" in line.lower():
|
| 311 |
+
return line.strip()
|
| 312 |
+
if "pipeline(" in line.lower():
|
| 313 |
+
return line.strip()
|
| 314 |
+
return None
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def extract_autogluon_leaderboard_text(logs: list[str]) -> Optional[str]:
|
| 318 |
+
"""Extract leaderboard table text from AutoGluon logs."""
|
| 319 |
+
rows = []
|
| 320 |
+
capture = False
|
| 321 |
+
for line in logs:
|
| 322 |
+
if "model" in line.lower() and "score_val" in line.lower():
|
| 323 |
+
capture = True
|
| 324 |
+
if capture:
|
| 325 |
+
rows.append(line)
|
| 326 |
+
if len(rows) > 15:
|
| 327 |
+
break
|
| 328 |
+
return "\n".join(rows) if rows else None
|
src/pycaret_utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import traceback
|
| 4 |
+
import queue
|
| 5 |
+
import time
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from typing import Dict, Any, Optional
|
| 8 |
+
|
| 9 |
+
import mlflow
|
| 10 |
+
|
| 11 |
+
from src.mlflow_utils import safe_set_experiment
|
| 12 |
+
from src.onnx_utils import export_to_onnx
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def run_pycaret_experiment(
|
| 16 |
+
train_df: pd.DataFrame,
|
| 17 |
+
target_col: str,
|
| 18 |
+
run_name: str,
|
| 19 |
+
time_limit: Optional[int],
|
| 20 |
+
log_queue: queue.Queue,
|
| 21 |
+
stop_event=None,
|
| 22 |
+
val_df: Optional[pd.DataFrame] = None,
|
| 23 |
+
task_type: str = "Classification",
|
| 24 |
+
n_jobs: int = 1,
|
| 25 |
+
**kwargs
|
| 26 |
+
) -> Dict[str, Any]:
|
| 27 |
+
"""
|
| 28 |
+
Run PyCaret experiment.
|
| 29 |
+
Dynamically loads classification, regression, or time_series depending on task_type.
|
| 30 |
+
"""
|
| 31 |
+
logger = logging.getLogger("pycaret")
|
| 32 |
+
logger.info(f"Starting PyCaret experiment: {run_name} (Task: {task_type})")
|
| 33 |
+
logger.info(f"Dataset shape: {train_df.shape}, Target: {target_col}")
|
| 34 |
+
|
| 35 |
+
# Dynamic imports based on task_type
|
| 36 |
+
if task_type == "Regression":
|
| 37 |
+
from pycaret.regression import setup, compare_models, pull, tune_model, blend_models, save_model
|
| 38 |
+
sort_metric = "R2"
|
| 39 |
+
include_models = ["lr", "rf", "et", "lightgbm"]
|
| 40 |
+
elif task_type == "Time Series Forecasting":
|
| 41 |
+
from pycaret.time_series import setup, compare_models, pull, tune_model, blend_models, save_model
|
| 42 |
+
sort_metric = "MASE"
|
| 43 |
+
include_models = ["naive", "snaive", "arima", "ets"]
|
| 44 |
+
else:
|
| 45 |
+
from pycaret.classification import setup, compare_models, pull, tune_model, blend_models, save_model
|
| 46 |
+
sort_metric = "F1"
|
| 47 |
+
include_models = ["lr", "nb", "rf", "et", "lightgbm"]
|
| 48 |
+
|
| 49 |
+
# Always end any dangling MLflow run to avoid conflicts
|
| 50 |
+
try:
|
| 51 |
+
mlflow.end_run()
|
| 52 |
+
except Exception:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
# 1. Prepare MLflow Tracking
|
| 56 |
+
safe_set_experiment("Multi_AutoML_Project")
|
| 57 |
+
|
| 58 |
+
if stop_event and stop_event.is_set():
|
| 59 |
+
raise StopIteration("Experiment cancelled before setup.")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
# 2. PyCaret Setup
|
| 63 |
+
logger.info("Step: Setting up PyCaret environment...")
|
| 64 |
+
|
| 65 |
+
setup_kwargs = {
|
| 66 |
+
"data": train_df,
|
| 67 |
+
"target": target_col,
|
| 68 |
+
"session_id": 42,
|
| 69 |
+
"verbose": False,
|
| 70 |
+
"fold": 3,
|
| 71 |
+
"log_experiment": False,
|
| 72 |
+
"system_log": False,
|
| 73 |
+
"n_jobs": n_jobs
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
if task_type == "Time Series Forecasting":
|
| 77 |
+
setup_kwargs["fh"] = kwargs.get("fh", 12)
|
| 78 |
+
setup_kwargs["seasonal_period"] = kwargs.get("seasonal_period", 12)
|
| 79 |
+
else:
|
| 80 |
+
setup_kwargs["test_data"] = val_df
|
| 81 |
+
setup_kwargs["normalize"] = True
|
| 82 |
+
setup_kwargs["index"] = False
|
| 83 |
+
setup_kwargs["feature_selection"] = False
|
| 84 |
+
setup_kwargs["memory"] = False
|
| 85 |
+
|
| 86 |
+
clf_setup = setup(**setup_kwargs)
|
| 87 |
+
|
| 88 |
+
if stop_event and stop_event.is_set():
|
| 89 |
+
raise StopIteration("Experiment cancelled after setup.")
|
| 90 |
+
|
| 91 |
+
# 3. Start our own MLflow run AFTER PyCaret setup
|
| 92 |
+
with mlflow.start_run(run_name=run_name) as run:
|
| 93 |
+
run_id = run.info.run_id
|
| 94 |
+
logger.info(f"MLflow Run ID: {run_id}")
|
| 95 |
+
mlflow.log_param("framework", "pycaret")
|
| 96 |
+
mlflow.log_param("model_type", "pycaret")
|
| 97 |
+
mlflow.log_param("task_type", task_type)
|
| 98 |
+
|
| 99 |
+
# 4. Model Comparison
|
| 100 |
+
logger.info("Step: Comparing models...")
|
| 101 |
+
n_select = 3
|
| 102 |
+
logger.info(f"Including models: {include_models} (Sorting by {sort_metric})")
|
| 103 |
+
|
| 104 |
+
best_models = compare_models(
|
| 105 |
+
n_select=n_select,
|
| 106 |
+
sort=sort_metric,
|
| 107 |
+
verbose=False,
|
| 108 |
+
include=include_models
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
comparison_df = pull()
|
| 112 |
+
if not comparison_df.empty:
|
| 113 |
+
top_model_name = comparison_df.iloc[0]['Model']
|
| 114 |
+
logger.info(f"Best model found: {top_model_name}")
|
| 115 |
+
|
| 116 |
+
if stop_event and stop_event.is_set():
|
| 117 |
+
raise StopIteration("Experiment cancelled after model comparison.")
|
| 118 |
+
|
| 119 |
+
# Ensure best_models is a list
|
| 120 |
+
if not isinstance(best_models, list):
|
| 121 |
+
best_models = [best_models]
|
| 122 |
+
|
| 123 |
+
best_model = best_models[0]
|
| 124 |
+
|
| 125 |
+
# 5. Tuning (Time Series tuning might require different params, keeping generic)
|
| 126 |
+
logger.info("Step: Tuning best model...")
|
| 127 |
+
n_iter = 10 if time_limit is None or time_limit >= 300 else 5
|
| 128 |
+
|
| 129 |
+
# search_library="scikit-learn" shouldn't be passed to pycaret.time_series
|
| 130 |
+
tune_kwargs = {
|
| 131 |
+
"estimator": best_model,
|
| 132 |
+
"optimize": sort_metric,
|
| 133 |
+
"n_iter": n_iter,
|
| 134 |
+
"verbose": False,
|
| 135 |
+
"choose_better": True
|
| 136 |
+
}
|
| 137 |
+
if task_type != "Time Series Forecasting":
|
| 138 |
+
tune_kwargs["search_library"] = "scikit-learn"
|
| 139 |
+
tune_kwargs["search_algorithm"] = "random"
|
| 140 |
+
|
| 141 |
+
tuned_model = tune_model(**tune_kwargs)
|
| 142 |
+
|
| 143 |
+
if stop_event and stop_event.is_set():
|
| 144 |
+
raise StopIteration("Experiment cancelled after tuning.")
|
| 145 |
+
|
| 146 |
+
# 6. Blending (only if we have multiple models)
|
| 147 |
+
if len(best_models) > 1:
|
| 148 |
+
logger.info("Step: Blending top models...")
|
| 149 |
+
final_model = blend_models(
|
| 150 |
+
estimator_list=best_models,
|
| 151 |
+
optimize=sort_metric,
|
| 152 |
+
verbose=False
|
| 153 |
+
)
|
| 154 |
+
else:
|
| 155 |
+
final_model = tuned_model
|
| 156 |
+
logger.info("Step: Skipping blend (only one model selected).")
|
| 157 |
+
|
| 158 |
+
# 7. Save model
|
| 159 |
+
model_dir = "models"
|
| 160 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 161 |
+
model_path_base = os.path.join(model_dir, f"{run_name}_pycaret_model")
|
| 162 |
+
logger.info(f"Saving model to {model_path_base}.pkl...")
|
| 163 |
+
save_model(final_model, model_path_base)
|
| 164 |
+
|
| 165 |
+
# 8. Log metrics to our MLflow run
|
| 166 |
+
try:
|
| 167 |
+
final_metrics = pull()
|
| 168 |
+
if not final_metrics.empty:
|
| 169 |
+
row = final_metrics.iloc[0]
|
| 170 |
+
for k, v in row.items():
|
| 171 |
+
if isinstance(v, (int, float)):
|
| 172 |
+
mlflow.log_metric(k.lower().replace(" ", "_"), float(v))
|
| 173 |
+
except Exception as me:
|
| 174 |
+
logger.warning(f"Could not pull metrics: {me}")
|
| 175 |
+
|
| 176 |
+
# Log model artifact
|
| 177 |
+
model_pkl = f"{model_path_base}.pkl"
|
| 178 |
+
if os.path.exists(model_pkl):
|
| 179 |
+
mlflow.log_artifact(model_pkl, artifact_path="model")
|
| 180 |
+
|
| 181 |
+
# ONNX Export
|
| 182 |
+
try:
|
| 183 |
+
onnx_path = os.path.join(model_dir, f"{run_name}_pycaret.onnx")
|
| 184 |
+
# PyCaret 'final_model' is a scikit-learn pipeline
|
| 185 |
+
export_to_onnx(final_model, "pycaret", target_col, onnx_path, input_sample=train_df[:1])
|
| 186 |
+
mlflow.log_artifact(onnx_path, artifact_path="model")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.warning(f"Failed to export PyCaret model to ONNX: {e}")
|
| 189 |
+
|
| 190 |
+
logger.info("PyCaret experiment completed successfully.")
|
| 191 |
+
return {
|
| 192 |
+
"success": True,
|
| 193 |
+
"predictor": final_model,
|
| 194 |
+
"run_id": run_id,
|
| 195 |
+
"type": "pycaret",
|
| 196 |
+
"model_path": model_pkl
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
except StopIteration as si:
|
| 200 |
+
logger.warning(f"Cancelled: {si}")
|
| 201 |
+
raise
|
| 202 |
+
except Exception as e:
|
| 203 |
+
logger.error(f"PyCaret Error: {e}")
|
| 204 |
+
logger.error(traceback.format_exc())
|
| 205 |
+
raise e
|
| 206 |
+
finally:
|
| 207 |
+
# Always clean up any dangling run
|
| 208 |
+
try:
|
| 209 |
+
mlflow.end_run()
|
| 210 |
+
except Exception:
|
| 211 |
+
pass
|
| 212 |
+
|
src/tpot_utils.py
CHANGED
|
@@ -165,7 +165,8 @@ def train_tpot_model(df, target_column, run_name,
|
|
| 165 |
generations=5, population_size=20, cv=5,
|
| 166 |
scoring=None, max_time_mins=30, max_eval_time_mins=5, random_state=42,
|
| 167 |
verbosity=2, n_jobs=-1, config_dict='TPOT sparse',
|
| 168 |
-
tfidf_max_features=500, tfidf_ngram_range=(1, 2)
|
|
|
|
| 169 |
"""
|
| 170 |
Train TPOT model with MLflow tracking
|
| 171 |
"""
|
|
@@ -232,10 +233,13 @@ def train_tpot_model(df, target_column, run_name,
|
|
| 232 |
scoring = 'neg_mean_squared_error'
|
| 233 |
|
| 234 |
# Ensure there are no loose active runs that could cause errors on start
|
| 235 |
-
|
| 236 |
-
mlflow.
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
-
with mlflow.start_run(run_name=run_name) as run:
|
| 239 |
logger.info(f"Starting TPOT training for run: {run_name}")
|
| 240 |
|
| 241 |
# Choose TPOT class based on problem type
|
|
@@ -432,6 +436,19 @@ def train_tpot_model(df, target_column, run_name,
|
|
| 432 |
# Log the fitted pipeline
|
| 433 |
mlflow.sklearn.log_model(final_pipeline, "model", registered_model_name=f"TPOT_{run_name}")
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
logger.info("TPOT model successfully registered in MLflow")
|
| 436 |
|
| 437 |
return tpot, final_pipeline, run.info.run_id, model_info
|
|
|
|
| 165 |
generations=5, population_size=20, cv=5,
|
| 166 |
scoring=None, max_time_mins=30, max_eval_time_mins=5, random_state=42,
|
| 167 |
verbosity=2, n_jobs=-1, config_dict='TPOT sparse',
|
| 168 |
+
tfidf_max_features=500, tfidf_ngram_range=(1, 2),
|
| 169 |
+
stop_event=None):
|
| 170 |
"""
|
| 171 |
Train TPOT model with MLflow tracking
|
| 172 |
"""
|
|
|
|
| 233 |
scoring = 'neg_mean_squared_error'
|
| 234 |
|
| 235 |
# Ensure there are no loose active runs that could cause errors on start
|
| 236 |
+
try:
|
| 237 |
+
while mlflow.active_run():
|
| 238 |
+
mlflow.end_run()
|
| 239 |
+
except:
|
| 240 |
+
pass
|
| 241 |
|
| 242 |
+
with mlflow.start_run(run_name=run_name, nested=True) as run:
|
| 243 |
logger.info(f"Starting TPOT training for run: {run_name}")
|
| 244 |
|
| 245 |
# Choose TPOT class based on problem type
|
|
|
|
| 436 |
# Log the fitted pipeline
|
| 437 |
mlflow.sklearn.log_model(final_pipeline, "model", registered_model_name=f"TPOT_{run_name}")
|
| 438 |
|
| 439 |
+
# Generate and log consumption code sample
|
| 440 |
+
try:
|
| 441 |
+
from src.code_gen_utils import generate_consumption_code
|
| 442 |
+
code_sample = generate_consumption_code("tpot", run.info.run_id, target_column)
|
| 443 |
+
code_path = "consumption_sample.py"
|
| 444 |
+
with open(code_path, "w") as f:
|
| 445 |
+
f.write(code_sample)
|
| 446 |
+
mlflow.log_artifact(code_path)
|
| 447 |
+
if os.path.exists(code_path):
|
| 448 |
+
os.remove(code_path)
|
| 449 |
+
except Exception as e:
|
| 450 |
+
logger.warning(f"Failed to generate consumption code: {e}")
|
| 451 |
+
|
| 452 |
logger.info("TPOT model successfully registered in MLflow")
|
| 453 |
|
| 454 |
return tpot, final_pipeline, run.info.run_id, model_info
|
src/training_worker.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
training_worker: thread entry point for every AutoML run.
|
| 3 |
+
Captures stdout/stderr, feeds log_queue, puts result into result_queue,
|
| 4 |
+
and respects the stop_event for graceful cancellation.
|
| 5 |
+
|
| 6 |
+
Log isolation strategy (definitive):
|
| 7 |
+
- We attach a _QueueLogHandler to each relevant named library logger.
|
| 8 |
+
- Each handler has a _ThreadFilter that only accepts log records whose
|
| 9 |
+
record.thread matches the experiment thread's ID.
|
| 10 |
+
- This means messages from Thread A never land in Thread B's queue,
|
| 11 |
+
even though they share the same named logger objects.
|
| 12 |
+
- propagate is set to False to prevent double-delivery via the root logger.
|
| 13 |
+
- All are restored in the finally block.
|
| 14 |
+
|
| 15 |
+
Stdout/Stderr isolation:
|
| 16 |
+
- redirect_stdout/redirect_stderr are process-global (they overwrite sys.stdout).
|
| 17 |
+
- We use a _ThreadAwareIO wrapper instead: it checks threading.current_thread()
|
| 18 |
+
on every write() call, so writes only reach the owning thread's queue.
|
| 19 |
+
"""
|
| 20 |
+
import io
|
| 21 |
+
import sys
|
| 22 |
+
import logging
|
| 23 |
+
import threading
|
| 24 |
+
import traceback
|
| 25 |
+
|
| 26 |
+
from src.experiment_manager import ExperimentEntry
|
| 27 |
+
|
| 28 |
+
_LIB_LOGGERS = [
|
| 29 |
+
'flaml', 'autogluon', 'mlflow', 'h2o', 'tpot',
|
| 30 |
+
'pycaret', 'lale', 'hyperopt', 'lightgbm', 'xgboost', 'catboost'
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Thread-aware stdout/stderr router (installed once, process-wide)
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
class _ThreadAwareIO(io.TextIOBase):
|
| 38 |
+
"""
|
| 39 |
+
Drop-in replacement for sys.stdout / sys.stderr that routes each write()
|
| 40 |
+
to the queue registered for the current thread, or falls back to the
|
| 41 |
+
original stream.
|
| 42 |
+
"""
|
| 43 |
+
def __init__(self, original_stream):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self._original = original_stream
|
| 46 |
+
self._lock = threading.Lock()
|
| 47 |
+
self._thread_queues: dict[int, "queue.Queue"] = {}
|
| 48 |
+
|
| 49 |
+
def register(self, thread_id: int, q):
|
| 50 |
+
with self._lock:
|
| 51 |
+
self._thread_queues[thread_id] = q
|
| 52 |
+
|
| 53 |
+
def unregister(self, thread_id: int):
|
| 54 |
+
with self._lock:
|
| 55 |
+
self._thread_queues.pop(thread_id, None)
|
| 56 |
+
|
| 57 |
+
def write(self, s: str) -> int:
|
| 58 |
+
if not isinstance(s, str):
|
| 59 |
+
try:
|
| 60 |
+
s = str(s)
|
| 61 |
+
except Exception:
|
| 62 |
+
return 0
|
| 63 |
+
tid = threading.current_thread().ident
|
| 64 |
+
with self._lock:
|
| 65 |
+
q = self._thread_queues.get(tid)
|
| 66 |
+
if q is not None:
|
| 67 |
+
if s.strip():
|
| 68 |
+
# Filter out progress bar characters that fail on Windows cp1252
|
| 69 |
+
# \u2588 is the full block character
|
| 70 |
+
safe_s = s.replace('\u2588', '#').replace('\u258c', '|').replace('\u2584', '-')
|
| 71 |
+
q.put(safe_s.strip())
|
| 72 |
+
else:
|
| 73 |
+
# Fall back to original stream for threads not registered
|
| 74 |
+
try:
|
| 75 |
+
self._original.write(s)
|
| 76 |
+
except Exception:
|
| 77 |
+
pass
|
| 78 |
+
return len(s)
|
| 79 |
+
|
| 80 |
+
def flush(self):
|
| 81 |
+
try:
|
| 82 |
+
self._original.flush()
|
| 83 |
+
except Exception:
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def encoding(self):
|
| 88 |
+
return getattr(self._original, 'encoding', 'utf-8') or 'utf-8'
|
| 89 |
+
|
| 90 |
+
@property
|
| 91 |
+
def errors(self):
|
| 92 |
+
return getattr(self._original, 'errors', 'replace')
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# Install thread-aware routers once for the entire process
|
| 96 |
+
_stdout_router = _ThreadAwareIO(sys.__stdout__)
|
| 97 |
+
_stderr_router = _ThreadAwareIO(sys.__stderr__)
|
| 98 |
+
sys.stdout = _stdout_router
|
| 99 |
+
sys.stderr = _stderr_router
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Per-thread log handler with thread filter
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
class _ThreadFilter(logging.Filter):
|
| 106 |
+
"""Only accepts log records emitted by a specific OS thread."""
|
| 107 |
+
def __init__(self, thread_id: int):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self._thread_id = thread_id
|
| 110 |
+
|
| 111 |
+
def filter(self, record: logging.LogRecord) -> bool:
|
| 112 |
+
return record.thread == self._thread_id
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class _QueueLogHandler(logging.Handler):
|
| 116 |
+
def __init__(self, log_queue):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.log_queue = log_queue
|
| 119 |
+
|
| 120 |
+
def emit(self, record: logging.LogRecord):
|
| 121 |
+
try:
|
| 122 |
+
msg = self.format(record)
|
| 123 |
+
self.log_queue.put(msg)
|
| 124 |
+
except Exception:
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ---------------------------------------------------------------------------
|
| 129 |
+
# Worker entry point
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
def run_training_worker(entry: ExperimentEntry, train_fn, kwargs: dict):
|
| 132 |
+
"""
|
| 133 |
+
Thread target. Runs train_fn(**kwargs, stop_event=entry.stop_event),
|
| 134 |
+
keeps the entry status updated, and puts the final result dict in
|
| 135 |
+
result_queue.
|
| 136 |
+
"""
|
| 137 |
+
thread_id = threading.current_thread().ident
|
| 138 |
+
|
| 139 |
+
# --- Thread-aware stdout/stderr capture ---
|
| 140 |
+
_stdout_router.register(thread_id, entry.log_queue)
|
| 141 |
+
_stderr_router.register(thread_id, entry.log_queue)
|
| 142 |
+
|
| 143 |
+
# --- Per-thread logging handler (with thread filter) ---
|
| 144 |
+
handler = _QueueLogHandler(entry.log_queue)
|
| 145 |
+
handler.setFormatter(logging.Formatter('%(message)s'))
|
| 146 |
+
handler.addFilter(_ThreadFilter(thread_id))
|
| 147 |
+
|
| 148 |
+
saved_propagate: dict[str, bool] = {}
|
| 149 |
+
for lib in _LIB_LOGGERS:
|
| 150 |
+
lib_logger = logging.getLogger(lib)
|
| 151 |
+
saved_propagate[lib] = lib_logger.propagate
|
| 152 |
+
lib_logger.propagate = False # prevents root from seeing AND double-deliver
|
| 153 |
+
lib_logger.addHandler(handler)
|
| 154 |
+
if lib_logger.level == logging.NOTSET or lib_logger.level > logging.INFO:
|
| 155 |
+
lib_logger.setLevel(logging.INFO)
|
| 156 |
+
|
| 157 |
+
entry.status = "running"
|
| 158 |
+
entry.log_queue.put(f"[Worker] Starting training: {entry.metadata.get('run_name', entry.key)}")
|
| 159 |
+
|
| 160 |
+
try:
|
| 161 |
+
# Inject stop_event and telemetry_queue into kwargs if the function accepts it
|
| 162 |
+
try:
|
| 163 |
+
import inspect
|
| 164 |
+
sig = inspect.signature(train_fn)
|
| 165 |
+
if 'stop_event' in sig.parameters:
|
| 166 |
+
kwargs['stop_event'] = entry.stop_event
|
| 167 |
+
if 'telemetry_queue' in sig.parameters:
|
| 168 |
+
kwargs['telemetry_queue'] = entry.telemetry_queue
|
| 169 |
+
except Exception:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
result = train_fn(**kwargs)
|
| 173 |
+
|
| 174 |
+
# Normalise result into a standard dict
|
| 175 |
+
if isinstance(result, tuple):
|
| 176 |
+
if len(result) == 2:
|
| 177 |
+
predictor, run_id = result
|
| 178 |
+
entry.result_queue.put({
|
| 179 |
+
"success": True, "predictor": predictor, "run_id": run_id,
|
| 180 |
+
"type": entry.metadata.get("framework_key", "unknown")
|
| 181 |
+
})
|
| 182 |
+
elif len(result) == 4:
|
| 183 |
+
tpot, pipeline, run_id, info = result
|
| 184 |
+
entry.result_queue.put({
|
| 185 |
+
"success": True, "predictor": pipeline, "run_id": run_id, "info": info, "type": "tpot"
|
| 186 |
+
})
|
| 187 |
+
else:
|
| 188 |
+
entry.result_queue.put({
|
| 189 |
+
"success": True, "predictor": result[0], "run_id": result[-1],
|
| 190 |
+
"type": entry.metadata.get("framework_key", "unknown")
|
| 191 |
+
})
|
| 192 |
+
elif isinstance(result, dict):
|
| 193 |
+
entry.result_queue.put(result)
|
| 194 |
+
else:
|
| 195 |
+
entry.result_queue.put({
|
| 196 |
+
"success": True, "predictor": result, "run_id": None,
|
| 197 |
+
"type": entry.metadata.get("framework_key", "unknown")
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
except StopIteration:
|
| 201 |
+
entry.log_queue.put("[Worker] Training cancelled by user request.")
|
| 202 |
+
entry.result_queue.put({"success": False, "cancelled": True, "error": "Cancelled by user"})
|
| 203 |
+
except Exception as e:
|
| 204 |
+
err_tb = traceback.format_exc()
|
| 205 |
+
entry.log_queue.put(f"[Worker] CRITICAL ERROR: {e}\n{err_tb}")
|
| 206 |
+
entry.result_queue.put({"success": False, "error": str(e), "traceback": err_tb})
|
| 207 |
+
finally:
|
| 208 |
+
# Restore all lib loggers
|
| 209 |
+
for lib in _LIB_LOGGERS:
|
| 210 |
+
lib_logger = logging.getLogger(lib)
|
| 211 |
+
lib_logger.removeHandler(handler)
|
| 212 |
+
lib_logger.propagate = saved_propagate.get(lib, True)
|
| 213 |
+
|
| 214 |
+
# Unregister stdout/stderr routing for this thread
|
| 215 |
+
_stdout_router.unregister(thread_id)
|
| 216 |
+
_stderr_router.unregister(thread_id)
|
| 217 |
+
|
| 218 |
+
entry.log_queue.put("[Worker] Thread finished.")
|
src/xai_utils.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import io
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
def generate_shap_explanation(model, X_train: pd.DataFrame, X_valid: pd.DataFrame = None,
|
| 10 |
+
max_background_samples=100, task_type="Classification"):
|
| 11 |
+
"""
|
| 12 |
+
Generates SHAP Global Feature Importance plot for Tabular data.
|
| 13 |
+
"""
|
| 14 |
+
try:
|
| 15 |
+
import shap
|
| 16 |
+
except ImportError:
|
| 17 |
+
warnings.warn("SHAP library not installed. Cannot generate explanations.")
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
plt.switch_backend('Agg') # Ensure thread-safe rendering without GUI
|
| 21 |
+
|
| 22 |
+
# 1. Determine background dataset (handle large data gracefully)
|
| 23 |
+
bg_data = X_train
|
| 24 |
+
if len(bg_data) > max_background_samples:
|
| 25 |
+
bg_data = bg_data.sample(n=max_background_samples, random_state=42)
|
| 26 |
+
|
| 27 |
+
evaluate_data = X_valid if X_valid is not None else bg_data
|
| 28 |
+
if len(evaluate_data) > max_background_samples:
|
| 29 |
+
evaluate_data = evaluate_data.sample(n=max_background_samples, random_state=42)
|
| 30 |
+
|
| 31 |
+
# Convert non-numeric for generic shap handling if required by models
|
| 32 |
+
# Depending on framework, categorical columns might need Ordinal/OneHot.
|
| 33 |
+
# For robust black-box generic explainer:
|
| 34 |
+
|
| 35 |
+
explainer = None
|
| 36 |
+
shap_values = None
|
| 37 |
+
|
| 38 |
+
# 2. Heuristics to pick the right explainer
|
| 39 |
+
model_type = str(type(model)).lower()
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
if 'lgbm' in model_type or 'xgb' in model_type or 'catboost' in model_type or 'ensemble' in model_type:
|
| 43 |
+
# TreeExplainer is fast for tree-based models and forests
|
| 44 |
+
try:
|
| 45 |
+
explainer = shap.TreeExplainer(model)
|
| 46 |
+
shap_values = explainer.shap_values(evaluate_data)
|
| 47 |
+
except Exception:
|
| 48 |
+
pass # Fallback to generic
|
| 49 |
+
|
| 50 |
+
if explainer is None:
|
| 51 |
+
# For complex pipelines (like sklearn pipelines, PyCaret, generic wrappers)
|
| 52 |
+
# Use KernelExplainer as a Black-Box proxy (requires a predict function)
|
| 53 |
+
|
| 54 |
+
predict_fn = None
|
| 55 |
+
if hasattr(model, "predict_proba") and "classification" in task_type.lower():
|
| 56 |
+
predict_fn = lambda x: model.predict_proba(x)
|
| 57 |
+
elif hasattr(model, "predict"):
|
| 58 |
+
predict_fn = lambda x: model.predict(x)
|
| 59 |
+
else:
|
| 60 |
+
return None # Can't explain
|
| 61 |
+
|
| 62 |
+
# KernelExplainer can be slow, hence the small bg_data
|
| 63 |
+
explainer = shap.KernelExplainer(predict_fn, bg_data)
|
| 64 |
+
shap_values = explainer.shap_values(evaluate_data)
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
warnings.warn(f"SHAP generation failed: {e}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
# 3. Generate the Plot
|
| 71 |
+
fig = plt.figure(figsize=(10, 6))
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# For multi-class, shap_values is a list. For regression/binary, it's an array.
|
| 75 |
+
if isinstance(shap_values, list):
|
| 76 |
+
# Take the shap values for the first class/positive class for overview
|
| 77 |
+
shap.summary_plot(shap_values[1] if len(shap_values)>1 else shap_values[0], evaluate_data, show=False)
|
| 78 |
+
else:
|
| 79 |
+
shap.summary_plot(shap_values, evaluate_data, show=False)
|
| 80 |
+
|
| 81 |
+
plt.tight_layout()
|
| 82 |
+
return fig
|
| 83 |
+
except Exception as e:
|
| 84 |
+
warnings.warn(f"SHAP plot rendering failed: {e}")
|
| 85 |
+
plt.close(fig)
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
def generate_cv_saliency_map(model, image_path: str, target_size=(224, 224), step=15, window_size=30):
|
| 89 |
+
"""
|
| 90 |
+
Universal Occlusion Saliency Map for Black-Box CV Models (AutoGluon/AutoKeras).
|
| 91 |
+
Instead of relying on internal hooks (which heavily abstracted AutoML layers hide),
|
| 92 |
+
we slide a black box ('occlusion') across the image and measure the confidence drop.
|
| 93 |
+
The regions that drop the confidence the most are the most salient (important) for the prediction.
|
| 94 |
+
"""
|
| 95 |
+
try:
|
| 96 |
+
from PIL import Image
|
| 97 |
+
import cv2
|
| 98 |
+
except ImportError:
|
| 99 |
+
warnings.warn("Missing CV libraries (Pillow/OpenCV) for Saliency representation.")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
# 1. Load Original Image
|
| 104 |
+
original_img = Image.open(image_path).convert('RGB')
|
| 105 |
+
img_w, img_h = original_img.size
|
| 106 |
+
|
| 107 |
+
# Determine the baseline prediction to see what class we are explaining
|
| 108 |
+
# Since this is a generic AutoML predictor UI, we assume `model.predict_proba` gives a df or dict
|
| 109 |
+
df_single = pd.DataFrame([{"image": image_path}])
|
| 110 |
+
|
| 111 |
+
# Get base probabilities.
|
| 112 |
+
# Note: Depending on AutoGluon/AutoKeras formatting, the predict_proba method might vary.
|
| 113 |
+
if hasattr(model, 'predict_proba'):
|
| 114 |
+
base_probs = model.predict_proba(df_single)
|
| 115 |
+
if isinstance(base_probs, pd.DataFrame):
|
| 116 |
+
# Assuming top class
|
| 117 |
+
top_class = base_probs.iloc[0].idxmax()
|
| 118 |
+
base_score = base_probs.iloc[0][top_class]
|
| 119 |
+
else:
|
| 120 |
+
top_class = np.argmax(base_probs[0])
|
| 121 |
+
base_score = base_probs[0][top_class]
|
| 122 |
+
else:
|
| 123 |
+
warnings.warn("Model does not support predict_proba, Saliency Map cannot track confidence drops.")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
# 2. Build Saliency Map Array
|
| 127 |
+
saliency_map = np.zeros((img_h, img_w))
|
| 128 |
+
heatmap_counts = np.zeros((img_h, img_w))
|
| 129 |
+
|
| 130 |
+
# We will create occluded images, save them temporarily, and batch-predict to find drops
|
| 131 |
+
# For performance, we downsize the grid if the image is huge
|
| 132 |
+
grid_step = step
|
| 133 |
+
w_size = window_size
|
| 134 |
+
|
| 135 |
+
# To avoid predicting 1000s of images, let's limit the grid
|
| 136 |
+
if (img_h / step) * (img_w / step) > 200:
|
| 137 |
+
grid_step = max(int(img_h/10), 10)
|
| 138 |
+
w_size = int(grid_step * 1.5)
|
| 139 |
+
|
| 140 |
+
occluded_paths = []
|
| 141 |
+
coords = []
|
| 142 |
+
|
| 143 |
+
tmp_dir = os.path.join("data_lake", "tmp_occlusion")
|
| 144 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 145 |
+
img_arr_orig = np.array(original_img)
|
| 146 |
+
|
| 147 |
+
# Generate Occluded Copies
|
| 148 |
+
for y in range(0, img_h, grid_step):
|
| 149 |
+
for x in range(0, img_w, grid_step):
|
| 150 |
+
img_copy = img_arr_orig.copy()
|
| 151 |
+
|
| 152 |
+
# Apply black box
|
| 153 |
+
y1, y2 = max(0, y - w_size // 2), min(img_h, y + w_size // 2)
|
| 154 |
+
x1, x2 = max(0, x - w_size // 2), min(img_w, x + w_size // 2)
|
| 155 |
+
img_copy[y1:y2, x1:x2] = 0 # Occlude
|
| 156 |
+
|
| 157 |
+
t_path = os.path.join(tmp_dir, f"occ_{y}_{x}.jpg")
|
| 158 |
+
Image.fromarray(img_copy).save(t_path)
|
| 159 |
+
occluded_paths.append(t_path)
|
| 160 |
+
coords.append((y1, y2, x1, x2))
|
| 161 |
+
|
| 162 |
+
# Predict all simultaneously
|
| 163 |
+
df_batch = pd.DataFrame({"image": occluded_paths})
|
| 164 |
+
|
| 165 |
+
try:
|
| 166 |
+
batch_probs = model.predict_proba(df_batch)
|
| 167 |
+
except Exception:
|
| 168 |
+
warnings.warn("Batch probability prediction failed for occlusion map.")
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
# Parse scores based on framework signature
|
| 172 |
+
if isinstance(batch_probs, pd.DataFrame):
|
| 173 |
+
scores = batch_probs[top_class].values
|
| 174 |
+
else:
|
| 175 |
+
scores = batch_probs[:, top_class] if len(batch_probs.shape) > 1 else batch_probs
|
| 176 |
+
|
| 177 |
+
# 3. Calculate importance based on score drops
|
| 178 |
+
for idx, (y1, y2, x1, x2) in enumerate(coords):
|
| 179 |
+
drop = base_score - scores[idx]
|
| 180 |
+
# If the score dropped, this region was important
|
| 181 |
+
importance = max(0, drop)
|
| 182 |
+
saliency_map[y1:y2, x1:x2] += importance
|
| 183 |
+
heatmap_counts[y1:y2, x1:x2] += 1
|
| 184 |
+
|
| 185 |
+
# Average overlaps
|
| 186 |
+
heatmap_counts[heatmap_counts == 0] = 1
|
| 187 |
+
saliency_avg = saliency_map / heatmap_counts
|
| 188 |
+
|
| 189 |
+
# Normalize 0-255
|
| 190 |
+
if np.max(saliency_avg) > 0:
|
| 191 |
+
saliency_avg = (saliency_avg / np.max(saliency_avg)) * 255
|
| 192 |
+
saliency_avg = np.uint8(saliency_avg)
|
| 193 |
+
|
| 194 |
+
# 4. Generate visual overlay
|
| 195 |
+
colormap = cv2.applyColorMap(saliency_avg, cv2.COLORMAP_JET)
|
| 196 |
+
|
| 197 |
+
orig_cv = cv2.cvtColor(np.array(original_img), cv2.COLORRGB_BGR) # To match cv2
|
| 198 |
+
final_overlay = cv2.addWeighted(orig_cv, 0.6, colormap, 0.4, 0)
|
| 199 |
+
final_rgb = cv2.cvtColor(final_overlay, cv2.COLORBGR_RGB)
|
| 200 |
+
|
| 201 |
+
# Cleanup
|
| 202 |
+
for p in occluded_paths:
|
| 203 |
+
try: os.remove(p)
|
| 204 |
+
except: pass
|
| 205 |
+
|
| 206 |
+
fig = plt.figure(figsize=(8, 8))
|
| 207 |
+
plt.imshow(final_rgb)
|
| 208 |
+
plt.title(f"XAI Occlusion Heatmap (Target: {top_class})")
|
| 209 |
+
plt.axis('off')
|
| 210 |
+
plt.tight_layout()
|
| 211 |
+
|
| 212 |
+
return fig
|
| 213 |
+
|
| 214 |
+
except Exception as e:
|
| 215 |
+
warnings.warn(f"CV XAI generation failed: {e}")
|
| 216 |
+
return None
|