Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +102 -11
modeling_dplm2.py
CHANGED
|
@@ -156,6 +156,26 @@ def build_collator(tokenizer: PreTrainedTokenizerBase) -> Callable[[list[str]],
|
|
| 156 |
return _collate_fn
|
| 157 |
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
class EmbeddingMixin:
|
| 160 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 161 |
raise NotImplementedError
|
|
@@ -243,7 +263,7 @@ class EmbeddingMixin:
|
|
| 243 |
|
| 244 |
def embed_dataset(
|
| 245 |
self,
|
| 246 |
-
sequences: List[str],
|
| 247 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 248 |
batch_size: int = 2,
|
| 249 |
max_len: int = 512,
|
|
@@ -256,6 +276,7 @@ class EmbeddingMixin:
|
|
| 256 |
save: bool = True,
|
| 257 |
sql_db_path: str = 'embeddings.db',
|
| 258 |
save_path: str = 'embeddings.pth',
|
|
|
|
| 259 |
**kwargs,
|
| 260 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 261 |
"""
|
|
@@ -264,7 +285,15 @@ class EmbeddingMixin:
|
|
| 264 |
Supports two modes:
|
| 265 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 266 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
|
|
|
|
|
|
|
|
|
| 267 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 269 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 270 |
hidden_size = self.config.hidden_size
|
|
@@ -668,8 +697,9 @@ def get_attention_mask(
|
|
| 668 |
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
|
| 669 |
return attention_mask_2d, None, flex_block_mask
|
| 670 |
|
| 671 |
-
# SDPA / manual
|
| 672 |
-
|
|
|
|
| 673 |
return attention_mask_2d, attention_mask_4d, None
|
| 674 |
|
| 675 |
|
|
@@ -680,6 +710,55 @@ def _infer_modality_type(input_ids: torch.Tensor, attention_mask: torch.Tensor)
|
|
| 680 |
return modality_type
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
@dataclass
|
| 684 |
class DPLM2MaskedLMOutput(ModelOutput):
|
| 685 |
loss: Optional[torch.Tensor] = None
|
|
@@ -747,17 +826,22 @@ class DPLM2PreTrainedModel(EsmPreTrainedModel):
|
|
| 747 |
|
| 748 |
|
| 749 |
class ModifiedRotaryEmbedding(RotaryEmbedding):
|
| 750 |
-
def __init__(self, dim: int, aa_type: int, struct_type: int):
|
| 751 |
super().__init__(dim)
|
| 752 |
self.aa_type = aa_type
|
| 753 |
self.struct_type = struct_type
|
|
|
|
| 754 |
|
| 755 |
def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 761 |
|
| 762 |
def _update_cos_sin_tables(
|
| 763 |
self,
|
|
@@ -774,12 +858,13 @@ class ModifiedRotaryEmbedding(RotaryEmbedding):
|
|
| 774 |
or self._sin_cached is None
|
| 775 |
or seq_len != self._seq_len_cached
|
| 776 |
or self._cos_cached.device != x.device
|
|
|
|
| 777 |
)
|
| 778 |
if cache_is_stale:
|
| 779 |
self._seq_len_cached = seq_len
|
| 780 |
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 781 |
freqs = torch.outer(t, self.inv_freq)
|
| 782 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
| 783 |
self._cos_cached = emb.cos()[None, None, :, :]
|
| 784 |
self._sin_cached = emb.sin()[None, None, :, :]
|
| 785 |
|
|
@@ -823,6 +908,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 823 |
dim=self.attention_head_size,
|
| 824 |
aa_type=config.aa_type,
|
| 825 |
struct_type=config.struct_type,
|
|
|
|
| 826 |
)
|
| 827 |
|
| 828 |
def forward(
|
|
@@ -1112,6 +1198,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1112 |
self.embeddings.word_embeddings = value
|
| 1113 |
|
| 1114 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 1115 |
if attention_mask is None:
|
| 1116 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1117 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
|
@@ -1143,7 +1230,7 @@ class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1143 |
if input_ids is not None and inputs_embeds is not None:
|
| 1144 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1145 |
elif input_ids is not None:
|
| 1146 |
-
|
| 1147 |
elif inputs_embeds is None:
|
| 1148 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1149 |
|
|
@@ -1248,6 +1335,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1248 |
self.lm_head.decoder = new_embeddings
|
| 1249 |
|
| 1250 |
def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 1251 |
return _infer_modality_type(input_ids, attention_mask)
|
| 1252 |
|
| 1253 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
@@ -1299,6 +1387,7 @@ class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1299 |
logits = self.lm_head(sequence_output)
|
| 1300 |
loss = None
|
| 1301 |
if labels is not None:
|
|
|
|
| 1302 |
labels = labels.to(logits.device)
|
| 1303 |
loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1304 |
|
|
@@ -1353,6 +1442,7 @@ class DPLM2ForSequenceClassification(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1353 |
if type_ids is None and input_ids is not None:
|
| 1354 |
if attention_mask is None:
|
| 1355 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
|
|
|
| 1356 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
| 1357 |
|
| 1358 |
outputs = self.esm(
|
|
@@ -1432,6 +1522,7 @@ class DPLM2ForTokenClassification(DPLM2PreTrainedModel, EmbeddingMixin):
|
|
| 1432 |
if type_ids is None and input_ids is not None:
|
| 1433 |
if attention_mask is None:
|
| 1434 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
|
|
|
| 1435 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
| 1436 |
|
| 1437 |
outputs = self.esm(
|
|
|
|
| 156 |
return _collate_fn
|
| 157 |
|
| 158 |
|
| 159 |
+
def parse_fasta(fasta_path: str) -> List[str]:
|
| 160 |
+
assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}"
|
| 161 |
+
sequences = []
|
| 162 |
+
current_seq = []
|
| 163 |
+
with open(fasta_path, 'r') as f:
|
| 164 |
+
for line in f:
|
| 165 |
+
line = line.strip()
|
| 166 |
+
if not line:
|
| 167 |
+
continue
|
| 168 |
+
if line.startswith('>'):
|
| 169 |
+
if current_seq:
|
| 170 |
+
sequences.append(''.join(current_seq))
|
| 171 |
+
current_seq = []
|
| 172 |
+
else:
|
| 173 |
+
current_seq.append(line)
|
| 174 |
+
if current_seq:
|
| 175 |
+
sequences.append(''.join(current_seq))
|
| 176 |
+
return sequences
|
| 177 |
+
|
| 178 |
+
|
| 179 |
class EmbeddingMixin:
|
| 180 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 181 |
raise NotImplementedError
|
|
|
|
| 263 |
|
| 264 |
def embed_dataset(
|
| 265 |
self,
|
| 266 |
+
sequences: Optional[List[str]] = None,
|
| 267 |
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 268 |
batch_size: int = 2,
|
| 269 |
max_len: int = 512,
|
|
|
|
| 276 |
save: bool = True,
|
| 277 |
sql_db_path: str = 'embeddings.db',
|
| 278 |
save_path: str = 'embeddings.pth',
|
| 279 |
+
fasta_path: Optional[str] = None,
|
| 280 |
**kwargs,
|
| 281 |
) -> Optional[dict[str, torch.Tensor]]:
|
| 282 |
"""
|
|
|
|
| 285 |
Supports two modes:
|
| 286 |
- Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used.
|
| 287 |
- Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used.
|
| 288 |
+
|
| 289 |
+
Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via
|
| 290 |
+
`fasta_path`, or both (the two sources are combined). At least one must be provided.
|
| 291 |
"""
|
| 292 |
+
if fasta_path is not None:
|
| 293 |
+
fasta_sequences = parse_fasta(fasta_path)
|
| 294 |
+
sequences = list(sequences or []) + fasta_sequences
|
| 295 |
+
assert sequences is not None and len(sequences) > 0, \
|
| 296 |
+
"Must provide at least one sequence via `sequences` or `fasta_path`."
|
| 297 |
sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences]))
|
| 298 |
sequences = sorted(sequences, key=len, reverse=True)
|
| 299 |
hidden_size = self.config.hidden_size
|
|
|
|
| 697 |
flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device)
|
| 698 |
return attention_mask_2d, None, flex_block_mask
|
| 699 |
|
| 700 |
+
# SDPA / manual — only mask the key dimension so padding query positions attend to
|
| 701 |
+
# real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf).
|
| 702 |
+
attention_mask_4d = attention_mask_2d[:, None, None, :]
|
| 703 |
return attention_mask_2d, attention_mask_4d, None
|
| 704 |
|
| 705 |
|
|
|
|
| 710 |
return modality_type
|
| 711 |
|
| 712 |
|
| 713 |
+
def _normalize_dplm2_input_ids(input_ids: torch.Tensor, vocab_size: int) -> torch.Tensor:
|
| 714 |
+
if input_ids.numel() == 0:
|
| 715 |
+
return input_ids
|
| 716 |
+
|
| 717 |
+
normalized_input_ids = input_ids.clone()
|
| 718 |
+
generic_to_aa_special_ids = {
|
| 719 |
+
vocab_size: 2,
|
| 720 |
+
vocab_size + 1: 3,
|
| 721 |
+
vocab_size + 2: 0,
|
| 722 |
+
vocab_size + 3: 32,
|
| 723 |
+
}
|
| 724 |
+
for generic_id, aa_id in generic_to_aa_special_ids.items():
|
| 725 |
+
normalized_input_ids[input_ids == generic_id] = aa_id
|
| 726 |
+
|
| 727 |
+
valid_token_mask = normalized_input_ids.ge(0)
|
| 728 |
+
if valid_token_mask.any():
|
| 729 |
+
max_token_id = int(normalized_input_ids[valid_token_mask].max().item())
|
| 730 |
+
assert max_token_id < vocab_size, (
|
| 731 |
+
f"Found token id {max_token_id} outside the DPLM2 embedding table (vocab_size={vocab_size}). "
|
| 732 |
+
"Tokenizer special tokens must be normalized before embedding."
|
| 733 |
+
)
|
| 734 |
+
return normalized_input_ids
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def _has_packed_multimodal_layout(
|
| 738 |
+
type_ids: Optional[torch.Tensor],
|
| 739 |
+
aa_type: int,
|
| 740 |
+
struct_type: int,
|
| 741 |
+
pad_type: int,
|
| 742 |
+
) -> bool:
|
| 743 |
+
if type_ids is None:
|
| 744 |
+
return False
|
| 745 |
+
assert type_ids.ndim == 2, f"Expected type_ids to have shape (batch, seq_len), got {tuple(type_ids.shape)}"
|
| 746 |
+
seq_len = type_ids.shape[-1]
|
| 747 |
+
if seq_len % 2 != 0:
|
| 748 |
+
return False
|
| 749 |
+
|
| 750 |
+
half_len = seq_len // 2
|
| 751 |
+
first_half = type_ids[:, :half_len]
|
| 752 |
+
second_half = type_ids[:, half_len:]
|
| 753 |
+
|
| 754 |
+
first_half_valid = ((first_half == aa_type) | (first_half == pad_type)).all(dim=-1)
|
| 755 |
+
second_half_valid = ((second_half == struct_type) | (second_half == pad_type)).all(dim=-1)
|
| 756 |
+
aa_count = (first_half == aa_type).sum(dim=-1)
|
| 757 |
+
struct_count = (second_half == struct_type).sum(dim=-1)
|
| 758 |
+
packed_rows = first_half_valid & second_half_valid & aa_count.gt(0) & aa_count.eq(struct_count)
|
| 759 |
+
return bool(packed_rows.all())
|
| 760 |
+
|
| 761 |
+
|
| 762 |
@dataclass
|
| 763 |
class DPLM2MaskedLMOutput(ModelOutput):
|
| 764 |
loss: Optional[torch.Tensor] = None
|
|
|
|
| 826 |
|
| 827 |
|
| 828 |
class ModifiedRotaryEmbedding(RotaryEmbedding):
|
| 829 |
+
def __init__(self, dim: int, aa_type: int, struct_type: int, pad_type: int):
|
| 830 |
super().__init__(dim)
|
| 831 |
self.aa_type = aa_type
|
| 832 |
self.struct_type = struct_type
|
| 833 |
+
self.pad_type = pad_type
|
| 834 |
|
| 835 |
def _has_multimodal_tokens(self, type_ids: Optional[torch.Tensor]) -> bool:
|
| 836 |
+
# The split rotary path only works when the sequence tensor is already packed
|
| 837 |
+
# as [AA half | structure half]. Plain protein batches can still contain
|
| 838 |
+
# high-ID special tokens, so mere modality presence is not enough.
|
| 839 |
+
return _has_packed_multimodal_layout(
|
| 840 |
+
type_ids=type_ids,
|
| 841 |
+
aa_type=self.aa_type,
|
| 842 |
+
struct_type=self.struct_type,
|
| 843 |
+
pad_type=self.pad_type,
|
| 844 |
+
)
|
| 845 |
|
| 846 |
def _update_cos_sin_tables(
|
| 847 |
self,
|
|
|
|
| 858 |
or self._sin_cached is None
|
| 859 |
or seq_len != self._seq_len_cached
|
| 860 |
or self._cos_cached.device != x.device
|
| 861 |
+
or self._cos_cached.dtype != x.dtype
|
| 862 |
)
|
| 863 |
if cache_is_stale:
|
| 864 |
self._seq_len_cached = seq_len
|
| 865 |
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 866 |
freqs = torch.outer(t, self.inv_freq)
|
| 867 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device=x.device, dtype=x.dtype)
|
| 868 |
self._cos_cached = emb.cos()[None, None, :, :]
|
| 869 |
self._sin_cached = emb.sin()[None, None, :, :]
|
| 870 |
|
|
|
|
| 908 |
dim=self.attention_head_size,
|
| 909 |
aa_type=config.aa_type,
|
| 910 |
struct_type=config.struct_type,
|
| 911 |
+
pad_type=config.pad_type,
|
| 912 |
)
|
| 913 |
|
| 914 |
def forward(
|
|
|
|
| 1198 |
self.embeddings.word_embeddings = value
|
| 1199 |
|
| 1200 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 1201 |
+
input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
|
| 1202 |
if attention_mask is None:
|
| 1203 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1204 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
|
|
|
| 1230 |
if input_ids is not None and inputs_embeds is not None:
|
| 1231 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 1232 |
elif input_ids is not None:
|
| 1233 |
+
input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
|
| 1234 |
elif inputs_embeds is None:
|
| 1235 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 1236 |
|
|
|
|
| 1335 |
self.lm_head.decoder = new_embeddings
|
| 1336 |
|
| 1337 |
def _get_modality_type(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 1338 |
+
input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
|
| 1339 |
return _infer_modality_type(input_ids, attention_mask)
|
| 1340 |
|
| 1341 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 1387 |
logits = self.lm_head(sequence_output)
|
| 1388 |
loss = None
|
| 1389 |
if labels is not None:
|
| 1390 |
+
labels = _normalize_dplm2_input_ids(labels, self.config.vocab_size)
|
| 1391 |
labels = labels.to(logits.device)
|
| 1392 |
loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
|
| 1393 |
|
|
|
|
| 1442 |
if type_ids is None and input_ids is not None:
|
| 1443 |
if attention_mask is None:
|
| 1444 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1445 |
+
input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
|
| 1446 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
| 1447 |
|
| 1448 |
outputs = self.esm(
|
|
|
|
| 1522 |
if type_ids is None and input_ids is not None:
|
| 1523 |
if attention_mask is None:
|
| 1524 |
attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 1525 |
+
input_ids = _normalize_dplm2_input_ids(input_ids, self.config.vocab_size)
|
| 1526 |
type_ids = _infer_modality_type(input_ids, attention_mask)
|
| 1527 |
|
| 1528 |
outputs = self.esm(
|