| | from __future__ import annotations |
| |
|
| | import os |
| | from pathlib import Path |
| | from typing import List, Optional, Union |
| |
|
| | from tokenizers import Tokenizer as BaseTokenizer |
| |
|
| | from .aliases import PathOrStr |
| | from .config import ModelConfig, TokenizerConfig, TrainConfig, TruncationDirection |
| | from .exceptions import OLMoConfigurationError |
| |
|
| | __all__ = ["Tokenizer"] |
| |
|
| |
|
| | class Tokenizer: |
| | """ |
| | A :class:`Tokenizer` is a light-weight wrapper around a HuggingFace :class:`tokenizers.Tokenizer`. |
| | |
| | :param base_tokenizer: The :class:`tokenizers.Tokenizer` to use. |
| | :param eos_token_id: The token ID corresponding to the "end-of-sentence" token. |
| | :param truncate_to: Truncate when tokenizing to this number of token IDs. |
| | :param truncate_direction: The direction to truncate in. "right" means truncate the tokens |
| | on the right. "left" means truncate the tokens on the left. If ``truncate_to`` is null, |
| | this setting has no effect. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | base_tokenizer: BaseTokenizer, |
| | eos_token_id: int, |
| | pad_token_id: Optional[int] = None, |
| | truncate_to: Optional[int] = None, |
| | truncate_direction: Union[str, TruncationDirection] = TruncationDirection.right, |
| | ): |
| | self.base_tokenizer = base_tokenizer |
| | self.base_tokenizer.no_truncation() |
| | self.eos_token_id = eos_token_id |
| | self.pad_token_id = pad_token_id if pad_token_id is not None else eos_token_id |
| | self.truncate_to = truncate_to |
| | self.truncate_direction = TruncationDirection(truncate_direction) |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return self.base_tokenizer.get_vocab_size() |
| |
|
| | @property |
| | def eos_token(self) -> str: |
| | return self.decode([self.eos_token_id], skip_special_tokens=False) |
| |
|
| | @property |
| | def pad_token(self) -> str: |
| | return self.decode([self.pad_token_id], skip_special_tokens=False) |
| |
|
| | @classmethod |
| | def from_train_config(cls, config: TrainConfig) -> Tokenizer: |
| | tokenizer_identifier = config.tokenizer.identifier |
| | if Path(tokenizer_identifier).is_file(): |
| | tokenizer = cls.from_file( |
| | tokenizer_identifier, |
| | eos_token_id=config.model.eos_token_id, |
| | pad_token_id=config.model.pad_token_id, |
| | ) |
| | else: |
| | tokenizer = cls.from_pretrained( |
| | tokenizer_identifier, |
| | eos_token_id=config.model.eos_token_id, |
| | pad_token_id=config.model.pad_token_id, |
| | ) |
| | if config.model.vocab_size != tokenizer.vocab_size: |
| | raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
| | return tokenizer |
| |
|
| | @classmethod |
| | def from_pretrained(cls, identifier: str, **kwargs) -> Tokenizer: |
| | """ |
| | Initialize a tokenizer from a pretrained tokenizer on the HuggingFace Hub. |
| | |
| | :param identifier: The identifier of a model on the Hub that contains a |
| | ``tokenizer.json`` file. |
| | :param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
| | """ |
| | base_tokenizer = BaseTokenizer.from_pretrained(identifier) |
| | eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
| | return cls(base_tokenizer, eos_token_id, **kwargs) |
| |
|
| | @classmethod |
| | def from_file(cls, filename: PathOrStr, **kwargs) -> Tokenizer: |
| | """ |
| | Initialize a tokenizer from a file. |
| | |
| | You can create those files with ``BaseTokenizer.save()``. |
| | |
| | :param filename: The name of a file containing a tokenizer specification. |
| | :param kwargs: Other key word arguments passed to :class:`Tokenizer`. |
| | """ |
| | base_tokenizer = BaseTokenizer.from_file(filename) |
| | eos_token_id = kwargs.pop("eos_token_id", base_tokenizer.get_vocab_size() - 1) |
| | return cls(base_tokenizer, eos_token_id, **kwargs) |
| |
|
| | @classmethod |
| | def from_checkpoint(cls, checkpoint_dir: PathOrStr) -> Tokenizer: |
| | """ |
| | Load a tokenizer from a checkpoint. |
| | """ |
| | from cached_path import cached_path |
| |
|
| | |
| | config_path = cached_path(os.path.join(checkpoint_dir, "config.yaml")) |
| | tokenizer_config = TokenizerConfig.load(config_path, key="tokenizer") |
| | model_config = ModelConfig.load(config_path, key="model") |
| |
|
| | |
| | if Path(tokenizer_config.identifier).is_file(): |
| | tokenizer = cls.from_file( |
| | tokenizer_config.identifier, |
| | eos_token_id=model_config.eos_token_id, |
| | pad_token_id=model_config.pad_token_id, |
| | ) |
| | else: |
| | tokenizer = cls.from_pretrained( |
| | tokenizer_config.identifier, |
| | eos_token_id=model_config.eos_token_id, |
| | pad_token_id=model_config.pad_token_id, |
| | ) |
| | if model_config.vocab_size != tokenizer.vocab_size: |
| | raise OLMoConfigurationError("vocab size mismatch between config and tokenizer") |
| | return tokenizer |
| |
|
| | def add_special_tokens(self, input_ids: List[int]) -> List[int]: |
| | """ |
| | Add special tokens in-place (if not already present) to the given token IDs. |
| | """ |
| | if not input_ids or input_ids[-1] != self.eos_token_id: |
| | input_ids.append(self.eos_token_id) |
| | return input_ids |
| |
|
| | def num_special_tokens_to_add(self, is_pair: bool = False) -> int: |
| | return 2 if is_pair else 1 |
| |
|
| | def _truncate( |
| | self, input_ids: List[int], truncate_to: Optional[int], direction: TruncationDirection |
| | ) -> list[int]: |
| | if truncate_to is None or len(input_ids) <= truncate_to: |
| | return input_ids |
| | elif direction == TruncationDirection.left: |
| | return input_ids[len(input_ids) - truncate_to :] |
| | else: |
| | return input_ids[: -(len(input_ids) - truncate_to)] |
| |
|
| | def encode(self, input: str, add_special_tokens: bool = True) -> List[int]: |
| | """ |
| | Encode a string into token IDs. |
| | """ |
| | return self.encode_batch([input], add_special_tokens=add_special_tokens)[0] |
| |
|
| | def encode_batch(self, inputs: List[str], add_special_tokens: bool = True) -> List[List[int]]: |
| | """ |
| | Encode a batch of strings into token IDs. |
| | """ |
| | truncate_to = self.truncate_to |
| | if truncate_to is not None and add_special_tokens: |
| | truncate_to -= self.num_special_tokens_to_add(False) |
| |
|
| | batch_encoding = self.base_tokenizer.encode_batch(inputs) |
| |
|
| | all_input_ids = [] |
| | for encoding in batch_encoding: |
| | input_ids = self._truncate(encoding.ids, truncate_to, self.truncate_direction) |
| | if add_special_tokens: |
| | input_ids = self.add_special_tokens(input_ids) |
| | all_input_ids.append(input_ids) |
| |
|
| | return all_input_ids |
| |
|
| | def decode(self, token_ids: List[int], skip_special_tokens: bool = True) -> str: |
| | """ |
| | Decode a list of token IDs to a string. |
| | """ |
| | return self.base_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) |
| |
|