| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class SourceVE(nn.Module): |
| """Perceiver-style variational encoder for condition-dependent source sampling.""" |
|
|
| def __init__( |
| self, |
| context_dim: int, |
| output_dim: int, |
| hidden_dim: int, |
| depth: int = 4, |
| num_heads: int = 8, |
| num_queries: int = 16, |
| dropout: float = 0.1, |
| use_variational: bool = True, |
| init_logvar: float = 1.0, |
| fixed_std: Optional[float] = None, |
| ): |
| super().__init__() |
| self.num_queries = num_queries |
| self.hidden_dim = hidden_dim |
| self.use_variational = use_variational |
| self.fixed_std = fixed_std |
|
|
| self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02) |
| self.query_pos_emb = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02) |
|
|
| self.input_proj = ( |
| nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity() |
| ) |
|
|
| self.layers = nn.ModuleList() |
| for _ in range(depth): |
| self.layers.append( |
| nn.ModuleDict( |
| { |
| "norm_q": nn.LayerNorm(hidden_dim), |
| "norm_kv": nn.LayerNorm(hidden_dim), |
| "cross_attn": nn.MultiheadAttention( |
| hidden_dim, |
| num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ), |
| "norm_sa": nn.LayerNorm(hidden_dim), |
| "self_attn": nn.MultiheadAttention( |
| hidden_dim, |
| num_heads, |
| dropout=dropout, |
| batch_first=True, |
| ), |
| "norm_ffn": nn.LayerNorm(hidden_dim), |
| "ffn": nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 4), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 4, hidden_dim), |
| nn.Dropout(dropout), |
| ), |
| } |
| ) |
| ) |
|
|
| self.norm = nn.LayerNorm(hidden_dim) |
| self.mean_head = nn.Linear(hidden_dim, output_dim) |
|
|
| if use_variational and fixed_std is None: |
| self.log_var_head = nn.Linear(hidden_dim, output_dim) |
| else: |
| self.log_var_head = None |
|
|
| self._init_weights(init_logvar) |
|
|
| def _init_weights(self, init_logvar: float): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| if self.log_var_head is not None: |
| nn.init.normal_(self.log_var_head.weight, std=1e-4) |
| nn.init.constant_(self.log_var_head.bias, init_logvar) |
|
|
| def forward( |
| self, context: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| """Map encoded context to source sample and distribution parameters.""" |
| if context.ndim != 3: |
| raise ValueError( |
| f"Expected context with shape (B, T, D), got {tuple(context.shape)}" |
| ) |
|
|
| batch_size = context.shape[0] |
|
|
| kv = self.input_proj(context) |
| queries = self.query_tokens.expand(batch_size, -1, -1) + self.query_pos_emb |
|
|
| for layer in self.layers: |
| q_norm = layer["norm_q"](queries) |
| kv_norm = layer["norm_kv"](kv) |
| attn_out, _ = layer["cross_attn"](q_norm, kv_norm, kv_norm, need_weights=False) |
| queries = queries + attn_out |
|
|
| sa_norm = layer["norm_sa"](queries) |
| sa_out, _ = layer["self_attn"](sa_norm, sa_norm, sa_norm, need_weights=False) |
| queries = queries + sa_out |
|
|
| queries = queries + layer["ffn"](layer["norm_ffn"](queries)) |
|
|
| pooled = self.norm(queries).mean(dim=1) |
| mu = self.mean_head(pooled) |
|
|
| if self.use_variational: |
| if self.fixed_std is not None: |
| log_var = torch.full_like(mu, math.log(self.fixed_std**2)) |
| else: |
| log_var = self.log_var_head(pooled) |
| else: |
| log_var = None |
|
|
| if log_var is not None and self.training: |
| std = torch.exp(0.5 * log_var) |
| x0 = mu + torch.randn_like(mu) * std |
| else: |
| x0 = mu |
|
|
| return x0, mu, log_var |
|
|
|
|
| def var_kld_loss( |
| mu: torch.Tensor, |
| log_var: torch.Tensor, |
| target_std: float = 1.0, |
| ) -> torch.Tensor: |
| """Variance-only KLD regularization used in CSFM.""" |
| var = log_var.exp() |
| if target_std != 1.0: |
| sigma2_target = target_std**2 |
| var = var / sigma2_target |
| log_var = log_var - math.log(sigma2_target) |
|
|
| return -0.5 * torch.mean(1 + log_var - var) |
|
|