ground-zero / src /optimization /onnx_exporter.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Merges LoRA adapter weights into the backbone and exports to ONNX.
Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime).
Requires: optimum[onnxruntime]
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from peft import PeftModel
from transformers import WhisperProcessor
logger = logging.getLogger(__name__)
class ONNXExporter:
"""Merges a LoRA PeftModel into its base model and exports to ONNX."""
def merge_and_export(
self,
peft_model: "PeftModel",
processor: "WhisperProcessor",
output_dir: str,
language: str,
) -> Path:
"""
1. Merge LoRA weights into base model (merge_and_unload)
2. Export merged model to ONNX via optimum
Returns the output directory path.
"""
output_path = Path(output_dir) / language
output_path.mkdir(parents=True, exist_ok=True)
logger.info("Merging LoRA adapter '%s' into base model...", language)
merged_model = peft_model.merge_and_unload()
merged_model.eval()
logger.info("Exporting to ONNX: %s", output_path)
self._export_with_optimum(merged_model, processor, str(output_path))
return output_path
def _export_with_optimum(
self,
merged_model,
processor: "WhisperProcessor",
output_dir: str,
) -> None:
"""Use optimum's ONNX export pipeline."""
from optimum.exporters.onnx import main_export
# Save merged model to a temp directory first
import tempfile
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info("Saving merged model to temp dir for export...")
merged_model.save_pretrained(tmp_dir)
processor.save_pretrained(tmp_dir)
logger.info("Running optimum ONNX export...")
main_export(
model_name_or_path=tmp_dir,
output=output_dir,
task="automatic-speech-recognition",
opset=17,
optimize="O2",
)
logger.info("ONNX export complete: %s", output_dir)
def validate(
self,
onnx_dir: str,
processor: "WhisperProcessor",
test_audio_arrays: list,
sample_rate: int = 16_000,
reference_texts: list[str] | None = None,
) -> dict:
"""
Run inference with the exported ONNX model and compute WER vs. references.
"""
import numpy as np
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
logger.info("Validating ONNX model at %s...", onnx_dir)
ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir)
transcriptions = []
for audio in test_audio_arrays:
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
outputs = ort_model.generate(inputs.input_features)
text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
transcriptions.append(text)
result = {"transcriptions": transcriptions}
if reference_texts:
import jiwer
wer = jiwer.wer(reference_texts, transcriptions)
result["wer"] = round(wer, 4)
return result