| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Testing suite for the PyTorch Bamba model.""" |
| |
|
| | import inspect |
| | import unittest |
| |
|
| | import pytest |
| |
|
| | from transformers import AutoTokenizer, BambaConfig, is_torch_available |
| | from transformers.testing_utils import ( |
| | Expectations, |
| | require_deterministic_for_xpu, |
| | require_torch, |
| | require_torch_accelerator, |
| | slow, |
| | torch_device, |
| | ) |
| |
|
| | from ...generation.test_utils import GenerationTesterMixin |
| | from ...test_configuration_common import ConfigTester |
| | from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor |
| | from ...test_pipeline_mixin import PipelineTesterMixin |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | from transformers import ( |
| | BambaForCausalLM, |
| | BambaModel, |
| | ) |
| | from transformers.models.bamba.modeling_bamba import ( |
| | HybridMambaAttentionDynamicCache, |
| | ) |
| |
|
| |
|
| | class BambaModelTester: |
| | def __init__( |
| | self, |
| | parent, |
| | batch_size=13, |
| | seq_length=7, |
| | is_training=True, |
| | use_input_mask=True, |
| | use_labels=True, |
| | vocab_size=99, |
| | hidden_size=32, |
| | num_hidden_layers=4, |
| | num_attention_heads=4, |
| | num_key_value_heads=2, |
| | intermediate_size=64, |
| | hidden_act="silu", |
| | attention_dropout=0.0, |
| | attn_layer_indices=None, |
| | attn_rotary_emb=8, |
| | max_position_embeddings=512, |
| | type_vocab_size=16, |
| | initializer_range=0.02, |
| | num_labels=3, |
| | pad_token_id=0, |
| | mamba_n_groups=1, |
| | mamba_n_heads=16, |
| | mamba_d_state=16, |
| | mamba_d_conv=4, |
| | mamba_expand=2, |
| | mamba_chunk_size=16, |
| | scope=None, |
| | ): |
| | self.parent = parent |
| | self.batch_size = batch_size |
| | self.seq_length = seq_length |
| | self.is_training = is_training |
| | self.use_input_mask = use_input_mask |
| | self.use_labels = use_labels |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_key_value_heads = num_key_value_heads |
| | self.intermediate_size = intermediate_size |
| | self.hidden_act = hidden_act |
| | self.attention_dropout = attention_dropout |
| | self.attn_layer_indices = attn_layer_indices |
| | self.attn_rotary_emb = attn_rotary_emb |
| | self.max_position_embeddings = max_position_embeddings |
| | self.type_vocab_size = type_vocab_size |
| | self.initializer_range = initializer_range |
| | self.num_labels = num_labels |
| | self.pad_token_id = pad_token_id |
| | self.scope = scope |
| | self.mamba_n_groups = mamba_n_groups |
| | self.mamba_n_heads = mamba_n_heads |
| | self.mamba_d_state = mamba_d_state |
| | self.mamba_d_conv = mamba_d_conv |
| | self.mamba_expand = mamba_expand |
| | self.mamba_chunk_size = mamba_chunk_size |
| |
|
| | def prepare_config_and_inputs(self): |
| | input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) |
| |
|
| | input_mask = None |
| | if self.use_input_mask: |
| | input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) |
| |
|
| | token_labels = None |
| | if self.use_labels: |
| | token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) |
| |
|
| | config = self.get_config() |
| |
|
| | return config, input_ids, input_mask, token_labels |
| |
|
| | def prepare_config_and_inputs_for_common(self): |
| | config_and_inputs = self.prepare_config_and_inputs() |
| | ( |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ) = config_and_inputs |
| | inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} |
| | return config, inputs_dict |
| |
|
| | def get_config(self): |
| | |
| | if self.num_hidden_layers < 4: |
| | self.num_hidden_layers = 4 |
| | if self.attn_layer_indices is None: |
| | d = [x for x in range(2, self.num_hidden_layers) if self.num_hidden_layers % x == 0] |
| | if len(d) == 0: |
| | raise ValueError("num_hidden_layers is prime, cannot automatically set attn_layer_indices.") |
| | d = d[-1] |
| | self.attn_layer_indices = [x + 1 for x in range(0, self.num_hidden_layers, d)] |
| |
|
| | return BambaConfig( |
| | vocab_size=self.vocab_size, |
| | hidden_size=self.hidden_size, |
| | num_hidden_layers=self.num_hidden_layers, |
| | num_attention_heads=self.num_attention_heads, |
| | num_key_value_heads=self.num_key_value_heads, |
| | intermediate_size=self.intermediate_size, |
| | hidden_act=self.hidden_act, |
| | attention_dropout=self.attention_dropout, |
| | attn_layer_indices=self.attn_layer_indices, |
| | attn_rotary_emb=self.attn_rotary_emb, |
| | max_position_embeddings=self.max_position_embeddings, |
| | initializer_range=self.initializer_range, |
| | pad_token_id=self.pad_token_id, |
| | mamba_n_groups=self.mamba_n_groups, |
| | mamba_n_heads=self.mamba_n_heads, |
| | mamba_d_state=self.mamba_d_state, |
| | mamba_d_conv=self.mamba_d_conv, |
| | mamba_expand=self.mamba_expand, |
| | mamba_chunk_size=self.mamba_chunk_size, |
| | ) |
| |
|
| | def create_and_check_model( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | model = BambaModel(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| | result = model(input_ids, attention_mask=input_mask) |
| | result = model(input_ids) |
| | self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) |
| |
|
| | def create_and_check_for_causal_lm( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | model = BambaForCausalLM(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| | result = model(input_ids, attention_mask=input_mask, labels=token_labels) |
| | result = model(input_ids, attention_mask=input_mask) |
| | result = model(input_ids, labels=token_labels) |
| | result = model(input_ids) |
| | self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) |
| |
|
| | def create_and_check_decoder_model_past_large_inputs( |
| | self, |
| | config, |
| | input_ids, |
| | input_mask, |
| | token_labels, |
| | ): |
| | |
| | |
| | model = BambaForCausalLM(config=config) |
| | model.to(torch_device) |
| | model.eval() |
| |
|
| | |
| | |
| | past_key_values = HybridMambaAttentionDynamicCache( |
| | config, input_ids.shape[0], model.dtype, device=model.device |
| | ) |
| | outputs = model( |
| | input_ids, |
| | attention_mask=input_mask, |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = outputs.past_key_values |
| |
|
| | |
| | next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) |
| | next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) |
| |
|
| | |
| | next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
| | next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) |
| |
|
| | output_from_no_past = model( |
| | next_input_ids, |
| | attention_mask=next_attention_mask, |
| | output_hidden_states=True, |
| | )["hidden_states"][0] |
| | output_from_past = model( |
| | next_tokens, |
| | attention_mask=next_attention_mask, |
| | past_key_values=past_key_values, |
| | output_hidden_states=True, |
| | cache_position=torch.arange( |
| | input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device |
| | ), |
| | )["hidden_states"][0] |
| |
|
| | |
| | random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() |
| | output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() |
| | output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() |
| |
|
| | self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) |
| |
|
| | |
| | self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) |
| |
|
| |
|
| | @require_torch |
| | class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): |
| | all_model_classes = (BambaModel, BambaForCausalLM) if is_torch_available() else () |
| | pipeline_model_mapping = ( |
| | { |
| | "feature-extraction": BambaModel, |
| | "text-generation": BambaForCausalLM, |
| | } |
| | if is_torch_available() |
| | else {} |
| | ) |
| | test_headmasking = False |
| | test_pruning = False |
| | fx_compatible = False |
| |
|
| | |
| | |
| | model_split_percents = [0.5, 0.7, 0.8] |
| |
|
| | def setUp(self): |
| | self.model_tester = BambaModelTester(self) |
| | self.config_tester = ConfigTester(self, config_class=BambaConfig, hidden_size=64) |
| |
|
| | def test_config(self): |
| | self.config_tester.run_common_tests() |
| |
|
| | def test_model(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_model(*config_and_inputs) |
| |
|
| | def test_for_casual_lm(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) |
| |
|
| | def test_decoder_model_past_with_large_inputs(self): |
| | config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| | self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
| |
|
| | def test_initialization(self): |
| | r""" |
| | Overriding the test_initialization test as the A_log and D params of the Bamba mixer are initialized differently |
| | """ |
| | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| |
|
| | configs_no_init = _config_zero_init(config) |
| | for model_class in self.all_model_classes: |
| | model = model_class(config=configs_no_init) |
| | for name, param in model.named_parameters(): |
| | if param.requires_grad: |
| | if "A_log" in name: |
| | A = torch.arange(1, config.mamba_n_heads + 1, dtype=torch.float32) |
| | torch.testing.assert_close(param.data, torch.log(A), rtol=1e-5, atol=1e-5) |
| | elif "D" in name: |
| | D = torch.ones(config.mamba_n_heads, dtype=torch.float32) |
| | torch.testing.assert_close(param.data, D, rtol=1e-5, atol=1e-5) |
| | else: |
| | self.assertIn( |
| | ((param.data.mean() * 1e9).round() / 1e9).item(), |
| | [0.0, 1.0], |
| | msg=f"Parameter {name} of model {model_class} seems not properly initialized", |
| | ) |
| |
|
| | def test_mismatched_shapes_have_properly_initialized_weights(self): |
| | r""" |
| | Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the |
| | Bamba mixer are initialized differently and we tested that in test_initialization |
| | """ |
| | self.skipTest(reason="Cumbersome and redundant for Bamba") |
| |
|
| | def test_attention_outputs(self): |
| | r""" |
| | Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers |
| | """ |
| | config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| | config.return_dict = True |
| |
|
| | seq_len = getattr(self.model_tester, "seq_length", None) |
| | encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) |
| | encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) |
| |
|
| | expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices) |
| |
|
| | for model_class in self.all_model_classes: |
| | inputs_dict["output_attentions"] = True |
| | inputs_dict["output_hidden_states"] = False |
| | config.return_dict = True |
| | model = model_class(config) |
| | model.to(torch_device) |
| | model.eval() |
| |
|
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| | attentions = outputs.attentions |
| | self.assertEqual(len(attentions), expected_num_attentions) |
| |
|
| | |
| | del inputs_dict["output_attentions"] |
| | config.output_attentions = True |
| | model = model_class(config) |
| | model.to(torch_device) |
| | model.eval() |
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| | attentions = outputs.attentions |
| | self.assertEqual(len(attentions), expected_num_attentions) |
| |
|
| | self.assertListEqual( |
| | list(attentions[0].shape[-3:]), |
| | [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], |
| | ) |
| | out_len = len(outputs) |
| |
|
| | |
| | inputs_dict["output_attentions"] = True |
| | inputs_dict["output_hidden_states"] = True |
| | model = model_class(config) |
| | model.to(torch_device) |
| | model.eval() |
| | with torch.no_grad(): |
| | outputs = model(**self._prepare_for_class(inputs_dict, model_class)) |
| |
|
| | added_hidden_states = 1 |
| | self.assertEqual(out_len + added_hidden_states, len(outputs)) |
| |
|
| | self_attentions = outputs.attentions |
| |
|
| | self.assertEqual(len(self_attentions), expected_num_attentions) |
| | self.assertListEqual( |
| | list(self_attentions[0].shape[-3:]), |
| | [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], |
| | ) |
| |
|
| | def test_batching_equivalence(self): |
| | |
| | orig = self.model_tester.use_input_mask |
| | self.model_tester.use_input_mask = False |
| | super().test_batching_equivalence() |
| | self.model_tester.use_input_mask = orig |
| |
|
| | |
| | @pytest.mark.generate |
| | def test_left_padding_compatibility(self): |
| | |
| | |
| |
|
| | |
| | |
| | if len(self.all_generative_model_classes) == 0: |
| | self.skipTest(reason="No generative architecture available for this model.") |
| |
|
| | |
| | if not self.has_attentions: |
| | self.skipTest(reason="This model doesn't support padding.") |
| |
|
| | |
| | decoder_only_classes = [] |
| | for model_class in self.all_generative_model_classes: |
| | config, _ = self.prepare_config_and_inputs_for_generate() |
| | if config.is_encoder_decoder: |
| | continue |
| | else: |
| | decoder_only_classes.append(model_class) |
| | if len(decoder_only_classes) == 0: |
| | self.skipTest(reason="No decoder-only architecture available for this model.") |
| |
|
| | |
| | |
| | has_encoder_attributes = any( |
| | attr_name |
| | for attr_name in config.to_dict().keys() |
| | if attr_name.startswith("encoder") and attr_name != "encoder_no_repeat_ngram_size" |
| | ) |
| | if has_encoder_attributes: |
| | self.skipTest( |
| | reason="The decoder-only derived from encoder-decoder models are not expected to support left-padding." |
| | ) |
| |
|
| | |
| | def _prepare_model_kwargs(input_ids, attention_mask, signature): |
| | model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} |
| | if "position_ids" in signature: |
| | position_ids = torch.cumsum(attention_mask, dim=-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | model_kwargs["position_ids"] = position_ids |
| | if "cache_position" in signature: |
| | cache_position = torch.arange(input_ids.shape[-1], device=torch_device) |
| | model_kwargs["cache_position"] = cache_position |
| | return model_kwargs |
| |
|
| | for model_class in decoder_only_classes: |
| | config, inputs_dict = self.prepare_config_and_inputs_for_generate() |
| | input_ids = inputs_dict["input_ids"] |
| |
|
| | |
| | |
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | model = model_class(config).to(torch_device).eval() |
| | signature = inspect.signature(model.forward).parameters.keys() |
| |
|
| | |
| | model.generation_config.use_cache = False |
| |
|
| | |
| | model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) |
| | next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] |
| |
|
| | |
| | |
| | pad_token_id = ( |
| | config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0 |
| | ) |
| | pad_size = (input_ids.shape[0], 32) |
| | padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id |
| | padded_input_ids = torch.cat((padding, input_ids), dim=1) |
| | padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) |
| | model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) |
| | next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] |
| |
|
| | |
| | torch.testing.assert_close(next_logits_wo_padding, next_logits_with_padding, rtol=1e-5, atol=1e-5) |
| |
|
| |
|
| | @slow |
| | @require_torch |
| | @require_torch_accelerator |
| | class BambaModelIntegrationTest(unittest.TestCase): |
| | model = None |
| | tokenizer = None |
| | |
| | |
| | cuda_compute_capability_major_version = None |
| |
|
| | @classmethod |
| | def setUpClass(cls): |
| | model_id = "ibm-fms/Bamba-9B" |
| | cls.model = BambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) |
| | cls.tokenizer = AutoTokenizer.from_pretrained(model_id) |
| |
|
| | |
| | cls.tokenizer.pad_token_id = cls.model.config.pad_token_id |
| | cls.tokenizer.padding_side = "left" |
| |
|
| | if is_torch_available() and torch.cuda.is_available(): |
| | |
| | cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] |
| |
|
| | def test_simple_generate(self): |
| | expectations = Expectations( |
| | { |
| | ( |
| | "cuda", |
| | 8, |
| | ): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all having a good time.", |
| | ( |
| | "rocm", |
| | 9, |
| | ): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | ( |
| | "xpu", |
| | 3, |
| | ): "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. Today I", |
| | } |
| | ) |
| |
|
| | self.model.to(torch_device) |
| |
|
| | input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[ |
| | "input_ids" |
| | ].to(torch_device) |
| | out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10) |
| | output_sentence = self.tokenizer.decode(out[0, :]) |
| | expected = expectations.get_expectation() |
| | self.assertEqual(output_sentence, expected) |
| |
|
| | |
| | if self.cuda_compute_capability_major_version == 8: |
| | with torch.no_grad(): |
| | logits = self.model(input_ids=input_ids, logits_to_keep=40).logits |
| |
|
| | EXPECTED_LOGITS_NO_GRAD = torch.tensor( |
| | [ |
| | 149., 142., 146., 142., 143., 144., 142., 145., |
| | 142., 146., 144., 146., 147., 147., 148., 145., |
| | 147., 145., 145., 145., 145., 144., 144., 144., |
| | 144., 145., 147., 146., 144., 144., 148., 147., |
| | 148., 147., 147., 147., 146., 146., 148., 148. |
| | ], dtype=torch.bfloat16) |
| |
|
| | torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1) |
| |
|
| | @require_deterministic_for_xpu |
| | def test_simple_batched_generate_with_padding(self): |
| | |
| | |
| | |
| | |
| | EXPECTED_TEXTS = Expectations( |
| | { |
| | ("cuda", 7): [], |
| | ("cuda", 8): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", |
| | ], |
| | ("rocm", 9): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are doing well. I am here", |
| | "!!!<|begin_of_text|>I am late! I need to be at the airport in 20 minutes! I", |
| | ], |
| | ("xpu", 3): [ |
| | "<|begin_of_text|>Hey how are you doing on this lovely evening? I hope you are all doing well. Today I", |
| | "!!!<|begin_of_text|>I am late! I need to get to work! I have to get to the", |
| | ], |
| | } |
| | ) |
| | EXPECTED_TEXT = EXPECTED_TEXTS.get_expectation() |
| |
|
| | self.model.to(torch_device) |
| |
|
| | inputs = self.tokenizer( |
| | ["Hey how are you doing on this lovely evening?", "I am late! I need to"], |
| | padding=True, |
| | return_tensors="pt", |
| | ).to(torch_device) |
| | out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10) |
| | output_sentences = self.tokenizer.batch_decode(out) |
| | self.assertEqual(output_sentences[0], EXPECTED_TEXT[0]) |
| | self.assertEqual(output_sentences[1], EXPECTED_TEXT[1]) |
| |
|
| | |
| | if self.cuda_compute_capability_major_version == 8: |
| | with torch.no_grad(): |
| | logits = self.model(input_ids=inputs["input_ids"]).logits |
| |
|
| | EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( |
| | [ |
| | 149., 142., 146., 142., 143., 144., 142., 145., |
| | 142., 146., 144., 146., 147., 147., 148., 145., |
| | 147., 145., 145., 145., 145., 144., 144., 144., |
| | 144., 145., 147., 146., 144., 144., 148., 147., |
| | 148., 147., 147., 147., 146., 146., 148., 148. |
| | ], dtype=torch.bfloat16) |
| |
|
| | EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( |
| | [ |
| | 182., 178., 177., 174., 176., 176., 178., 178., |
| | 177., 179., 176., 183., 180., 182., 179., 174., |
| | 178., 176., 176., 175., 175., 175., 174., 173., |
| | 174., 182., 180., 176., 177., 177., 180., 176., |
| | 178., 177., 177., 175., 176., 177., 175., 177. |
| | ], dtype=torch.bfloat16) |
| |
|
| | torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1) |
| | torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1) |
| |
|