import math from typing import Optional import torch import torch.nn as nn class SourceVE(nn.Module): """Perceiver-style variational encoder for condition-dependent source sampling.""" def __init__( self, context_dim: int, output_dim: int, hidden_dim: int, depth: int = 4, num_heads: int = 8, num_queries: int = 16, dropout: float = 0.1, use_variational: bool = True, init_logvar: float = 1.0, fixed_std: Optional[float] = None, ): super().__init__() self.num_queries = num_queries self.hidden_dim = hidden_dim self.use_variational = use_variational self.fixed_std = fixed_std self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02) self.query_pos_emb = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02) self.input_proj = ( nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity() ) self.layers = nn.ModuleList() for _ in range(depth): self.layers.append( nn.ModuleDict( { "norm_q": nn.LayerNorm(hidden_dim), "norm_kv": nn.LayerNorm(hidden_dim), "cross_attn": nn.MultiheadAttention( hidden_dim, num_heads, dropout=dropout, batch_first=True, ), "norm_sa": nn.LayerNorm(hidden_dim), "self_attn": nn.MultiheadAttention( hidden_dim, num_heads, dropout=dropout, batch_first=True, ), "norm_ffn": nn.LayerNorm(hidden_dim), "ffn": nn.Sequential( nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim), nn.Dropout(dropout), ), } ) ) self.norm = nn.LayerNorm(hidden_dim) self.mean_head = nn.Linear(hidden_dim, output_dim) if use_variational and fixed_std is None: self.log_var_head = nn.Linear(hidden_dim, output_dim) else: self.log_var_head = None self._init_weights(init_logvar) def _init_weights(self, init_logvar: float): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) if self.log_var_head is not None: nn.init.normal_(self.log_var_head.weight, std=1e-4) nn.init.constant_(self.log_var_head.bias, init_logvar) def forward( self, context: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Map encoded context to source sample and distribution parameters.""" if context.ndim != 3: raise ValueError( f"Expected context with shape (B, T, D), got {tuple(context.shape)}" ) batch_size = context.shape[0] kv = self.input_proj(context) queries = self.query_tokens.expand(batch_size, -1, -1) + self.query_pos_emb for layer in self.layers: q_norm = layer["norm_q"](queries) kv_norm = layer["norm_kv"](kv) attn_out, _ = layer["cross_attn"](q_norm, kv_norm, kv_norm, need_weights=False) queries = queries + attn_out sa_norm = layer["norm_sa"](queries) sa_out, _ = layer["self_attn"](sa_norm, sa_norm, sa_norm, need_weights=False) queries = queries + sa_out queries = queries + layer["ffn"](layer["norm_ffn"](queries)) pooled = self.norm(queries).mean(dim=1) mu = self.mean_head(pooled) if self.use_variational: if self.fixed_std is not None: log_var = torch.full_like(mu, math.log(self.fixed_std**2)) else: log_var = self.log_var_head(pooled) else: log_var = None if log_var is not None and self.training: std = torch.exp(0.5 * log_var) x0 = mu + torch.randn_like(mu) * std else: x0 = mu return x0, mu, log_var def var_kld_loss( mu: torch.Tensor, log_var: torch.Tensor, target_std: float = 1.0, ) -> torch.Tensor: """Variance-only KLD regularization used in CSFM.""" var = log_var.exp() if target_std != 1.0: sigma2_target = target_std**2 var = var / sigma2_target log_var = log_var - math.log(sigma2_target) return -0.5 * torch.mean(1 + log_var - var)