Spaces:
Sleeping
Sleeping
File size: 2,643 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | """
Converts ONNX models to TFLite for offline edge deployment (Android phones in rural areas).
Note: Whisper's encoder and decoder are exported as separate TFLite models and
orchestrated together at inference time.
Requires: onnx-tf, tensorflow (install separately — large dependencies)
"""
from __future__ import annotations
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
class TFLiteConverter:
"""Converts ONNX Whisper models to TFLite format for edge deployment."""
def convert(
self,
onnx_encoder_path: str,
onnx_decoder_path: str,
output_dir: str,
quantize: bool = True,
) -> dict[str, Path]:
"""
Convert encoder and decoder ONNX models to TFLite.
Returns paths to the generated .tflite files.
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
encoder_tflite = output_path / "encoder.tflite"
decoder_tflite = output_path / "decoder.tflite"
logger.info("Converting encoder ONNX → TFLite...")
self._onnx_to_tflite(onnx_encoder_path, str(encoder_tflite), quantize=quantize)
logger.info("Converting decoder ONNX → TFLite...")
self._onnx_to_tflite(onnx_decoder_path, str(decoder_tflite), quantize=quantize)
return {"encoder": encoder_tflite, "decoder": decoder_tflite}
def _onnx_to_tflite(self, onnx_path: str, output_path: str, quantize: bool) -> None:
"""Convert a single ONNX model to TFLite via onnx-tf + tensorflow."""
try:
import onnx
import onnx_tf
import tensorflow as tf
except ImportError as e:
raise ImportError(
"TFLite conversion requires onnx-tf and tensorflow. "
"Install with: pip install onnx-tf tensorflow"
) from e
import tempfile
# Step 1: ONNX → TensorFlow SavedModel
with tempfile.TemporaryDirectory() as tmp_dir:
onnx_model = onnx.load(onnx_path)
tf_rep = onnx_tf.backend.prepare(onnx_model)
tf_rep.export_graph(tmp_dir)
# Step 2: TF SavedModel → TFLite
converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
if quantize:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with open(output_path, "wb") as f:
f.write(tflite_model)
size_mb = Path(output_path).stat().st_size / 1e6
logger.info("TFLite model saved: %s (%.1f MB)", output_path, size_mb)
|