| | """ |
| | Author: Eshan Jayasundara |
| | Last Updated: 2nd of March 2025 |
| | Created: 28th of February 2025 |
| | ___ |
| | |
| | About: |
| | βββ Single head transformer (Transformer with self-attention training with teacher-forcing) |
| | ___ |
| | |
| | Training: |
| | βββ Teacher Forcing (Baseline) |
| | βββ During training, the actual ground-truth tokens (from the dataset) are fed as input to the decoder instead of using the modelβs own predictions. |
| | βββ This makes training faster and ensures the model learns accurate token-to-token mappings. |
| | βββ Drawback: At inference time, the model doesn't see ground-truth inputs, so errors can accumulate (called exposure bias). |
| | ___ |
| | |
| | vocabulary dataset (from huggingface): |
| | βββ "yukiarimo/english-vocabulary" |
| | ___ |
| | |
| | Architecture: |
| | |
| | Encoder |
| | βββ Input text |
| | β βββ Eg: "Hello, how are you?" |
| | βββ Remove punctuation from input text |
| | βββ Input tokenization |
| | βββ Embedding lookup with torch.nn.Embedding |
| | βββ Positional encoding (sin, cosine) |
| | βββ Self-attention |
| | β βββ single-head |
| | β βββ Q = Wq @ Embedding |
| | β βββ K = Wk @ Embedding |
| | β βββ V = Wv @ Embedding |
| | βββ Add and norm |
| | βββ Feed forward layer |
| | β βββ 2 hidden layers |
| | β βββ ReLU as the activation in hidden layer |
| | β βββ No activation at the output layer |
| | β βββ nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim) |
| | βββ Add and norm (again) |
| | βββ Save encoder out to be used in cross attention |
| | |
| | Decoder |
| | βββ Decoder teacher text (same as the target text but shifted right) |
| | β βββ Eg: Decoder teacher text - "<SOS> hello, I'm fine." |
| | β βββ Eg: target text - "hello, I'm fine. <EOS>" |
| | βββ Remove punctuation from input text |
| | βββ Input tokenization |
| | βββ Embedding lookup with torch.nn.Embedding |
| | βββ Positional encoding (sin, cosine) |
| | βββ Masked-self-attention (single-head, new class signature for masked self attention introduced) |
| | β βββ single-head |
| | β βββ causal mask with triangular matrix |
| | β βββ Q = Wq @ Embedding |
| | β βββ K = Wk @ Embedding |
| | β βββ V = Wv @ Embedding |
| | βββ Add and norm |
| | βββ Cross attention (same class signature used in the encoder self-attention can be used) |
| | β βββ single-head |
| | β βββ Q = Wq @ Add and normalized output from masked-self-attention |
| | β βββ K = Wk @ Encoder output |
| | β βββ V = Wv @ Encoder output |
| | βββ Add and norm |
| | βββ Feed forward layer |
| | β βββ 2 hidden layers |
| | β βββ ReLU as the activation in hidden layer |
| | β βββ No activation at the output layer |
| | β βββ nn.Linear(in_features=embedding_dim, out_features=d_ff), nn.ReLU(), nn.Linear(in_features=d_ff, out_features=embedding_dim) |
| | βββ Add and norm (again) |
| | βββ Linear layer (No activation or softmax as in 'Attention is all you need' is used here) |
| | |
| | Optimization |
| | βββ Initialize the Adam optimizer with the modelβs parameters and a specified learning rate. |
| | β βββ self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate) |
| | βββ Before computing gradients for the current batch, we reset any existing gradients from the previous iteration. |
| | β βββ self.optimizer.zero_grad() |
| | βββ The model takes in `input_tokens` and `decoder_teacher_tokens` and performs a forward pass to compute `logits` |
| | β βββ logits = self.forward(input_tokens, decoder_teacher_tokens) |
| | βββ The cross-entropy loss |
| | β βββ Measures the difference between the predicted token distribution (logits) and the actual target tokens (decoder_target_tokens). |
| | β βββ It expects logits to have raw scores (not probabilities), and it applies softmax internally. |
| | β βββ loss = F.cross_entropy(logits, decoder_target_tokens) |
| | βββ Compute the gradients of the loss with respect to all trainable parameters in the model using automatic differentiation (backpropagation). |
| | β βββ loss.backward() |
| | βββ Optimizer updates the model's weights using the computed gradients. |
| | βββ self.optimizer.step() |
| | |
| | After training, to calculate the output tokens -> text, 'Autoregressive text generation' is used (one word at a time) |
| | βββ Start with <SOS>. (Initial input to the decoder) but input to the encoder is the `prompt`. |
| | βββ Model predicts the next token. |
| | βββ Append the predicted token to the sequence. |
| | βββ Repeat until an <EOS> token or max length is reached. |
| | βββ For illustration let's use words instead of tokens(numerical representation) |
| | <SOS> |
| | <SOS> hello |
| | <SOS> hello I'm |
| | <SOS> hello I'm good |
| | <SOS> hello I'm good <EOS> |
| | ___ |
| | |
| | Feauter Improvements: |
| | βββ Multi-head attention instead of single-head attention. |
| | βββ Layer normalization instead of simple mean-variance normalization. |
| | βββ Dropout layers for better generalization. |
| | """ |
| |
|
| |
|
| | from datasets import load_dataset |
| | import torch |
| | import torch.nn as nn |
| | import string |
| | import torch.nn.functional as F |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | device = torch.device('cuda:1') |
| | print(f"Using Device: {device} | Name: {torch.cuda.get_device_name(0)}") |
| | else: |
| | device = torch.device('cpu') |
| | print(f"Using Device: {device}") |
| |
|
| | |
| | class SingleHeadAttention(torch.nn.Module): |
| | def __init__(self, embedding_dim): |
| | super().__init__() |
| | self.embedding_dim = embedding_dim |
| | self.query_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| | self.key_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| | self.value_layer = torch.nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| |
|
| | def forward(self, q_embedding, k_embedding, v_embedding, attention_mask): |
| | Q = self.query_layer.forward(q_embedding) |
| | K = self.key_layer.forward(k_embedding) |
| | V = self.value_layer.forward(v_embedding) |
| |
|
| | |
| | attention_scores = (torch.matmul(Q, K.transpose(-2, -1)) / self.embedding_dim ** 0.5).float() |
| | |
| | |
| | if attention_mask is not None: |
| | attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min) |
| |
|
| | |
| | attention_weights = F.softmax(attention_scores, dim=-1) |
| |
|
| | |
| | attention_output = torch.matmul(attention_weights, V) |
| |
|
| | return attention_output, attention_weights |
| |
|
| | |
| | class FeedForwardLayer(torch.nn.Module): |
| | def __init__(self, embedding_dim=64, d_ff=256): |
| | super().__init__() |
| | self.fc1 = torch.nn.Linear(in_features=embedding_dim, out_features=d_ff) |
| | self.fc2 = torch.nn.Linear(in_features=d_ff, out_features=embedding_dim) |
| | self.activation = torch.nn.ReLU() |
| |
|
| | def forward(self, x): |
| | return self.fc2.forward( |
| | self.activation( |
| | self.fc1.forward(x) |
| | ) |
| | ) |
| | |
| | |
| | class DecoderMaskedAttention(nn.Module): |
| | def __init__(self, embedding_dim): |
| | super().__init__() |
| | self.embedding_dim = embedding_dim |
| | self.query_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| | self.key_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| | self.value_layer = nn.Linear(in_features=embedding_dim, out_features=embedding_dim) |
| |
|
| | def forward(self, q_embedding, k_embedding, v_embedding, attention_mask=None): |
| | |
| | Q = self.query_layer(q_embedding) |
| | K = self.key_layer(k_embedding) |
| | V = self.value_layer(v_embedding) |
| |
|
| | |
| | attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embedding_dim ** 0.5) |
| |
|
| | |
| | seq_len = q_embedding.shape[0] |
| | causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() |
| |
|
| | |
| | attention_scores = attention_scores.masked_fill(causal_mask, torch.finfo(attention_scores.dtype).min) |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_scores = attention_scores.masked_fill(attention_mask == 0, torch.finfo(attention_scores.dtype).min) |
| |
|
| | |
| | attention_weights = F.softmax(attention_scores, dim=-1) |
| |
|
| | |
| | attention_output = torch.matmul(attention_weights, V) |
| |
|
| | return attention_output, attention_weights |
| | |
| |
|
| | class Transformer(torch.nn.Module): |
| | def __init__(self, embedding_dim, learning_rate=1e-3, vocab_dataset="yukiarimo/english-vocabulary", split="train"): |
| | super().__init__() |
| | |
| | |
| | self.vocab_df = load_dataset(vocab_dataset, split=split).to_pandas() |
| |
|
| | remove_indices = self.vocab_df[(self.vocab_df["text"]=='PAD') | (self.vocab_df["text"]=='SOS') | (self.vocab_df["text"]=='EOS')].index |
| | self.vocab_df = self.vocab_df.drop(remove_indices, axis=0) |
| |
|
| | self.vocab_df.loc[0, "text"] = '<PAD>' |
| | self.vocab_df.loc[1, "text"] = '<UNK>' |
| | self.vocab_df.loc[2, "text"] = '<SOS>' |
| | self.vocab_df.loc[3, "text"] = '<EOS>' |
| |
|
| | self.vocab_size = self.vocab_df.shape[0] |
| |
|
| | self.vocab_df['idx'] = range(0, self.vocab_size) |
| | self.vocab_df = self.vocab_df.set_index("text") |
| | self.vocab = self.vocab_df["idx"].to_dict() |
| |
|
| | |
| | self.embedding_fn = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=embedding_dim) |
| | self.encoder_self_attention = SingleHeadAttention(embedding_dim=embedding_dim) |
| | self.encoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4) |
| | self.cross_attention = SingleHeadAttention(embedding_dim=embedding_dim) |
| | self.decoder_masked_attention = DecoderMaskedAttention(embedding_dim=embedding_dim) |
| | self.decoder_ff = FeedForwardLayer(embedding_dim=embedding_dim, d_ff=embedding_dim * 4) |
| | self.linear = nn.Linear(in_features=embedding_dim, out_features=self.vocab_size) |
| |
|
| | |
| | self.parameters = list(self.embedding_fn.parameters()) + \ |
| | list(self.encoder_self_attention.parameters()) + \ |
| | list(self.encoder_ff.parameters()) + \ |
| | list(self.cross_attention.parameters()) + \ |
| | list(self.decoder_masked_attention.parameters()) + \ |
| | list(self.decoder_ff.parameters()) + \ |
| | list(self.linear.parameters()) |
| | |
| | |
| | self.optimizer = torch.optim.Adam(params=self.parameters, lr=learning_rate) |
| |
|
| | |
| | def remove_punctuation(self, text): |
| | return text.translate(str.maketrans("", "", string.punctuation)) |
| |
|
| | def tokenize(self, text, unk_token="<UNK>"): |
| | tokens = text.strip().split() |
| | return torch.tensor([self.vocab.get(token, self.vocab.get(unk_token)) for token in tokens], device=device) |
| |
|
| | def positional_encoding(self, embedding, max_len, embedding_dim=64): |
| | pe = torch.zeros(max_len, embedding_dim, device=device) |
| | |
| | |
| | position = torch.arange(0, max_len, dtype=torch.float, device=device).unsqueeze(1) |
| |
|
| | |
| | div_term = torch.exp(torch.arange(0, embedding_dim, 2, device=device).float() * (torch.log(torch.tensor(10000.0, device=device))) / embedding_dim) |
| |
|
| | |
| | pe[:, 0::2] = torch.sin(position / div_term) |
| | pe[:, 1::2] = torch.cos(position / div_term) |
| | |
| | return embedding + pe |
| | |
| | |
| | def add_norm(self, old_tensor, new_tensor): |
| | addition = old_tensor + new_tensor |
| | norm = (addition - addition.mean(dim=-1, keepdim=True)) / addition.std(dim=-1, keepdim=True) |
| | return norm |
| |
|
| | |
| | def encoder(self, encoder_input_tokens): |
| | encoder_input_embeddings = self.embedding_fn(encoder_input_tokens).to(device=device) |
| | encoder_input_pos_embeddings = self.positional_encoding(encoder_input_embeddings, max_len=encoder_input_embeddings.shape[0], embedding_dim=64).to(device=device) |
| | encoder_self_attention_out, _ = self.encoder_self_attention.forward( |
| | q_embedding=encoder_input_pos_embeddings, |
| | k_embedding=encoder_input_pos_embeddings, |
| | v_embedding=encoder_input_pos_embeddings, |
| | attention_mask=None |
| | ) |
| | add_norm_encoder_self_attention_out = self.add_norm(old_tensor=encoder_input_pos_embeddings, new_tensor=encoder_self_attention_out.to(device=device)).to(device=device) |
| | encoder_ff_out = self.encoder_ff.forward(add_norm_encoder_self_attention_out).to(device=device) |
| | add_norm_encoder_ff_out = self.add_norm(old_tensor=add_norm_encoder_self_attention_out, new_tensor=encoder_ff_out).to(device=device) |
| | return add_norm_encoder_ff_out |
| | |
| | |
| | def decoder(self, decoder_teacher_tokens, encoder_out): |
| | decoder_teacher_embeddings = self.embedding_fn(decoder_teacher_tokens).to(device=device) |
| | decoder_teacher_pos_embeddings = self.positional_encoding(decoder_teacher_embeddings, max_len=decoder_teacher_embeddings.shape[0], embedding_dim=64).to(device=device) |
| | decoder_masked_attention_out, _ = self.decoder_masked_attention.forward( |
| | q_embedding=decoder_teacher_pos_embeddings, |
| | k_embedding=decoder_teacher_pos_embeddings, |
| | v_embedding=decoder_teacher_pos_embeddings, |
| | attention_mask=None |
| | ) |
| | add_norm_decoder_masked_attention_out = self.add_norm(old_tensor=decoder_teacher_pos_embeddings, new_tensor=decoder_masked_attention_out.to(device=device)).to(device=device) |
| | cross_attention_out, _ = self.cross_attention.forward( |
| | q_embedding=add_norm_decoder_masked_attention_out, |
| | k_embedding=encoder_out, |
| | v_embedding=encoder_out, |
| | attention_mask=None |
| | ) |
| | add_norm_cross_attention_out = self.add_norm(old_tensor=add_norm_decoder_masked_attention_out, new_tensor=cross_attention_out.to(device=device)).to(device=device) |
| | decoder_ff_out = self.decoder_ff.forward(add_norm_cross_attention_out).to(device=device) |
| | add_norm_decoder_ff_out = self.add_norm(old_tensor=add_norm_cross_attention_out, new_tensor=decoder_ff_out).to(device=device) |
| | logits = self.linear.forward(add_norm_decoder_ff_out).to(device=device) |
| | return logits |
| | |
| | |
| | def forward(self, encoder_input_tokens, decoder_teacher_tokens): |
| | encoder_out = self.encoder(encoder_input_tokens) |
| | decoder_out = self.decoder(decoder_teacher_tokens, encoder_out=encoder_out) |
| | return decoder_out |
| | |
| | |
| | def train(self, dataset, epochs=100): |
| | for epoch in range(epochs): |
| | total_loss = 0 |
| | for input_text, output_text in dataset: |
| | encoder_input_text = self.remove_punctuation(input_text) |
| | target_text = self.remove_punctuation(output_text) |
| | decoder_teacher_text = "<SOS> " + target_text |
| | decoder_target_text = target_text + " <EOS>" |
| |
|
| | encoder_input_tokens = self.tokenize(encoder_input_text) |
| | decoder_teacher_tokens = self.tokenize(decoder_teacher_text) |
| | decoder_target_tokens = self.tokenize(decoder_target_text) |
| |
|
| | self.optimizer.zero_grad() |
| | logits = self.forward(encoder_input_tokens=encoder_input_tokens, decoder_teacher_tokens=decoder_teacher_tokens).to(device=device) |
| | loss = F.cross_entropy(logits, decoder_target_tokens) |
| | loss.backward() |
| | self.optimizer.step() |
| |
|
| | total_loss += loss.item() |
| | |
| | if (epoch+1) % 10 == 0: |
| | print(f"Epoch {epoch+1:04d} - Loss: {total_loss:.4f}") |
| |
|
| | print("*** END ***\n") |
| |
|
| | |
| | def predict_tokens(self, encoder_input_tokens, max_output_len=20): |
| | encoder_out = self.encoder(encoder_input_tokens).to(device=device) |
| | decoder_input = [self.vocab["<SOS>"]] |
| | for _ in range(max_output_len): |
| | current_decoder_tokens = torch.tensor(decoder_input).to(device=device) |
| | pred_index = torch.argmax(self.decoder(current_decoder_tokens, encoder_out).to(device=device)[-1, :]).item() |
| | decoder_input.append(pred_index) |
| | if pred_index == self.vocab["<EOS>"]: |
| | break |
| | return decoder_input |
| | |
| | |
| | def predict_text(self, encoder_input_tokens): |
| | return ' '.join( |
| | [self.vocab_df[self.vocab_df['idx'] == token].index.values[0] \ |
| | for token in self.predict_tokens(encoder_input_tokens=encoder_input_tokens)] |
| | ) |
| |
|