| import torch |
| from torch import nn |
|
|
| tokenizer = "gpt2" |
|
|
| |
| |
| |
| class MLA(nn.Module): |
| def __init__(self, d_model=32, num_heads=4, num_latents=4, latent_dim=32): |
| super().__init__() |
| self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) |
| self.attn = nn.MultiheadAttention( |
| embed_dim=d_model, |
| num_heads=num_heads, |
| batch_first=True |
| ) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_model), |
| nn.GELU(), |
| nn.Linear(d_model, d_model) |
| ) |
|
|
| def forward(self, x): |
| batch_size = x.size(0) |
| latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) |
| updated_latents, _ = self.attn(query=latents, key=x, value=x) |
| updated_latents = updated_latents + self.ff(updated_latents) |
| return updated_latents |
|
|
|
|
| |
| |
| |
| class Model(nn.Module): |
| def __init__(self, vocab_dim, d_model=36, num_classes=2, num_cls_tokens=4): |
| super().__init__() |
| self.d_model = d_model |
| self.num_cls_tokens = num_cls_tokens |
|
|
| self.token_embed = nn.Embedding(vocab_dim, d_model) |
| self.pos_embed = nn.Embedding(512, d_model) |
|
|
| self.compress = nn.Sequential( |
| nn.Linear(512, 150), |
| nn.GELU(), nn.AlphaDropout(0.05), nn.RMSNorm(150), |
| nn.Linear(150, d_model) |
| ) |
|
|
| te = nn.TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=6, |
| dim_feedforward=100, |
| dropout=0.26, |
| activation=nn.functional.gelu, |
| batch_first=True |
| ) |
| self.encoder = nn.TransformerEncoder(te, num_layers=6) |
|
|
| self.mla = MLA(d_model=d_model, num_heads=6, num_latents=8, latent_dim=d_model) |
|
|
| self.head = nn.Linear((num_cls_tokens + self.mla.latents.size(0)) * d_model, num_classes) |
|
|
| def forward(self, x): |
| batch_size, seq_len = x.shape |
|
|
| pos = torch.arange(512, device=x.device).unsqueeze(0).expand(batch_size, 512) |
|
|
| |
| x = nn.functional.pad(x, (0, 512 - seq_len)) |
| |
| |
| x = self.token_embed(x) + self.pos_embed(pos) |
| |
| x = self.compress(x.transpose(1, 2)).transpose(1, 2) |
|
|
| out = self.encoder(x) |
|
|
| cls_embeddings = out[:, :self.num_cls_tokens, :].reshape(batch_size, -1) |
| mla_embeddings = self.mla(out).reshape(batch_size, -1) |
|
|
| features = torch.cat([cls_embeddings, mla_embeddings], dim=-1) |
| logits = self.head(features) |
| return logits |