import torch import torch.nn as nn import torch.nn.functional as F from abc import ABC from typing import Optional from .source_ve import SourceVE, var_kld_loss from .transport import Sampler, create_transport from .velocity_net import VelocityNet def _cfg_get(cfg, key, default): if cfg is None: return default get_fn = getattr(cfg, "get", None) if callable(get_fn): try: return get_fn(key, default) except Exception: pass try: return getattr(cfg, key) except Exception: return default class BASECFM(nn.Module, ABC): def __init__(self, feat_dim: int, cfm_params): super().__init__() self.feat_dim = feat_dim self.kld_weight = float(_cfg_get(cfm_params, "kld_weight", 3.0)) self.kld_target_std = float(_cfg_get(cfm_params, "kld_target_std", 1.0)) self.detach_ut = bool(_cfg_get(cfm_params, "detach_ut", False)) self.solver = str(_cfg_get(cfm_params, "solver", "euler")) self.estimator: Optional[nn.Module] = None self.src_gen: Optional[nn.Module] = None self.transport = None @staticmethod def _flatten_bvt(x: torch.Tensor) -> tuple[torch.Tensor, int, int, int]: if x.ndim != 3: raise ValueError( f"Expected tensor with shape (B, V, T), got {tuple(x.shape)}" ) bsz, voxels, time_steps = x.shape flat = x.transpose(1, 2).contiguous().reshape(bsz * time_steps, voxels) return flat, bsz, voxels, time_steps @staticmethod def _unflatten_bvt(x_flat: torch.Tensor, batch_size: int, time_steps: int) -> torch.Tensor: return x_flat.reshape(batch_size, time_steps, -1).transpose(1, 2).contiguous() @staticmethod def _batch_indices(batch_size: int, time_steps: int, device: torch.device) -> torch.Tensor: return torch.arange(batch_size, device=device).repeat_interleave(time_steps) def _prepare_context(self, mu: torch.Tensor): context = mu.transpose(1, 2).contiguous() # (B, T, V) context_encoded = self.estimator.encode_context(context) # (B, T, H) bsz, time_steps, _ = context_encoded.shape batch_idx = self._batch_indices(bsz, time_steps, mu.device) context_for_tokens = context_encoded[batch_idx] # (B*T, T, H) return context_for_tokens, bsz, time_steps # ---- inference ----------------------------------------------------------- @torch.inference_mode() def forward( self, mu: torch.Tensor, # (B, feat_dim, L) n_timesteps: int, temperature: float = 1.0, ) -> torch.Tensor: return self.synthesise( mu=mu, n_timesteps=n_timesteps, solver_method=self.solver, temperature=temperature, ) @torch.inference_mode() def synthesise( self, mu: torch.Tensor, n_timesteps: int = 50, solver_method: Optional[str] = None, temperature: float = 1.0, ) -> torch.Tensor: context_for_tokens, bsz, time_steps = self._prepare_context(mu) _, src_mu, log_var = self.src_gen(context_for_tokens) if log_var is not None and temperature > 0: std = torch.exp(0.5 * log_var) x0 = src_mu + torch.randn_like(src_mu) * std * temperature else: x0 = src_mu sampler = Sampler(self.transport) sample_fn = sampler.sample_ode( sampling_method=solver_method or self.solver, num_steps=n_timesteps, ) def model_fn(x, t, **kwargs): return self.estimator( x=x, t=t, pre_encoded_context=context_for_tokens, ) trajectory = sample_fn(x0, model_fn) pred_flat = trajectory[-1] return self._unflatten_bvt(pred_flat, bsz, time_steps) # ---- training ------------------------------------------------------------ def compute_loss( self, x1: torch.Tensor, # (B, feat_dim, L) mu: torch.Tensor, # (B, feat_dim, L) ) -> tuple: if x1.shape != mu.shape: raise ValueError( f"x1 and mu must share shape (B, V, T), got {tuple(x1.shape)} and {tuple(mu.shape)}" ) x1_flat, _, _, _ = self._flatten_bvt(x1) context_for_tokens, _, _ = self._prepare_context(mu) x0, src_mu, log_var = self.src_gen(context_for_tokens) t = self.transport.sample_timestep(x1_flat) t, xt, ut = self.transport.path_sampler.plan(t, x0, x1_flat) pred = self.estimator( x=xt, t=t, pre_encoded_context=context_for_tokens, ) ut_target = ut.detach() if self.detach_ut else ut loss_fm = F.mse_loss(pred, ut_target) if log_var is not None: loss_kld = var_kld_loss(src_mu, log_var, target_std=self.kld_target_std) else: loss_kld = torch.tensor(0.0, device=x1.device, dtype=x1.dtype) loss_total = loss_fm + self.kld_weight * loss_kld loss_dict = { "fm": loss_fm.item(), "kld": loss_kld.item(), "total": loss_total.item(), } return loss_total, loss_dict class CFM(BASECFM): def __init__( self, feat_dim: int, cfm_params, velocity_net_params: Optional[dict] = None, source_ve_params: Optional[dict] = None, transport_params: Optional[dict] = None, ): super().__init__(feat_dim=feat_dim, cfm_params=cfm_params) vn_cfg = dict(velocity_net_params or {}) vn_cfg.setdefault("output_dim", feat_dim) vn_cfg.setdefault("modality_dims", [feat_dim]) self.estimator = VelocityNet(**vn_cfg) hidden_dim = int(vn_cfg.get("hidden_dim", 256)) sve_cfg = dict(source_ve_params or {}) sve_cfg.setdefault("context_dim", hidden_dim) sve_cfg.setdefault("output_dim", feat_dim) sve_cfg.setdefault("hidden_dim", hidden_dim) self.src_gen = SourceVE(**sve_cfg) tp_cfg = dict(transport_params or {}) tp_cfg.setdefault("path_type", "Linear") tp_cfg.setdefault("prediction", "velocity") tp_cfg.setdefault("time_dist_type", "uniform") tp_cfg.setdefault("time_dist_shift", float(_cfg_get(cfm_params, "time_dist_shift", 1.0))) self.transport = create_transport(**tp_cfg)