| | import gc |
| | import unittest |
| | import weakref |
| | from unittest.mock import MagicMock |
| |
|
| | import torch |
| |
|
| | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, pipeline |
| | from transformers.generation.candidate_generator import ( |
| | AssistantToTargetTranslator, |
| | AssistantVocabTranslatorCache, |
| | UniversalSpeculativeDecodingGenerator, |
| | ) |
| | from transformers.testing_utils import require_torch, torch_device |
| |
|
| |
|
| | @require_torch |
| | class TestAssistantToTargetTranslator(unittest.TestCase): |
| | def setUp(self): |
| | |
| | self.target_tokenizer = MagicMock() |
| | self.assistant_tokenizer = MagicMock() |
| | self.assistant_model = MagicMock(device=torch_device) |
| |
|
| | |
| | self.target_vocab = {"hello": 0, "world": 1, "foo": 2, "bar": 3} |
| | self.assistant_vocab = {"hello": 0, "world": 1, "foo": 2, "baz": 4} |
| |
|
| | self.target_tokenizer.get_vocab.return_value = self.target_vocab |
| | self.assistant_tokenizer.get_vocab.return_value = self.assistant_vocab |
| | self.target_vocab_size = 6 |
| |
|
| | |
| | self.translator = AssistantToTargetTranslator( |
| | target_tokenizer=self.target_tokenizer, |
| | assistant_tokenizer=self.assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| |
|
| | def test_get_assistant_to_target_input_ids(self): |
| | """Test the mapping from assistant tokens to target tokens.""" |
| | expected_mapping = [0, 1, 2, self.translator.SUPPRESS_TOKEN_ID, self.translator.SUPPRESS_TOKEN_ID] |
| | actual_mapping = self.translator._assistant_to_target_input_ids.tolist() |
| | self.assertEqual(actual_mapping, expected_mapping) |
| |
|
| | def test_get_suppress_input_ids(self): |
| | """Test the suppression of assistant input IDs not present in the target vocabulary.""" |
| | expected_suppress_ids = [3, 4] |
| | actual_suppress_ids = self.translator._get_suppress_input_ids().tolist() |
| | self.assertEqual(actual_suppress_ids, expected_suppress_ids) |
| |
|
| | def test_get_target_ids(self): |
| | """Test the translation of assistant candidate IDs to target candidate IDs.""" |
| | assistant_input_ids = torch.LongTensor([[0, 1, 2]]).to( |
| | self.assistant_model.device |
| | ) |
| | target_input_ids = torch.LongTensor([[0, 1, 2]]).to( |
| | self.assistant_model.device |
| | ) |
| | assistant_candidate_ids = torch.LongTensor([[0, 1, 2, 4]]).to( |
| | self.assistant_model.device |
| | ) |
| |
|
| | expected_target_ids = torch.LongTensor( |
| | [[0, 1, 2, self.translator.SUPPRESS_TOKEN_ID]] |
| | ).to( |
| | self.assistant_model.device |
| | ) |
| |
|
| | actual_target_ids = self.translator.get_target_ids( |
| | assistant_input_ids, target_input_ids, assistant_candidate_ids |
| | ) |
| | self.assertTrue(torch.equal(actual_target_ids, expected_target_ids)) |
| |
|
| | def test_get_target_logits(self): |
| | """Test the conversion of assistant logits to target logits.""" |
| | |
| | assistant_logits = torch.FloatTensor([[[0.1, 0.2, 0.3, 0.4, self.translator.FILTER_VALUE]]]).to( |
| | self.assistant_model.device |
| | ) |
| |
|
| | |
| | expected_target_logits = torch.full((1, 1, self.target_vocab_size), self.translator.FILTER_VALUE).to( |
| | self.assistant_model.device |
| | ) |
| | expected_target_logits[0, 0, 0] = 0.1 |
| | expected_target_logits[0, 0, 1] = 0.2 |
| | expected_target_logits[0, 0, 2] = 0.3 |
| | |
| |
|
| | actual_target_logits = self.translator.get_target_logits(assistant_logits) |
| | self.assertTrue(torch.equal(actual_target_logits, expected_target_logits)) |
| |
|
| |
|
| | class MockTokenizer: |
| | """A simple mock tokenizer class that supports weak references.""" |
| |
|
| | def __init__(self, vocab=None): |
| | self._vocab = vocab or {} |
| |
|
| | def get_vocab(self): |
| | return self._vocab |
| |
|
| | def __call__(self, text, add_special_tokens=True): |
| | |
| | tokens = text.split() |
| | input_ids = [self._vocab.get(token, 0) for token in tokens] |
| | return {"input_ids": input_ids} |
| |
|
| |
|
| | @require_torch |
| | class TestAssistantVocabTranslatorCache(unittest.TestCase): |
| | def setUp(self): |
| | |
| | AssistantVocabTranslatorCache._cache.clear() |
| | |
| | self.target_tokenizer = MockTokenizer({"hello": 0, "world": 1}) |
| | self.assistant_tokenizer = MockTokenizer({"hello": 0, "world": 1, "foo": 2}) |
| | self.other_target_tokenizer = MockTokenizer({"foo": 2, "bar": 3}) |
| | self.other_assistant_tokenizer = MockTokenizer({"baz": 4, "qux": 5}) |
| | self.assistant_model = MagicMock(device=torch_device) |
| |
|
| | self.target_vocab_size = 6 |
| |
|
| | def test_same_instance_for_same_tokenizers(self): |
| | """Test that the same translator is returned for the same tokenizers.""" |
| | translator1 = AssistantVocabTranslatorCache.get_translator( |
| | self.target_tokenizer, |
| | self.assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | translator2 = AssistantVocabTranslatorCache.get_translator( |
| | self.target_tokenizer, |
| | self.assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | self.assertIs(translator1, translator2, "Translators should be cached and identical") |
| |
|
| | def test_different_instances_for_different_tokenizers(self): |
| | """Test that different tokenizers produce different translators.""" |
| | translator1 = AssistantVocabTranslatorCache.get_translator( |
| | self.target_tokenizer, |
| | self.assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | translator2 = AssistantVocabTranslatorCache.get_translator( |
| | self.other_target_tokenizer, |
| | self.other_assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | self.assertIsNot(translator1, translator2, "Translators should differ for different tokenizers") |
| |
|
| | def test_cache_with_weakref_key(self): |
| | """Ensure that the cache uses weak references as keys.""" |
| | initial_cache_size = len(AssistantVocabTranslatorCache._cache) |
| | target_tokenizer = MockTokenizer({"hello": 0}) |
| | assistant_tokenizer = MockTokenizer({"hello": 0}) |
| |
|
| | |
| | translator = AssistantVocabTranslatorCache.get_translator( |
| | target_tokenizer, |
| | assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) |
| |
|
| | |
| | del target_tokenizer |
| | del assistant_tokenizer |
| | del translator |
| |
|
| | |
| | gc.collect() |
| |
|
| | |
| | AssistantVocabTranslatorCache.cleanup() |
| |
|
| | |
| | self.assertEqual(len(AssistantVocabTranslatorCache._cache), initial_cache_size + 1) |
| |
|
| | def test_weakref_cache_cleanup(self): |
| | """Test that the cache cleans up translators when tokenizers are garbage collected.""" |
| |
|
| | def create_translator(): |
| | target_tokenizer = MockTokenizer({"hello": 0}) |
| | assistant_tokenizer = MockTokenizer({"hello": 0}) |
| | translator = AssistantVocabTranslatorCache.get_translator( |
| | target_tokenizer, |
| | assistant_tokenizer, |
| | target_vocab_size=self.target_vocab_size, |
| | assistant_model=self.assistant_model, |
| | assistant_prune_lm_head=False, |
| | ) |
| | |
| | refs = (weakref.ref(translator), weakref.ref(target_tokenizer), weakref.ref(assistant_tokenizer)) |
| | |
| | del target_tokenizer |
| | del assistant_tokenizer |
| | del translator |
| | return refs |
| |
|
| | translator_ref, target_ref, assistant_ref = create_translator() |
| |
|
| | |
| | gc.collect() |
| |
|
| | |
| | AssistantVocabTranslatorCache.cleanup() |
| |
|
| | |
| | self.assertIsNotNone(target_ref(), "Target tokenizer should still be alive due to strong references") |
| | self.assertIsNotNone(assistant_ref(), "Assistant tokenizer should still be alive due to strong references") |
| | self.assertIsNotNone(translator_ref(), "Translator should still be alive due to strong references") |
| |
|
| |
|
| | @require_torch |
| | class TestUniversalSpeculativeDecoding(unittest.TestCase): |
| | @classmethod |
| | def setUpClass(cls): |
| | cls.target_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" |
| | cls.assistant_name = "hf-internal-testing/tiny-random-PhiForCausalLM" |
| |
|
| | def setUp(self): |
| | self.target_tokenizer = AutoTokenizer.from_pretrained(self.target_name) |
| | self.target_config = AutoConfig.from_pretrained(self.target_name) |
| | self.assistant_model = AutoModelForCausalLM.from_pretrained(self.assistant_name).to(torch_device) |
| | self.assistant_tokenizer = AutoTokenizer.from_pretrained(self.assistant_name) |
| |
|
| | self.generation_config = GenerationConfig() |
| |
|
| | |
| | if self.target_tokenizer.pad_token_id is None: |
| | self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id |
| | if self.target_tokenizer.bos_token_id is None: |
| | self.target_tokenizer.bos_token_id = self.target_tokenizer.eos_token_id |
| | if self.assistant_tokenizer.pad_token_id is None: |
| | self.assistant_tokenizer.pad_token_id = self.assistant_tokenizer.eos_token_id |
| | if self.assistant_tokenizer.bos_token_id is None: |
| | self.assistant_tokenizer.bos_token_id = self.assistant_tokenizer.eos_token_id |
| |
|
| | self.input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) |
| | self.model_kwargs = { |
| | "attention_mask": torch.ones_like(self.input_ids).to(torch_device), |
| | } |
| | atm_translator = AssistantVocabTranslatorCache.get_translator( |
| | target_tokenizer=self.target_tokenizer, |
| | assistant_tokenizer=self.assistant_tokenizer, |
| | assistant_model=self.assistant_model, |
| | target_vocab_size=self.target_config.vocab_size, |
| | ) |
| | self.generator = UniversalSpeculativeDecodingGenerator( |
| | input_ids=self.input_ids, |
| | assistant_model=self.assistant_model, |
| | target_tokenizer=self.target_tokenizer, |
| | assistant_tokenizer=self.assistant_tokenizer, |
| | generation_config=self.generation_config, |
| | model_kwargs=self.model_kwargs, |
| | atm_translator=atm_translator, |
| | ) |
| |
|
| | def test_basic_generation(self): |
| | """Test basic speculative decoding works""" |
| | input_text = "The quick brown fox" |
| | input_ids = self.target_tokenizer.encode(input_text, return_tensors="pt") |
| | self.generator.input_ids = input_ids |
| | candidates, scores = self.generator.get_candidates(input_ids) |
| |
|
| | self.assertIsNotNone(candidates) |
| | self.assertIsNotNone(scores) |
| | self.assertTrue(torch.is_tensor(candidates)) |
| | self.assertTrue(torch.is_tensor(scores)) |
| |
|
| | def test_mismatched_vocabularies(self): |
| | """Test handling of mismatched vocabularies between models""" |
| | |
| | |
| | |
| | missing_token = next( |
| | token |
| | for token in self.target_tokenizer.get_vocab() |
| | if token not in self.assistant_tokenizer.get_vocab() |
| | and token not in self.target_tokenizer.all_special_tokens |
| | and "reserved_" not in token |
| | ) |
| | input_ids = torch.tensor([[self.target_tokenizer.convert_tokens_to_ids(missing_token)]]) |
| | self.generator.input_ids = input_ids |
| | candidates, _ = self.generator.get_candidates(input_ids) |
| | self.assertIsNotNone(candidates) |
| |
|
| | def test_speculation_depth(self): |
| | """Test different speculation depths""" |
| | input_ids = self.target_tokenizer.encode("Test text", return_tensors="pt") |
| | self.generator.input_ids = input_ids |
| |
|
| | for depth in [1, 8, 17]: |
| | self.generator.num_assistant_tokens = depth |
| | candidates, _ = self.generator.get_candidates(input_ids) |
| | self.assertLessEqual(candidates.shape[1] - input_ids.shape[1], depth) |
| |
|
| | def test_device_consistency(self): |
| | """Test handling of inputs on different devices""" |
| | input_ids = torch.tensor([[1, 2, 3]]).to(torch_device) |
| | self.generator.input_ids = input_ids |
| | candidates, _ = self.generator.get_candidates(input_ids) |
| | self.assertEqual(candidates.device, input_ids.device) |
| |
|
| | def test_usd_vs_vanilla_sampling(cls): |
| | """Test that USD matches vanilla sampling with temperature set to nearly 0""" |
| | prompt = "Test text" |
| |
|
| | pipe_vanilla = pipeline( |
| | "text-generation", |
| | model=cls.target_name, |
| | ) |
| | pipe_vanilla_output = pipe_vanilla(prompt, max_new_tokens=5, do_sample=False) |
| | vanilla_text = pipe_vanilla_output[0]["generated_text"] |
| |
|
| | pipe_usd = pipeline( |
| | "text-generation", |
| | model=cls.target_name, |
| | assistant_model=cls.assistant_name, |
| | ) |
| | pipe_usd_output = pipe_usd(prompt, max_new_tokens=5, do_sample=True, temperature=1e-9) |
| | usd_text = pipe_usd_output[0]["generated_text"] |
| |
|
| | |
| | cls.assertEqual(usd_text, vanilla_text) |
| |
|