| | from typing import TYPE_CHECKING, List, Optional, Tuple |
| |
|
| | from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding |
| | from transformers.utils import logging, TensorType, to_py_obj |
| |
|
| | try: |
| | from ariautils.midi import MidiDict |
| | from ariautils.tokenizer import AbsTokenizer |
| | from ariautils.tokenizer._base import Token |
| | except ImportError: |
| | raise ImportError( |
| | "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`." |
| | ) |
| |
|
| | if TYPE_CHECKING: |
| | pass |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class AriaTokenizer(PreTrainedTokenizer): |
| | """ |
| | Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule. |
| | |
| | For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts: |
| | <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END> |
| | This way, we expect a continuation that connects PROMPT and GUIDANCE. |
| | """ |
| |
|
| | vocab_files_names = {} |
| | model_input_names = ["input_ids", "attention_mask"] |
| |
|
| | def __init__( |
| | self, |
| | add_bos_token=True, |
| | add_eos_token=False, |
| | clean_up_tokenization_spaces=False, |
| | use_default_system_prompt=False, |
| | **kwargs, |
| | ): |
| | self._tokenizer = AbsTokenizer() |
| |
|
| | self.add_bos_token = add_bos_token |
| | self.add_eos_token = add_eos_token |
| | self.use_default_system_prompt = use_default_system_prompt |
| |
|
| | bos_token = self._tokenizer.bos_tok |
| | eos_token = self._tokenizer.eos_tok |
| | pad_token = self._tokenizer.pad_tok |
| | unk_token = self._tokenizer.unk_tok |
| |
|
| | super().__init__( |
| | bos_token=bos_token, |
| | eos_token=eos_token, |
| | unk_token=unk_token, |
| | pad_token=pad_token, |
| | use_default_system_prompt=use_default_system_prompt, |
| | **kwargs, |
| | ) |
| |
|
| | def __getstate__(self): |
| | return {} |
| |
|
| | def __setstate__(self, d): |
| | raise NotImplementedError() |
| |
|
| | @property |
| | def vocab_size(self): |
| | """Returns vocab size""" |
| | return self._tokenizer.vocab_size |
| |
|
| | def get_vocab(self): |
| | return self._tokenizer.tok_to_id |
| |
|
| | def tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: |
| | return self._tokenizer(midi_dict) |
| |
|
| | def _tokenize(self, midi_dict: MidiDict, **kwargs) -> List[Token]: |
| | return self._tokenizer(midi_dict) |
| |
|
| | def __call__( |
| | self, |
| | midi_dicts: MidiDict | list[MidiDict], |
| | padding: bool = False, |
| | max_length: int | None = None, |
| | pad_to_multiple_of: int | None = None, |
| | return_tensors: str | TensorType | None = None, |
| | return_attention_mask: bool | None = None, |
| | **kwargs, |
| | ) -> BatchEncoding: |
| | """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design.""" |
| | if isinstance(midi_dicts, MidiDict): |
| | midi_dicts = [midi_dicts] |
| |
|
| | all_tokens: list[list[int]] = [] |
| | all_attn_masks: list[list[int]] = [] |
| | max_len_encoded = 0 |
| | |
| | for md in midi_dicts: |
| | tokens = self._tokenizer.encode(self._tokenizer.tokenize(md)) |
| | if max_length is not None: |
| | tokens = tokens[:max_length] |
| | max_len_encoded = max(max_len_encoded, len(tokens)) |
| | all_tokens.append(tokens) |
| | all_attn_masks.append([True] * len(tokens)) |
| |
|
| | if pad_to_multiple_of is not None: |
| | max_len_encoded = ( |
| | (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of |
| | ) * pad_to_multiple_of |
| | if padding: |
| | for tokens, attn_mask in zip(all_tokens, all_attn_masks): |
| | tokens.extend([self.pad_token_id] * (max_len_encoded - len(tokens))) |
| | attn_mask.extend([False] * (max_len_encoded - len(tokens))) |
| |
|
| | return BatchEncoding( |
| | { |
| | "input_ids": all_tokens, |
| | "attention_masks": all_attn_masks, |
| | }, |
| | tensor_type=return_tensors, |
| | ) |
| |
|
| | def decode(self, token_ids: List[Token], **kwargs) -> MidiDict: |
| | token_ids = to_py_obj(token_ids) |
| |
|
| | return self._tokenizer.detokenize(self._tokenizer.decode(token_ids)) |
| |
|
| | def batch_decode( |
| | self, token_ids_list: List[List[Token]], **kwargs |
| | ) -> List[MidiDict]: |
| | results = [] |
| | for token_ids in token_ids_list: |
| | |
| | results.append(self.decode(token_ids)) |
| | return results |
| |
|
| | def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding: |
| | midi_dict = MidiDict.from_midi(filename) |
| | return self(midi_dict, **kwargs) |
| |
|
| | def encode_from_files(self, filenames: list[str], **kwargs) -> BatchEncoding: |
| | midi_dicts = [MidiDict.from_midi(file) for file in filenames] |
| | return self(midi_dicts, **kwargs) |
| |
|
| | def _convert_token_to_id(self, token: Token): |
| | """Converts a token (tuple or str) into an id.""" |
| | return self._tokenizer.tok_to_id.get( |
| | token, self._tokenizer.tok_to_id[self.unk_token] |
| | ) |
| |
|
| | def _convert_id_to_token(self, index: int): |
| | """Converts an index (integer) in a token (tuple or str).""" |
| | return self._tokenizer.id_to_tok.get(index, self.unk_token) |
| |
|
| | def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict: |
| | """Converts a sequence of tokens into a single MidiDict.""" |
| | return self._tokenizer.detokenize(tokens) |
| |
|
| | def save_vocabulary( |
| | self, save_directory, filename_prefix: Optional[str] = None |
| | ) -> Tuple[str]: |
| | raise NotImplementedError() |
| |
|