import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # add repo root to path import torch from dataclasses import dataclass from typing import Any, Literal, Optional import numpy as np import pandas as pd from lightning_modules.mdm import MaskedDiffusionModule @dataclass class SamplingTraceDatapoint: t: float event_type: Literal["insertion", "change"] position: int token: Any @dataclass class SamplingResult: samples: torch.Tensor # Trace is supposed to be processed sequentially as updates are not commutative trace: Optional[list[SamplingTraceDatapoint]] def __iter__(self): yield from [self.samples, self.trace] # Sample from categorical distribution for each position using the transition probabilities def _sample_tokens(probs: torch.Tensor) -> torch.Tensor: """Sample one token per position from probability distribution. Args: probs: [batch_size, seq_len, vocab_size] transition probabilities Returns: [batch_size, seq_len] sampled token indices """ batch_size, seq_len, vocab_size = probs.shape flat_probs = probs.view(-1, vocab_size) samples = torch.multinomial(flat_probs, num_samples=1) return samples.view(batch_size, seq_len) def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor: batch_size, seq_len, vocab_size = probs.shape gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device) noisy_logits = torch.log(probs + 1e-10) + gumbel_noise # add Gumbel noise to log probabilities # select the highest score (most likely category after Gumbel noise) samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long) return samples.view(batch_size, seq_len) @torch.no_grad() def mdm_euler_sampling( model: MaskedDiffusionModule, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, temperature: float = 1.0, ): assert not return_trace, "Trace is not yet implemented in MDM Euler sampling" device = next(model.parameters()).device xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) for i in range(steps): print("i-th sampling step") # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) _xt = xt.clone() trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) # Apply temperature scaling if temperature != 1.0: logits = torch.log(trans_prob + 1e-10) / temperature trans_prob = torch.softmax(logits, dim=-1) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 print(trans_prob[mask_pos + (mask,)]) new_xt = _sample_tokens(trans_prob) new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt return xt, [] @torch.no_grad() def any_order_mask_insertion_euler_sampling( model: torch.nn.Module, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, temperature: float = 1.0, ) -> SamplingResult: device = next(model.parameters()).device # 1) Initialize all‑pad sequence and trace xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # Precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None for i in range(steps): # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # create uniform distribution over valid tokens (excluding mask and pad) uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: # normalize to sum to 1 trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) if i != steps - 1: # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt if return_trace: # Check if the token was changed for batch_idx in range(batch_size): for j in range(max_length): if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="change", position=j, token=new_xt[batch_idx, j].item(), ) ) # Check if a new token was inserted for j in range(max_length): id = max_length - j - 1 if ext[batch_idx, id]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="insertion", position=id, token=mask, ) ) xt = xt_tmp t = t + dt return xt, sampling_trace @torch.no_grad() def batch_mcts_reverse_step( xt: torch.Tensor, t: torch.Tensor, dt: float, model: torch.nn.Module, pretrained: torch.nn.Module, mask: int, pad: int, batch_size: int, max_length: int, last_step: bool = False, temperature: float = 1.0, ) -> SamplingResult: device = next(model.parameters()).device xt = xt.repeat(batch_size, 1) # squeeze to remove extra dimensions, then expand to batch_size t = t.squeeze().expand(batch_size) # precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— get pretrained model rates for log_rnd computation ——— pretrained_pred = pretrained(xt, t) pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # Same for pretrained pretrained_unmask_rate[xt != mask] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) pretrained_trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), ) if last_step: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # create uniform distribution over valid tokens (excluding mask and pad) uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: # normalize to sum to 1 trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # ——— compute log probabilities for RND ——— lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) log_policy_step = (lp * changed_mask).sum(dim=1) log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) log_rnd = log_pretrained_step - log_policy_step # (B,) if not last_step: # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) # log P(ext; λ) = ext*log(λ) - λ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) log_rnd += log_insert_diff log_pretrained_step += log_pretrained_insert log_policy_step += log_policy_insert xt_len = xt.ne(pad).sum(dim=1) # (B,) seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch gaps = torch.arange(seq_dim, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt return xt_tmp, log_rnd, log_policy_step, log_pretrained_step @torch.no_grad() def mcts_reverse_step( xt: torch.Tensor, t: torch.Tensor, dt: float, model: torch.nn.Module, pretrained: torch.nn.Module, mask: int, pad: int, max_length: int, last_step: bool = False, temperature: float = 1.0, ) -> SamplingResult: device = next(model.parameters()).device batch_size = xt.size(0) # precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— get pretrained model rates for log_rnd computation ——— pretrained_pred = pretrained(xt, t) pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # same for pretrained pretrained_unmask_rate[xt != mask] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) pretrained_trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), ) if last_step: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # avoid division by zero - if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # create uniform distribution over valid tokens (excluding mask and pad) uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: # normalize to sum to 1 trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # ——— compute log probabilities for RND ——— lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) log_policy_step = (lp * changed_mask).sum(dim=1) log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) log_rnd = log_pretrained_step - log_policy_step # (B,) if not last_step: # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) # log P(ext; λ) = ext*log(λ) - λ log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) log_rnd += log_insert_diff log_pretrained_step += log_pretrained_insert log_policy_step += log_policy_insert xt_len = xt.ne(pad).sum(dim=1) # (B,) seq_dim = ext.size(1) # Use actual ext dimension to avoid mismatch gaps = torch.arange(seq_dim, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt return xt_tmp, log_rnd, log_policy_step, log_pretrained_step @torch.no_grad() def any_order_euler_sampling_with_schedule( model: torch.nn.Module, time_schedule: torch.Tensor, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, temperature: float = 1.0, ) -> SamplingResult: device = next(model.parameters()).device time_schedule = time_schedule.to(device) if time_schedule[0] < time_schedule[-1]: time_schedule = torch.flip(time_schedule, [0]) # descending order steps = len(time_schedule) - 1 # initialize all-pad sequence and trace xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) # precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None for i in range(steps): # use scheduled timesteps t = time_schedule[i].repeat(batch_size) t_next = time_schedule[i + 1] dt = (t - t_next).abs() # timestep difference # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0) # add "stay" probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) # Apply temperature scaling if temperature != 1.0: logits = torch.log(trans_prob + 1e-10) / temperature trans_prob = torch.softmax(logits, dim=-1) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) if i != steps - 1: # ——— gap-wise insertion refactored — compute new length, fill masks, scatter tokens ——— ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt if return_trace: # Check if the token was changed for batch_idx in range(batch_size): for j in range(max_length): if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="change", position=j, token=new_xt[batch_idx, j].item(), ) ) # Check if a new token was inserted for j in range(max_length): id = max_length - j - 1 if ext[batch_idx, id]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="insertion", position=id, token=mask, ) ) xt = xt_tmp return xt, sampling_trace @torch.no_grad() def any_order_mask_insertion_euler_sampling_with_rnd( model, pretrained, reward_model, analyzer, tokenizer, steps, mask, pad, batch_size, max_length, return_trace = False, alpha = 0.1, temperature: float = 1.0, ): device = next(model.parameters()).device # initialize all‑pad sequence and trace xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] # initialize log_rnd to accumulate log probability ratios log_rnd = torch.zeros(batch_size, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None for i in range(steps): # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— get pretrained model rates for log_rnd computation ——— pretrained_pred = pretrained(xt, t) pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() # (B, L, V) pretrained_len_rate = pretrained_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # Same for pretrained pretrained_unmask_rate[xt != mask] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = 0 pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) pretrained_trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), ) # Apply temperature scaling if temperature != 1.0: logits = torch.log(trans_prob + 1e-10) / temperature trans_prob = torch.softmax(logits, dim=-1) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # create uniform distribution over valid tokens (excluding mask and pad) uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # ——— compute log probabilities for RND ——— lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) log_policy_step = (lp * changed_mask).sum(dim=1) log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) log_rnd = log_pretrained_step - log_policy_step # (B,) if i != steps - 1: ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) insertion_rate = (len_rate * dt).clamp(min=1e-10) # (B, L+1) pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) # (B, L+1) log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) # (B,) log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) # (B,) log_insert_diff = log_pretrained_insert - log_policy_insert # (B,) log_rnd += log_insert_diff xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt if return_trace: # check if the token was changed for i in range(batch_size): for j in range(max_length): if xt[i, j] != pad and xt[i, j] != new_xt[i, j]: sampling_trace[i].append( SamplingTraceDatapoint( t=t[i].item(), event_type="change", position=j, token=new_xt[i, j].item(), ) ) # check if a new token was inserted for j in range(max_length): id = max_length - j - 1 if ext[i, id]: sampling_trace[i].append( SamplingTraceDatapoint( t=t[i].item(), event_type="insertion", position=id, token=mask, ) ) xt = xt_tmp t = t + dt # change rewards for peptides samples = xt.to(device) # store raw token IDs # Decode and strip samples decoded_samples = tokenizer.batch_decode(samples) valid_x_final = [] validSequences = [] valid_log_rnd = [] for idx, seq in enumerate(decoded_samples): # check if the peptide is valid if analyzer.is_peptide(seq): valid_x_final.append(xt[idx]) validSequences.append(seq) valid_log_rnd.append(log_rnd[idx]) print("len valid sequences:", len(validSequences)) # compute multi-objective rewards score_vectors = reward_model(input_seqs=validSequences) scalar_rewards = np.sum(score_vectors, axis=-1) scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) print(f"scalar reward dim{len(scalar_rewards)}") valid_log_rnd = torch.stack(valid_log_rnd, dim=0) log_rnd = valid_log_rnd + (scalar_rewards / alpha) # scale down by alpha valid_x_final = torch.stack(valid_x_final, dim=0) return valid_x_final, log_rnd, scalar_rewards, sampling_trace @torch.no_grad() def any_order_finetuned_euler_sampler( model, reward_model, analyzer, tokenizer, steps, mask, pad, batch_size, max_length, return_trace = False, dataframe = False, temperature: float = 1.0, ): device = next(model.parameters()).device # initialize all‑pad sequence and trace xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) sampling_trace = [[] for _ in range(batch_size)] if return_trace else None for i in range(steps): # ——— predict and convert rates ——— pred_rate = model(xt, t) pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) unmask_rate = pred_rate.unmask_rate # (B, L, V) len_rate = pred_rate.length_rate # (B, L+1) # ——— unmask step (Euler) ——— mask_pos = (xt == mask).nonzero(as_tuple=True) unmask_rate[xt != mask] = 0 unmask_rate[mask_pos + (mask,)] = 0 unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) # add “stay” probability _xt = xt.clone() _xt[xt == pad] = mask trans_prob.scatter_add_( 2, _xt.unsqueeze(-1), torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), ) # Apply temperature scaling if temperature != 1.0: logits = torch.log(trans_prob + 1e-10) / temperature trans_prob = torch.softmax(logits, dim=-1) if i == steps - 1: print("Final step, removing mask token from sampling") trans_prob[mask_pos + (mask,)] = 0.0 # remove mask token from sampling at the last step # renormalize probabilities to ensure they sum to 1 prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) # avoid division by zero; if all probs are 0, use uniform distribution (excluding mask and pad) mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) if mask_has_zero_prob.any(): # create uniform distribution over valid tokens (excluding mask and pad) uniform_prob = torch.zeros_like(trans_prob[0]) uniform_prob[:mask] = 1.0 / mask # Uniform over tokens 0 to mask-1 trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob else: # normalize to sum to 1 trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum new_xt = _sample_tokens(trans_prob) new_xt[xt == pad] = pad new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) if i != steps - 1: # gap-wise insertion refactored — compute new length, fill masks, scatter tokens ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) xt_tmp = torch.full_like(xt, pad) mask_fill = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_fill] = mask new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt if return_trace: # check if the token was changed for batch_idx in range(batch_size): for j in range(max_length): if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="change", position=j, token=new_xt[batch_idx, j].item(), ) ) # check if a new token was inserted for j in range(max_length): id = max_length - j - 1 if ext[batch_idx, id]: sampling_trace[batch_idx].append( SamplingTraceDatapoint( t=t[batch_idx].item(), event_type="insertion", position=id, token=mask, ) ) xt = xt_tmp t = t + dt # start eval samples = xt.to(device) decoded_samples = tokenizer.batch_decode(samples) valid_x_final = [] validSequences = [] for idx, seq in enumerate(decoded_samples): if analyzer.is_peptide(seq): valid_x_final.append(samples[idx]) validSequences.append(seq) print("len valid sequences:", len(validSequences)) valid_fraction = len(validSequences) / batch_size if (len(validSequences) != 0): # add scores to log score_vectors = reward_model(input_seqs=validSequences) # (num_children, num_objectives) average_scores = score_vectors.T affinity = average_scores[0] sol = average_scores[1] hemo = average_scores[2] nf = average_scores[3] permeability = average_scores[4] else: zeros = [0.0] affinity = zeros sol = zeros hemo = zeros nf = zeros permeability = zeros if dataframe: df = pd.DataFrame({ "Peptide Sequence": validSequences, "Binding Affinity": affinity if len(validSequences) else [0.0], "Solubility": sol if len(validSequences) else [0.0], "Hemolysis": hemo if len(validSequences) else [0.0], "Nonfouling": nf if len(validSequences) else [0.0], "Permeability": permeability if len(validSequences) else [0.0], }) return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df return samples, affinity, sol, hemo, nf, permeability, valid_fraction @torch.no_grad() def mdm_tau_leaping_sampling( model: MaskedDiffusionModule, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, temperature: float = 1.0, ): assert not return_trace, "Trace is not yet supported" device = next(model.parameters()).device xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) dt = 1.0 / steps t = torch.zeros(batch_size, device=device) for i in range(steps): # ——— predict and convert rates ——— pred = model(xt, t) pred = model.interpolant.to_actual_rate(xt, pred, t) unmask_rate = pred.unmask_rate # (B, L, V) if i == steps - 1: # last step: deterministic unmask via argmax mask_pos = xt == mask # (B, L) new_token = unmask_rate.argmax(dim=2) # (B, L) new_xt = xt.clone() new_xt[mask_pos] = new_token[mask_pos] new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt continue # tau-leaping via Poisson counts counts = torch.poisson(unmask_rate * dt).long() mask_pos = xt == mask # (B, L) # zero out non-mask positions and mask→mask counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 counts[..., mask] = 0 # only accept exactly one event sum_c = counts.sum(dim=2) # (B, L) one_event = sum_c == 1 new_token = counts.argmax(dim=2) # (B, L) # build new xt new_xt = xt.clone() new_xt[one_event] = new_token[one_event] # keep pads and already-unmasked tokens new_xt = torch.where(xt != mask, xt, new_xt) xt = new_xt t = t + dt return xt, [] # Not used in production, for debugging purposes lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1} def binomial_mass(k, n, p): """ Calculate the probability mass function (PMF) for a binomial distribution. Args: k (int): Number of successes n (int): Number of trials p (float): Probability of success in a single trial Returns: float: Probability mass P(X = k) """ import math # Calculate binomial coefficient (n choose k) try: binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k)) except ValueError: # Handle cases where k > n or negative values return 0.0 # Calculate probability mass return binom_coef * (p ** k) * ((1 - p) ** (n - k)) def calculate_rate_batch(alpha_t, len_t): """ Calculate rate for a batch of alpha_t and len_t values. Args: alpha_t (torch.Tensor): Tensor of shape (batch_size,) len_t (torch.Tensor): Tensor of shape (batch_size,) Returns: torch.Tensor: Tensor of shape (batch_size,) containing calculated rates """ batch_size = alpha_t.shape[0] device = alpha_t.device # Initialize tensors for numerator and denominator nom = torch.zeros(batch_size, device=device) denom = torch.zeros(batch_size, device=device) for length, probability in lengths.items(): # Create mask for valid entries where len_t <= length valid_mask = (len_t <= length) & (len_t >= 0) if not valid_mask.any(): continue valid_indices = valid_mask.nonzero(as_tuple=True)[0] valid_len_t = len_t[valid_indices] valid_alpha_t = alpha_t[valid_indices] # Calculate binomial probabilities efficiently using torch distribution binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t) binom_probs = binom_dist.log_prob(valid_len_t).exp() # Update numerator and denominator for valid indices nom[valid_indices] += (length - valid_len_t) * probability * binom_probs denom[valid_indices] += probability * binom_probs # Handle division by zero in a vectorized way result = torch.zeros_like(nom) div_mask = denom > 0 result[div_mask] = nom[div_mask] / (denom[div_mask]) return result # Keep the original function for backward compatibility def calculate_rate(alpha_t, len_t): """Legacy scalar version of calculate_rate""" if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0: return calculate_rate_batch(alpha_t, len_t) nom, denom = 0, 0 for length, probability in lengths.items(): if length >= len_t: nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t) denom += probability * binomial_mass(len_t, length, alpha_t) if denom == 0: return 0.0 return nom /denom @torch.no_grad() def any_order_mask_insertion_tau_leaping_sampling( model: torch.nn.Module, steps: int, mask: int, pad: int, batch_size: int, max_length: int, return_trace: bool = False, confidence_based_sampling: bool = True, # whether to use confidence-based decoding alpha: float = 5.0, # hyperparameter for window size calculation max_window: int = 32, # Maximum window size for sliding window confidence_method: str = "prob_diff", # "position", "top_prob", "prob_diff", "entropy" use_sliding_window: bool = False, # whether to use sliding window for position selection temperature: float = 1.0, ) -> SamplingResult: device = next(model.parameters()).device xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) sampling_trace = [] dt = 1.0 / steps t = torch.zeros(batch_size, device=device) # Precompute row indices for scatter batch_idx_L = ( torch.arange(batch_size, device=device) .view(batch_size, 1) .expand(batch_size, max_length) ) pos_idx_L = ( torch.arange(max_length, device=device) .view(1, max_length) .expand(batch_size, max_length) ) for i in range(steps): # --- predict rates --- pred = model(xt, t) xt_len = (xt != pad).sum(dim=1) pred = model.interpolant.to_actual_rate(xt, pred, t) unmask_rate = pred.unmask_rate # (B, L, V) len_rate = pred.length_rate # (B, L+1) if i == steps - 1: # last step: deterministic unmask via argmax mask_pos = xt == mask new_token = unmask_rate.argmax(dim=2) new_xt = xt.clone() new_xt[mask_pos] = new_token[mask_pos] new_xt = torch.where(xt == pad, pad, new_xt) new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) xt = new_xt t = t + dt continue # --- confidence-based decoding --- if confidence_based_sampling > 0.0: # Confidence-based unmasking (vectorized) mask_positions = (xt == mask) # (B, L) num_mask_positions = mask_positions.sum(dim=1) # (B,) # 1. Determine number of tokens to unmask using Poisson unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() # (B,) # 2. Calculate confidence based on selected method if confidence_method == "position": # Position-based confidence: position i / len(xt) xt_len = (xt != pad).sum(dim=1) # (B,) - current sequence lengths position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) # (B, L) confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) # (B, L) elif confidence_method == "top_prob": # Top probability confidence import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) - use the unmask_rate as logits unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) confidence = unmask_probs.max(dim=-1)[0] # (B, L) elif confidence_method == "prob_diff": # Probability difference confidence (top - second top) import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) # (B, L, 2) confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] # (B, L) elif confidence_method == "entropy": # Entropy-based confidence (lower entropy = higher confidence) import torch.nn.functional as F token_logits = unmask_rate # (B, L, V) unmask_probs = F.softmax(token_logits, dim=-1) # (B, L, V) entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) # (B, L) confidence = -entropy # (B, L) - negative entropy so lower entropy gives higher confidence else: raise ValueError(f"Unknown confidence_method: {confidence_method}") # 3. Apply window constraint if enabled if use_sliding_window: # Calculate dynamic k for each batch k_values = torch.minimum( torch.minimum( (alpha * unmask_counts).long(), torch.tensor(max_window, device=device) ), num_mask_positions) # (B,) # Get cumulative count of mask positions mask_cumsum = mask_positions.cumsum(dim=1) # (B, L) # Create window mask: position is eligible if it's a mask and within first k masks is_within_window = mask_cumsum <= k_values.unsqueeze(1) # (B, L) window_mask = mask_positions & is_within_window # (B, L) # Set confidence to -inf for positions outside the window or non-mask positions confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device)) else: # No window constraint - only mask positions are eligible confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device)) new_xt = xt.clone() # vectorized unmasking max_unmask = unmask_counts.max().item() if max_unmask > 0: _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) # (B, max_unmask) # create mask for valid unmask operations unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) # (B, max_unmask) most_likely_tokens = unmask_rate.argmax(dim=-1) # (B, L) selected_positions = all_top_indices[unmask_mask] batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask] new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions] else: # --- tau-leaping unmask via Poisson --- counts = torch.poisson(unmask_rate * dt).long() mask_pos = xt == mask counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 counts[..., mask] = 0 sum_c = counts.sum(dim=2) one_event = sum_c == 1 new_token = counts.argmax(dim=2) new_xt = xt.clone() new_xt[one_event] = new_token[one_event] new_xt = torch.where(xt == pad, pad, new_xt) new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) # insertion only on non-last if i != steps - 1: # --- Poisson insertion, compute new lengths and fill masks --- ext = torch.poisson(len_rate * dt).long() # (B, L+1) xt_len = xt.ne(pad).sum(dim=1) # (B,) gaps = torch.arange(max_length + 1, device=device).view(1, -1) ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() total_ext = ext.sum(dim=1) valid = xt_len + total_ext <= max_length ext = ext * valid.view(batch_size, 1).long() # compute prefix sums of insertions ext_ex = ext.int().cumsum(dim=1) # (B, L+1) new_len = xt_len + total_ext # (B,) # initialize with pads, then fill mask up to new_len xt_tmp = torch.full_like(xt, pad) mask_pos = pos_idx_L < new_len.view(batch_size, 1) xt_tmp[mask_pos] = mask # shift and scatter original tokens new_pos_orig = pos_idx_L + ext_ex[:, :max_length] # (B, L) orig_mask = pos_idx_L < xt_len.view(batch_size, 1) flat_b = batch_idx_L[orig_mask] flat_p = new_pos_orig[orig_mask] xt_tmp[flat_b, flat_p] = new_xt[orig_mask] else: xt_tmp = new_xt xt = xt_tmp t = t + dt if return_trace: sampling_trace.append(xt) return xt, sampling_trace