ground-zero / src /optimization /tflite_converter.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas).
Note: Whisper's encoder and decoder are exported as separate TFLite models and
orchestrated together at inference time.
Requires: onnx-tf, tensorflow (install separately — large dependencies)
"""
from __future__ import annotations
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
class TFLiteConverter:
"""Converts ONNX Whisper models to TFLite format for edge deployment."""
def convert(
self,
onnx_encoder_path: str,
onnx_decoder_path: str,
output_dir: str,
quantize: bool = True,
) -> dict[str, Path]:
"""
Convert encoder and decoder ONNX models to TFLite.
Returns paths to the generated .tflite files.
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
encoder_tflite = output_path / "encoder.tflite"
decoder_tflite = output_path / "decoder.tflite"
logger.info("Converting encoder ONNX → TFLite...")
self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize)
logger.info("Converting decoder ONNX → TFLite...")
self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize)
return {"encoder": encoder_tflite, "decoder": decoder_tflite}
def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None:
"""Convert a single ONNX model to TFLite via onnx-tf + tensorflow."""
try:
import onnx
import onnx_tf
import tensorflow as tf
except ImportError as e:
raise ImportError(
"TFLite conversion requires onnx-tf and tensorflow. "
"Install with: pip install onnx-tf tensorflow"
) from e
import tempfile
# Step 1: ONNX → TensorFlow SavedModel
with tempfile.TemporaryDirectory() as tmp_dir:
onnx_model = onnx.load(onnx_path)
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(tmp_dir)
# Step 2: TF SavedModel → TFLite
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
if quantize:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open(output_path, "wb") as f:
f.write(tflite_model)
size_mb = Path(output_path).stat().st_size / 1e6
logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb)