| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import unittest |
| | from functools import lru_cache |
| |
|
| | from datasets import load_dataset |
| |
|
| | from transformers import BloomTokenizerFast |
| | from transformers.testing_utils import require_jinja, require_tokenizers |
| |
|
| | from ...test_tokenization_common import TokenizerTesterMixin, use_cache_if_possible |
| |
|
| |
|
| | @require_tokenizers |
| | class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase): |
| | from_pretrained_id = "bigscience/tokenizer" |
| | slow_tokenizer_class = None |
| | rust_tokenizer_class = BloomTokenizerFast |
| | tokenizer_class = BloomTokenizerFast |
| | test_rust_tokenizer = True |
| | test_slow_tokenizer = False |
| | from_pretrained_vocab_key = "tokenizer_file" |
| | special_tokens_map = {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"} |
| |
|
| | @classmethod |
| | def setUpClass(cls): |
| | super().setUpClass() |
| | tokenizer = BloomTokenizerFast.from_pretrained("bigscience/tokenizer") |
| | tokenizer.save_pretrained(cls.tmpdirname) |
| |
|
| | @classmethod |
| | @use_cache_if_possible |
| | @lru_cache(maxsize=64) |
| | def get_rust_tokenizer(cls, pretrained_name=None, **kwargs): |
| | _kwargs = copy.deepcopy(cls.special_tokens_map) |
| | _kwargs.update(kwargs) |
| | kwargs = _kwargs |
| | pretrained_name = pretrained_name or cls.tmpdirname |
| | return BloomTokenizerFast.from_pretrained(pretrained_name, **kwargs) |
| |
|
| | @unittest.skip(reason="This needs a slow tokenizer. Bloom does not have one!") |
| | def test_encode_decode_with_spaces(self): |
| | return |
| |
|
| | def test_encodings_from_sample_data(self): |
| | """ |
| | Assert that the created tokens are the same than the hard-coded ones |
| | """ |
| | tokenizer = self.get_rust_tokenizer() |
| |
|
| | INPUT_SENTENCES = ["The quick brown fox</s>", "jumps over the lazy dog</s>"] |
| | TARGET_TOKENS = [[2175, 23714, 73173, 144252, 2], [77, 132619, 3478, 368, 109586, 35433, 2]] |
| |
|
| | computed_tokens = tokenizer.batch_encode_plus(INPUT_SENTENCES)["input_ids"] |
| | self.assertListEqual(TARGET_TOKENS, computed_tokens) |
| |
|
| | decoded_tokens = tokenizer.batch_decode(computed_tokens) |
| | self.assertListEqual(decoded_tokens, INPUT_SENTENCES) |
| |
|
| | def test_padding(self, max_length=6): |
| | for tokenizer, pretrained_name, kwargs in self.tokenizers_list: |
| | with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): |
| | tokenizer_r = self.get_rust_tokenizer(pretrained_name, **kwargs) |
| | |
| | |
| | s = "This is a simple input" |
| | s2 = ["This is a simple input 1", "This is a simple input 2"] |
| | p = ("This is a simple input", "This is a pair") |
| | p2 = [ |
| | ("This is a simple input 1", "This is a simple input 2"), |
| | ("This is a simple pair 1", "This is a simple pair 2"), |
| | ] |
| |
|
| | |
| | try: |
| | tokenizer_r.encode(s, max_length=max_length) |
| | tokenizer_r.encode_plus(s, max_length=max_length) |
| |
|
| | tokenizer_r.batch_encode_plus(s2, max_length=max_length) |
| | tokenizer_r.encode(p, max_length=max_length) |
| | tokenizer_r.batch_encode_plus(p2, max_length=max_length) |
| | except ValueError: |
| | self.fail("Bloom Tokenizer should be able to deal with padding") |
| |
|
| | tokenizer_r.pad_token = None |
| | self.assertRaises(ValueError, tokenizer_r.encode, s, max_length=max_length, padding="max_length") |
| |
|
| | |
| | self.assertRaises(ValueError, tokenizer_r.encode_plus, s, max_length=max_length, padding="max_length") |
| |
|
| | |
| | self.assertRaises( |
| | ValueError, |
| | tokenizer_r.batch_encode_plus, |
| | s2, |
| | max_length=max_length, |
| | padding="max_length", |
| | ) |
| |
|
| | |
| | self.assertRaises(ValueError, tokenizer_r.encode, p, max_length=max_length, padding="max_length") |
| |
|
| | |
| | self.assertRaises(ValueError, tokenizer_r.encode_plus, p, max_length=max_length, padding="max_length") |
| |
|
| | |
| | self.assertRaises( |
| | ValueError, |
| | tokenizer_r.batch_encode_plus, |
| | p2, |
| | max_length=max_length, |
| | padding="max_length", |
| | ) |
| |
|
| | def test_encodings_from_xnli_dataset(self): |
| | """ |
| | Tests the tokenizer downloaded from here: |
| | - https://huggingface.co/bigscience/tokenizer/ |
| | """ |
| | tokenizer = self.get_rust_tokenizer() |
| | ds = load_dataset("facebook/xnli", "all_languages", split="test", streaming=True) |
| |
|
| | sample_data = next(iter(ds))["premise"] |
| | input_text = list(sample_data.values()) |
| |
|
| | output_tokens = list(map(tokenizer.encode, input_text)) |
| | predicted_text = [tokenizer.decode(x, clean_up_tokenization_spaces=False) for x in output_tokens] |
| | self.assertListEqual(predicted_text, input_text) |
| |
|
| | @require_jinja |
| | def test_tokenization_for_chat(self): |
| | tokenizer = self.get_rust_tokenizer() |
| | tokenizer.chat_template = "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}" |
| | test_chats = [ |
| | [{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}], |
| | [ |
| | {"role": "system", "content": "You are a helpful chatbot."}, |
| | {"role": "user", "content": "Hello!"}, |
| | {"role": "assistant", "content": "Nice to meet you."}, |
| | ], |
| | [{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}], |
| | ] |
| | tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats] |
| | expected_tokens = [ |
| | [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2], |
| | [5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2], |
| | [229126, 427, 11890, 1152, 17, 2, 59414, 4, 2], |
| | ] |
| | for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens): |
| | self.assertListEqual(tokenized_chat, expected_tokens) |
| |
|
| | def test_add_prefix_space_fast(self): |
| | tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True) |
| | tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False) |
| | tokens_w_prefix = tokenizer_w_prefix.tokenize("Hey") |
| | tokens_wo_prefix = tokenizer_wo_prefix.tokenize("Hey") |
| | self.assertNotEqual(tokens_w_prefix, tokens_wo_prefix) |
| |
|