Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import Any, Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| # Global flags for availability | |
| ONNX_AVAILABLE = None | |
| def _check_onnx_availability(): | |
| global ONNX_AVAILABLE | |
| if ONNX_AVAILABLE is not None: | |
| return ONNX_AVAILABLE | |
| try: | |
| import onnx | |
| import onnxruntime as ort | |
| ONNX_AVAILABLE = True | |
| except Exception as e: | |
| logger.warning(f"ONNX or ONNXRuntime not available: {e}") | |
| ONNX_AVAILABLE = False | |
| return ONNX_AVAILABLE | |
| def export_to_onnx(model: Any, model_type: str, target_col: str, output_path: str, input_sample: Optional[Any] = None) -> str: | |
| """ | |
| Exports a trained model to ONNX format. | |
| Supports: flaml, pycaret, autogluon (tabular), autokeras (tensorflow). | |
| """ | |
| if not _check_onnx_availability(): | |
| raise ImportError("ONNX or ONNXRuntime is not available in this environment.") | |
| import onnx | |
| import pandas as pd | |
| import numpy as np | |
| logger.info(f"Exporting {model_type} model to ONNX: {output_path}") | |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) | |
| try: | |
| if model_type in ["flaml", "pycaret", "tpot"]: | |
| from skl2onnx import to_onnx | |
| if input_sample is None: | |
| raise ValueError("input_sample is required for scikit-learn based ONNX export") | |
| if isinstance(input_sample, pd.DataFrame) and target_col in input_sample.columns: | |
| input_sample = input_sample.drop(columns=[target_col]) | |
| onx = to_onnx(model, input_sample[:1], initial_types=None) | |
| with open(output_path, "wb") as f: | |
| f.write(onx.SerializeToString()) | |
| elif model_type == "autokeras": | |
| import tf2onnx | |
| import tensorflow as tf | |
| if input_sample is None: | |
| raise ValueError("input_sample is required for TensorFlow/AutoKeras ONNX export") | |
| input_signature = [tf.TensorSpec([None] + list(input_sample.shape[1:]), tf.float32, name='input')] | |
| onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=13) | |
| onnx.save_model(onnx_model, output_path) | |
| elif model_type == "autogluon": | |
| try: | |
| model.export_onnx(output_path) | |
| except AttributeError: | |
| logger.warning("AutoGluon model does not support direct export_onnx.") | |
| raise NotImplementedError("AutoGluon ONNX export fallback not implemented.") | |
| else: | |
| raise ValueError(f"Unsupported model type for ONNX export: {model_type}") | |
| logger.info(f"Successfully exported model to {output_path}") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to export {model_type} model to ONNX: {e}") | |
| raise | |
| def load_onnx_session(onnx_path: str): | |
| """Loads an ONNX model into an inference session.""" | |
| if not _check_onnx_availability(): | |
| raise ImportError("ONNXRuntime is not available.") | |
| import onnxruntime as ort | |
| if not os.path.exists(onnx_path): | |
| raise FileNotFoundError(f"ONNX file not found: {onnx_path}") | |
| return ort.InferenceSession(onnx_path) | |
| def predict_onnx(session: Any, df: Any) -> Any: | |
| """Runs inference on a DataFrame using an ONNX session.""" | |
| import numpy as np | |
| inputs = {} | |
| for node in session.get_inputs(): | |
| name = node.name | |
| if name in df.columns: | |
| inputs[name] = df[[name]].values.astype(np.float32) | |
| else: | |
| if len(session.get_inputs()) == 1: | |
| inputs[name] = df.values.astype(np.float32) | |
| break | |
| outputs = session.run(None, inputs) | |
| return outputs[0] | |