| | import torch |
| | import torch.nn as nn |
| | from typing import Optional |
| | from transformers import EsmTokenizer, EsmConfig |
| | from transformers.utils import ModelOutput |
| | from dataclasses import dataclass |
| |
|
| | try: |
| | from model_components.transformer import TransformerForMaskedLM, TransformerConfig |
| | except: |
| | try: |
| | from protify.model_components.transformer import TransformerForMaskedLM, TransformerConfig |
| | except: |
| | from ..model_components.transformer import TransformerForMaskedLM, TransformerConfig |
| |
|
| |
|
| | presets = { |
| | 'Random': 'random', |
| | 'Random-Transformer': 'facebook/esm2_t12_35M_UR50D', |
| | 'Random-ESM2-8': 'facebook/esm2_t6_8M_UR50D', |
| | 'Random-ESM2-35': 'facebook/esm2_t12_35M_UR50D', |
| | 'Random-ESM2-150': 'facebook/esm2_t30_150M_UR50D', |
| | 'Random-ESM2-650': 'facebook/esm2_t36_650M_UR50D', |
| | } |
| |
|
| |
|
| | @dataclass |
| | class RandomModelOutput(ModelOutput): |
| | last_hidden_state: torch.FloatTensor = None |
| | logits: torch.FloatTensor = None |
| |
|
| |
|
| | class RandomModel(nn.Module): |
| | def __init__(self, config: EsmConfig): |
| | super().__init__() |
| | self.config = config |
| | self.hidden_size = config.hidden_size |
| | self.holder_param = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size)) |
| | |
| | self.lm_head = nn.Linear(self.hidden_size, config.vocab_size) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_logits: bool = False, |
| | ): |
| | device = self.holder_param.device |
| | B, T = input_ids.shape |
| | last_hidden_state = torch.randn(B, T, self.hidden_size, device=device, dtype=self.holder_param.dtype) |
| | if return_logits: |
| | logits = self.lm_head(last_hidden_state) |
| | return RandomModelOutput(last_hidden_state=last_hidden_state, logits=logits) |
| | else: |
| | return last_hidden_state |
| |
|
| |
|
| | class RandomTransformer(nn.Module): |
| | def __init__(self, config: TransformerConfig): |
| | super().__init__() |
| | self.config = config |
| | self.transformer = TransformerForMaskedLM(config) |
| |
|
| | def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> torch.Tensor: |
| | if output_attentions: |
| | out = self.transformer(input_ids, attention_mask, output_attentions=output_attentions) |
| | return out.last_hidden_state, out.attentions |
| | else: |
| | return self.transformer(input_ids, attention_mask).last_hidden_state |
| |
|
| |
|
| | class RandomTransformerForMaskedLM(nn.Module): |
| | """Random-initialized transformer that returns logits for ProteinGym scoring.""" |
| | def __init__(self, config: TransformerConfig): |
| | super().__init__() |
| | self.config = config |
| | self.transformer = TransformerForMaskedLM(config) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ) -> RandomModelOutput: |
| | out = self.transformer(input_ids, attention_mask, return_preds=False) |
| | return RandomModelOutput(last_hidden_state=out.last_hidden_state, logits=out.logits) |
| |
|
| |
|
| | def _build_random_transformer_config(preset: str) -> TransformerConfig: |
| | esm_config = EsmConfig.from_pretrained(presets[preset]) |
| | config = TransformerConfig() |
| | config.hidden_size = esm_config.hidden_size |
| | config.n_heads = esm_config.num_attention_heads |
| | config.n_layers = esm_config.num_hidden_layers |
| | config.vocab_size = esm_config.vocab_size |
| | config.attn_implementation = 'sdpa' |
| | return config |
| |
|
| |
|
| | def build_random_model(preset: str, masked_lm: bool = False, model_path: str = None, **kwargs): |
| | tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t12_35M_UR50D') |
| | if preset == 'Random': |
| | model = RandomModel(EsmConfig.from_pretrained('facebook/esm2_t12_35M_UR50D')) |
| | else: |
| | config = _build_random_transformer_config(preset) |
| | if masked_lm: |
| | model = RandomTransformerForMaskedLM(config).eval() |
| | else: |
| | model = RandomTransformer(config).eval() |
| | return model, tokenizer |
| |
|
| |
|
| | if __name__ == '__main__': |
| | model, tokenizer = build_random_model('Random-Transformer') |
| | print(model) |
| | print(tokenizer) |