diff --git "a/modeling_esm_plusplus.py" "b/modeling_esm_plusplus.py" --- "a/modeling_esm_plusplus.py" +++ "b/modeling_esm_plusplus.py" @@ -1,2351 +1,2548 @@ -from __future__ import annotations - -import torch -import torch._inductor.config as inductor_config -import torch._dynamo as dynamo - -# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) -# Provides significant speedup with minimal precision loss -torch.set_float32_matmul_precision('high') - -# Enable TF32 for matrix multiplications and cuDNN operations -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -# Enable cuDNN autotuner - finds fastest algorithms for your hardware -# Best when input sizes are consistent; may slow down first iterations -torch.backends.cudnn.benchmark = True - -# Deterministic operations off for speed (set True if reproducibility needed) -torch.backends.cudnn.deterministic = False -inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" - -dynamo.config.capture_scalar_outputs = True -torch._dynamo.config.recompile_limit = 16 - -import io -import os -import queue -import sqlite3 -import struct -import threading -import time - -import networkx as nx -import numpy as np -import torch -from tqdm.auto import tqdm -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple -from torch.utils.data import DataLoader -from torch.utils.data import Dataset as TorchDataset -from transformers import PreTrainedTokenizerBase - - -# Compact blob serialization constants -# Canonical source: core/embed/blob.py. Keep in sync with protify/utils.py. -_COMPACT_VERSION = 0x01 -_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2} -_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32} -_CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32} - - -def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes: - """Serialize a tensor to compact binary format for SQLite blob storage. - - Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes] - bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16) - but tagged with dtype_code=1 so they can be cast back on read. - Falls back to torch.save for unsupported dtypes. - """ - t = tensor.cpu() - if t.dtype not in _DTYPE_TO_CODE: - buffer = io.BytesIO() - torch.save(t, buffer) - return buffer.getvalue() - dtype_code = _DTYPE_TO_CODE[t.dtype] - - if t.dtype == torch.bfloat16: - raw = t.half().numpy().tobytes() - else: - raw = t.numpy().tobytes() - - shape = t.shape - header = struct.pack(f' bytes: - """Build just the compact header for a given dtype and shape.""" - dtype_code = _DTYPE_TO_CODE[dtype] - return struct.pack(f' List[bytes]: - """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors). - - Builds the header once and slices raw bytes per row. Much faster than - per-row tensor_to_embedding_blob calls for uniform-shape batches. - """ - assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}" - t = batch.cpu() - store_dtype = t.dtype - if t.dtype not in _DTYPE_TO_CODE: - return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])] - - if t.dtype == torch.bfloat16: - arr = t.half().numpy() - store_dtype = torch.bfloat16 - else: - arr = t.numpy() - - row_shape = tuple(t.shape[1:]) - header = _compact_header(store_dtype, row_shape) - raw = arr.tobytes() - stride = len(raw) // t.shape[0] - return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])] - - -def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: - """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats.""" - if len(blob) >= 6 and blob[0] == _COMPACT_VERSION: - dtype_code = blob[1] - ndim = struct.unpack_from(' torch.nn.Module: - """Compile model with torch.compile if possible. - - Skips compilation when dynamic=True (padding='longest') because - flex attention's create_block_mask is incompatible with dynamic shapes - under torch.compile, causing CUDA illegal memory access. - """ - if dynamic: - print("Skipping torch.compile (dynamic shapes + flex attention incompatible)") - return model - try: - model = torch.compile(model) - print("Model compiled") - except Exception as e: - print(f"Skipping torch.compile: {e}") - return model - - -def build_collator( - tokenizer: PreTrainedTokenizerBase, - padding: str = 'max_length', - max_length: int = 512, -) -> Callable[[List[str]], Dict[str, torch.Tensor]]: - def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: - kwargs: Dict[str, Any] = dict( - return_tensors="pt", padding=padding, truncation=True, max_length=max_length, - ) - if padding != 'max_length': - kwargs['pad_to_multiple_of'] = 8 - return tokenizer(sequences, **kwargs) - return _collate_fn - - -def _make_embedding_progress( - dataloader: DataLoader, - padding: str, - n_warmup: int = 3, - n_calibration: int = 5, -) -> Iterator[Tuple[int, Any]]: - """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader). - - When padding='max_length', all batches have uniform cost so plain tqdm works. - When padding='longest' (sorted longest-first), batch times vary dramatically. - In that case: yield warmup batches first (compiler warmup + OOM check on longest - sequences), then time mid-length calibration batches to estimate total ETA. - - Keep in sync with protify/embedder.py and core/atlas/precomputed.py. - """ - total = len(dataloader) - if padding == 'max_length' or total <= n_warmup + n_calibration: - for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'): - yield i, batch - return - - dl_iter = iter(dataloader) - - # Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first) - warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False) - for i in warmup_bar: - batch = next(dl_iter) - yield i, batch - warmup_bar.close() - - # Phase 2: skip to middle of dataset for calibration timing - # We need to yield all intermediate batches too (they contain real data) - mid_start = total // 2 - intermediate_bar = tqdm( - range(n_warmup, mid_start), desc='Embedding batches', leave=False, - ) - for i in intermediate_bar: - batch = next(dl_iter) - yield i, batch - intermediate_bar.close() - - # Phase 3: time calibration batches from the middle - calibration_times: List[float] = [] - cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False) - for j in cal_bar: - t0 = time.perf_counter() - batch = next(dl_iter) - yield mid_start + j, batch - calibration_times.append(time.perf_counter() - t0) - cal_bar.close() - - avg_time = sum(calibration_times) / len(calibration_times) - remaining_start = mid_start + n_calibration - remaining_count = total - remaining_start - estimated_total_seconds = avg_time * remaining_count - - # Phase 4: remaining batches with calibrated ETA - main_bar = tqdm( - range(remaining_count), - desc='Embedding batches', - bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', - ) - main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)') - for k in main_bar: - batch = next(dl_iter) - yield remaining_start + k, batch - main_bar.close() - - -class _SQLWriter: - """Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter.""" - - def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None: - self._conn = conn - self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize) - self._thread: Optional[threading.Thread] = None - - def __enter__(self) -> "_SQLWriter": - self._thread = threading.Thread(target=self._writer_loop, daemon=True) - self._thread.start() - return self - - def write_batch(self, rows: List[Tuple[str, bytes]]) -> None: - self._queue.put(rows) - - def _writer_loop(self) -> None: - cursor = self._conn.cursor() - while True: - item = self._queue.get() - if item is None: - break - cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item) - if self._queue.qsize() == 0: - self._conn.commit() - self._conn.commit() - - def __exit__(self, *exc) -> None: - if self._thread is not None: - self._queue.put(None) - self._thread.join() - self._thread = None - - -class Pooler: - def __init__(self, pooling_types: List[str]) -> None: - self.pooling_types = pooling_types - self.pooling_options: Dict[str, Callable] = { - 'mean': self.mean_pooling, - 'max': self.max_pooling, - 'norm': self.norm_pooling, - 'median': self.median_pooling, - 'std': self.std_pooling, - 'var': self.var_pooling, - 'cls': self.cls_pooling, - 'parti': self._pool_parti, - } - - def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: - assert isinstance(attentions, torch.Tensor) - maxed_attentions = torch.max(attentions, dim=1)[0] - return maxed_attentions - - def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: - G = self._convert_to_graph(attention_matrix) - if G.number_of_nodes() != attention_matrix.shape[0]: - raise Exception( - f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") - if G.number_of_edges() == 0: - raise Exception(f"You don't seem to have any attention edges left in the graph.") - - return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) - - def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: - G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) - return G - - def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: - if attention_mask is not None: - for k in list(dict_importance.keys()): - if attention_mask[k] == 0: - del dict_importance[k] - - total = sum(dict_importance.values()) - return np.array([v / total for _, v in dict_importance.items()]) - - def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() - emb_pooled = [] - for e, a, mask in zip(emb, maxed_attentions, attention_mask): - dict_importance = self._page_rank(a) - importance_weights = self._calculate_importance_weights(dict_importance, mask) - num_tokens = int(mask.sum().item()) - emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) - pooled = torch.tensor(np.array(emb_pooled)) - return pooled - - def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.mean(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - - def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.max(dim=1).values - else: - mask = attention_mask.unsqueeze(-1).bool() - return emb.masked_fill(~mask, float('-inf')).max(dim=1).values - - def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.norm(dim=1, p=2) - else: - attention_mask = attention_mask.unsqueeze(-1) - return (emb * attention_mask).norm(dim=1, p=2) - - def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.median(dim=1).values - else: - mask = attention_mask.unsqueeze(-1).bool() - return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values - - def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.std(dim=1) - else: - var = self.var_pooling(emb, attention_mask, **kwargs) - return torch.sqrt(var) - - def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - if attention_mask is None: - return emb.var(dim=1) - else: - attention_mask = attention_mask.unsqueeze(-1) - mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - mean = mean.unsqueeze(1) - squared_diff = (emb - mean) ** 2 - var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) - return var - - def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: - return emb[:, 0, :] - - def __call__( - self, - emb: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attentions: Optional[torch.Tensor] = None - ) -> torch.Tensor: - if attention_mask is not None: - assert attention_mask.sum(dim=-1).min() > 0, ( - "Pooler received samples with all-zero attention masks. " - "This causes NaN from division by zero. Filter empty inputs before pooling." - ) - final_emb: List[torch.Tensor] = [] - for pooling_type in self.pooling_types: - final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) - return torch.cat(final_emb, dim=-1) - - -class ProteinDataset(TorchDataset): - """Simple dataset for protein sequences.""" - def __init__(self, sequences: List[str]) -> None: - self.sequences = sequences - - def __len__(self) -> int: - return len(self.sequences) - - def __getitem__(self, idx: int) -> str: - return self.sequences[idx] - - -def parse_fasta(fasta_path: str) -> List[str]: - assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" - sequences = [] - current_seq = [] - with open(fasta_path, 'r') as f: - for line in f: - line = line.strip() - if not line: - continue - if line.startswith('>'): - if current_seq: - sequences.append(''.join(current_seq)) - current_seq = [] - else: - current_seq.append(line) - if current_seq: - sequences.append(''.join(current_seq)) - return sequences - - -class EmbeddingMixin: - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - raise NotImplementedError - - @property - def device(self) -> torch.device: - """Get the device of the model.""" - return next(self.parameters()).device - - def _read_sequences_from_db(self, db_path: str) -> Set[str]: - """Read sequences from SQLite database.""" - with sqlite3.connect(db_path, timeout=30) as conn: - c = conn.cursor() - c.execute("SELECT sequence FROM embeddings") - return {row[0] for row in c.fetchall()} - - def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: - cursor = conn.cursor() - cursor.execute( - "CREATE TABLE IF NOT EXISTS embeddings (" - "sequence TEXT PRIMARY KEY, " - "embedding BLOB NOT NULL" - ")" - ) - conn.commit() - - def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: - assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" - payload = torch.load(save_path, map_location="cpu", weights_only=True) - assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." - for sequence, tensor in payload.items(): - assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." - assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." - return payload - - def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: - assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" - loaded: Dict[str, torch.Tensor] = {} - with sqlite3.connect(db_path, timeout=30) as conn: - self._ensure_embeddings_table(conn) - cursor = conn.cursor() - if sequences is None: - cursor.execute("SELECT sequence, embedding FROM embeddings") - else: - if len(sequences) == 0: - return loaded - placeholders = ",".join(["?"] * len(sequences)) - cursor.execute( - f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})", - tuple(sequences), - ) - - rows = cursor.fetchall() - for row in rows: - sequence = row[0] - embedding_bytes = row[1] - loaded[sequence] = embedding_blob_to_tensor(embedding_bytes) - return loaded - - def embed_dataset( - self, - sequences: Optional[List[str]] = None, - tokenizer: Optional[PreTrainedTokenizerBase] = None, - batch_size: int = 2, - max_len: int = 512, - truncate: bool = True, - full_embeddings: bool = False, - embed_dtype: torch.dtype = torch.float32, - pooling_types: List[str] = ['mean'], - num_workers: int = 0, - sql: bool = False, - save: bool = True, - sql_db_path: str = 'embeddings.db', - save_path: str = 'embeddings.pth', - fasta_path: Optional[str] = None, - padding: str = 'max_length', - **kwargs, - ) -> Optional[Dict[str, torch.Tensor]]: - """ - Embed a dataset of protein sequences. - - Supports two modes: - - Tokenizer mode (ESM2/ESM++): provide `tokenizer` or use `self.tokenizer`. - - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. - - Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via - `fasta_path`, or both (the two sources are combined). At least one must be provided. - """ - if fasta_path is not None: - fasta_sequences = parse_fasta(fasta_path) - sequences = list(sequences or []) + fasta_sequences - assert sequences is not None and len(sequences) > 0, \ - "Must provide at least one sequence via `sequences` or `fasta_path`." - sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) - sequences = sorted(sequences, key=len, reverse=True) - hidden_size = self.config.hidden_size - pooler = Pooler(pooling_types) if not full_embeddings else None - if tokenizer is None and self.config.model_type != "E1": - tokenizer = self.tokenizer - tokenizer_mode = tokenizer is not None - - # Resolve padding and compilation - dynamic = padding == 'longest' - compiled_model = maybe_compile(self, dynamic=dynamic) - - if tokenizer_mode: - collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len) - device = self.device - else: - collate_fn = None - device = None - - def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - assert isinstance(residue_embeddings, torch.Tensor) - if full_embeddings or residue_embeddings.ndim == 2: - return residue_embeddings - return pooler(residue_embeddings, attention_mask) - - def iter_batches(to_embed: List[str]): - if tokenizer_mode: - assert collate_fn is not None - assert device is not None - dataset = ProteinDataset(to_embed) - dataloader = DataLoader( - dataset, - batch_size=batch_size, - num_workers=num_workers, - prefetch_factor=2 if num_workers > 0 else None, - collate_fn=collate_fn, - shuffle=False, - pin_memory=True, - ) - for i, batch in _make_embedding_progress(dataloader, padding): - seqs = to_embed[i * batch_size:(i + 1) * batch_size] - input_ids = batch['input_ids'].to(device) - attention_mask = batch['attention_mask'].to(device) - residue_embeddings = compiled_model._embed(input_ids, attention_mask) - yield seqs, residue_embeddings, attention_mask - else: - for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): - seqs = to_embed[batch_start:batch_start + batch_size] - batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs) - assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." - assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." - residue_embeddings, attention_mask = batch_output - assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." - yield seqs, residue_embeddings, attention_mask - - if sql: - # Step 1: DEDUPLICATE - check existing embeddings in SQL - conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False) - conn.execute('PRAGMA journal_mode=WAL') - conn.execute('PRAGMA busy_timeout=30000') - conn.execute('PRAGMA synchronous=OFF') - conn.execute('PRAGMA cache_size=-64000') - self._ensure_embeddings_table(conn) - already_embedded = self._read_sequences_from_db(sql_db_path) - to_embed = [seq for seq in sequences if seq not in already_embedded] - print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") - print(f"Embedding {len(to_embed)} new sequences") - if len(to_embed) > 0: - # Steps 4-7: BATCH+EMBED -> POOL/TRIM -> SERIALIZE -> WRITE (async) - with _SQLWriter(conn) as writer: - with torch.inference_mode(): - for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - if full_embeddings: - batch_rows = [] - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size)))) - else: - blobs = batch_tensor_to_blobs(embeddings) - batch_rows = list(zip(seqs, blobs)) - writer.write_batch(batch_rows) - conn.close() - return None - - embeddings_dict = {} - if os.path.exists(save_path): - embeddings_dict = self.load_embeddings_from_pth(save_path) - to_embed = [seq for seq in sequences if seq not in embeddings_dict] - print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") - print(f"Embedding {len(to_embed)} new sequences") - else: - to_embed = sequences - print(f"Embedding {len(to_embed)} new sequences") - - if len(to_embed) > 0: - with torch.inference_mode(): - for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): - embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) - for seq, emb, mask in zip(seqs, embeddings, attention_mask): - if full_embeddings: - emb = emb[mask.bool()].reshape(-1, hidden_size) - embeddings_dict[seq] = emb.cpu() - - if save: - torch.save(embeddings_dict, save_path) - - return embeddings_dict - - -if __name__ == "__main__": - # py -m pooler - pooler = Pooler(pooling_types=['max', 'parti']) - batch_size = 8 - seq_len = 64 - hidden_size = 128 - num_layers = 12 - emb = torch.randn(batch_size, seq_len, hidden_size) - attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) - attention_mask = torch.ones(batch_size, seq_len) - y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) - print(y.shape) - -"""Shared attention infrastructure for all FastPLMs models. - -Contains: AttentionBackend enum, backend resolution, mask creation, -flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. -""" -from enum import Enum -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from torch.nn import functional as F -from einops import rearrange - -try: - from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask -except ImportError: - create_block_mask = None - flex_attention = None - BlockMask = None - -_compiled_flex_attention = None - - -def _get_flex_attention_fn(): - """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" - global _compiled_flex_attention - if flex_attention is None: - return None - flex_mod = torch.nn.attention.flex_attention - if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): - return flex_attention - if _compiled_flex_attention is None: - _compiled_flex_attention = torch.compile( - flex_attention, - dynamic=False, - ) - return _compiled_flex_attention - - -### Kernels Flash Attention Detection -def _infer_kernels_flash_variant(kernel) -> Optional[str]: - if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): - return "flash_attn2" - if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): - return "flash_attn3" - return None - - -def _try_get_kernels_flash(): - try: - from kernels import get_kernel - except ImportError: - return None, None - - flash_kernel = None - flash_kernel_variant = None - try: - flash_kernel = get_kernel("kernels-community/flash-attn3") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." - except Exception: - try: - flash_kernel = get_kernel("kernels-community/flash-attn2") - flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) - assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." - except Exception: - flash_kernel = None - flash_kernel_variant = None - return flash_kernel, flash_kernel_variant - - -_FLASH_KERNELS_LOADED = False -FLASH_KERNEL = None -FLASH_KERNEL_VARIANT = None - - -def _ensure_flash_kernels_loaded(): - global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT - if _FLASH_KERNELS_LOADED: - return - _FLASH_KERNELS_LOADED = True - FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() - - -def _kernels_flash_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - causal: bool = False, - softmax_scale: Optional[float] = None, -) -> torch.Tensor: - """Flash-attention forward, optionally overriding the softmax scale. - - When `softmax_scale is None`, the flash kernel applies its default - `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already - pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold). - Failing to override when Q is pre-scaled produces DOUBLE scaling and - catastrophic downstream drift -- on DPLM-150M (30 layers) this was observed - as pooled-embedding cosine ~-0.12 and argmax agreement ~0.27 vs sdpa. - """ - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.fwd( - q=query_states, k=key_states, v=value_states, - softmax_scale=softmax_scale, is_causal=causal, - )[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_func( - q=query_states, k=key_states, v=value_states, - softmax_scale=softmax_scale, causal=causal, - ) - except TypeError: - output = FLASH_KERNEL.flash_attn_func( - query_states, key_states, value_states, - 0.0, softmax_scale, causal, - ) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -def _kernels_flash_varlen_forward( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_in_batch_q: int, - max_seqlen_in_batch_k: int, - causal: bool = False, - softmax_scale: Optional[float] = None, -) -> torch.Tensor: - """Varlen flash-attention forward, optionally overriding the softmax scale. - - See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be - passed when Q has been pre-scaled by the caller. - """ - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if FLASH_KERNEL_VARIANT == "flash_attn2": - return FLASH_KERNEL.varlen_fwd( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - softmax_scale=softmax_scale, is_causal=causal, - )[0] - if FLASH_KERNEL_VARIANT == "flash_attn3": - try: - output = FLASH_KERNEL.flash_attn_varlen_func( - q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, - softmax_scale=softmax_scale, causal=causal, - ) - except TypeError: - output = FLASH_KERNEL.flash_attn_varlen_func( - query_states, key_states, value_states, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_in_batch_q, max_seqlen_in_batch_k, - 0.0, softmax_scale, causal, - ) - if isinstance(output, tuple): - return output[0] - return output - raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") - - -### Unpad / Pad helpers for varlen flash attention -class IndexFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, input, indices) -> torch.Tensor: - ctx.save_for_backward(indices) - assert input.ndim >= 2 - ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) - ).reshape(-1, *other_shape) - - @staticmethod - def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: - (indices,) = ctx.saved_tensors - assert grad_output.ndim >= 2 - other_shape = grad_output.shape[1:] - grad_output = rearrange(grad_output, "b ... -> b (...)") - grad_input = torch.zeros( - [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype - ) - grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) - return grad_input.reshape(ctx.first_axis_dim, *other_shape), None - - -class IndexPutFirstAxis(torch.autograd.Function): - @staticmethod - def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: - ctx.save_for_backward(indices) - assert indices.ndim == 1 - assert values.ndim >= 2 - output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) - output[indices] = values - return output - - @staticmethod - def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: - (indices,) = ctx.saved_tensors - return grad_output[indices], None, None - - -index_first_axis = IndexFirstAxis.apply -index_put_first_axis = IndexPutFirstAxis.apply - - -def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: - output = index_put_first_axis(hidden_states, indices, batch * seqlen) - return rearrange(output, "(b s) ... -> b s ...", b=batch) - - -def _unpad_input( - query_layer: torch.Tensor, - key_layer: torch.Tensor, - value_layer: torch.Tensor, - attention_mask_2d: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: - batch_size, seq_len, num_heads, head_dim = query_layer.shape - seqlens = attention_mask_2d.sum(dim=1).int() - cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) - max_seqlen = int(seqlens.max().item()) - indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() - query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) - return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) - - -def kernels_flash_attention_func( - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - causal: bool = False, - softmax_scale: Optional[float] = None, -) -> torch.Tensor: - """Public flash-attention entry point with optional padding handling. - - `softmax_scale`: - None -> kernel applies its default `1 / sqrt(head_dim)`. - float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled - by the caller). - - IMPORTANT: if your family multiplies Q by `1/sqrt(head_dim)` before calling - this function (as ESM2, DPLM, DPLM2, E1, and ESMFold do) you MUST pass - `softmax_scale=1.0`. Otherwise the kernel applies its default scale ON TOP - of the caller's, producing effective scale `1/head_dim` and catastrophic - downstream drift that compounds across layers. - """ - assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." - if not causal and attention_mask_2d is not None: - batch_size, q_len = query_states.shape[:2] - ( - query_states, key_states, value_states, - indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), - ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) - attn_output_unpad = _kernels_flash_varlen_forward( - query_states=query_states, key_states=key_states, value_states=value_states, - cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, - softmax_scale=softmax_scale, - ) - return pad_input(attn_output_unpad, indices_q, batch_size, q_len) - else: - return _kernels_flash_forward( - query_states=query_states, key_states=key_states, value_states=value_states, - causal=causal, softmax_scale=softmax_scale, - ) - - -### Attention Backend Enum & Resolution -class AttentionBackend(Enum): - AUTO = "auto" - KERNELS_FLASH = "kernels_flash" - FLEX = "flex" - SDPA = "sdpa" - - -VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) - - -_BACKEND_CONFIRMED = False - - -def resolve_attention_backend(requested_backend: str) -> AttentionBackend: - global _BACKEND_CONFIRMED - assert requested_backend in VALID_ATTENTION_BACKENDS, ( - f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - ) - if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): - _ensure_flash_kernels_loaded() - if requested_backend == AttentionBackend.AUTO.value: - if FLASH_KERNEL is not None: - resolved = AttentionBackend.KERNELS_FLASH - elif flex_attention is not None: - resolved = AttentionBackend.FLEX - else: - resolved = AttentionBackend.SDPA - elif requested_backend == AttentionBackend.KERNELS_FLASH.value: - assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." - resolved = AttentionBackend.KERNELS_FLASH - elif requested_backend == AttentionBackend.FLEX.value: - assert flex_attention is not None, "Flex Attention is not available in this environment." - resolved = AttentionBackend.FLEX - elif requested_backend == AttentionBackend.SDPA.value: - resolved = AttentionBackend.SDPA - else: - raise AssertionError(f"Unsupported attention backend: {requested_backend}") - if not _BACKEND_CONFIRMED: - print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") - _BACKEND_CONFIRMED = True - return resolved - - -@torch.compiler.disable -def get_attention_mask( - effective_backend: AttentionBackend, - batch_size: int, - seq_len: int, - device: torch.device, - attention_mask: Optional[torch.Tensor] = None, -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: - """Build padding masks once for all encoder layers. - - Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). - """ - if attention_mask is None: - return None, None, None - - attention_mask_2d = attention_mask.bool() - - if effective_backend == AttentionBackend.KERNELS_FLASH: - return attention_mask_2d, None, None - - if effective_backend == AttentionBackend.FLEX: - assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." - valid_lens = attention_mask_2d.sum(dim=-1) - - def mask_mod(batch_idx, head_idx, q_idx, kv_idx): - return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) - - flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) - return attention_mask_2d, None, flex_block_mask - - # SDPA / manual -- only mask the key dimension so padding query positions attend to - # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). - attention_mask_4d = attention_mask_2d[:, None, None, :] - return attention_mask_2d, attention_mask_4d, None - - -def bool_to_additive_mask( - bool_mask: torch.Tensor, - dtype: torch.dtype, -) -> torch.Tensor: - """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid). - - Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))` - directly on a bool tensor returns a bool tensor -- because `-inf` casts to `True` -- and - silently drops the mask entirely. Always allocate a float tensor first, then fill it. - This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask. - """ - assert bool_mask.dtype == torch.bool, ( - f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}" - ) - additive = torch.zeros_like(bool_mask, dtype=dtype) - additive.masked_fill_(bool_mask.logical_not(), float("-inf")) - return additive - -""" -ESM++ model implementation. - -ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility -The ESM Python package is not required - -Modified from https://github.com/Biohub/esm -License: https://github.com/Biohub/esm/blob/main/LICENSE.md -""" - -import math -import os -import json -import torch -import torch.nn as nn -import torch.nn.functional as F -from dataclasses import dataclass -from functools import cache, partial -from pathlib import Path -from typing import Optional, Tuple, Union, List -from einops import rearrange, repeat -from huggingface_hub import snapshot_download -from safetensors.torch import load_file as load_safetensors_file -from tokenizers import Tokenizer -from tokenizers.models import BPE -from tokenizers.processors import TemplateProcessing -from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig -from transformers.modeling_outputs import ModelOutput - - - -class ESMplusplusConfig(PretrainedConfig): - """Configuration class for ESM++ model. - - Args: - vocab_size: Size of the vocabulary - hidden_size: Dimension of hidden layers - num_attention_heads: Number of attention heads - num_hidden_layers: Number of transformer layers - num_labels: Number of output labels for classification - problem_type: Type of problem - regression, single/multi label classification - """ - model_type = "ESMplusplus" - def __init__( - self, - vocab_size: int = 64, - hidden_size: int = 960, - num_attention_heads: int = 15, - num_hidden_layers: int = 30, - num_labels: int = 2, - problem_type: Optional[str] = None, - dropout: float = 0.0, - initializer_range: float = 0.02, - attn_backend: str = "sdpa", - **kwargs, - ): - super().__init__(**kwargs) - self.vocab_size = vocab_size - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.num_hidden_layers = num_hidden_layers - self.num_labels = num_labels - self.problem_type = problem_type - self.dropout = dropout - self.initializer_range = initializer_range - self.tie_word_embeddings = False - self.attn_backend = attn_backend - - -### Rotary Embeddings -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - """Rotates half the hidden dims of the input.""" - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - interleaved: bool = False, - _inplace: bool = False, -) -> torch.Tensor: - """Apply rotary embeddings to input based on cos and sin.""" - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - seqlen = x.size(1) - cos = cos[:seqlen] - sin = sin[:seqlen] - cos = repeat(cos, "s d -> s 1 (2 d)") - sin = repeat(sin, "s d -> s 1 (2 d)") - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -class RotaryEmbedding(torch.nn.Module): - """Rotary position embeddings. - - Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding" - - Args: - dim: Dimension of the embedding - base: Base for computing angular frequencies - interleaved: Whether to use interleaved rotations - scale_base: Base for scaling - scaling_factor: Factor for scaling positions - pos_idx_in_fp32: Whether to compute position indices in fp32 - device: Computation device - """ - def __init__( - self, - dim: int, - base: float = 10000.0, - interleaved: bool = False, - scale_base: Optional[float] = None, - scaling_factor: float = 1.0, - pos_idx_in_fp32: bool = True, - device: Optional[torch.device] = None, - ): - super().__init__() - self.dim = dim - self.base = float(base) - self.pos_idx_in_fp32 = pos_idx_in_fp32 - self.interleaved = interleaved - self.scale_base = scale_base - self.scaling_factor = scaling_factor - self.device = device - - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - self.reset_parameters() - - def reset_parameters(self): - """Reset the parameters of the embedding.""" - if "inv_freq" in self._buffers and isinstance(self._buffers["inv_freq"], torch.Tensor): - buffer_device = self._buffers["inv_freq"].device - else: - buffer_device = self.device - inv_freq = self._compute_inv_freq(buffer_device) - self._seq_len_cached = 0 - self._cos_cached = None - self._sin_cached = None - self._cos_k_cached = None - self._sin_k_cached = None - self.register_buffer("inv_freq", inv_freq, persistent=False) - arange = torch.arange(0, self.dim, 2, device=buffer_device, dtype=torch.float32) - scale = ( - (arange + 0.4 * self.dim) / (1.4 * self.dim) - if self.scale_base is not None - else None - ) - self.register_buffer("scale", scale) - - def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor: - """Compute inverse frequency bands. - - Always computes on CPU then moves to the requested device. This matches - native Biohub ESMC, which computes inv_freq on CPU at - `__init__` and migrates via `.to(device)`. Computing directly on GPU - gives a ~3.7e-9 bit-level difference in inv_freq (fp32 transcendental - precision differs between CPU and GPU), which compounds through the 30 - attention layers to ~1e-3 mse divergence from native at - `hidden_states[-2]`. See testing/parity_debug_rotary.py. - """ - cpu_inv_freq = 1 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, device="cpu", dtype=torch.float32) - / self.dim - ) - ) - if device is not None and torch.device(device).type != "cpu": - return cpu_inv_freq.to(device) - return cpu_inv_freq - - def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): - """Update the cached cosine and sine values.""" - if ( - seqlen > self._seq_len_cached - or self._cos_cached is None - or self._cos_cached.device != device - or self._cos_cached.dtype != dtype - or (self.training and self._cos_cached.is_inference()) - ): - self._seq_len_cached = seqlen - if self.pos_idx_in_fp32: - t = torch.arange(seqlen, device=device, dtype=torch.float32) - t /= self.scaling_factor - if self.inv_freq.dtype != torch.float32: - inv_freq = self.inv_freq.to(torch.float32) - else: - inv_freq = self.inv_freq - else: - t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) - t /= self.scaling_factor - inv_freq = self.inv_freq - freqs = torch.outer(t, inv_freq) - - if self.scale is None: - self._cos_cached = torch.cos(freqs).to(dtype) - self._sin_cached = torch.sin(freqs).to(dtype) - else: - power = ( - torch.arange( - seqlen, dtype=self.scale.dtype, device=self.scale.device - ) - - seqlen // 2 - ) / self.scale_base - scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) - self._cos_cached = (torch.cos(freqs) * scale).to(dtype) - self._sin_cached = (torch.sin(freqs) * scale).to(dtype) - self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) - self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) - - def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """Apply rotary embeddings to queries and keys. - - Args: - q: Query tensor of shape (batch, seqlen, nheads, headdim) - k: Key tensor of shape (batch, seqlen, nheads, headdim) - - Returns: - Tuple of rotated query and key tensors - """ - # NOTE: do NOT recompute inv_freq here if device has changed. The native - # ESMC implementation computes inv_freq once on CPU at __init__ and - # relies on PyTorch's `.to(device)` to migrate the buffer. Recomputing - # the values directly on GPU gives a ~3.7e-9 bit-level difference vs the - # CPU-computed-then-moved values due to fp32 transcendental precision, - # which compounds through 30 attention layers to ~1e-3 mse divergence - # from native at `hidden_states[-2]`. See testing/parity_debug_rotary.py. - self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype) - assert self._cos_cached is not None - assert self._sin_cached is not None - if self.scale is None: - return ( - apply_rotary_emb_torch( - q, - self._cos_cached, - self._sin_cached, - self.interleaved, - True, # inplace=True - ), - apply_rotary_emb_torch( - k, - self._cos_cached, - self._sin_cached, - self.interleaved, - True, # inplace=True - ), - ) # type: ignore - else: - assert False - - -### Feedforward Network Components -def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: - """Compute corrected dimension for SwiGLU.""" - return int(((expansion_ratio * d_model) + 255) // 256 * 256) - - -class SwiGLU(nn.Module): - """SwiGLU activation function.""" - def __init__(self): - super(SwiGLU, self).__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x1, x2 = x.chunk(2, dim=-1) - return F.silu(x1) * x2 - - -def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential: - """Create SwiGLU feedforward network with layer normalization.""" - return nn.Sequential( - nn.LayerNorm(d_model), - nn.Linear( - d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False - ), - SwiGLU(), - nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False), - ) - - -### Attention -class MultiHeadAttention(nn.Module): - """Multi-head attention with rotary embeddings and configurable backend. - - Args: - d_model: Model dimension - n_heads: Number of attention heads - attn_backend: One of "auto", "kernels_flash", "flex", "sdpa" - """ - def __init__( - self, - d_model: int, - n_heads: int, - attn_backend: str = "sdpa", - ): - super().__init__() - self.d_model = d_model - self.n_heads = n_heads - self.d_head = self.d_model // self.n_heads - self.scale = 1.0 / math.sqrt(self.d_head) - self.attn_backend = resolve_attention_backend(attn_backend) - self.layernorm_qkv = nn.Sequential( - nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False) - ) - self.out_proj = nn.Linear(d_model, d_model, bias=False) - self.q_ln = nn.LayerNorm(d_model, bias=False) - self.k_ln = nn.LayerNorm(d_model, bias=False) - self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) - self.rotary = RotaryEmbedding(d_model // n_heads) - - def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - q = q.unflatten(-1, (self.n_heads, self.d_head)) - k = k.unflatten(-1, (self.n_heads, self.d_head)) - q, k = self.rotary(q, k) - q = q.flatten(-2, -1) - k = k.flatten(-2, -1) - return q, k - - def forward( - self, - x: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - output_attentions: bool = False, - output_s_max: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: - qkv_BLD3 = self.layernorm_qkv(x) - query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) - query_BLD, key_BLD = ( - self.q_ln(query_BLD).to(query_BLD.dtype), - self.k_ln(key_BLD).to(query_BLD.dtype), - ) - query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) - query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD)) - - attn_output, attn_weights, s_max = self._attn( - query_BHLD, key_BHLD, value_BHLD, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - - output = self.out_proj(attn_output) - return output, attn_weights, s_max - - def _attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - output_attentions: bool = False, - output_s_max: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: - if output_attentions: - return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max) - - if self.attn_backend == AttentionBackend.KERNELS_FLASH: - attn_output, attn_weights = self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) - elif self.attn_backend == AttentionBackend.FLEX: - attn_output, attn_weights = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) - elif self.attn_backend == AttentionBackend.SDPA: - attn_output, attn_weights = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) - else: - raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") - - s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None - return attn_output, attn_weights, s_max - - @torch.no_grad() - def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]: - q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) - k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) - s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale - return [s_max_bound[h] for h in range(self.n_heads)] - - def _manual_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_4d: Optional[torch.Tensor] = None, - output_s_max: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]: - attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale - if attention_mask_4d is not None: - attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) - attn_weights = F.softmax(attn_weights, dim=-1) - context_BHLD = torch.matmul(attn_weights, value_BHLD) - attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") - s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None - return attn_output, attn_weights, s_max - - def _kernels_flash_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, None]: - query_BLHD = query_BHLD.transpose(1, 2).contiguous() - key_BLHD = key_BHLD.transpose(1, 2).contiguous() - value_BLHD = value_BHLD.transpose(1, 2).contiguous() - attn_output = kernels_flash_attention_func( - query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, - attention_mask_2d=attention_mask_2d, causal=False, - ) - return rearrange(attn_output, "b s h d -> b s (h d)"), None - - def _flex_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - flex_block_mask: Optional[BlockMask] = None, - ) -> Tuple[torch.Tensor, None]: - assert flex_attention is not None, "Flex attention is not available in this environment." - fn = _get_flex_attention_fn() - context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale) - return rearrange(context_BHLD, "b h s d -> b s (h d)"), None - - def _sdpa_attn( - self, - query_BHLD: torch.Tensor, - key_BHLD: torch.Tensor, - value_BHLD: torch.Tensor, - attention_mask_4d: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, None]: - context_BHLD = F.scaled_dot_product_attention( - query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale, - ) - return rearrange(context_BHLD, "b h s d -> b s (h d)"), None - - -### Regression Head -def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module: - """Create a regression head with optional hidden dimension. - - Args: - d_model: Input dimension - output_dim: Output dimension - hidden_dim: Optional hidden dimension (defaults to d_model) - """ - hidden_dim = hidden_dim if hidden_dim is not None else d_model - return nn.Sequential( - nn.Linear(d_model, hidden_dim), - nn.GELU(), - nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, output_dim), - ) - - -### Transformer Block -class UnifiedTransformerBlock(nn.Module): - """Transformer block with attention and feedforward layers.""" - def __init__( - self, - d_model: int, - n_heads: int, - residue_scaling_factor: float = 1, - expansion_ratio: float = 8 / 3, - dropout: float = 0.0, - attn_backend: str = "sdpa", - ): - super().__init__() - self.attn = MultiHeadAttention(d_model=d_model, n_heads=n_heads, attn_backend=attn_backend) - self.ffn = swiglu_ln_ffn(d_model, expansion_ratio) - self.scaling_factor = residue_scaling_factor - self.dropout = nn.Dropout(dropout) - - def forward( - self, - x: torch.Tensor, - attention_mask_2d: Optional[torch.Tensor] = None, - attention_mask_4d: Optional[torch.Tensor] = None, - flex_block_mask: Optional[BlockMask] = None, - output_attentions: bool = False, - output_s_max: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: - attn_output, attn_weights, s_max = self.attn( - x, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - x = x + self.dropout(attn_output) / self.scaling_factor - x = x + self.dropout(self.ffn(x)) / self.scaling_factor - return x, attn_weights, s_max - - -### Model Outputs -@dataclass -class TransformerOutput(ModelOutput): - """Output type for transformer encoder.""" - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor]] = None - attentions: Optional[Tuple[torch.Tensor]] = None - s_max: Optional[Tuple[List[torch.Tensor], ...]] = None - - -@dataclass -class ESMplusplusOutput(ModelOutput): - """Output type for ESM++ models.""" - loss: Optional[torch.Tensor] = None - logits: Optional[torch.Tensor] = None - last_hidden_state: Optional[torch.Tensor] = None - hidden_states: Optional[Tuple[torch.Tensor]] = None - attentions: Optional[Tuple[torch.Tensor]] = None - s_max: Optional[Tuple[List[torch.Tensor], ...]] = None - - -### Transformer Stack -class TransformerStack(nn.Module): - """Stack of transformer blocks.""" - def __init__( - self, - d_model: int, - n_heads: int, - n_layers: int, - dropout: float = 0.0, - attn_backend: str = "sdpa", - ): - super().__init__() - self.attention_backend = resolve_attention_backend(attn_backend) - self.blocks = nn.ModuleList( - [ - UnifiedTransformerBlock( - d_model, - n_heads, - residue_scaling_factor=math.sqrt(n_layers / 36), - dropout=dropout, - attn_backend=attn_backend, - ) - for i in range(n_layers) - ] - ) - self.norm = nn.LayerNorm(d_model, bias=False) - self.gradient_checkpointing = False - - @property - def attn_backend(self) -> AttentionBackend: - return self.attention_backend - - @attn_backend.setter - def attn_backend(self, backend: str) -> None: - resolved = resolve_attention_backend(backend) - self.attention_backend = resolved - for block in self.blocks: - block.attn.attn_backend = resolved - - def forward( - self, - x: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = False, - output_attentions: Optional[bool] = False, - output_s_max: Optional[bool] = False, - ) -> TransformerOutput: - hidden_states = () if output_hidden_states else None - attentions = () if output_attentions else None - full_s_max = () if output_s_max else None - - attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( - effective_backend=self.attention_backend, - batch_size=x.shape[0], - seq_len=x.shape[1], - device=x.device, - attention_mask=attention_mask, - ) - - for block in self.blocks: - if self.gradient_checkpointing and self.training: - x, attn_weights, s_max = self._gradient_checkpointing_func( - block.__call__, - x=x, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - else: - x, attn_weights, s_max = block( - x=x, - attention_mask_2d=attention_mask_2d, - attention_mask_4d=attention_mask_4d, - flex_block_mask=flex_block_mask, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - - if attentions is not None: - attentions += (attn_weights,) - if output_hidden_states: - assert hidden_states is not None - hidden_states += (x,) - if full_s_max is not None: - full_s_max += (s_max,) - - last_hidden_state = self.norm(x) - if output_hidden_states: - hidden_states += (last_hidden_state,) - - return TransformerOutput( - last_hidden_state=last_hidden_state, - hidden_states=hidden_states, - attentions=attentions, - s_max=full_s_max, - ) - - -class PreTrainedESMplusplusModel(PreTrainedModel): - """ - init weights for ESM++ models - """ - config_class = ESMplusplusConfig - base_model_prefix = "esm++" - supports_gradient_checkpointing = True - all_tied_weights_keys = {} - - @classmethod - def is_remote_code(cls) -> bool: - # Prevent post-load reinitialization of tensors already loaded from checkpoints. - return True - - def _init_weights(self, module): - """Initialize the weights""" - # HF from_pretrained marks loaded parameters with `_is_hf_initialized`. - # Skip this module if any local parameter is already marked as loaded. - for parameter in module.parameters(recurse=False): - if "_is_hf_initialized" in parameter.__dict__ and parameter.__dict__["_is_hf_initialized"]: - return - - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - if module.padding_idx is not None: - with torch.no_grad(): - module.weight[module.padding_idx].zero_() - elif isinstance(module, nn.LayerNorm): - if module.bias is not None: - nn.init.zeros_(module.bias) - nn.init.ones_(module.weight) - - @property - def attn_backend(self) -> str: - return self.config.attn_backend - - @attn_backend.setter - def attn_backend(self, backend: str) -> None: - assert backend in VALID_ATTENTION_BACKENDS, f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." - self.config.attn_backend = backend - for module in self.modules(): - if isinstance(module, TransformerStack): - module.attn_backend = backend - - def _reset_rotary_embeddings(self): - """Refresh non-persistent rotary buffers after checkpoint loading.""" - for module in self.modules(): - if isinstance(module, RotaryEmbedding): - module.reset_parameters() - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - output_loading_info = bool(kwargs["output_loading_info"]) if "output_loading_info" in kwargs else False - loaded = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - if output_loading_info: - model, loading_info = loaded - model._reset_rotary_embeddings() - return model, loading_info - loaded._reset_rotary_embeddings() - return loaded - - @classmethod - def from_pretrained_esm( - cls, - model_name: str, - device: Union[torch.device, str] = "cpu", - ): - """Load a pretrained ESM++ model.""" - key = _resolve_esmc_checkpoint_key(model_name) - if key == "esmc-300": - return ESMplusplus_300M(device=device) - if key == "esmc-600": - return ESMplusplus_600M(device=device) - if key == "esmc-6b": - return ESMplusplus_6B(device=device) - raise ValueError(f"Invalid model name: {model_name}") - - -### ESM++ Models -class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin): - """ - ESM++ model. transformer model with no heads - """ - config_class = ESMplusplusConfig - def __init__(self, config: ESMplusplusConfig, **kwargs): - PreTrainedESMplusplusModel.__init__(self, config, **kwargs) - self.config = config - self.vocab_size = config.vocab_size - self.embed = nn.Embedding(self.vocab_size, config.hidden_size) - self.transformer = TransformerStack( - d_model=config.hidden_size, - n_heads=config.num_attention_heads, - n_layers=config.num_hidden_layers, - dropout=config.dropout, - attn_backend=config.attn_backend, - ) - self.tokenizer = EsmSequenceTokenizer() - self.init_weights() - - def get_input_embeddings(self): - return self.embed - - def set_input_embeddings(self, value): - self.embed = value - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = self.embed(input_ids) - return self.transformer( - x=x, - attention_mask=attention_mask, - output_hidden_states=False, - output_attentions=False, - ).last_hidden_state - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_s_max: Optional[bool] = False, - return_dict: Optional[bool] = None, - **kwargs, - ) -> ESMplusplusOutput: - assert input_ids is not None or inputs_embeds is not None, "You have to specify either input_ids or inputs_embeds" - assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds at the same time" - - if inputs_embeds is None: - x = self.embed(input_ids) - else: - x = inputs_embeds - - transformer_output = self.transformer( - x=x, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - return ESMplusplusOutput( - last_hidden_state=transformer_output.last_hidden_state, - hidden_states=transformer_output.hidden_states, - attentions=transformer_output.attentions, - s_max=transformer_output.s_max, - ) - -class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin): - """ - ESM++ model for masked language modeling. - Implements the base ESM++ architecture with a masked language modeling head. - """ - config_class = ESMplusplusConfig - def __init__(self, config: ESMplusplusConfig, **kwargs): - PreTrainedESMplusplusModel.__init__(self, config, **kwargs) - self.config = config - self.vocab_size = config.vocab_size - self.embed = nn.Embedding(self.vocab_size, config.hidden_size) - self.transformer = TransformerStack( - d_model=config.hidden_size, - n_heads=config.num_attention_heads, - n_layers=config.num_hidden_layers, - dropout=config.dropout, - attn_backend=config.attn_backend, - ) - self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size) - self.ce_loss = nn.CrossEntropyLoss() - self.tokenizer = EsmSequenceTokenizer() - self.init_weights() - - def get_input_embeddings(self): - return self.embed - - def set_input_embeddings(self, value): - self.embed = value - - def get_output_embeddings(self): - return self.sequence_head[-1] - - def set_output_embeddings(self, new_embeddings): - self.sequence_head[-1] = new_embeddings - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = self.embed(input_ids) - return self.transformer( - x=x, - attention_mask=attention_mask, - output_hidden_states=False, - output_attentions=False, - ).last_hidden_state - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_s_max: Optional[bool] = False, - return_dict: Optional[bool] = None, - **kwargs, - ) -> ESMplusplusOutput: - if inputs_embeds is None: - x = self.embed(input_ids) - else: - x = inputs_embeds - - output = self.transformer( - x=x, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - output_s_max=output_s_max, - ) - - last_hidden_state = output.last_hidden_state - logits = self.sequence_head(last_hidden_state) - loss = None - if labels is not None: - loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1)) - - return ESMplusplusOutput( - loss=loss, - logits=logits, - last_hidden_state=last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - s_max=output.s_max, - ) - - -class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin): - """ - ESM++ model for sequence classification. - Extends the base ESM++ model with a classification head. - """ - def __init__(self, config: ESMplusplusConfig, **kwargs): - ESMplusplusForMaskedLM.__init__(self, config, **kwargs) - self.config = config - self.num_labels = config.num_labels - self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4) - # Large intermediate projections help with sequence classification tasks (*4) - self.mse = nn.MSELoss() - self.ce = nn.CrossEntropyLoss() - self.bce = nn.BCEWithLogitsLoss() - # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean'] - if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: - pooling_types = kwargs['pooling_types'] - else: - pooling_types = ['mean', 'var'] - self.pooler = Pooler(pooling_types) - self.init_weights() - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = self.embed(input_ids) - return self.transformer( - x=x, - attention_mask=attention_mask, - output_hidden_states=False, - output_attentions=False, - ).last_hidden_state - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_s_max: Optional[bool] = False, - return_dict: Optional[bool] = None, - **kwargs, - ) -> ESMplusplusOutput: - output = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=None, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - ) - - last_hidden_state = output.last_hidden_state - features = self.pooler(last_hidden_state, attention_mask) - logits = self.classifier(features) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - if self.num_labels == 1: - loss = self.mse(logits.flatten(), labels.flatten()) - else: - loss = self.mse(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss = self.bce(logits, labels) - - return ESMplusplusOutput( - loss=loss, - logits=logits, - last_hidden_state=last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - s_max=output.s_max, - ) - - -class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin): - """ - ESM++ model for token classification. - Extends the base ESM++ model with a token classification head. - """ - def __init__(self, config: ESMplusplusConfig, **kwargs): - ESMplusplusForMaskedLM.__init__(self, config, **kwargs) - self.config = config - self.num_labels = config.num_labels - self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4) - # Large intermediate projections help with sequence classification tasks (*4) - self.loss_fct = nn.CrossEntropyLoss() - self.init_weights() - - def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: - x = self.embed(input_ids) - return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_s_max: Optional[bool] = False, - return_dict: Optional[bool] = None, - **kwargs, - ) -> ESMplusplusOutput: - output = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - labels=None, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_s_max=output_s_max, - ) - - last_hidden_state = output.last_hidden_state - logits = self.classifier(last_hidden_state) - loss = None - if labels is not None: - loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - - return ESMplusplusOutput( - loss=loss, - logits=logits, - last_hidden_state=last_hidden_state, - hidden_states=output.hidden_states, - attentions=output.attentions, - s_max=output.s_max, - ) - - -### Loading from Biohub -_ESMC_CHECKPOINT_SPECS = { - "esmc-300": { - "repo_id": "biohub/ESMC-300M", - "hidden_size": 960, - "num_attention_heads": 15, - "num_hidden_layers": 30, - }, - "esmc-600": { - "repo_id": "biohub/ESMC-600M", - "hidden_size": 1152, - "num_attention_heads": 18, - "num_hidden_layers": 36, - }, - "esmc-6b": { - "repo_id": "biohub/ESMC-6B", - "hidden_size": 2560, - "num_attention_heads": 40, - "num_hidden_layers": 80, - }, -} - - -def _resolve_esmc_checkpoint_key(model: str) -> str: - normalized = model.lower().replace("_", "-") - if "300" in normalized: - return "esmc-300" - if "600" in normalized: - return "esmc-600" - if "6b" in normalized: - return "esmc-6b" - raise ValueError(f"{model=} is an invalid ESMC model name.") - - -@staticmethod -@cache -def data_root(model: str): - if "INFRA_PROVIDER" in os.environ: - return Path("") - key = _resolve_esmc_checkpoint_key(model) - return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"])) - - -def get_esmc_checkpoint_path(model: str) -> Path: - key = _resolve_esmc_checkpoint_key(model) - spec = _ESMC_CHECKPOINT_SPECS[key] - if "weights_relpath" in spec: - return data_root(key) / spec["weights_relpath"] - checkpoint_dir = data_root(key) - if (checkpoint_dir / "model.safetensors").exists(): - return checkpoint_dir / "model.safetensors" - if (checkpoint_dir / "model.safetensors.index.json").exists(): - return checkpoint_dir / "model.safetensors.index.json" - raise FileNotFoundError(f"No ESMC checkpoint found under {checkpoint_dir}.") - - -def _normalize_esmc_state_key(key: str) -> Optional[str]: - if key.endswith("._extra_state"): - return None - if key.startswith("esmc."): - key = key[len("esmc."):] - if key.startswith("lm_head."): - key = f"sequence_head.{key[len('lm_head.'):]}" - replacements = ( - (".attn.layernorm_qkv.layer_norm_bias", ".attn.layernorm_qkv.0.bias"), - (".attn.layernorm_qkv.layer_norm_weight", ".attn.layernorm_qkv.0.weight"), - (".attn.layernorm_qkv.weight", ".attn.layernorm_qkv.1.weight"), - (".ffn.layer_norm_bias", ".ffn.0.bias"), - (".ffn.layer_norm_weight", ".ffn.0.weight"), - (".ffn.fc1_weight", ".ffn.1.weight"), - (".ffn.fc2_weight", ".ffn.3.weight"), - ) - for old, new in replacements: - key = key.replace(old, new) - return key - - -def _normalize_esmc_state_dict(state_dict: dict) -> dict: - normalized = {} - for key, tensor in state_dict.items(): - normalized_key = _normalize_esmc_state_key(key) - if normalized_key is None: - continue - normalized[normalized_key] = tensor - return normalized - - -def _safetensors_checkpoint_files(checkpoint_path: Path) -> List[Path]: - if checkpoint_path.name == "model.safetensors": - return [checkpoint_path] - with checkpoint_path.open("r", encoding="utf-8") as f: - index = json.load(f) - return [ - checkpoint_path.parent / filename - for filename in sorted(set(index["weight_map"].values())) - ] - - -def _load_safetensors_state_dict( - model_obj: ESMplusplusForMaskedLM, - checkpoint_path: Path, - device: Union[torch.device, str], -) -> None: - expected_keys = set(model_obj.state_dict().keys()) - loaded_keys = set() - device_string = str(torch.device(device)) - for shard_path in _safetensors_checkpoint_files(checkpoint_path): - shard_state_dict = load_safetensors_file(shard_path, device=device_string) - normalized = _normalize_esmc_state_dict(shard_state_dict) - unexpected = set(normalized.keys()) - expected_keys - assert len(unexpected) == 0, ( - f"Unexpected ESMC checkpoint keys in {shard_path.name}: " - f"{sorted(unexpected)[:10]}" - ) - model_obj.load_state_dict(normalized, strict=False) - loaded_keys.update(normalized.keys()) - - missing = expected_keys - loaded_keys - assert len(missing) == 0, ( - f"ESMC checkpoint did not provide all expected keys: {sorted(missing)[:10]}" - ) - - -def _load_esmc_checkpoint_model( - config: ESMplusplusConfig, - model: str, - device: Union[torch.device, str] = "cpu", -) -> ESMplusplusForMaskedLM: - key = _resolve_esmc_checkpoint_key(model) - spec = _ESMC_CHECKPOINT_SPECS[key] - assert config.hidden_size == spec["hidden_size"], ( - f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, " - f"but got {config.hidden_size}." - ) - assert config.num_attention_heads == spec["num_attention_heads"], ( - f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, " - f"but got {config.num_attention_heads}." - ) - assert config.num_hidden_layers == spec["num_hidden_layers"], ( - f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, " - f"but got {config.num_hidden_layers}." - ) - with torch.device(device): - model_obj = ESMplusplusForMaskedLM(config) - checkpoint_path = get_esmc_checkpoint_path(key) - if checkpoint_path.suffix == ".safetensors" or checkpoint_path.name == "model.safetensors.index.json": - _load_safetensors_state_dict( - model_obj=model_obj, - checkpoint_path=checkpoint_path, - device=device, - ) - else: - state_dict = torch.load(checkpoint_path, map_location=device) - model_obj.load_state_dict(_normalize_esmc_state_dict(state_dict)) - return model_obj - - -def ESMplusplus_300M(device: Union[torch.device, str] = "cpu"): - config = ESMplusplusConfig( - hidden_size=960, - num_attention_heads=15, - num_hidden_layers=30, - ) - return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device) - - -def ESMplusplus_600M(device: Union[torch.device, str] = "cpu"): - config = ESMplusplusConfig( - hidden_size=1152, - num_attention_heads=18, - num_hidden_layers=36, - ) - return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device) - - -def ESMplusplus_6B(device: Union[torch.device, str] = "cpu"): - config = ESMplusplusConfig( - hidden_size=2560, - num_attention_heads=40, - num_hidden_layers=80, - ) - return _load_esmc_checkpoint_model(config=config, model="esmc-6b", device=device) - - -### Tokenization -SEQUENCE_VOCAB = [ - "", "", "", "", - "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", - "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", - "O", ".", "-", "|", - "", -] - -class EsmSequenceTokenizer(PreTrainedTokenizerFast): - model_input_names = ["input_ids", "attention_mask"] - - def __init__( - self, - unk_token="", - cls_token="", - pad_token="", - mask_token="", - eos_token="", - chain_break_token="|", - **kwargs, - ): - all_tokens = SEQUENCE_VOCAB - token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} - - # a character-level tokenizer is the same as BPE with no token merges - bpe = BPE(token_to_id, merges=[], unk_token=unk_token) - tokenizer = Tokenizer(bpe) - special_tokens = [ - cls_token, - pad_token, - mask_token, - eos_token, - chain_break_token, - ] - self.cb_token = chain_break_token - additional_special_tokens = [chain_break_token] - - tokenizer.add_special_tokens(special_tokens) - - # This is where we configure the automatic addition of special tokens when we call - # tokenizer(text, add_special_tokens=True). Note that you can also configure how two - # sequences are merged if you want. - tokenizer.post_processor = TemplateProcessing( # type: ignore - single=" $A ", - pair=":0 $A:0 :0 $B:1 :1", - special_tokens=[ - ("", tokenizer.token_to_id("")), - ("", tokenizer.token_to_id("")), - ], - ) - super().__init__( - tokenizer_object=tokenizer, - unk_token=unk_token, - cls_token=cls_token, - pad_token=pad_token, - mask_token=mask_token, - eos_token=eos_token, - additional_special_tokens=additional_special_tokens, - **kwargs, - ) - - # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. - @property - def bos_token(self): - return self.cls_token - - @property - def bos_token_id(self): - return self.cls_token_id - - @property - def chain_break_token(self): - return self.cb_token - - @property - def chain_break_token_id(self): - return self.convert_tokens_to_ids(self.chain_break_token) - - @property - def all_token_ids(self): - return list(range(self.vocab_size)) - - @property - def special_token_ids(self): - return self.all_special_ids - - -if __name__ == "__main__": - import random - - import torch - - from torch import Tensor - - def print_tensor_shapes(prefix: str, obj): - if isinstance(obj, Tensor): - print(f"{prefix}{obj.shape}") - elif isinstance(obj, dict): - for name, value in obj.items(): - print_tensor_shapes(f"{prefix}{name}.", value) - elif isinstance(obj, list): - for idx, value in enumerate(obj): - print_tensor_shapes(f"{prefix}[{idx}].", value) - elif isinstance(obj, tuple): - for idx, value in enumerate(obj): - print_tensor_shapes(f"{prefix}[{idx}].", value) - elif hasattr(obj, "__dict__"): - for name, value in vars(obj).items(): - if name.startswith("_"): - continue - print_tensor_shapes(f"{prefix}{name}.", value) - else: - print(f"{prefix}{type(obj)}") - - random.seed(0) - torch.manual_seed(0) - - tokenizer = EsmSequenceTokenizer() - num_attention_heads = random.choice([2, 4]) - config = ESMplusplusConfig( - vocab_size=tokenizer.vocab_size, - hidden_size=16 * num_attention_heads, - num_attention_heads=num_attention_heads, - num_hidden_layers=random.choice([1, 2]), - num_labels=2, - dropout=0.0, - ) - - batch = tokenizer(["ACDEFG", "MKTW"], return_tensors="pt", padding=True) - batch["labels"] = batch["input_ids"].clone() - model = ESMplusplusForMaskedLM(config=config).eval() - - with torch.no_grad(): - output = model(**batch, return_dict=True) - - print("Batch shape:") - print_tensor_shapes("", batch) - print("Output shape:") - print_tensor_shapes("", output) +from __future__ import annotations + +import torch +import torch._inductor.config as inductor_config +import torch._dynamo as dynamo + +# Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) +# Provides significant speedup with minimal precision loss +torch.set_float32_matmul_precision('high') + +# Enable TF32 for matrix multiplications and cuDNN operations +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + +# Enable cuDNN autotuner - finds fastest algorithms for your hardware +# Best when input sizes are consistent; may slow down first iterations +torch.backends.cudnn.benchmark = True + +# Deterministic operations off for speed (set True if reproducibility needed) +torch.backends.cudnn.deterministic = False +inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" + +dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.recompile_limit = 16 + +import io +import os +import queue +import sqlite3 +import struct +import threading +import time + +import networkx as nx +import numpy as np +import torch +from tqdm.auto import tqdm +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from transformers import PreTrainedTokenizerBase + + +# SQLite stores tensors as compact blobs. Keep this header format compatible +# with Protify readers that share the same dtype/version codes. +_COMPACT_VERSION = 0x01 +_DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2} +_CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32} +_CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32} + + +def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes: + """Serialize a tensor to compact binary format for SQLite blob storage. + + Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes] + bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16) + but tagged with dtype_code=1 so they can be cast back on read. + Falls back to torch.save for unsupported dtypes. + """ + t = tensor.cpu() + if t.dtype not in _DTYPE_TO_CODE: + buffer = io.BytesIO() + torch.save(t, buffer) + return buffer.getvalue() + dtype_code = _DTYPE_TO_CODE[t.dtype] + + if t.dtype == torch.bfloat16: + raw = t.half().numpy().tobytes() + else: + raw = t.numpy().tobytes() + + shape = t.shape + header = struct.pack(f' bytes: + """Build just the compact header for a given dtype and shape.""" + dtype_code = _DTYPE_TO_CODE[dtype] + return struct.pack(f' List[bytes]: + """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors). + + Builds the header once and slices raw bytes per row. Much faster than + per-row tensor_to_embedding_blob calls for uniform-shape batches. + """ + assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}" + t = batch.cpu() + store_dtype = t.dtype + if t.dtype not in _DTYPE_TO_CODE: + return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])] + + if t.dtype == torch.bfloat16: + arr = t.half().numpy() + store_dtype = torch.bfloat16 + else: + arr = t.numpy() + + row_shape = tuple(t.shape[1:]) + header = _compact_header(store_dtype, row_shape) + raw = arr.tobytes() + stride = len(raw) // t.shape[0] + return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])] + + +def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: + """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats.""" + if len(blob) >= 6 and blob[0] == _COMPACT_VERSION: + dtype_code = blob[1] + ndim = struct.unpack_from(' torch.Tensor: + assert isinstance(hidden_state_index, int), "hidden_state_index must be an integer." + if store_all_hidden_states: + assert hidden_states is not None, "store_all_hidden_states requires output_hidden_states=True." + assert len(hidden_states) > 0, "Model returned no hidden states." + return torch.stack(tuple(hidden_states), dim=1) + if hidden_state_index == -1: + return last_hidden_state + assert hidden_states is not None, "hidden_state_index selection requires output_hidden_states=True." + return hidden_states[hidden_state_index] + + +def _trim_full_embedding(embedding: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + mask = attention_mask.bool() + if embedding.ndim == 2: + return embedding[mask].reshape(-1, embedding.shape[-1]) + if embedding.ndim == 3: + return embedding[:, mask, :].reshape(embedding.shape[0], -1, embedding.shape[-1]) + raise AssertionError(f"Expected full embedding tensor with 2 or 3 dims, got {embedding.ndim}.") + + +def pool_embeddings( + embeddings: Dict[str, torch.Tensor], + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, +) -> Dict[str, torch.Tensor]: + pooler = Pooler(pooling_types) + pooled: Dict[str, torch.Tensor] = {} + for sequence, embedding in embeddings.items(): + assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." + assert isinstance(embedding, torch.Tensor), "Expected embedding dictionary values to be tensors." + if embedding.ndim == 1: + pooled[sequence] = embedding.cpu() + continue + if embedding.ndim == 3: + embedding = embedding[hidden_state_index] + assert embedding.ndim == 2, f"Expected token-wise embedding with 2 dims, got {embedding.ndim}." + pooled[sequence] = pooler(embedding.unsqueeze(0)).squeeze(0).cpu() + return pooled + + +def load_pooled_embeddings_from_pth( + save_path: str, + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, +) -> Dict[str, torch.Tensor]: + assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" + payload = torch.load(save_path, map_location="cpu", weights_only=True) + assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." + return pool_embeddings(payload, pooling_types=pooling_types, hidden_state_index=hidden_state_index) + + +def load_pooled_embeddings_from_db( + db_path: str, + sequences: Optional[List[str]] = None, + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, +) -> Dict[str, torch.Tensor]: + assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" + loaded: Dict[str, torch.Tensor] = {} + with sqlite3.connect(db_path, timeout=30) as conn: + cursor = conn.cursor() + if sequences is None: + cursor.execute("SELECT sequence, embedding FROM embeddings") + else: + if len(sequences) == 0: + return loaded + placeholders = ",".join(["?"] * len(sequences)) + cursor.execute( + f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})", + tuple(sequences), + ) + for sequence, embedding_bytes in cursor.fetchall(): + loaded[sequence] = embedding_blob_to_tensor(embedding_bytes) + return pool_embeddings(loaded, pooling_types=pooling_types, hidden_state_index=hidden_state_index) + + +def maybe_compile(model: torch.nn.Module, dynamic: bool = False) -> torch.nn.Module: + """Compile model with torch.compile if possible. + + Skips compilation when dynamic=True (padding='longest') because + flex attention's create_block_mask is incompatible with dynamic shapes + under torch.compile, causing CUDA illegal memory access. + """ + if dynamic: + print("Skipping torch.compile (dynamic shapes + flex attention incompatible)") + return model + try: + model = torch.compile(model) + print("Model compiled") + except Exception as e: + print(f"Skipping torch.compile: {e}") + return model + + +def build_collator( + tokenizer: PreTrainedTokenizerBase, + padding: str = 'max_length', + max_length: int = 512, +) -> Callable[[List[str]], Dict[str, torch.Tensor]]: + def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: + kwargs: Dict[str, Any] = dict( + return_tensors="pt", padding=padding, truncation=True, max_length=max_length, + ) + if padding != 'max_length': + kwargs['pad_to_multiple_of'] = 8 + return tokenizer(sequences, **kwargs) + return _collate_fn + + +def _make_embedding_progress( + dataloader: DataLoader, + padding: str, + n_warmup: int = 3, + n_calibration: int = 5, +) -> Iterator[Tuple[int, Any]]: + """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader). + + When padding='max_length', all batches have uniform cost so plain tqdm works. + When padding='longest' (sorted longest-first), batch times vary dramatically. + In that case: yield warmup batches first (compiler warmup + OOM check on longest + sequences), then time mid-length calibration batches to estimate total ETA. + + Keep in sync with protify/embedder.py and core/atlas/precomputed.py. + """ + total = len(dataloader) + if padding == 'max_length' or total <= n_warmup + n_calibration: + for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'): + yield i, batch + return + + dl_iter = iter(dataloader) + + # Warm up on the longest batches first; sorted inputs make these the OOM-risk + # and compile-stabilization cases. + warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False) + for i in warmup_bar: + batch = next(dl_iter) + yield i, batch + warmup_bar.close() + + # Move toward mid-length batches for ETA calibration, yielding every real + # batch on the way so no sequences are skipped. + mid_start = total // 2 + intermediate_bar = tqdm( + range(n_warmup, mid_start), desc='Embedding batches', leave=False, + ) + for i in intermediate_bar: + batch = next(dl_iter) + yield i, batch + intermediate_bar.close() + + # Mid-length batches give a better remaining-time estimate than the longest + # warmup batches. + calibration_times: List[float] = [] + cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False) + for j in cal_bar: + t0 = time.perf_counter() + batch = next(dl_iter) + yield mid_start + j, batch + calibration_times.append(time.perf_counter() - t0) + cal_bar.close() + + avg_time = sum(calibration_times) / len(calibration_times) + remaining_start = mid_start + n_calibration + remaining_count = total - remaining_start + estimated_total_seconds = avg_time * remaining_count + + # Finish the tail with the calibrated ETA shown in the progress bar. + main_bar = tqdm( + range(remaining_count), + desc='Embedding batches', + bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', + ) + main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)') + for k in main_bar: + batch = next(dl_iter) + yield remaining_start + k, batch + main_bar.close() + + +class _SQLWriter: + """Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter.""" + + def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None: + self._conn = conn + self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize) + self._thread: Optional[threading.Thread] = None + + def __enter__(self) -> "_SQLWriter": + self._thread = threading.Thread(target=self._writer_loop, daemon=True) + self._thread.start() + return self + + def write_batch(self, rows: List[Tuple[str, bytes]]) -> None: + self._queue.put(rows) + + def _writer_loop(self) -> None: + cursor = self._conn.cursor() + while True: + item = self._queue.get() + if item is None: + break + cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item) + if self._queue.qsize() == 0: + self._conn.commit() + self._conn.commit() + + def __exit__(self, *exc) -> None: + if self._thread is not None: + self._queue.put(None) + self._thread.join() + self._thread = None + + +class Pooler: + def __init__(self, pooling_types: List[str]) -> None: + self.pooling_types = pooling_types + self.pooling_options: Dict[str, Callable] = { + 'mean': self.mean_pooling, + 'max': self.max_pooling, + 'norm': self.norm_pooling, + 'median': self.median_pooling, + 'std': self.std_pooling, + 'var': self.var_pooling, + 'cls': self.cls_pooling, + 'parti': self._pool_parti, + } + + def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: + assert isinstance(attentions, torch.Tensor) + maxed_attentions = torch.max(attentions, dim=1)[0] + return maxed_attentions + + def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: + G = self._convert_to_graph(attention_matrix) + if G.number_of_nodes() != attention_matrix.shape[0]: + raise Exception( + f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") + if G.number_of_edges() == 0: + raise Exception(f"You don't seem to have any attention edges left in the graph.") + + return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) + + def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: + G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) + return G + + def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: + if attention_mask is not None: + for k in list(dict_importance.keys()): + if attention_mask[k] == 0: + del dict_importance[k] + + total = sum(dict_importance.values()) + return np.array([v / total for _, v in dict_importance.items()]) + + def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() + emb_pooled = [] + for e, a, mask in zip(emb, maxed_attentions, attention_mask): + dict_importance = self._page_rank(a) + importance_weights = self._calculate_importance_weights(dict_importance, mask) + num_tokens = int(mask.sum().item()) + emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) + pooled = torch.tensor(np.array(emb_pooled)) + return pooled + + def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.mean(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + + def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.max(dim=1).values + else: + mask = attention_mask.unsqueeze(-1).bool() + return emb.masked_fill(~mask, float('-inf')).max(dim=1).values + + def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.norm(dim=1, p=2) + else: + attention_mask = attention_mask.unsqueeze(-1) + return (emb * attention_mask).norm(dim=1, p=2) + + def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.median(dim=1).values + else: + mask = attention_mask.unsqueeze(-1).bool() + return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values + + def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.std(dim=1) + else: + var = self.var_pooling(emb, attention_mask, **kwargs) + return torch.sqrt(var) + + def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + if attention_mask is None: + return emb.var(dim=1) + else: + attention_mask = attention_mask.unsqueeze(-1) + mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + mean = mean.unsqueeze(1) + squared_diff = (emb - mean) ** 2 + var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) + return var + + def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + return emb[:, 0, :] + + def __call__( + self, + emb: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attentions: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if attention_mask is not None: + assert attention_mask.sum(dim=-1).min() > 0, ( + "Pooler received samples with all-zero attention masks. " + "This causes NaN from division by zero. Filter empty inputs before pooling." + ) + final_emb: List[torch.Tensor] = [] + for pooling_type in self.pooling_types: + final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) + return torch.cat(final_emb, dim=-1) + + +class ProteinDataset(TorchDataset): + """Simple dataset for protein sequences.""" + def __init__(self, sequences: List[str]) -> None: + self.sequences = sequences + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> str: + return self.sequences[idx] + + +def parse_fasta(fasta_path: str) -> List[str]: + assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" + sequences = [] + current_seq = [] + with open(fasta_path, 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + if line.startswith('>'): + if current_seq: + sequences.append(''.join(current_seq)) + current_seq = [] + else: + current_seq.append(line) + if current_seq: + sequences.append(''.join(current_seq)) + return sequences + + +class EmbeddingMixin: + def _embed( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + ) -> torch.Tensor: + raise NotImplementedError + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + def _read_sequences_from_db(self, db_path: str) -> Set[str]: + """Read sequences from SQLite database.""" + with sqlite3.connect(db_path, timeout=30) as conn: + c = conn.cursor() + c.execute("SELECT sequence FROM embeddings") + return {row[0] for row in c.fetchall()} + + def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: + cursor = conn.cursor() + cursor.execute( + "CREATE TABLE IF NOT EXISTS embeddings (" + "sequence TEXT PRIMARY KEY, " + "embedding BLOB NOT NULL" + ")" + ) + conn.commit() + + def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: + assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" + payload = torch.load(save_path, map_location="cpu", weights_only=True) + assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." + for sequence, tensor in payload.items(): + assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." + assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." + return payload + + def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: + assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" + loaded: Dict[str, torch.Tensor] = {} + with sqlite3.connect(db_path, timeout=30) as conn: + self._ensure_embeddings_table(conn) + cursor = conn.cursor() + if sequences is None: + cursor.execute("SELECT sequence, embedding FROM embeddings") + else: + if len(sequences) == 0: + return loaded + placeholders = ",".join(["?"] * len(sequences)) + cursor.execute( + f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})", + tuple(sequences), + ) + + rows = cursor.fetchall() + for row in rows: + sequence = row[0] + embedding_bytes = row[1] + loaded[sequence] = embedding_blob_to_tensor(embedding_bytes) + return loaded + + def pool_embeddings( + self, + embeddings: Dict[str, torch.Tensor], + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, + ) -> Dict[str, torch.Tensor]: + return pool_embeddings(embeddings, pooling_types=pooling_types, hidden_state_index=hidden_state_index) + + def load_pooled_embeddings_from_pth( + self, + save_path: str, + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, + ) -> Dict[str, torch.Tensor]: + return load_pooled_embeddings_from_pth( + save_path, + pooling_types=pooling_types, + hidden_state_index=hidden_state_index, + ) + + def load_pooled_embeddings_from_db( + self, + db_path: str, + sequences: Optional[List[str]] = None, + pooling_types: List[str] = ['mean'], + hidden_state_index: int = -1, + ) -> Dict[str, torch.Tensor]: + return load_pooled_embeddings_from_db( + db_path, + sequences=sequences, + pooling_types=pooling_types, + hidden_state_index=hidden_state_index, + ) + + def embed_dataset( + self, + sequences: Optional[List[str]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + batch_size: int = 2, + max_len: int = 512, + truncate: bool = True, + full_embeddings: bool = False, + embed_dtype: torch.dtype = torch.float32, + pooling_types: List[str] = ['mean'], + num_workers: int = 0, + sql: bool = False, + save: bool = True, + sql_db_path: str = 'embeddings.db', + save_path: str = 'embeddings.pth', + fasta_path: Optional[str] = None, + padding: str = 'max_length', + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + **kwargs, + ) -> Optional[Dict[str, torch.Tensor]]: + """ + Embed a dataset of protein sequences. + + Supports two modes: + - Tokenizer mode (ESM2/ESM++): provide `tokenizer` or use `self.tokenizer`. + - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. + + Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via + `fasta_path`, or both (the two sources are combined). At least one must be provided. + """ + if fasta_path is not None: + fasta_sequences = parse_fasta(fasta_path) + sequences = list(sequences or []) + fasta_sequences + assert sequences is not None and len(sequences) > 0, \ + "Must provide at least one sequence via `sequences` or `fasta_path`." + assert isinstance(hidden_state_index, int), "hidden_state_index must be an integer." + assert full_embeddings or not store_all_hidden_states, \ + "store_all_hidden_states=True requires full_embeddings=True." + sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) + sequences = sorted(sequences, key=len, reverse=True) + pooler = Pooler(pooling_types) if not full_embeddings else None + if tokenizer is None and self.config.model_type != "E1": + tokenizer = self.tokenizer + tokenizer_mode = tokenizer is not None + + # Resolve padding and compilation + dynamic = padding == 'longest' + compiled_model = maybe_compile(self, dynamic=dynamic) + + if tokenizer_mode: + collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len) + device = self.device + else: + collate_fn = None + device = None + + def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + assert isinstance(residue_embeddings, torch.Tensor) + if full_embeddings or residue_embeddings.ndim == 2: + return residue_embeddings + return pooler(residue_embeddings, attention_mask) + + def iter_batches(to_embed: List[str]): + if tokenizer_mode: + assert collate_fn is not None + assert device is not None + dataset = ProteinDataset(to_embed) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + prefetch_factor=2 if num_workers > 0 else None, + collate_fn=collate_fn, + shuffle=False, + pin_memory=True, + ) + for i, batch in _make_embedding_progress(dataloader, padding): + seqs = to_embed[i * batch_size:(i + 1) * batch_size] + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + residue_embeddings = compiled_model._embed( + input_ids, + attention_mask, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + ) + yield seqs, residue_embeddings, attention_mask + else: + for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): + seqs = to_embed[batch_start:batch_start + batch_size] + batch_output = compiled_model._embed( + seqs, + return_attention_mask=True, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + **kwargs, + ) + assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." + assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." + residue_embeddings, attention_mask = batch_output + assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." + yield seqs, residue_embeddings, attention_mask + + if sql: + # Resume safely: skip sequences already present in the SQLite table. + conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False) + conn.execute('PRAGMA journal_mode=WAL') + conn.execute('PRAGMA busy_timeout=30000') + conn.execute('PRAGMA synchronous=OFF') + conn.execute('PRAGMA cache_size=-64000') + self._ensure_embeddings_table(conn) + already_embedded = self._read_sequences_from_db(sql_db_path) + to_embed = [seq for seq in sequences if seq not in already_embedded] + print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") + print(f"Embedding {len(to_embed)} new sequences") + if len(to_embed) > 0: + # Embed batches synchronously; serialize/write them on the SQL writer thread. + with _SQLWriter(conn) as writer: + with torch.inference_mode(): + for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + if full_embeddings: + batch_rows = [] + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + batch_rows.append((seq, tensor_to_embedding_blob(_trim_full_embedding(emb, mask)))) + else: + blobs = batch_tensor_to_blobs(embeddings) + batch_rows = list(zip(seqs, blobs)) + writer.write_batch(batch_rows) + conn.close() + return None + + embeddings_dict = {} + if os.path.exists(save_path): + embeddings_dict = self.load_embeddings_from_pth(save_path) + to_embed = [seq for seq in sequences if seq not in embeddings_dict] + print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") + print(f"Embedding {len(to_embed)} new sequences") + else: + to_embed = sequences + print(f"Embedding {len(to_embed)} new sequences") + + if len(to_embed) > 0: + with torch.inference_mode(): + for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): + embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) + for seq, emb, mask in zip(seqs, embeddings, attention_mask): + if full_embeddings: + emb = _trim_full_embedding(emb, mask) + embeddings_dict[seq] = emb.cpu() + + if save: + torch.save(embeddings_dict, save_path) + + return embeddings_dict + + +if __name__ == "__main__": + # Manual smoke test for pooling shape behavior. + pooler = Pooler(pooling_types=['max', 'parti']) + batch_size = 8 + seq_len = 64 + hidden_size = 128 + num_layers = 12 + emb = torch.randn(batch_size, seq_len, hidden_size) + attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) + attention_mask = torch.ones(batch_size, seq_len) + y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) + print(y.shape) + +"""Shared attention infrastructure for all FastPLMs models. + +Contains: AttentionBackend enum, backend resolution, mask creation, +flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. +""" +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F +from einops import rearrange + +try: + from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask +except ImportError: + create_block_mask = None + flex_attention = None + BlockMask = None + +_compiled_flex_attention = None + + +def _get_flex_attention_fn(): + """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" + global _compiled_flex_attention + if flex_attention is None: + return None + flex_mod = torch.nn.attention.flex_attention + if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): + return flex_attention + if _compiled_flex_attention is None: + _compiled_flex_attention = torch.compile( + flex_attention, + dynamic=False, + ) + return _compiled_flex_attention + + +# HuggingFace `kernels` exposes slightly different APIs for Flash Attention 2 +# and 3. Detect the loaded variant once so every caller uses the same dispatch. +def _infer_kernels_flash_variant(kernel) -> Optional[str]: + if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): + return "flash_attn2" + if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): + return "flash_attn3" + return None + + +def _try_get_kernels_flash(): + try: + from kernels import get_kernel + except ImportError: + return None, None + + flash_kernel = None + flash_kernel_variant = None + try: + flash_kernel = get_kernel("kernels-community/flash-attn3") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." + except Exception: + try: + flash_kernel = get_kernel("kernels-community/flash-attn2") + flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) + assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." + except Exception: + flash_kernel = None + flash_kernel_variant = None + return flash_kernel, flash_kernel_variant + + +_FLASH_KERNELS_LOADED = False +FLASH_KERNEL = None +FLASH_KERNEL_VARIANT = None + + +def _ensure_flash_kernels_loaded(): + global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT + if _FLASH_KERNELS_LOADED: + return + _FLASH_KERNELS_LOADED = True + FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() + + +def _kernels_flash_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + causal: bool = False, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Flash-attention forward, optionally overriding the softmax scale. + + When `softmax_scale is None`, the flash kernel applies its default + `1 / sqrt(head_dim)`. Pass `softmax_scale=1.0` if the caller has already + pre-scaled Q (the convention used by ESM2, DPLM, DPLM2, E1, ESMFold). + Failing to override when Q is pre-scaled applies the scale twice. On + DPLM-150M, that produced pooled-embedding cosine around -0.12 and argmax + agreement around 0.27 vs SDPA. + """ + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.fwd( + q=query_states, k=key_states, v=value_states, + softmax_scale=softmax_scale, is_causal=causal, + )[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_func( + q=query_states, k=key_states, v=value_states, + softmax_scale=softmax_scale, causal=causal, + ) + except TypeError: + output = FLASH_KERNEL.flash_attn_func( + query_states, key_states, value_states, + 0.0, softmax_scale, causal, + ) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +def _kernels_flash_varlen_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_in_batch_q: int, + max_seqlen_in_batch_k: int, + causal: bool = False, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Varlen flash-attention forward, optionally overriding the softmax scale. + + See `_kernels_flash_forward` docstring for why `softmax_scale=1.0` must be + passed when Q has been pre-scaled by the caller. + """ + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if FLASH_KERNEL_VARIANT == "flash_attn2": + return FLASH_KERNEL.varlen_fwd( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, is_causal=causal, + )[0] + if FLASH_KERNEL_VARIANT == "flash_attn3": + try: + output = FLASH_KERNEL.flash_attn_varlen_func( + q=query_states, k=key_states, v=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, causal=causal, + ) + except TypeError: + output = FLASH_KERNEL.flash_attn_varlen_func( + query_states, key_states, value_states, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_in_batch_q, max_seqlen_in_batch_k, + 0.0, softmax_scale, causal, + ) + if isinstance(output, tuple): + return output[0] + return output + raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") + + +# Varlen flash attention runs only on real tokens. These helpers remove padding +# before the kernel call and restore the original padded batch shape afterward. +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices) -> torch.Tensor: + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + output[indices] = values + return output + + @staticmethod + def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: + (indices,) = ctx.saved_tensors + return grad_output[indices], None, None + + +index_first_axis = IndexFirstAxis.apply +index_put_first_axis = IndexPutFirstAxis.apply + + +def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def _unpad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask_2d: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: + batch_size, seq_len, num_heads, head_dim = query_layer.shape + seqlens = attention_mask_2d.sum(dim=1).int() + cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) + max_seqlen = int(seqlens.max().item()) + indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() + query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) + return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) + + +def kernels_flash_attention_func( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Public flash-attention entry point with optional padding handling. + + `softmax_scale`: + None -> kernel applies its default `1 / sqrt(head_dim)`. + float -> kernel uses the given scale (pass 1.0 when Q is pre-scaled + by the caller). + + Caller contract: if a model family pre-scales Q by `1/sqrt(head_dim)` + before calling this function (ESM2, DPLM, DPLM2, E1, and ESMFold do), pass + `softmax_scale=1.0`. Otherwise the flash kernel applies its default scale + again, yielding an effective `1/head_dim` scale that drifts across layers. + """ + assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." + if not causal and attention_mask_2d is not None: + batch_size, q_len = query_states.shape[:2] + ( + query_states, key_states, value_states, + indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), + ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) + attn_output_unpad = _kernels_flash_varlen_forward( + query_states=query_states, key_states=key_states, value_states=value_states, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, + softmax_scale=softmax_scale, + ) + return pad_input(attn_output_unpad, indices_q, batch_size, q_len) + else: + return _kernels_flash_forward( + query_states=query_states, key_states=key_states, value_states=value_states, + causal=causal, softmax_scale=softmax_scale, + ) + + +# User-facing backend strings resolve to this enum before attention dispatch. +class AttentionBackend(Enum): + AUTO = "auto" + KERNELS_FLASH = "kernels_flash" + FLEX = "flex" + SDPA = "sdpa" + + +VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) + + +_BACKEND_CONFIRMED = False + + +def resolve_attention_backend(requested_backend: str) -> AttentionBackend: + global _BACKEND_CONFIRMED + assert requested_backend in VALID_ATTENTION_BACKENDS, ( + f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + ) + if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): + _ensure_flash_kernels_loaded() + if requested_backend == AttentionBackend.AUTO.value: + if FLASH_KERNEL is not None: + resolved = AttentionBackend.KERNELS_FLASH + elif flex_attention is not None: + resolved = AttentionBackend.FLEX + else: + resolved = AttentionBackend.SDPA + elif requested_backend == AttentionBackend.KERNELS_FLASH.value: + assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." + resolved = AttentionBackend.KERNELS_FLASH + elif requested_backend == AttentionBackend.FLEX.value: + assert flex_attention is not None, "Flex Attention is not available in this environment." + resolved = AttentionBackend.FLEX + elif requested_backend == AttentionBackend.SDPA.value: + resolved = AttentionBackend.SDPA + else: + raise AssertionError(f"Unsupported attention backend: {requested_backend}") + if not _BACKEND_CONFIRMED: + print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") + _BACKEND_CONFIRMED = True + return resolved + + +@torch.compiler.disable +def get_attention_mask( + effective_backend: AttentionBackend, + batch_size: int, + seq_len: int, + device: torch.device, + attention_mask: Optional[torch.Tensor] = None, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: + """Build padding masks once for all encoder layers. + + Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). + """ + if attention_mask is None: + return None, None, None + + attention_mask_2d = attention_mask.bool() + + if effective_backend == AttentionBackend.KERNELS_FLASH: + return attention_mask_2d, None, None + + if effective_backend == AttentionBackend.FLEX: + assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." + valid_lens = attention_mask_2d.sum(dim=-1) + + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) + + flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) + return attention_mask_2d, None, flex_block_mask + + # SDPA/manual masks only keys. Padding queries still attend to real keys, so + # their outputs stay finite instead of softmaxing over all -inf scores. + attention_mask_4d = attention_mask_2d[:, None, None, :] + return attention_mask_2d, attention_mask_4d, None + + +def bool_to_additive_mask( + bool_mask: torch.Tensor, + dtype: torch.dtype, +) -> torch.Tensor: + """Convert a bool mask (True = valid) to a float additive mask (0.0 valid, -inf invalid). + + Why this exists: calling `bool_mask.masked_fill(bool_mask.logical_not(), float('-inf'))` + directly on a bool tensor returns a bool tensor because `-inf` casts to `True`. + That silently drops the mask. Always allocate a float tensor first, then fill it. + This helper is the sanctioned way to build an SDPA additive mask from a bool validity mask. + """ + assert bool_mask.dtype == torch.bool, ( + f"bool_to_additive_mask requires a bool tensor, got dtype={bool_mask.dtype}" + ) + additive = torch.zeros_like(bool_mask, dtype=dtype) + additive.masked_fill_(bool_mask.logical_not(), float("-inf")) + return additive + +""" +ESM++ model implementation. + +ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility +The ESM Python package is not required + +Modified from https://github.com/Biohub/esm +License: https://github.com/Biohub/esm/blob/main/LICENSE.md +""" + +import math +import os +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass +from functools import cache, partial +from pathlib import Path +from typing import Optional, Tuple, Union, List +from einops import rearrange, repeat +from huggingface_hub import snapshot_download +from safetensors.torch import load_file as load_safetensors_file +from tokenizers import Tokenizer +from tokenizers.models import BPE +from tokenizers.processors import TemplateProcessing +from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig +from transformers.modeling_outputs import ModelOutput + + + +class ESMplusplusConfig(PretrainedConfig): + """Configuration class for ESM++ model. + + Args: + vocab_size: Size of the vocabulary + hidden_size: Dimension of hidden layers + num_attention_heads: Number of attention heads + num_hidden_layers: Number of transformer layers + num_labels: Number of output labels for classification + problem_type: Type of problem - regression, single/multi label classification + """ + model_type = "ESMplusplus" + def __init__( + self, + vocab_size: int = 64, + hidden_size: int = 960, + num_attention_heads: int = 15, + num_hidden_layers: int = 30, + num_labels: int = 2, + problem_type: Optional[str] = None, + dropout: float = 0.0, + initializer_range: float = 0.02, + attn_backend: str = "sdpa", + **kwargs, + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_labels = num_labels + self.problem_type = problem_type + self.dropout = dropout + self.initializer_range = initializer_range + self.tie_word_embeddings = False + self.attn_backend = attn_backend + + +### Rotary Embeddings +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + _inplace: bool = False, +) -> torch.Tensor: + """Apply rotary embeddings to input based on cos and sin.""" + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + seqlen = x.size(1) + cos = cos[:seqlen] + sin = sin[:seqlen] + cos = repeat(cos, "s d -> s 1 (2 d)") + sin = repeat(sin, "s d -> s 1 (2 d)") + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +class RotaryEmbedding(torch.nn.Module): + """Rotary position embeddings. + + Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding" + + Args: + dim: Dimension of the embedding + base: Base for computing angular frequencies + interleaved: Whether to use interleaved rotations + scale_base: Base for scaling + scaling_factor: Factor for scaling positions + pos_idx_in_fp32: Whether to compute position indices in fp32 + device: Computation device + """ + def __init__( + self, + dim: int, + base: float = 10000.0, + interleaved: bool = False, + scale_base: Optional[float] = None, + scaling_factor: float = 1.0, + pos_idx_in_fp32: bool = True, + device: Optional[torch.device] = None, + ): + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + self.interleaved = interleaved + self.scale_base = scale_base + self.scaling_factor = scaling_factor + self.device = device + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.reset_parameters() + + def reset_parameters(self): + """Reset the parameters of the embedding.""" + if "inv_freq" in self._buffers and isinstance(self._buffers["inv_freq"], torch.Tensor): + buffer_device = self._buffers["inv_freq"].device + else: + buffer_device = self.device + inv_freq = self._compute_inv_freq(buffer_device) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + self.register_buffer("inv_freq", inv_freq, persistent=False) + arange = torch.arange(0, self.dim, 2, device=buffer_device, dtype=torch.float32) + scale = ( + (arange + 0.4 * self.dim) / (1.4 * self.dim) + if self.scale_base is not None + else None + ) + self.register_buffer("scale", scale) + + def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor: + """Compute inverse frequency bands. + + Always computes on CPU then moves to the requested device. This matches + native Biohub ESMC, which computes inv_freq on CPU at + `__init__` and migrates via `.to(device)`. Computing directly on GPU + gives a ~3.7e-9 bit-level difference in inv_freq (fp32 transcendental + precision differs between CPU and GPU), which compounds through the 30 + attention layers to ~1e-3 mse divergence from native at + `hidden_states[-2]`. See testing/parity_debug_rotary.py. + """ + cpu_inv_freq = 1 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, device="cpu", dtype=torch.float32) + / self.dim + ) + ) + if device is not None and torch.device(device).type != "cpu": + return cpu_inv_freq.to(device) + return cpu_inv_freq + + def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + """Update the cached cosine and sine values.""" + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + t /= self.scaling_factor + if self.inv_freq.dtype != torch.float32: + inv_freq = self.inv_freq.to(torch.float32) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + t /= self.scaling_factor + inv_freq = self.inv_freq + freqs = torch.outer(t, inv_freq) + + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange( + seqlen, dtype=self.scale.dtype, device=self.scale.device + ) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** power.unsqueeze(-1) + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings to queries and keys. + + Args: + q: Query tensor of shape (batch, seqlen, nheads, headdim) + k: Key tensor of shape (batch, seqlen, nheads, headdim) + + Returns: + Tuple of rotated query and key tensors + """ + # NOTE: do NOT recompute inv_freq here if device has changed. The native + # ESMC implementation computes inv_freq once on CPU at __init__ and + # relies on PyTorch's `.to(device)` to migrate the buffer. Recomputing + # the values directly on GPU gives a ~3.7e-9 bit-level difference vs the + # CPU-computed-then-moved values due to fp32 transcendental precision, + # which compounds through 30 attention layers to ~1e-3 mse divergence + # from native at `hidden_states[-2]`. See testing/parity_debug_rotary.py. + self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype) + assert self._cos_cached is not None + assert self._sin_cached is not None + if self.scale is None: + return ( + apply_rotary_emb_torch( + q, + self._cos_cached, + self._sin_cached, + self.interleaved, + True, # inplace=True + ), + apply_rotary_emb_torch( + k, + self._cos_cached, + self._sin_cached, + self.interleaved, + True, # inplace=True + ), + ) # type: ignore + else: + assert False + + +### Feedforward Network Components +def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int: + """Compute corrected dimension for SwiGLU.""" + return int(((expansion_ratio * d_model) + 255) // 256 * 256) + + +class SwiGLU(nn.Module): + """SwiGLU activation function.""" + def __init__(self): + super(SwiGLU, self).__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x1, x2 = x.chunk(2, dim=-1) + return F.silu(x1) * x2 + + +def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential: + """Create SwiGLU feedforward network with layer normalization.""" + return nn.Sequential( + nn.LayerNorm(d_model), + nn.Linear( + d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False + ), + SwiGLU(), + nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False), + ) + + +### Attention +class MultiHeadAttention(nn.Module): + """Multi-head attention with rotary embeddings and configurable backend. + + Args: + d_model: Model dimension + n_heads: Number of attention heads + attn_backend: One of "auto", "kernels_flash", "flex", "sdpa" + """ + def __init__( + self, + d_model: int, + n_heads: int, + attn_backend: str = "sdpa", + ): + super().__init__() + self.d_model = d_model + self.n_heads = n_heads + self.d_head = self.d_model // self.n_heads + self.scale = 1.0 / math.sqrt(self.d_head) + self.attn_backend = resolve_attention_backend(attn_backend) + self.layernorm_qkv = nn.Sequential( + nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False) + ) + self.out_proj = nn.Linear(d_model, d_model, bias=False) + self.q_ln = nn.LayerNorm(d_model, bias=False) + self.k_ln = nn.LayerNorm(d_model, bias=False) + self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads) + self.rotary = RotaryEmbedding(d_model // n_heads) + + def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + q = q.unflatten(-1, (self.n_heads, self.d_head)) + k = k.unflatten(-1, (self.n_heads, self.d_head)) + q, k = self.rotary(q, k) + q = q.flatten(-2, -1) + k = k.flatten(-2, -1) + return q, k + + def forward( + self, + x: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + output_attentions: bool = False, + output_s_max: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: + qkv_BLD3 = self.layernorm_qkv(x) + query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1) + query_BLD, key_BLD = ( + self.q_ln(query_BLD).to(query_BLD.dtype), + self.k_ln(key_BLD).to(query_BLD.dtype), + ) + query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD) + query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD)) + + attn_output, attn_weights, s_max = self._attn( + query_BHLD, key_BHLD, value_BHLD, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + + output = self.out_proj(attn_output) + return output, attn_weights, s_max + + def _attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + output_attentions: bool = False, + output_s_max: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: + if output_attentions: + return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max) + + if self.attn_backend == AttentionBackend.KERNELS_FLASH: + attn_output, attn_weights = self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) + elif self.attn_backend == AttentionBackend.FLEX: + attn_output, attn_weights = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) + elif self.attn_backend == AttentionBackend.SDPA: + attn_output, attn_weights = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) + else: + raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") + + s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None + return attn_output, attn_weights, s_max + + @torch.no_grad() + def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]: + q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) + k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) + s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale + return [s_max_bound[h] for h in range(self.n_heads)] + + def _manual_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_4d: Optional[torch.Tensor] = None, + output_s_max: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]: + attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale + if attention_mask_4d is not None: + attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) + attn_weights = F.softmax(attn_weights, dim=-1) + context_BHLD = torch.matmul(attn_weights, value_BHLD) + attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") + s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None + return attn_output, attn_weights, s_max + + def _kernels_flash_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, None]: + query_BLHD = query_BHLD.transpose(1, 2).contiguous() + key_BLHD = key_BHLD.transpose(1, 2).contiguous() + value_BLHD = value_BHLD.transpose(1, 2).contiguous() + attn_output = kernels_flash_attention_func( + query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, + attention_mask_2d=attention_mask_2d, causal=False, + ) + return rearrange(attn_output, "b s h d -> b s (h d)"), None + + def _flex_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + flex_block_mask: Optional[BlockMask] = None, + ) -> Tuple[torch.Tensor, None]: + assert flex_attention is not None, "Flex attention is not available in this environment." + fn = _get_flex_attention_fn() + context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale) + return rearrange(context_BHLD, "b h s d -> b s (h d)"), None + + def _sdpa_attn( + self, + query_BHLD: torch.Tensor, + key_BHLD: torch.Tensor, + value_BHLD: torch.Tensor, + attention_mask_4d: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, None]: + context_BHLD = F.scaled_dot_product_attention( + query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale, + ) + return rearrange(context_BHLD, "b h s d -> b s (h d)"), None + + +### Regression Head +def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module: + """Create a regression head with optional hidden dimension. + + Args: + d_model: Input dimension + output_dim: Output dimension + hidden_dim: Optional hidden dimension (defaults to d_model) + """ + hidden_dim = hidden_dim if hidden_dim is not None else d_model + return nn.Sequential( + nn.Linear(d_model, hidden_dim), + nn.GELU(), + nn.LayerNorm(hidden_dim), + nn.Linear(hidden_dim, output_dim), + ) + + +### Transformer Block +class UnifiedTransformerBlock(nn.Module): + """Transformer block with attention and feedforward layers.""" + def __init__( + self, + d_model: int, + n_heads: int, + residue_scaling_factor: float = 1, + expansion_ratio: float = 8 / 3, + dropout: float = 0.0, + attn_backend: str = "sdpa", + ): + super().__init__() + self.attn = MultiHeadAttention(d_model=d_model, n_heads=n_heads, attn_backend=attn_backend) + self.ffn = swiglu_ln_ffn(d_model, expansion_ratio) + self.scaling_factor = residue_scaling_factor + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + attention_mask_2d: Optional[torch.Tensor] = None, + attention_mask_4d: Optional[torch.Tensor] = None, + flex_block_mask: Optional[BlockMask] = None, + output_attentions: bool = False, + output_s_max: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: + attn_output, attn_weights, s_max = self.attn( + x, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + x = x + self.dropout(attn_output) / self.scaling_factor + x = x + self.dropout(self.ffn(x)) / self.scaling_factor + return x, attn_weights, s_max + + +### Model Outputs +@dataclass +class TransformerOutput(ModelOutput): + """Output type for transformer encoder.""" + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor]] = None + attentions: Optional[Tuple[torch.Tensor]] = None + s_max: Optional[Tuple[List[torch.Tensor], ...]] = None + + +@dataclass +class ESMplusplusOutput(ModelOutput): + """Output type for ESM++ models.""" + loss: Optional[torch.Tensor] = None + logits: Optional[torch.Tensor] = None + last_hidden_state: Optional[torch.Tensor] = None + hidden_states: Optional[Tuple[torch.Tensor]] = None + attentions: Optional[Tuple[torch.Tensor]] = None + s_max: Optional[Tuple[List[torch.Tensor], ...]] = None + + +### Transformer Stack +class TransformerStack(nn.Module): + """Stack of transformer blocks.""" + def __init__( + self, + d_model: int, + n_heads: int, + n_layers: int, + dropout: float = 0.0, + attn_backend: str = "sdpa", + ): + super().__init__() + self.attention_backend = resolve_attention_backend(attn_backend) + self.blocks = nn.ModuleList( + [ + UnifiedTransformerBlock( + d_model, + n_heads, + residue_scaling_factor=math.sqrt(n_layers / 36), + dropout=dropout, + attn_backend=attn_backend, + ) + for i in range(n_layers) + ] + ) + self.norm = nn.LayerNorm(d_model, bias=False) + self.gradient_checkpointing = False + + @property + def attn_backend(self) -> AttentionBackend: + return self.attention_backend + + @attn_backend.setter + def attn_backend(self, backend: str) -> None: + resolved = resolve_attention_backend(backend) + self.attention_backend = resolved + for block in self.blocks: + block.attn.attn_backend = resolved + + def forward( + self, + x: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_s_max: Optional[bool] = False, + ) -> TransformerOutput: + hidden_states = () if output_hidden_states else None + attentions = () if output_attentions else None + full_s_max = () if output_s_max else None + + attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( + effective_backend=self.attention_backend, + batch_size=x.shape[0], + seq_len=x.shape[1], + device=x.device, + attention_mask=attention_mask, + ) + + for block in self.blocks: + if self.gradient_checkpointing and self.training: + x, attn_weights, s_max = self._gradient_checkpointing_func( + block.__call__, + x=x, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + else: + x, attn_weights, s_max = block( + x=x, + attention_mask_2d=attention_mask_2d, + attention_mask_4d=attention_mask_4d, + flex_block_mask=flex_block_mask, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + + if attentions is not None: + attentions += (attn_weights,) + if output_hidden_states: + assert hidden_states is not None + hidden_states += (x,) + if full_s_max is not None: + full_s_max += (s_max,) + + last_hidden_state = self.norm(x) + if output_hidden_states: + hidden_states += (last_hidden_state,) + + return TransformerOutput( + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + attentions=attentions, + s_max=full_s_max, + ) + + +class PreTrainedESMplusplusModel(PreTrainedModel): + """ + init weights for ESM++ models + """ + config_class = ESMplusplusConfig + base_model_prefix = "esm++" + supports_gradient_checkpointing = True + all_tied_weights_keys = {} + + @classmethod + def is_remote_code(cls) -> bool: + # Prevent post-load reinitialization of tensors already loaded from checkpoints. + return True + + def _init_weights(self, module): + """Initialize the weights""" + # HF from_pretrained marks loaded parameters with `_is_hf_initialized`. + # Skip this module if any local parameter is already marked as loaded. + for parameter in module.parameters(recurse=False): + if "_is_hf_initialized" in parameter.__dict__ and parameter.__dict__["_is_hf_initialized"]: + return + + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + with torch.no_grad(): + module.weight[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + if module.bias is not None: + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + @property + def attn_backend(self) -> str: + return self.config.attn_backend + + @attn_backend.setter + def attn_backend(self, backend: str) -> None: + assert backend in VALID_ATTENTION_BACKENDS, f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." + self.config.attn_backend = backend + for module in self.modules(): + if isinstance(module, TransformerStack): + module.attn_backend = backend + + def _reset_rotary_embeddings(self): + """Refresh non-persistent rotary buffers after checkpoint loading.""" + for module in self.modules(): + if isinstance(module, RotaryEmbedding): + module.reset_parameters() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + output_loading_info = bool(kwargs["output_loading_info"]) if "output_loading_info" in kwargs else False + loaded = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + if output_loading_info: + model, loading_info = loaded + model._reset_rotary_embeddings() + return model, loading_info + loaded._reset_rotary_embeddings() + return loaded + + @classmethod + def from_pretrained_esm( + cls, + model_name: str, + device: Union[torch.device, str] = "cpu", + ): + """Load a pretrained ESM++ model.""" + key = _resolve_esmc_checkpoint_key(model_name) + if key == "esmc-300": + return ESMplusplus_300M(device=device) + if key == "esmc-600": + return ESMplusplus_600M(device=device) + if key == "esmc-6b": + return ESMplusplus_6B(device=device) + raise ValueError(f"Invalid model name: {model_name}") + + +### ESM++ Models +class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin): + """ + ESM++ model. transformer model with no heads + """ + config_class = ESMplusplusConfig + def __init__(self, config: ESMplusplusConfig, **kwargs): + PreTrainedESMplusplusModel.__init__(self, config, **kwargs) + self.config = config + self.vocab_size = config.vocab_size + self.embed = nn.Embedding(self.vocab_size, config.hidden_size) + self.transformer = TransformerStack( + d_model=config.hidden_size, + n_heads=config.num_attention_heads, + n_layers=config.num_hidden_layers, + dropout=config.dropout, + attn_backend=config.attn_backend, + ) + self.tokenizer = EsmSequenceTokenizer() + self.init_weights() + + def get_input_embeddings(self): + return self.embed + + def set_input_embeddings(self, value): + self.embed = value + + def _embed( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + ) -> torch.Tensor: + x = self.embed(input_ids) + output_hidden_states = store_all_hidden_states or hidden_state_index != -1 + output = self.transformer( + x=x, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=False, + ) + return select_hidden_state_embeddings( + output.last_hidden_state, + output.hidden_states, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_s_max: Optional[bool] = False, + return_dict: Optional[bool] = None, + **kwargs, + ) -> ESMplusplusOutput: + assert input_ids is not None or inputs_embeds is not None, "You have to specify either input_ids or inputs_embeds" + assert not (input_ids is not None and inputs_embeds is not None), "You cannot specify both input_ids and inputs_embeds at the same time" + + if inputs_embeds is None: + x = self.embed(input_ids) + else: + x = inputs_embeds + + transformer_output = self.transformer( + x=x, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + return ESMplusplusOutput( + last_hidden_state=transformer_output.last_hidden_state, + hidden_states=transformer_output.hidden_states, + attentions=transformer_output.attentions, + s_max=transformer_output.s_max, + ) + +class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin): + """ + ESM++ model for masked language modeling. + Implements the base ESM++ architecture with a masked language modeling head. + """ + config_class = ESMplusplusConfig + def __init__(self, config: ESMplusplusConfig, **kwargs): + PreTrainedESMplusplusModel.__init__(self, config, **kwargs) + self.config = config + self.vocab_size = config.vocab_size + self.embed = nn.Embedding(self.vocab_size, config.hidden_size) + self.transformer = TransformerStack( + d_model=config.hidden_size, + n_heads=config.num_attention_heads, + n_layers=config.num_hidden_layers, + dropout=config.dropout, + attn_backend=config.attn_backend, + ) + self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size) + self.ce_loss = nn.CrossEntropyLoss() + self.tokenizer = EsmSequenceTokenizer() + self.init_weights() + + def get_input_embeddings(self): + return self.embed + + def set_input_embeddings(self, value): + self.embed = value + + def get_output_embeddings(self): + return self.sequence_head[-1] + + def set_output_embeddings(self, new_embeddings): + self.sequence_head[-1] = new_embeddings + + def _embed( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + ) -> torch.Tensor: + x = self.embed(input_ids) + output_hidden_states = store_all_hidden_states or hidden_state_index != -1 + output = self.transformer( + x=x, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=False, + ) + return select_hidden_state_embeddings( + output.last_hidden_state, + output.hidden_states, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_s_max: Optional[bool] = False, + return_dict: Optional[bool] = None, + **kwargs, + ) -> ESMplusplusOutput: + if inputs_embeds is None: + x = self.embed(input_ids) + else: + x = inputs_embeds + + output = self.transformer( + x=x, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + output_s_max=output_s_max, + ) + + last_hidden_state = output.last_hidden_state + logits = self.sequence_head(last_hidden_state) + loss = None + if labels is not None: + loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1)) + + return ESMplusplusOutput( + loss=loss, + logits=logits, + last_hidden_state=last_hidden_state, + hidden_states=output.hidden_states, + attentions=output.attentions, + s_max=output.s_max, + ) + + +class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin): + """ + ESM++ model for sequence classification. + Extends the base ESM++ model with a classification head. + """ + def __init__(self, config: ESMplusplusConfig, **kwargs): + ESMplusplusForMaskedLM.__init__(self, config, **kwargs) + self.config = config + self.num_labels = config.num_labels + self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4) + # Large intermediate projections help with sequence classification tasks (*4) + self.mse = nn.MSELoss() + self.ce = nn.CrossEntropyLoss() + self.bce = nn.BCEWithLogitsLoss() + # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean'] + if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: + pooling_types = kwargs['pooling_types'] + else: + pooling_types = ['mean', 'var'] + self.pooler = Pooler(pooling_types) + self.init_weights() + + def _embed( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + ) -> torch.Tensor: + x = self.embed(input_ids) + output_hidden_states = store_all_hidden_states or hidden_state_index != -1 + output = self.transformer( + x=x, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=False, + ) + return select_hidden_state_embeddings( + output.last_hidden_state, + output.hidden_states, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_s_max: Optional[bool] = False, + return_dict: Optional[bool] = None, + **kwargs, + ) -> ESMplusplusOutput: + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + ) + + last_hidden_state = output.last_hidden_state + features = self.pooler(last_hidden_state, attention_mask) + logits = self.classifier(features) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + if self.num_labels == 1: + loss = self.mse(logits.flatten(), labels.flatten()) + else: + loss = self.mse(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss = self.bce(logits, labels) + + return ESMplusplusOutput( + loss=loss, + logits=logits, + last_hidden_state=last_hidden_state, + hidden_states=output.hidden_states, + attentions=output.attentions, + s_max=output.s_max, + ) + + +class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin): + """ + ESM++ model for token classification. + Extends the base ESM++ model with a token classification head. + """ + def __init__(self, config: ESMplusplusConfig, **kwargs): + ESMplusplusForMaskedLM.__init__(self, config, **kwargs) + self.config = config + self.num_labels = config.num_labels + self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4) + # Large intermediate projections help with sequence classification tasks (*4) + self.loss_fct = nn.CrossEntropyLoss() + self.init_weights() + + def _embed( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + hidden_state_index: int = -1, + store_all_hidden_states: bool = False, + ) -> torch.Tensor: + x = self.embed(input_ids) + output_hidden_states = store_all_hidden_states or hidden_state_index != -1 + output = self.transformer( + x, + attention_mask, + output_hidden_states=output_hidden_states, + output_attentions=False, + ) + return select_hidden_state_embeddings( + output.last_hidden_state, + output.hidden_states, + hidden_state_index=hidden_state_index, + store_all_hidden_states=store_all_hidden_states, + ) + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_s_max: Optional[bool] = False, + return_dict: Optional[bool] = None, + **kwargs, + ) -> ESMplusplusOutput: + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_s_max=output_s_max, + ) + + last_hidden_state = output.last_hidden_state + logits = self.classifier(last_hidden_state) + loss = None + if labels is not None: + loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return ESMplusplusOutput( + loss=loss, + logits=logits, + last_hidden_state=last_hidden_state, + hidden_states=output.hidden_states, + attentions=output.attentions, + s_max=output.s_max, + ) + + +### Loading from Biohub +_ESMC_CHECKPOINT_SPECS = { + "esmc-300": { + "repo_id": "biohub/ESMC-300M", + "hidden_size": 960, + "num_attention_heads": 15, + "num_hidden_layers": 30, + }, + "esmc-600": { + "repo_id": "biohub/ESMC-600M", + "hidden_size": 1152, + "num_attention_heads": 18, + "num_hidden_layers": 36, + }, + "esmc-6b": { + "repo_id": "biohub/ESMC-6B", + "hidden_size": 2560, + "num_attention_heads": 40, + "num_hidden_layers": 80, + }, +} + + +def _resolve_esmc_checkpoint_key(model: str) -> str: + normalized = model.lower().replace("_", "-") + if "300" in normalized: + return "esmc-300" + if "600" in normalized: + return "esmc-600" + if "6b" in normalized: + return "esmc-6b" + raise ValueError(f"{model=} is an invalid ESMC model name.") + + +@staticmethod +@cache +def data_root(model: str): + if "INFRA_PROVIDER" in os.environ: + return Path("") + key = _resolve_esmc_checkpoint_key(model) + return Path(snapshot_download(repo_id=_ESMC_CHECKPOINT_SPECS[key]["repo_id"])) + + +def get_esmc_checkpoint_path(model: str) -> Path: + key = _resolve_esmc_checkpoint_key(model) + spec = _ESMC_CHECKPOINT_SPECS[key] + if "weights_relpath" in spec: + return data_root(key) / spec["weights_relpath"] + checkpoint_dir = data_root(key) + if (checkpoint_dir / "model.safetensors").exists(): + return checkpoint_dir / "model.safetensors" + if (checkpoint_dir / "model.safetensors.index.json").exists(): + return checkpoint_dir / "model.safetensors.index.json" + raise FileNotFoundError(f"No ESMC checkpoint found under {checkpoint_dir}.") + + +def _normalize_esmc_state_key(key: str) -> Optional[str]: + if key.endswith("._extra_state"): + return None + if key.startswith("esmc."): + key = key[len("esmc."):] + if key.startswith("lm_head."): + key = f"sequence_head.{key[len('lm_head.'):]}" + replacements = ( + (".attn.layernorm_qkv.layer_norm_bias", ".attn.layernorm_qkv.0.bias"), + (".attn.layernorm_qkv.layer_norm_weight", ".attn.layernorm_qkv.0.weight"), + (".attn.layernorm_qkv.weight", ".attn.layernorm_qkv.1.weight"), + (".ffn.layer_norm_bias", ".ffn.0.bias"), + (".ffn.layer_norm_weight", ".ffn.0.weight"), + (".ffn.fc1_weight", ".ffn.1.weight"), + (".ffn.fc2_weight", ".ffn.3.weight"), + ) + for old, new in replacements: + key = key.replace(old, new) + return key + + +def _normalize_esmc_state_dict(state_dict: dict) -> dict: + normalized = {} + for key, tensor in state_dict.items(): + normalized_key = _normalize_esmc_state_key(key) + if normalized_key is None: + continue + normalized[normalized_key] = tensor + return normalized + + +def _safetensors_checkpoint_files(checkpoint_path: Path) -> List[Path]: + if checkpoint_path.name == "model.safetensors": + return [checkpoint_path] + with checkpoint_path.open("r", encoding="utf-8") as f: + index = json.load(f) + return [ + checkpoint_path.parent / filename + for filename in sorted(set(index["weight_map"].values())) + ] + + +def _load_safetensors_state_dict( + model_obj: ESMplusplusForMaskedLM, + checkpoint_path: Path, + device: Union[torch.device, str], +) -> None: + expected_keys = set(model_obj.state_dict().keys()) + loaded_keys = set() + device_string = str(torch.device(device)) + for shard_path in _safetensors_checkpoint_files(checkpoint_path): + shard_state_dict = load_safetensors_file(shard_path, device=device_string) + normalized = _normalize_esmc_state_dict(shard_state_dict) + unexpected = set(normalized.keys()) - expected_keys + assert len(unexpected) == 0, ( + f"Unexpected ESMC checkpoint keys in {shard_path.name}: " + f"{sorted(unexpected)[:10]}" + ) + model_obj.load_state_dict(normalized, strict=False) + loaded_keys.update(normalized.keys()) + + missing = expected_keys - loaded_keys + assert len(missing) == 0, ( + f"ESMC checkpoint did not provide all expected keys: {sorted(missing)[:10]}" + ) + + +def _load_esmc_checkpoint_model( + config: ESMplusplusConfig, + model: str, + device: Union[torch.device, str] = "cpu", +) -> ESMplusplusForMaskedLM: + key = _resolve_esmc_checkpoint_key(model) + spec = _ESMC_CHECKPOINT_SPECS[key] + assert config.hidden_size == spec["hidden_size"], ( + f"ESMC loader expected hidden_size={spec['hidden_size']} for {key}, " + f"but got {config.hidden_size}." + ) + assert config.num_attention_heads == spec["num_attention_heads"], ( + f"ESMC loader expected num_attention_heads={spec['num_attention_heads']} for {key}, " + f"but got {config.num_attention_heads}." + ) + assert config.num_hidden_layers == spec["num_hidden_layers"], ( + f"ESMC loader expected num_hidden_layers={spec['num_hidden_layers']} for {key}, " + f"but got {config.num_hidden_layers}." + ) + with torch.device(device): + model_obj = ESMplusplusForMaskedLM(config) + checkpoint_path = get_esmc_checkpoint_path(key) + if checkpoint_path.suffix == ".safetensors" or checkpoint_path.name == "model.safetensors.index.json": + _load_safetensors_state_dict( + model_obj=model_obj, + checkpoint_path=checkpoint_path, + device=device, + ) + else: + state_dict = torch.load(checkpoint_path, map_location=device) + model_obj.load_state_dict(_normalize_esmc_state_dict(state_dict)) + return model_obj + + +def ESMplusplus_300M(device: Union[torch.device, str] = "cpu"): + config = ESMplusplusConfig( + hidden_size=960, + num_attention_heads=15, + num_hidden_layers=30, + ) + return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device) + + +def ESMplusplus_600M(device: Union[torch.device, str] = "cpu"): + config = ESMplusplusConfig( + hidden_size=1152, + num_attention_heads=18, + num_hidden_layers=36, + ) + return _load_esmc_checkpoint_model(config=config, model="esmc-600", device=device) + + +def ESMplusplus_6B(device: Union[torch.device, str] = "cpu"): + config = ESMplusplusConfig( + hidden_size=2560, + num_attention_heads=40, + num_hidden_layers=80, + ) + return _load_esmc_checkpoint_model(config=config, model="esmc-6b", device=device) + + +### Tokenization +SEQUENCE_VOCAB = [ + "", "", "", "", + "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", + "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", + "O", ".", "-", "|", + "", +] + +class EsmSequenceTokenizer(PreTrainedTokenizerFast): + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + unk_token="", + cls_token="", + pad_token="", + mask_token="", + eos_token="", + chain_break_token="|", + **kwargs, + ): + all_tokens = SEQUENCE_VOCAB + token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)} + + # a character-level tokenizer is the same as BPE with no token merges + bpe = BPE(token_to_id, merges=[], unk_token=unk_token) + tokenizer = Tokenizer(bpe) + special_tokens = [ + cls_token, + pad_token, + mask_token, + eos_token, + chain_break_token, + ] + self.cb_token = chain_break_token + additional_special_tokens = [chain_break_token] + + tokenizer.add_special_tokens(special_tokens) + + # This is where we configure the automatic addition of special tokens when we call + # tokenizer(text, add_special_tokens=True). Note that you can also configure how two + # sequences are merged if you want. + tokenizer.post_processor = TemplateProcessing( # type: ignore + single=" $A ", + pair=":0 $A:0 :0 $B:1 :1", + special_tokens=[ + ("", tokenizer.token_to_id("")), + ("", tokenizer.token_to_id("")), + ], + ) + super().__init__( + tokenizer_object=tokenizer, + unk_token=unk_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + eos_token=eos_token, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here. + @property + def bos_token(self): + return self.cls_token + + @property + def bos_token_id(self): + return self.cls_token_id + + @property + def chain_break_token(self): + return self.cb_token + + @property + def chain_break_token_id(self): + return self.convert_tokens_to_ids(self.chain_break_token) + + @property + def all_token_ids(self): + return list(range(self.vocab_size)) + + @property + def special_token_ids(self): + return self.all_special_ids + + +if __name__ == "__main__": + import random + + import torch + + from torch import Tensor + + def print_tensor_shapes(prefix: str, obj): + if isinstance(obj, Tensor): + print(f"{prefix}{obj.shape}") + elif isinstance(obj, dict): + for name, value in obj.items(): + print_tensor_shapes(f"{prefix}{name}.", value) + elif isinstance(obj, list): + for idx, value in enumerate(obj): + print_tensor_shapes(f"{prefix}[{idx}].", value) + elif isinstance(obj, tuple): + for idx, value in enumerate(obj): + print_tensor_shapes(f"{prefix}[{idx}].", value) + elif hasattr(obj, "__dict__"): + for name, value in vars(obj).items(): + if name.startswith("_"): + continue + print_tensor_shapes(f"{prefix}{name}.", value) + else: + print(f"{prefix}{type(obj)}") + + random.seed(0) + torch.manual_seed(0) + + tokenizer = EsmSequenceTokenizer() + num_attention_heads = random.choice([2, 4]) + config = ESMplusplusConfig( + vocab_size=tokenizer.vocab_size, + hidden_size=16 * num_attention_heads, + num_attention_heads=num_attention_heads, + num_hidden_layers=random.choice([1, 2]), + num_labels=2, + dropout=0.0, + ) + + batch = tokenizer(["ACDEFG", "MKTW"], return_tensors="pt", padding=True) + batch["labels"] = batch["input_ids"].clone() + model = ESMplusplusForMaskedLM(config=config).eval() + + with torch.no_grad(): + output = model(**batch, return_dict=True) + + print("Batch shape:") + print_tensor_shapes("", batch) + print("Output shape:") + print_tensor_shapes("", output)