Update moppit.py
Browse files
moppit.py
CHANGED
|
@@ -1,9 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from transformers import AutoTokenizer
|
| 3 |
from pathlib import Path
|
| 4 |
import inspect
|
| 5 |
|
| 6 |
-
from models.peptide_classifiers import *
|
|
|
|
| 7 |
from utils.parsing import parse_guidance_args
|
| 8 |
args = parse_guidance_args()
|
| 9 |
|
|
@@ -17,9 +46,9 @@ device = 'cuda:0'
|
|
| 17 |
|
| 18 |
length = args.length
|
| 19 |
target = args.target_protein
|
|
|
|
| 20 |
if args.motifs:
|
| 21 |
motifs = parse_motifs(args.motifs).to(device)
|
| 22 |
-
print(motifs)
|
| 23 |
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 25 |
target_sequence = tokenizer(target, return_tensors='pt').to(device)
|
|
@@ -29,29 +58,30 @@ solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_lo
|
|
| 29 |
|
| 30 |
score_models = []
|
| 31 |
if 'Hemolysis' in args.objectives:
|
| 32 |
-
hemolysis_model =
|
| 33 |
score_models.append(hemolysis_model)
|
| 34 |
if 'Non-Fouling' in args.objectives:
|
| 35 |
-
nonfouling_model =
|
| 36 |
score_models.append(nonfouling_model)
|
| 37 |
if 'Solubility' in args.objectives:
|
| 38 |
-
solubility_model =
|
| 39 |
score_models.append(solubility_model)
|
|
|
|
|
|
|
|
|
|
| 40 |
if 'Half-Life' in args.objectives:
|
| 41 |
-
halflife_model =
|
| 42 |
score_models.append(halflife_model)
|
| 43 |
if 'Affinity' in args.objectives:
|
| 44 |
-
|
| 45 |
-
affinity_model = AffinityModel(affinity_predictor, target_sequence, device)
|
| 46 |
score_models.append(affinity_model)
|
| 47 |
-
|
| 48 |
-
if 'Specificity' in args.objectives:
|
| 49 |
-
motif_penalty = True
|
| 50 |
-
else:
|
| 51 |
-
motif_penalty = False
|
| 52 |
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
|
| 53 |
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
score_models.append(motif_model)
|
| 56 |
|
| 57 |
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
|
|
@@ -68,36 +98,46 @@ else:
|
|
| 68 |
f.write(objective_line)
|
| 69 |
|
| 70 |
for i in range(args.n_batches):
|
| 71 |
-
if
|
| 72 |
-
x_init =
|
| 73 |
-
elif source_distribution == "mask":
|
| 74 |
-
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
|
| 75 |
else:
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
|
| 83 |
step_size=step_size,
|
| 84 |
verbose=True,
|
| 85 |
time_grid=torch.tensor([0.0, 1.0-1e-3]),
|
| 86 |
score_models=score_models,
|
| 87 |
-
num_objectives=len(score_models) + int(
|
| 88 |
-
weights=args.weights
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
print(samples)
|
| 93 |
|
| 94 |
scores = []
|
|
|
|
| 95 |
for i, s in enumerate(score_models):
|
| 96 |
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 97 |
if 't' in sig.parameters:
|
| 98 |
-
candidate_scores = s(
|
| 99 |
else:
|
| 100 |
-
candidate_scores = s(
|
| 101 |
|
| 102 |
if args.objectives[i] == 'Affinity':
|
| 103 |
candidate_scores = 10 * candidate_scores
|
|
@@ -110,7 +150,7 @@ for i in range(args.n_batches):
|
|
| 110 |
print(scores)
|
| 111 |
|
| 112 |
with open(args.output_file, 'a') as f:
|
| 113 |
-
f.write(
|
| 114 |
for score in scores:
|
| 115 |
f.write(f",{score}")
|
| 116 |
f.write('\n')
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 8 |
+
|
| 9 |
+
warnings.filterwarnings("ignore")
|
| 10 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 11 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 12 |
+
|
| 13 |
+
from sklearn.exceptions import InconsistentVersionWarning
|
| 14 |
+
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
|
| 15 |
+
|
| 16 |
+
logging.getLogger().setLevel(logging.ERROR)
|
| 17 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
| 18 |
+
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
| 19 |
+
logging.getLogger("transformers").setLevel(logging.ERROR)
|
| 20 |
+
logging.getLogger("absl").setLevel(logging.ERROR)
|
| 21 |
+
|
| 22 |
+
from transformers import logging as hf_logging
|
| 23 |
+
hf_logging.set_verbosity_error()
|
| 24 |
+
hf_logging.disable_progress_bar()
|
| 25 |
+
|
| 26 |
+
logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR)
|
| 27 |
+
logging.getLogger("pytorch_lightning.utilities.seed").setLevel(logging.ERROR)
|
| 28 |
+
|
| 29 |
import torch
|
| 30 |
from transformers import AutoTokenizer
|
| 31 |
from pathlib import Path
|
| 32 |
import inspect
|
| 33 |
|
| 34 |
+
# from models.peptide_classifiers import *
|
| 35 |
+
from models.peptiverse_classifiers import *
|
| 36 |
from utils.parsing import parse_guidance_args
|
| 37 |
args = parse_guidance_args()
|
| 38 |
|
|
|
|
| 46 |
|
| 47 |
length = args.length
|
| 48 |
target = args.target_protein
|
| 49 |
+
|
| 50 |
if args.motifs:
|
| 51 |
motifs = parse_motifs(args.motifs).to(device)
|
|
|
|
| 52 |
|
| 53 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 54 |
target_sequence = tokenizer(target, return_tensors='pt').to(device)
|
|
|
|
| 58 |
|
| 59 |
score_models = []
|
| 60 |
if 'Hemolysis' in args.objectives:
|
| 61 |
+
hemolysis_model = HemolysisWT()
|
| 62 |
score_models.append(hemolysis_model)
|
| 63 |
if 'Non-Fouling' in args.objectives:
|
| 64 |
+
nonfouling_model = NonfoulingWT()
|
| 65 |
score_models.append(nonfouling_model)
|
| 66 |
if 'Solubility' in args.objectives:
|
| 67 |
+
solubility_model = Solubility()
|
| 68 |
score_models.append(solubility_model)
|
| 69 |
+
if 'Permeability' in args.objectives:
|
| 70 |
+
permeability_model = PermeabilityWT()
|
| 71 |
+
score_models.append(permeability_model)
|
| 72 |
if 'Half-Life' in args.objectives:
|
| 73 |
+
halflife_model = HalfLifeWT()
|
| 74 |
score_models.append(halflife_model)
|
| 75 |
if 'Affinity' in args.objectives:
|
| 76 |
+
affinity_model = AffinityWT(target)
|
|
|
|
| 77 |
score_models.append(affinity_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
|
| 79 |
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
|
| 80 |
+
if 'Specificity' in args.objectives:
|
| 81 |
+
args.specificity = True
|
| 82 |
+
else:
|
| 83 |
+
args.specificity = False
|
| 84 |
+
motif_model = MotifModelWT(bindevaluator, target_sequence['input_ids'], motifs, tokenizer, device, penalty=args.specificity)
|
| 85 |
score_models.append(motif_model)
|
| 86 |
|
| 87 |
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
|
|
|
|
| 98 |
f.write(objective_line)
|
| 99 |
|
| 100 |
for i in range(args.n_batches):
|
| 101 |
+
if args.starting_sequence:
|
| 102 |
+
x_init = tokenizer(args.starting_sequence, return_tensors='pt')['input_ids'].to(device)
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
+
if source_distribution == "uniform":
|
| 105 |
+
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
|
| 106 |
+
elif source_distribution == "mask":
|
| 107 |
+
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError
|
| 110 |
|
| 111 |
+
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
|
| 112 |
+
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
|
| 113 |
+
x_init = torch.cat([zeros, x_init, twos], dim=1)
|
| 114 |
+
|
| 115 |
+
if args.fixed_positions is not None:
|
| 116 |
+
fixed_positions = parse_motifs(args.fixed_positions).tolist()
|
| 117 |
+
else:
|
| 118 |
+
fixed_positions = []
|
| 119 |
+
|
| 120 |
+
invalid_tokens = torch.tensor([0, 1, 2, 3], device=device)
|
| 121 |
|
| 122 |
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
|
| 123 |
step_size=step_size,
|
| 124 |
verbose=True,
|
| 125 |
time_grid=torch.tensor([0.0, 1.0-1e-3]),
|
| 126 |
score_models=score_models,
|
| 127 |
+
num_objectives=len(score_models) + int(args.specificity),
|
| 128 |
+
weights=args.weights,
|
| 129 |
+
tokenizer=tokenizer,
|
| 130 |
+
fixed_positions=fixed_positions,
|
| 131 |
+
invalid_tokens=invalid_tokens)
|
|
|
|
| 132 |
|
| 133 |
scores = []
|
| 134 |
+
input_seqs = [tokenizer.batch_decode(x_1)[0].replace(' ', '')[5:-5]]
|
| 135 |
for i, s in enumerate(score_models):
|
| 136 |
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
|
| 137 |
if 't' in sig.parameters:
|
| 138 |
+
candidate_scores = s(input_seqs, 1)
|
| 139 |
else:
|
| 140 |
+
candidate_scores = s(input_seqs)
|
| 141 |
|
| 142 |
if args.objectives[i] == 'Affinity':
|
| 143 |
candidate_scores = 10 * candidate_scores
|
|
|
|
| 150 |
print(scores)
|
| 151 |
|
| 152 |
with open(args.output_file, 'a') as f:
|
| 153 |
+
f.write(input_seqs[0])
|
| 154 |
for score in scores:
|
| 155 |
f.write(f",{score}")
|
| 156 |
f.write('\n')
|