| | from model_components import Block |
| | from constants import * |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from utils import tokenizer, vocab_size |
| |
|
| | class DecoderLanguageModel(nn.Module): |
| | """ |
| | Transformer Decoder Language Model with optional coordinate regression head. |
| | Processes a combined sequence of embeddings. |
| | Outputs logits for token prediction and optionally regressed coordinates (for MAX_POINTS). |
| | """ |
| | def __init__(self, n_embd=HIDDEN_DIM, vocab_size=vocab_size, num_heads=NUM_HEADS, |
| | n_layer=NUM_LAYERS, max_context=CONTEXT_LENGTH, dropout=DROPOUT): |
| | super().__init__() |
| | |
| | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) |
| | self.position_embedding_table = nn.Embedding(max_context, n_embd) |
| | self.dropout = nn.Dropout(dropout) |
| |
|
| | |
| | self.blocks = nn.ModuleList([ |
| | Block(n_embd, num_heads, dropout, is_decoder=True) |
| | for _ in range(n_layer) |
| | ]) |
| |
|
| | |
| | self.ln_f = nn.LayerNorm(n_embd) |
| |
|
| | |
| | |
| | self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) |
| |
|
| | |
| | self.regression_head = nn.Sequential( |
| | nn.Linear(n_embd, n_embd // 2), |
| | nn.GELU(), |
| | nn.Linear(n_embd // 2, MAX_POINTS * 2), |
| | nn.Sigmoid() |
| | ) |
| | |
| |
|
| | self.n_embd = n_embd |
| | self.max_context = max_context |
| | self.token_embedding_table.weight = self.lm_head.weight |
| | self.apply(self._init_weights) |
| | print(f"DecoderLanguageModel initialized with {n_layer} layers.") |
| |
|
| | def _init_weights(self, module): |
| | |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | elif isinstance(module, nn.LayerNorm): |
| | torch.nn.init.zeros_(module.bias) |
| | torch.nn.init.ones_(module.weight) |
| |
|
| |
|
| | def forward(self, combined_embeds, attention_mask=None, targets=None): |
| | """ |
| | Forward pass for training or inference where loss is calculated. |
| | Regression output is now handled *outside* this module by VLM. |
| | """ |
| | |
| | if combined_embeds.ndim != 3: |
| | raise ValueError(f"DecoderLM received non-3D combined_embeds! Shape: {combined_embeds.shape}") |
| | B, T, C = combined_embeds.shape |
| | if T > self.max_context: |
| | |
| | print(f"WARNING (Decoder forward): Input sequence length {T} > max context {self.max_context}. Truncating.") |
| | combined_embeds = combined_embeds[:, -self.max_context:, :] |
| | if attention_mask is not None: attention_mask = attention_mask[:, -self.max_context:] |
| | if targets is not None: targets = targets[:, -self.max_context:] |
| | T = self.max_context |
| |
|
| | |
| | pos = torch.arange(0, T, dtype=torch.long, device=combined_embeds.device) |
| | pos = pos.clamp(max=self.position_embedding_table.num_embeddings - 1) |
| | pos_emb = self.position_embedding_table(pos) |
| | x = combined_embeds + pos_emb.unsqueeze(0) |
| | x = self.dropout(x) |
| |
|
| | |
| | for block in self.blocks: |
| | x = block(x, attention_mask=attention_mask) |
| |
|
| | |
| | x_norm = self.ln_f(x) |
| |
|
| | |
| | logits = self.lm_head(x_norm) |
| |
|
| | |
| | class_loss = None |
| | if targets is not None: |
| | |
| | try: |
| | class_loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | targets.view(-1), |
| | ignore_index=-100 |
| | ) |
| | if torch.isnan(class_loss): |
| | print("Warning: class_loss is NaN.") |
| | class_loss = None |
| | except Exception as e: |
| | print(f"Error calculating cross_entropy: {e}") |
| | print(f"Logits shape: {logits.shape}, Targets shape: {targets.shape}") |
| | class_loss = None |
| |
|
| | |
| | return logits, class_loss, x_norm |
| |
|
| | |
| | |
| | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| | """ |
| | Autoregressive generation based on starting token IDs. |
| | NOTE: This version doesn't handle combined embeddings directly. |
| | The VisionLanguageModel should ideally use a method like |
| | generate_from_embeddings or implement the loop externally. |
| | """ |
| | self.eval() |
| | for _ in range(max_new_tokens): |
| | |
| | |
| | idx_cond = idx if idx.size(1) <= self.max_context else idx[:, -self.max_context:] |
| |
|
| | |
| | |
| | tok_embeds = self.token_embedding_table(idx_cond) |
| | |
| | pos = torch.arange(0, idx_cond.size(1), dtype=torch.long, device=idx.device) |
| | pos = pos.clamp(max=self.max_context - 1) |
| | pos_emb = self.position_embedding_table(pos).unsqueeze(0) |
| | x = self.dropout(tok_embeds + pos_emb) |
| | |
| | for block in self.blocks: |
| | x = block(x, attention_mask=None) |
| | |
| | x = self.ln_f(x[:, -1:, :]) |
| | logits = self.lm_head(x) |
| | logits = logits.squeeze(1) |
| |
|
| | |
| | logits = logits / temperature |
| | if top_k is not None and top_k > 0: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| | probs = F.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| |
|
| | |
| | idx = torch.cat((idx, idx_next), dim=1) |
| |
|
| | |
| | if hasattr(tokenizer, 'eos_token_id') and (idx_next == tokenizer.eos_token_id).all(): |
| | break |
| | self.train() |
| | return idx |
| | |
| |
|