""" BitsAndBytes quantization for GPU-constrained deployment. 4-bit NF4: reduces Whisper-large-v3-turbo from ~3GB to ~1GB VRAM. 8-bit: intermediate option with less accuracy loss. """ from __future__ import annotations import logging import time from typing import TYPE_CHECKING import torch from transformers import BitsAndBytesConfig, WhisperForConditionalGeneration, WhisperProcessor if TYPE_CHECKING: pass logger = logging.getLogger(__name__) def load_4bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration: """Load Whisper with 4-bit NF4 quantization. Reduces VRAM to ~1GB.""" bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, ) logger.info("Loading %s with 4-bit NF4 quantization...", model_id) model = WhisperForConditionalGeneration.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", token=hf_token, ) return model def load_8bit(model_id: str, hf_token: str | None = None) -> WhisperForConditionalGeneration: """Load Whisper with 8-bit quantization. Reduces VRAM to ~1.5GB.""" bnb_config = BitsAndBytesConfig(load_in_8bit=True) logger.info("Loading %s with 8-bit quantization...", model_id) model = WhisperForConditionalGeneration.from_pretrained( model_id, quantization_config=bnb_config, device_map="auto", token=hf_token, ) return model class ModelQuantizer: """Benchmarks quantized vs full-precision models.""" def __init__(self, model_id: str, hf_token: str | None = None) -> None: self.model_id = model_id self.hf_token = hf_token def benchmark( self, model: WhisperForConditionalGeneration, processor: WhisperProcessor, test_audio_arrays: list, sample_rate: int = 16_000, ) -> dict: """Measure latency and memory for a list of audio arrays.""" import numpy as np device = next(model.parameters()).device latencies = [] for audio in test_audio_arrays: inputs = processor.feature_extractor(audio, sampling_rate=sample_rate, return_tensors="pt") features = inputs.input_features.to(device) if device.type == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): model.generate(features, max_new_tokens=50) if device.type == "cuda": torch.cuda.synchronize() latencies.append((time.perf_counter() - t0) * 1000) result = { "mean_latency_ms": round(sum(latencies) / len(latencies), 1), "max_latency_ms": round(max(latencies), 1), } if torch.cuda.is_available(): result["vram_allocated_gb"] = round(torch.cuda.memory_allocated() / 1e9, 2) return result