File size: 5,085 Bytes
4edc9aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import math
from typing import Optional
import torch
import torch.nn as nn
class SourceVE(nn.Module):
"""Perceiver-style variational encoder for condition-dependent source sampling."""
def __init__(
self,
context_dim: int,
output_dim: int,
hidden_dim: int,
depth: int = 4,
num_heads: int = 8,
num_queries: int = 16,
dropout: float = 0.1,
use_variational: bool = True,
init_logvar: float = 1.0,
fixed_std: Optional[float] = None,
):
super().__init__()
self.num_queries = num_queries
self.hidden_dim = hidden_dim
self.use_variational = use_variational
self.fixed_std = fixed_std
self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02)
self.query_pos_emb = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02)
self.input_proj = (
nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity()
)
self.layers = nn.ModuleList()
for _ in range(depth):
self.layers.append(
nn.ModuleDict(
{
"norm_q": nn.LayerNorm(hidden_dim),
"norm_kv": nn.LayerNorm(hidden_dim),
"cross_attn": nn.MultiheadAttention(
hidden_dim,
num_heads,
dropout=dropout,
batch_first=True,
),
"norm_sa": nn.LayerNorm(hidden_dim),
"self_attn": nn.MultiheadAttention(
hidden_dim,
num_heads,
dropout=dropout,
batch_first=True,
),
"norm_ffn": nn.LayerNorm(hidden_dim),
"ffn": nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout),
),
}
)
)
self.norm = nn.LayerNorm(hidden_dim)
self.mean_head = nn.Linear(hidden_dim, output_dim)
if use_variational and fixed_std is None:
self.log_var_head = nn.Linear(hidden_dim, output_dim)
else:
self.log_var_head = None
self._init_weights(init_logvar)
def _init_weights(self, init_logvar: float):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
if self.log_var_head is not None:
nn.init.normal_(self.log_var_head.weight, std=1e-4)
nn.init.constant_(self.log_var_head.bias, init_logvar)
def forward(
self, context: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Map encoded context to source sample and distribution parameters."""
if context.ndim != 3:
raise ValueError(
f"Expected context with shape (B, T, D), got {tuple(context.shape)}"
)
batch_size = context.shape[0]
kv = self.input_proj(context)
queries = self.query_tokens.expand(batch_size, -1, -1) + self.query_pos_emb
for layer in self.layers:
q_norm = layer["norm_q"](queries)
kv_norm = layer["norm_kv"](kv)
attn_out, _ = layer["cross_attn"](q_norm, kv_norm, kv_norm, need_weights=False)
queries = queries + attn_out
sa_norm = layer["norm_sa"](queries)
sa_out, _ = layer["self_attn"](sa_norm, sa_norm, sa_norm, need_weights=False)
queries = queries + sa_out
queries = queries + layer["ffn"](layer["norm_ffn"](queries))
pooled = self.norm(queries).mean(dim=1)
mu = self.mean_head(pooled)
if self.use_variational:
if self.fixed_std is not None:
log_var = torch.full_like(mu, math.log(self.fixed_std**2))
else:
log_var = self.log_var_head(pooled)
else:
log_var = None
if log_var is not None and self.training:
std = torch.exp(0.5 * log_var)
x0 = mu + torch.randn_like(mu) * std
else:
x0 = mu
return x0, mu, log_var
def var_kld_loss(
mu: torch.Tensor,
log_var: torch.Tensor,
target_std: float = 1.0,
) -> torch.Tensor:
"""Variance-only KLD regularization used in CSFM."""
var = log_var.exp()
if target_std != 1.0:
sigma2_target = target_std**2
var = var / sigma2_target
log_var = log_var - math.log(sigma2_target)
return -0.5 * torch.mean(1 + log_var - var)
|