Lekr0's picture
Add files using upload-large-folder tool
29658b2 verified
import os
import shutil
import tempfile
import unittest
from unittest.mock import patch
import torch
from transformers import LlamaConfig
from specforge.modeling.draft.llama3_eagle import (
LlamaAttention,
LlamaForCausalLMEagle3,
LlamaMLP,
LlamaRMSNorm,
)
# from model_module import LlamaForCausalLMEagle3
class TestLlamaForCausalLMEagle3Loading(unittest.TestCase):
def setUp(self):
"""Set up the test environment before each test."""
self.temp_dir = tempfile.mkdtemp()
config_dict = {
"architectures": ["LlamaForCausalLM"],
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.28.1",
"use_cache": True,
"vocab_size": 128256,
"draft_vocab_size": 32000,
}
self.config = LlamaConfig(**config_dict)
def tearDown(self):
shutil.rmtree(self.temp_dir)
def test_model_initialization(self):
model = LlamaForCausalLMEagle3(self.config)
self.assertIsInstance(model.midlayer.self_attn, LlamaAttention)
self.assertIsInstance(model.midlayer.mlp, LlamaMLP)
self.assertIsInstance(model.midlayer.hidden_norm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.input_layernorm, LlamaRMSNorm)
self.assertIsInstance(model.midlayer.post_attention_layernorm, LlamaRMSNorm)
self.assertEqual(model.midlayer.hidden_size, self.config.hidden_size)
def test_save_pretrained(self):
"""Test the model's save_pretrained functionality."""
model = LlamaForCausalLMEagle3(self.config)
self.config.save_pretrained(self.temp_dir)
model_path = os.path.join(self.temp_dir, "pytorch_model.bin")
torch.save(model.state_dict(), model_path)
self.assertTrue(os.path.exists(os.path.join(self.temp_dir, "config.json")))
self.assertTrue(os.path.exists(model_path))
@patch("transformers.modeling_utils.PreTrainedModel.from_pretrained")
def test_from_pretrained_mock(self, mock_from_pretrained):
"""mock"""
mock_model = LlamaForCausalLMEagle3(self.config)
mock_from_pretrained.return_value = mock_model
loaded_model = LlamaForCausalLMEagle3.from_pretrained(self.temp_dir)
mock_from_pretrained.assert_called_once_with(self.temp_dir)
self.assertIsInstance(loaded_model, LlamaForCausalLMEagle3)
def test_model_forward_pass(self):
"""forward"""
model = LlamaForCausalLMEagle3(self.config)
model.eval()
batch_size = 2
seq_len = 10
input_emb = torch.randn(batch_size, seq_len, self.config.hidden_size)
hidden_states = torch.randn(batch_size, seq_len, self.config.hidden_size * 3)
attention_mask = torch.ones(batch_size, seq_len)
with torch.no_grad():
outputs = model(
inputs_embeds=input_emb,
hidden_states=hidden_states,
attention_mask=attention_mask,
)
self.assertEqual(outputs.shape, (batch_size, seq_len, self.config.hidden_size))
def test_state_dict_compatibility(self):
model1 = LlamaForCausalLMEagle3(self.config)
model2 = LlamaForCausalLMEagle3(self.config)
state_dict = model1.state_dict()
model2.load_state_dict(state_dict)
for name, param1 in model1.named_parameters():
param2 = dict(model2.named_parameters())[name]
self.assertTrue(torch.equal(param1, param2))
def test_config_validation(self):
invalid_config = LlamaConfig(
vocab_size=1000,
hidden_size=127,
num_attention_heads=4,
num_key_value_heads=2,
)
with self.assertRaises(AttributeError):
LlamaForCausalLMEagle3(invalid_config)
if __name__ == "__main__":
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestLlamaForCausalLMEagle3Loading))
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)