""" Phase 4a: Merge LoRA adapters and export language-specific ONNX models. Validates that ONNX WER is within 2% of PyTorch baseline. Usage: python scripts/export_onnx.py """ import logging import os import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from dotenv import load_dotenv load_dotenv() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s") import yaml from src.optimization.onnx_exporter import ONNXExporter def export_language(language: str, adapter_path: str, config: dict) -> None: from peft import PeftModel from transformers import WhisperForConditionalGeneration, WhisperProcessor hf_token = os.getenv("HF_TOKEN") model_id = config["model"]["id"] print(f"\n[{language.upper()}] Loading base model...") base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token) processor = WhisperProcessor.from_pretrained(model_id, token=hf_token) print(f"[{language.upper()}] Loading adapter from {adapter_path}...") peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language) output_dir = f"{config['paths']['models']}/onnx/{language}" exporter = ONNXExporter() result_path = exporter.merge_and_export(peft_model, processor, output_dir, language) print(f"[{language.upper()}] ONNX exported to: {result_path}") def main() -> None: with open("configs/base_config.yaml") as f: config = yaml.safe_load(f) print("=" * 60) print("Sahel-Agri Voice AI — ONNX Export") print("=" * 60) bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara") fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula") for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]: if Path(adapter_path).exists(): export_language(language, adapter_path, config) else: print(f"\nSkipping {language}: adapter not found at {adapter_path}") print("\nExport complete.") if __name__ == "__main__": main()