AlienChen commited on
Commit
b5bda14
·
verified ·
1 Parent(s): 77f4137

Create peptiverse_classifiers.py

Browse files
Files changed (1) hide show
  1. models/peptiverse_classifiers.py +446 -0
models/peptiverse_classifiers.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ from modules.bindevaluator_modules import *
4
+ from transformers import AutoModelWithLMHead, AutoTokenizer, EsmModel
5
+
6
+ from flow_matching.path import MixtureDiscreteProbPath
7
+ from flow_matching.path.scheduler import PolynomialConvexScheduler
8
+ from flow_matching.solver import MixtureDiscreteEulerSolver
9
+ from flow_matching.utils import ModelWrapper
10
+ from flow_matching.loss import MixturePathGeneralizedKL
11
+
12
+ from models.peptide_models import CNNModel
13
+ from modules.bindevaluator_modules import *
14
+ # from models.uaa_models import *
15
+
16
+ import sys
17
+ sys.path.append('./PeptiVerse')
18
+ from inference import PeptiVersePredictor
19
+
20
+ pred = PeptiVersePredictor(
21
+ manifest_path="./PeptiVerse/best_models.txt", # best model list
22
+ classifier_weight_root="./PeptiVerse/", # repo root (where training_classifiers/ lives)
23
+ device="cuda", # or "cpu"
24
+ )
25
+
26
+ class HemolysisWT:
27
+ def __call__(self, input_seqs):
28
+ scores = []
29
+ for seq in input_seqs:
30
+ score = pred.predict_property("hemolysis", col="wt", input_str=seq)['score']
31
+ scores.append(score)
32
+
33
+ return torch.tensor(scores)
34
+
35
+ class NonfoulingWT:
36
+ def __call__(self, input_seqs):
37
+ scores = []
38
+ for seq in input_seqs:
39
+ score = pred.predict_property("nf", col="wt", input_str=seq)['score']
40
+ scores.append(score)
41
+
42
+ return torch.tensor(scores)
43
+
44
+ # class Solubility:
45
+ # def __init__(self):
46
+ # self.hydrophobic = list("AVLIMFWPavilmfwpŶƘṂŁĊ")
47
+
48
+ # def __call__(self, aa_seqs: list):
49
+ # scores = []
50
+ # for seq in aa_seqs:
51
+ # if len(seq) == 0:
52
+ # scores.append(0)
53
+ # continue
54
+ # score = len([tok for tok in seq if tok not in self.hydrophobic]) / len(seq)
55
+ # scores.append(score)
56
+ # return torch.tensor(scores)
57
+
58
+ class Solubility:
59
+ def __call__(self, input_seqs):
60
+ scores = []
61
+ for seq in input_seqs:
62
+ score = pred.predict_property("solubility", col="wt", input_str=seq)['score']
63
+ scores.append(score)
64
+
65
+ return torch.tensor(scores)
66
+
67
+
68
+ class PermeabilityWT:
69
+ def __call__(self, input_seqs):
70
+ scores = []
71
+ for seq in input_seqs:
72
+ score = pred.predict_property("permeability_penetrance", col="wt", input_str=seq)['score']
73
+ scores.append(score)
74
+
75
+ return torch.tensor(scores)
76
+
77
+ class HalfLifeWT:
78
+ def __call__(self, input_seqs):
79
+ scores = []
80
+ for seq in input_seqs:
81
+ score = pred.predict_property("halflife", col="wt", input_str=seq)['score']
82
+ scores.append(score)
83
+
84
+ return torch.tensor(scores)
85
+
86
+ class AffinityWT:
87
+ def __init__(self, target):
88
+ self.target = target
89
+
90
+ def __call__(self, input_seqs):
91
+ scores = []
92
+ for seq in input_seqs:
93
+ score = pred.predict_binding_affinity(col="wt", target_seq=self.target, binder_str=seq)['affinity']
94
+ scores.append(score / 10)
95
+
96
+ return torch.tensor(scores)
97
+
98
+
99
+ def parse_motifs(motif: str) -> list:
100
+ parts = motif.split(',')
101
+ result = []
102
+
103
+ for part in parts:
104
+ part = part.strip()
105
+ if '-' in part:
106
+ start, end = map(int, part.split('-'))
107
+ result.extend(range(start, end + 1))
108
+ else:
109
+ result.append(int(part))
110
+
111
+ # result = [pos-1 for pos in result]
112
+ print(f'Target Motifs: {result}')
113
+ return torch.tensor(result)
114
+
115
+
116
+ class BindEvaluatorWT(pl.LightningModule):
117
+ def __init__(self, n_layers, d_model, d_hidden, n_head,
118
+ d_k, d_v, d_inner, dropout=0.2,
119
+ learning_rate=0.00001, max_epochs=15, kl_weight=1):
120
+ super(BindEvaluatorWT, self).__init__()
121
+
122
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
123
+ self.esm_model.eval()
124
+ # freeze all the esm_model parameters
125
+ for param in self.esm_model.parameters():
126
+ param.requires_grad = False
127
+
128
+ self.repeated_module = RepeatedModule3(n_layers, d_model, d_hidden,
129
+ n_head, d_k, d_v, d_inner, dropout=dropout)
130
+
131
+ self.final_attention_layer = MultiHeadAttentionSequence(n_head, d_model,
132
+ d_k, d_v, dropout=dropout)
133
+
134
+ self.final_ffn = FFN(d_model, d_inner, dropout=dropout)
135
+
136
+ self.output_projection_prot = nn.Linear(d_model, 1)
137
+
138
+ self.learning_rate = learning_rate
139
+ self.max_epochs = max_epochs
140
+ self.kl_weight = kl_weight
141
+
142
+ self.classification_threshold = nn.Parameter(torch.tensor(0.5)) # Initial threshold
143
+ self.historical_memory = 0.9
144
+ self.class_weights = torch.tensor([3.000471363174231, 0.5999811490272925]) # binding_site weights, non-bidning site weights
145
+
146
+ def forward(self, binder_tokens, target_tokens):
147
+ peptide_sequence = self.esm_model(**binder_tokens).last_hidden_state
148
+ protein_sequence = self.esm_model(**target_tokens).last_hidden_state
149
+
150
+ prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
151
+ seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
152
+ protein_sequence)
153
+
154
+ prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
155
+
156
+ prot_enc = self.final_ffn(prot_enc)
157
+
158
+ prot_enc = self.output_projection_prot(prot_enc)
159
+
160
+ return prot_enc
161
+
162
+ def get_probs(self, x_t, target_sequence):
163
+ '''
164
+ Inputs:
165
+ - xt: Shape (bsz, seq_len)
166
+ - target_sequence: Shape (1, tgt_len)
167
+ '''
168
+ # pdb.set_trace()
169
+ target_sequence = target_sequence.repeat(x_t.shape[0], 1)
170
+ binder_attention_mask = torch.ones_like(x_t)
171
+ target_attention_mask = torch.ones_like(target_sequence)
172
+
173
+ binder_attention_mask[:, 0] = binder_attention_mask[:, -1] = 0
174
+ target_attention_mask[:, 0] = target_attention_mask[:, -1] = 0
175
+
176
+ binder_tokens = {'input_ids': x_t, 'attention_mask': binder_attention_mask.to(x_t.device)}
177
+ target_tokens = {'input_ids': target_sequence, 'attention_mask': target_attention_mask.to(target_sequence.device)}
178
+
179
+ logits = self.forward(binder_tokens, target_tokens).squeeze(-1)
180
+ logits[:, 0] = logits[:, -1] = -100 # float('-inf')
181
+ probs = torch.sigmoid(logits)
182
+
183
+ return probs # shape (bsz, tgt_len)
184
+
185
+ def motif_score(self, x_t, target_sequence, motifs):
186
+ probs = self.get_probs(x_t, target_sequence)
187
+ motif_probs = probs[:, motifs]
188
+ motif_score = motif_probs.sum(dim=-1) / len(motifs)
189
+ # pdb.set_trace()
190
+ return motif_score
191
+
192
+ def non_motif_score(self, x_t, target_sequence, motifs):
193
+ probs = self.get_probs(x_t, target_sequence)
194
+ non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
195
+ mask = non_motif_probs >= 0.5
196
+ count = mask.sum(dim=-1)
197
+
198
+ non_motif_score = torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
199
+
200
+ return non_motif_score
201
+
202
+ def scoring(self, x_t, target_sequence, motifs, penalty=False):
203
+ probs = self.get_probs(x_t, target_sequence)
204
+ motif_probs = probs[:, motifs]
205
+ motif_score = motif_probs.sum(dim=-1) / len(motifs)
206
+ # pdb.set_trace()
207
+
208
+ if penalty:
209
+ non_motif_probs = probs[:, [i for i in range(probs.shape[1]) if i not in motifs]]
210
+ mask = non_motif_probs >= 0.5
211
+ count = mask.sum(dim=-1)
212
+ # non_motif_score = 1 - torch.where(count > 0, (non_motif_probs * mask).sum(dim=-1) / count, torch.zeros_like(count))
213
+ non_motif_score = count / target_sequence.shape[1]
214
+ return motif_score, 1 - non_motif_score
215
+ else:
216
+ return motif_score
217
+
218
+ class MotifModelWT(nn.Module):
219
+ def __init__(self, bindevaluator, target_sequence, motifs, tokenizer, device, penalty=False):
220
+ super(MotifModelWT, self).__init__()
221
+ self.bindevaluator = bindevaluator
222
+ self.target_sequence = target_sequence
223
+ self.motifs = motifs
224
+ self.penalty = penalty
225
+
226
+ self.tokenizer = tokenizer
227
+ self.device = device
228
+
229
+ def forward(self, input_seqs):
230
+ x = self.tokenizer(input_seqs, return_tensors='pt')['input_ids'].to(self.device)
231
+ return self.bindevaluator.scoring(x, self.target_sequence, self.motifs, self.penalty)
232
+
233
+ def load_bindevaluator(checkpoint_path, device):
234
+ bindevaluator = BindEvaluatorWT.load_from_checkpoint(checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64).to(device)
235
+ bindevaluator.eval()
236
+ for param in bindevaluator.parameters():
237
+ param.requires_grad = False
238
+
239
+ return bindevaluator
240
+
241
+
242
+ def load_solver(checkpoint_path, vocab_size, device):
243
+ lr = 1e-4
244
+ epochs = 200
245
+ embed_dim = 512
246
+ hidden_dim = 256
247
+ epsilon = 1e-3
248
+ batch_size = 256
249
+ warmup_epochs = epochs // 10
250
+ device = 'cuda:0'
251
+
252
+
253
+ probability_denoiser = CNNModel(alphabet_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim).to(device)
254
+ probability_denoiser.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=False))
255
+ probability_denoiser.eval()
256
+ for param in probability_denoiser.parameters():
257
+ param.requires_grad = False
258
+
259
+ # instantiate a convex path object
260
+ scheduler = PolynomialConvexScheduler(n=2.0)
261
+ path = MixtureDiscreteProbPath(scheduler=scheduler)
262
+
263
+ class WrappedModel(ModelWrapper):
264
+ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
265
+ return torch.softmax(self.model(x, t), dim=-1)
266
+
267
+ wrapped_probability_denoiser = WrappedModel(probability_denoiser)
268
+ solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
269
+
270
+ return solver
271
+
272
+ # def load_uaa_solver(checkpoint_path, vocab_size, device):
273
+ # checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
274
+ # model_args = checkpoint['args']
275
+
276
+ # probability_denoiser = MDLM(
277
+ # vocab_size=model_args.vocab_size,
278
+ # seq_len=model_args.seq_len,
279
+ # model_dim=model_args.model_dim,
280
+ # n_heads=model_args.n_heads,
281
+ # n_layers=model_args.n_layers
282
+ # ).to(device)
283
+
284
+ # probability_denoiser.load_state_dict(checkpoint['model_state_dict'])
285
+ # probability_denoiser.eval()
286
+ # for param in probability_denoiser.parameters():
287
+ # param.requires_grad = False
288
+
289
+ # # instantiate a convex path object
290
+ # scheduler = PolynomialConvexScheduler(n=2.0)
291
+ # path = MixtureDiscreteProbPath(scheduler=scheduler)
292
+
293
+ # class WrappedModel(ModelWrapper):
294
+ # def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
295
+ # return torch.softmax(self.model(x, t), dim=-1)
296
+
297
+ # wrapped_probability_denoiser = WrappedModel(probability_denoiser)
298
+ # solver = MixtureDiscreteEulerSolver(model=wrapped_probability_denoiser, path=path, vocabulary_size=vocab_size)
299
+
300
+ # return solver
301
+
302
+
303
+ class HemolysisSMILES:
304
+ def __call__(self, smiles_seqs):
305
+ scores = []
306
+ for seq in smiles_seqs:
307
+ score = pred.predict_property("hemolysis", col="smiles", input_str=seq)['score']
308
+ scores.append(score)
309
+
310
+ return torch.tensor(scores)
311
+
312
+ class NonfoulingSMILES:
313
+ def __call__(self, smiles_seqs):
314
+ scores = []
315
+ for seq in smiles_seqs:
316
+ score = pred.predict_property("nf", col="smiles", input_str=seq)['score']
317
+ scores.append(score)
318
+
319
+ return torch.tensor(scores)
320
+
321
+ class PermeabilitySMILES:
322
+ def __call__(self, smiles_seqs):
323
+ scores = []
324
+ for seq in smiles_seqs:
325
+ score = pred.predict_property("permeability_pampa", col="smiles", input_str=seq)['score']
326
+ score = (score + 9) / (-4 + 9)
327
+ scores.append(score)
328
+
329
+ return torch.tensor(scores)
330
+
331
+ class HalfLifeSMILES:
332
+ def __call__(self, smiles_seqs):
333
+ scores = []
334
+ for seq in smiles_seqs:
335
+ score = pred.predict_property("halflife", col="smiles", input_str=seq)['score']
336
+ scores.append(score)
337
+
338
+ return torch.tensor(scores)
339
+
340
+ class ToxicitySMILES:
341
+ def __call__(self, smiles_seqs):
342
+ scores = []
343
+ for seq in smiles_seqs:
344
+ score = pred.predict_property("toxicity", col="smiles", input_str=seq)['score']
345
+ scores.append(score)
346
+
347
+ return torch.tensor(scores)
348
+
349
+ class AffinitySMILES:
350
+ def __init__(self, target):
351
+ self.target = target
352
+
353
+ def __call__(self, smiles_seqs):
354
+ scores = []
355
+ for seq in smiles_seqs:
356
+ score = pred.predict_binding_affinity(col="smiles", target_seq=self.target, binder_str=seq)['affinity']
357
+ scores.append(score / 10)
358
+
359
+ return torch.tensor(scores)
360
+
361
+
362
+ class BindEvaluatorSMILES(pl.LightningModule):
363
+ def __init__(self, cfg):
364
+ super(BindEvaluatorSMILES, self).__init__()
365
+
366
+ self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval()
367
+ for param in self.esm_model.parameters():
368
+ param.requires_grad = False
369
+
370
+ self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval()
371
+ for param in self.chemberta_model.parameters():
372
+ param.requires_grad = False
373
+
374
+ self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden,
375
+ cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout)
376
+
377
+ self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model,
378
+ cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout)
379
+
380
+ self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout)
381
+
382
+ self.output_projection_prot = nn.Linear(cfg.model.d_model, 1)
383
+
384
+ def forward(self, binder_tokens, target_tokens):
385
+ peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state
386
+ protein_sequence = self.esm_model(**target_tokens).last_hidden_state
387
+
388
+ binder_mask = binder_tokens["attention_mask"] # [B, Ls]
389
+ target_mask = target_tokens["attention_mask"] # [B, Lp]
390
+
391
+ prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
392
+ prot_seq_attention_list, seq_prot_attention_list = self.repeated_module(
393
+ peptide_sequence,
394
+ protein_sequence,
395
+ peptide_mask=binder_mask,
396
+ protein_mask=target_mask,
397
+ )
398
+
399
+ # final cross-attention: protein queries attend to binder keys
400
+ prot_enc, final_prot_seq_attention = self.final_attention_layer(
401
+ prot_enc, sequence_enc, sequence_enc,
402
+ key_padding_mask=binder_mask,
403
+ query_padding_mask=target_mask,
404
+ )
405
+
406
+ prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask)
407
+ prot_enc = self.output_projection_prot(prot_enc)
408
+ return prot_enc
409
+
410
+ class MotifModelSMILES:
411
+ def __init__(self, cfg, target, motifs, device, specificity):
412
+ self.cfg = cfg
413
+ self.threshold = 0.918
414
+ self.device = device
415
+ self.specificity = specificity
416
+ self.chemberta_tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k")
417
+ self.esm_tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
418
+ self.target = self.esm_tokenizer(target, return_tensors='pt').to(device)
419
+ self.motifs = parse_motifs(motifs).to(device)
420
+ self.bindevaluator = BindEvaluatorSMILES.load_from_checkpoint(cfg.inference.ckpt, cfg=cfg, map_location=device)
421
+
422
+ def __call__(self, smiles_seqs):
423
+ L = self.target['input_ids'].shape[1]
424
+
425
+ motif_scores = []
426
+ specificity_scores = []
427
+ for seq in smiles_seqs:
428
+ binder = self.chemberta_tokenizer(seq, return_tensors='pt').to(self.device)
429
+ prediction = self.bindevaluator(binder, self.target).squeeze(-1)
430
+ # pdb.set_trace()
431
+ probs = torch.sigmoid(prediction).squeeze(0) # (1, L)
432
+ motif_score = probs[self.motifs].mean()
433
+ motif_scores.append(motif_score)
434
+ if self.specificity:
435
+ non_motif_probs = probs[[i for i in range(probs.shape[0]) if i not in self.motifs]]
436
+ mask = non_motif_probs >= self.threshold
437
+ count = mask.sum()
438
+ specificity = 1 - count / (L-2)
439
+
440
+ specificity_scores.append(specificity)
441
+
442
+ if self.specificity:
443
+ return torch.tensor(motif_scores), torch.tensor(specificity_scores)
444
+ else:
445
+ return torch.tensor(motif_scores)
446
+