Spaces:
Sleeping
Sleeping
| """ | |
| Merges LoRA adapter weights into the backbone and exports to ONNX. | |
| Produces one ONNX file per language (ONNX cannot hot-swap adapters at runtime). | |
| Requires: optimum[onnxruntime] | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from peft import PeftModel | |
| from transformers import WhisperProcessor | |
| logger = logging.getLogger(__name__) | |
| class ONNXExporter: | |
| """Merges a LoRA PeftModel into its base model and exports to ONNX.""" | |
| def merge_and_export( | |
| self, | |
| peft_model: "PeftModel", | |
| processor: "WhisperProcessor", | |
| output_dir: str, | |
| language: str, | |
| ) -> Path: | |
| """ | |
| 1. Merge LoRA weights into base model (merge_and_unload) | |
| 2. Export merged model to ONNX via optimum | |
| Returns the output directory path. | |
| """ | |
| output_path = Path(output_dir) / language | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| logger.info("Merging LoRA adapter '%s' into base model...", language) | |
| merged_model = peft_model.merge_and_unload() | |
| merged_model.eval() | |
| logger.info("Exporting to ONNX: %s", output_path) | |
| self._export_with_optimum(merged_model, processor, str(output_path)) | |
| return output_path | |
| def _export_with_optimum( | |
| self, | |
| merged_model, | |
| processor: "WhisperProcessor", | |
| output_dir: str, | |
| ) -> None: | |
| """Use optimum's ONNX export pipeline.""" | |
| from optimum.exporters.onnx import main_export | |
| # Save merged model to a temp directory first | |
| import tempfile | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| logger.info("Saving merged model to temp dir for export...") | |
| merged_model.save_pretrained(tmp_dir) | |
| processor.save_pretrained(tmp_dir) | |
| logger.info("Running optimum ONNX export...") | |
| main_export( | |
| model_name_or_path=tmp_dir, | |
| output=output_dir, | |
| task="automatic-speech-recognition", | |
| opset=17, | |
| optimize="O2", | |
| ) | |
| logger.info("ONNX export complete: %s", output_dir) | |
| def validate( | |
| self, | |
| onnx_dir: str, | |
| processor: "WhisperProcessor", | |
| test_audio_arrays: list, | |
| sample_rate: int = 16_000, | |
| reference_texts: list[str] | None = None, | |
| ) -> dict: | |
| """ | |
| Run inference with the exported ONNX model and compute WER vs. references. | |
| """ | |
| import numpy as np | |
| from optimum.onnxruntime import ORTModelForSpeechSeq2Seq | |
| logger.info("Validating ONNX model at %s...", onnx_dir) | |
| ort_model = ORTModelForSpeechSeq2Seq.from_pretrained(onnx_dir) | |
| transcriptions = [] | |
| for audio in test_audio_arrays: | |
| inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt") | |
| outputs = ort_model.generate(inputs.input_features) | |
| text = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| transcriptions.append(text) | |
| result = {"transcriptions": transcriptions} | |
| if reference_texts: | |
| import jiwer | |
| wer = jiwer.wer(reference_texts, transcriptions) | |
| result["wer"] = round(wer, 4) | |
| return result | |