sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC
from typing import Optional
from .source_ve import SourceVE, var_kld_loss
from .transport import Sampler, create_transport
from .velocity_net import VelocityNet
def _cfg_get(cfg, key, default):
if cfg is None:
return default
get_fn = getattr(cfg, "get", None)
if callable(get_fn):
try:
return get_fn(key, default)
except Exception:
pass
try:
return getattr(cfg, key)
except Exception:
return default
class BASECFM(nn.Module, ABC):
def __init__(self, feat_dim: int, cfm_params):
super().__init__()
self.feat_dim = feat_dim
self.kld_weight = float(_cfg_get(cfm_params, "kld_weight", 3.0))
self.kld_target_std = float(_cfg_get(cfm_params, "kld_target_std", 1.0))
self.detach_ut = bool(_cfg_get(cfm_params, "detach_ut", False))
self.solver = str(_cfg_get(cfm_params, "solver", "euler"))
self.estimator: Optional[nn.Module] = None
self.src_gen: Optional[nn.Module] = None
self.transport = None
@staticmethod
def _flatten_bvt(x: torch.Tensor) -> tuple[torch.Tensor, int, int, int]:
if x.ndim != 3:
raise ValueError(
f"Expected tensor with shape (B, V, T), got {tuple(x.shape)}"
)
bsz, voxels, time_steps = x.shape
flat = x.transpose(1, 2).contiguous().reshape(bsz * time_steps, voxels)
return flat, bsz, voxels, time_steps
@staticmethod
def _unflatten_bvt(x_flat: torch.Tensor, batch_size: int, time_steps: int) -> torch.Tensor:
return x_flat.reshape(batch_size, time_steps, -1).transpose(1, 2).contiguous()
@staticmethod
def _batch_indices(batch_size: int, time_steps: int, device: torch.device) -> torch.Tensor:
return torch.arange(batch_size, device=device).repeat_interleave(time_steps)
def _prepare_context(self, mu: torch.Tensor):
context = mu.transpose(1, 2).contiguous() # (B, T, V)
context_encoded = self.estimator.encode_context(context) # (B, T, H)
bsz, time_steps, _ = context_encoded.shape
batch_idx = self._batch_indices(bsz, time_steps, mu.device)
context_for_tokens = context_encoded[batch_idx] # (B*T, T, H)
return context_for_tokens, bsz, time_steps
# ---- inference -----------------------------------------------------------
@torch.inference_mode()
def forward(
self,
mu: torch.Tensor, # (B, feat_dim, L)
n_timesteps: int,
temperature: float = 1.0,
) -> torch.Tensor:
return self.synthesise(
mu=mu,
n_timesteps=n_timesteps,
solver_method=self.solver,
temperature=temperature,
)
@torch.inference_mode()
def synthesise(
self,
mu: torch.Tensor,
n_timesteps: int = 50,
solver_method: Optional[str] = None,
temperature: float = 1.0,
) -> torch.Tensor:
context_for_tokens, bsz, time_steps = self._prepare_context(mu)
_, src_mu, log_var = self.src_gen(context_for_tokens)
if log_var is not None and temperature > 0:
std = torch.exp(0.5 * log_var)
x0 = src_mu + torch.randn_like(src_mu) * std * temperature
else:
x0 = src_mu
sampler = Sampler(self.transport)
sample_fn = sampler.sample_ode(
sampling_method=solver_method or self.solver,
num_steps=n_timesteps,
)
def model_fn(x, t, **kwargs):
return self.estimator(
x=x,
t=t,
pre_encoded_context=context_for_tokens,
)
trajectory = sample_fn(x0, model_fn)
pred_flat = trajectory[-1]
return self._unflatten_bvt(pred_flat, bsz, time_steps)
# ---- training ------------------------------------------------------------
def compute_loss(
self,
x1: torch.Tensor, # (B, feat_dim, L)
mu: torch.Tensor, # (B, feat_dim, L)
) -> tuple:
if x1.shape != mu.shape:
raise ValueError(
f"x1 and mu must share shape (B, V, T), got {tuple(x1.shape)} and {tuple(mu.shape)}"
)
x1_flat, _, _, _ = self._flatten_bvt(x1)
context_for_tokens, _, _ = self._prepare_context(mu)
x0, src_mu, log_var = self.src_gen(context_for_tokens)
t = self.transport.sample_timestep(x1_flat)
t, xt, ut = self.transport.path_sampler.plan(t, x0, x1_flat)
pred = self.estimator(
x=xt,
t=t,
pre_encoded_context=context_for_tokens,
)
ut_target = ut.detach() if self.detach_ut else ut
loss_fm = F.mse_loss(pred, ut_target)
if log_var is not None:
loss_kld = var_kld_loss(src_mu, log_var, target_std=self.kld_target_std)
else:
loss_kld = torch.tensor(0.0, device=x1.device, dtype=x1.dtype)
loss_total = loss_fm + self.kld_weight * loss_kld
loss_dict = {
"fm": loss_fm.item(),
"kld": loss_kld.item(),
"total": loss_total.item(),
}
return loss_total, loss_dict
class CFM(BASECFM):
def __init__(
self,
feat_dim: int,
cfm_params,
velocity_net_params: Optional[dict] = None,
source_ve_params: Optional[dict] = None,
transport_params: Optional[dict] = None,
):
super().__init__(feat_dim=feat_dim, cfm_params=cfm_params)
vn_cfg = dict(velocity_net_params or {})
vn_cfg.setdefault("output_dim", feat_dim)
vn_cfg.setdefault("modality_dims", [feat_dim])
self.estimator = VelocityNet(**vn_cfg)
hidden_dim = int(vn_cfg.get("hidden_dim", 256))
sve_cfg = dict(source_ve_params or {})
sve_cfg.setdefault("context_dim", hidden_dim)
sve_cfg.setdefault("output_dim", feat_dim)
sve_cfg.setdefault("hidden_dim", hidden_dim)
self.src_gen = SourceVE(**sve_cfg)
tp_cfg = dict(transport_params or {})
tp_cfg.setdefault("path_type", "Linear")
tp_cfg.setdefault("prediction", "velocity")
tp_cfg.setdefault("time_dist_type", "uniform")
tp_cfg.setdefault("time_dist_shift", float(_cfg_get(cfm_params, "time_dist_shift", 1.0)))
self.transport = create_transport(**tp_cfg)