Spaces:
Running
Running
File size: 2,113 Bytes
76db545 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """
Phase 4a: Merge LoRA adapters and export language-specific ONNX models.
Validates that ONNX WER is within 2% of PyTorch baseline.
Usage:
python scripts/export_onnx.py
"""
import logging
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")
import yaml
from src.optimization.onnx_exporter import ONNXExporter
def export_language(language: str, adapter_path: str, config: dict) -> None:
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor
hf_token = os.getenv("HF_TOKEN")
model_id = config["model"]["id"]
print(f"\n[{language.upper()}] Loading base model...")
base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token)
processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)
print(f"[{language.upper()}] Loading adapter from {adapter_path}...")
peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language)
output_dir = f"{config['paths']['models']}/onnx/{language}"
exporter = ONNXExporter()
result_path = exporter.merge_and_export(peft_model, processor, output_dir, language)
print(f"[{language.upper()}] ONNX exported to: {result_path}")
def main() -> None:
with open("configs/base_config.yaml") as f:
config = yaml.safe_load(f)
print("=" * 60)
print("Sahel-Agri Voice AI — ONNX Export")
print("=" * 60)
bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")
for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]:
if Path(adapter_path).exists():
export_language(language, adapter_path, config)
else:
print(f"\nSkipping {language}: adapter not found at {adapter_path}")
print("\nExport complete.")
if __name__ == "__main__":
main()
|