""" Merges LoRA adapter weights into the backbone and exports to ONNX. Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime). Requires: optimum[onnxruntime] """ from __future__ import annotations import logging from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: from peft import PeftModel from transformers import WhisperProcessor logger = logging.getLogger(__name__) class ONNXExporter: """Merges a LoRA PeftModel into its base model and exports to ONNX.""" def merge_and_export( self, peft_model: "PeftModel", processor: "WhisperProcessor", output_dir: str, language: str, ) -> Path: """ 1. Merge LoRA weights into base model (merge_and_unload) 2. Export merged model to ONNX via optimum Returns the output directory path. """ output_path = Path(output_dir) / language output_path.mkdir(parents=True, exist_ok=True) logger.info("Merging LoRA adapter '%s' into base model...", language) merged_model = peft_model.merge_and_unload() merged_model.eval() logger.info("Exporting to ONNX: %s", output_path) self._export_with_optimum(merged_model, processor, str(output_path)) return output_path def _export_with_optimum( self, merged_model, processor: "WhisperProcessor", output_dir: str, ) -> None: """Use optimum's ONNX export pipeline.""" from optimum.exporters.onnx import main_export # Save merged model to a temp directory first import tempfile with tempfile.TemporaryDirectory() as tmp_dir: logger.info("Saving merged model to temp dir for export...") merged_model.save_pretrained(tmp_dir) processor.save_pretrained(tmp_dir) logger.info("Running optimum ONNX export...") main_export( model_name_or_path=tmp_dir, output=output_dir, task="automatic-speech-recognition", opset=17, optimize="O2", ) logger.info("ONNX export complete: %s", output_dir) def validate( self, onnx_dir: str, processor: "WhisperProcessor", test_audio_arrays: list, sample_rate: int = 16_000, reference_texts: list[str] | None = None, ) -> dict: """ Run inference with the exported ONNX model and compute WER vs. references. """ import numpy as np from optimum.onnxruntime import ORTModelForSpeechSeq2Seq logger.info("Validating ONNX model at %s...", onnx_dir) ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir) transcriptions = [] for audio in test_audio_arrays: inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt") outputs = ort_model.generate(inputs.input_features) text = processor.batch_decode(outputs, skip_special_tokens=True)[0] transcriptions.append(text) result = {"transcriptions": transcriptions} if reference_texts: import jiwer wer = jiwer.wer(reference_texts, transcriptions) result["wer"] = round(wer, 4) return result