ground-zero / tests /test_data_pipeline.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
"""
Unit tests for the data pipeline: augmentation, feature extractor, agri dictionary.
"""
from __future__ import annotations
import numpy as np
import pytest
class TestFieldNoiseAugmenter:
def test_augmenter_without_noise_files(self, tmp_path):
"""Augmenter with empty noise_dir should fall back to Gaussian-only and still be ready."""
config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.6}}
from src.data.augmentation import FieldNoiseAugmenter
augmenter = FieldNoiseAugmenter(str(tmp_path), config)
assert augmenter.is_ready()
assert augmenter._gaussian_only
def test_augmenter_output_shape(self, tmp_path):
"""Augmented audio should have the same length as input."""
config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 1.0}}
from src.data.augmentation import FieldNoiseAugmenter
augmenter = FieldNoiseAugmenter(str(tmp_path), config)
audio = np.random.randn(16000).astype(np.float32) * 0.01
augmented = augmenter.augment(audio, 16000)
assert augmented.shape == audio.shape
def test_augmenter_no_crash_on_silent_audio(self, tmp_path):
"""Silent audio (all zeros) should not crash the augmenter."""
config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.5}}
from src.data.augmentation import FieldNoiseAugmenter
augmenter = FieldNoiseAugmenter(str(tmp_path), config)
audio = np.zeros(16000, dtype=np.float32)
result = augmenter.augment(audio, 16000)
assert result is not None
class TestAgriculturalDictionary:
def test_bambara_vocab_not_empty(self):
from src.data.agri_dictionary import BAMBARA_VOCAB
assert len(BAMBARA_VOCAB) > 0
def test_fula_vocab_not_empty(self):
from src.data.agri_dictionary import FULA_VOCAB
assert len(FULA_VOCAB) > 0
def test_get_vocab_invalid_language(self):
from src.data.agri_dictionary import AgriculturalDictionary
d = AgriculturalDictionary()
with pytest.raises(ValueError):
d.get_vocab("xyz")
def test_prompt_text_contains_terms(self):
from src.data.agri_dictionary import AgriculturalDictionary
d = AgriculturalDictionary()
prompt = d.get_prompt_text("bam")
assert "sɛnɛ" in prompt
assert "kaba" in prompt
class TestDataCollator:
def test_collator_pads_labels(self):
"""DataCollator should pad labels and replace pad tokens with -100."""
from unittest.mock import MagicMock
import torch
from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding
# Mock processor
processor = MagicMock()
processor.feature_extractor.pad.return_value = {
"input_features": torch.zeros(2, 80, 3000)
}
# Simulate padded labels batch
padded_labels = MagicMock()
padded_labels.input_ids = torch.tensor([[1, 2, 3, 0], [1, 4, 0, 0]])
padded_labels.attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])
processor.tokenizer.pad.return_value = padded_labels
collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=1,
)
features = [
{"input_features": np.zeros((80, 3000)), "labels": [1, 2, 3]},
{"input_features": np.zeros((80, 3000)), "labels": [1, 4]},
]
batch = collator(features)
assert "labels" in batch
# -100 should appear where attention_mask is 0
assert -100 in batch["labels"].tolist()[0] or -100 in batch["labels"].tolist()[1]