| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
| import math
|
|
|
|
|
| class PositionalEncoding(nn.Module):
|
| def __init__(self, d_model, max_len=512):
|
| super().__init__()
|
| pe = torch.zeros(max_len, d_model)
|
| position = torch.arange(0, max_len).unsqueeze(1).float()
|
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
|
| pe[:, 0::2] = torch.sin(position * div_term)
|
| pe[:, 1::2] = torch.cos(position * div_term)
|
| pe = pe.unsqueeze(0)
|
| self.register_buffer('pe', pe)
|
|
|
| def forward(self, x):
|
|
|
| return x + self.pe[:, :x.size(1)].to(x.device)
|
|
|
|
|
| class TransformerSeq2Seq(nn.Module):
|
| """
|
| Thiết kế sao cho forward(src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None)
|
| - src_embedded: (B, S, E) — bạn có thể pass embedding matrix bên ngoài (embedding_src[src_ids])
|
| - tgt_input_ids: (B, T) — token ids cho decoder input (BOS.. token_{n-1})
|
| - src_attn_mask / tgt_attn_mask: (B, S) / (B, T) with 1 for real tokens, 0 for pad
|
| """
|
| def __init__(self,
|
| embed_dim,
|
| vocab_size,
|
| embedding_decoder=None,
|
| num_heads=2,
|
| num_layers=2,
|
| dim_feedforward=256,
|
| dropout=0.1,
|
| freeze_decoder_emb=True,
|
| max_len=512):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.vocab_size = vocab_size
|
|
|
|
|
| self.pos_encoder = PositionalEncoding(embed_dim, max_len=max_len)
|
|
|
|
|
| if embedding_decoder is None:
|
| self.embedding_decoder = nn.Embedding(vocab_size, embed_dim)
|
| else:
|
| if not isinstance(embedding_decoder, torch.Tensor):
|
| embedding_decoder = torch.tensor(embedding_decoder, dtype=torch.float)
|
| self.embedding_decoder = nn.Embedding.from_pretrained(embedding_decoder, freeze=freeze_decoder_emb)
|
|
|
|
|
| self.encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
|
| dim_feedforward=dim_feedforward, dropout=dropout,
|
| batch_first=True)
|
| self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
|
|
|
| self.decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads,
|
| dim_feedforward=dim_feedforward, dropout=dropout,
|
| batch_first=True)
|
| self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
|
|
|
| self.output_proj = nn.Linear(embed_dim, vocab_size)
|
|
|
| def forward(self, src_embedded, tgt_input_ids, src_attn_mask=None, tgt_attn_mask=None):
|
| """
|
| src_embedded : (B, S, E)
|
| tgt_input_ids: (B, T)
|
| src_attn_mask : (B, S) mask: 1 real token, 0 pad (optional)
|
| tgt_attn_mask : (B, T) same
|
| """
|
| device = src_embedded.device
|
|
|
| tgt_embedded = self.embedding_decoder(tgt_input_ids)
|
|
|
|
|
| src = self.pos_encoder(src_embedded)
|
| tgt = self.pos_encoder(tgt_embedded)
|
|
|
|
|
| src_key_padding_mask = None
|
| tgt_key_padding_mask = None
|
| if src_attn_mask is not None:
|
| src_key_padding_mask = (src_attn_mask == 0).to(device)
|
| if tgt_attn_mask is not None:
|
| tgt_key_padding_mask = (tgt_attn_mask == 0).to(device)
|
|
|
|
|
| memory = self.encoder(src, src_key_padding_mask=src_key_padding_mask)
|
|
|
|
|
| T = tgt.size(1)
|
| if T > 0:
|
| tgt_mask = torch.triu(torch.full((T, T), float('-inf'), device=device), diagonal=1)
|
| else:
|
| tgt_mask = None
|
|
|
|
|
| output = self.decoder(tgt, memory,
|
| tgt_mask=tgt_mask,
|
| tgt_key_padding_mask=tgt_key_padding_mask,
|
| memory_key_padding_mask=src_key_padding_mask)
|
|
|
| logits = self.output_proj(output)
|
| return logits
|
|
|
|
|
| def apply_src_embedding(embedding_src, src_ids):
|
| """
|
| embedding_src can be:
|
| - torch.Tensor of shape (vocab_src, embed_dim) -> indexing
|
| - nn.Embedding instance -> call( ids )
|
| src_ids: LongTensor (B, S)
|
| return: (B, S, E) float tensor on same device as src_ids
|
| """
|
| if isinstance(embedding_src, nn.Embedding):
|
| return embedding_src(src_ids)
|
| else:
|
|
|
| if not isinstance(embedding_src, torch.Tensor):
|
| embedding_src = torch.tensor(embedding_src, dtype=torch.float, device=src_ids.device)
|
| else:
|
| embedding_src = embedding_src.to(src_ids.device)
|
| return embedding_src[src_ids]
|
| @torch.no_grad()
|
| def translate(model, src_sentence, tokenizer_src, tokenizer_tgt, embedding_src, device, max_len=50):
|
| model.eval()
|
| inputs = tokenizer_src([src_sentence], return_tensors="pt", padding=True, truncation=True, max_length=128)
|
| src_ids = inputs["input_ids"].to(device)
|
| src_attn = inputs.get("attention_mask", None)
|
| if src_attn is not None:
|
| src_attn = src_attn.to(device)
|
|
|
| src_embedded = apply_src_embedding(embedding_src, src_ids)
|
|
|
| decoded_ids = [tokenizer_tgt.cls_token_id]
|
| for _ in range(max_len):
|
| decoder_input = torch.tensor([decoded_ids], device=device)
|
|
|
| logits = model(src_embedded, decoder_input, src_attn_mask=src_attn, tgt_attn_mask=None)
|
| next_token = logits[:, -1, :].argmax(dim=-1).item()
|
| decoded_ids.append(next_token)
|
| if next_token == tokenizer_tgt.sep_token_id:
|
| break
|
|
|
| return tokenizer_tgt.decode(decoded_ids, skip_special_tokens=True) |