lhallee commited on
Commit
b701ed0
·
verified ·
1 Parent(s): 91de303

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +1191 -1195
modeling_esm_plusplus.py CHANGED
@@ -1,1195 +1,1191 @@
1
- """
2
- ESM++ model implementation.
3
-
4
- ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
- The ESM Python package is not required
6
-
7
- Modified from https://github.com/evolutionaryscale/esm
8
- License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
- """
10
-
11
- import math
12
- import os
13
- import warnings
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import networkx as nx
18
- from dataclasses import dataclass
19
- from functools import cache, partial
20
- from pathlib import Path
21
- from typing import Optional, Tuple, Union, List, Callable, Dict
22
- from einops import rearrange, repeat
23
- from huggingface_hub import snapshot_download
24
- from tokenizers import Tokenizer
25
- from tokenizers.models import BPE
26
- from tokenizers.processors import TemplateProcessing
27
- from torch.utils.data import Dataset as TorchDataset
28
- from torch.utils.data import DataLoader
29
- from tqdm.auto import tqdm
30
- from transformers import PreTrainedModel, PreTrainedTokenizerFast, PreTrainedTokenizerBase, PretrainedConfig
31
- from transformers.modeling_outputs import ModelOutput
32
-
33
- from .embedding_mixin import EmbeddingMixin, Pooler
34
-
35
- try:
36
- from torch.nn.attention.flex_attention import create_block_mask
37
- from torch.nn.attention.flex_attention import flex_attention as _raw_flex_attention
38
- except ImportError:
39
- create_block_mask = None
40
- _raw_flex_attention = None
41
-
42
-
43
- def _resolve_flex_attention(attn_compile: bool):
44
- if _raw_flex_attention is None:
45
- return None
46
- if not attn_compile:
47
- return _raw_flex_attention
48
- try:
49
- return torch.compile(_raw_flex_attention, dynamic=True)
50
- except Exception:
51
- return _raw_flex_attention
52
-
53
-
54
- def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
55
- assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
56
- token_valid = attention_mask_2d.bool()
57
- batch_size, seq_len = token_valid.shape
58
-
59
- def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
60
- return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
61
-
62
- return create_block_mask(
63
- mask_mod,
64
- batch_size,
65
- 1,
66
- seq_len,
67
- seq_len,
68
- device=attention_mask_2d.device,
69
- BLOCK_SIZE=block_size,
70
- )
71
-
72
-
73
- class ESMplusplusConfig(PretrainedConfig):
74
- """Configuration class for ESM++ model.
75
-
76
- Args:
77
- vocab_size: Size of the vocabulary
78
- hidden_size: Dimension of hidden layers
79
- num_attention_heads: Number of attention heads
80
- num_hidden_layers: Number of transformer layers
81
- num_labels: Number of output labels for classification
82
- problem_type: Type of problem - regression, single/multi label classification
83
- """
84
- model_type = "ESMplusplus"
85
- def __init__(
86
- self,
87
- vocab_size: int = 64,
88
- hidden_size: int = 960,
89
- num_attention_heads: int = 15,
90
- num_hidden_layers: int = 30,
91
- num_labels: int = 2,
92
- problem_type: str | None = None,
93
- dropout: float = 0.0,
94
- initializer_range: float = 0.02,
95
- attn_backend: str = "flex",
96
- attn_compile: bool = True,
97
- flex_block_size: int = 128,
98
- **kwargs,
99
- ):
100
- super().__init__(**kwargs)
101
- self.vocab_size = vocab_size
102
- self.hidden_size = hidden_size
103
- self.num_attention_heads = num_attention_heads
104
- self.num_hidden_layers = num_hidden_layers
105
- self.num_labels = num_labels
106
- self.problem_type = problem_type
107
- self.dropout = dropout
108
- self.initializer_range = initializer_range
109
- self.tie_word_embeddings = False
110
- self.attn_backend = attn_backend
111
- self.attn_compile = attn_compile
112
- self.flex_block_size = flex_block_size
113
-
114
-
115
- ### Rotary Embeddings
116
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
117
- """Rotates half the hidden dims of the input."""
118
- if not interleaved:
119
- x1, x2 = x.chunk(2, dim=-1)
120
- return torch.cat((-x2, x1), dim=-1)
121
- else:
122
- x1, x2 = x[..., ::2], x[..., 1::2]
123
- return rearrange(
124
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
125
- )
126
-
127
-
128
- def apply_rotary_emb_torch(
129
- x: torch.Tensor,
130
- cos: torch.Tensor,
131
- sin: torch.Tensor,
132
- interleaved: bool = False,
133
- _inplace: bool = False,
134
- ) -> torch.Tensor:
135
- """Apply rotary embeddings to input based on cos and sin."""
136
- ro_dim = cos.shape[-1] * 2
137
- assert ro_dim <= x.shape[-1]
138
- seqlen = x.size(1)
139
- cos = cos[:seqlen]
140
- sin = sin[:seqlen]
141
- cos = repeat(cos, "s d -> s 1 (2 d)")
142
- sin = repeat(sin, "s d -> s 1 (2 d)")
143
- return torch.cat(
144
- [
145
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
146
- x[..., ro_dim:],
147
- ],
148
- dim=-1,
149
- )
150
-
151
-
152
- class RotaryEmbedding(torch.nn.Module):
153
- """Rotary position embeddings.
154
-
155
- Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
156
-
157
- Args:
158
- dim: Dimension of the embedding
159
- base: Base for computing angular frequencies
160
- interleaved: Whether to use interleaved rotations
161
- scale_base: Base for scaling
162
- scaling_factor: Factor for scaling positions
163
- pos_idx_in_fp32: Whether to compute position indices in fp32
164
- device: Computation device
165
- """
166
- def __init__(
167
- self,
168
- dim: int,
169
- base: float = 10000.0,
170
- interleaved: bool = False,
171
- scale_base: Optional[float] = None,
172
- scaling_factor: float = 1.0,
173
- pos_idx_in_fp32: bool = True,
174
- device: Optional[torch.device] = None,
175
- ):
176
- super().__init__()
177
- self.dim = dim
178
- self.base = float(base)
179
- self.pos_idx_in_fp32 = pos_idx_in_fp32
180
- self.interleaved = interleaved
181
- self.scale_base = scale_base
182
- self.scaling_factor = scaling_factor
183
- self.device = device
184
-
185
- self._seq_len_cached = 0
186
- self._cos_cached = None
187
- self._sin_cached = None
188
- self._cos_k_cached = None
189
- self._sin_k_cached = None
190
- self.reset_parameters()
191
-
192
- def reset_parameters(self):
193
- """Reset the parameters of the embedding."""
194
- inv_freq = self._compute_inv_freq(self.device)
195
- self.register_buffer("inv_freq", inv_freq, persistent=False)
196
- arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
197
- scale = (
198
- (arange + 0.4 * self.dim) / (1.4 * self.dim)
199
- if self.scale_base is not None
200
- else None
201
- )
202
- self.register_buffer("scale", scale)
203
-
204
- def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
205
- """Compute inverse frequency bands."""
206
- return 1 / (
207
- self.base
208
- ** (
209
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
210
- / self.dim
211
- )
212
- )
213
-
214
- def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
215
- """Update the cached cosine and sine values."""
216
- if (
217
- seqlen > self._seq_len_cached
218
- or self._cos_cached is None
219
- or self._cos_cached.device != device
220
- or self._cos_cached.dtype != dtype
221
- or (self.training and self._cos_cached.is_inference())
222
- ):
223
- self._seq_len_cached = seqlen
224
- if self.pos_idx_in_fp32:
225
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
226
- t /= self.scaling_factor
227
- if self.inv_freq.dtype != torch.float32:
228
- inv_freq = self.inv_freq.to(torch.float32)
229
- else:
230
- inv_freq = self.inv_freq
231
- else:
232
- t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
233
- t /= self.scaling_factor
234
- inv_freq = self.inv_freq
235
- freqs = torch.outer(t, inv_freq)
236
-
237
- if self.scale is None:
238
- self._cos_cached = torch.cos(freqs).to(dtype)
239
- self._sin_cached = torch.sin(freqs).to(dtype)
240
- else:
241
- power = (
242
- torch.arange(
243
- seqlen, dtype=self.scale.dtype, device=self.scale.device
244
- )
245
- - seqlen // 2
246
- ) / self.scale_base
247
- scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
248
- self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
249
- self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
250
- self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
251
- self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
252
-
253
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
254
- """Apply rotary embeddings to queries and keys.
255
-
256
- Args:
257
- q: Query tensor of shape (batch, seqlen, nheads, headdim)
258
- k: Key tensor of shape (batch, seqlen, nheads, headdim)
259
-
260
- Returns:
261
- Tuple of rotated query and key tensors
262
- """
263
- self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
264
- assert self._cos_cached is not None
265
- assert self._sin_cached is not None
266
- if self.scale is None:
267
- return (
268
- apply_rotary_emb_torch(
269
- q,
270
- self._cos_cached,
271
- self._sin_cached,
272
- self.interleaved,
273
- True, # inplace=True
274
- ),
275
- apply_rotary_emb_torch(
276
- k,
277
- self._cos_cached,
278
- self._sin_cached,
279
- self.interleaved,
280
- True, # inplace=True
281
- ),
282
- ) # type: ignore
283
- else:
284
- assert False
285
-
286
-
287
- ### Feedforward Network Components
288
- def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
289
- """Compute corrected dimension for SwiGLU."""
290
- return int(((expansion_ratio * d_model) + 255) // 256 * 256)
291
-
292
-
293
- class SwiGLU(nn.Module):
294
- """SwiGLU activation function."""
295
- def __init__(self):
296
- super(SwiGLU, self).__init__()
297
-
298
- def forward(self, x: torch.Tensor) -> torch.Tensor:
299
- x1, x2 = x.chunk(2, dim=-1)
300
- return F.silu(x1) * x2
301
-
302
-
303
- def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
304
- """Create SwiGLU feedforward network with layer normalization."""
305
- return nn.Sequential(
306
- nn.LayerNorm(d_model),
307
- nn.Linear(
308
- d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
309
- ),
310
- SwiGLU(),
311
- nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
312
- )
313
-
314
-
315
- ### Attention
316
- class MultiHeadAttention(nn.Module):
317
- """Multi-head attention with rotary embeddings.
318
-
319
- Args:
320
- d_model: Model dimension
321
- n_heads: Number of attention heads
322
- """
323
- def __init__(
324
- self,
325
- d_model: int,
326
- n_heads: int,
327
- attn_backend: str = "flex",
328
- attn_compile: bool = True,
329
- flex_block_size: int = 128,
330
- ):
331
- super().__init__()
332
- self.d_model = d_model
333
- self.n_heads = n_heads
334
- self.d_head = self.d_model // self.n_heads
335
- self.attn_backend = attn_backend
336
- self.flex_block_size = flex_block_size
337
- self.flex_attention = _resolve_flex_attention(attn_compile)
338
- self._warned_flex_fallback = False
339
- self.layernorm_qkv = nn.Sequential(
340
- nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
341
- )
342
- self.out_proj = nn.Linear(d_model, d_model, bias=False)
343
- self.q_ln = nn.LayerNorm(d_model, bias=False)
344
- self.k_ln = nn.LayerNorm(d_model, bias=False)
345
- self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
346
- self.rotary = RotaryEmbedding(d_model // n_heads)
347
-
348
- def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
349
- """Apply rotary embeddings to query and key."""
350
- q = q.unflatten(-1, (self.n_heads, self.d_head))
351
- k = k.unflatten(-1, (self.n_heads, self.d_head))
352
- q, k = self.rotary(q, k)
353
- q = q.flatten(-2, -1)
354
- k = k.flatten(-2, -1)
355
- return q, k
356
-
357
- def forward(
358
- self,
359
- x: torch.Tensor,
360
- attention_mask: Optional[torch.Tensor] = None,
361
- flex_block_mask: Optional[object] = None,
362
- output_attentions: bool = False,
363
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
364
- """
365
- Args:
366
- x: Input tensor
367
- attention_mask: Optional attention mask
368
- output_attentions: Whether to return attention weights
369
-
370
- Returns:
371
- Output tensor after self attention, and optionally attention weights
372
- """
373
- attn_weights = None
374
- qkv_BLD3 = self.layernorm_qkv(x)
375
- query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
376
- query_BLD, key_BLD = (
377
- self.q_ln(query_BLD).to(query_BLD.dtype),
378
- self.k_ln(key_BLD).to(query_BLD.dtype),
379
- )
380
- query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
381
- query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
382
-
383
- if output_attentions: # Manual attention computation
384
- b, h, l, d = query_BHLD.shape
385
- scale = 1 / math.sqrt(d)
386
- attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
387
- if attention_mask is not None:
388
- attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
389
- attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
390
- attn_weights += attn_bias
391
- attn_weights = F.softmax(attn_weights, dim=-1)
392
- context_BHLD = torch.matmul(attn_weights, value_BHLD)
393
- else:
394
- sdpa_mask = None
395
- if attention_mask is not None:
396
- sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
397
- sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
398
- use_flex = (
399
- self.attn_backend == "flex"
400
- and self.flex_attention is not None
401
- and (attention_mask is None or flex_block_mask is not None)
402
- )
403
- if use_flex:
404
- try:
405
- context_BHLD = self.flex_attention(
406
- query_BHLD,
407
- key_BHLD,
408
- value_BHLD,
409
- block_mask=flex_block_mask,
410
- enable_gqa=query_BHLD.shape[1] != key_BHLD.shape[1],
411
- )
412
- except Exception as exc:
413
- if not self._warned_flex_fallback:
414
- warnings.warn(
415
- f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
416
- RuntimeWarning,
417
- )
418
- self._warned_flex_fallback = True
419
- context_BHLD = F.scaled_dot_product_attention(
420
- query_BHLD,
421
- key_BHLD,
422
- value_BHLD,
423
- attn_mask=sdpa_mask,
424
- )
425
- else:
426
- context_BHLD = F.scaled_dot_product_attention(
427
- query_BHLD,
428
- key_BHLD,
429
- value_BHLD,
430
- attn_mask=sdpa_mask,
431
- )
432
-
433
- context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
434
- output = self.out_proj(context_BLD)
435
- return output, attn_weights
436
-
437
-
438
- ### Regression Head
439
- def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
440
- """Create a regression head with optional hidden dimension.
441
-
442
- Args:
443
- d_model: Input dimension
444
- output_dim: Output dimension
445
- hidden_dim: Optional hidden dimension (defaults to d_model)
446
- """
447
- hidden_dim = hidden_dim if hidden_dim is not None else d_model
448
- return nn.Sequential(
449
- nn.Linear(d_model, hidden_dim),
450
- nn.GELU(),
451
- nn.LayerNorm(hidden_dim),
452
- nn.Linear(hidden_dim, output_dim),
453
- )
454
-
455
-
456
- ### Transformer Block
457
- class UnifiedTransformerBlock(nn.Module):
458
- """Transformer block with attention and feedforward layers.
459
-
460
- Args:
461
- d_model: Model dimension
462
- n_heads: Number of attention heads
463
- residue_scaling_factor: Factor for scaling residual connections
464
- expansion_ratio: Expansion ratio for feedforward network
465
- """
466
- def __init__(
467
- self,
468
- d_model: int,
469
- n_heads: int,
470
- residue_scaling_factor: float = 1,
471
- expansion_ratio: float = 8 / 3,
472
- dropout: float = 0.0,
473
- attn_backend: str = "flex",
474
- attn_compile: bool = True,
475
- flex_block_size: int = 128,
476
- ):
477
- super().__init__()
478
- self.attn = MultiHeadAttention(
479
- d_model=d_model,
480
- n_heads=n_heads,
481
- attn_backend=attn_backend,
482
- attn_compile=attn_compile,
483
- flex_block_size=flex_block_size,
484
- )
485
- self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
486
- self.scaling_factor = residue_scaling_factor
487
- self.dropout = nn.Dropout(dropout)
488
-
489
- def forward(
490
- self,
491
- x: torch.Tensor,
492
- attention_mask: Optional[torch.Tensor] = None,
493
- flex_block_mask: Optional[object] = None,
494
- output_attentions: bool = False,
495
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
496
- """
497
- Args:
498
- x: Input tensor
499
- attention_mask: Optional attention mask
500
- output_attentions: Whether to return attention weights
501
-
502
- Returns:
503
- Output tensor after transformer block, and optionally attention weights
504
- """
505
- attn_output, attn_weights = self.attn(
506
- x,
507
- attention_mask,
508
- flex_block_mask,
509
- output_attentions,
510
- )
511
- x = x + self.dropout(attn_output) / self.scaling_factor
512
- x = x + self.dropout(self.ffn(x)) / self.scaling_factor
513
- return x, attn_weights
514
-
515
-
516
- ### Model Outputs
517
- @dataclass
518
- class TransformerOutput(ModelOutput):
519
- """Output type for transformer encoder."""
520
- last_hidden_state: Optional[torch.Tensor] = None
521
- hidden_states: Optional[Tuple[torch.Tensor]] = None
522
- attentions: Optional[Tuple[torch.Tensor]] = None
523
-
524
-
525
- @dataclass
526
- class ESMplusplusOutput(ModelOutput):
527
- """Output type for ESM++ models."""
528
- loss: Optional[torch.Tensor] = None
529
- logits: Optional[torch.Tensor] = None
530
- last_hidden_state: Optional[torch.Tensor] = None
531
- hidden_states: Optional[Tuple[torch.Tensor]] = None
532
- attentions: Optional[Tuple[torch.Tensor]] = None
533
-
534
-
535
- ### Transformer Stack
536
- class TransformerStack(nn.Module):
537
- """Stack of transformer blocks.
538
-
539
- Args:
540
- d_model: Model dimension
541
- n_heads: Number of attention heads
542
- n_layers: Number of transformer layers
543
- dropout: Dropout rate
544
- """
545
- def __init__(
546
- self,
547
- d_model: int,
548
- n_heads: int,
549
- n_layers: int,
550
- dropout: float = 0.0,
551
- attn_backend: str = "flex",
552
- attn_compile: bool = True,
553
- flex_block_size: int = 128,
554
- ):
555
- super().__init__()
556
- self.attn_backend = attn_backend
557
- self.flex_block_size = flex_block_size
558
- self.blocks = nn.ModuleList(
559
- [
560
- UnifiedTransformerBlock(
561
- d_model,
562
- n_heads,
563
- residue_scaling_factor=math.sqrt(n_layers / 36),
564
- dropout=dropout,
565
- attn_backend=attn_backend,
566
- attn_compile=attn_compile,
567
- flex_block_size=flex_block_size,
568
- )
569
- for i in range(n_layers)
570
- ]
571
- )
572
- self.norm = nn.LayerNorm(d_model, bias=False)
573
- self.gradient_checkpointing = False
574
-
575
- def forward(
576
- self,
577
- x: torch.Tensor,
578
- attention_mask: Optional[torch.Tensor] = None,
579
- output_hidden_states: bool = False,
580
- output_attentions: bool = False,
581
- ) -> TransformerOutput:
582
- """
583
- Args:
584
- x: Input tensor
585
- attention_mask: Optional attention mask
586
- output_hidden_states: Whether to return all hidden states
587
- output_attentions: Whether to return attention weights
588
-
589
- Returns:
590
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
591
- """
592
- batch_size, seq_len, _ = x.shape
593
- hidden_states = () if output_hidden_states else None
594
- attentions = () if output_attentions else None
595
-
596
- if attention_mask is not None:
597
- attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
598
- if self.attn_backend == "flex" and create_block_mask is not None and not output_attentions:
599
- token_attention_mask = attention_mask[:, 0, 0, :]
600
- flex_block_mask = _create_pad_block_mask(token_attention_mask, self.flex_block_size)
601
- else:
602
- flex_block_mask = None
603
- else:
604
- flex_block_mask = None
605
-
606
- for block in self.blocks:
607
- if self.gradient_checkpointing and self.training:
608
- x, attn_weights = self._gradient_checkpointing_func(
609
- block.__call__,
610
- x,
611
- attention_mask,
612
- flex_block_mask,
613
- output_attentions,
614
- )
615
- else:
616
- x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
617
-
618
- if attentions is not None:
619
- attentions += (attn_weights,)
620
-
621
- if output_hidden_states:
622
- assert hidden_states is not None
623
- hidden_states += (x,)
624
-
625
- return TransformerOutput(
626
- last_hidden_state=self.norm(x),
627
- hidden_states=hidden_states,
628
- attentions=attentions
629
- )
630
-
631
-
632
- class PreTrainedESMplusplusModel(PreTrainedModel):
633
- """
634
- init weights for ESM++ models
635
- """
636
- config_class = ESMplusplusConfig
637
- base_model_prefix = "esm++"
638
- supports_gradient_checkpointing = True
639
- all_tied_weights_keys = {}
640
-
641
- def _init_weights(self, module):
642
- """Initialize the weights"""
643
- if isinstance(module, nn.Linear):
644
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
645
- if module.bias is not None:
646
- module.bias.data.zero_()
647
- elif isinstance(module, nn.Embedding):
648
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
649
- if module.padding_idx is not None:
650
- module.weight.data[module.padding_idx].zero_()
651
- elif isinstance(module, nn.LayerNorm):
652
- if module.bias is not None:
653
- module.bias.data.zero_()
654
- module.weight.data.fill_(1.0)
655
-
656
- @classmethod
657
- def from_pretrained_esm(cls, model_name: str):
658
- """Load a pretrained ESM++ model."""
659
- if '300' in model_name:
660
- return ESMplusplus_300M()
661
- elif '600' in model_name:
662
- return ESMplusplus_600M()
663
- else:
664
- raise ValueError(f"Invalid model name: {model_name}")
665
-
666
-
667
- ### ESM++ Models
668
- class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
669
- """
670
- ESM++ model. transformer model with no heads
671
- """
672
- config_class = ESMplusplusConfig
673
- def __init__(self, config: ESMplusplusConfig, **kwargs):
674
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
675
- self.config = config
676
- self.vocab_size = config.vocab_size
677
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
678
- self.transformer = TransformerStack(
679
- d_model=config.hidden_size,
680
- n_heads=config.num_attention_heads,
681
- n_layers=config.num_hidden_layers,
682
- dropout=config.dropout,
683
- attn_backend=config.attn_backend,
684
- attn_compile=config.attn_compile,
685
- flex_block_size=config.flex_block_size,
686
- )
687
- self.tokenizer = EsmSequenceTokenizer()
688
- self.init_weights()
689
-
690
- def get_input_embeddings(self):
691
- return self.embed
692
-
693
- def set_input_embeddings(self, value):
694
- self.embed = value
695
-
696
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
697
- x = self.embed(input_ids)
698
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
699
-
700
- def forward(
701
- self,
702
- input_ids: Optional[torch.Tensor] = None,
703
- attention_mask: Optional[torch.Tensor] = None,
704
- inputs_embeds: Optional[torch.Tensor] = None,
705
- output_attentions: Optional[bool] = None,
706
- output_hidden_states: Optional[bool] = None,
707
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
708
- **kwargs,
709
- ) -> TransformerOutput:
710
- """Forward pass for masked language modeling.
711
-
712
- Args:
713
- input_ids: Input token IDs
714
- attention_mask: Attention mask
715
- inputs_embeds: Optional precomputed embeddings
716
- output_hidden_states: Whether to return all hidden states
717
- output_attentions: Whether to return attention weights
718
-
719
- Returns:
720
- TransformerOutput containing last hidden state and optionally all hidden states and attention weights
721
- """
722
- if inputs_embeds is None:
723
- x = self.embed(input_ids)
724
- else:
725
- x = inputs_embeds
726
- return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
727
-
728
-
729
- class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
730
- """
731
- ESM++ model for masked language modeling.
732
- Implements the base ESM++ architecture with a masked language modeling head.
733
- """
734
- config_class = ESMplusplusConfig
735
- def __init__(self, config: ESMplusplusConfig, **kwargs):
736
- PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
737
- self.config = config
738
- self.vocab_size = config.vocab_size
739
- self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
740
- self.transformer = TransformerStack(
741
- d_model=config.hidden_size,
742
- n_heads=config.num_attention_heads,
743
- n_layers=config.num_hidden_layers,
744
- dropout=config.dropout,
745
- attn_backend=config.attn_backend,
746
- attn_compile=config.attn_compile,
747
- flex_block_size=config.flex_block_size,
748
- )
749
- self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
750
- self.ce_loss = nn.CrossEntropyLoss()
751
- self.tokenizer = EsmSequenceTokenizer()
752
- self.init_weights()
753
-
754
- def get_input_embeddings(self):
755
- return self.embed
756
-
757
- def set_input_embeddings(self, value):
758
- self.embed = value
759
-
760
- def get_output_embeddings(self):
761
- return self.sequence_head[-1]
762
-
763
- def set_output_embeddings(self, new_embeddings):
764
- self.sequence_head[-1] = new_embeddings
765
-
766
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
767
- x = self.embed(input_ids)
768
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
769
-
770
- def forward(
771
- self,
772
- input_ids: Optional[torch.Tensor] = None,
773
- attention_mask: Optional[torch.Tensor] = None,
774
- inputs_embeds: Optional[torch.Tensor] = None,
775
- labels: Optional[torch.Tensor] = None,
776
- output_attentions: Optional[bool] = None,
777
- output_hidden_states: Optional[bool] = None,
778
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
779
- **kwargs,
780
- ) -> ESMplusplusOutput:
781
- """Forward pass for masked language modeling.
782
-
783
- Args:
784
- input_ids: Input token IDs
785
- attention_mask: Attention mask
786
- inputs_embeds: Optional precomputed embeddings
787
- labels: Optional labels for masked tokens
788
- output_hidden_states: Whether to return all hidden states
789
- output_attentions: Whether to return attention weights
790
-
791
- Returns:
792
- ESMplusplusOutput containing loss, logits, hidden states and attention weights
793
- """
794
- if inputs_embeds is None:
795
- x = self.embed(input_ids)
796
- else:
797
- x = inputs_embeds
798
- output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
799
- x = output.last_hidden_state
800
- logits = self.sequence_head(x)
801
- loss = None
802
- if labels is not None:
803
- loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
804
- return ESMplusplusOutput(
805
- loss=loss,
806
- logits=logits,
807
- last_hidden_state=x,
808
- hidden_states=output.hidden_states,
809
- attentions=output.attentions,
810
- )
811
-
812
-
813
- class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
814
- """
815
- ESM++ model for sequence classification.
816
- Extends the base ESM++ model with a classification head.
817
- """
818
- def __init__(self, config: ESMplusplusConfig, **kwargs):
819
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
820
- self.config = config
821
- self.num_labels = config.num_labels
822
- self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
823
- # Large intermediate projections help with sequence classification tasks (*4)
824
- self.mse = nn.MSELoss()
825
- self.ce = nn.CrossEntropyLoss()
826
- self.bce = nn.BCEWithLogitsLoss()
827
- # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
828
- if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
829
- pooling_types = kwargs['pooling_types']
830
- else:
831
- pooling_types = ['cls', 'mean']
832
- self.pooler = Pooler(pooling_types)
833
- self.init_weights()
834
-
835
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
836
- x = self.embed(input_ids)
837
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
838
-
839
- def forward(
840
- self,
841
- input_ids: Optional[torch.Tensor] = None,
842
- attention_mask: Optional[torch.Tensor] = None,
843
- inputs_embeds: Optional[torch.Tensor] = None,
844
- labels: Optional[torch.Tensor] = None,
845
- output_attentions: Optional[bool] = None,
846
- output_hidden_states: Optional[bool] = None,
847
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
848
- **kwargs,
849
- ) -> ESMplusplusOutput:
850
- """Forward pass for sequence classification.
851
-
852
- Args:
853
- input_ids: Input token IDs
854
- attention_mask: Attention mask
855
- inputs_embeds: Optional precomputed embeddings
856
- labels: Optional labels for classification
857
- output_hidden_states: Whether to return all hidden states
858
- output_attentions: Whether to return attention weights
859
-
860
- Returns:
861
- ESMplusplusOutput containing loss, logits, and hidden states
862
- """
863
- output = super().forward(
864
- input_ids=input_ids,
865
- attention_mask=attention_mask,
866
- inputs_embeds=inputs_embeds,
867
- labels=None,
868
- output_attentions=output_attentions,
869
- output_hidden_states=output_hidden_states
870
- )
871
- x = output.last_hidden_state
872
- features = self.pooler(x, attention_mask)
873
- logits = self.classifier(features)
874
- loss = None
875
- if labels is not None:
876
- labels = labels.to(logits.device)
877
- if self.config.problem_type is None:
878
- if self.num_labels == 1:
879
- self.config.problem_type = "regression"
880
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
881
- self.config.problem_type = "single_label_classification"
882
- else:
883
- self.config.problem_type = "multi_label_classification"
884
-
885
- if self.config.problem_type == "regression":
886
- if self.num_labels == 1:
887
- loss = self.mse(logits.flatten(), labels.flatten())
888
- else:
889
- loss = self.mse(logits, labels)
890
- elif self.config.problem_type == "single_label_classification":
891
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
892
- elif self.config.problem_type == "multi_label_classification":
893
- loss = self.bce(logits, labels)
894
-
895
- return ESMplusplusOutput(
896
- loss=loss,
897
- logits=logits,
898
- last_hidden_state=x,
899
- hidden_states=output.hidden_states,
900
- attentions=output.attentions,
901
- )
902
-
903
-
904
- class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
905
- """
906
- ESM++ model for token classification.
907
- Extends the base ESM++ model with a token classification head.
908
- """
909
- def __init__(self, config: ESMplusplusConfig, **kwargs):
910
- ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
911
- self.config = config
912
- self.num_labels = config.num_labels
913
- self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
914
- # Large intermediate projections help with sequence classification tasks (*4)
915
- self.loss_fct = nn.CrossEntropyLoss()
916
- self.init_weights()
917
-
918
- def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
919
- x = self.embed(input_ids)
920
- return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
921
-
922
- def forward(
923
- self,
924
- input_ids: Optional[torch.Tensor] = None,
925
- attention_mask: Optional[torch.Tensor] = None,
926
- inputs_embeds: Optional[torch.Tensor] = None,
927
- labels: Optional[torch.Tensor] = None,
928
- output_attentions: Optional[bool] = None,
929
- output_hidden_states: Optional[bool] = None,
930
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
931
- **kwargs,
932
- ) -> ESMplusplusOutput:
933
- """Forward pass for token classification.
934
-
935
- Args:
936
- input_ids: Input token IDs
937
- attention_mask: Attention mask
938
- inputs_embeds: Optional precomputed embeddings
939
- labels: Optional labels for token classification
940
- output_hidden_states: Whether to return all hidden states
941
- output_attentions: Whether to return attention weights
942
-
943
- Returns:
944
- ESMplusplusOutput containing loss, logits, and hidden states
945
- """
946
- output = super().forward(
947
- input_ids=input_ids,
948
- attention_mask=attention_mask,
949
- inputs_embeds=inputs_embeds,
950
- labels=None,
951
- output_attentions=output_attentions,
952
- output_hidden_states=output_hidden_states
953
- )
954
- x = output.last_hidden_state
955
- logits = self.classifier(x)
956
- loss = None
957
- if labels is not None:
958
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
959
- return ESMplusplusOutput(
960
- loss=loss,
961
- logits=logits,
962
- last_hidden_state=x,
963
- hidden_states=output.hidden_states,
964
- attentions=output.attentions,
965
- )
966
-
967
-
968
- ### Loading from EvolutionaryScale
969
- @staticmethod
970
- @cache
971
- def data_root(model: str):
972
- if "INFRA_PROVIDER" in os.environ:
973
- return Path("")
974
- # Try to download from hugginface if it doesn't exist
975
- if model.startswith("esmc-300"):
976
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
977
- elif model.startswith("esmc-600"):
978
- path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
979
- else:
980
- raise ValueError(f"{model=} is an invalid model name.")
981
- return path
982
-
983
-
984
- def ESMplusplus_300M(device: torch.device | str = "cpu"):
985
- with torch.device(device):
986
- config = ESMplusplusConfig(
987
- hidden_size=960,
988
- num_attention_heads=15,
989
- num_hidden_layers=30,
990
- )
991
- model = ESMplusplusForMaskedLM(config)
992
- state_dict = torch.load(
993
- data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
994
- map_location=device,
995
- )
996
- model.load_state_dict(state_dict)
997
- return model
998
-
999
-
1000
- def ESMplusplus_600M(device: torch.device | str = "cpu"):
1001
- with torch.device(device):
1002
- config = ESMplusplusConfig(
1003
- hidden_size=1152,
1004
- num_attention_heads=18,
1005
- num_hidden_layers=36,
1006
- )
1007
- model = ESMplusplusForMaskedLM(config)
1008
- state_dict = torch.load(
1009
- data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
1010
- map_location=device,
1011
- )
1012
- model.load_state_dict(state_dict)
1013
- return model
1014
-
1015
-
1016
- ### Tokenization
1017
- SEQUENCE_VOCAB = [
1018
- "<cls>", "<pad>", "<eos>", "<unk>",
1019
- "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1020
- "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1021
- "O", ".", "-", "|",
1022
- "<mask>",
1023
- ]
1024
-
1025
- class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1026
- model_input_names = ["input_ids", "attention_mask"]
1027
-
1028
- def __init__(
1029
- self,
1030
- unk_token="<unk>",
1031
- cls_token="<cls>",
1032
- pad_token="<pad>",
1033
- mask_token="<mask>",
1034
- eos_token="<eos>",
1035
- chain_break_token="|",
1036
- **kwargs,
1037
- ):
1038
- all_tokens = SEQUENCE_VOCAB
1039
- token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1040
-
1041
- # a character-level tokenizer is the same as BPE with no token merges
1042
- bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1043
- tokenizer = Tokenizer(bpe)
1044
- special_tokens = [
1045
- cls_token,
1046
- pad_token,
1047
- mask_token,
1048
- eos_token,
1049
- chain_break_token,
1050
- ]
1051
- self.cb_token = chain_break_token
1052
- additional_special_tokens = [chain_break_token]
1053
-
1054
- tokenizer.add_special_tokens(special_tokens)
1055
-
1056
- # This is where we configure the automatic addition of special tokens when we call
1057
- # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1058
- # sequences are merged if you want.
1059
- tokenizer.post_processor = TemplateProcessing( # type: ignore
1060
- single="<cls> $A <eos>",
1061
- pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1062
- special_tokens=[
1063
- ("<cls>", tokenizer.token_to_id("<cls>")),
1064
- ("<eos>", tokenizer.token_to_id("<eos>")),
1065
- ],
1066
- )
1067
- super().__init__(
1068
- tokenizer_object=tokenizer,
1069
- unk_token=unk_token,
1070
- cls_token=cls_token,
1071
- pad_token=pad_token,
1072
- mask_token=mask_token,
1073
- eos_token=eos_token,
1074
- additional_special_tokens=additional_special_tokens,
1075
- **kwargs,
1076
- )
1077
-
1078
- # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1079
- @property
1080
- def bos_token(self):
1081
- return self.cls_token
1082
-
1083
- @property
1084
- def bos_token_id(self):
1085
- return self.cls_token_id
1086
-
1087
- @property
1088
- def chain_break_token(self):
1089
- return self.cb_token
1090
-
1091
- @property
1092
- def chain_break_token_id(self):
1093
- return self.convert_tokens_to_ids(self.chain_break_token)
1094
-
1095
- @property
1096
- def all_token_ids(self):
1097
- return list(range(self.vocab_size))
1098
-
1099
- @property
1100
- def special_token_ids(self):
1101
- return self.all_special_ids
1102
-
1103
-
1104
- if __name__ == "__main__":
1105
- # Set device to CPU for testing
1106
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1107
- print(f"Using device: {device}")
1108
-
1109
- # Test tokenizer
1110
- tokenizer = EsmSequenceTokenizer()
1111
- sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1112
- encoding = tokenizer(sample_sequence, return_tensors="pt")
1113
- print(f"Input sequence length: {len(sample_sequence)}")
1114
- print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1115
-
1116
- # Prepare inputs
1117
- input_ids = encoding['input_ids'].to(device)
1118
- attention_mask = encoding['attention_mask'].to(device)
1119
-
1120
- # Test base model with smaller config for quick testing
1121
- print("\n=== Testing ESMplusplus Base Model ===")
1122
- base_config = ESMplusplusConfig(
1123
- hidden_size=384,
1124
- num_attention_heads=6,
1125
- num_hidden_layers=4
1126
- )
1127
- base_model = ESMplusplusModel(base_config).to(device)
1128
-
1129
- with torch.no_grad():
1130
- outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1131
-
1132
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1133
-
1134
- # Test embedding functionality
1135
- print("\nTesting embedding functionality:")
1136
- with torch.no_grad():
1137
- embeddings = base_model._embed(input_ids, attention_mask)
1138
- print(f"Embedding shape: {embeddings.shape}")
1139
-
1140
- # Test masked language modeling
1141
- print("\n=== Testing ESMplusplus For Masked LM ===")
1142
- mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1143
-
1144
- with torch.no_grad():
1145
- outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1146
-
1147
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1148
- print(f"Logits shape: {outputs.logits.shape}")
1149
-
1150
- # Test sequence classification model
1151
- print("\n=== Testing Sequence Classification Model ===")
1152
- classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1153
-
1154
- with torch.no_grad():
1155
- outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1156
-
1157
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1158
- print(f"Logits shape: {outputs.logits.shape}")
1159
-
1160
- # Test token classification model
1161
- print("\n=== Testing Token Classification Model ===")
1162
- token_model = ESMplusplusForTokenClassification(base_config).to(device)
1163
-
1164
- with torch.no_grad():
1165
- outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1166
-
1167
- print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1168
- print(f"Logits shape: {outputs.logits.shape}")
1169
-
1170
- # Test embedding dataset functionality with a mini dataset
1171
- print("\n=== Testing Embed Dataset Functionality ===")
1172
- mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1173
- print(f"Creating embeddings for {len(mini_dataset)} sequences")
1174
-
1175
- # Only run this if save path doesn't exist to avoid overwriting
1176
- if not os.path.exists("test_embeddings.pth"):
1177
- embeddings = mlm_model.embed_dataset(
1178
- sequences=mini_dataset,
1179
- tokenizer=tokenizer,
1180
- batch_size=2,
1181
- max_len=100,
1182
- full_embeddings=False,
1183
- pooling_types=['mean'],
1184
- save_path="test_embeddings.pth"
1185
- )
1186
- if embeddings:
1187
- print(f"Embedding dictionary size: {len(embeddings)}")
1188
- for seq, emb in embeddings.items():
1189
- print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1190
- break
1191
- else:
1192
- print("Skipping embedding test as test_embeddings.pth already exists")
1193
-
1194
- print("\nAll tests completed successfully!")
1195
-
 
1
+ """
2
+ ESM++ model implementation.
3
+
4
+ ESM++ is a faithful implementation of ESMC that allows for batching and standard Huggingface compatibility
5
+ The ESM Python package is not required
6
+
7
+ Modified from https://github.com/evolutionaryscale/esm
8
+ License: https://www.evolutionaryscale.ai/policies/cambrian-non-commercial-license-agreement
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from dataclasses import dataclass
18
+ from functools import cache, partial
19
+ from pathlib import Path
20
+ from typing import Optional, Tuple, Union, List
21
+ from einops import rearrange, repeat
22
+ from huggingface_hub import snapshot_download
23
+ from tokenizers import Tokenizer
24
+ from tokenizers.models import BPE
25
+ from tokenizers.processors import TemplateProcessing
26
+ from transformers import PreTrainedModel, PreTrainedTokenizerFast, PretrainedConfig
27
+ from transformers.modeling_outputs import ModelOutput
28
+
29
+ from .embedding_mixin import EmbeddingMixin, Pooler
30
+
31
+ try:
32
+ from torch.nn.attention.flex_attention import create_block_mask
33
+ from torch.nn.attention.flex_attention import flex_attention as _raw_flex_attention
34
+ except ImportError:
35
+ create_block_mask = None
36
+ _raw_flex_attention = None
37
+
38
+
39
+ def _resolve_flex_attention(attn_compile: bool):
40
+ if _raw_flex_attention is None:
41
+ return None
42
+ if not attn_compile:
43
+ return _raw_flex_attention
44
+ try:
45
+ return torch.compile(_raw_flex_attention, dynamic=True)
46
+ except Exception:
47
+ return _raw_flex_attention
48
+
49
+
50
+ def _create_pad_block_mask(attention_mask_2d: torch.Tensor, block_size: int):
51
+ assert create_block_mask is not None, "Flex attention block mask requires create_block_mask."
52
+ token_valid = attention_mask_2d.bool()
53
+ batch_size, seq_len = token_valid.shape
54
+
55
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
56
+ return token_valid[batch_idx, q_idx] & token_valid[batch_idx, kv_idx]
57
+
58
+ return create_block_mask(
59
+ mask_mod,
60
+ batch_size,
61
+ 1,
62
+ seq_len,
63
+ seq_len,
64
+ device=attention_mask_2d.device,
65
+ BLOCK_SIZE=block_size,
66
+ )
67
+
68
+
69
+ class ESMplusplusConfig(PretrainedConfig):
70
+ """Configuration class for ESM++ model.
71
+
72
+ Args:
73
+ vocab_size: Size of the vocabulary
74
+ hidden_size: Dimension of hidden layers
75
+ num_attention_heads: Number of attention heads
76
+ num_hidden_layers: Number of transformer layers
77
+ num_labels: Number of output labels for classification
78
+ problem_type: Type of problem - regression, single/multi label classification
79
+ """
80
+ model_type = "ESMplusplus"
81
+ def __init__(
82
+ self,
83
+ vocab_size: int = 64,
84
+ hidden_size: int = 960,
85
+ num_attention_heads: int = 15,
86
+ num_hidden_layers: int = 30,
87
+ num_labels: int = 2,
88
+ problem_type: str | None = None,
89
+ dropout: float = 0.0,
90
+ initializer_range: float = 0.02,
91
+ attn_backend: str = "flex",
92
+ attn_compile: bool = True,
93
+ flex_block_size: int = 128,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(**kwargs)
97
+ self.vocab_size = vocab_size
98
+ self.hidden_size = hidden_size
99
+ self.num_attention_heads = num_attention_heads
100
+ self.num_hidden_layers = num_hidden_layers
101
+ self.num_labels = num_labels
102
+ self.problem_type = problem_type
103
+ self.dropout = dropout
104
+ self.initializer_range = initializer_range
105
+ self.tie_word_embeddings = False
106
+ self.attn_backend = attn_backend
107
+ self.attn_compile = attn_compile
108
+ self.flex_block_size = flex_block_size
109
+
110
+
111
+ ### Rotary Embeddings
112
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
113
+ """Rotates half the hidden dims of the input."""
114
+ if not interleaved:
115
+ x1, x2 = x.chunk(2, dim=-1)
116
+ return torch.cat((-x2, x1), dim=-1)
117
+ else:
118
+ x1, x2 = x[..., ::2], x[..., 1::2]
119
+ return rearrange(
120
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
121
+ )
122
+
123
+
124
+ def apply_rotary_emb_torch(
125
+ x: torch.Tensor,
126
+ cos: torch.Tensor,
127
+ sin: torch.Tensor,
128
+ interleaved: bool = False,
129
+ _inplace: bool = False,
130
+ ) -> torch.Tensor:
131
+ """Apply rotary embeddings to input based on cos and sin."""
132
+ ro_dim = cos.shape[-1] * 2
133
+ assert ro_dim <= x.shape[-1]
134
+ seqlen = x.size(1)
135
+ cos = cos[:seqlen]
136
+ sin = sin[:seqlen]
137
+ cos = repeat(cos, "s d -> s 1 (2 d)")
138
+ sin = repeat(sin, "s d -> s 1 (2 d)")
139
+ return torch.cat(
140
+ [
141
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
142
+ x[..., ro_dim:],
143
+ ],
144
+ dim=-1,
145
+ )
146
+
147
+
148
+ class RotaryEmbedding(torch.nn.Module):
149
+ """Rotary position embeddings.
150
+
151
+ Based on the paper "RoFormer: Enhanced Transformer with Rotary Position Embedding"
152
+
153
+ Args:
154
+ dim: Dimension of the embedding
155
+ base: Base for computing angular frequencies
156
+ interleaved: Whether to use interleaved rotations
157
+ scale_base: Base for scaling
158
+ scaling_factor: Factor for scaling positions
159
+ pos_idx_in_fp32: Whether to compute position indices in fp32
160
+ device: Computation device
161
+ """
162
+ def __init__(
163
+ self,
164
+ dim: int,
165
+ base: float = 10000.0,
166
+ interleaved: bool = False,
167
+ scale_base: Optional[float] = None,
168
+ scaling_factor: float = 1.0,
169
+ pos_idx_in_fp32: bool = True,
170
+ device: Optional[torch.device] = None,
171
+ ):
172
+ super().__init__()
173
+ self.dim = dim
174
+ self.base = float(base)
175
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
176
+ self.interleaved = interleaved
177
+ self.scale_base = scale_base
178
+ self.scaling_factor = scaling_factor
179
+ self.device = device
180
+
181
+ self._seq_len_cached = 0
182
+ self._cos_cached = None
183
+ self._sin_cached = None
184
+ self._cos_k_cached = None
185
+ self._sin_k_cached = None
186
+ self.reset_parameters()
187
+
188
+ def reset_parameters(self):
189
+ """Reset the parameters of the embedding."""
190
+ inv_freq = self._compute_inv_freq(self.device)
191
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
192
+ arange = torch.arange(0, self.dim, 2, device=self.device, dtype=torch.float32)
193
+ scale = (
194
+ (arange + 0.4 * self.dim) / (1.4 * self.dim)
195
+ if self.scale_base is not None
196
+ else None
197
+ )
198
+ self.register_buffer("scale", scale)
199
+
200
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
201
+ """Compute inverse frequency bands."""
202
+ return 1 / (
203
+ self.base
204
+ ** (
205
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
206
+ / self.dim
207
+ )
208
+ )
209
+
210
+ def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
211
+ """Update the cached cosine and sine values."""
212
+ if (
213
+ seqlen > self._seq_len_cached
214
+ or self._cos_cached is None
215
+ or self._cos_cached.device != device
216
+ or self._cos_cached.dtype != dtype
217
+ or (self.training and self._cos_cached.is_inference())
218
+ ):
219
+ self._seq_len_cached = seqlen
220
+ if self.pos_idx_in_fp32:
221
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
222
+ t /= self.scaling_factor
223
+ if self.inv_freq.dtype != torch.float32:
224
+ inv_freq = self.inv_freq.to(torch.float32)
225
+ else:
226
+ inv_freq = self.inv_freq
227
+ else:
228
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
229
+ t /= self.scaling_factor
230
+ inv_freq = self.inv_freq
231
+ freqs = torch.outer(t, inv_freq)
232
+
233
+ if self.scale is None:
234
+ self._cos_cached = torch.cos(freqs).to(dtype)
235
+ self._sin_cached = torch.sin(freqs).to(dtype)
236
+ else:
237
+ power = (
238
+ torch.arange(
239
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
240
+ )
241
+ - seqlen // 2
242
+ ) / self.scale_base
243
+ scale = self.scale.to(device=power.device) ** power.unsqueeze(-1)
244
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
245
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
246
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
247
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
248
+
249
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
250
+ """Apply rotary embeddings to queries and keys.
251
+
252
+ Args:
253
+ q: Query tensor of shape (batch, seqlen, nheads, headdim)
254
+ k: Key tensor of shape (batch, seqlen, nheads, headdim)
255
+
256
+ Returns:
257
+ Tuple of rotated query and key tensors
258
+ """
259
+ self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
260
+ assert self._cos_cached is not None
261
+ assert self._sin_cached is not None
262
+ if self.scale is None:
263
+ return (
264
+ apply_rotary_emb_torch(
265
+ q,
266
+ self._cos_cached,
267
+ self._sin_cached,
268
+ self.interleaved,
269
+ True, # inplace=True
270
+ ),
271
+ apply_rotary_emb_torch(
272
+ k,
273
+ self._cos_cached,
274
+ self._sin_cached,
275
+ self.interleaved,
276
+ True, # inplace=True
277
+ ),
278
+ ) # type: ignore
279
+ else:
280
+ assert False
281
+
282
+
283
+ ### Feedforward Network Components
284
+ def swiglu_correction_fn(expansion_ratio: float, d_model: int) -> int:
285
+ """Compute corrected dimension for SwiGLU."""
286
+ return int(((expansion_ratio * d_model) + 255) // 256 * 256)
287
+
288
+
289
+ class SwiGLU(nn.Module):
290
+ """SwiGLU activation function."""
291
+ def __init__(self):
292
+ super(SwiGLU, self).__init__()
293
+
294
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
295
+ x1, x2 = x.chunk(2, dim=-1)
296
+ return F.silu(x1) * x2
297
+
298
+
299
+ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
300
+ """Create SwiGLU feedforward network with layer normalization."""
301
+ return nn.Sequential(
302
+ nn.LayerNorm(d_model),
303
+ nn.Linear(
304
+ d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
305
+ ),
306
+ SwiGLU(),
307
+ nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
308
+ )
309
+
310
+
311
+ ### Attention
312
+ class MultiHeadAttention(nn.Module):
313
+ """Multi-head attention with rotary embeddings.
314
+
315
+ Args:
316
+ d_model: Model dimension
317
+ n_heads: Number of attention heads
318
+ """
319
+ def __init__(
320
+ self,
321
+ d_model: int,
322
+ n_heads: int,
323
+ attn_backend: str = "flex",
324
+ attn_compile: bool = True,
325
+ flex_block_size: int = 128,
326
+ ):
327
+ super().__init__()
328
+ self.d_model = d_model
329
+ self.n_heads = n_heads
330
+ self.d_head = self.d_model // self.n_heads
331
+ self.attn_backend = attn_backend
332
+ self.flex_block_size = flex_block_size
333
+ self.flex_attention = _resolve_flex_attention(attn_compile)
334
+ self._warned_flex_fallback = False
335
+ self.layernorm_qkv = nn.Sequential(
336
+ nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=False)
337
+ )
338
+ self.out_proj = nn.Linear(d_model, d_model, bias=False)
339
+ self.q_ln = nn.LayerNorm(d_model, bias=False)
340
+ self.k_ln = nn.LayerNorm(d_model, bias=False)
341
+ self.reshaper = partial(rearrange, pattern="b s (h d) -> b h s d", h=n_heads)
342
+ self.rotary = RotaryEmbedding(d_model // n_heads)
343
+
344
+ def _apply_rotary(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
345
+ """Apply rotary embeddings to query and key."""
346
+ q = q.unflatten(-1, (self.n_heads, self.d_head))
347
+ k = k.unflatten(-1, (self.n_heads, self.d_head))
348
+ q, k = self.rotary(q, k)
349
+ q = q.flatten(-2, -1)
350
+ k = k.flatten(-2, -1)
351
+ return q, k
352
+
353
+ def forward(
354
+ self,
355
+ x: torch.Tensor,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ flex_block_mask: Optional[object] = None,
358
+ output_attentions: bool = False,
359
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
360
+ """
361
+ Args:
362
+ x: Input tensor
363
+ attention_mask: Optional attention mask
364
+ output_attentions: Whether to return attention weights
365
+
366
+ Returns:
367
+ Output tensor after self attention, and optionally attention weights
368
+ """
369
+ attn_weights = None
370
+ qkv_BLD3 = self.layernorm_qkv(x)
371
+ query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
372
+ query_BLD, key_BLD = (
373
+ self.q_ln(query_BLD).to(query_BLD.dtype),
374
+ self.k_ln(key_BLD).to(query_BLD.dtype),
375
+ )
376
+ query_BLD, key_BLD = self._apply_rotary(query_BLD, key_BLD)
377
+ query_BHLD, key_BHLD, value_BHLD = map(self.reshaper, (query_BLD, key_BLD, value_BLD))
378
+
379
+ if output_attentions: # Manual attention computation
380
+ b, h, l, d = query_BHLD.shape
381
+ scale = 1 / math.sqrt(d)
382
+ attn_bias = torch.zeros(b, h, l, l, dtype=query_BLD.dtype, device=query_BLD.device)
383
+ if attention_mask is not None:
384
+ attn_bias.masked_fill_(attention_mask.logical_not(), float('-inf'))
385
+ attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale
386
+ attn_weights += attn_bias
387
+ attn_weights = F.softmax(attn_weights, dim=-1)
388
+ context_BHLD = torch.matmul(attn_weights, value_BHLD)
389
+ else:
390
+ sdpa_mask = None
391
+ if attention_mask is not None:
392
+ sdpa_mask = torch.zeros_like(attention_mask, dtype=query_BHLD.dtype)
393
+ sdpa_mask.masked_fill_(attention_mask.logical_not(), float("-inf"))
394
+ use_flex = (
395
+ self.attn_backend == "flex"
396
+ and self.flex_attention is not None
397
+ and (attention_mask is None or flex_block_mask is not None)
398
+ )
399
+ if use_flex:
400
+ try:
401
+ context_BHLD = self.flex_attention(
402
+ query_BHLD,
403
+ key_BHLD,
404
+ value_BHLD,
405
+ block_mask=flex_block_mask,
406
+ enable_gqa=query_BHLD.shape[1] != key_BHLD.shape[1],
407
+ )
408
+ except Exception as exc:
409
+ if not self._warned_flex_fallback:
410
+ warnings.warn(
411
+ f"Flex attention failed in ESM++ attention; falling back to SDPA. Error: {exc}",
412
+ RuntimeWarning,
413
+ )
414
+ self._warned_flex_fallback = True
415
+ context_BHLD = F.scaled_dot_product_attention(
416
+ query_BHLD,
417
+ key_BHLD,
418
+ value_BHLD,
419
+ attn_mask=sdpa_mask,
420
+ )
421
+ else:
422
+ context_BHLD = F.scaled_dot_product_attention(
423
+ query_BHLD,
424
+ key_BHLD,
425
+ value_BHLD,
426
+ attn_mask=sdpa_mask,
427
+ )
428
+
429
+ context_BLD = rearrange(context_BHLD, "b h s d -> b s (h d)")
430
+ output = self.out_proj(context_BLD)
431
+ return output, attn_weights
432
+
433
+
434
+ ### Regression Head
435
+ def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
436
+ """Create a regression head with optional hidden dimension.
437
+
438
+ Args:
439
+ d_model: Input dimension
440
+ output_dim: Output dimension
441
+ hidden_dim: Optional hidden dimension (defaults to d_model)
442
+ """
443
+ hidden_dim = hidden_dim if hidden_dim is not None else d_model
444
+ return nn.Sequential(
445
+ nn.Linear(d_model, hidden_dim),
446
+ nn.GELU(),
447
+ nn.LayerNorm(hidden_dim),
448
+ nn.Linear(hidden_dim, output_dim),
449
+ )
450
+
451
+
452
+ ### Transformer Block
453
+ class UnifiedTransformerBlock(nn.Module):
454
+ """Transformer block with attention and feedforward layers.
455
+
456
+ Args:
457
+ d_model: Model dimension
458
+ n_heads: Number of attention heads
459
+ residue_scaling_factor: Factor for scaling residual connections
460
+ expansion_ratio: Expansion ratio for feedforward network
461
+ """
462
+ def __init__(
463
+ self,
464
+ d_model: int,
465
+ n_heads: int,
466
+ residue_scaling_factor: float = 1,
467
+ expansion_ratio: float = 8 / 3,
468
+ dropout: float = 0.0,
469
+ attn_backend: str = "flex",
470
+ attn_compile: bool = True,
471
+ flex_block_size: int = 128,
472
+ ):
473
+ super().__init__()
474
+ self.attn = MultiHeadAttention(
475
+ d_model=d_model,
476
+ n_heads=n_heads,
477
+ attn_backend=attn_backend,
478
+ attn_compile=attn_compile,
479
+ flex_block_size=flex_block_size,
480
+ )
481
+ self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
482
+ self.scaling_factor = residue_scaling_factor
483
+ self.dropout = nn.Dropout(dropout)
484
+
485
+ def forward(
486
+ self,
487
+ x: torch.Tensor,
488
+ attention_mask: Optional[torch.Tensor] = None,
489
+ flex_block_mask: Optional[object] = None,
490
+ output_attentions: bool = False,
491
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
492
+ """
493
+ Args:
494
+ x: Input tensor
495
+ attention_mask: Optional attention mask
496
+ output_attentions: Whether to return attention weights
497
+
498
+ Returns:
499
+ Output tensor after transformer block, and optionally attention weights
500
+ """
501
+ attn_output, attn_weights = self.attn(
502
+ x,
503
+ attention_mask,
504
+ flex_block_mask,
505
+ output_attentions,
506
+ )
507
+ x = x + self.dropout(attn_output) / self.scaling_factor
508
+ x = x + self.dropout(self.ffn(x)) / self.scaling_factor
509
+ return x, attn_weights
510
+
511
+
512
+ ### Model Outputs
513
+ @dataclass
514
+ class TransformerOutput(ModelOutput):
515
+ """Output type for transformer encoder."""
516
+ last_hidden_state: Optional[torch.Tensor] = None
517
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
518
+ attentions: Optional[Tuple[torch.Tensor]] = None
519
+
520
+
521
+ @dataclass
522
+ class ESMplusplusOutput(ModelOutput):
523
+ """Output type for ESM++ models."""
524
+ loss: Optional[torch.Tensor] = None
525
+ logits: Optional[torch.Tensor] = None
526
+ last_hidden_state: Optional[torch.Tensor] = None
527
+ hidden_states: Optional[Tuple[torch.Tensor]] = None
528
+ attentions: Optional[Tuple[torch.Tensor]] = None
529
+
530
+
531
+ ### Transformer Stack
532
+ class TransformerStack(nn.Module):
533
+ """Stack of transformer blocks.
534
+
535
+ Args:
536
+ d_model: Model dimension
537
+ n_heads: Number of attention heads
538
+ n_layers: Number of transformer layers
539
+ dropout: Dropout rate
540
+ """
541
+ def __init__(
542
+ self,
543
+ d_model: int,
544
+ n_heads: int,
545
+ n_layers: int,
546
+ dropout: float = 0.0,
547
+ attn_backend: str = "flex",
548
+ attn_compile: bool = True,
549
+ flex_block_size: int = 128,
550
+ ):
551
+ super().__init__()
552
+ self.attn_backend = attn_backend
553
+ self.flex_block_size = flex_block_size
554
+ self.blocks = nn.ModuleList(
555
+ [
556
+ UnifiedTransformerBlock(
557
+ d_model,
558
+ n_heads,
559
+ residue_scaling_factor=math.sqrt(n_layers / 36),
560
+ dropout=dropout,
561
+ attn_backend=attn_backend,
562
+ attn_compile=attn_compile,
563
+ flex_block_size=flex_block_size,
564
+ )
565
+ for i in range(n_layers)
566
+ ]
567
+ )
568
+ self.norm = nn.LayerNorm(d_model, bias=False)
569
+ self.gradient_checkpointing = False
570
+
571
+ def forward(
572
+ self,
573
+ x: torch.Tensor,
574
+ attention_mask: Optional[torch.Tensor] = None,
575
+ output_hidden_states: bool = False,
576
+ output_attentions: bool = False,
577
+ ) -> TransformerOutput:
578
+ """
579
+ Args:
580
+ x: Input tensor
581
+ attention_mask: Optional attention mask
582
+ output_hidden_states: Whether to return all hidden states
583
+ output_attentions: Whether to return attention weights
584
+
585
+ Returns:
586
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
587
+ """
588
+ batch_size, seq_len, _ = x.shape
589
+ hidden_states = () if output_hidden_states else None
590
+ attentions = () if output_attentions else None
591
+
592
+ if attention_mask is not None:
593
+ attention_mask = attention_mask[:, None, None, :].expand(batch_size, 1, seq_len, seq_len).bool()
594
+ if self.attn_backend == "flex" and create_block_mask is not None and not output_attentions:
595
+ token_attention_mask = attention_mask[:, 0, 0, :]
596
+ flex_block_mask = _create_pad_block_mask(token_attention_mask, self.flex_block_size)
597
+ else:
598
+ flex_block_mask = None
599
+ else:
600
+ flex_block_mask = None
601
+
602
+ for block in self.blocks:
603
+ if self.gradient_checkpointing and self.training:
604
+ x, attn_weights = self._gradient_checkpointing_func(
605
+ block.__call__,
606
+ x,
607
+ attention_mask,
608
+ flex_block_mask,
609
+ output_attentions,
610
+ )
611
+ else:
612
+ x, attn_weights = block(x, attention_mask, flex_block_mask, output_attentions)
613
+
614
+ if attentions is not None:
615
+ attentions += (attn_weights,)
616
+
617
+ if output_hidden_states:
618
+ assert hidden_states is not None
619
+ hidden_states += (x,)
620
+
621
+ return TransformerOutput(
622
+ last_hidden_state=self.norm(x),
623
+ hidden_states=hidden_states,
624
+ attentions=attentions
625
+ )
626
+
627
+
628
+ class PreTrainedESMplusplusModel(PreTrainedModel):
629
+ """
630
+ init weights for ESM++ models
631
+ """
632
+ config_class = ESMplusplusConfig
633
+ base_model_prefix = "esm++"
634
+ supports_gradient_checkpointing = True
635
+ all_tied_weights_keys = {}
636
+
637
+ def _init_weights(self, module):
638
+ """Initialize the weights"""
639
+ if isinstance(module, nn.Linear):
640
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
641
+ if module.bias is not None:
642
+ module.bias.data.zero_()
643
+ elif isinstance(module, nn.Embedding):
644
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
645
+ if module.padding_idx is not None:
646
+ module.weight.data[module.padding_idx].zero_()
647
+ elif isinstance(module, nn.LayerNorm):
648
+ if module.bias is not None:
649
+ module.bias.data.zero_()
650
+ module.weight.data.fill_(1.0)
651
+
652
+ @classmethod
653
+ def from_pretrained_esm(cls, model_name: str):
654
+ """Load a pretrained ESM++ model."""
655
+ if '300' in model_name:
656
+ return ESMplusplus_300M()
657
+ elif '600' in model_name:
658
+ return ESMplusplus_600M()
659
+ else:
660
+ raise ValueError(f"Invalid model name: {model_name}")
661
+
662
+
663
+ ### ESM++ Models
664
+ class ESMplusplusModel(PreTrainedESMplusplusModel, EmbeddingMixin):
665
+ """
666
+ ESM++ model. transformer model with no heads
667
+ """
668
+ config_class = ESMplusplusConfig
669
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
670
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
671
+ self.config = config
672
+ self.vocab_size = config.vocab_size
673
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
674
+ self.transformer = TransformerStack(
675
+ d_model=config.hidden_size,
676
+ n_heads=config.num_attention_heads,
677
+ n_layers=config.num_hidden_layers,
678
+ dropout=config.dropout,
679
+ attn_backend=config.attn_backend,
680
+ attn_compile=config.attn_compile,
681
+ flex_block_size=config.flex_block_size,
682
+ )
683
+ self.tokenizer = EsmSequenceTokenizer()
684
+ self.init_weights()
685
+
686
+ def get_input_embeddings(self):
687
+ return self.embed
688
+
689
+ def set_input_embeddings(self, value):
690
+ self.embed = value
691
+
692
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
693
+ x = self.embed(input_ids)
694
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
695
+
696
+ def forward(
697
+ self,
698
+ input_ids: Optional[torch.Tensor] = None,
699
+ attention_mask: Optional[torch.Tensor] = None,
700
+ inputs_embeds: Optional[torch.Tensor] = None,
701
+ output_attentions: Optional[bool] = None,
702
+ output_hidden_states: Optional[bool] = None,
703
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
704
+ **kwargs,
705
+ ) -> TransformerOutput:
706
+ """Forward pass for masked language modeling.
707
+
708
+ Args:
709
+ input_ids: Input token IDs
710
+ attention_mask: Attention mask
711
+ inputs_embeds: Optional precomputed embeddings
712
+ output_hidden_states: Whether to return all hidden states
713
+ output_attentions: Whether to return attention weights
714
+
715
+ Returns:
716
+ TransformerOutput containing last hidden state and optionally all hidden states and attention weights
717
+ """
718
+ if inputs_embeds is None:
719
+ x = self.embed(input_ids)
720
+ else:
721
+ x = inputs_embeds
722
+ return self.transformer(x, attention_mask, output_hidden_states, output_attentions)
723
+
724
+
725
+ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel, EmbeddingMixin):
726
+ """
727
+ ESM++ model for masked language modeling.
728
+ Implements the base ESM++ architecture with a masked language modeling head.
729
+ """
730
+ config_class = ESMplusplusConfig
731
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
732
+ PreTrainedESMplusplusModel.__init__(self, config, **kwargs)
733
+ self.config = config
734
+ self.vocab_size = config.vocab_size
735
+ self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
736
+ self.transformer = TransformerStack(
737
+ d_model=config.hidden_size,
738
+ n_heads=config.num_attention_heads,
739
+ n_layers=config.num_hidden_layers,
740
+ dropout=config.dropout,
741
+ attn_backend=config.attn_backend,
742
+ attn_compile=config.attn_compile,
743
+ flex_block_size=config.flex_block_size,
744
+ )
745
+ self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
746
+ self.ce_loss = nn.CrossEntropyLoss()
747
+ self.tokenizer = EsmSequenceTokenizer()
748
+ self.init_weights()
749
+
750
+ def get_input_embeddings(self):
751
+ return self.embed
752
+
753
+ def set_input_embeddings(self, value):
754
+ self.embed = value
755
+
756
+ def get_output_embeddings(self):
757
+ return self.sequence_head[-1]
758
+
759
+ def set_output_embeddings(self, new_embeddings):
760
+ self.sequence_head[-1] = new_embeddings
761
+
762
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
763
+ x = self.embed(input_ids)
764
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
765
+
766
+ def forward(
767
+ self,
768
+ input_ids: Optional[torch.Tensor] = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ inputs_embeds: Optional[torch.Tensor] = None,
771
+ labels: Optional[torch.Tensor] = None,
772
+ output_attentions: Optional[bool] = None,
773
+ output_hidden_states: Optional[bool] = None,
774
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
775
+ **kwargs,
776
+ ) -> ESMplusplusOutput:
777
+ """Forward pass for masked language modeling.
778
+
779
+ Args:
780
+ input_ids: Input token IDs
781
+ attention_mask: Attention mask
782
+ inputs_embeds: Optional precomputed embeddings
783
+ labels: Optional labels for masked tokens
784
+ output_hidden_states: Whether to return all hidden states
785
+ output_attentions: Whether to return attention weights
786
+
787
+ Returns:
788
+ ESMplusplusOutput containing loss, logits, hidden states and attention weights
789
+ """
790
+ if inputs_embeds is None:
791
+ x = self.embed(input_ids)
792
+ else:
793
+ x = inputs_embeds
794
+ output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
795
+ x = output.last_hidden_state
796
+ logits = self.sequence_head(x)
797
+ loss = None
798
+ if labels is not None:
799
+ loss = self.ce_loss(logits.view(-1, self.vocab_size), labels.view(-1))
800
+ return ESMplusplusOutput(
801
+ loss=loss,
802
+ logits=logits,
803
+ last_hidden_state=x,
804
+ hidden_states=output.hidden_states,
805
+ attentions=output.attentions,
806
+ )
807
+
808
+
809
+ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
810
+ """
811
+ ESM++ model for sequence classification.
812
+ Extends the base ESM++ model with a classification head.
813
+ """
814
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
815
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
816
+ self.config = config
817
+ self.num_labels = config.num_labels
818
+ self.classifier = RegressionHead(config.hidden_size * 2, config.num_labels, config.hidden_size * 4)
819
+ # Large intermediate projections help with sequence classification tasks (*4)
820
+ self.mse = nn.MSELoss()
821
+ self.ce = nn.CrossEntropyLoss()
822
+ self.bce = nn.BCEWithLogitsLoss()
823
+ # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
824
+ if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
825
+ pooling_types = kwargs['pooling_types']
826
+ else:
827
+ pooling_types = ['cls', 'mean']
828
+ self.pooler = Pooler(pooling_types)
829
+ self.init_weights()
830
+
831
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
832
+ x = self.embed(input_ids)
833
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
834
+
835
+ def forward(
836
+ self,
837
+ input_ids: Optional[torch.Tensor] = None,
838
+ attention_mask: Optional[torch.Tensor] = None,
839
+ inputs_embeds: Optional[torch.Tensor] = None,
840
+ labels: Optional[torch.Tensor] = None,
841
+ output_attentions: Optional[bool] = None,
842
+ output_hidden_states: Optional[bool] = None,
843
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
844
+ **kwargs,
845
+ ) -> ESMplusplusOutput:
846
+ """Forward pass for sequence classification.
847
+
848
+ Args:
849
+ input_ids: Input token IDs
850
+ attention_mask: Attention mask
851
+ inputs_embeds: Optional precomputed embeddings
852
+ labels: Optional labels for classification
853
+ output_hidden_states: Whether to return all hidden states
854
+ output_attentions: Whether to return attention weights
855
+
856
+ Returns:
857
+ ESMplusplusOutput containing loss, logits, and hidden states
858
+ """
859
+ output = super().forward(
860
+ input_ids=input_ids,
861
+ attention_mask=attention_mask,
862
+ inputs_embeds=inputs_embeds,
863
+ labels=None,
864
+ output_attentions=output_attentions,
865
+ output_hidden_states=output_hidden_states
866
+ )
867
+ x = output.last_hidden_state
868
+ features = self.pooler(x, attention_mask)
869
+ logits = self.classifier(features)
870
+ loss = None
871
+ if labels is not None:
872
+ labels = labels.to(logits.device)
873
+ if self.config.problem_type is None:
874
+ if self.num_labels == 1:
875
+ self.config.problem_type = "regression"
876
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
877
+ self.config.problem_type = "single_label_classification"
878
+ else:
879
+ self.config.problem_type = "multi_label_classification"
880
+
881
+ if self.config.problem_type == "regression":
882
+ if self.num_labels == 1:
883
+ loss = self.mse(logits.flatten(), labels.flatten())
884
+ else:
885
+ loss = self.mse(logits, labels)
886
+ elif self.config.problem_type == "single_label_classification":
887
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
888
+ elif self.config.problem_type == "multi_label_classification":
889
+ loss = self.bce(logits, labels)
890
+
891
+ return ESMplusplusOutput(
892
+ loss=loss,
893
+ logits=logits,
894
+ last_hidden_state=x,
895
+ hidden_states=output.hidden_states,
896
+ attentions=output.attentions,
897
+ )
898
+
899
+
900
+ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM, EmbeddingMixin):
901
+ """
902
+ ESM++ model for token classification.
903
+ Extends the base ESM++ model with a token classification head.
904
+ """
905
+ def __init__(self, config: ESMplusplusConfig, **kwargs):
906
+ ESMplusplusForMaskedLM.__init__(self, config, **kwargs)
907
+ self.config = config
908
+ self.num_labels = config.num_labels
909
+ self.classifier = RegressionHead(config.hidden_size, config.num_labels, config.hidden_size * 4)
910
+ # Large intermediate projections help with sequence classification tasks (*4)
911
+ self.loss_fct = nn.CrossEntropyLoss()
912
+ self.init_weights()
913
+
914
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
915
+ x = self.embed(input_ids)
916
+ return self.transformer(x, attention_mask, output_hidden_states=False, output_attentions=False).last_hidden_state
917
+
918
+ def forward(
919
+ self,
920
+ input_ids: Optional[torch.Tensor] = None,
921
+ attention_mask: Optional[torch.Tensor] = None,
922
+ inputs_embeds: Optional[torch.Tensor] = None,
923
+ labels: Optional[torch.Tensor] = None,
924
+ output_attentions: Optional[bool] = None,
925
+ output_hidden_states: Optional[bool] = None,
926
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
927
+ **kwargs,
928
+ ) -> ESMplusplusOutput:
929
+ """Forward pass for token classification.
930
+
931
+ Args:
932
+ input_ids: Input token IDs
933
+ attention_mask: Attention mask
934
+ inputs_embeds: Optional precomputed embeddings
935
+ labels: Optional labels for token classification
936
+ output_hidden_states: Whether to return all hidden states
937
+ output_attentions: Whether to return attention weights
938
+
939
+ Returns:
940
+ ESMplusplusOutput containing loss, logits, and hidden states
941
+ """
942
+ output = super().forward(
943
+ input_ids=input_ids,
944
+ attention_mask=attention_mask,
945
+ inputs_embeds=inputs_embeds,
946
+ labels=None,
947
+ output_attentions=output_attentions,
948
+ output_hidden_states=output_hidden_states
949
+ )
950
+ x = output.last_hidden_state
951
+ logits = self.classifier(x)
952
+ loss = None
953
+ if labels is not None:
954
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
955
+ return ESMplusplusOutput(
956
+ loss=loss,
957
+ logits=logits,
958
+ last_hidden_state=x,
959
+ hidden_states=output.hidden_states,
960
+ attentions=output.attentions,
961
+ )
962
+
963
+
964
+ ### Loading from EvolutionaryScale
965
+ @staticmethod
966
+ @cache
967
+ def data_root(model: str):
968
+ if "INFRA_PROVIDER" in os.environ:
969
+ return Path("")
970
+ # Try to download from hugginface if it doesn't exist
971
+ if model.startswith("esmc-300"):
972
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-300m-2024-12"))
973
+ elif model.startswith("esmc-600"):
974
+ path = Path(snapshot_download(repo_id="EvolutionaryScale/esmc-600m-2024-12"))
975
+ else:
976
+ raise ValueError(f"{model=} is an invalid model name.")
977
+ return path
978
+
979
+
980
+ def ESMplusplus_300M(device: torch.device | str = "cpu"):
981
+ with torch.device(device):
982
+ config = ESMplusplusConfig(
983
+ hidden_size=960,
984
+ num_attention_heads=15,
985
+ num_hidden_layers=30,
986
+ )
987
+ model = ESMplusplusForMaskedLM(config)
988
+ state_dict = torch.load(
989
+ data_root("esmc-300") / "data/weights/esmc_300m_2024_12_v0.pth",
990
+ map_location=device,
991
+ )
992
+ model.load_state_dict(state_dict)
993
+ return model
994
+
995
+
996
+ def ESMplusplus_600M(device: torch.device | str = "cpu"):
997
+ with torch.device(device):
998
+ config = ESMplusplusConfig(
999
+ hidden_size=1152,
1000
+ num_attention_heads=18,
1001
+ num_hidden_layers=36,
1002
+ )
1003
+ model = ESMplusplusForMaskedLM(config)
1004
+ state_dict = torch.load(
1005
+ data_root("esmc-600") / "data/weights/esmc_600m_2024_12_v0.pth",
1006
+ map_location=device,
1007
+ )
1008
+ model.load_state_dict(state_dict)
1009
+ return model
1010
+
1011
+
1012
+ ### Tokenization
1013
+ SEQUENCE_VOCAB = [
1014
+ "<cls>", "<pad>", "<eos>", "<unk>",
1015
+ "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K",
1016
+ "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z",
1017
+ "O", ".", "-", "|",
1018
+ "<mask>",
1019
+ ]
1020
+
1021
+ class EsmSequenceTokenizer(PreTrainedTokenizerFast):
1022
+ model_input_names = ["input_ids", "attention_mask"]
1023
+
1024
+ def __init__(
1025
+ self,
1026
+ unk_token="<unk>",
1027
+ cls_token="<cls>",
1028
+ pad_token="<pad>",
1029
+ mask_token="<mask>",
1030
+ eos_token="<eos>",
1031
+ chain_break_token="|",
1032
+ **kwargs,
1033
+ ):
1034
+ all_tokens = SEQUENCE_VOCAB
1035
+ token_to_id = {tok: ind for ind, tok in enumerate(all_tokens)}
1036
+
1037
+ # a character-level tokenizer is the same as BPE with no token merges
1038
+ bpe = BPE(token_to_id, merges=[], unk_token=unk_token)
1039
+ tokenizer = Tokenizer(bpe)
1040
+ special_tokens = [
1041
+ cls_token,
1042
+ pad_token,
1043
+ mask_token,
1044
+ eos_token,
1045
+ chain_break_token,
1046
+ ]
1047
+ self.cb_token = chain_break_token
1048
+ additional_special_tokens = [chain_break_token]
1049
+
1050
+ tokenizer.add_special_tokens(special_tokens)
1051
+
1052
+ # This is where we configure the automatic addition of special tokens when we call
1053
+ # tokenizer(text, add_special_tokens=True). Note that you can also configure how two
1054
+ # sequences are merged if you want.
1055
+ tokenizer.post_processor = TemplateProcessing( # type: ignore
1056
+ single="<cls> $A <eos>",
1057
+ pair="<cls>:0 $A:0 <eos>:0 $B:1 <eos>:1",
1058
+ special_tokens=[
1059
+ ("<cls>", tokenizer.token_to_id("<cls>")),
1060
+ ("<eos>", tokenizer.token_to_id("<eos>")),
1061
+ ],
1062
+ )
1063
+ super().__init__(
1064
+ tokenizer_object=tokenizer,
1065
+ unk_token=unk_token,
1066
+ cls_token=cls_token,
1067
+ pad_token=pad_token,
1068
+ mask_token=mask_token,
1069
+ eos_token=eos_token,
1070
+ additional_special_tokens=additional_special_tokens,
1071
+ **kwargs,
1072
+ )
1073
+
1074
+ # These are a footgun, we never use the `bos` token anywhere so we're just overriding it here.
1075
+ @property
1076
+ def bos_token(self):
1077
+ return self.cls_token
1078
+
1079
+ @property
1080
+ def bos_token_id(self):
1081
+ return self.cls_token_id
1082
+
1083
+ @property
1084
+ def chain_break_token(self):
1085
+ return self.cb_token
1086
+
1087
+ @property
1088
+ def chain_break_token_id(self):
1089
+ return self.convert_tokens_to_ids(self.chain_break_token)
1090
+
1091
+ @property
1092
+ def all_token_ids(self):
1093
+ return list(range(self.vocab_size))
1094
+
1095
+ @property
1096
+ def special_token_ids(self):
1097
+ return self.all_special_ids
1098
+
1099
+
1100
+ if __name__ == "__main__":
1101
+ # Set device to CPU for testing
1102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1103
+ print(f"Using device: {device}")
1104
+
1105
+ # Test tokenizer
1106
+ tokenizer = EsmSequenceTokenizer()
1107
+ sample_sequence = "MQIFVKTLTGKTITLEVEPSDTIENVKAKIQDKEGIPPDQQRLIFAGKQLEDGRTLSDYNIQKESTLHLVLRLRGG"
1108
+ encoding = tokenizer(sample_sequence, return_tensors="pt")
1109
+ print(f"Input sequence length: {len(sample_sequence)}")
1110
+ print(f"Tokenized sequence: {encoding['input_ids'].shape}")
1111
+
1112
+ # Prepare inputs
1113
+ input_ids = encoding['input_ids'].to(device)
1114
+ attention_mask = encoding['attention_mask'].to(device)
1115
+
1116
+ # Test base model with smaller config for quick testing
1117
+ print("\n=== Testing ESMplusplus Base Model ===")
1118
+ base_config = ESMplusplusConfig(
1119
+ hidden_size=384,
1120
+ num_attention_heads=6,
1121
+ num_hidden_layers=4
1122
+ )
1123
+ base_model = ESMplusplusModel(base_config).to(device)
1124
+
1125
+ with torch.no_grad():
1126
+ outputs = base_model(input_ids=input_ids, attention_mask=attention_mask)
1127
+
1128
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1129
+
1130
+ # Test embedding functionality
1131
+ print("\nTesting embedding functionality:")
1132
+ with torch.no_grad():
1133
+ embeddings = base_model._embed(input_ids, attention_mask)
1134
+ print(f"Embedding shape: {embeddings.shape}")
1135
+
1136
+ # Test masked language modeling
1137
+ print("\n=== Testing ESMplusplus For Masked LM ===")
1138
+ mlm_model = ESMplusplusForMaskedLM(base_config).to(device)
1139
+
1140
+ with torch.no_grad():
1141
+ outputs = mlm_model(input_ids=input_ids, attention_mask=attention_mask)
1142
+
1143
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1144
+ print(f"Logits shape: {outputs.logits.shape}")
1145
+
1146
+ # Test sequence classification model
1147
+ print("\n=== Testing Sequence Classification Model ===")
1148
+ classification_model = ESMplusplusForSequenceClassification(base_config).to(device)
1149
+
1150
+ with torch.no_grad():
1151
+ outputs = classification_model(input_ids=input_ids, attention_mask=attention_mask)
1152
+
1153
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1154
+ print(f"Logits shape: {outputs.logits.shape}")
1155
+
1156
+ # Test token classification model
1157
+ print("\n=== Testing Token Classification Model ===")
1158
+ token_model = ESMplusplusForTokenClassification(base_config).to(device)
1159
+
1160
+ with torch.no_grad():
1161
+ outputs = token_model(input_ids=input_ids, attention_mask=attention_mask)
1162
+
1163
+ print(f"Last hidden state shape: {outputs.last_hidden_state.shape}")
1164
+ print(f"Logits shape: {outputs.logits.shape}")
1165
+
1166
+ # Test embedding dataset functionality with a mini dataset
1167
+ print("\n=== Testing Embed Dataset Functionality ===")
1168
+ mini_dataset = [sample_sequence, sample_sequence[:50], sample_sequence[:30]]
1169
+ print(f"Creating embeddings for {len(mini_dataset)} sequences")
1170
+
1171
+ # Only run this if save path doesn't exist to avoid overwriting
1172
+ if not os.path.exists("test_embeddings.pth"):
1173
+ embeddings = mlm_model.embed_dataset(
1174
+ sequences=mini_dataset,
1175
+ tokenizer=tokenizer,
1176
+ batch_size=2,
1177
+ max_len=100,
1178
+ full_embeddings=False,
1179
+ pooling_types=['mean'],
1180
+ save_path="test_embeddings.pth"
1181
+ )
1182
+ if embeddings:
1183
+ print(f"Embedding dictionary size: {len(embeddings)}")
1184
+ for seq, emb in embeddings.items():
1185
+ print(f"Sequence length: {len(seq)}, Embedding shape: {emb.shape}")
1186
+ break
1187
+ else:
1188
+ print("Skipping embedding test as test_embeddings.pth already exists")
1189
+
1190
+ print("\nAll tests completed successfully!")
1191
+