| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| |
| |
| |
| class RemaskingHead(nn.Module): |
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.norm = nn.LayerNorm(hidden_size) |
| self.proj1 = nn.Linear(hidden_size, hidden_size) |
| self.act = nn.GELU() |
| self.proj2 = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.norm(x) |
| h = self.proj1(h) |
| h = self.act(h) |
| h = self.proj2(h) |
| return h |
|
|
|
|
| class InsertionQualityHead(nn.Module): |
| def __init__(self, hidden_size: int): |
| super().__init__() |
| self.norm = nn.LayerNorm(hidden_size) |
| self.proj1 = nn.Linear(hidden_size, hidden_size) |
| self.act = nn.GELU() |
| self.proj2 = nn.Linear(hidden_size, 1) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.norm(x) |
| h = self.proj1(h) |
| h = self.act(h) |
| h = self.proj2(h) |
| return h |
|
|
|
|
| class RemaskingAnyOrder(nn.Module): |
| """Remasking adapter for AnyOrderMaskInsertionFlow models.""" |
| def __init__(self, backbone: nn.Module, d_model: int, insertion_planner: bool = False): |
| super().__init__() |
| |
| |
| |
| object.__setattr__(self, 'backbone', backbone) |
| self.d_model = d_model |
| self.insertion_planner = insertion_planner |
| self.remasking_head = RemaskingHead(d_model) |
| |
| if insertion_planner: |
| self.insertion_head = InsertionQualityHead(d_model) |
| |
| def forward(self, indices: torch.Tensor, t: torch.Tensor, **kwargs): |
| """ |
| Forward pass for remasking training. |
| |
| Args: |
| indices: Token indices [batch_size, seq_len] |
| t: Timesteps [batch_size] |
| **kwargs: Additional arguments (ignored for compatibility) |
| |
| Returns: |
| Dict with 'logits', 'remasking_conf', and optionally 'insertion_conf' keys |
| """ |
| |
| |
| |
| prediction, features = self.backbone(indices, t, return_features=True) |
| hidden_states = features[:, :-1] |
|
|
| remasking_conf = self.remasking_head(hidden_states) |
| token_logits = prediction.token_logits |
|
|
| result = {"logits": token_logits, "remasking_conf": remasking_conf} |
|
|
| if self.insertion_planner: |
| insertion_conf = self.insertion_head(hidden_states) |
| result["insertion_conf"] = insertion_conf |
|
|
| return result |
| |
| def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor): |
| """ |
| Get hidden states and logits for adapter training. |
| |
| Args: |
| indices: Token indices [batch_size, seq_len] |
| t: Timesteps [batch_size] |
| |
| Returns: |
| Tuple of (token_logits, hidden_states, conditioning) |
| """ |
| return self.backbone.get_hidden_states(indices, t) |
| |
| @property |
| def device(self): |
| return next(self.backbone.parameters()).device |
|
|