| import os |
| import torch |
| import torch.nn as nn |
| import pytorch_lightning as pl |
| from omegaconf import DictConfig |
| import torch.nn.functional as F |
| from model.transformer import AnyOrderMaskInsertionFlow |
| from model.interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction |
| from .bregman import jump_kernel_elbo, mse |
| from .schedule import get_schedule_from_config |
| from lightning_modules.any_order import AnyOrderInsertionFlowModule |
| from model.model_wrapper import RemaskingAnyOrder |
| from sampling import _sample_tokens |
|
|
| import re |
| from typing import Dict, Any |
| from dataclasses import dataclass |
|
|
| def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Returns a new state_dict where any key containing '._orig_mod.' is replaced |
| by removing the '_orig_mod' segment, e.g. |
| 'model._orig_mod.vocab_embed.embedding' |
| becomes |
| 'model.vocab_embed.embedding' |
| """ |
| new_state_dict: Dict[str, Any] = {} |
| for key, value in state_dict.items(): |
| |
| clean_key = re.sub(r"\._orig_mod\.", ".", key) |
| new_state_dict[clean_key] = value |
| return new_state_dict |
|
|
|
|
| @torch.no_grad() |
| def _binary_auc(scores: torch.Tensor, labels: torch.Tensor) -> float: |
| """Rank-based AUROC (Mann-Whitney U statistic). |
| |
| AUC = P(score[pos] > score[neg]); 0.5 means no discrimination. Returns NaN |
| when only one class is present (AUC undefined). Ties are not averaged, which |
| is fine for continuous logits used here. |
| """ |
| scores = scores.float().reshape(-1) |
| labels = labels.float().reshape(-1) |
| n_pos = labels.sum() |
| n_neg = labels.numel() - n_pos |
| if n_pos == 0 or n_neg == 0: |
| return float("nan") |
| order = torch.argsort(scores) |
| ranks = torch.empty_like(scores) |
| ranks[order] = torch.arange(1, scores.numel() + 1, device=scores.device, dtype=scores.dtype) |
| auc = (ranks[labels == 1].sum() - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg) |
| return auc.item() |
|
|
|
|
| class AnyOrderInsertionFlowModuleFT(AnyOrderInsertionFlowModule): |
| """ |
| Wrapper around AnyOrderInsertionFlowModule that adds adaptive schedule model |
| for fine-tuning. Can load a pretrained AnyOrderInsertionFlowModule checkpoint |
| and add the schedule model on top. |
| """ |
| def __init__(self, config, args, pretrained_checkpoint, insertion_planner=False): |
| |
| super().__init__(config) |
| |
| self.args = args |
| self.insertion_planner = insertion_planner |
| |
| |
| self.save_hyperparameters(ignore=['pretrained_checkpoint', 'args']) |
| |
| |
| if pretrained_checkpoint is not None: |
| self.load_pretrained_model(pretrained_checkpoint) |
| |
| |
| self.planner = RemaskingAnyOrder( |
| backbone=self, |
| d_model=self.config.model.hidden_size, |
| insertion_planner=insertion_planner) |
| |
| def load_pretrained_model(self, checkpoint_path: str): |
| """ |
| Load pretrained AnyOrderInsertionFlowModule weights. |
| Only loads the base model and interpolant, not the schedule model. |
| """ |
| print(f"Loading pretrained model from {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| |
| |
| if 'state_dict' in checkpoint: |
| state_dict = checkpoint['state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| |
| state_dict = strip_orig_mod_keys(state_dict) |
| |
| |
| base_state_dict = {k: v for k, v in state_dict.items() |
| if not k.startswith('planner.')} |
| |
| |
| |
| incompatible_keys = self.load_state_dict(base_state_dict, strict=False) |
| |
| |
| unexpected_missing = [k for k in incompatible_keys.missing_keys |
| if not k.startswith('planner.')] |
| planner_missing = [k for k in incompatible_keys.missing_keys |
| if k.startswith('planner.')] |
| |
| if unexpected_missing: |
| print(f"Warning: Unexpected missing keys from pretrained checkpoint: {unexpected_missing}") |
| if planner_missing: |
| print(f"Note: Planner will be trained from scratch ({len(planner_missing)} parameters)") |
| if incompatible_keys.unexpected_keys: |
| print(f"Warning: Unexpected keys in pretrained checkpoint: {incompatible_keys.unexpected_keys}") |
| |
| |
| if self.config.training.get('freeze_base_model', False): |
| print("Freezing base model parameters") |
| for name, param in self.named_parameters(): |
| if not name.startswith('planner.'): |
| param.requires_grad = False |
|
|
| def forward(self, x, t, return_features=False): |
| |
| return super().forward(x, t, return_features=return_features) |
|
|
| def training_loss(self, x1, t): |
| |
| |
| unmask_loss, insertion_loss, total_loss = super().training_loss(x1, t) |
| return unmask_loss, insertion_loss, total_loss |
| |
| |
| def training_step(self, batch, batch_idx): |
| |
| if isinstance(batch, dict): |
| batch = batch["input_ids"] |
|
|
| x1 = batch |
| t = self.sample_time(x1.shape[0], x1.device) |
|
|
| |
| unmask_loss, len_loss, loss = self.training_loss(x1, t) |
| |
| |
| self.log("train/unmask_loss", unmask_loss, prog_bar=True) |
| self.log("train/len_loss", len_loss, prog_bar=True) |
| self.log("train/total_loss", loss, prog_bar=True) |
| |
| return loss |
|
|
| def validation_step(self, batch, batch_idx): |
| if isinstance(batch, dict): |
| batch = batch["input_ids"] |
|
|
| x1 = batch |
| t = self.sample_time(x1.shape[0], x1.device) |
| unmask_loss, len_loss, loss = self.training_loss(x1, t) |
|
|
| self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True) |
| self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True) |
| self.log("val_loss", loss, prog_bar=True, sync_dist=True) |
|
|
| return loss |
| |
| @classmethod |
| def load_from_checkpoint(cls, checkpoint_path, map_location=None, strict=True, **kwargs): |
| """ |
| Custom checkpoint loading that handles finetuned checkpoints wrapped by PeptideFinetuner. |
| Extracts config from original pretrained checkpoint and loads finetuned weights. |
| """ |
| print(f"Loading finetuned checkpoint from {checkpoint_path}") |
| checkpoint = torch.load(checkpoint_path, map_location=map_location or 'cpu', weights_only=False) |
| |
| |
| hparams = checkpoint.get('hyper_parameters', {}) |
| state_dict = checkpoint.get('state_dict', {}) |
| |
| |
| has_policy_prefix = any(k.startswith('policy_model.') for k in state_dict.keys()) |
| |
| if has_policy_prefix: |
| |
| |
| vocab_size = None |
| for k, v in state_dict.items(): |
| if 'vocab_embed.embedding' in k: |
| vocab_size = v.shape[0] |
| break |
| |
| is_molecule_model = vocab_size is not None and vocab_size > 1000 |
| model_type = "MolFinetuner" if is_molecule_model else "PeptideFinetuner" |
| print(f"Detected wrapped finetuned checkpoint ({model_type}, vocab_size={vocab_size})") |
| |
| |
| if 'args' not in hparams: |
| raise ValueError(f"Cannot find 'args' in hyperparameters. This checkpoint may not be from {model_type}.") |
| |
| args = hparams['args'] |
| print(f"Found args in hyperparameters, type: {type(args)}") |
| |
| |
| |
| original_ckpt_path = None |
| if hasattr(args, 'checkpoint_path'): |
| original_ckpt_path = args.checkpoint_path |
| elif isinstance(args, dict) and 'checkpoint_path' in args: |
| original_ckpt_path = args['checkpoint_path'] |
| |
| |
| |
| if original_ckpt_path is None: |
| _repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| if is_molecule_model: |
| original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_mol.ckpt') |
| print(f"Warning: checkpoint_path not found in args, using default molecule pretrained checkpoint") |
| else: |
| original_ckpt_path = os.path.join(_repo_root, 'pretrained', 'anylength_pep.ckpt') |
| print(f"Warning: checkpoint_path not found in args, using default peptide pretrained checkpoint") |
| |
| |
| |
| if 'config' in checkpoint: |
| print("Found config directly in checkpoint") |
| config = checkpoint['config'] |
| else: |
| print(f"Config not in checkpoint, loading from original checkpoint: {original_ckpt_path}") |
| |
| |
| orig_ckpt = torch.load(original_ckpt_path, map_location='cpu', weights_only=False) |
| if 'config' not in orig_ckpt: |
| raise ValueError(f"Original checkpoint {original_ckpt_path} does not contain config") |
| |
| config = orig_ckpt['config'] |
| |
| |
| |
| from omegaconf import OmegaConf |
| if hasattr(config, 'training'): |
| OmegaConf.set_struct(config, False) |
| config.training.use_adaptive_schedule = True |
| OmegaConf.set_struct(config, True) |
| |
| |
| if not hasattr(args, '__dict__'): |
| |
| class Args: |
| pass |
| args_obj = Args() |
| for k, v in args.items(): |
| setattr(args_obj, k, v) |
| args = args_obj |
| |
| |
| model = cls( |
| config=config, |
| args=args, |
| pretrained_checkpoint=None, |
| insertion_planner=getattr(args, 'insertion_planner', False) |
| ) |
| |
| |
| policy_state = {} |
| for k, v in state_dict.items(): |
| if k.startswith('policy_model.'): |
| |
| new_key = k[len('policy_model.'):] |
| policy_state[new_key] = v |
| |
| |
| incompatible = model.load_state_dict(policy_state, strict=False) |
| if incompatible.missing_keys or incompatible.unexpected_keys: |
| print(f"Warning: Incompatible keys when loading finetuned weights:") |
| if incompatible.missing_keys: |
| print(f" Missing: {incompatible.missing_keys[:5]}...") |
| if incompatible.unexpected_keys: |
| print(f" Unexpected: {incompatible.unexpected_keys[:5]}...") |
| |
| |
| if model.use_ema: |
| if "ema_params" in checkpoint: |
| |
| model.ema_params = checkpoint["ema_params"] |
| print("Loaded EMA params from checkpoint") |
| else: |
| |
| model.ema_params = { |
| name: param.clone().detach() |
| for name, param in model.named_parameters() |
| } |
| print("Initialized EMA params from current model state") |
| else: |
| model.ema_params = {} |
| |
| |
| if "planner_state" in checkpoint and hasattr(model, 'planner'): |
| model.planner.load_state_dict(checkpoint["planner_state"], strict=False) |
| print("Loaded planner state from checkpoint") |
| |
| return model |
| else: |
| |
| |
| raise NotImplementedError( |
| "Direct finetuned checkpoints (not wrapped by PeptideFinetuner) are not yet supported. " |
| "Please provide config and args as kwargs." |
| ) |
| |
| def on_save_checkpoint(self, checkpoint): |
| """Save config and EMA params, including planner state.""" |
| |
| super().on_save_checkpoint(checkpoint) |
| |
| |
| if hasattr(self, 'planner'): |
| checkpoint["planner_state"] = self.planner.state_dict() |
| |
| def on_load_checkpoint(self, checkpoint): |
| """Load config and reinitialize interpolant, including planner.""" |
| |
| |
| if "config" in checkpoint: |
| |
| super().on_load_checkpoint(checkpoint) |
| else: |
| |
| |
| if self.use_ema and "ema_params" in checkpoint: |
| self.ema_params = checkpoint["ema_params"] |
| |
| |
| if hasattr(self, 'planner') and "planner_state" in checkpoint: |
| self.planner.load_state_dict(checkpoint["planner_state"]) |
| print("Loaded planner from checkpoint") |
| |
| def loss_wdce_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): |
| r""" |
| Weighted denoising cross entropy loss |
| X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) |
| |
| log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) |
| x: [B, L] (no mask) |
| num_replicates: R, number of replicates of each row in x |
| weight_func: w(lambda) for each sample, 1/lambda by default |
| centering_strength: float, controls how much of the mean is subtracted (DMPO-style) |
| softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) |
| """ |
| |
| batch = x.repeat_interleave(num_replicates, dim=0) |
| |
| batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) |
| if centering: |
| batch_weights = batch_weights - centering_strength * batch_weights.mean() |
| |
| batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) |
| |
| lamda = torch.rand(batch.shape[0], device=batch.device) |
| lamda_weights = weight_func(lamda).clamp(max=1e5) |
| |
| t = lamda |
| |
| |
| interpolant_sample = self.interpolant.sample_interpolant(t, batch) |
| unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) |
|
|
| prediction: ModelPrediction = self(interpolant_sample.xt, t) |
|
|
| scale_factor = self.config.interpolant.max_length |
|
|
| match self.unmask_loss_fn: |
| case "elbo": |
| mask_indices = interpolant_sample.mask_indices |
| unmask_loss_all = torch.zeros_like(unmask_weight) |
| unmask_loss_all[mask_indices] = unmask_weight[mask_indices] * F.cross_entropy( |
| prediction.token_logits[mask_indices], |
| interpolant_sample.unmasked[mask_indices], |
| reduction="none", |
| ) |
| unmask_loss = unmask_loss_all.sum(dim=1) / scale_factor |
| case _: |
| raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}") |
|
|
| match self.insert_loss_fn: |
| case "expectation": |
| gaps, gaps_mask = interpolant_sample.gaps_and_mask |
| insertion_loss_all = torch.zeros_like(insert_weight) |
| insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * jump_kernel_elbo( |
| gaps[gaps_mask], prediction.expected_gaps[gaps_mask] |
| ) |
| insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor |
|
|
| case "distribution": |
| gaps, gaps_mask = interpolant_sample.gaps_and_mask |
| insertion_loss_all = torch.zeros_like(insert_weight) |
| insertion_loss_all[gaps_mask] = insert_weight[gaps_mask] * F.cross_entropy( |
| prediction.length_posterior[gaps_mask], gaps[gaps_mask] |
| ) |
| insertion_loss = insertion_loss_all.sum(dim=1) / scale_factor |
|
|
| total_loss = unmask_loss + insertion_loss |
| |
| |
| weighted_loss = total_loss * batch_weights |
| return weighted_loss.mean() |
| |
| def one_step_sampler(self, xt, t, pred_rate=None): |
| """ |
| Sample one step of unmasking using model predictions. |
| |
| Args: |
| xt: Current state [B, L] |
| t: Time [B] |
| pred_rate: Optional pre-computed ModelPrediction. If None, will compute from model. |
| |
| Returns: |
| new_xt: Next state [B, L] |
| update_ids: Boolean mask of updated positions [B, L] |
| """ |
| mask = self.interpolant.mask_token |
| pad = self.interpolant.pad_token |
| batch_size, L = xt.shape |
| device = xt.device |
| steps = self.args.total_num_steps |
| dt = 1.0 / steps |
| max_length = self.interpolant.max_length |
| |
| batch_idx_L = ( |
| torch.arange(batch_size, device=device) |
| .view(batch_size, 1) |
| .expand(batch_size, L) |
| ) |
| pos_idx_L = ( |
| torch.arange(L, device=device) |
| .view(1, L) |
| .expand(batch_size, L) |
| ) |
| |
| |
| if pred_rate is None: |
| pred_rate = self(xt, t) |
| pred_rate = self.interpolant.to_actual_rate(xt, pred_rate, t) |
| unmask_rate = pred_rate.unmask_rate |
| len_rate = pred_rate.length_rate |
|
|
| |
| mask_pos = (xt == self.interpolant.mask_token).nonzero(as_tuple=True) |
| unmask_rate[xt != mask] = 0 |
| unmask_rate[mask_pos + (mask,)] = 0 |
| unmask_rate[mask_pos + (mask,)] = -unmask_rate[mask_pos + (slice(None),)].sum(dim=1) |
| trans_prob = (unmask_rate * dt).clamp(0.0, 1.0) |
| |
| |
| _xt = xt.clone() |
| _xt[xt == pad] = mask |
| trans_prob.scatter_add_( |
| 2, |
| _xt.unsqueeze(-1), |
| torch.ones_like(_xt.unsqueeze(-1), dtype=trans_prob.dtype), |
| ) |
|
|
| trans_prob[mask_pos + (mask,)] = 0.0 |
| |
| |
| prob_sum = trans_prob[mask_pos].sum(dim=-1, keepdim=True) |
| |
| mask_has_zero_prob = (prob_sum.squeeze(-1) == 0.0) |
| if mask_has_zero_prob.any(): |
| |
| num_zero_prob = mask_has_zero_prob.sum().item() |
| uniform_prob = torch.zeros((num_zero_prob, trans_prob.shape[-1]), device=device, dtype=trans_prob.dtype) |
| uniform_prob[:, :mask] = 1.0 / mask |
| trans_prob[mask_pos[0][mask_has_zero_prob], mask_pos[1][mask_has_zero_prob]] = uniform_prob |
| else: |
| |
| trans_prob[mask_pos] = trans_prob[mask_pos] / prob_sum |
|
|
| new_xt = _sample_tokens(trans_prob) |
| new_xt[xt == pad] = pad |
| new_xt = torch.where((xt != mask) & (xt != pad), xt, new_xt) |
| |
| |
| |
| |
| |
| |
| |
| |
| old_update_ids = (xt != new_xt) & (xt != pad) & (xt != mask) |
| |
| |
| update_ids = (xt != new_xt) & (xt != pad) |
| |
| if self.insertion_planner is False: |
| return new_xt, update_ids |
| |
| |
| ext = torch.poisson(len_rate * dt).long() |
| xt_len = xt.ne(pad).sum(dim=1) |
| |
| actual_max_length = ext.shape[1] - 1 |
| gaps = torch.arange(ext.shape[1], device=device).view(1, -1) |
| ext = ext * (gaps <= xt_len.view(batch_size, 1)).long() |
| total_ext = ext.sum(dim=1) |
| valid = xt_len + total_ext <= actual_max_length |
| ext = ext * valid.view(batch_size, 1).long() |
|
|
| ext_ex = ext.int().cumsum(dim=1) |
| new_len = xt_len + total_ext |
|
|
| xt_tmp = torch.full_like(xt, pad) |
| |
| pos_idx_for_fill = torch.arange(xt_tmp.shape[1], device=device).view(1, -1).expand(batch_size, -1) |
| mask_fill = pos_idx_for_fill < new_len.view(batch_size, 1) |
| xt_tmp[mask_fill] = mask |
|
|
| new_pos_orig = pos_idx_L + ext_ex[:, :actual_max_length] |
| orig_mask = pos_idx_L < xt_len.view(batch_size, 1) |
| flat_b = batch_idx_L[orig_mask] |
| flat_p = new_pos_orig[orig_mask] |
| xt_tmp[flat_b, flat_p] = new_xt[orig_mask] |
| |
| new_ins_xt = xt_tmp |
| |
| |
| newly_inserted_masks = (new_ins_xt == mask) & (xt != mask) & (xt != pad) |
| |
| update_ins_ids = newly_inserted_masks |
| |
| return new_xt, update_ids, new_ins_xt, update_ins_ids |
| |
| def loss_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): |
| r""" |
| Weighted denoising cross entropy loss |
| X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) |
| |
| log_rnd: [B] — pre-computed importance weights (already softmax-normalized over the full buffer) |
| x: [B, L] (no mask) |
| num_replicates: R, number of replicates of each row in x |
| weight_func: w(lambda) for each sample, 1/lambda by default |
| centering_strength: float, controls how much of the mean is subtracted (DMPO-style) |
| softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) |
| """ |
| |
| batch = x.repeat_interleave(num_replicates, dim=0) |
| batch_size = batch.shape[0] |
| |
| batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) |
| if centering: |
| batch_weights = batch_weights - centering_strength * batch_weights.mean() |
| |
| batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) |
| |
| lamda = torch.rand(batch.shape[0], device=batch.device) |
| lamda_weights = weight_func(lamda).clamp(max=1e5) |
| |
| t = lamda |
| scale_factor = self.config.interpolant.max_length |
| |
| |
| interpolant_sample = self.interpolant.sample_interpolant(t, batch) |
| unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) |
|
|
| prediction: ModelPrediction = self(interpolant_sample.xt, t) |
| |
| with torch.no_grad(): |
| sampler_out = self.one_step_sampler(interpolant_sample.xt, t, prediction) |
| |
| xs, update_ids = sampler_out[0], sampler_out[1] |
|
|
| |
| |
| |
| planner = self.planner(xs, t) |
| remasking_conf = planner["remasking_conf"] |
|
|
| |
| |
| |
| st = interpolant_sample.st |
| batch_reordered = torch.gather(batch, 1, st) |
| |
| binary_label = (xs == batch_reordered).float() |
| |
| |
| per_token_loss = F.binary_cross_entropy_with_logits( |
| remasking_conf.squeeze(-1), |
| binary_label, |
| reduction="none" |
| ) |
| |
| per_token_loss = per_token_loss * update_ids.float() |
| |
| |
| per_sample_loss = per_token_loss.sum(dim=1) / (update_ids.sum(dim=1).float() + 1e-8) |
| |
| |
| weighted_loss = per_sample_loss * batch_weights |
|
|
| |
| with torch.no_grad(): |
| metrics = {} |
| sel_u = update_ids.bool() |
| if sel_u.any(): |
| u_scores = remasking_conf.squeeze(-1)[sel_u] |
| u_labels = binary_label[sel_u] |
| metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) |
| metrics["unmask_label_mean"] = u_labels.mean().item() |
| metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() |
| metrics["unmask_n"] = float(sel_u.sum().item()) |
| self._last_planner_metrics = metrics |
|
|
| return weighted_loss.mean() |
| |
| def loss_insert_planner_flexible(self, log_rnd, x, num_replicates=16, weight_func=lambda l: 1/l, eps=1e-3, centering=False, centering_strength=1.0, softmax_temperature=1.0): |
| r""" |
| Weighted denoising cross entropy loss |
| X_T ~ P^u_T and weights \log\frac{dP^*}{dP^u}(X) |
| |
| log_rnd: [B] — pre-computed importance weights |
| x: [B, L] (no mask) |
| num_replicates: R, number of replicates of each row in x |
| weight_func: w(lambda) for each sample, 1/lambda by default |
| centering_strength: float, controls how much of the mean is subtracted (DMPO-style) |
| softmax_temperature: float, temperature for softmax on log_rnd (>1 smooths weights) |
| """ |
| |
| batch = x.repeat_interleave(num_replicates, dim=0) |
| batch_size = batch.shape[0] |
| |
| batch_weights = (log_rnd.detach() / softmax_temperature).softmax(dim=-1) |
| if centering: |
| batch_weights = batch_weights - centering_strength * batch_weights.mean() |
| |
| batch_weights = batch_weights.repeat_interleave(num_replicates, dim=0) |
| |
| lamda = torch.rand(batch.shape[0], device=batch.device) |
| lamda_weights = weight_func(lamda).clamp(max=1e5) |
| |
| t = lamda |
| scale_factor = self.config.interpolant.max_length |
| |
| |
| |
| |
| interpolant_sample, deleted_mask, gap_assignment = self.interpolant.sample_interpolant_plan(t, batch) |
| unmask_weight, insert_weight = self.interpolant.elbo_weight(t, batch) |
|
|
| prediction: ModelPrediction = self(interpolant_sample.xt, t) |
| |
| with torch.no_grad(): |
| xs_unmask, update_unmask_ids, xs_insert, update_ins_ids = self.one_step_sampler(interpolant_sample.xt, t, prediction) |
|
|
| |
| |
| |
| |
| planner = self.planner(xs_unmask, t) |
| remasking_conf = planner["remasking_conf"] |
|
|
| |
| |
| |
| |
| |
| if self.planner.insertion_planner: |
| insertion_conf = self.planner(xs_insert, t)["insertion_conf"] |
| else: |
| insertion_conf = None |
| |
| |
| |
| |
| |
| st = interpolant_sample.st |
| batch_reordered = torch.gather(batch, 1, st) |
| |
| |
| binary_label = (xs_unmask == batch_reordered).float() |
| |
| |
| per_token_loss = F.binary_cross_entropy_with_logits( |
| remasking_conf.squeeze(-1), |
| binary_label, |
| reduction="none" |
| ) |
| |
| per_token_loss = per_token_loss * update_unmask_ids.float() |
| |
| |
| unmask_per_sample_loss = per_token_loss.sum(dim=1) / (update_unmask_ids.sum(dim=1).float() + 1e-8) |
| |
| |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| prediction_after_insert: ModelPrediction = self(xs_insert, t) |
| |
| |
| |
| token_probs = F.softmax(prediction_after_insert.token_logits, dim=-1) |
| |
| |
| |
| |
| |
| |
| vocab_size = token_probs.shape[-1] |
| L = token_probs.shape[1] |
| max_gaps = gap_assignment.shape[1] |
| |
| |
| |
| gap_vocab_mask = torch.zeros(batch_size, max_gaps, vocab_size, device=batch.device, dtype=torch.float) |
| |
| |
| |
| tokens_expanded = batch.unsqueeze(1).expand(batch_size, max_gaps, L) |
| |
| |
| valid_mask = (gap_assignment > 0) & (tokens_expanded != self.interpolant.pad_token) |
| |
| |
| gap_vocab_mask.scatter_add_( |
| 2, |
| tokens_expanded.clamp(0, vocab_size - 1), |
| valid_mask.float() |
| ) |
| |
| |
| gap_vocab_mask = (gap_vocab_mask > 0).float() |
| |
| |
| |
| |
| |
| |
| insertion_quality_full = (token_probs * gap_vocab_mask[:, :L, :]).sum(dim=-1) |
| |
| |
| insertion_quality = insertion_quality_full * update_ins_ids.float() |
| |
| |
| if insertion_conf is not None: |
| |
| |
| |
| |
| |
| ins_per_token_loss = F.binary_cross_entropy_with_logits( |
| insertion_conf.squeeze(-1), |
| insertion_quality, |
| reduction="none" |
| ) |
| |
| |
| ins_per_token_loss = ins_per_token_loss * update_ins_ids.float() |
| |
| |
| ins_per_sample_loss = ins_per_token_loss.sum(dim=1) / (update_ins_ids.sum(dim=1).float() + 1e-8) |
| else: |
| |
| ins_per_sample_loss = torch.zeros_like(unmask_per_sample_loss) |
| |
| |
| per_sample_loss = unmask_per_sample_loss + ins_per_sample_loss |
| |
| |
| weighted_loss = per_sample_loss * batch_weights |
|
|
| |
| |
| with torch.no_grad(): |
| metrics = {} |
| sel_u = update_unmask_ids.bool() |
| if sel_u.any(): |
| u_scores = remasking_conf.squeeze(-1)[sel_u] |
| u_labels = binary_label[sel_u] |
| metrics["unmask_auc"] = _binary_auc(u_scores, u_labels) |
| metrics["unmask_label_mean"] = u_labels.mean().item() |
| metrics["unmask_conf_mean"] = torch.sigmoid(u_scores).mean().item() |
| metrics["unmask_n"] = float(sel_u.sum().item()) |
| if insertion_conf is not None: |
| sel_i = update_ins_ids.bool() |
| if sel_i.any(): |
| i_scores = insertion_conf.squeeze(-1)[sel_i] |
| i_targets = insertion_quality[sel_i] |
| i_labels = (i_targets > 0.5).float() |
| metrics["insert_auc"] = _binary_auc(i_scores, i_labels) |
| metrics["insert_target_mean"] = i_targets.mean().item() |
| metrics["insert_conf_mean"] = torch.sigmoid(i_scores).mean().item() |
| metrics["insert_n"] = float(sel_i.sum().item()) |
| self._last_planner_metrics = metrics |
|
|
| return unmask_per_sample_loss.mean(), ins_per_sample_loss.mean(), weighted_loss.mean() |
|
|