| |
|
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | from transformers import PreTrainedModel, PretrainedConfig |
| | from transformers.modeling_outputs import CausalLMOutput |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TinyWayConfig(PretrainedConfig): |
| | model_type = "tinyway" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size=50257, |
| | n_positions=256, |
| | n_embd=512, |
| | n_layer=10, |
| | n_head=8, |
| | dropout=0.1, |
| | **kwargs |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | self.vocab_size = vocab_size |
| | self.n_positions = n_positions |
| | self.n_embd = n_embd |
| | self.n_layer = n_layer |
| | self.n_head = n_head |
| | self.dropout = dropout |
| |
|
| | |
| | self.hidden_size = n_embd |
| | self.num_hidden_layers = n_layer |
| | self.num_attention_heads = n_head |
| | self.max_position_embeddings = n_positions |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| | assert config.n_embd % config.n_head == 0 |
| |
|
| | self.n_head = config.n_head |
| | self.head_dim = config.n_embd // config.n_head |
| |
|
| | self.qkv = nn.Linear(config.n_embd, 3 * config.n_embd) |
| | self.proj = nn.Linear(config.n_embd, config.n_embd) |
| |
|
| | self.attn_dropout = nn.Dropout(config.dropout) |
| | self.proj_dropout = nn.Dropout(config.dropout) |
| |
|
| | self.register_buffer( |
| | "mask", |
| | torch.tril( |
| | torch.ones( |
| | config.n_positions, |
| | config.n_positions, |
| | dtype=torch.bool |
| | ) |
| | ) |
| | ) |
| |
|
| | self.last_attn = None |
| |
|
| | def forward(self, x): |
| | B, T, C = x.shape |
| |
|
| | qkv = self.qkv(x) |
| | q, k, v = qkv.chunk(3, dim=-1) |
| |
|
| | q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| | v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| |
|
| | att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| | att = att.masked_fill( |
| | ~self.mask[:T, :T], |
| | torch.finfo(att.dtype).min |
| | ) |
| |
|
| | att = F.softmax(att, dim=-1) |
| | self.last_attn = att.detach() |
| |
|
| | att = self.attn_dropout(att) |
| |
|
| | out = att @ v |
| | out = out.transpose(1, 2).contiguous().view(B, T, C) |
| |
|
| | out = self.proj(out) |
| | out = self.proj_dropout(out) |
| |
|
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class Block(nn.Module): |
| | def __init__(self, config): |
| | super().__init__() |
| |
|
| | self.ln1 = nn.LayerNorm(config.n_embd) |
| | self.attn = CausalSelfAttention(config) |
| |
|
| | self.ln2 = nn.LayerNorm(config.n_embd) |
| |
|
| | |
| | self.ffn = nn.Sequential( |
| | nn.Linear(config.n_embd, 4 * config.n_embd), |
| | nn.GELU(), |
| | nn.Linear(4 * config.n_embd, config.n_embd), |
| | nn.Dropout(config.dropout), |
| | ) |
| |
|
| | def forward(self, x): |
| | x = x + self.attn(self.ln1(x)) |
| | x = x + self.ffn(self.ln2(x)) |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class TinyWayForCausalLM(PreTrainedModel): |
| | config_class = TinyWayConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.token_emb = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.pos_emb = nn.Embedding(config.n_positions, config.n_embd) |
| |
|
| | self.blocks = nn.ModuleList([ |
| | Block(config) for _ in range(config.n_layer) |
| | ]) |
| |
|
| | self.ln = nn.LayerNorm(config.n_embd) |
| |
|
| | self.head = nn.Linear( |
| | config.n_embd, |
| | config.vocab_size, |
| | bias=False |
| | ) |
| |
|
| | |
| | self.head.weight = self.token_emb.weight |
| |
|
| | self.dropout = nn.Dropout(config.dropout) |
| |
|
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_ids, |
| | labels=None, |
| | attention_mask=None, |
| | **kwargs |
| | ): |
| | B, T = input_ids.shape |
| | pos = torch.arange(T, device=input_ids.device) |
| |
|
| | x = self.token_emb(input_ids) + self.pos_emb(pos) |
| | x = self.dropout(x) |
| |
|
| | for block in self.blocks: |
| | x = block(x) |
| |
|
| | x = self.ln(x) |
| | logits = self.head(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1) |
| | ) |
| |
|
| | return CausalLMOutput( |
| | loss=loss, |
| | logits=logits |
| | ) |
| |
|
| | |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | return {"input_ids": input_ids} |
| |
|