import os import shutil import tempfile import unittest from unittest.mock import patch import torch from transformers import LlamaConfig from specforge.modeling.draft.llama3_eagle import ( LlamaAttention, LlamaForCausalLMEagle3, LlamaMLP, LlamaRMSNorm, ) # from model_module import LlamaForCausalLMEagle3 class TestLlamaForCausalLMEagle3Loading(unittest.TestCase): def setUp(self): """Set up the test environment before each test.""" self.temp_dir = tempfile.mkdtemp() config_dict = { "architectures": ["LlamaForCausalLM"], "bos_token_id": 128000, "eos_token_id": 128001, "hidden_act": "silu", "hidden_size": 4096, "initializer_range": 0.02, "intermediate_size": 14336, "max_position_embeddings": 2048, "model_type": "llama", "num_attention_heads": 32, "num_key_value_heads": 8, "num_hidden_layers": 1, "pad_token_id": 0, "rms_norm_eps": 1e-05, "tie_word_embeddings": False, "torch_dtype": "float16", "transformers_version": "4.28.1", "use_cache": True, "vocab_size": 128256, "draft_vocab_size": 32000, } self.config = LlamaConfig(**config_dict) def tearDown(self): shutil.rmtree(self.temp_dir) def test_model_initialization(self): model = LlamaForCausalLMEagle3(self.config) self.assertIsInstance(model.midlayer.self_attn, LlamaAttention) self.assertIsInstance(model.midlayer.mlp, LlamaMLP) self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm) self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm) self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size) def test_save_pretrained(self): """Test the model's save_pretrained functionality.""" model = LlamaForCausalLMEagle3(self.config) self.config.save_pretrained(self.temp_dir) model_path = os.path.join(self.temp_dir, "pytorch_model.bin") torch.save(model.state_dict(), model_path) self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.json"))) self.assertTrue(os.path.exists(model_path)) @patch("transformers.modeling_utils.PreTrainedModel.from_pretrained") def test_from_pretrained_mock(self, mock_from_pretrained): """mock""" mock_model = LlamaForCausalLMEagle3(self.config) mock_from_pretrained.return_value = mock_model loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir) mock_from_pretrained.assert_called_once_with(self.temp_dir) self.assertIsInstance(loaded_model, LlamaForCausalLMEagle3) def test_model_forward_pass(self): """forward""" model = LlamaForCausalLMEagle3(self.config) model.eval() batch_size = 2 seq_len = 10 input_emb = torch.randn(batch_size, seq_len, self.config.hidden_size) hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size * 3) attention_mask = torch.ones(batch_size, seq_len) with torch.no_grad(): outputs = model( inputs_embeds=input_emb, hidden_states=hidden_states, attention_mask=attention_mask, ) self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size)) def test_state_dict_compatibility(self): model1 = LlamaForCausalLMEagle3(self.config) model2 = LlamaForCausalLMEagle3(self.config) state_dict = model1.state_dict() model2.load_state_dict(state_dict) for name, param1 in model1.named_parameters(): param2 = dict(model2.named_parameters())[name] self.assertTrue(torch.equal(param1, param2)) def test_config_validation(self): invalid_config = LlamaConfig( vocab_size=1000, hidden_size=127, num_attention_heads=4, num_key_value_heads=2, ) with self.assertRaises(AttributeError): LlamaForCausalLMEagle3(invalid_config) if __name__ == "__main__": suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(TestLlamaForCausalLMEagle3Loading)) runner = unittest.TextTestRunner(verbosity=2) runner.run(suite)