lhallee commited on
Commit
807fd23
·
verified ·
1 Parent(s): 4f5e2b4

Upload modeling_dplm2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_dplm2.py +102 -11
modeling_dplm2.py CHANGED
@@ -156,6 +156,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
156
  return _collate_fn
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  class EmbeddingMixin:
160
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
161
  raise NotImplementedError
@@ -243,7 +263,7 @@ class EmbeddingMixin:
243
 
244
  def embed_dataset(
245
  self,
246
- sequences: List[str],
247
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
248
  batch_size: int = 2,
249
  max_len: int = 512,
@@ -256,6 +276,7 @@ class EmbeddingMixin:
256
  save: bool = True,
257
  sql_db_path: str = 'embeddings.db',
258
  save_path: str = 'embeddings.pth',
 
259
  **kwargs,
260
  ) -> Optional[dict[str, torch.Tensor]]:
261
  """
@@ -264,7 +285,15 @@ class EmbeddingMixin:
264
  Supports two modes:
265
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
266
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
 
 
 
267
  """
 
 
 
 
 
268
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
269
  sequences = sorted(sequences, key=len, reverse=True)
270
  hidden_size = self.config.hidden_size
@@ -668,8 +697,9 @@ def get_attention_mask(
668
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
669
  return attention_mask_2d, None, flex_block_mask
670
 
671
- # SDPA / manual
672
- attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
 
673
  return attention_mask_2d, attention_mask_4d, None
674
 
675
 
@@ -680,6 +710,55 @@ def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor)
680
  return modality_type
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  @dataclass
684
  class DPLM2MaskedLMOutput(ModelOutput):
685
  loss: Optional[torch.Tensor] = None
@@ -747,17 +826,22 @@ class DPLM2PreTrainedModel(EsmPreTrainedModel):
747
 
748
 
749
  class ModifiedRotaryEmbedding(RotaryEmbedding):
750
- def __init__(self, dim: int, aa_type: int, struct_type: int):
751
  super().__init__(dim)
752
  self.aa_type = aa_type
753
  self.struct_type = struct_type
 
754
 
755
  def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
756
- if type_ids is None:
757
- return False
758
- aa_present = (type_ids == self.aa_type).any()
759
- struct_present = (type_ids == self.struct_type).any()
760
- return bool(aa_present and struct_present)
 
 
 
 
761
 
762
  def _update_cos_sin_tables(
763
  self,
@@ -774,12 +858,13 @@ class ModifiedRotaryEmbedding(RotaryEmbedding):
774
  or self._sin_cached is None
775
  or seq_len != self._seq_len_cached
776
  or self._cos_cached.device != x.device
 
777
  )
778
  if cache_is_stale:
779
  self._seq_len_cached = seq_len
780
  t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
781
  freqs = torch.outer(t, self.inv_freq)
782
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
783
  self._cos_cached = emb.cos()[None, None, :, :]
784
  self._sin_cached = emb.sin()[None, None, :, :]
785
 
@@ -823,6 +908,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
823
  dim=self.attention_head_size,
824
  aa_type=config.aa_type,
825
  struct_type=config.struct_type,
 
826
  )
827
 
828
  def forward(
@@ -1112,6 +1198,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
1112
  self.embeddings.word_embeddings = value
1113
 
1114
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
1115
  if attention_mask is None:
1116
  attention_mask = input_ids.ne(self.config.pad_token_id)
1117
  type_ids = _infer_modality_type(input_ids, attention_mask)
@@ -1143,7 +1230,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
1143
  if input_ids is not None and inputs_embeds is not None:
1144
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1145
  elif input_ids is not None:
1146
- pass
1147
  elif inputs_embeds is None:
1148
  raise ValueError("You have to specify either input_ids or inputs_embeds")
1149
 
@@ -1248,6 +1335,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1248
  self.lm_head.decoder = new_embeddings
1249
 
1250
  def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
 
1251
  return _infer_modality_type(input_ids, attention_mask)
1252
 
1253
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -1299,6 +1387,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
1299
  logits = self.lm_head(sequence_output)
1300
  loss = None
1301
  if labels is not None:
 
1302
  labels = labels.to(logits.device)
1303
  loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1304
 
@@ -1353,6 +1442,7 @@ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1353
  if type_ids is None and input_ids is not None:
1354
  if attention_mask is None:
1355
  attention_mask = input_ids.ne(self.config.pad_token_id)
 
1356
  type_ids = _infer_modality_type(input_ids, attention_mask)
1357
 
1358
  outputs = self.esm(
@@ -1432,6 +1522,7 @@ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
1432
  if type_ids is None and input_ids is not None:
1433
  if attention_mask is None:
1434
  attention_mask = input_ids.ne(self.config.pad_token_id)
 
1435
  type_ids = _infer_modality_type(input_ids, attention_mask)
1436
 
1437
  outputs = self.esm(
 
156
  return _collate_fn
157
 
158
 
159
+ def parse_fasta(fasta_path: str) -> List[str]:
160
+ assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
161
+ sequences = []
162
+ current_seq = []
163
+ with open(fasta_path, 'r') as f:
164
+ for line in f:
165
+ line = line.strip()
166
+ if not line:
167
+ continue
168
+ if line.startswith('>'):
169
+ if current_seq:
170
+ sequences.append(''.join(current_seq))
171
+ current_seq = []
172
+ else:
173
+ current_seq.append(line)
174
+ if current_seq:
175
+ sequences.append(''.join(current_seq))
176
+ return sequences
177
+
178
+
179
  class EmbeddingMixin:
180
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
181
  raise NotImplementedError
 
263
 
264
  def embed_dataset(
265
  self,
266
+ sequences: Optional[List[str]] = None,
267
  tokenizer: Optional[PreTrainedTokenizerBase] = None,
268
  batch_size: int = 2,
269
  max_len: int = 512,
 
276
  save: bool = True,
277
  sql_db_path: str = 'embeddings.db',
278
  save_path: str = 'embeddings.pth',
279
+ fasta_path: Optional[str] = None,
280
  **kwargs,
281
  ) -> Optional[dict[str, torch.Tensor]]:
282
  """
 
285
  Supports two modes:
286
  - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
287
  - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
288
+
289
+ Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
290
+ `fasta_path`, or both (the two sources are combined). At least one must be provided.
291
  """
292
+ if fasta_path is not None:
293
+ fasta_sequences = parse_fasta(fasta_path)
294
+ sequences = list(sequences or []) + fasta_sequences
295
+ assert sequences is not None and len(sequences) > 0, \
296
+ "Must provide at least one sequence via `sequences` or `fasta_path`."
297
  sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
298
  sequences = sorted(sequences, key=len, reverse=True)
299
  hidden_size = self.config.hidden_size
 
697
  flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
698
  return attention_mask_2d, None, flex_block_mask
699
 
700
+ # SDPA / manual — only mask the key dimension so padding query positions attend to
701
+ # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
702
+ attention_mask_4d = attention_mask_2d[:, None, None, :]
703
  return attention_mask_2d, attention_mask_4d, None
704
 
705
 
 
710
  return modality_type
711
 
712
 
713
+ def _normalize_dplm2_input_ids(input_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
714
+ if input_ids.numel() == 0:
715
+ return input_ids
716
+
717
+ normalized_input_ids = input_ids.clone()
718
+ generic_to_aa_special_ids = {
719
+ vocab_size: 2,
720
+ vocab_size + 1: 3,
721
+ vocab_size + 2: 0,
722
+ vocab_size + 3: 32,
723
+ }
724
+ for generic_id, aa_id in generic_to_aa_special_ids.items():
725
+ normalized_input_ids[input_ids == generic_id] = aa_id
726
+
727
+ valid_token_mask = normalized_input_ids.ge(0)
728
+ if valid_token_mask.any():
729
+ max_token_id = int(normalized_input_ids[valid_token_mask].max().item())
730
+ assert max_token_id < vocab_size, (
731
+ f"Found token id {max_token_id} outside the DPLM2 embedding table (vocab_size={vocab_size}). "
732
+ "Tokenizer special tokens must be normalized before embedding."
733
+ )
734
+ return normalized_input_ids
735
+
736
+
737
+ def _has_packed_multimodal_layout(
738
+ type_ids: Optional[torch.Tensor],
739
+ aa_type: int,
740
+ struct_type: int,
741
+ pad_type: int,
742
+ ) -> bool:
743
+ if type_ids is None:
744
+ return False
745
+ assert type_ids.ndim == 2, f"Expected type_ids to have shape (batch, seq_len), got {tuple(type_ids.shape)}"
746
+ seq_len = type_ids.shape[-1]
747
+ if seq_len % 2 != 0:
748
+ return False
749
+
750
+ half_len = seq_len // 2
751
+ first_half = type_ids[:, :half_len]
752
+ second_half = type_ids[:, half_len:]
753
+
754
+ first_half_valid = ((first_half == aa_type) | (first_half == pad_type)).all(dim=-1)
755
+ second_half_valid = ((second_half == struct_type) | (second_half == pad_type)).all(dim=-1)
756
+ aa_count = (first_half == aa_type).sum(dim=-1)
757
+ struct_count = (second_half == struct_type).sum(dim=-1)
758
+ packed_rows = first_half_valid & second_half_valid & aa_count.gt(0) & aa_count.eq(struct_count)
759
+ return bool(packed_rows.all())
760
+
761
+
762
  @dataclass
763
  class DPLM2MaskedLMOutput(ModelOutput):
764
  loss: Optional[torch.Tensor] = None
 
826
 
827
 
828
  class ModifiedRotaryEmbedding(RotaryEmbedding):
829
+ def __init__(self, dim: int, aa_type: int, struct_type: int, pad_type: int):
830
  super().__init__(dim)
831
  self.aa_type = aa_type
832
  self.struct_type = struct_type
833
+ self.pad_type = pad_type
834
 
835
  def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
836
+ # The split rotary path only works when the sequence tensor is already packed
837
+ # as [AA half | structure half]. Plain protein batches can still contain
838
+ # high-ID special tokens, so mere modality presence is not enough.
839
+ return _has_packed_multimodal_layout(
840
+ type_ids=type_ids,
841
+ aa_type=self.aa_type,
842
+ struct_type=self.struct_type,
843
+ pad_type=self.pad_type,
844
+ )
845
 
846
  def _update_cos_sin_tables(
847
  self,
 
858
  or self._sin_cached is None
859
  or seq_len != self._seq_len_cached
860
  or self._cos_cached.device != x.device
861
+ or self._cos_cached.dtype != x.dtype
862
  )
863
  if cache_is_stale:
864
  self._seq_len_cached = seq_len
865
  t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
866
  freqs = torch.outer(t, self.inv_freq)
867
+ emb = torch.cat((freqs, freqs), dim=-1).to(device=x.device, dtype=x.dtype)
868
  self._cos_cached = emb.cos()[None, None, :, :]
869
  self._sin_cached = emb.sin()[None, None, :, :]
870
 
 
908
  dim=self.attention_head_size,
909
  aa_type=config.aa_type,
910
  struct_type=config.struct_type,
911
+ pad_type=config.pad_type,
912
  )
913
 
914
  def forward(
 
1198
  self.embeddings.word_embeddings = value
1199
 
1200
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1201
+ input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
1202
  if attention_mask is None:
1203
  attention_mask = input_ids.ne(self.config.pad_token_id)
1204
  type_ids = _infer_modality_type(input_ids, attention_mask)
 
1230
  if input_ids is not None and inputs_embeds is not None:
1231
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1232
  elif input_ids is not None:
1233
+ input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
1234
  elif inputs_embeds is None:
1235
  raise ValueError("You have to specify either input_ids or inputs_embeds")
1236
 
 
1335
  self.lm_head.decoder = new_embeddings
1336
 
1337
  def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1338
+ input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
1339
  return _infer_modality_type(input_ids, attention_mask)
1340
 
1341
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
1387
  logits = self.lm_head(sequence_output)
1388
  loss = None
1389
  if labels is not None:
1390
+ labels = _normalize_dplm2_input_ids(labels, self.config.vocab_size)
1391
  labels = labels.to(logits.device)
1392
  loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1393
 
 
1442
  if type_ids is None and input_ids is not None:
1443
  if attention_mask is None:
1444
  attention_mask = input_ids.ne(self.config.pad_token_id)
1445
+ input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
1446
  type_ids = _infer_modality_type(input_ids, attention_mask)
1447
 
1448
  outputs = self.esm(
 
1522
  if type_ids is None and input_ids is not None:
1523
  if attention_mask is None:
1524
  attention_mask = input_ids.ne(self.config.pad_token_id)
1525
+ input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
1526
  type_ids = _infer_modality_type(input_ids, attention_mask)
1527
 
1528
  outputs = self.esm(