flow-matching / src /stage2 /source_ve.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import math
from typing import Optional
import torch
import torch.nn as nn
class SourceVE(nn.Module):
"""Perceiver-style variational encoder for condition-dependent source sampling."""
def __init__(
self,
context_dim: int,
output_dim: int,
hidden_dim: int,
depth: int = 4,
num_heads: int = 8,
num_queries: int = 16,
dropout: float = 0.1,
use_variational: bool = True,
init_logvar: float = 1.0,
fixed_std: Optional[float] = None,
):
super().__init__()
self.num_queries = num_queries
self.hidden_dim = hidden_dim
self.use_variational = use_variational
self.fixed_std = fixed_std
self.query_tokens = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02)
self.query_pos_emb = nn.Parameter(torch.randn(1, num_queries, hidden_dim) * 0.02)
self.input_proj = (
nn.Linear(context_dim, hidden_dim) if context_dim != hidden_dim else nn.Identity()
)
self.layers = nn.ModuleList()
for _ in range(depth):
self.layers.append(
nn.ModuleDict(
{
"norm_q": nn.LayerNorm(hidden_dim),
"norm_kv": nn.LayerNorm(hidden_dim),
"cross_attn": nn.MultiheadAttention(
hidden_dim,
num_heads,
dropout=dropout,
batch_first=True,
),
"norm_sa": nn.LayerNorm(hidden_dim),
"self_attn": nn.MultiheadAttention(
hidden_dim,
num_heads,
dropout=dropout,
batch_first=True,
),
"norm_ffn": nn.LayerNorm(hidden_dim),
"ffn": nn.Sequential(
nn.Linear(hidden_dim, hidden_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 4, hidden_dim),
nn.Dropout(dropout),
),
}
)
)
self.norm = nn.LayerNorm(hidden_dim)
self.mean_head = nn.Linear(hidden_dim, output_dim)
if use_variational and fixed_std is None:
self.log_var_head = nn.Linear(hidden_dim, output_dim)
else:
self.log_var_head = None
self._init_weights(init_logvar)
def _init_weights(self, init_logvar: float):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
if self.log_var_head is not None:
nn.init.normal_(self.log_var_head.weight, std=1e-4)
nn.init.constant_(self.log_var_head.bias, init_logvar)
def forward(
self, context: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Map encoded context to source sample and distribution parameters."""
if context.ndim != 3:
raise ValueError(
f"Expected context with shape (B, T, D), got {tuple(context.shape)}"
)
batch_size = context.shape[0]
kv = self.input_proj(context)
queries = self.query_tokens.expand(batch_size, -1, -1) + self.query_pos_emb
for layer in self.layers:
q_norm = layer["norm_q"](queries)
kv_norm = layer["norm_kv"](kv)
attn_out, _ = layer["cross_attn"](q_norm, kv_norm, kv_norm, need_weights=False)
queries = queries + attn_out
sa_norm = layer["norm_sa"](queries)
sa_out, _ = layer["self_attn"](sa_norm, sa_norm, sa_norm, need_weights=False)
queries = queries + sa_out
queries = queries + layer["ffn"](layer["norm_ffn"](queries))
pooled = self.norm(queries).mean(dim=1)
mu = self.mean_head(pooled)
if self.use_variational:
if self.fixed_std is not None:
log_var = torch.full_like(mu, math.log(self.fixed_std**2))
else:
log_var = self.log_var_head(pooled)
else:
log_var = None
if log_var is not None and self.training:
std = torch.exp(0.5 * log_var)
x0 = mu + torch.randn_like(mu) * std
else:
x0 = mu
return x0, mu, log_var
def var_kld_loss(
mu: torch.Tensor,
log_var: torch.Tensor,
target_std: float = 1.0,
) -> torch.Tensor:
"""Variance-only KLD regularization used in CSFM."""
var = log_var.exp()
if target_std != 1.0:
sigma2_target = target_std**2
var = var / sigma2_target
log_var = log_var - math.log(sigma2_target)
return -0.5 * torch.mean(1 + log_var - var)