# MultiMolecule # Copyright (C) 2024-Present MultiMolecule # This file is part of MultiMolecule. # MultiMolecule is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # any later version. # MultiMolecule is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # For additional terms and clarifications, please refer to our License FAQ at: # . from __future__ import annotations import torch import torch.nn as nn import os from torch import Tensor from functools import lru_cache from itertools import product from typing import Any, Sequence, Tuple, List from pathlib import Path from collections import OrderedDict from transformers.tokenization_utils import PreTrainedTokenizer VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} SPECIAL_TOKENS_MAP = { "pad_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, "cls_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, "eos_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, "unk_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, "mask_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, "null_token": { "content": "", "lstrip": False, "normalized": False, "rstrip": False, "single_word": False, }, } STANDARD_ALPHABET = list("ACGUNRYSWKMBDHV.X*-I") IUPAC_ALPHABET = list("ACGUNRYSWKMBDHV.") STREAMLINE_ALPHABET = list("ACGUN") NUCLEOBASE_ALPHABET = list("ACGU") ALPHABETS = { "standard": STANDARD_ALPHABET, "iupac": IUPAC_ALPHABET, "streamline": STREAMLINE_ALPHABET, "nucleobase": NUCLEOBASE_ALPHABET, } VOCAB_MAPPING = { "R": "AG", "Y": "CU", "S": "CG", "W": "AU", "K": "GU", "M": "AC", "B": "CGU", "D": "AGU", "H": "ACU", "V": "ACG", "X": "ACGU", } TOKENIZER_CONFIG = { "tokenizer_class": "RnaTokenizer", "clean_up_tokenization_spaces": True, } def get_alphabet(alphabet: List[str] | str | None = None, nmers: int = 1, **kwargs) -> Alphabet: if alphabet is None: alphabet = STANDARD_ALPHABET if nmers <= 1 else STREAMLINE_ALPHABET elif isinstance(alphabet, str): alphabet = ALPHABETS[alphabet] return Alphabet(alphabet, nmers=nmers, **kwargs) def get_vocab_mapping(): return VOCAB_MAPPING def get_special_tokens_map(): return SPECIAL_TOKENS_MAP def get_tokenizer_config(add_special_tokens: bool = False): config = TOKENIZER_CONFIG if add_special_tokens: config.setdefault("added_tokens_decoder", {}) for i, v in enumerate(SPECIAL_TOKENS_MAP.values()): config["added_tokens_decoder"][str(i)] = v # type: ignore[index] return config class Alphabet: prepend_tokens: Tuple[str, ...] = ("", "", "", "", "", "") append_tokens: Tuple[str, ...] = () tokens: Tuple[str, ...] nmers: int def __init__( self, tokens: Sequence[str], prepend_tokens: Tuple[str, ...] | None = None, append_tokens: Tuple[str, ...] | None = None, nmers: int = 1, ): if isinstance(tokens, Alphabet): tokens = tokens.tokens self.tokens = tuple(tokens) if prepend_tokens is not None: self.prepend_tokens = tuple(prepend_tokens) if append_tokens is not None: self.append_tokens = tuple(append_tokens) self.nmers = nmers @property def vocabulary(self) -> Tuple[str, ...]: return self._vocabulary(self.prepend_tokens, self.tokens, self.nmers, self.append_tokens) @staticmethod @lru_cache(maxsize=None) def _vocabulary( prepend_tokens: Tuple[str, ...], tokens: Tuple[str, ...], nmers: int, append_tokens: Tuple[str, ...] ) -> Tuple[str, ...]: return prepend_tokens + generate_kmer_vocabulary(tokens, nmers) + append_tokens def __iter__(self): return iter(self.vocabulary) def __len__(self): return len(self.vocabulary) def __contains__(self, item: str): return item in self.vocabulary def __repr__(self) -> str: repr_parts = [f"Alphabet(tokens={self.tokens}"] if self.nmers > 1: repr_parts.append(f"nmers={self.nmers}") repr_parts.append(f"prepend_tokens={self.prepend_tokens}") repr_parts.append(f"append_tokens={self.append_tokens})") return ", ".join(repr_parts) def _merge_extra_special_tokens( additional_special_tokens: List | Tuple | None, kwargs: dict[str, Any], ) -> List | Tuple | None: if "extra_special_tokens" not in kwargs: return additional_special_tokens extra_special_tokens = kwargs.pop("extra_special_tokens") if additional_special_tokens is None: merged_special_tokens = [] else: merged_special_tokens = list(additional_special_tokens) if isinstance(extra_special_tokens, dict): extra_tokens = list(extra_special_tokens.values()) elif isinstance(extra_special_tokens, (list, tuple)): extra_tokens = list(extra_special_tokens) else: raise TypeError( f"extra_special_tokens must be dict, list, or tuple, got {type(extra_special_tokens).__name__}" ) for token in extra_tokens: token_value = token if isinstance(token, dict) and "content" in token: token_value = token["content"] if token_value not in merged_special_tokens: merged_special_tokens.append(token_value) return merged_special_tokens def generate_kmer_vocabulary(vocabulary: Tuple[str, ...], nmers: int = 1) -> Tuple[str, ...]: """ Generates a kmer vocabulary given an original vocabulary and the size of kmer. Args: vocabulary (List[str]): The original vocabulary. nmers (int, defaults to 1): The size of kmer to generate. Returns: vocabulary (List[str]): The kmer vocabulary. """ if nmers <= 1: return vocabulary special_tokens, tokens = [], [] for token in vocabulary: if token.startswith("<") or token.startswith("["): special_tokens.append(token) else: tokens.append(token) return tuple(special_tokens) + tuple("".join(kmer) for kmer in product(tokens, repeat=nmers)) class Tokenizer(PreTrainedTokenizer): """ Constructs a Base tokenizer. Args: alphabet: List of tokens or an Alphabet object to use in tokenization. Either alphabet or vocab_file must be specified. bos_token: A special token representing the beginning of a sequence. cls_token: A special token representing the classification token. pad_token: A special token representing padding. eos_token: A special token representing the end of a sequence. sep_token: A special token representing the separator token. unk_token: A special token representing unknown tokens. mask_token: A special token representing the mask token. null_token: A special token representing the null token. additional_special_tokens: Additional special tokens to add to the vocabulary. do_upper_case: Whether to convert input to uppercase. vocab_file: Path to a vocabulary file. Either alphabet or vocab_file must be specified. Examples: >>> from multimolecule.tokenisers import Tokenizer >>> tokenizer = Tokenizer(["A", "C", "G", "T", "N"], unk_token="N") >>> tokenizer('ACGTN')["input_ids"] [0, 1, 2, 3, 4] >>> tokenizer('acgtn')["input_ids"] [0, 1, 2, 3, 4] >>> len(tokenizer) 5 >>> tokenizer = Tokenizer(["A", "C", "G", "T", "N"], unk_token="N", do_upper_case=False) >>> tokenizer('ACGTN')["input_ids"] [0, 1, 2, 3, 4] >>> tokenizer('acgtn')["input_ids"] [4, 4, 4, 4, 4] >>> tokenizer('ACgtN')["input_ids"] [0, 1, 4, 4, 4] >>> tokenizer = Tokenizer(["", "", "A", "C", "G", "T", "N", "", ""]) >>> tokenizer('ACGTN')["input_ids"] [1, 2, 3, 4, 5, 6, 8] >>> tokenizer('ACGTN')["input_ids"] [1, 2, 3, 7, 4, 5, 6, 8] >>> tokenizer(['TATATAT', 'ATCGN'], padding=True)["input_ids"] [[1, 5, 2, 5, 2, 5, 2, 5, 8], [1, 2, 5, 3, 4, 6, 8, 0, 0]] """ model_input_names = ["input_ids", "attention_mask"] vocab_files_names = VOCAB_FILES_NAMES do_upper_case: bool = True def __init__( self, alphabet: Alphabet | List[str] | None = None, bos_token: str | None = ..., # type: ignore[assignment] cls_token: str | None = ..., # type: ignore[assignment] pad_token: str | None = ..., # type: ignore[assignment] eos_token: str | None = ..., # type: ignore[assignment] sep_token: str | None = ..., # type: ignore[assignment] unk_token: str | None = ..., # type: ignore[assignment] mask_token: str | None = ..., # type: ignore[assignment] null_token: str | None = ..., # type: ignore[assignment] additional_special_tokens: List | Tuple | None = None, do_upper_case: bool = True, vocab_file: str | None = None, **kwargs, ): if alphabet is None and vocab_file is None: raise ValueError("You must specify either alphabet or vocab_file") if vocab_file is not None: alphabet = self.load_vocabulary(vocab_file) self._id_to_token: OrderedDict[int, str] = OrderedDict(enumerate(alphabet)) self._token_to_id: OrderedDict[str, int] = OrderedDict({tok: ind for ind, tok in enumerate(alphabet)}) if cls_token is ...: cls_token = self.identify_special_token(alphabet, "cls") if bos_token is ...: bos_token = cls_token if pad_token is ...: pad_token = self.identify_special_token(alphabet, "pad") if eos_token is ...: eos_token = self.identify_special_token(alphabet, "eos") if sep_token is ...: sep_token = self.identify_special_token(alphabet, "sep") or self.identify_special_token(alphabet, "eos") if unk_token is ...: unk_token = self.identify_special_token(alphabet, "unk") if mask_token is ...: mask_token = self.identify_special_token(alphabet, "mask") if null_token is ...: null_token = self.identify_special_token(alphabet, "null") additional_special_tokens = _merge_extra_special_tokens(additional_special_tokens, kwargs) if additional_special_tokens is None: additional_special_tokens = [] if null_token in alphabet and null_token not in additional_special_tokens: # type: ignore[operator] additional_special_tokens = list(additional_special_tokens) additional_special_tokens.append(null_token) super().__init__( bos_token=bos_token, cls_token=cls_token, pad_token=pad_token, eos_token=eos_token, sep_token=sep_token, unk_token=unk_token, mask_token=mask_token, additional_special_tokens=additional_special_tokens, **kwargs, ) self.do_upper_case = do_upper_case self._id_to_token.update(self.added_tokens_decoder) self._token_to_id.update(self.added_tokens_encoder) # TODO, all the tokens are added? But they are also part of the vocab... bit strange. # none of them are special, but they all need special splitting. # self.unique_no_split_tokens = self.all_tokens # self._update_trie(self.unique_no_split_tokens) def _tokenize(self, text: str, **kwargs): if self.do_upper_case: text = text.upper() return list(text) def _convert_token_to_id(self, token: str) -> int: id = self._token_to_id.get(token, self.unk_token_id) if id is None: raise ValueError(f"Token {token} is not in the vocabulary, and no UNK token is set!") return id def _convert_id_to_token(self, index: int) -> str: token = self._id_to_token.get(index, self.unk_token) if token is None: raise ValueError(f"ID {index} is not in the vocabulary, and no UNK token is set!") return token def token_to_id(self, token: str) -> int: return self._convert_token_to_id(token) def id_to_token(self, index: int) -> str: return self._convert_id_to_token(index) def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: List[int] | None = None ) -> List[int]: bos = [self.bos_token_id] # points to cls sep = [self.sep_token_id] # points to eos eos = [self.eos_token_id] # eos is eos if token_ids_1 is None: if self.bos_token_id is None: if self.eos_token_id is None: return token_ids_0 return token_ids_0 + eos if self.eos_token_id is None: return bos + token_ids_0 return bos + token_ids_0 + eos if self.eos_token_id is None: raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") return bos + token_ids_0 + sep + token_ids_1 + eos def get_special_tokens_mask( self, token_ids_0: List[int], token_ids_1: List[int] | None = None, already_has_special_tokens: bool = False ) -> List[int]: """ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. Args: token_ids_0 (`List[int]`): List of ids of the first sequence. token_ids_1 (`List[int]`, *optional*): List of ids of the second sequence. already_has_special_tokens (`bool`, *optional*, defaults to `False`): Whether or not the token list is already formatted with special tokens for the model. Returns: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. """ if already_has_special_tokens: if token_ids_1 is not None: raise ValueError( "You should not supply a second sequence if the provided sequence of " "ids is already formatted with special tokens for the model." ) return [1 if token in self.all_special_ids else 0 for token in token_ids_0] mask = [0] * len(token_ids_0) if self.bos_token_id is not None: mask = [1] + mask if self.sep_token_id is not None: mask += [1] if token_ids_1 is not None: mask += [0] * len(token_ids_1) if self.eos_token_id is not None: mask += [1] return mask @staticmethod def load_vocabulary(vocab_file: str | Path) -> List[str]: with open(vocab_file, encoding="utf-8") as reader: vocabulary = reader.read().splitlines() return vocabulary def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None): vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") with open(vocab_file, "w") as f: f.write("\n".join(self.all_tokens)) return (vocab_file,) @staticmethod def identify_special_token(alphabet: Alphabet | List[str], token) -> str | None: tokens = [i for i in alphabet if token in i.lower()] if len(tokens) == 1: return tokens[0] if len(tokens) == 0: return None raise ValueError(f"Token {token} is ambiguous, could be {tokens}") def get_vocab(self): return dict(self.vocab, **self.added_tokens_encoder) @property def vocab(self) -> OrderedDict[str, int]: return self._token_to_id.copy() @property def all_tokens(self) -> List[str]: return list(self.get_vocab().keys()) @property def vocab_size(self) -> int: return len(self.all_tokens) class RnaTokenizer(Tokenizer): """ Tokenizer for RNA sequences. Args: alphabet: alphabet to use for tokenization. - If is `None`, the standard RNA alphabet will be used. - If is a `string`, it should correspond to the name of a predefined alphabet. The options include + `standard` + `extended` + `streamline` + `nucleobase` - If is an alphabet or a list of characters, that specific alphabet will be used. nmers: Size of kmer to tokenize. codon: Whether to tokenize into codons. replace_T_with_U: Whether to replace T with U. do_upper_case: Whether to convert input to uppercase. Examples: >>> from multimolecule import RnaTokenizer >>> tokenizer = RnaTokenizer() >>> tokenizer('ACGUNRYSWKMBDHV.X*-I')["input_ids"] [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2] >>> tokenizer('acgu')["input_ids"] [1, 6, 7, 8, 9, 2] >>> tokenizer('acgt')["input_ids"] [1, 6, 7, 8, 9, 2] >>> tokenizer = RnaTokenizer(replace_T_with_U=False) >>> tokenizer('acgt')["input_ids"] [1, 6, 7, 8, 3, 2] >>> tokenizer = RnaTokenizer(nmers=3) >>> tokenizer('uagcuuauc')["input_ids"] [1, 83, 17, 64, 49, 96, 84, 22, 2] >>> tokenizer = RnaTokenizer(codon=True) >>> tokenizer('uagcuuauc')["input_ids"] [1, 83, 49, 22, 2] >>> tokenizer('uagcuuauca')["input_ids"] Traceback (most recent call last): ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10 """ model_input_names = ["input_ids", "attention_mask"] def __init__( self, alphabet: Alphabet | str | List[str] | None = None, nmers: int = 1, codon: bool = True, replace_T_with_U: bool = True, do_upper_case: bool = True, additional_special_tokens: List | Tuple | None = None, **kwargs, ): if codon and (nmers > 1 and nmers != 3): raise ValueError("Codon and nmers cannot be used together.") if codon: nmers = 3 # set to 3 to get correct vocab if not isinstance(alphabet, Alphabet): alphabet = get_alphabet(alphabet, nmers=nmers) additional_special_tokens = _merge_extra_special_tokens(additional_special_tokens, kwargs) super().__init__( alphabet=alphabet, nmers=nmers, codon=codon, replace_T_with_U=replace_T_with_U, do_upper_case=do_upper_case, additional_special_tokens=additional_special_tokens, **kwargs, ) self.replace_T_with_U = replace_T_with_U self.nmers = nmers self.codon = codon def _tokenize(self, text: str, **kwargs): if self.do_upper_case: text = text.upper() if self.replace_T_with_U: text = text.replace("T", "U") if self.codon: if len(text) % 3 != 0: raise ValueError( f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}" ) return [text[i : i + 3] for i in range(0, len(text), 3)] if self.nmers > 1: return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)] # noqa: E203 return list(text) class RotaryEmbedding(nn.Module): """ Rotary position embeddings based on those in [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation matrices which depend on their relative positions. Tip: **Cache** The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes. Success: **Sequence Length** Rotary Embedding is irrespective of the sequence length and can be used for any sequence length. Use the `scale` parameter to extend context length beyond training (e.g., scale=2.0 doubles effective context). Example: >>> embedding = RotaryEmbedding(embedding_dim=64) >>> query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) >>> query, key = embedding(query, key) >>> query.shape torch.Size([2, 4, 28, 64]) >>> # For extended context length >>> embedding_extended = RotaryEmbedding(embedding_dim=64, scale=2.0) >>> embedding.state_dict() # no weight in state_dict OrderedDict() """ _seq_len_cached: int | None = None _cos_cached: Tensor | None = None _sin_cached: Tensor | None = None def __init__( self, embedding_dim: int, base: float = 10000.0, scale: float = 1.0, dtype: torch.dtype = torch.float32, ): """ Initialize rotary position embeddings. Args: embedding_dim: Dimension of the embeddings (must be even) base: Base for computing inverse frequencies. Defaults to 10000.0. scale: Scaling factor for frequencies. Values > 1.0 extend context length (e.g., scale=2.0 doubles the effective context). Defaults to 1.0. dtype: Data type for computations. Defaults to torch.float32. """ super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, embedding_dim, 2, dtype=dtype) / embedding_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.scale = scale def forward(self, q: Tensor, k: Tensor, offset: int = 0, seq_length: int | None = None) -> Tuple[Tensor, Tensor]: """ Apply rotary position embeddings to query and key tensors. Args: q: Query tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` k: Key tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` offset: Position offset for the start of the sequence (used with past_key_values). Defaults to 0. seq_length: Full sequence length including offset. If None, uses the sequence length from the input tensors. Required when offset > 0. Returns: Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs. """ if offset > 0 and seq_length is None: raise ValueError("seq_length must be provided when offset > 0") if seq_length is None: seq_length = k.shape[-2] self._update_cos_sin_tables(k, seq_len_dim=-2, seq_length=seq_length) return self.apply_rotary_pos_emb(q, offset=offset), self.apply_rotary_pos_emb(k, offset=offset) def _update_cos_sin_tables( self, x: Tensor, seq_len_dim: int = 2, seq_length: int | None = None ) -> Tuple[Tensor, Tensor]: """ Update cached cos/sin tables for rotary embeddings. Args: x: Input tensor to determine device and dtype seq_len_dim: Dimension containing sequence length (default: -2) seq_length: Full sequence length to cache. If None, uses x.shape[seq_len_dim] """ if seq_length is None: seq_length = x.shape[seq_len_dim] if seq_length != self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != x.device: self._seq_len_cached = seq_length inv_freq = self.inv_freq if not isinstance(inv_freq, Tensor): raise RuntimeError("inv_freq buffer is not a Tensor") t = torch.arange(seq_length, device=x.device, dtype=inv_freq.dtype) # Apply scaling: divide frequencies by scale to extend context length freqs = torch.outer(t, inv_freq) / self.scale emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self._cos_cached = emb.cos()[None, None, :, :] self._sin_cached = emb.sin()[None, None, :, :] # At this point, _cos_cached and _sin_cached are guaranteed to be Tensor assert self._cos_cached is not None and self._sin_cached is not None return self._cos_cached, self._sin_cached def apply_rotary_pos_emb(self, x: Tensor, offset: int = 0) -> Tensor: """ Apply rotary position embeddings to a tensor. Args: x: Input tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` offset: Position offset for the start of the sequence (used with past_key_values). Defaults to 0. Returns: Rotated tensor with the same shape as input. """ if self._cos_cached is None or self._sin_cached is None: raise RuntimeError("Cos/sin tables not initialized. Call forward() or _update_cos_sin_tables() first.") cos = self._cos_cached[:, :, offset : offset + x.shape[-2], :] sin = self._sin_cached[:, :, offset : offset + x.shape[-2], :] return (x * cos) + (self.rotate_half(x) * sin) @staticmethod def rotate_half(x: Tensor) -> Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1)