""" Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas). Note: Whisper's encoder and decoder are exported as separate TFLite models and orchestrated together at inference time. Requires: onnx-tf, tensorflow (install separately — large dependencies) """ from __future__ import annotations import logging from pathlib import Path logger = logging.getLogger(__name__) class TFLiteConverter: """Converts ONNX Whisper models to TFLite format for edge deployment.""" def convert( self, onnx_encoder_path: str, onnx_decoder_path: str, output_dir: str, quantize: bool = True, ) -> dict[str, Path]: """ Convert encoder and decoder ONNX models to TFLite. Returns paths to the generated .tflite files. """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) encoder_tflite = output_path / "encoder.tflite" decoder_tflite = output_path / "decoder.tflite" logger.info("Converting encoder ONNX → TFLite...") self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize) logger.info("Converting decoder ONNX → TFLite...") self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize) return {"encoder": encoder_tflite, "decoder": decoder_tflite} def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None: """Convert a single ONNX model to TFLite via onnx-tf + tensorflow.""" try: import onnx import onnx_tf import tensorflow as tf except ImportError as e: raise ImportError( "TFLite conversion requires onnx-tf and tensorflow. " "Install with: pip install onnx-tf tensorflow" ) from e import tempfile # Step 1: ONNX → TensorFlow SavedModel with tempfile.TemporaryDirectory() as tmp_dir: onnx_model = onnx.load(onnx_path) tf_rep = onnx_tf.backend.prepare(onnx_model) tf_rep.export_graph(tmp_dir) # Step 2: TF SavedModel → TFLite converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir) if quantize: converter.optimizations = [tf.lite.Optimize.DEFAULT] tflite_model = converter.convert() with open(output_path, "wb") as f: f.write(tflite_model) size_mb = Path(output_path).stat().st_size / 1e6 logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb)