ground-zero / src /optimization /quantizer.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
BitsAndBytes quantization for GPU-constrained deployment.
4-bit NF4: reduces Whisper-large-v3-turbo from ~3GB to ~1GB VRAM.
8-bit: intermediate option with less accuracy loss.
"""
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING
import torch
from transformers import BitsAndBytesConfig, WhisperForConditionalGeneration, WhisperProcessor
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
def load_4bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
"""Load Whisper with 4-bit NF4 quantization. Reduces VRAM to ~1GB."""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
logger.info("Loading %s with 4-bit NF4 quantization...", model_id)
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
token=hf_token,
)
return model
def load_8bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration:
"""Load Whisper with 8-bit quantization. Reduces VRAM to ~1.5GB."""
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
logger.info("Loading %s with 8-bit quantization...", model_id)
model = WhisperForConditionalGeneration.from_pretrained(
model_id,
quantization_config=bnb_config,
device_map="auto",
token=hf_token,
)
return model
class ModelQuantizer:
"""Benchmarks quantized vs full-precision models."""
def __init__(self, model_id: str, hf_token: str | None = None) -> None:
self.model_id = model_id
self.hf_token = hf_token
def benchmark(
self,
model: WhisperForConditionalGeneration,
processor: WhisperProcessor,
test_audio_arrays: list,
sample_rate: int = 16_000,
) -> dict:
"""Measure latency and memory for a list of audio arrays."""
import numpy as np
device = next(model.parameters()).device
latencies = []
for audio in test_audio_arrays:
inputs = processor.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt")
features = inputs.input_features.to(device)
if device.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
with torch.no_grad():
model.generate(features, max_new_tokens=50)
if device.type == "cuda":
torch.cuda.synchronize()
latencies.append((time.perf_counter() - t0) * 1000)
result = {
"mean_latency_ms": round(sum(latencies) / len(latencies), 1),
"max_latency_ms": round(max(latencies), 1),
}
if torch.cuda.is_available():
result["vram_allocated_gb"] = round(torch.cuda.memory_allocated() / 1e9, 2)
return result