| import os |
| import sys |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| import torch |
| from dataclasses import dataclass |
| from typing import Any, Literal, Optional |
| import numpy as np |
| import pandas as pd |
|
|
| from lightning_modules.mdm import MaskedDiffusionModule |
|
|
|
|
| @dataclass |
| class SamplingTraceDatapoint: |
| t: float |
| event_type: Literal["insertion", "change"] |
| position: int |
| token: Any |
|
|
|
|
| @dataclass |
| class SamplingResult: |
| samples: torch.Tensor |
| |
| trace: Optional[list[SamplingTraceDatapoint]] |
|
|
| def __iter__(self): |
| yield from [self.samples, self.trace] |
|
|
|
|
| |
| def _sample_tokens(probs: torch.Tensor) -> torch.Tensor: |
| """Sample one token per position from probability distribution. |
| Args: |
| probs: [batch_size, seq_len, vocab_size] transition probabilities |
| Returns: |
| [batch_size, seq_len] sampled token indices |
| """ |
| batch_size, seq_len, vocab_size = probs.shape |
| flat_probs = probs.view(-1, vocab_size) |
| samples = torch.multinomial(flat_probs, num_samples=1) |
| return samples.view(batch_size, seq_len) |
|
|
|
|
| def _sample_batched_tokens(probs: torch.Tensor) -> torch.Tensor: |
| |
| batch_size, seq_len, vocab_size = probs.shape |
| |
| gumbel_noise = (-torch.log(-torch.log(torch.rand(batch_size, seq_len, vocab_size) + 1e-10) + 1e-10)).to(probs.device) |
| noisy_logits = torch.log(probs + 1e-10) + gumbel_noise |
| |
| |
| samples = noisy_logits.argmax(dim=-1).to(dtype=torch.long) |
| |
| return samples.view(batch_size, seq_len) |
|
|
| @torch.no_grad() |
| def mdm_euler_sampling( |
| model: MaskedDiffusionModule, |
| steps: int, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| return_trace: bool = False, |
| temperature: float = 1.0, |
| ): |
| assert not return_trace, "Trace is not yet implemented in MDM Euler sampling" |
| device = next(model.parameters()).device |
| xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| for i in range(steps): |
| print("i-th sampling step") |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| _xt = xt.clone() |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
|
|
| |
| if temperature != 1.0: |
| logits = torch.log(trans_prob + 1e-10) / temperature |
| trans_prob = torch.softmax(logits, dim=-1) |
|
|
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| print(trans_prob[mask_pos + (mask,)]) |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt = torch.where(xt != mask, xt, new_xt) |
|
|
| xt = new_xt |
| t = t + dt |
|
|
| return xt, [] |
|
|
|
|
| @torch.no_grad() |
| def any_order_mask_insertion_euler_sampling( |
| model: torch.nn.Module, |
| steps: int, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| return_trace: bool = False, |
| temperature: float = 1.0, |
| ) -> SamplingResult: |
| device = next(model.parameters()).device |
|
|
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
| sampling_trace = [] |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| for i in range(steps): |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
|
|
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| if i != steps - 1: |
| |
| ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() |
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| |
| for batch_idx in range(batch_size): |
| for j in range(max_length): |
| if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[batch_idx, j].item(), |
| ) |
| ) |
|
|
| |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[batch_idx, id]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| xt = xt_tmp |
| t = t + dt |
|
|
| return xt, sampling_trace |
|
|
| @torch.no_grad() |
| def batch_mcts_reverse_step( |
| xt: torch.Tensor, |
| t: torch.Tensor, |
| dt: float, |
| model: torch.nn.Module, |
| pretrained: torch.nn.Module, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| last_step: bool = False, |
| temperature: float = 1.0, |
| ) -> SamplingResult: |
| device = next(model.parameters()).device |
| |
| xt = xt.repeat(batch_size, 1) |
| |
| |
| t = t.squeeze().expand(batch_size) |
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
| |
| |
| pretrained_pred = pretrained(xt, t) |
| pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) |
| pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() |
| pretrained_len_rate = pretrained_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
| |
| |
| pretrained_unmask_rate[xt != mask] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
| pretrained_trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), |
| ) |
|
|
| if last_step: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
| |
| |
| lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
| lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
|
|
| changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
|
|
| log_policy_step = (lp * changed_mask).sum(dim=1) |
| log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) |
| |
| log_rnd = log_pretrained_step - log_policy_step |
|
|
| if not last_step: |
| |
| ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() |
| |
| insertion_rate = (len_rate * dt).clamp(min=1e-10) |
| pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) |
| |
| |
| log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) |
| log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) |
| |
| log_insert_diff = log_pretrained_insert - log_policy_insert |
| log_rnd += log_insert_diff |
| log_pretrained_step += log_pretrained_insert |
| log_policy_step += log_policy_insert |
| |
| xt_len = xt.ne(pad).sum(dim=1) |
| seq_dim = ext.size(1) |
| gaps = torch.arange(seq_dim, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
| |
| return xt_tmp, log_rnd, log_policy_step, log_pretrained_step |
|
|
|
|
| @torch.no_grad() |
| def mcts_reverse_step( |
| xt: torch.Tensor, |
| t: torch.Tensor, |
| dt: float, |
| model: torch.nn.Module, |
| pretrained: torch.nn.Module, |
| mask: int, |
| pad: int, |
| max_length: int, |
| last_step: bool = False, |
| temperature: float = 1.0, |
| ) -> SamplingResult: |
| device = next(model.parameters()).device |
| |
| batch_size = xt.size(0) |
| |
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
| |
| |
| pretrained_pred = pretrained(xt, t) |
| pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) |
| pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() |
| pretrained_len_rate = pretrained_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
| |
| |
| pretrained_unmask_rate[xt != mask] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
| pretrained_trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), |
| ) |
|
|
| if last_step: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
| |
| |
| lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
| lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
|
|
| changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
|
|
| log_policy_step = (lp * changed_mask).sum(dim=1) |
| log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) |
| |
| log_rnd = log_pretrained_step - log_policy_step |
|
|
| if not last_step: |
| |
| ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() |
| |
| insertion_rate = (len_rate * dt).clamp(min=1e-10) |
| pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) |
| |
| |
| log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) |
| log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) |
| |
| log_insert_diff = log_pretrained_insert - log_policy_insert |
| log_rnd += log_insert_diff |
| log_pretrained_step += log_pretrained_insert |
| log_policy_step += log_policy_insert |
| |
| xt_len = xt.ne(pad).sum(dim=1) |
| seq_dim = ext.size(1) |
| gaps = torch.arange(seq_dim, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
| |
| return xt_tmp, log_rnd, log_policy_step, log_pretrained_step |
|
|
| @torch.no_grad() |
| def any_order_euler_sampling_with_schedule( |
| model: torch.nn.Module, |
| time_schedule: torch.Tensor, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| return_trace: bool = False, |
| temperature: float = 1.0, |
| ) -> SamplingResult: |
| device = next(model.parameters()).device |
| |
| time_schedule = time_schedule.to(device) |
| if time_schedule[0] < time_schedule[-1]: |
| time_schedule = torch.flip(time_schedule, [0]) |
| |
| steps = len(time_schedule) - 1 |
| |
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
| |
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| for i in range(steps): |
| |
| t = time_schedule[i].repeat(batch_size) |
| t_next = time_schedule[i + 1] |
| dt = (t - t_next).abs() |
| |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt[:, None, None]).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
|
|
| |
| if temperature != 1.0: |
| logits = torch.log(trans_prob + 1e-10) / temperature |
| trans_prob = torch.softmax(logits, dim=-1) |
|
|
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| |
| if mask_has_zero_prob.any(): |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| if i != steps - 1: |
| |
| ext = torch.bernoulli((len_rate * dt[:, None]).clamp(0.0, 1.0)).long() |
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| |
| for batch_idx in range(batch_size): |
| for j in range(max_length): |
| if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[batch_idx, j].item(), |
| ) |
| ) |
|
|
| |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[batch_idx, id]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| xt = xt_tmp |
|
|
| return xt, sampling_trace |
|
|
|
|
| @torch.no_grad() |
| def any_order_mask_insertion_euler_sampling_with_rnd( |
| model, pretrained, reward_model, analyzer, |
| tokenizer, steps, |
| mask, |
| pad, |
| batch_size, |
| max_length, |
| return_trace = False, |
| alpha = 0.1, |
| temperature: float = 1.0, |
| ): |
| device = next(model.parameters()).device |
|
|
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
| sampling_trace = [] |
| |
| |
| log_rnd = torch.zeros(batch_size, device=device) |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| for i in range(steps): |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
| |
| |
| pretrained_pred = pretrained(xt, t) |
| pretrained_rate = pretrained.interpolant.to_actual_rate(xt, pretrained_pred, t) |
| pretrained_unmask_rate = pretrained_rate.unmask_rate.clone() |
| pretrained_len_rate = pretrained_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
| |
| |
| pretrained_unmask_rate[xt != mask] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = 0 |
| pretrained_unmask_rate[mask_pos + (mask,)] = -pretrained_unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| pretrained_trans_prob = (pretrained_unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
| pretrained_trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=pretrained_trans_prob.dtype), |
| ) |
|
|
| |
| if temperature != 1.0: |
| logits = torch.log(trans_prob + 1e-10) / temperature |
| trans_prob = torch.softmax(logits, dim=-1) |
|
|
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
| |
| |
| lp = torch.gather(torch.log(trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
| lp_pre = torch.gather(torch.log(pretrained_trans_prob + 1e-10), 2, new_xt.unsqueeze(-1)).squeeze(-1) |
|
|
| changed_mask = (xt == mask) & (new_xt != mask) & (new_xt != pad) |
|
|
| log_policy_step = (lp * changed_mask).sum(dim=1) |
| log_pretrained_step = (lp_pre * changed_mask).sum(dim=1) |
| |
| log_rnd = log_pretrained_step - log_policy_step |
|
|
| if i != steps - 1: |
| ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() |
| |
| insertion_rate = (len_rate * dt).clamp(min=1e-10) |
| pretrained_insertion_rate = (pretrained_len_rate * dt).clamp(min=1e-10) |
| |
| log_policy_insert = (ext * torch.log(insertion_rate) - insertion_rate).sum(dim=1) |
| log_pretrained_insert = (ext * torch.log(pretrained_insertion_rate) - pretrained_insertion_rate).sum(dim=1) |
| |
| log_insert_diff = log_pretrained_insert - log_policy_insert |
| log_rnd += log_insert_diff |
| |
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| |
| for i in range(batch_size): |
| for j in range(max_length): |
| if xt[i, j] != pad and xt[i, j] != new_xt[i, j]: |
| sampling_trace[i].append( |
| SamplingTraceDatapoint( |
| t=t[i].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[i, j].item(), |
| ) |
| ) |
|
|
| |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[i, id]: |
| sampling_trace[i].append( |
| SamplingTraceDatapoint( |
| t=t[i].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| xt = xt_tmp |
| t = t + dt |
| |
| |
| samples = xt.to(device) |
|
|
| |
| |
| decoded_samples = tokenizer.batch_decode(samples) |
|
|
| valid_x_final = [] |
| validSequences = [] |
| valid_log_rnd = [] |
| |
| for idx, seq in enumerate(decoded_samples): |
| |
| if analyzer.is_peptide(seq): |
| valid_x_final.append(xt[idx]) |
| validSequences.append(seq) |
| valid_log_rnd.append(log_rnd[idx]) |
| |
| print("len valid sequences:", len(validSequences)) |
| |
| score_vectors = reward_model(input_seqs=validSequences) |
| scalar_rewards = np.sum(score_vectors, axis=-1) |
| scalar_rewards = torch.as_tensor(scalar_rewards, dtype=torch.float32, device=device) |
|
|
| print(f"scalar reward dim{len(scalar_rewards)}") |
| valid_log_rnd = torch.stack(valid_log_rnd, dim=0) |
|
|
| log_rnd = valid_log_rnd + (scalar_rewards / alpha) |
| valid_x_final = torch.stack(valid_x_final, dim=0) |
|
|
| return valid_x_final, log_rnd, scalar_rewards, sampling_trace |
|
|
| @torch.no_grad() |
| def any_order_finetuned_euler_sampler( |
| model, reward_model, analyzer, |
| tokenizer, steps, |
| mask, |
| pad, |
| batch_size, |
| max_length, |
| return_trace = False, |
| dataframe = False, |
| temperature: float = 1.0, |
| ): |
| device = next(model.parameters()).device |
|
|
| |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
| sampling_trace = [] |
|
|
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
| sampling_trace = [[] for _ in range(batch_size)] if return_trace else None |
|
|
| for i in range(steps): |
| |
| pred_rate = model(xt, t) |
| pred_rate = model.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| mask_pos = (xt == mask).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
|
|
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
|
|
| |
| if temperature != 1.0: |
| logits = torch.log(trans_prob + 1e-10) / temperature |
| trans_prob = torch.softmax(logits, dim=-1) |
|
|
| if i == steps - 1: |
| print("Final step, removing mask token from sampling") |
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| uniform_prob = torch.zeros_like(trans_prob[0]) |
| uniform_prob[:mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| if i != steps - 1: |
| |
| ext = torch.bernoulli((len_rate * dt).clamp(0.0, 1.0)).long() |
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| mask_fill = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
|
|
| if return_trace: |
| |
| for batch_idx in range(batch_size): |
| for j in range(max_length): |
| if xt[batch_idx, j] != pad and xt[batch_idx, j] != new_xt[batch_idx, j]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="change", |
| position=j, |
| token=new_xt[batch_idx, j].item(), |
| ) |
| ) |
|
|
| |
| for j in range(max_length): |
| id = max_length - j - 1 |
| if ext[batch_idx, id]: |
| sampling_trace[batch_idx].append( |
| SamplingTraceDatapoint( |
| t=t[batch_idx].item(), |
| event_type="insertion", |
| position=id, |
| token=mask, |
| ) |
| ) |
|
|
| xt = xt_tmp |
| t = t + dt |
| |
| |
| samples = xt.to(device) |
| |
| decoded_samples = tokenizer.batch_decode(samples) |
| |
| valid_x_final = [] |
| validSequences = [] |
| |
| for idx, seq in enumerate(decoded_samples): |
| if analyzer.is_peptide(seq): |
| valid_x_final.append(samples[idx]) |
| validSequences.append(seq) |
| |
| print("len valid sequences:", len(validSequences)) |
| valid_fraction = len(validSequences) / batch_size |
| |
| if (len(validSequences) != 0): |
| |
| score_vectors = reward_model(input_seqs=validSequences) |
| average_scores = score_vectors.T |
| |
| affinity = average_scores[0] |
| sol = average_scores[1] |
| hemo = average_scores[2] |
| nf = average_scores[3] |
| permeability = average_scores[4] |
| |
| else: |
| zeros = [0.0] |
| |
| affinity = zeros |
| sol = zeros |
| hemo = zeros |
| nf = zeros |
| permeability = zeros |
| |
| if dataframe: |
| df = pd.DataFrame({ |
| "Peptide Sequence": validSequences, |
| "Binding Affinity": affinity if len(validSequences) else [0.0], |
| "Solubility": sol if len(validSequences) else [0.0], |
| "Hemolysis": hemo if len(validSequences) else [0.0], |
| "Nonfouling": nf if len(validSequences) else [0.0], |
| "Permeability": permeability if len(validSequences) else [0.0], |
| }) |
| return samples, affinity, sol, hemo, nf, permeability, valid_fraction, df |
| |
| return samples, affinity, sol, hemo, nf, permeability, valid_fraction |
|
|
| @torch.no_grad() |
| def mdm_tau_leaping_sampling( |
| model: MaskedDiffusionModule, |
| steps: int, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| return_trace: bool = False, |
| temperature: float = 1.0, |
| ): |
| assert not return_trace, "Trace is not yet supported" |
| device = next(model.parameters()).device |
| xt = torch.full((batch_size, max_length), mask, dtype=torch.int64, device=device) |
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| for i in range(steps): |
| |
| pred = model(xt, t) |
| pred = model.interpolant.to_actual_rate(xt, pred, t) |
| unmask_rate = pred.unmask_rate |
|
|
| if i == steps - 1: |
| |
| mask_pos = xt == mask |
| new_token = unmask_rate.argmax(dim=2) |
| new_xt = xt.clone() |
| new_xt[mask_pos] = new_token[mask_pos] |
| new_xt = torch.where(xt != mask, xt, new_xt) |
| xt = new_xt |
| t = t + dt |
| continue |
| |
| counts = torch.poisson(unmask_rate * dt).long() |
| mask_pos = xt == mask |
| |
| counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 |
| counts[..., mask] = 0 |
| |
| sum_c = counts.sum(dim=2) |
| one_event = sum_c == 1 |
| new_token = counts.argmax(dim=2) |
|
|
| |
| new_xt = xt.clone() |
| new_xt[one_event] = new_token[one_event] |
| |
| new_xt = torch.where(xt != mask, xt, new_xt) |
| xt = new_xt |
| t = t + dt |
|
|
| return xt, [] |
|
|
| |
| lengths = {4: 0.1, 16: 0.4, 32: 0.4, 64: 0.1} |
|
|
| def binomial_mass(k, n, p): |
| """ |
| Calculate the probability mass function (PMF) for a binomial distribution. |
| |
| Args: |
| k (int): Number of successes |
| n (int): Number of trials |
| p (float): Probability of success in a single trial |
| |
| Returns: |
| float: Probability mass P(X = k) |
| """ |
| import math |
| |
| |
| try: |
| binom_coef = math.factorial(n) / (math.factorial(k) * math.factorial(n - k)) |
| except ValueError: |
| |
| return 0.0 |
| |
| |
| return binom_coef * (p ** k) * ((1 - p) ** (n - k)) |
|
|
| def calculate_rate_batch(alpha_t, len_t): |
| """ |
| Calculate rate for a batch of alpha_t and len_t values. |
| |
| Args: |
| alpha_t (torch.Tensor): Tensor of shape (batch_size,) |
| len_t (torch.Tensor): Tensor of shape (batch_size,) |
| |
| Returns: |
| torch.Tensor: Tensor of shape (batch_size,) containing calculated rates |
| """ |
| batch_size = alpha_t.shape[0] |
| device = alpha_t.device |
| |
| |
| nom = torch.zeros(batch_size, device=device) |
| denom = torch.zeros(batch_size, device=device) |
| |
| for length, probability in lengths.items(): |
| |
| valid_mask = (len_t <= length) & (len_t >= 0) |
| |
| if not valid_mask.any(): |
| continue |
| |
| valid_indices = valid_mask.nonzero(as_tuple=True)[0] |
| valid_len_t = len_t[valid_indices] |
| valid_alpha_t = alpha_t[valid_indices] |
| |
| |
| binom_dist = torch.distributions.Binomial(total_count=length, probs=valid_alpha_t) |
| binom_probs = binom_dist.log_prob(valid_len_t).exp() |
| |
| |
| nom[valid_indices] += (length - valid_len_t) * probability * binom_probs |
| denom[valid_indices] += probability * binom_probs |
| |
| |
| result = torch.zeros_like(nom) |
| div_mask = denom > 0 |
| result[div_mask] = nom[div_mask] / (denom[div_mask]) |
| |
| return result |
|
|
| |
| def calculate_rate(alpha_t, len_t): |
| """Legacy scalar version of calculate_rate""" |
| if isinstance(alpha_t, torch.Tensor) and alpha_t.ndim > 0: |
| return calculate_rate_batch(alpha_t, len_t) |
| |
| nom, denom = 0, 0 |
| for length, probability in lengths.items(): |
| if length >= len_t: |
| nom += (length - len_t) * probability * binomial_mass(len_t, length, alpha_t) |
| denom += probability * binomial_mass(len_t, length, alpha_t) |
| |
| if denom == 0: |
| return 0.0 |
| |
| return nom /denom |
|
|
|
|
| @torch.no_grad() |
| def any_order_mask_insertion_tau_leaping_sampling( |
| model: torch.nn.Module, |
| steps: int, |
| mask: int, |
| pad: int, |
| batch_size: int, |
| max_length: int, |
| return_trace: bool = False, |
| confidence_based_sampling: bool = True, |
| alpha: float = 5.0, |
| max_window: int = 32, |
| confidence_method: str = "prob_diff", |
| use_sliding_window: bool = False, |
| temperature: float = 1.0, |
| ) -> SamplingResult: |
| |
| device = next(model.parameters()).device |
| xt = torch.full((batch_size, max_length), pad, dtype=torch.int64, device=device) |
| sampling_trace = [] |
| dt = 1.0 / steps |
| t = torch.zeros(batch_size, device=device) |
|
|
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, max_length) |
| ) |
| pos_idx_L = ( |
| torch.arange(max_length, device=device) |
| .view(1, max_length) |
| .expand(batch_size, max_length) |
| ) |
|
|
| for i in range(steps): |
| |
| pred = model(xt, t) |
| xt_len = (xt != pad).sum(dim=1) |
| pred = model.interpolant.to_actual_rate(xt, pred, t) |
| unmask_rate = pred.unmask_rate |
| len_rate = pred.length_rate |
|
|
| if i == steps - 1: |
| |
| mask_pos = xt == mask |
| new_token = unmask_rate.argmax(dim=2) |
| new_xt = xt.clone() |
| new_xt[mask_pos] = new_token[mask_pos] |
| new_xt = torch.where(xt == pad, pad, new_xt) |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
| xt = new_xt |
| t = t + dt |
| continue |
|
|
| |
| if confidence_based_sampling > 0.0: |
| |
| mask_positions = (xt == mask) |
| num_mask_positions = mask_positions.sum(dim=1) |
| |
| |
| unmask_counts = torch.poisson(num_mask_positions.float() * dt).long() |
| |
| |
| if confidence_method == "position": |
| |
| xt_len = (xt != pad).sum(dim=1) |
| position_indices = torch.arange(max_length, device=device).unsqueeze(0).expand(batch_size, -1) |
| confidence = 1.0 - (position_indices.float() / xt_len.unsqueeze(1).float().clamp(min=1)) |
| |
| elif confidence_method == "top_prob": |
| |
| import torch.nn.functional as F |
| token_logits = unmask_rate |
| unmask_probs = F.softmax(token_logits, dim=-1) |
| confidence = unmask_probs.max(dim=-1)[0] |
| |
| elif confidence_method == "prob_diff": |
| |
| import torch.nn.functional as F |
| token_logits = unmask_rate |
| unmask_probs = F.softmax(token_logits, dim=-1) |
| top2_probs, _ = torch.topk(unmask_probs, k=2, dim=-1) |
| confidence = top2_probs[:, :, 0] - top2_probs[:, :, 1] |
| |
| elif confidence_method == "entropy": |
| |
| import torch.nn.functional as F |
| token_logits = unmask_rate |
| unmask_probs = F.softmax(token_logits, dim=-1) |
| entropy = -torch.sum(unmask_probs * torch.log(unmask_probs + 1e-10), dim=-1) |
| confidence = -entropy |
| |
| else: |
| raise ValueError(f"Unknown confidence_method: {confidence_method}") |
| |
| |
| if use_sliding_window: |
| |
| k_values = torch.minimum( |
| torch.minimum( |
| (alpha * unmask_counts).long(), |
| torch.tensor(max_window, device=device) |
| ), num_mask_positions) |
| |
| |
| mask_cumsum = mask_positions.cumsum(dim=1) |
| |
| |
| is_within_window = mask_cumsum <= k_values.unsqueeze(1) |
| window_mask = mask_positions & is_within_window |
| |
| |
| confidence = torch.where(window_mask, confidence, torch.tensor(-float('inf'), device=device)) |
| else: |
| |
| confidence = torch.where(mask_positions, confidence, torch.tensor(-float('inf'), device=device)) |
| |
| new_xt = xt.clone() |
|
|
| |
| max_unmask = unmask_counts.max().item() |
| if max_unmask > 0: |
| _, all_top_indices = torch.topk(confidence, k=max_unmask, dim=1, largest=True) |
| |
| |
| unmask_mask = torch.arange(max_unmask, device=device).unsqueeze(0) < unmask_counts.unsqueeze(1) |
| |
| most_likely_tokens = unmask_rate.argmax(dim=-1) |
| |
| selected_positions = all_top_indices[unmask_mask] |
| batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, max_unmask)[unmask_mask] |
| |
| new_xt[batch_indices, selected_positions] = most_likely_tokens[batch_indices, selected_positions] |
| else: |
| |
| counts = torch.poisson(unmask_rate * dt).long() |
| mask_pos = xt == mask |
| counts[~mask_pos.unsqueeze(-1).expand_as(counts)] = 0 |
| counts[..., mask] = 0 |
| sum_c = counts.sum(dim=2) |
| one_event = sum_c == 1 |
| new_token = counts.argmax(dim=2) |
| new_xt = xt.clone() |
| new_xt[one_event] = new_token[one_event] |
| new_xt = torch.where(xt == pad, pad, new_xt) |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
|
|
| |
| if i != steps - 1: |
| |
| ext = torch.poisson(len_rate * dt).long() |
| xt_len = xt.ne(pad).sum(dim=1) |
| gaps = torch.arange(max_length + 1, device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| |
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| |
| xt_tmp = torch.full_like(xt, pad) |
| mask_pos = pos_idx_L < new_len.view(batch_size, 1) |
| xt_tmp[mask_pos] = mask |
|
|
| |
| new_pos_orig = pos_idx_L + ext_ex[:, :max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| else: |
| xt_tmp = new_xt |
|
|
| xt = xt_tmp |
| t = t + dt |
| if return_trace: |
| sampling_trace.append(xt) |
|
|
| return xt, sampling_trace |
|
|