import os import torch import torch.distributed as dist from tokenizers import Tokenizer from torch.nn.utils.rnn import pad_sequence from typing import Any from dataclasses import dataclass PAD_TOKEN_ID = 0 def get_tokenizer() -> Tokenizer: fname = os.path.join(os.path.dirname(__file__), "e1_tokenizer.json") tokenizer: Tokenizer = Tokenizer.from_file(fname) assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" ) return tokenizer def is_dist_initialized() -> bool: return dist.is_available() and dist.is_initialized() def get_world_size(group: Any = None) -> int: if os.environ.get("RANK", -1) == -1 or not is_dist_initialized(): return 1 return dist.get_world_size(group=group) def get_rank(group: Any = None) -> int: if os.environ.get("RANK", -1) == -1 or not is_dist_initialized(): return 0 return dist.get_rank(group=group) def get_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda", torch.cuda.current_device()) return torch.device("cpu") def get_local_rank() -> int: return int(os.environ.get("LOCAL_RANK", 0)) if is_dist_initialized() else 0 def setup_dist() -> None: rank = int(os.environ.get("RANK", -1)) if dist.is_available() and torch.cuda.is_available() and rank != -1: torch.distributed.init_process_group(backend="nccl") torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0))) def destroy_process_group() -> None: if is_dist_initialized(): dist.destroy_process_group() def barrier() -> None: if is_dist_initialized(): dist.barrier() @dataclass class DataPrepConfig: max_num_sequences: int = 512 max_num_positions_within_seq: int = 8192 remove_X_tokens: bool = False def get_context(sequence: str) -> str | None: if "," in sequence: return sequence.rsplit(",", 1)[0] return None class E1BatchPreparer: def __init__( self, data_prep_config: DataPrepConfig | None = None, tokenizer: Tokenizer | None = None, preserve_context_labels: bool = False, device: torch.device | None = None, ): self.tokenizer = tokenizer or get_tokenizer() self.data_prep_config = data_prep_config or DataPrepConfig() self.pad_token_id = self.tokenizer.token_to_id("") self.preserve_context_labels = preserve_context_labels self.boundary_token_ids = torch.tensor( [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=(device or get_device()) ).long() self.mask_token = "?" # nosec self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) self.X_token_id = self.tokenizer.token_to_id("X") self.vocab = self.tokenizer.get_vocab() def get_batch_kwargs( # type: ignore[override] self, sequences: list[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False ) -> dict[str, torch.Tensor | list[str] | list[int]]: sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] return self.pad_encodings(sequence_encodings, device, non_blocking) def pad_encodings( self, sequence_encodings: list[dict[str, torch.Tensor]], device: torch.device = torch.device("cpu"), non_blocking: bool = False, ) -> dict[str, torch.Tensor | list[str] | list[int]]: non_blocking = non_blocking and device.type == "cuda" padded_encodings = {} # Note: We use -1 as the padding value for sequence and position ids because the 0 value # is a valid value for sequence and position ids. -1 is then used to distinguish valid # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. for key, padding_value in { "input_ids": self.pad_token_id, "sequence_ids": -1, "within_seq_position_ids": -1, "global_position_ids": -1, "labels": self.pad_token_id, }.items(): padded_encodings[key] = pad_sequence( [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value ).to(device=device, dtype=torch.long, non_blocking=non_blocking) padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] return padded_encodings def prepare_multiseq(self, sequence: str) -> dict[str, torch.Tensor | str | int]: single_sequences = sequence.split(",") if len(single_sequences) > self.data_prep_config.max_num_sequences: raise ValueError( f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." ) single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) labels = torch.cat([x["labels"] for x in single_sequence_encodings]) within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) global_position_ids, ctx_len = [], 0 for encoding in single_sequence_encodings: global_position_ids.append(encoding["position_ids"] + ctx_len) ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) global_position_ids = torch.cat(global_position_ids) sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired context_len = sum(num_tokens[:-1]) context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) if not self.preserve_context_labels: labels[:context_len] = self.pad_token_id assert ( input_ids.shape == sequence_ids.shape == within_seq_position_ids.shape == global_position_ids.shape == labels.shape ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" return { "input_ids": input_ids, "sequence_ids": sequence_ids, "within_seq_position_ids": within_seq_position_ids, "global_position_ids": global_position_ids, "labels": labels, "context": context, "context_len": context_len, } def prepare_singleseq(self, sequence: str) -> dict[str, torch.Tensor]: if not self.validate_sequence(sequence): raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") if len(sequence) > self.data_prep_config.max_num_positions_within_seq: raise ValueError( f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" ) # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` # but following is faster since our vocabulary is simple. tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) position_ids = torch.arange(len(tokens)) if self.data_prep_config.remove_X_tokens: X_positions = torch.where(tokens != self.X_token_id)[0] tokens = tokens[X_positions] position_ids = position_ids[X_positions] return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return torch.isin(tokens, self.boundary_token_ids.to(tokens.device)) def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return tokens == self.mask_token_id def validate_sequence(self, sequence: str) -> bool: assert isinstance(sequence, str), "Sequence must be a string" sequence = sequence.replace(self.mask_token, "") return sequence.isalpha() and sequence.isupper()