""" Unit Test Suite for TinyBert-CNN Intent Classifier Pipeline. Tests: model init, dataset tokenization, forward pass, predict, compound splitter, dataset generator output, and auto_trainer state I/O. """ import unittest import os import sys import json import tempfile import torch import pandas as pd # Ensure the project directory is on sys.path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from TinyBert import IntentClassifier, IntentDataset, CompoundSentenceSplitter, TinyBertCNN # ───────────────────────────────────────────────────────────────────── # 1. MODEL INITIALIZATION # ───────────────────────────────────────────────────────────────────── class TestModelInit(unittest.TestCase): """Test that the model initializes correctly.""" @classmethod def setUpClass(cls): cls.classifier = IntentClassifier(num_classes=5) def test_model_instance(self): self.assertIsInstance(self.classifier.model, TinyBertCNN) def test_num_classes(self): self.assertEqual(self.classifier.num_classes, 5) def test_device_assigned(self): self.assertIsNotNone(self.classifier.device) def test_tokenizer_loaded(self): self.assertIsNotNone(self.classifier.tokenizer) def test_model_has_batchnorm(self): """Verify BatchNorm layers were added.""" self.assertTrue(hasattr(self.classifier.model, 'batchnorms')) self.assertEqual(len(self.classifier.model.batchnorms), 3) # 3 filter sizes def test_model_has_hidden_fc(self): """Verify hidden FC layer exists.""" self.assertTrue(hasattr(self.classifier.model, 'fc_hidden')) self.assertTrue(hasattr(self.classifier.model, 'bn_hidden')) # ───────────────────────────────────────────────────────────────────── # 2. INTENT DATASET # ───────────────────────────────────────────────────────────────────── class TestIntentDataset(unittest.TestCase): """Test tokenization and tensor shapes from IntentDataset.""" @classmethod def setUpClass(cls): cls.classifier = IntentClassifier(num_classes=5) cls.sample_data = [ {'student_input': 'How do I use for loops?', 'session_context': 'topic:For Loops | prev:If/Else | ability:If/Else:85% | emotion:engaged | pace:normal | slides:14,15,16', 'label': 0}, {'student_input': "What's the weather?", 'session_context': 'topic:Variables | prev:None | ability:N/A | emotion:bored | pace:slow | slides:5,6,7', 'label': 1}, ] cls.dataset = IntentDataset(cls.sample_data, cls.classifier.tokenizer, max_length=128) def test_dataset_length(self): self.assertEqual(len(self.dataset), 2) def test_output_keys(self): item = self.dataset[0] self.assertIn('input_ids', item) self.assertIn('attention_mask', item) self.assertIn('labels', item) def test_tensor_shapes(self): item = self.dataset[0] self.assertEqual(item['input_ids'].shape, torch.Size([128])) self.assertEqual(item['attention_mask'].shape, torch.Size([128])) def test_label_type(self): item = self.dataset[0] self.assertEqual(item['labels'].dtype, torch.long) def test_token_type_ids_present(self): """TinyBERT should produce token_type_ids for sentence pairs.""" item = self.dataset[0] if 'token_type_ids' in item: self.assertEqual(item['token_type_ids'].shape, torch.Size([128])) def test_handles_string_labels(self): data = [{'student_input': 'test', 'session_context': 'ctx', 'label': 'Pace-Related'}] ds = IntentDataset(data, self.classifier.tokenizer) item = ds[0] self.assertEqual(item['labels'].item(), 3) # ───────────────────────────────────────────────────────────────────── # 3. FORWARD PASS # ───────────────────────────────────────────────────────────────────── class TestForwardPass(unittest.TestCase): """Test the TinyBertCNN forward pass with dummy data.""" @classmethod def setUpClass(cls): cls.classifier = IntentClassifier(num_classes=5) def test_output_shape(self): batch_size = 4 seq_len = 128 input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device) self.classifier.model.eval() with torch.no_grad(): logits = self.classifier.model(input_ids, attention_mask) self.assertEqual(logits.shape, torch.Size([batch_size, 5])) def test_output_with_token_type_ids(self): batch_size = 2 seq_len = 128 input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(self.classifier.device) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long).to(self.classifier.device) token_type_ids = torch.zeros(batch_size, seq_len, dtype=torch.long).to(self.classifier.device) self.classifier.model.eval() with torch.no_grad(): logits = self.classifier.model(input_ids, attention_mask, token_type_ids=token_type_ids) self.assertEqual(logits.shape, torch.Size([batch_size, 5])) def test_single_sample(self): """Ensure single-sample batches don't crash (important for BatchNorm).""" input_ids = torch.randint(0, 1000, (1, 128)).to(self.classifier.device) attention_mask = torch.ones(1, 128, dtype=torch.long).to(self.classifier.device) self.classifier.model.eval() with torch.no_grad(): logits = self.classifier.model(input_ids, attention_mask) self.assertEqual(logits.shape, torch.Size([1, 5])) # ───────────────────────────────────────────────────────────────────── # 4. PREDICT # ───────────────────────────────────────────────────────────────────── class TestPredict(unittest.TestCase): """Test the predict() method with real text.""" @classmethod def setUpClass(cls): cls.classifier = IntentClassifier(num_classes=5) def test_predict_with_context(self): preds, probs = self.classifier.predict( ["How do loops work?"], ["topic:For Loops | prev:None | ability:N/A | emotion:neutral | pace:normal | slides:10,11,12"] ) self.assertEqual(len(preds), 1) self.assertEqual(probs.shape[1], 5) def test_predict_without_context(self): preds, probs = self.classifier.predict(["I'm feeling frustrated"]) self.assertEqual(len(preds), 1) def test_predict_empty_string(self): """Empty input should not crash.""" preds, probs = self.classifier.predict([""]) self.assertEqual(len(preds), 1) def test_predict_multiple(self): preds, probs = self.classifier.predict( ["Hello", "Can you repeat?", "Speed up please"], ["ctx1", "ctx2", "ctx3"] ) self.assertEqual(len(preds), 3) # ───────────────────────────────────────────────────────────────────── # 5. COMPOUND SENTENCE SPLITTER # ───────────────────────────────────────────────────────────────────── class TestCompoundSplitter(unittest.TestCase): """Test the CompoundSentenceSplitter edge cases.""" @classmethod def setUpClass(cls): cls.splitter = CompoundSentenceSplitter() def test_compound_question_splits(self): result = self.splitter.split_compound_question( "What is a variable and how do I use it?" ) self.assertGreaterEqual(len(result), 2) def test_single_question_no_split(self): result = self.splitter.split_compound_question("How do loops work?") self.assertEqual(len(result), 1) def test_non_question_no_split(self): result = self.splitter.split_compound_question("I like programming.") self.assertEqual(len(result), 1) def test_multiple_question_marks(self): result = self.splitter.split_compound_question("What is a loop? How does it work?") self.assertEqual(len(result), 2) def test_empty_string(self): result = self.splitter.split_compound_question("") self.assertEqual(len(result), 1) # ───────────────────────────────────────────────────────────────────── # 6. DATASET GENERATOR # ───────────────────────────────────────────────────────────────────── class TestDatasetGenerator(unittest.TestCase): """Test that the dataset generator produces correct output.""" @classmethod def setUpClass(cls): # Generate a small dataset from dataset_generator import build_dataset cls.original_dir = os.getcwd() cls.tmp_dir = tempfile.mkdtemp() os.chdir(cls.tmp_dir) build_dataset(num_samples_per_class=20) cls.train_df = pd.read_csv('data/train.csv') cls.val_df = pd.read_csv('data/val.csv') cls.test_df = pd.read_csv('data/test.csv') @classmethod def tearDownClass(cls): os.chdir(cls.original_dir) def test_columns_exist(self): for col in ['student_input', 'session_context', 'label', 'intent_name']: self.assertIn(col, self.train_df.columns) def test_three_splits_exist(self): self.assertGreater(len(self.train_df), 0) self.assertGreater(len(self.val_df), 0) self.assertGreater(len(self.test_df), 0) def test_all_classes_present(self): all_labels = set(self.train_df['label'].unique()) self.assertEqual(all_labels, {0, 1, 2, 3, 4}) def test_compact_context_format(self): ctx = self.train_df.iloc[0]['session_context'] self.assertIn('topic:', ctx) self.assertIn('prev:', ctx) self.assertIn('emotion:', ctx) def test_no_empty_inputs(self): self.assertFalse(self.train_df['student_input'].isna().any()) self.assertFalse(self.train_df['session_context'].isna().any()) # ───────────────────────────────────────────────────────────────────── # 7. AUTO TRAINER STATE # ───────────────────────────────────────────────────────────────────── class TestAutoTrainerState(unittest.TestCase): """Test load_state / save_state round-trip.""" def test_state_round_trip(self): from auto_trainer import load_state, save_state, STATE_FILE # Save original if exists original_exists = os.path.exists(STATE_FILE) original_content = None if original_exists: with open(STATE_FILE, 'r') as f: original_content = f.read() try: test_state = {"sessions_since_last_train": 42, "total_sessions": 100} save_state(test_state) loaded = load_state() self.assertEqual(loaded["sessions_since_last_train"], 42) self.assertEqual(loaded["total_sessions"], 100) finally: # Restore original if original_exists: with open(STATE_FILE, 'w') as f: f.write(original_content) elif os.path.exists(STATE_FILE): os.remove(STATE_FILE) def test_default_state(self): from auto_trainer import load_state, STATE_FILE backup = None if os.path.exists(STATE_FILE): with open(STATE_FILE, 'r') as f: backup = f.read() os.remove(STATE_FILE) try: state = load_state() self.assertEqual(state["sessions_since_last_train"], 0) self.assertEqual(state["total_sessions"], 0) finally: if backup: with open(STATE_FILE, 'w') as f: f.write(backup) if __name__ == '__main__': unittest.main(verbosity=2)