| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from abc import ABC |
| from typing import Optional |
|
|
| from .source_ve import SourceVE, var_kld_loss |
| from .transport import Sampler, create_transport |
| from .velocity_net import VelocityNet |
|
|
|
|
| def _cfg_get(cfg, key, default): |
| if cfg is None: |
| return default |
|
|
| get_fn = getattr(cfg, "get", None) |
| if callable(get_fn): |
| try: |
| return get_fn(key, default) |
| except Exception: |
| pass |
|
|
| try: |
| return getattr(cfg, key) |
| except Exception: |
| return default |
|
|
| class BASECFM(nn.Module, ABC): |
| def __init__(self, feat_dim: int, cfm_params): |
| super().__init__() |
| self.feat_dim = feat_dim |
| self.kld_weight = float(_cfg_get(cfm_params, "kld_weight", 3.0)) |
| self.kld_target_std = float(_cfg_get(cfm_params, "kld_target_std", 1.0)) |
| self.detach_ut = bool(_cfg_get(cfm_params, "detach_ut", False)) |
| self.solver = str(_cfg_get(cfm_params, "solver", "euler")) |
| self.estimator: Optional[nn.Module] = None |
| self.src_gen: Optional[nn.Module] = None |
| self.transport = None |
|
|
| @staticmethod |
| def _flatten_bvt(x: torch.Tensor) -> tuple[torch.Tensor, int, int, int]: |
| if x.ndim != 3: |
| raise ValueError( |
| f"Expected tensor with shape (B, V, T), got {tuple(x.shape)}" |
| ) |
| bsz, voxels, time_steps = x.shape |
| flat = x.transpose(1, 2).contiguous().reshape(bsz * time_steps, voxels) |
| return flat, bsz, voxels, time_steps |
|
|
| @staticmethod |
| def _unflatten_bvt(x_flat: torch.Tensor, batch_size: int, time_steps: int) -> torch.Tensor: |
| return x_flat.reshape(batch_size, time_steps, -1).transpose(1, 2).contiguous() |
|
|
| @staticmethod |
| def _batch_indices(batch_size: int, time_steps: int, device: torch.device) -> torch.Tensor: |
| return torch.arange(batch_size, device=device).repeat_interleave(time_steps) |
|
|
| def _prepare_context(self, mu: torch.Tensor): |
| context = mu.transpose(1, 2).contiguous() |
| context_encoded = self.estimator.encode_context(context) |
|
|
| bsz, time_steps, _ = context_encoded.shape |
| batch_idx = self._batch_indices(bsz, time_steps, mu.device) |
| context_for_tokens = context_encoded[batch_idx] |
|
|
| return context_for_tokens, bsz, time_steps |
|
|
| |
|
|
| @torch.inference_mode() |
| def forward( |
| self, |
| mu: torch.Tensor, |
| n_timesteps: int, |
| temperature: float = 1.0, |
| ) -> torch.Tensor: |
| return self.synthesise( |
| mu=mu, |
| n_timesteps=n_timesteps, |
| solver_method=self.solver, |
| temperature=temperature, |
| ) |
|
|
| @torch.inference_mode() |
| def synthesise( |
| self, |
| mu: torch.Tensor, |
| n_timesteps: int = 50, |
| solver_method: Optional[str] = None, |
| temperature: float = 1.0, |
| ) -> torch.Tensor: |
| context_for_tokens, bsz, time_steps = self._prepare_context(mu) |
|
|
| _, src_mu, log_var = self.src_gen(context_for_tokens) |
| if log_var is not None and temperature > 0: |
| std = torch.exp(0.5 * log_var) |
| x0 = src_mu + torch.randn_like(src_mu) * std * temperature |
| else: |
| x0 = src_mu |
|
|
| sampler = Sampler(self.transport) |
| sample_fn = sampler.sample_ode( |
| sampling_method=solver_method or self.solver, |
| num_steps=n_timesteps, |
| ) |
|
|
| def model_fn(x, t, **kwargs): |
| return self.estimator( |
| x=x, |
| t=t, |
| pre_encoded_context=context_for_tokens, |
| ) |
|
|
| trajectory = sample_fn(x0, model_fn) |
| pred_flat = trajectory[-1] |
| return self._unflatten_bvt(pred_flat, bsz, time_steps) |
|
|
| |
|
|
| def compute_loss( |
| self, |
| x1: torch.Tensor, |
| mu: torch.Tensor, |
| ) -> tuple: |
| if x1.shape != mu.shape: |
| raise ValueError( |
| f"x1 and mu must share shape (B, V, T), got {tuple(x1.shape)} and {tuple(mu.shape)}" |
| ) |
|
|
| x1_flat, _, _, _ = self._flatten_bvt(x1) |
| context_for_tokens, _, _ = self._prepare_context(mu) |
|
|
| x0, src_mu, log_var = self.src_gen(context_for_tokens) |
|
|
| t = self.transport.sample_timestep(x1_flat) |
| t, xt, ut = self.transport.path_sampler.plan(t, x0, x1_flat) |
|
|
| pred = self.estimator( |
| x=xt, |
| t=t, |
| pre_encoded_context=context_for_tokens, |
| ) |
|
|
| ut_target = ut.detach() if self.detach_ut else ut |
| loss_fm = F.mse_loss(pred, ut_target) |
|
|
| if log_var is not None: |
| loss_kld = var_kld_loss(src_mu, log_var, target_std=self.kld_target_std) |
| else: |
| loss_kld = torch.tensor(0.0, device=x1.device, dtype=x1.dtype) |
|
|
| loss_total = loss_fm + self.kld_weight * loss_kld |
|
|
| loss_dict = { |
| "fm": loss_fm.item(), |
| "kld": loss_kld.item(), |
| "total": loss_total.item(), |
| } |
|
|
| return loss_total, loss_dict |
|
|
|
|
| class CFM(BASECFM): |
| def __init__( |
| self, |
| feat_dim: int, |
| cfm_params, |
| velocity_net_params: Optional[dict] = None, |
| source_ve_params: Optional[dict] = None, |
| transport_params: Optional[dict] = None, |
| ): |
| super().__init__(feat_dim=feat_dim, cfm_params=cfm_params) |
|
|
| vn_cfg = dict(velocity_net_params or {}) |
| vn_cfg.setdefault("output_dim", feat_dim) |
| vn_cfg.setdefault("modality_dims", [feat_dim]) |
| self.estimator = VelocityNet(**vn_cfg) |
|
|
| hidden_dim = int(vn_cfg.get("hidden_dim", 256)) |
| sve_cfg = dict(source_ve_params or {}) |
| sve_cfg.setdefault("context_dim", hidden_dim) |
| sve_cfg.setdefault("output_dim", feat_dim) |
| sve_cfg.setdefault("hidden_dim", hidden_dim) |
| self.src_gen = SourceVE(**sve_cfg) |
|
|
| tp_cfg = dict(transport_params or {}) |
| tp_cfg.setdefault("path_type", "Linear") |
| tp_cfg.setdefault("prediction", "velocity") |
| tp_cfg.setdefault("time_dist_type", "uniform") |
| tp_cfg.setdefault("time_dist_shift", float(_cfg_get(cfm_params, "time_dist_shift", 1.0))) |
| self.transport = create_transport(**tp_cfg) |
|
|