A2D2 / model /model_wrapper.py
Sophia
initial commit
8019be0
Raw
History Blame Contribute Delete
3.55 kB
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------------------------------------
# additional sigmoid head
# ------------------------------------------------------------
class RemaskingHead(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.norm = nn.LayerNorm(hidden_size)
self.proj1 = nn.Linear(hidden_size, hidden_size)
self.act = nn.GELU()
self.proj2 = nn.Linear(hidden_size, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = self.proj1(h)
h = self.act(h)
h = self.proj2(h)
return h
class InsertionQualityHead(nn.Module):
def __init__(self, hidden_size: int):
super().__init__()
self.norm = nn.LayerNorm(hidden_size)
self.proj1 = nn.Linear(hidden_size, hidden_size)
self.act = nn.GELU()
self.proj2 = nn.Linear(hidden_size, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = self.norm(x)
h = self.proj1(h)
h = self.act(h)
h = self.proj2(h)
return h
class RemaskingAnyOrder(nn.Module):
"""Remasking adapter for AnyOrderMaskInsertionFlow models."""
def __init__(self, backbone: nn.Module, d_model: int, insertion_planner: bool = False):
super().__init__()
# Store backbone as non-module attribute to avoid circular reference in module tree
# Use object.__setattr__ to bypass nn.Module's __setattr__ which registers modules
object.__setattr__(self, 'backbone', backbone)
self.d_model = d_model
self.insertion_planner = insertion_planner
self.remasking_head = RemaskingHead(d_model)
if insertion_planner:
self.insertion_head = InsertionQualityHead(d_model)
def forward(self, indices: torch.Tensor, t: torch.Tensor, **kwargs):
"""
Forward pass for remasking training.
Args:
indices: Token indices [batch_size, seq_len]
t: Timesteps [batch_size]
**kwargs: Additional arguments (ignored for compatibility)
Returns:
Dict with 'logits', 'remasking_conf', and optionally 'insertion_conf' keys
"""
# Single backbone pass returning both prediction and post-block features.
# features has shape [B, L+1, hidden]; the remasking/insertion heads use
# the same [B, L, hidden] slice that get_hidden_states would have returned.
prediction, features = self.backbone(indices, t, return_features=True)
hidden_states = features[:, :-1]
remasking_conf = self.remasking_head(hidden_states)
token_logits = prediction.token_logits
result = {"logits": token_logits, "remasking_conf": remasking_conf}
if self.insertion_planner:
insertion_conf = self.insertion_head(hidden_states)
result["insertion_conf"] = insertion_conf
return result
def get_hidden_states(self, indices: torch.Tensor, t: torch.Tensor):
"""
Get hidden states and logits for adapter training.
Args:
indices: Token indices [batch_size, seq_len]
t: Timesteps [batch_size]
Returns:
Tuple of (token_logits, hidden_states, conditioning)
"""
return self.backbone.get_hidden_states(indices, t)
@property
def device(self):
return next(self.backbone.parameters()).device