""" Unit tests for the data pipeline: augmentation, feature extractor, agri dictionary. """ from __future__ import annotations import numpy as np import pytest class TestFieldNoiseAugmenter: def test_augmenter_without_noise_files(self, tmp_path): """Augmenter with empty noise_dir should fall back to Gaussian-only and still be ready.""" config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.6}} from src.data.augmentation import FieldNoiseAugmenter augmenter = FieldNoiseAugmenter(str(tmp_path), config) assert augmenter.is_ready() assert augmenter._gaussian_only def test_augmenter_output_shape(self, tmp_path): """Augmented audio should have the same length as input.""" config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 1.0}} from src.data.augmentation import FieldNoiseAugmenter augmenter = FieldNoiseAugmenter(str(tmp_path), config) audio = np.random.randn(16000).astype(np.float32) * 0.01 augmented = augmenter.augment(audio, 16000) assert augmented.shape == audio.shape def test_augmenter_no_crash_on_silent_audio(self, tmp_path): """Silent audio (all zeros) should not crash the augmenter.""" config = {"audio": {"noise_snr_db_range": [5, 20], "augmentation_prob": 0.5}} from src.data.augmentation import FieldNoiseAugmenter augmenter = FieldNoiseAugmenter(str(tmp_path), config) audio = np.zeros(16000, dtype=np.float32) result = augmenter.augment(audio, 16000) assert result is not None class TestAgriculturalDictionary: def test_bambara_vocab_not_empty(self): from src.data.agri_dictionary import BAMBARA_VOCAB assert len(BAMBARA_VOCAB) > 0 def test_fula_vocab_not_empty(self): from src.data.agri_dictionary import FULA_VOCAB assert len(FULA_VOCAB) > 0 def test_get_vocab_invalid_language(self): from src.data.agri_dictionary import AgriculturalDictionary d = AgriculturalDictionary() with pytest.raises(ValueError): d.get_vocab("xyz") def test_prompt_text_contains_terms(self): from src.data.agri_dictionary import AgriculturalDictionary d = AgriculturalDictionary() prompt = d.get_prompt_text("bam") assert "sɛnɛ" in prompt assert "kaba" in prompt class TestDataCollator: def test_collator_pads_labels(self): """DataCollator should pad labels and replace pad tokens with -100.""" from unittest.mock import MagicMock import torch from src.data.feature_extractor import DataCollatorSpeechSeq2SeqWithPadding # Mock processor processor = MagicMock() processor.feature_extractor.pad.return_value = { "input_features": torch.zeros(2, 80, 3000) } # Simulate padded labels batch padded_labels = MagicMock() padded_labels.input_ids = torch.tensor([[1, 2, 3, 0], [1, 4, 0, 0]]) padded_labels.attention_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]) processor.tokenizer.pad.return_value = padded_labels collator = DataCollatorSpeechSeq2SeqWithPadding( processor=processor, decoder_start_token_id=1, ) features = [ {"input_features": np.zeros((80, 3000)), "labels": [1, 2, 3]}, {"input_features": np.zeros((80, 3000)), "labels": [1, 4]}, ] batch = collator(features) assert "labels" in batch # -100 should appear where attention_mask is 0 assert -100 in batch["labels"].tolist()[0] or -100 in batch["labels"].tolist()[1]