Kseniia-Kholina commited on
Commit
c262d88
·
verified ·
1 Parent(s): a1aee96

script with esm embedding guidance

Browse files
Files changed (2) hide show
  1. diffusion_emb_guidance.py +1664 -0
  2. sample_emb_guidance.py +173 -0
diffusion_emb_guidance.py ADDED
@@ -0,0 +1,1664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for modeling discrete diffusion
2
+ (absorbing state or uniform) and AR
3
+ (a special case of absorbing state).
4
+ """
5
+ import itertools
6
+ import math
7
+ import typing
8
+ from dataclasses import dataclass
9
+
10
+ import hydra.utils
11
+ import lightning as L
12
+ import numpy as np
13
+ import omegaconf
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import torchmetrics
17
+ import transformers
18
+ from mamba_ssm.utils.generation import InferenceParams
19
+ from torch import Tensor
20
+ from tqdm.auto import tqdm
21
+ import pdb
22
+ import gc
23
+
24
+ import classifier
25
+ import dataloader
26
+ import models
27
+ import noise_schedule
28
+ from transformers import AutoTokenizer, EsmModel
29
+ from faesm.esm import FAEsmForMaskedLM
30
+ LOG2 = math.log(2)
31
+
32
+
33
+ def _sample_categorical(categorical_probs):
34
+ gumbel_norm = (
35
+ 1e-10
36
+ - (torch.rand_like(categorical_probs) + 1e-10).log()).to(categorical_probs.dtype)
37
+ return (categorical_probs / gumbel_norm).argmax(dim=-1)
38
+
39
+
40
+ def _unsqueeze(x, reference):
41
+ return x.view(
42
+ * x.shape,
43
+ * ((1,) * (len(reference.shape) - len(x.shape))))
44
+
45
+
46
+ @dataclass
47
+ class Loss:
48
+ loss: torch.FloatTensor
49
+ nlls: torch.FloatTensor
50
+ token_mask: torch.FloatTensor
51
+ recon_loss: typing.Optional[torch.FloatTensor] = None
52
+ diffusion_loss: typing.Optional[torch.FloatTensor] = None
53
+
54
+
55
+ class NLL(torchmetrics.aggregation.MeanMetric):
56
+ pass
57
+
58
+
59
+ class BPD(NLL):
60
+ def compute(self) -> Tensor:
61
+ """Computes the bits per dimension.
62
+
63
+ Returns:
64
+ bpd
65
+ """
66
+ return self.mean_value / self.weight / LOG2
67
+
68
+
69
+ class Perplexity(NLL):
70
+ def compute(self) -> Tensor:
71
+ """Computes the Perplexity.
72
+
73
+ Returns:
74
+ Perplexity
75
+ """
76
+ return torch.exp(self.mean_value / self.weight)
77
+
78
+
79
+ class Diffusion(L.LightningModule):
80
+ def __init__(
81
+ self,
82
+ config,
83
+ tokenizer: transformers.PreTrainedTokenizer):
84
+ super().__init__()
85
+ self.save_hyperparameters()
86
+ self.config = config
87
+
88
+ self.tokenizer = tokenizer
89
+ self.vocab_size = tokenizer.vocab_size
90
+
91
+ self.antithetic_sampling = config.training.antithetic_sampling
92
+ self.importance_sampling = config.training.importance_sampling
93
+ self.change_of_variables = config.training.change_of_variables
94
+ self.noise = noise_schedule.get_noise(config, dtype=self.dtype)
95
+
96
+ esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16)
97
+
98
+ original_binder_input = esm.tokenizer(self.config.sampling.original_binder, return_tensors="pt")
99
+ original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()}
100
+ original_binder_outputs = esm(**original_binder_input)
101
+ original_binder_embedding = original_binder_outputs['last_hidden_state']
102
+ self.original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1)
103
+
104
+
105
+ if self.config.is_vision:
106
+ self.mask_index = getattr(tokenizer, 'mask_token_id', -1)
107
+ else:
108
+ if (not hasattr(self.tokenizer, 'mask_token')
109
+ or tokenizer.mask_token is None):
110
+ self.mask_index = self.vocab_size
111
+ self.vocab_size += 1
112
+ else:
113
+ self.mask_index = tokenizer.mask_token_id
114
+
115
+ # Note: creating limiting distribution with
116
+ # broadcast-able batch and sequence dimensions.
117
+ self.parameterization = config.parameterization
118
+ self.diffusion = config.diffusion
119
+ if config.parameterization == 'ar':
120
+ self.limiting_distribution = None
121
+ else:
122
+ if self.diffusion == 'absorbing_state':
123
+ # Not needed, posterior calculated explicitly.
124
+ limiting_distribution = None
125
+ elif self.diffusion == 'uniform':
126
+ limiting_distribution = torch.ones(
127
+ (1, 1, self.vocab_size), dtype=self.dtype) / self.vocab_size
128
+ else:
129
+ raise NotImplementedError(
130
+ f"Diffusion type {self.diffusion} not implemented.")
131
+ self.register_buffer('limiting_distribution',
132
+ limiting_distribution)
133
+
134
+ self.T = config.T
135
+ self.subs_masking = config.subs_masking
136
+ self.time_conditioning = config.time_conditioning
137
+
138
+ if self.config.backbone == 'dit':
139
+ self.backbone = models.dit.DIT(
140
+ self.config, vocab_size=self.vocab_size)
141
+ elif self.config.backbone == 'dimamba':
142
+ self.backbone = models.dimamba.DiMamba(
143
+ self.config, vocab_size=self.vocab_size,
144
+ pad_token_id=self.tokenizer.pad_token_id)
145
+ elif self.config.backbone == 'unet':
146
+ self.backbone = models.unet.UNet(
147
+ self.config, vocab_size=self.vocab_size)
148
+ elif self.config.backbone == 'hf_dit':
149
+ self.backbone = transformers.AutoModelForMaskedLM.from_pretrained(
150
+ config.model.pretrained_model_name_or_path, trust_remote_code=True)
151
+ else:
152
+ raise NotImplementedError(
153
+ f"Backbone {self.config.backbone} not implemented.")
154
+
155
+ self.lr = self.config.optim.lr
156
+ self.sampling_eps = config.training.sampling_eps
157
+
158
+ self.softplus = torch.nn.Softplus()
159
+ self.neg_infinity = -1_000_000.0
160
+
161
+ if config.training.ema > 0:
162
+ self.ema = models.ema.ExponentialMovingAverage(
163
+ itertools.chain(self.backbone.parameters(),
164
+ self.noise.parameters()),
165
+ decay=config.training.ema)
166
+ else:
167
+ self.ema = None
168
+
169
+ # metrics are automatically reset at end of epoch
170
+ metrics = torchmetrics.MetricCollection({
171
+ 'nll': NLL(),
172
+ 'bpd': BPD(),
173
+ 'ppl': Perplexity(),
174
+ })
175
+ metrics.set_dtype(torch.float64)
176
+ self.train_metrics = metrics.clone(prefix='train/')
177
+ self.valid_metrics = metrics.clone(prefix='val/')
178
+ self.test_metrics = metrics.clone(prefix='test/')
179
+
180
+ self.fast_forward_epochs = None
181
+ self.fast_forward_batches = None
182
+
183
+ self._validate_configuration()
184
+
185
+
186
+ def _validate_configuration(self):
187
+ assert not (self.change_of_variables
188
+ and self.importance_sampling)
189
+ if self.diffusion != 'absorbing_state':
190
+ assert self.parameterization not in {'ar', 'subs'}
191
+ if self.T > 0:
192
+ assert self.parameterization in {'d3pm', 'subs'}
193
+ if self.subs_masking:
194
+ assert self.parameterization == 'd3pm'
195
+
196
+ def on_load_checkpoint(self, checkpoint):
197
+ if self.limiting_distribution is not None:
198
+ checkpoint['state_dict']['limiting_distribution'] = self.limiting_distribution.to(
199
+ list(checkpoint['state_dict'].values())[0].device)
200
+ if self.ema:
201
+ self.ema.load_state_dict(checkpoint['ema'])
202
+ # Copied from:
203
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py#L41
204
+ self.fast_forward_epochs = checkpoint['loops'][
205
+ 'fit_loop']['epoch_progress']['current']['completed']
206
+ self.fast_forward_batches = checkpoint['loops'][
207
+ 'fit_loop']['epoch_loop.batch_progress'][
208
+ 'current']['completed']
209
+
210
+ def on_save_checkpoint(self, checkpoint):
211
+ # Do not save this buffer
212
+ checkpoint['state_dict'].pop('limiting_distribution',
213
+ None)
214
+ if self.ema:
215
+ checkpoint['ema'] = self.ema.state_dict()
216
+ # Copied from:
217
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/tasks/seq.py
218
+ # ['epoch_loop.batch_progress']['total']['completed'] is
219
+ # 1 iteration behind, so we're using the optimizer's
220
+ # progress.
221
+ checkpoint['loops']['fit_loop'][
222
+ 'epoch_loop.batch_progress']['total'][
223
+ 'completed'] = checkpoint['loops']['fit_loop'][
224
+ 'epoch_loop.automatic_optimization.optim_progress'][
225
+ 'optimizer']['step']['total'][
226
+ 'completed'] * self.trainer.accumulate_grad_batches
227
+ checkpoint['loops']['fit_loop'][
228
+ 'epoch_loop.batch_progress']['current'][
229
+ 'completed'] = checkpoint['loops']['fit_loop'][
230
+ 'epoch_loop.automatic_optimization.optim_progress'][
231
+ 'optimizer']['step']['current'][
232
+ 'completed'] * self.trainer.accumulate_grad_batches
233
+ # _batches_that_stepped tracks the number of global
234
+ # steps, not the number of local steps, so we don't
235
+ # multiply with self.trainer.accumulate_grad_batches
236
+ # here.
237
+ checkpoint['loops']['fit_loop'][
238
+ 'epoch_loop.state_dict'][
239
+ '_batches_that_stepped'] = checkpoint['loops']['fit_loop'][
240
+ 'epoch_loop.automatic_optimization.optim_progress'][
241
+ 'optimizer']['step']['total']['completed']
242
+ if 'sampler' not in checkpoint.keys():
243
+ checkpoint['sampler'] = {}
244
+ if hasattr(self.trainer.train_dataloader.sampler,
245
+ 'state_dict'):
246
+ sampler_state_dict = self.trainer.\
247
+ train_dataloader.sampler.state_dict()
248
+ checkpoint['sampler'][
249
+ 'random_state'] = sampler_state_dict.get(
250
+ 'random_state', None)
251
+ else:
252
+ checkpoint['sampler']['random_state'] = None
253
+
254
+ def on_train_start(self):
255
+ if self.ema:
256
+ self.ema.move_shadow_params_to_device(self.device)
257
+ # Adapted from:
258
+ # https://github.com/Dao-AILab/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
259
+ distributed = (
260
+ self.trainer._accelerator_connector.use_distributed_sampler
261
+ and self.trainer._accelerator_connector.is_distributed)
262
+ if distributed:
263
+ sampler_cls = dataloader.FaultTolerantDistributedSampler
264
+ else:
265
+ sampler_cls = dataloader.RandomFaultTolerantSampler
266
+ updated_dls = []
267
+ for dl in self.trainer.fit_loop._combined_loader.flattened:
268
+ if hasattr(dl.sampler, 'shuffle'):
269
+ dl_sampler = sampler_cls(
270
+ dl.dataset, shuffle=dl.sampler.shuffle)
271
+ else:
272
+ dl_sampler = sampler_cls(dl.dataset)
273
+ if (distributed
274
+ and self.fast_forward_epochs is not None
275
+ and self.fast_forward_batches is not None):
276
+ dl_sampler.load_state_dict({
277
+ 'epoch': self.fast_forward_epochs,
278
+ 'counter': (self.fast_forward_batches
279
+ * self.config.loader.batch_size)})
280
+
281
+ from functools import partial
282
+ from dataloader import collate_fn
283
+ collate_partial = partial(collate_fn)
284
+ torch.cuda.empty_cache()
285
+
286
+ updated_dls.append(
287
+ torch.utils.data.DataLoader(
288
+ dl.dataset,
289
+ # batch_size=self.config.loader.batch_size,
290
+ num_workers=self.config.loader.num_workers,
291
+ pin_memory=self.config.loader.pin_memory,
292
+ # sampler=dl_sampler,
293
+ shuffle=False,
294
+ persistent_workers=self.config.loader.persistent_workers,
295
+ collate_fn=collate_partial
296
+ ))
297
+ self.trainer.fit_loop._combined_loader.flattened = updated_dls
298
+
299
+ def configure_optimizers(self):
300
+ # TODO(yair): Lightning currently giving this warning when using `fp16`:
301
+ # "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
302
+ # Not clear if this is a problem or not.
303
+ # See: https://github.com/Lightning-AI/pytorch-lightning/issues/5558
304
+ optimizer = torch.optim.AdamW(
305
+ itertools.chain(self.backbone.parameters(),
306
+ self.noise.parameters()),
307
+ lr=self.config.optim.lr,
308
+ betas=(self.config.optim.beta1,
309
+ self.config.optim.beta2),
310
+ eps=self.config.optim.eps,
311
+ weight_decay=self.config.optim.weight_decay)
312
+
313
+ scheduler = hydra.utils.instantiate(
314
+ self.config.lr_scheduler, optimizer=optimizer)
315
+ scheduler_dict = {
316
+ 'scheduler': scheduler,
317
+ 'interval': 'step',
318
+ 'monitor': 'val/loss',
319
+ 'name': 'trainer/lr',
320
+ }
321
+ return [optimizer], [scheduler_dict]
322
+
323
+ def optimizer_step(self, *args, **kwargs):
324
+ super().optimizer_step(*args, **kwargs)
325
+ if self.ema:
326
+ self.ema.update(itertools.chain(
327
+ self.backbone.parameters(),
328
+ self.noise.parameters()))
329
+
330
+ def _subs_parameterization(self, logits, xt):
331
+ # "Zero Masking Prob":
332
+ # log prob at the mask index = - infinity
333
+ logits[..., self.mask_index] += self.neg_infinity
334
+
335
+ # "Copy over":
336
+ # Apply updates directly in the logits matrix.
337
+ # For the logits of the unmasked tokens, set all values
338
+ # to -infinity except for the indices corresponding to
339
+ # the unmasked tokens.
340
+ unmasked_indices = (xt != self.mask_index)
341
+ logits[unmasked_indices] = self.neg_infinity
342
+ logits[unmasked_indices, xt[unmasked_indices]] = 0
343
+
344
+ # Normalize the logits such that x.exp() is
345
+ # a probability distribution over vocab_size.
346
+ return logits.log_softmax(dim=-1)
347
+
348
+ def _process_sigma(self, sigma):
349
+ if sigma is None:
350
+ assert self.parameterization == 'ar'
351
+ return sigma
352
+ if sigma.ndim > 1:
353
+ sigma = sigma.squeeze(-1)
354
+ if not self.time_conditioning:
355
+ sigma = torch.zeros_like(sigma)
356
+ assert sigma.ndim == 1, sigma.shape
357
+ return sigma
358
+
359
+ def forward(self, x, sigma, cond=None, x_emb=None, **kwargs):
360
+ """Returns log_probs / logits."""
361
+
362
+ sigma = self._process_sigma(sigma)
363
+
364
+ with torch.cuda.amp.autocast(dtype=torch.float32):
365
+
366
+
367
+ logits = self.backbone(x, sigma, cond, x_emb=x_emb, **kwargs)
368
+
369
+ if self.parameterization == 'subs':
370
+ # returns log_probs
371
+ return self._subs_parameterization(
372
+ logits=logits, xt=x)
373
+ if self.parameterization in {'ar', 'd3pm'}:
374
+ # returns log_probs
375
+ if self.subs_masking: # Can use "zero masking prob"
376
+ logits[:, :, self.mask_index] += self.neg_infinity
377
+ return logits.log_softmax(dim=-1)
378
+ return logits
379
+
380
+ def _compute_posterior(self, x, xt, alpha_s, alpha_t):
381
+ """Computes the posterior / approximate posterior.
382
+
383
+ Args:
384
+ x: Either clean input `x0` (one-hot),
385
+ or model's predicted `x_theta` of shape (B, L, V).
386
+ xt: The noisy latent (as indices) of shape (B, L).
387
+ alpha_s: Noise level at s of shape (B, [L | 1], 1).
388
+ alpha_t: Noise level at t of shape (B, [L | 1], 1).
389
+
390
+ Returns:
391
+ Posterior / approximate posterior of shape (B, L, V).
392
+ """
393
+ alpha_ts = alpha_t / alpha_s
394
+ d_alpha = alpha_s - alpha_t
395
+ xt_one_hot = F.one_hot(xt, self.vocab_size)
396
+ if self.diffusion == 'uniform':
397
+ return (
398
+ (alpha_t * self.vocab_size * x * xt_one_hot +
399
+ (alpha_ts - alpha_t) * xt_one_hot +
400
+ d_alpha * x +
401
+ (1 - alpha_ts) * (1 - alpha_s) * self.limiting_distribution)
402
+ /
403
+ (alpha_t * self.vocab_size * torch.gather(x, -1, xt[..., None]) +
404
+ (1 - alpha_t))
405
+ )
406
+ raise NotImplementedError(
407
+ f"Diffusion type {self.diffusion} not implemented.")
408
+
409
+ def _d3pm_loss(self, model_output, xt, x0, t):
410
+ assert self.config.noise.type == 'loglinear', (
411
+ 'D3PM loss only implemented for log-linear noise.')
412
+ dt = 1 / self.T
413
+
414
+ if torch.is_tensor(t):
415
+ t = t[:, None]
416
+ assert t.ndim == 2
417
+ t = t.clamp(0., 1. - 1e-4)
418
+ alpha_t = 1 - t + torch.zeros_like(xt)
419
+ alpha_s = 1 - (t - dt) + torch.zeros_like(xt)
420
+
421
+ if self.diffusion == 'absorbing_state':
422
+ log_x_theta_at_x0 = torch.gather(
423
+ model_output, -1, x0[:, :, None]).squeeze(-1)
424
+ log_x_theta_at_m = model_output[:, :, self.mask_index]
425
+ x_theta_at_m = log_x_theta_at_m.exp()
426
+
427
+ term_1_coef = dt / t
428
+ term_1_log_nr = torch.log(alpha_t * x_theta_at_m / t + 1)
429
+ term_1_log_dr = log_x_theta_at_x0
430
+
431
+ term_2_coef = 1 - dt / t
432
+ term_2_log_nr = term_1_log_nr
433
+ term_2_log_dr = torch.log(alpha_s * x_theta_at_m / (t - dt) + 1)
434
+
435
+ L_vb_masked = (
436
+ term_1_coef * (term_1_log_nr - term_1_log_dr)
437
+ + term_2_coef * (term_2_log_nr - term_2_log_dr))
438
+
439
+ L_vb = L_vb_masked * (xt == self.mask_index)
440
+ elif self.diffusion == 'uniform':
441
+ posterior = self._compute_posterior(
442
+ x=F.one_hot(x0, num_classes=self.vocab_size).to(self.dtype),
443
+ xt=xt,
444
+ alpha_s=alpha_s[..., None],
445
+ alpha_t=alpha_t[..., None])
446
+ posterior_pred = self._compute_posterior(
447
+ x=model_output.exp(),
448
+ xt=xt,
449
+ alpha_s=alpha_s[..., None],
450
+ alpha_t=alpha_t[..., None])
451
+ L_vb = (
452
+ posterior * (torch.log(posterior + 1e-12) - torch.log(posterior_pred))
453
+ ).sum(dim=-1)
454
+ else:
455
+ raise NotImplementedError(
456
+ f"Diffusion type {self.diffusion} not implemented for D3PM.")
457
+ return self.T * L_vb
458
+
459
+ def _reconstruction_loss(self, x0, cond=None):
460
+ # For D3PM parameterization
461
+ assert self.config.noise.type == 'loglinear', (
462
+ 'Reconstruction loss only implemented for log-linear '
463
+ 'noise.')
464
+ t0 = torch.zeros(x0.shape[0], dtype=self.dtype,
465
+ device=self.device)
466
+ time_conditioning = self.noise(t0)[0][:, None]
467
+ model_output_t0 = self.forward(x0, time_conditioning,
468
+ cond=cond)
469
+ return - torch.gather(input=model_output_t0,
470
+ dim=-1,
471
+ index=x0[:, :, None]).squeeze(-1)
472
+
473
+ def _sample_t(self, n):
474
+ _eps_t = torch.rand(n, device=self.device)
475
+ if self.antithetic_sampling:
476
+ offset = torch.arange(n, device=self.device) / n
477
+ _eps_t = (_eps_t / n + offset) % 1
478
+ t = (1 - self.sampling_eps) * _eps_t + self.sampling_eps
479
+ if self.importance_sampling:
480
+ return self.noise.importance_sampling_transformation(
481
+ t)
482
+ return t
483
+
484
+ def _q_xt(self, x, move_chance):
485
+ """Computes the noisy sample xt.
486
+
487
+ Args:
488
+ x: int torch.Tensor with shape (batch_size,
489
+ diffusion_model_input_length), input.
490
+ move_chance: float torch.Tensor with shape
491
+ (batch_size, 1).
492
+ """
493
+ move_indices = torch.rand(
494
+ *x.shape, device=x.device) < move_chance
495
+ if self.diffusion == 'absorbing_state':
496
+ return torch.where(move_indices, self.mask_index, x)
497
+ if self.diffusion == 'uniform':
498
+ uniform_tensor = torch.randint(
499
+ 0, self.vocab_size, x.shape, device=x.device)
500
+ return torch.where(move_indices, uniform_tensor, x)
501
+ elif self.diffusion == 'uniform_data_marginals':
502
+ return torch.where(
503
+ move_indices,
504
+ self._sample_prior(*x.shape),
505
+ x)
506
+ raise NotImplementedError(
507
+ f"Diffusion type {self.diffusion} not implemented.")
508
+
509
+ def _forward_pass_diffusion(self, x0, cond=None):
510
+ t = self._sample_t(x0.shape[0])
511
+ if self.T > 0:
512
+ t = (t * self.T).to(torch.int)
513
+ t = t / self.T
514
+ # t \in {1/T, 2/T, ..., 1}
515
+ t += (1 / self.T)
516
+
517
+ if self.change_of_variables:
518
+ time_conditioning = t[:, None]
519
+ f_T = torch.log1p(- torch.exp(- self.noise.sigma_max))
520
+ f_0 = torch.log1p(- torch.exp(- self.noise.sigma_min))
521
+ move_chance = torch.exp(f_0 + t * (f_T - f_0))
522
+ move_chance = move_chance[:, None]
523
+ sigma, dsigma = None, None
524
+ else:
525
+ sigma, dsigma = self.noise(t)
526
+ time_conditioning = sigma[:, None]
527
+ move_chance = 1 - torch.exp(-sigma[:, None])
528
+
529
+ xt = self._q_xt(x0, move_chance)
530
+ model_output = self.forward(xt, time_conditioning,
531
+ cond=cond)
532
+
533
+ # Discrete (finite T) time
534
+ if self.T > 0:
535
+ diffusion_loss = self._d3pm_loss(
536
+ model_output=model_output, xt=xt, x0=x0, t=t)
537
+ if self.parameterization == 'd3pm':
538
+ reconstruction_loss = self._reconstruction_loss(
539
+ x0, cond=cond)
540
+ if self.training and self.config.training.use_simple_ce_loss:
541
+ loss = -torch.gather(
542
+ input=model_output,
543
+ dim=-1,
544
+ index=x0[:, :, None]).squeeze(-1)
545
+ else:
546
+ loss = reconstruction_loss + diffusion_loss
547
+ return {
548
+ 'recon_loss': reconstruction_loss,
549
+ 'diffusion_loss': diffusion_loss,
550
+ 'loss': loss}
551
+ elif self.parameterization == 'subs':
552
+ if self.training and self.config.training.use_simple_ce_loss:
553
+ loss = -torch.gather(
554
+ input=model_output,
555
+ dim=-1,
556
+ index=x0[:, :, None]).squeeze(-1)
557
+ else:
558
+ loss = diffusion_loss
559
+ return {'diffusion_loss': diffusion_loss, 'loss': loss}
560
+ else:
561
+ raise ValueError(
562
+ f"Invalid parameterization: {self.parameterization} for T > 0.")
563
+
564
+ # Continuous (T --> infty) time
565
+ if self.diffusion == 'absorbing_state':
566
+ # SUBS parameterization, continuous time.
567
+ log_p_theta = torch.gather(
568
+ input=model_output,
569
+ dim=-1,
570
+ index=x0[:, :, None]).squeeze(-1)
571
+
572
+ if self.change_of_variables or self.importance_sampling:
573
+ if self.training and self.config.training.use_simple_ce_loss:
574
+ return {
575
+ 'diffusion_loss': log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min)),
576
+ 'loss': -log_p_theta
577
+ }
578
+ return log_p_theta * torch.log1p(-torch.exp(- self.noise.sigma_min))
579
+
580
+ if self.training and self.config.training.use_simple_ce_loss:
581
+ return {
582
+ 'diffusion_loss': log_p_theta * (dsigma / torch.expm1(sigma))[:, None],
583
+ 'loss': log_p_theta
584
+ }
585
+ return - log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
586
+
587
+ elif self.diffusion == 'uniform':
588
+ assert self.config.noise.type == 'loglinear', (
589
+ 'Continuous time uniform diffusion only implemented'
590
+ ' for log-linear noise.')
591
+ # TODO: Currently α_t' and α_t are hardcoded to a
592
+ # log-linear noise.
593
+ # Make generic (as above, for absorbing state):
594
+ # alpha_t_prime = -dsigma * (-sigma).exp()
595
+ # alpha_t = (-sigma).exp()
596
+ alpha_t_prime = -1.
597
+ alpha_t = 1. - t[..., None, None] # B, 1, 1
598
+
599
+ # x_bar = N * α_t * x + 1 - α_t ; B, L, V
600
+ x_bar = self.vocab_size * alpha_t * F.one_hot(x0, self.vocab_size).float() + 1 - alpha_t
601
+ x_bar_theta = self.vocab_size * alpha_t * model_output.exp() + 1 - alpha_t
602
+
603
+ # α_t' / (N*α_t)
604
+ coeff = alpha_t_prime / (self.vocab_size * alpha_t) # B, 1, 1
605
+
606
+ # Term 1: indices where z_t = 1
607
+ x_bar_zt = torch.gather(x_bar, -1, xt[..., None]) # B, L, 1
608
+ x_bar_theta_zt = torch.gather(x_bar_theta, -1, xt[..., None]) # B, L, 1
609
+ term1 = ((self.vocab_size / x_bar_zt) - (self.vocab_size / x_bar_theta_zt)) # B, L, 1
610
+
611
+ # Term 2: indices where z_t = 0
612
+ term2 = ( # B, L, V before summing --> B, L, 1 after
613
+ (x_bar / x_bar_zt) *
614
+ (
615
+ x_bar_theta_zt.log() - x_bar_theta.log() +
616
+ x_bar.log() - x_bar_zt.log()
617
+ )
618
+ )
619
+ term2 = term2.sum(dim=-1, keepdim=True) # B, L, 1
620
+
621
+ diffusion_loss = (coeff * (term1 - term2)).squeeze() # B, L
622
+ reconstruction_loss = self._reconstruction_loss(
623
+ x0, cond=cond)
624
+ if self.training and self.config.training.use_simple_ce_loss:
625
+ return {
626
+ 'recon_loss': reconstruction_loss,
627
+ 'diffusion_loss': diffusion_loss,
628
+ 'loss': -torch.gather(
629
+ input=model_output,
630
+ dim=-1,
631
+ index=x0[:, :, None]).squeeze(-1)
632
+ }
633
+ return {
634
+ 'recon_loss': reconstruction_loss,
635
+ 'diffusion_loss': diffusion_loss,
636
+ 'loss': diffusion_loss if getattr(self.config, 'zero_recon_loss', False)
637
+ else diffusion_loss + reconstruction_loss
638
+ }
639
+ else:
640
+ raise NotImplementedError(
641
+ f"Diffusion type {self.diffusion} not "
642
+ "implemented for continuous time case.")
643
+
644
+ def _maybe_sub_sample(self, x0, attention_mask):
645
+ seqlen = x0.shape[1]
646
+ # if seqlen > self.config.model.length:
647
+ # assert seqlen == 2 * self.config.model.length
648
+ # # cropping is necessary for the text8-crop dataset;
649
+ # # try the same starting point for now
650
+ # start = np.random.choice(self.config.model.length)
651
+ # end = start + self.config.model.length
652
+ # input_tokens = x0[:, start: end]
653
+ # output_tokens = x0[:, start + 1: end + 1]
654
+ # new_attention_mask = attention_mask[:, start: end]
655
+
656
+ # # Helps with validation PPL, since the val
657
+ # # examples will all start and end with BOS/EOS
658
+ # input_tokens[:, 0] = self.tokenizer.bos_token_id
659
+ # output_tokens[:, -1] = self.tokenizer.eos_token_id
660
+ # elif self.parameterization == 'ar':
661
+ # input_tokens = x0[:, :-1]
662
+ # output_tokens = x0[:, 1:]
663
+ # new_attention_mask = attention_mask[:, 1:]
664
+ # else:
665
+ # input_tokens = x0
666
+ # output_tokens = None
667
+ # new_attention_mask = attention_mask
668
+
669
+ input_tokens = x0
670
+ output_tokens = None
671
+ new_attention_mask = attention_mask
672
+ return input_tokens, output_tokens, new_attention_mask
673
+
674
+ def _loss(self, x0, attention_mask, cond=None):
675
+ (input_tokens, output_tokens,
676
+ attention_mask) = self._maybe_sub_sample(
677
+ x0, attention_mask)
678
+
679
+ recon_loss, diffusion_loss = None, None
680
+
681
+ if (cond is not None and self.training
682
+ and self.config.training.guidance is not None
683
+ and self.config.training.guidance.cond_dropout > 0):
684
+ # Randomly mask out conditioning for classifier-free
685
+ # guidance training.
686
+ p = torch.bernoulli(
687
+ torch.ones_like(cond) *
688
+ self.config.training.guidance.cond_dropout).to(torch.bool)
689
+ # Use num_classes index as conditioning mask_token_id
690
+ cond[p] = self.config.data.num_classes
691
+
692
+ if self.parameterization == 'ar':
693
+ logprobs = self.forward(
694
+ input_tokens, sigma=None, cond=cond)
695
+ loss = - logprobs.gather(
696
+ -1, output_tokens[:, :, None])[:, :, 0]
697
+ else:
698
+ loss = self._forward_pass_diffusion(input_tokens,
699
+ cond=cond)
700
+ if isinstance(loss, dict):
701
+ recon_loss = loss['recon_loss']
702
+ diffusion_loss = loss['diffusion_loss']
703
+ loss = loss['loss']
704
+
705
+ nlls = loss * attention_mask
706
+ count = attention_mask.sum()
707
+
708
+ if (self.config.training.compute_loss_on_pad_tokens
709
+ and self.training):
710
+ token_nll = loss.mean()
711
+ else:
712
+ batch_nll = nlls.sum()
713
+ token_nll = batch_nll / count
714
+
715
+ if recon_loss is not None and diffusion_loss is not None:
716
+ with torch.no_grad():
717
+ recon_loss_batch = (recon_loss * attention_mask).sum() / count
718
+ diffusion_loss_batch = (diffusion_loss * attention_mask).sum() / count
719
+ return Loss(loss=token_nll,
720
+ nlls=nlls,
721
+ token_mask=attention_mask,
722
+ recon_loss=recon_loss_batch,
723
+ diffusion_loss=diffusion_loss_batch)
724
+ return Loss(loss=token_nll,
725
+ nlls=nlls,
726
+ token_mask=attention_mask)
727
+
728
+ def _compute_loss(self, batch, prefix):
729
+ if 'attention_mask' in batch:
730
+ attention_mask = batch['attention_mask']
731
+ else:
732
+ attention_mask = None
733
+ cond = None
734
+ if (self.config.training.guidance is not None or # Training for / using CFG
735
+ (hasattr(self.config, 'guidance')
736
+ and self.config.guidance is not None
737
+ and self.config.guidance.method == 'cfg')):
738
+ if self.config.data.label_col in batch:
739
+ cond = batch[self.config.data.label_col]
740
+ elif f"{self.config.data.label_col}_threshold" in batch:
741
+ cond = batch[f"{self.config.data.label_col}_threshold"]
742
+ else:
743
+ raise RuntimeError(
744
+ f"Conditioning {self.config.data.label_col}"
745
+ f" not found in batch.")
746
+ losses = self._loss(batch['input_ids'], attention_mask,
747
+ cond=cond)
748
+
749
+ if prefix == 'train':
750
+ self.train_metrics.update(losses.nlls,
751
+ losses.token_mask)
752
+ metrics = self.train_metrics
753
+ elif prefix == 'val':
754
+ self.valid_metrics.update(losses.nlls,
755
+ losses.token_mask)
756
+ metrics = self.valid_metrics
757
+ elif prefix == 'test':
758
+ self.test_metrics.update(losses.nlls,
759
+ losses.token_mask)
760
+ metrics = self.test_metrics
761
+ else:
762
+ raise ValueError(f"Invalid prefix: {prefix}")
763
+
764
+ self.log_dict(metrics,
765
+ on_step=False,
766
+ on_epoch=True,
767
+ sync_dist=True)
768
+ return losses
769
+
770
+ def training_step(self, batch, batch_idx):
771
+ losses = self._compute_loss(batch, prefix='train')
772
+ self.log(name='trainer/loss',
773
+ value=losses.loss.item(),
774
+ on_step=True,
775
+ on_epoch=True,
776
+ sync_dist=True,
777
+ prog_bar=True)
778
+ if losses.recon_loss is not None:
779
+ self.log(name='trainer/recon_loss',
780
+ value=losses.recon_loss.item(),
781
+ on_step=True,
782
+ on_epoch=True,
783
+ sync_dist=True,
784
+ prog_bar=False)
785
+ self.log(name='trainer/diffusion_loss',
786
+ value=losses.diffusion_loss.item(),
787
+ on_step=True,
788
+ on_epoch=True,
789
+ sync_dist=True,
790
+ prog_bar=False)
791
+ self.log(name='lr',
792
+ value=self.trainer.optimizers[0].param_groups[0]['lr'],
793
+ on_step=True,
794
+ on_epoch=False,
795
+ sync_dist=True,
796
+ prog_bar=True, logger=False)
797
+ return losses.loss
798
+
799
+ def validation_step(self, batch, batch_idx):
800
+ losses = self._compute_loss(batch, prefix='val')
801
+ self.log(name='trainer/val_loss',
802
+ value=losses.loss.item(),
803
+ on_step=True,
804
+ on_epoch=True,
805
+ prog_bar=True,
806
+ sync_dist=True)
807
+ return losses.loss
808
+
809
+ def load_ema_params(self):
810
+ if self.ema:
811
+ self.ema.store(itertools.chain(
812
+ self.backbone.parameters(),
813
+ self.noise.parameters()))
814
+ self.ema.copy_to(itertools.chain(
815
+ self.backbone.parameters(),
816
+ self.noise.parameters()))
817
+
818
+ def _restore_non_ema_params(self):
819
+ if self.ema:
820
+ self.ema.restore(itertools.chain(
821
+ self.backbone.parameters(),
822
+ self.noise.parameters()))
823
+
824
+ def on_validation_epoch_start(self):
825
+ # pdb.set_trace()
826
+ gc.collect()
827
+ torch.cuda.empty_cache()
828
+ self.load_ema_params()
829
+ assert self.valid_metrics.nll.mean_value == 0
830
+ assert self.valid_metrics.nll.weight == 0
831
+
832
+ def on_validation_epoch_end(self):
833
+ # pdb.set_trace()
834
+ # self._restore_non_ema_params()
835
+ # if (not self.trainer.sanity_checking
836
+ # and self.config.eval.generate_samples
837
+ # and self.trainer.global_rank == 0):
838
+ # self.config.sampling.batch_size = 1
839
+ # if self.config.is_vision:
840
+ # samples = []
841
+ # if self.config.training.guidance is not None:
842
+ # # Generate one image per class (up to 10 images)
843
+
844
+ # guidance = {
845
+ # 'method': 'cfg', 'condition': 0, 'gamma': 1.0}
846
+ # omegaconf.OmegaConf.update(
847
+ # self.config, key='guidance', value=guidance,
848
+ # force_add=True)
849
+ # for i in range(max(self.config.data.num_classes, 10)):
850
+ # self.config.guidance.condition = i
851
+ # samples.append(self.sample())
852
+ # else:
853
+ # # Generate ten images
854
+ # for i in range(10):
855
+ # samples.append(self.sample())
856
+ # image_samples = self.tokenizer.batch_decode(
857
+ # torch.concat(samples, dim=0))
858
+ # if hasattr(self.trainer.logger, 'log_image'):
859
+ # self.trainer.logger.log_image(
860
+ # key=f"samples@global_step{self.global_step}",
861
+ # caption=[str(i) for i in range(len(samples))],
862
+ # images=[s for s in image_samples.float()])
863
+ # else:
864
+ # if self.config.training.guidance is not None:
865
+ # guidance = {
866
+ # 'method': 'cfg', 'condition': 0, 'gamma': 1.0}
867
+ # omegaconf.OmegaConf.update(
868
+ # self.config, key='guidance', value=guidance,
869
+ # force_add=True)
870
+ # for i in range(self.config.data.num_classes):
871
+ # self.config.guidance.condition = i
872
+ # samples = self.sample()
873
+ # decoded_samples = self.tokenizer.batch_decode(
874
+ # samples)
875
+ # if hasattr(self.trainer.logger, 'log_table'):
876
+ # # Log some generated samples
877
+ # self.trainer.logger.log_table(
878
+ # key=f"samples@global_step{self.global_step}_class-{i}",
879
+ # columns=['Generated Samples'],
880
+ # data=[decoded_samples])
881
+ # else:
882
+ # self.config.sampling.batch_size = 2
883
+ # samples = self.sample()
884
+ # decoded_samples = self.tokenizer.batch_decode(
885
+ # samples)
886
+ # if hasattr(self.trainer.logger, 'log_table'):
887
+ # # Log some generated samples
888
+ # self.trainer.logger.log_table(
889
+ # key=f"samples@global_step{self.global_step}",
890
+ # columns=['Generated Samples'],
891
+ # data=[[s] for s in decoded_samples])
892
+ gc.collect()
893
+ torch.cuda.empty_cache()
894
+ self._restore_non_ema_params()
895
+
896
+ def _sample_prior(self, *batch_dims):
897
+ if self.diffusion == 'absorbing_state':
898
+ return self.mask_index * torch.ones(
899
+ *batch_dims, dtype=torch.int64, device=self.device)
900
+ if self.diffusion == 'uniform':
901
+ return torch.randint(
902
+ 0, self.vocab_size, batch_dims, dtype=torch.int64,
903
+ device=self.device)
904
+ elif self.diffusion == 'uniform_data_marginals':
905
+ if self.limiting_distribution.squeeze().ndim == 2:
906
+ batch_dims = (batch_dims[0],)
907
+ return torch.distributions.Categorical(
908
+ self.limiting_distribution.squeeze()).sample(
909
+ sample_shape=torch.Size(batch_dims))
910
+ raise NotImplementedError(
911
+ f'Diffusion type {self.diffusion} not '
912
+ 'implemented.')
913
+
914
+ def sample(
915
+ self,
916
+ eps=1e-5,
917
+ target_sequence: torch.tensor = None,
918
+ target_motifs: torch.tensor = None,
919
+ classifier_model = None): # Note: differs from self.config.training.sampling_eps
920
+ """Generate samples from (ema) model.
921
+
922
+ Supports both AR and diffusion sampling.
923
+ Supports:
924
+ - standard decoding,
925
+ - classifier-free guidance,
926
+ - classifier-based guidance
927
+ - CBG / FUDGE,
928
+ - NOS / PPLM.
929
+ """
930
+ # WARNING: Lightning auto-casting is not working in this method.
931
+
932
+
933
+
934
+ if not self.config.eval.disable_ema:
935
+ self.load_ema_params()
936
+ if getattr(self.config, 'guidance', None) is not None:
937
+ if self.config.guidance.method == 'cfg':
938
+ cond = (torch.ones(self.config.sampling.batch_size, device=self.device) *
939
+ self.config.guidance.condition).to(torch.long)
940
+ else:
941
+ cond = None
942
+ if ((self.parameterization == 'ar' and self.config.guidance.method in {'fudge', 'pplm'})
943
+ or self.config.guidance.method in {'cbg', 'nos'}):
944
+ if classifier_model is None:
945
+ classifier_model = classifier.Classifier.load_from_checkpoint(
946
+ self.config.guidance.classifier_checkpoint_path,
947
+ tokenizer=self.tokenizer,
948
+ config=self.config, logger=False)
949
+ classifier_model = classifier_model.to(self.device)
950
+ classifier_model.eval()
951
+ else:
952
+ classifier_model = None
953
+ else:
954
+ classifier_model, cond = None, None
955
+
956
+ if self.parameterization == 'ar':
957
+ samples = self._ar_sample(
958
+ classifier_model=classifier_model, cond=cond)
959
+ else: # Diffusion sampling, current parameterization: d3pm
960
+ samples = self._diffusion_sample(
961
+ classifier_model=classifier_model, cond=cond,
962
+ eps=eps,
963
+ target_sequence=target_sequence,
964
+ target_motifs=target_motifs)
965
+ if not self.config.eval.disable_ema:
966
+ self._restore_non_ema_params()
967
+
968
+ # return orig binders along with this
969
+ return samples
970
+
971
+ @torch.no_grad()
972
+ def _ar_sample(
973
+ self,
974
+ classifier_model: typing.Optional[classifier.Classifier] = None,
975
+ cond: typing.Optional[torch.tensor] = None,
976
+ ):
977
+ # precompute token buffer
978
+ num_pred_tokens = self.config.model.length - 1
979
+ x = torch.zeros(
980
+ (self.config.sampling.batch_size, num_pred_tokens + 1),
981
+ dtype=torch.long,
982
+ device=self.device)
983
+ x[:, 0] = self.tokenizer.bos_token_id
984
+ # precompute Gumbel sampling noise
985
+ if (getattr(self.config, 'guidance', None) is not None
986
+ and self.config.guidance.method == 'fudge'):
987
+ noise = torch.distributions.Gumbel(0, 1).sample(
988
+ (self.config.sampling.batch_size, # type: ignore
989
+ num_pred_tokens,
990
+ self.config.guidance.topk)).to(self.device)
991
+ else:
992
+ noise = torch.distributions.Gumbel(0, 1).sample(
993
+ (self.config.sampling.batch_size, # type: ignore
994
+ num_pred_tokens,
995
+ self.vocab_size)).to(self.device)
996
+ if self.config.sampling.use_float64:
997
+ noise = noise.to(torch.float64)
998
+ pbar = tqdm(range(num_pred_tokens), desc='AR Sampling',
999
+ leave=False)
1000
+ inference_params = InferenceParams(
1001
+ max_seqlen=num_pred_tokens,
1002
+ max_batch_size=x.shape[0],
1003
+ seqlen_offset=1)
1004
+ # For cfg we do 2 forward passes, one for conditional
1005
+ # model and one unconditional, so we need 2 copies of
1006
+ # inference_params.
1007
+ uncond_inference_params = InferenceParams(
1008
+ max_seqlen=num_pred_tokens,
1009
+ max_batch_size=x.shape[0],
1010
+ seqlen_offset=1)
1011
+ for i in pbar:
1012
+ if getattr(self.config, 'guidance', None) is None:
1013
+ if self.config.backbone == 'dimamba':
1014
+ log_probs = self.forward(
1015
+ x[:, i:i + 1], None, cond=None,
1016
+ inference_params=inference_params)
1017
+ else:
1018
+ log_probs = self.forward(x[:, :i + 1],
1019
+ None, cond=None)
1020
+ if self.config.sampling.use_float64:
1021
+ log_probs = log_probs.to(torch.float64)
1022
+ next_log_probs = log_probs[:, -1]
1023
+ y = (next_log_probs + noise[:, i]).argmax(-1)
1024
+ else:
1025
+ if self.config.guidance.method == 'cfg':
1026
+ if self.config.backbone == 'dimamba':
1027
+ next_log_probs = self._ar_cfg_denoise(
1028
+ cond=cond,
1029
+ gamma=self.config.guidance.gamma,
1030
+ x=x[:, i:i + 1],
1031
+ i=i,
1032
+ inference_params=(inference_params, uncond_inference_params))
1033
+ else:
1034
+ next_log_probs = self._ar_cfg_denoise(
1035
+ cond=cond,
1036
+ gamma=self.config.guidance.gamma,
1037
+ x=x,
1038
+ i=i)
1039
+ y = (next_log_probs + noise[:, i]).argmax(-1)
1040
+ elif self.config.guidance.method == 'fudge':
1041
+ if self.config.backbone == 'dimamba':
1042
+ next_log_probs, top_indices = self._ar_fudge_denoise(
1043
+ classifier_model=classifier_model,
1044
+ guidance_cond=self.config.guidance.condition,
1045
+ topk=self.config.guidance.topk,
1046
+ gamma=self.config.guidance.gamma,
1047
+ x=x[:, i:i + 1],
1048
+ i=i,
1049
+ inference_params=inference_params)
1050
+ else:
1051
+ next_log_probs, top_indices = self._ar_fudge_denoise(
1052
+ classifier_model=classifier_model,
1053
+ guidance_cond=self.config.guidance.condition,
1054
+ topk=self.config.guidance.topk,
1055
+ gamma=self.config.guidance.gamma,
1056
+ x=x,
1057
+ i=i)
1058
+ y = torch.gather(
1059
+ top_indices,
1060
+ 1,
1061
+ (next_log_probs + noise[:, i]).argmax(-1).unsqueeze(1)
1062
+ ).squeeze(1)
1063
+ elif self.config.guidance.method == 'pplm':
1064
+ raise NotImplementedError
1065
+ else:
1066
+ raise NotImplementedError(
1067
+ f"Guidance method {self.config.guidance.method} not implemented.")
1068
+ pbar.set_postfix(
1069
+ prob_check=(next_log_probs.exp().sum() / x.shape[0]).item(),
1070
+ nan_check=bool(next_log_probs.isnan().sum() > 0))
1071
+ x[:, i + 1] = y
1072
+ return x
1073
+
1074
+ def _ar_cfg_denoise(
1075
+ self,
1076
+ cond: torch.tensor,
1077
+ gamma: float,
1078
+ x: torch.tensor,
1079
+ i: int,
1080
+ **kwargs
1081
+ ) -> torch.tensor:
1082
+ if self.config.guidance.gamma == 0.0: # Sample unconditionally
1083
+ mask_cond = (torch.ones_like(cond) *
1084
+ self.config.data.num_classes)
1085
+ if self.config.backbone == 'dimamba':
1086
+ inference_params = kwargs.pop('inference_params')
1087
+ log_probs = self.forward(
1088
+ x[:, :i + 1],None, cond=mask_cond,
1089
+ inference_params=inference_params[1])
1090
+ else:
1091
+ log_probs = self.forward(
1092
+ x[:, :i + 1],None, cond=mask_cond, **kwargs)
1093
+ elif gamma == 1.0: # Sample conditionally
1094
+ if self.config.backbone == 'dimamba':
1095
+ inference_params = kwargs.pop('inference_params')
1096
+ log_probs = self.forward(
1097
+ x[:, :i + 1], None, cond=cond,
1098
+ inference_params=inference_params[0])
1099
+ else:
1100
+ log_probs = self.forward(
1101
+ x[:, :i + 1], None, cond=cond, **kwargs)
1102
+ else: # Sample from tempered distribution
1103
+ mask_cond = (torch.ones_like(cond) *
1104
+ self.config.data.num_classes)
1105
+ if self.config.backbone == 'dimamba':
1106
+ inference_params = kwargs.pop('inference_params')
1107
+ log_probs_cond = self.forward(
1108
+ x[:, :i + 1], None, cond=cond,
1109
+ inference_params=inference_params[0])
1110
+ log_probs_uncond = self.forward(
1111
+ x[:, :i + 1],None, cond=mask_cond,
1112
+ inference_params=inference_params[1])
1113
+ else:
1114
+ log_probs_cond = self.forward(
1115
+ x[:, :i + 1], None, cond=cond, **kwargs)
1116
+ log_probs_uncond = self.forward(
1117
+ x[:, :i + 1],None, cond=mask_cond, **kwargs)
1118
+
1119
+ log_probs = gamma * log_probs_cond + (1 - gamma) * log_probs_uncond
1120
+ # Gamma > 1.0 causes instability for Mamba, re-normalizing
1121
+ log_probs = log_probs.log_softmax(dim=-1)
1122
+ return log_probs[:, -1]
1123
+
1124
+ def _ar_fudge_denoise(
1125
+ self,
1126
+ classifier_model: classifier.Classifier,
1127
+ guidance_cond: int,
1128
+ topk: int,
1129
+ gamma: float,
1130
+ x: torch.tensor,
1131
+ i: int,
1132
+ **kwargs
1133
+ ) -> typing.Tuple[torch.tensor, torch.LongTensor]:
1134
+ log_probs = self.forward(
1135
+ x[:, :i + 1], None, cond=None, **kwargs)
1136
+ next_log_probs = log_probs[:, -1]
1137
+ top_logits, top_indices = next_log_probs.topk(topk, dim=-1)
1138
+ t_candidates = torch.cat(
1139
+ [x[:, :i + 1].unsqueeze(1).expand(-1, topk, -1),
1140
+ top_indices.unsqueeze(2)],
1141
+ dim=2).view(-1, i + 2) # (B * K), L
1142
+
1143
+ t = torch.zeros(t_candidates.shape[0],
1144
+ device=self.device)
1145
+ sigma, dsigma = self.noise(t)
1146
+ time_conditioning = sigma[:, None]
1147
+
1148
+ classifier_log_prob = classifier_model.get_log_probs(
1149
+ t_candidates, time_conditioning)
1150
+ classifier_log_prob = classifier_log_prob[:, i + 1, :].view(
1151
+ x.shape[0], topk, -1)[..., guidance_cond] # (batch, topk)
1152
+ next_log_probs = (top_logits + gamma * classifier_log_prob).log_softmax(dim=-1)
1153
+ return next_log_probs, top_indices
1154
+
1155
+ def _ar_pplm_denoise(
1156
+ self,
1157
+ classifier_model: classifier.Classifier,
1158
+ guidance_cond: int,
1159
+ num_ppl_steps: int,
1160
+ pplm_step_size: float,
1161
+ pplm_stability_coef: float,
1162
+ x: torch.tensor,
1163
+ i: int,
1164
+ ):
1165
+ raise NotImplementedError
1166
+
1167
+ @torch.no_grad()
1168
+ def _diffusion_sample(
1169
+ self,
1170
+ classifier_model: typing.Optional[classifier.Classifier] = None,
1171
+ cond: typing.Optional[torch.tensor] = None,
1172
+ eps: float = 1e-5, # Note: differs from self.config.training.sampling_eps
1173
+ target_sequence: torch.tensor = None,
1174
+ target_motifs: torch.tensor = None,
1175
+ ):
1176
+
1177
+ xt = self._sample_prior(
1178
+ self.config.sampling.batch_size,
1179
+ self.config.model.length
1180
+ ).to(self.device)
1181
+
1182
+
1183
+ timesteps = torch.linspace(
1184
+ 1, eps, self.config.sampling.steps + 1, device=self.device)
1185
+ dt = (1 - eps) / self.config.sampling.steps
1186
+ pbar = tqdm(range(self.config.sampling.steps),
1187
+ desc='Sampling',
1188
+ leave=False)
1189
+ NFEs = 0
1190
+ cache = None
1191
+
1192
+ for i in pbar:
1193
+ t = timesteps[i]
1194
+ if self.T > 0: # t in {1/T,..., 1}, to match training
1195
+ t = (t * self.T).to(torch.int)
1196
+ t = t / self.T
1197
+ t += (1 / self.T)
1198
+ t = t * torch.ones(xt.shape[0], 1, device=self.device)
1199
+ if cache is None:
1200
+ NFEs += 1
1201
+ sigma_t, _ = self.noise(t)
1202
+ sigma_s, _ = self.noise(t - dt)
1203
+ if sigma_t.ndim > 1:
1204
+ sigma_t = sigma_t.squeeze(-1)
1205
+ if sigma_s.ndim > 1:
1206
+ sigma_s = sigma_s.squeeze(-1)
1207
+ assert sigma_t.ndim == 1, sigma_t.shape
1208
+ assert sigma_s.ndim == 1, sigma_s.shape
1209
+ move_chance_t = 1 - torch.exp(-sigma_t)
1210
+ move_chance_s = 1 - torch.exp(-sigma_s)
1211
+ move_chance_t = move_chance_t[:, None, None]
1212
+ move_chance_s = move_chance_s[:, None, None]
1213
+ assert move_chance_t.ndim == 3, move_chance_t.shape
1214
+
1215
+ if getattr(self.config, 'guidance', None) is None:
1216
+ xs, q_xs, cache = self._ddpm_denoise(
1217
+ xt=xt,
1218
+ time_conditioning=sigma_t,
1219
+ move_chance_t=move_chance_t,
1220
+ move_chance_s=move_chance_s,
1221
+ cache=cache)
1222
+ else:
1223
+ if self.config.guidance.method == 'cfg':
1224
+ xs, q_xs, cache = self._cfg_denoise(
1225
+ cond=cond,
1226
+ gamma=self.config.guidance.gamma,
1227
+ xt=xt,
1228
+ time_conditioning=sigma_t,
1229
+ move_chance_t=move_chance_t,
1230
+ move_chance_s=move_chance_s,
1231
+ cache=cache)
1232
+ elif self.config.guidance.method == 'cbg':
1233
+ xs, q_xs, cache = self._cbg_denoise(
1234
+ classifier_model=classifier_model,
1235
+ conditioning_class=self.config.guidance.condition,
1236
+ gamma=self.config.guidance.gamma,
1237
+ use_approx=self.config.guidance.use_approx,
1238
+ xt=xt,
1239
+ time_conditioning=sigma_t,
1240
+ move_chance_t=move_chance_t,
1241
+ move_chance_s=move_chance_s,
1242
+ target_sequence=target_sequence,
1243
+ target_motifs=target_motifs,
1244
+ cache=cache)
1245
+
1246
+
1247
+ elif self.config.guidance.method == 'nos':
1248
+ xs, q_xs, cache = self._nos_denoise(
1249
+ classifier_model=classifier_model,
1250
+ conditioning_class=self.config.guidance.condition,
1251
+ num_nos_steps=self.config.guidance.num_nos_steps,
1252
+ nos_step_size=self.config.guidance.nos_step_size,
1253
+ nos_stability_coef=self.config.guidance.nos_stability_coef,
1254
+ xt=xt,
1255
+ time_conditioning=sigma_t,
1256
+ move_chance_t=move_chance_t,
1257
+ move_chance_s=move_chance_s)
1258
+ else:
1259
+ raise NotImplementedError(
1260
+ f"Guidance method {self.config.guidance.method} not implemented.")
1261
+ pbar.set_postfix(
1262
+ NFEs=NFEs,
1263
+ prob_check=(q_xs.sum() / xt.numel()).item(),
1264
+ nan_check=bool(q_xs.isnan().sum() > 0))
1265
+ if (not self.config.sampling.use_cache or
1266
+ not torch.allclose(xs, xt)):
1267
+ # Disable caching
1268
+ cache = None
1269
+ xt = xs
1270
+ return xt
1271
+
1272
+ def _ddpm_denoise(
1273
+ self,
1274
+ xt: torch.tensor,
1275
+ time_conditioning: torch.tensor,
1276
+ move_chance_t: torch.tensor,
1277
+ move_chance_s: torch.tensor,
1278
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1279
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1280
+
1281
+ # Compute x_theta
1282
+ if cache is not None:
1283
+ log_x_theta = cache['log_x_theta']
1284
+ else:
1285
+ log_x_theta = self.forward(xt, time_conditioning,
1286
+ cond=None)
1287
+ if self.config.sampling.use_float64:
1288
+ log_x_theta = log_x_theta.to(torch.float64)
1289
+ x_theta = log_x_theta.exp()
1290
+
1291
+ # Compute posterior
1292
+ if self.diffusion == 'absorbing_state':
1293
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1294
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1295
+ q_xs /= move_chance_t
1296
+ elif self.diffusion == 'uniform':
1297
+ q_xs = self._compute_posterior(
1298
+ x=x_theta,
1299
+ xt=xt,
1300
+ alpha_s=1 - move_chance_s,
1301
+ alpha_t=1 - move_chance_t)
1302
+ else:
1303
+ raise NotImplementedError(
1304
+ f"Diffusion type {self.diffusion} not implemented.")
1305
+
1306
+ # Sample from posterior
1307
+ xs = _sample_categorical(q_xs)
1308
+ if self.diffusion == 'absorbing_state':
1309
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1310
+ q_xs[copy_flag] = 0.0
1311
+ q_xs[copy_flag, xt[copy_flag]] = 1.0
1312
+ xs = torch.where(copy_flag, xt, xs)
1313
+
1314
+ return xs, q_xs, {'log_x_theta': log_x_theta}
1315
+
1316
+ def _cfg_denoise(
1317
+ self,
1318
+ cond: torch.tensor,
1319
+ gamma: float,
1320
+ xt: torch.tensor,
1321
+ time_conditioning: torch.tensor,
1322
+ move_chance_t: torch.tensor,
1323
+ move_chance_s: torch.tensor,
1324
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1325
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1326
+
1327
+ # Compute log_x_theta
1328
+ if cache is not None:
1329
+ log_x_theta_uncond = cache['log_x_theta_uncond']
1330
+ log_x_theta_cond = cache['log_x_theta_cond']
1331
+ else:
1332
+ if gamma == 0.0: # Sample unconditionally
1333
+ mask_cond = (torch.ones_like(cond) *
1334
+ self.config.data.num_classes)
1335
+ log_x_theta_uncond = self.forward(
1336
+ xt, time_conditioning, cond=mask_cond)
1337
+ log_x_theta_cond = None
1338
+ elif gamma == 1.0: # Sample conditionally
1339
+ log_x_theta_cond = self.forward(xt, time_conditioning,
1340
+ cond=cond)
1341
+ log_x_theta_uncond = None
1342
+ else: # Sample from tempered distribution
1343
+ log_x_theta_cond = self.forward(xt, time_conditioning,
1344
+ cond=cond)
1345
+ mask_cond = (torch.ones_like(cond) *
1346
+ self.config.data.num_classes)
1347
+ log_x_theta_uncond = self.forward(xt,
1348
+ time_conditioning,
1349
+ cond=mask_cond)
1350
+ # Compute (weighted) posterior
1351
+ if (log_x_theta_cond is None # gamma == 0
1352
+ or log_x_theta_uncond is None): # or gamma == 1
1353
+ log_x_theta = log_x_theta_uncond if log_x_theta_uncond is not None else log_x_theta_cond
1354
+ x_theta = log_x_theta.exp()
1355
+ if self.diffusion == 'absorbing_state':
1356
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1357
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1358
+ q_xs /= move_chance_t
1359
+ elif self.diffusion == 'uniform':
1360
+ q_xs = self._compute_posterior(
1361
+ x=x_theta,
1362
+ xt=xt,
1363
+ alpha_s=1 - move_chance_s,
1364
+ alpha_t=1 - move_chance_t)
1365
+ else:
1366
+ raise NotImplementedError(
1367
+ f"Diffusion type {self.diffusion} not implemented.")
1368
+ else: # gamma != 0 and gamma != 1
1369
+ if self.diffusion == 'absorbing_state':
1370
+ log_x_theta = (gamma * log_x_theta_cond + (1 - gamma) * log_x_theta_uncond)
1371
+ x_theta = log_x_theta.softmax(dim=-1)
1372
+ q_xs = x_theta * (move_chance_t - move_chance_s)
1373
+ q_xs[:, :, self.mask_index] = move_chance_s[:, :, 0]
1374
+ q_xs /= move_chance_t
1375
+ elif (self.diffusion == 'uniform'
1376
+ or self.diffusion == 'uniform_data_marginals'):
1377
+ log_q_xs_uncond = self._compute_posterior(
1378
+ x=log_x_theta_uncond.exp(),
1379
+ xt=xt,
1380
+ alpha_s=1 - move_chance_s,
1381
+ alpha_t=1 - move_chance_t).log()
1382
+ log_q_xs_cond = self._compute_posterior(
1383
+ x=log_x_theta_cond.exp(),
1384
+ xt=xt,
1385
+ alpha_s=1 - move_chance_s,
1386
+ alpha_t=1 - move_chance_t).log()
1387
+ log_q_xs = (gamma * log_q_xs_cond +
1388
+ (1 - gamma) * log_q_xs_uncond)
1389
+ q_xs = log_q_xs.softmax(dim=-1)
1390
+ else:
1391
+ raise NotImplementedError(
1392
+ f"Diffusion type {self.diffusion} not implemented.")
1393
+
1394
+ # Sample from posterior
1395
+ xs = _sample_categorical(q_xs)
1396
+ if self.diffusion == 'absorbing_state':
1397
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1398
+ q_xs[copy_flag] = 0.0
1399
+ q_xs[copy_flag, xt[copy_flag]] = 1.0
1400
+ xs = torch.where(copy_flag, xt, xs)
1401
+
1402
+ return xs, q_xs, {'log_x_theta_uncond': log_x_theta_uncond,
1403
+ 'log_x_theta_cond': log_x_theta_cond}
1404
+
1405
+ def _cbg_denoise(
1406
+ self,
1407
+ conditioning_class: int,
1408
+ gamma: float,
1409
+ classifier_model: classifier.Classifier,
1410
+ xt: torch.tensor,
1411
+ time_conditioning: torch.tensor,
1412
+ move_chance_t: torch.tensor,
1413
+ move_chance_s: torch.tensor,
1414
+ target_sequence: torch.tensor = None,
1415
+ target_motifs: torch.tensor = None,
1416
+ use_approx: bool = False, # whether to use first-order approximation
1417
+ cache: typing.Optional[typing.Dict[str, torch.Tensor]] = None,
1418
+ ) -> typing.Tuple[torch.tensor, torch.tensor, typing.Dict[str, torch.tensor]]:
1419
+
1420
+ if cache is not None:
1421
+ log_x_theta = cache['log_x_theta']
1422
+ classifier_log_prob = cache['classifier_log_prob']
1423
+ similarity_log_probs = cache['similarity_log_probs']
1424
+
1425
+ else:
1426
+ # Diffusion model
1427
+ log_x_theta = self.forward(xt, time_conditioning,
1428
+ cond=None)
1429
+ # Classifier model
1430
+ if use_approx:
1431
+ print("if statement pops up")
1432
+ xt_one_hot = torch.nn.functional.one_hot(
1433
+ xt, self.vocab_size).to(torch.float)
1434
+ with torch.enable_grad():
1435
+ xt_one_hot.requires_grad_(True)
1436
+ classifier_log_prob_xt = classifier_model.get_log_probs(
1437
+ xt_one_hot, time_conditioning)
1438
+ classifier_log_prob_xt[..., conditioning_class].sum().backward()
1439
+ grad_log_prob_xt = xt_one_hot.grad
1440
+
1441
+ classifier_log_prob_ratio = (
1442
+ grad_log_prob_xt - (xt_one_hot * grad_log_prob_xt).sum(dim=-1, keepdim=True)
1443
+ ).detach().requires_grad_(False)
1444
+ classifier_log_prob = (
1445
+ classifier_log_prob_ratio +
1446
+ classifier_log_prob_xt[..., conditioning_class][..., None, None]
1447
+ ).detach().requires_grad_(False)
1448
+ else:
1449
+ # Copied from https://github.com/hnisonoff/discrete_guidance/blob/main/src/fm_utils.py#L441
1450
+ bsz, seq_len = xt.shape
1451
+ # Create bsz*seq_len*N copies of input sequences
1452
+ # Shape: (bsz, 1, seq_len) -> (bsz, seq_len*N, seq_len)
1453
+ # (where N = vocab_size).
1454
+ xt_expand = xt.unsqueeze(1).repeat(1, seq_len * self.vocab_size, 1)
1455
+ # Flatten batch and transition dimensions
1456
+ # Shape: (bsz, seq_len*N, seq_len) -> (bsz*seq_len*N, seq_len)
1457
+ xt_expand = xt_expand.view(-1, seq_len)
1458
+
1459
+ # Create indices for all possible transitions
1460
+ # Shape: (seq_len*N,) -> (bsz, seq_len*N) -> (bsz*seq_len*N,)
1461
+ jump_idx = torch.arange(seq_len * self.vocab_size).to(xt.device)
1462
+ jump_idx = jump_idx.repeat(bsz, 1).flatten()
1463
+
1464
+ # Create tensor for states after one transition
1465
+ xt_jumps = xt_expand.clone()
1466
+
1467
+ # Calculate which dimension changes for each transition
1468
+ # Shape: (bsz*seq_len*N,)
1469
+ jump_dims = jump_idx // self.vocab_size
1470
+
1471
+ # Calculate new value for changed dimension
1472
+ # Shape: (bsz*seq_len*N,)
1473
+ jump_states = jump_idx % self.vocab_size
1474
+
1475
+ # Apply transitions by assigning new values at transition dimensions
1476
+ # Shape: (bsz*seq_len*N, seq_len)
1477
+ xt_jumps[
1478
+ torch.arange(jump_idx.size(0), device=xt.device),
1479
+ jump_dims, # Index the transitioned dimension
1480
+ ] = jump_states # Assign the new state
1481
+
1482
+ # classifier_log_prob = (classifier_model.get_log_probs(
1483
+ # xt_jumps, time_conditioning.repeat(seq_len * self.vocab_size)
1484
+ # ))[..., conditioning_class].reshape(bsz, seq_len, self.vocab_size)
1485
+
1486
+ target_sequence = target_sequence.to(self.device)
1487
+ mask_vec = torch.tensor([1 if i-1 in target_motifs else 0 for i in range(target_sequence.shape[1])]).to(self.device)
1488
+
1489
+ bindevaluator_probs, similarity_scores = classifier_model.get_probs(
1490
+ xt_jumps, target_sequence.repeat(xt_jumps.shape[0], 1), self.original_binder_embedding_avg
1491
+ )
1492
+
1493
+ similarity_scores_reshaped = similarity_scores.reshape(bsz, seq_len, self.vocab_size)
1494
+ # this is to normalize cos scores: [-1 1] -> [0 1]
1495
+ normalized_similarity = (similarity_scores_reshaped + 1) / 2
1496
+ similarity_log_probs = torch.log(normalized_similarity + 1e-8)
1497
+ # pdb.set_trace()
1498
+ bindevaluator_probs = torch.where(bindevaluator_probs == 0, torch.tensor(1e-8, dtype=bindevaluator_probs.dtype), bindevaluator_probs)
1499
+ # this mask vector corresponds to the target sequence, how can you multipl it with bindevaluator?
1500
+ classifier_log_prob = torch.log(bindevaluator_probs) * mask_vec
1501
+
1502
+ # pdb.set_trace()
1503
+ classifier_log_prob = classifier_log_prob.sum(dim=-1) / mask_vec.sum()
1504
+ # print("before reshape classifier_log_prob.shape", classifier_log_prob.shape)
1505
+ classifier_log_prob = classifier_log_prob.reshape(bsz, seq_len, self.vocab_size)
1506
+ # print("after reshape classifier_log_prob.shape", classifier_log_prob.shape)
1507
+
1508
+
1509
+ # Compute unguided posterior
1510
+ if self.diffusion == 'absorbing_state':
1511
+ diffusion_log_probs = log_x_theta + torch.log(
1512
+ 1. - (move_chance_s / move_chance_t))
1513
+ diffusion_log_probs[..., self.mask_index] = torch.log(
1514
+ move_chance_s / move_chance_t)[:, :, 0]
1515
+ diffusion_log_probs.detach()
1516
+ elif self.diffusion == 'uniform':
1517
+ diffusion_log_probs = self._compute_posterior(
1518
+ x=log_x_theta.exp(),
1519
+ xt=xt,
1520
+ alpha_s=1 - move_chance_s,
1521
+ alpha_t=1 - move_chance_t).log()
1522
+ else:
1523
+ raise NotImplementedError(
1524
+ f"Diffusion type {self.diffusion} not implemented.")
1525
+
1526
+
1527
+ # Apply guidance
1528
+ with torch.no_grad():
1529
+ if self.diffusion == 'absorbing_state':
1530
+
1531
+ guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs
1532
+ copy_flag = (xt != self.mask_index)
1533
+ guided_log_probs[copy_flag] = self.neg_infinity
1534
+ guided_log_probs[copy_flag, xt[copy_flag]] = 0.0
1535
+ elif self.diffusion == 'uniform':
1536
+
1537
+ # print("final diffusion_log_probs", diffusion_log_probs)
1538
+ # print("similarity_log_probs", similarity_log_probs)
1539
+
1540
+ guided_log_probs = (gamma * classifier_log_prob) + diffusion_log_probs + 2*similarity_log_probs
1541
+ else:
1542
+ raise NotImplementedError(
1543
+ f"Diffusion type {self.diffusion} not implemented.")
1544
+
1545
+ guided_probs = guided_log_probs.softmax(dim=-1)
1546
+ # Sample from guided posterior
1547
+ xs = _sample_categorical(guided_probs)
1548
+ if self.diffusion == 'absorbing_state':
1549
+ xs = torch.where(copy_flag.to(bool), xt, xs)
1550
+ return xs, guided_probs, {'log_x_theta': log_x_theta,
1551
+ 'classifier_log_prob': classifier_log_prob,
1552
+ 'similarity_log_probs': similarity_log_probs}
1553
+
1554
+ def _nos_denoise(
1555
+ self,
1556
+ classifier_model: classifier.Classifier,
1557
+ num_nos_steps: int,
1558
+ nos_step_size: float,
1559
+ nos_stability_coef: float,
1560
+ conditioning_class: int,
1561
+ xt: torch.Tensor,
1562
+ time_conditioning: torch.tensor,
1563
+ move_chance_t: torch.tensor,
1564
+ move_chance_s: torch.tensor,
1565
+ ) -> typing.Tuple[torch.tensor, torch.tensor, None]:
1566
+ # Compute original diffusion_log_probs and hidden states
1567
+ copy_flag = (xt != self.mask_index).to(torch.bool)
1568
+ with torch.no_grad():
1569
+ time_conditioning = self._process_sigma(time_conditioning)
1570
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1571
+ logits, hidden_states = self.backbone(
1572
+ xt, time_conditioning, cond=None,
1573
+ return_hidden_states=True)
1574
+ if self.parameterization == 'subs':
1575
+ log_x_theta = self._subs_parameterization(
1576
+ logits=logits, xt=xt)
1577
+ elif self.parameterization == 'd3pm':
1578
+ # returns log_probs
1579
+ if self.subs_masking: # Can use "zero masking prob"
1580
+ logits[:, :,
1581
+ self.mask_index] += self.neg_infinity
1582
+ log_x_theta = logits.log_softmax(dim=-1)
1583
+ else:
1584
+ raise NotImplementedError(
1585
+ f"Parameterization {self.parameterization} not implemented for NOS guidance.")
1586
+ if self.diffusion == 'absorbing_state':
1587
+ diffusion_log_probs = log_x_theta + torch.log(
1588
+ 1. - (move_chance_s / move_chance_t))
1589
+ diffusion_log_probs[..., self.mask_index] = torch.log(
1590
+ move_chance_s / move_chance_t)[:, :, 0]
1591
+ diffusion_log_probs[copy_flag] = self.neg_infinity
1592
+ diffusion_log_probs[copy_flag, xt[copy_flag]] = 0.0
1593
+ elif self.diffusion == 'uniform':
1594
+ diffusion_log_probs = self._compute_posterior(
1595
+ x=log_x_theta.exp(),
1596
+ xt=xt,
1597
+ alpha_s=1 - move_chance_s,
1598
+ alpha_t=1 - move_chance_t).log()
1599
+
1600
+ # Perform NOS steps
1601
+ kl_loss = torch.nn.KLDivLoss(reduction='batchmean',
1602
+ log_target=True)
1603
+ delta = torch.nn.Parameter(
1604
+ torch.zeros_like(hidden_states[-1]),
1605
+ requires_grad=True)
1606
+ optimizer = torch.optim.Adagrad([delta], lr=nos_step_size)
1607
+ with torch.enable_grad():
1608
+ for _ in tqdm(range(num_nos_steps),
1609
+ desc='NOS', leave=False):
1610
+ h_current = hidden_states[-1] + delta
1611
+ target_loss = classifier_model.get_log_probs(
1612
+ xt, time_conditioning, x_emb=h_current)[..., conditioning_class].sum()
1613
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1614
+ new_logits = self.forward(xt, time_conditioning,
1615
+ cond=None,
1616
+ x_emb=h_current)
1617
+ if self.diffusion == 'absorbing_state':
1618
+ adjusted_log_probs = new_logits + torch.log(
1619
+ 1. - (move_chance_s / move_chance_t))
1620
+ adjusted_log_probs[
1621
+ ..., self.mask_index] = torch.log(
1622
+ move_chance_s / move_chance_t)[:, :, 0]
1623
+ adjusted_log_probs[
1624
+ copy_flag] = self.neg_infinity
1625
+ adjusted_log_probs[copy_flag, xt[copy_flag]] = 0.0
1626
+ elif self.diffusion == 'uniform':
1627
+ adjusted_log_probs = self._compute_posterior(
1628
+ x=new_logits.exp(),
1629
+ xt=xt,
1630
+ alpha_s=1 - move_chance_s,
1631
+ alpha_t=1 - move_chance_t).log()
1632
+ kl = kl_loss(adjusted_log_probs, diffusion_log_probs)
1633
+ loss = -target_loss + nos_stability_coef * kl
1634
+ optimizer.zero_grad()
1635
+ loss.backward()
1636
+ optimizer.step()
1637
+ with torch.cuda.amp.autocast(dtype=torch.float32):
1638
+ guided_logits = self.forward(
1639
+ xt, time_conditioning,
1640
+ cond=None,
1641
+ x_emb=hidden_states[-1] + delta.data)
1642
+ if self.diffusion == 'absorbing_state':
1643
+ diffusion_log_probs = guided_logits + torch.log(
1644
+ 1. - (move_chance_s / move_chance_t))
1645
+ diffusion_log_probs[
1646
+ ..., self.mask_index] = torch.log(
1647
+ move_chance_s / move_chance_t)[:, :, 0]
1648
+ diffusion_log_probs.detach()
1649
+ guided_probs = diffusion_log_probs.exp()
1650
+ elif self.diffusion == 'uniform':
1651
+ guided_probs = self._compute_posterior(
1652
+ x=guided_logits.exp(),
1653
+ xt=xt,
1654
+ alpha_s=1 - move_chance_s,
1655
+ alpha_t=1 - move_chance_t).detach()
1656
+ else:
1657
+ raise NotImplementedError(
1658
+ f"Diffusion type {self.diffusion} not implemented.")
1659
+
1660
+ xs = _sample_categorical(guided_probs)
1661
+ if self.diffusion == 'absorbing_state':
1662
+ xs = torch.where(copy_flag, xt, xs)
1663
+
1664
+ return xs, guided_probs, None
sample_emb_guidance.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hydra
3
+ import lightning as L
4
+ import numpy as np
5
+ import omegaconf
6
+ import pandas as pd
7
+ import rdkit
8
+ import rich.syntax
9
+ import rich.tree
10
+ import torch
11
+ from tqdm.auto import tqdm
12
+ import pdb
13
+ import torch.nn.functional as F
14
+ import dataloader
15
+ import diffusion
16
+ from models.bindevaluator import BindEvaluator
17
+ from transformers import AutoTokenizer, EsmModel
18
+ from faesm.esm import FAEsmForMaskedLM
19
+ import torch.nn as nn
20
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
21
+ import numpy as np
22
+
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ # PEPMLM_NAME = "ChatterjeeLab/PepMLM-650M"
25
+ # PEPMLM_TOKEN = "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" #place your access token here
26
+ # PEPMLM_MODEL = AutoModelForMaskedLM.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
27
+ # pepmlm_tokenizer = AutoTokenizer.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
28
+
29
+ # pepmlm = PEPMLM_MODEL.to(DEVICE)
30
+ rdkit.rdBase.DisableLog('rdApp.error')
31
+
32
+ omegaconf.OmegaConf.register_new_resolver(
33
+ 'cwd', os.getcwd)
34
+ omegaconf.OmegaConf.register_new_resolver(
35
+ 'device_count', torch.cuda.device_count)
36
+ omegaconf.OmegaConf.register_new_resolver(
37
+ 'eval', eval)
38
+ omegaconf.OmegaConf.register_new_resolver(
39
+ 'div_up', lambda x, y: (x + y - 1) // y)
40
+ omegaconf.OmegaConf.register_new_resolver(
41
+ 'if_then_else',
42
+ lambda condition, x, y: x if condition else y
43
+ )
44
+
45
+ def _print_config(
46
+ config: omegaconf.DictConfig,
47
+ resolve: bool = True) -> None:
48
+ """Prints content of DictConfig using Rich library and its tree structure.
49
+
50
+ Args:
51
+ config (DictConfig): Configuration composed by Hydra.
52
+ resolve (bool): Whether to resolve reference fields of DictConfig.
53
+ """
54
+
55
+ style = 'dim'
56
+ tree = rich.tree.Tree('CONFIG', style=style,
57
+ guide_style=style)
58
+
59
+ fields = config.keys()
60
+ for field in fields:
61
+ branch = tree.add(field, style=style, guide_style=style)
62
+
63
+ config_section = config.get(field)
64
+ branch_content = str(config_section)
65
+ if isinstance(config_section, omegaconf.DictConfig):
66
+ branch_content = omegaconf.OmegaConf.to_yaml(
67
+ config_section, resolve=resolve)
68
+
69
+ branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
70
+ rich.print(tree)
71
+
72
+ def parse_motif(motif: str) -> list:
73
+ parts = motif.split(',')
74
+ result = []
75
+
76
+ for part in parts:
77
+ part = part.strip()
78
+ if '-' in part:
79
+ start, end = map(int, part.split('-'))
80
+ result.extend(range(start, end + 1))
81
+ else:
82
+ result.append(int(part))
83
+
84
+ return torch.tensor(result)
85
+
86
+
87
+ @hydra.main(version_base=None, config_path='./configs',
88
+ config_name='config')
89
+ def main(config: omegaconf.DictConfig) -> None:
90
+ # Reproducibility
91
+ L.seed_everything(config.seed)
92
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
93
+ torch.use_deterministic_algorithms(True)
94
+ torch.backends.cudnn.benchmark = False
95
+
96
+ # _print_config(config, resolve=True)
97
+ print(f"Checkpoint: {config.eval.checkpoint_path}")
98
+
99
+ tokenizer = dataloader.get_tokenizer(config)
100
+ target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids']
101
+
102
+ pretrained = diffusion.Diffusion.load_from_checkpoint(
103
+ config.eval.checkpoint_path,
104
+ tokenizer=tokenizer,
105
+ config=config, logger=False)
106
+ pretrained.eval()
107
+ pretrained = pretrained.to('cuda')
108
+
109
+ bindevaluator = BindEvaluator.load_from_checkpoint(
110
+ config.guidance.classifier_checkpoint_path,
111
+ n_layers=8,
112
+ d_model=128,
113
+ d_hidden=128,
114
+ n_head=8,
115
+ d_k=64,
116
+ d_v=128,
117
+ d_inner=64)
118
+ bindevaluator = bindevaluator.to('cuda')
119
+
120
+ # below is the implementation of ESM with flash attention
121
+ # using 650M --> might use a bugger/smaller model
122
+ # esm = EsmModel.from_pretrained("facebook/esm2_t6_650M_UR50D")
123
+ # esm = esm.to("cuda")
124
+ # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_650M_UR50D")
125
+
126
+ esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16)
127
+
128
+
129
+ samples = []
130
+ original_binder = config.sampling.original_binder
131
+ original_binder_input = esm.tokenizer(original_binder, return_tensors="pt")
132
+ original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()}
133
+ original_binder_outputs = esm(**original_binder_input)
134
+ original_binder_embedding = original_binder_outputs['last_hidden_state']
135
+ original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1)
136
+
137
+ for _ in tqdm(
138
+ range(config.sampling.num_sample_batches),
139
+ desc='Gen. batches', leave=False):
140
+ sample = pretrained.sample(
141
+ target_sequence = target_sequence,
142
+ target_motifs = parse_motif(config.eval.target_motifs),
143
+ classifier_model = bindevaluator
144
+ )
145
+ sample_decoded = pretrained.tokenizer.batch_decode(sample)
146
+ samples_processed = [seq.replace(' ', '')[5:-5] for seq in sample_decoded]
147
+ print('sample: ', samples_processed)
148
+ samples.extend(samples_processed)
149
+
150
+ samples_similarity = {}
151
+
152
+ with torch.no_grad():
153
+ for seq in tqdm(samples, desc='Computing similarities'):
154
+ seq_input = esm.tokenizer(seq, return_tensors="pt")
155
+ seq_input = {k: v.to('cuda') for k, v in seq_input.items()}
156
+ seq_output = esm(**seq_input)
157
+ seq_embedding = seq_output['last_hidden_state']
158
+ seq_embedding_avg = torch.mean(seq_embedding, dim=1)
159
+ similarity_score = F.cosine_similarity(seq_embedding_avg, original_binder_embedding_avg)
160
+ samples_similarity[seq] = similarity_score.item()
161
+
162
+
163
+ outputs_csv = pd.DataFrame({
164
+ 'samples': list(samples),
165
+ 'samples_similarity': list(samples_similarity.values())
166
+ })
167
+ print("outputs_csv", outputs_csv)
168
+ outputs_csv.to_csv('il2_alpha_guidance.csv', index = False)
169
+
170
+
171
+
172
+ if __name__ == '__main__':
173
+ main()