File size: 4,280 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
from typing import Optional
from transformers import EsmTokenizer, EsmConfig
from transformers.utils import ModelOutput
from dataclasses import dataclass

try:
    from model_components.transformer import TransformerForMaskedLM, TransformerConfig
except:
    try:
        from protify.model_components.transformer import TransformerForMaskedLM, TransformerConfig
    except:
        from ..model_components.transformer import TransformerForMaskedLM, TransformerConfig


presets = {
    'Random': 'random',
    'Random-Transformer': 'facebook/esm2_t12_35M_UR50D', # default is 35M version
    'Random-ESM2-8': 'facebook/esm2_t6_8M_UR50D',
    'Random-ESM2-35': 'facebook/esm2_t12_35M_UR50D',
    'Random-ESM2-150': 'facebook/esm2_t30_150M_UR50D',
    'Random-ESM2-650': 'facebook/esm2_t36_650M_UR50D',
}


@dataclass
class RandomModelOutput(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    logits: torch.FloatTensor = None


class RandomModel(nn.Module):
    def __init__(self, config: EsmConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.holder_param = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
        # Simple projection head to produce token logits
        self.lm_head = nn.Linear(self.hidden_size, config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        return_logits: bool = False,
    ):
        device = self.holder_param.device
        B, T = input_ids.shape
        last_hidden_state = torch.randn(B, T, self.hidden_size, device=device, dtype=self.holder_param.dtype)
        if return_logits:
            logits = self.lm_head(last_hidden_state)  # (B, T, vocab)
            return RandomModelOutput(last_hidden_state=last_hidden_state, logits=logits)
        else:
            return last_hidden_state


class RandomTransformer(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.transformer = TransformerForMaskedLM(config)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False) -> torch.Tensor:
        if output_attentions:
            out = self.transformer(input_ids, attention_mask, output_attentions=output_attentions)
            return out.last_hidden_state, out.attentions
        else:
            return self.transformer(input_ids, attention_mask).last_hidden_state


class RandomTransformerForMaskedLM(nn.Module):
    """Random-initialized transformer that returns logits for ProteinGym scoring."""
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.transformer = TransformerForMaskedLM(config)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> RandomModelOutput:
        out = self.transformer(input_ids, attention_mask, return_preds=False)
        return RandomModelOutput(last_hidden_state=out.last_hidden_state, logits=out.logits)


def _build_random_transformer_config(preset: str) -> TransformerConfig:
    esm_config = EsmConfig.from_pretrained(presets[preset])
    config = TransformerConfig()
    config.hidden_size = esm_config.hidden_size
    config.n_heads = esm_config.num_attention_heads
    config.n_layers = esm_config.num_hidden_layers
    config.vocab_size = esm_config.vocab_size
    config.attn_implementation = 'sdpa'
    return config


def build_random_model(preset: str, masked_lm: bool = False, model_path: str = None, **kwargs):
    tokenizer = EsmTokenizer.from_pretrained('facebook/esm2_t12_35M_UR50D')
    if preset == 'Random':
        model = RandomModel(EsmConfig.from_pretrained('facebook/esm2_t12_35M_UR50D'))
    else:
        config = _build_random_transformer_config(preset)
        if masked_lm:
            model = RandomTransformerForMaskedLM(config).eval()
        else:
            model = RandomTransformer(config).eval()
    return model, tokenizer


if __name__ == '__main__':
    model, tokenizer = build_random_model('Random-Transformer')
    print(model)
    print(tokenizer)