Spaces:
Running
Running
| """ | |
| Phase 4a: Merge LoRA adapters and export language-specific ONNX models. | |
| Validates that ONNX WER is within 2% of PyTorch baseline. | |
| Usage: | |
| python scripts/export_onnx.py | |
| """ | |
| import logging | |
| import os | |
| import sys | |
| from pathlib import Path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s") | |
| import yaml | |
| from src.optimization.onnx_exporter import ONNXExporter | |
| def export_language(language: str, adapter_path: str, config: dict) -> None: | |
| from peft import PeftModel | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| hf_token = os.getenv("HF_TOKEN") | |
| model_id = config["model"]["id"] | |
| print(f"\n[{language.upper()}] Loading base model...") | |
| base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token) | |
| processor = WhisperProcessor.from_pretrained(model_id, token=hf_token) | |
| print(f"[{language.upper()}] Loading adapter from {adapter_path}...") | |
| peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language) | |
| output_dir = f"{config['paths']['models']}/onnx/{language}" | |
| exporter = ONNXExporter() | |
| result_path = exporter.merge_and_export(peft_model, processor, output_dir, language) | |
| print(f"[{language.upper()}] ONNX exported to: {result_path}") | |
| def main() -> None: | |
| with open("configs/base_config.yaml") as f: | |
| config = yaml.safe_load(f) | |
| print("=" * 60) | |
| print("Sahel-Agri Voice AI — ONNX Export") | |
| print("=" * 60) | |
| bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara") | |
| fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula") | |
| for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]: | |
| if Path(adapter_path).exists(): | |
| export_language(language, adapter_path, config) | |
| else: | |
| print(f"\nSkipping {language}: adapter not found at {adapter_path}") | |
| print("\nExport complete.") | |
| if __name__ == "__main__": | |
| main() | |