File size: 2,113 Bytes
76db545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Phase 4a: Merge LoRA adapters and export language-specific ONNX models.
Validates that ONNX WER is within 2% of PyTorch baseline.

Usage:
    python scripts/export_onnx.py
"""
import logging
import os
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s — %(message)s")

import yaml

from src.optimization.onnx_exporter import ONNXExporter


def export_language(language: str, adapter_path: str, config: dict) -> None:
    from peft import PeftModel
    from transformers import WhisperForConditionalGeneration, WhisperProcessor

    hf_token = os.getenv("HF_TOKEN")
    model_id = config["model"]["id"]

    print(f"\n[{language.upper()}] Loading base model...")
    base_model = WhisperForConditionalGeneration.from_pretrained(model_id, token=hf_token)
    processor = WhisperProcessor.from_pretrained(model_id, token=hf_token)

    print(f"[{language.upper()}] Loading adapter from {adapter_path}...")
    peft_model = PeftModel.from_pretrained(base_model, adapter_path, adapter_name=language)

    output_dir = f"{config['paths']['models']}/onnx/{language}"
    exporter = ONNXExporter()
    result_path = exporter.merge_and_export(peft_model, processor, output_dir, language)
    print(f"[{language.upper()}] ONNX exported to: {result_path}")


def main() -> None:
    with open("configs/base_config.yaml") as f:
        config = yaml.safe_load(f)

    print("=" * 60)
    print("Sahel-Agri Voice AI — ONNX Export")
    print("=" * 60)

    bambara_path = os.getenv("BAMBARA_ADAPTER_PATH", "./adapters/bambara")
    fula_path = os.getenv("FULA_ADAPTER_PATH", "./adapters/fula")

    for language, adapter_path in [("bambara", bambara_path), ("fula", fula_path)]:
        if Path(adapter_path).exists():
            export_language(language, adapter_path, config)
        else:
            print(f"\nSkipping {language}: adapter not found at {adapter_path}")

    print("\nExport complete.")


if __name__ == "__main__":
    main()