nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
import torch.nn as nn
from typing import Optional
from transformers import EsmTokenizer, EsmConfig
from transformers.utils import ModelOutput
from dataclasses import dataclass
try:
from model_components.transformer import TransformerForMaskedLM, TransformerConfig
except:
try:
from protify.model_components.transformer import TransformerForMaskedLM, TransformerConfig
except:
from ..model_components.transformer import TransformerForMaskedLM, TransformerConfig
presets = {
'Random': 'random',
'Random-Transformer': 'facebook/esm2_t12_35M_UR50D', # default is 35M version
'Random-ESM2-8': 'facebook/esm2_t6_8M_UR50D',
'Random-ESM2-35': 'facebook/esm2_t12_35M_UR50D',
'Random-ESM2-150': 'facebook/esm2_t30_150M_UR50D',
'Random-ESM2-650': 'facebook/esm2_t36_650M_UR50D',
}
@dataclass
class RandomModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
logits: torch.FloatTensor = None
class RandomModel(nn.Module):
def __init__(self, config: EsmConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.holder_param = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
# Simple projection head to produce token logits
self.lm_head = nn.Linear(self.hidden_size, config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_logits: bool = False,
):
device = self.holder_param.device
B, T = input_ids.shape
last_hidden_state = torch.randn(B, T, self.hidden_size, device=device, dtype=self.holder_param.dtype)
if return_logits:
logits = self.lm_head(last_hidden_state) # (B, T, vocab)
return RandomModelOutput(last_hidden_state=last_hidden_state, logits=logits)
else:
return last_hidden_state
class RandomTransformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.transformer = TransformerForMaskedLM(config)
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> torch.Tensor:
if output_attentions:
out = self.transformer(input_ids, attention_mask, output_attentions=output_attentions)
return out.last_hidden_state, out.attentions
else:
return self.transformer(input_ids, attention_mask).last_hidden_state
class RandomTransformerForMaskedLM(nn.Module):
"""Random-initialized transformer that returns logits for ProteinGym scoring."""
def __init__(self, config: TransformerConfig):
super().__init__()
self.config = config
self.transformer = TransformerForMaskedLM(config)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> RandomModelOutput:
out = self.transformer(input_ids, attention_mask, return_preds=False)
return RandomModelOutput(last_hidden_state=out.last_hidden_state, logits=out.logits)
def _build_random_transformer_config(preset: str) -> TransformerConfig:
esm_config = EsmConfig.from_pretrained(presets[preset])
config = TransformerConfig()
config.hidden_size = esm_config.hidden_size
config.n_heads = esm_config.num_attention_heads
config.n_layers = esm_config.num_hidden_layers
config.vocab_size = esm_config.vocab_size
config.attn_implementation = 'sdpa'
return config
def build_random_model(preset: str, masked_lm: bool = False, model_path: str = None, **kwargs):
tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t12_35M_UR50D')
if preset == 'Random':
model = RandomModel(EsmConfig.from_pretrained('facebook/esm2_t12_35M_UR50D'))
else:
config = _build_random_transformer_config(preset)
if masked_lm:
model = RandomTransformerForMaskedLM(config).eval()
else:
model = RandomTransformer(config).eval()
return model, tokenizer
if __name__ == '__main__':
model, tokenizer = build_random_model('Random-Transformer')
print(model)
print(tokenizer)