ground-zero / scripts /verify_baseline.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Phase 1 smoke test: load Whisper, run inference on a sample audio clip.
Prints model info, inference time, GPU memory usage, and sample transcript.
Usage:
python scripts/verify_baseline.py
"""
import sys
import time
from pathlib import Path
# Allow imports from project root
sys.path.insert(0, str(Path(__file__).parent.parent))
import numpy as np
import torch
def main() -> None:
from src.engine.whisper_base import WhisperBackbone
print("=" * 60)
print("Sahel-Agri Voice AI — Baseline Verification")
print("=" * 60)
# 1. Check environment
print(f"\nPython: {sys.version.split()[0]}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# 2. Load model
print("\n[1/3] Loading backbone model...")
t0 = time.time()
backbone = WhisperBackbone("configs/base_config.yaml")
backbone.load(device="cuda")
load_time = time.time() - t0
print(f" Loaded in {load_time:.1f}s")
if torch.cuda.is_available():
used = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f" GPU memory: {used:.2f} GB allocated / {reserved:.2f} GB reserved")
# 3. Generate synthetic test audio (1 second of silence with slight noise)
print("\n[2/3] Generating test audio (1s white noise)...")
sample_rate = 16000
duration = 1.0
audio = np.random.randn(int(sample_rate * duration)).astype(np.float32) * 0.01
# 4. Run inference
print("[3/3] Running inference...")
processor = backbone.processor
model = backbone.model
inputs = processor(audio, sampling_rate=sample_rate, return_tensors="pt")
input_features = inputs.input_features.to(backbone.device)
if backbone.device == "cuda":
input_features = input_features.half()
t0 = time.time()
with torch.no_grad():
predicted_ids = model.generate(input_features, max_new_tokens=50)
infer_time = time.time() - t0
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
print(f"\n{'=' * 60}")
print(f"Transcript: '{transcription}' (noise input — blank expected)")
print(f"Inference time: {infer_time * 1000:.0f} ms")
print(f"\nBaseline verification PASSED.")
print(f"{'=' * 60}")
if __name__ == "__main__":
main()