ground-zero / scripts /export_onnx.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Phase 4a: Merge LoRA adapters and export language-specific ONNX models.
Validates that ONNX WER is within 2% of PyTorch baseline.
Usage:
python scripts/export_onnx.py
"""
import logging
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
import yaml
from src.optimization.onnx_exporter import ONNXExporter
def export_language(language: str, adapter_path: str, config: dict) -> None:
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor
hf_token = os.getenv("HF_TOKEN")
model_id = config["model"]["id"]
print(f"\n[{language.upper()}] Loading base model...")
base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token)
processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
print(f"[{language.upper()}] Loading adapter from {adapter_path}...")
peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language)
output_dir = f"{config['paths']['models']}/onnx/{language}"
exporter = ONNXExporter()
result_path = exporter.merge_and_export(peft_model, processor, output_dir, language)
print(f"[{language.upper()}] ONNX exported to: {result_path}")
def main() -> None:
with open("configs/base_config.yaml") as f:
config = yaml.safe_load(f)
print("=" * 60)
print("Sahel-Agri Voice AI — ONNX Export")
print("=" * 60)
bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]:
if Path(adapter_path).exists():
export_language(language, adapter_path, config)
else:
print(f"\nSkipping {language}: adapter not found at {adapter_path}")
print("\nExport complete.")
if __name__ == "__main__":
main()