| | import numpy as np |
| | import random |
| | import math |
| |
|
| | from sklearn.metrics import * |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset |
| | import pickle |
| |
|
| | |
| | def word2idx(word, words): |
| | if word in words.keys(): |
| | return int(words[word]) |
| | |
| | return 0 |
| |
|
| | def pad_seq(dataset, max_len): |
| | output = [] |
| | for seq in dataset: |
| | pad = np.zeros(max_len) |
| | pad[:len(seq)] = seq |
| | output.append(pad) |
| | |
| | return np.array(output) |
| |
|
| | def str2bool(seq): |
| | out = [] |
| | for s in seq: |
| | if s == "positive": |
| | out.append(1) |
| | elif s == "negative": |
| | out.append(0) |
| | |
| | return np.array(out) |
| | |
| | class API_Dataset(Dataset): |
| | def __init__(self, apta, esm_prot, y, apta_attn_mask, prot_attn_mask): |
| | super(Dataset, self).__init__() |
| | |
| | self.apta = np.array(apta, dtype=np.int64) |
| | self.esm_prot = np.array(esm_prot, dtype=np.int64) |
| | self.y = np.array(y, dtype=np.int64) |
| | self.apta_attn_mask = np.array(apta_attn_mask) |
| | self.prot_attn_mask = np.array(prot_attn_mask) |
| | self.len = len(self.apta) |
| |
|
| | def __len__(self): |
| | return self.len |
| |
|
| | def __getitem__(self, index): |
| | return torch.tensor(self.apta[index], dtype=torch.int64), torch.tensor(self.esm_prot[index], dtype=torch.int64), torch.tensor(self.y[index], dtype=torch.int64), torch.tensor(self.apta_attn_mask[index], dtype=torch.int64), torch.tensor(self.prot_attn_mask[index], dtype=torch.int64) |
| |
|
| | def find_opt_threshold(target, pred): |
| | result = 0 |
| | best = 0 |
| | |
| | for i in range(0, 1000): |
| | pred_threshold = np.where(pred > i/1000, 1, 0) |
| | now = f1_score(target, pred_threshold) |
| | if now > best: |
| | result = i/1000 |
| | best = now |
| | |
| | return result |
| |
|
| | def argument_seqset(seqset): |
| | arg_seqset = [] |
| | for s, ss in seqset: |
| | arg_seqset.append([s, ss]) |
| |
|
| | arg_seqset.append([s[::-1], ss[::-1]]) |
| |
|
| | return arg_seqset |
| |
|
| | def augment_apis(apta, prot, ys): |
| | aug_apta = [] |
| | aug_prot = [] |
| | aug_y = [] |
| | for a, p, y in zip(apta, prot, ys): |
| | aug_apta.append(a) |
| | aug_prot.append(p) |
| | aug_y.append(y) |
| |
|
| | aug_apta.append(a[::-1]) |
| | aug_prot.append(p) |
| | aug_y.append(y) |
| |
|
| | return np.array(aug_apta), np.array(aug_prot), np.array(aug_y) |
| |
|
| |
|
| |
|
| | def load_data_source(filepath): |
| | with open(filepath,"rb") as fr: |
| | dataset = pickle.load(fr) |
| | dataset_train = np.array(dataset[dataset["dataset"]=="training dataset"]) |
| | dataset_test = np.array(dataset[dataset["dataset"]=="test dataset"]) |
| | dataset_bench = np.array(dataset[dataset['dataset']=='benchmark dataset']) |
| |
|
| | return dataset_train, dataset_test, dataset_bench |
| |
|
| |
|
| | def get_dataset(filepath, prot_max_len, n_prot_vocabs, prot_words): |
| | dataset_train, dataset_test, dataset_bench = load_data_source(filepath) |
| |
|
| | |
| | arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]) |
| | datasets_train = [rna2vec(arg_apta), tokenize_sequences(arg_prot, prot_max_len, n_prot_vocabs, prot_words), str2bool(arg_y)] |
| | datasets_test = [rna2vec(dataset_test[:, 0]), tokenize_sequences(dataset_test[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_test[:, 2])] |
| | datasets_bench = [rna2vec(dataset_bench[:, 0]), tokenize_sequences(dataset_bench[:, 1], prot_max_len, n_prot_vocabs, prot_words), str2bool(dataset_bench[:, 2])] |
| |
|
| | return datasets_train, datasets_test, datasets_bench |
| |
|
| |
|
| | def get_esm_dataset(filepath, batch_converter, alphabet): |
| | dataset_train, dataset_test, dataset_bench = load_data_source(filepath) |
| |
|
| | |
| | |
| |
|
| | arg_apta, arg_prot, arg_y = dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2] |
| | arg_apta, arg_prot, arg_y = augment_apis(arg_apta, arg_prot, arg_y) |
| |
|
| | train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)] |
| | _, _, prot_tokens = batch_converter(train_inputs) |
| | datasets_train = [rna2vec(arg_apta), prot_tokens, str2bool(arg_y)] |
| |
|
| | test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])] |
| | _, _, test_prot_tokens = batch_converter(test_inputs) |
| | datasets_test = [rna2vec(dataset_test[:, 0]), test_prot_tokens, str2bool(dataset_test[:, 2])] |
| |
|
| | bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])] |
| | _, _, bench_prot_tokens = batch_converter(bench_inputs) |
| | |
| | bench_prot_tokenized = bench_prot_tokens[:, :1678] |
| | |
| | prot_ex = torch.ones((bench_prot_tokenized.shape[0], 1678), dtype=torch.int64)*alphabet.padding_idx |
| | prot_ex[:, :bench_prot_tokenized.shape[1]] = bench_prot_tokenized |
| | datasets_bench = [rna2vec(dataset_bench[:, 0]), prot_ex, str2bool(dataset_bench[:, 2])] |
| |
|
| | return datasets_train, datasets_test, datasets_bench |
| |
|
| | def get_nt_esm_dataset(filepath, nt_tokenizer, batch_converter, alphabet): |
| | dataset_train, dataset_test, dataset_bench = load_data_source(filepath) |
| |
|
| | arg_apta, arg_prot, arg_y = augment_apis(dataset_train[:, 0], dataset_train[:, 1], dataset_train[:, 2]) |
| | |
| | max_length = 275 |
| |
|
| | train_inputs = [(i, j) for i, j in zip(arg_y, arg_prot)] |
| | _, _, prot_tokens = batch_converter(train_inputs) |
| | apta_toks = nt_tokenizer.batch_encode_plus(arg_apta, return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] |
| | apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id |
| | prot_attention_mask = prot_tokens != alphabet.padding_idx |
| | |
| | datasets_train = [apta_toks, prot_tokens, str2bool(arg_y), apta_attention_mask, prot_attention_mask] |
| |
|
| | test_inputs = [(i, j) for i, j in enumerate(dataset_test[:, 1])] |
| | _, _, test_prot_tokens = batch_converter(test_inputs) |
| | prot_ex = torch.ones((test_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx |
| | prot_ex[:, :test_prot_tokens.shape[1]] = test_prot_tokens |
| | apta_toks = nt_tokenizer.batch_encode_plus(dataset_test[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] |
| | apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id |
| | prot_attention_mask = prot_ex != alphabet.padding_idx |
| | datasets_test = [apta_toks, prot_ex, str2bool(dataset_test[:, 2]), apta_attention_mask, prot_attention_mask] |
| |
|
| | bench_inputs = [(i, j) for i, j in enumerate(dataset_bench[:, 1])] |
| | _, _, bench_prot_tokens = batch_converter(bench_inputs) |
| | |
| | prot_ex = torch.ones((bench_prot_tokens.shape[0], 1680), dtype=torch.int64)*alphabet.padding_idx |
| | prot_ex[:, :bench_prot_tokens.shape[1]] = bench_prot_tokens |
| | apta_toks = nt_tokenizer.batch_encode_plus(dataset_bench[:, 0], return_tensors='pt', padding='max_length', max_length=max_length)['input_ids'] |
| | apta_attention_mask = apta_toks != nt_tokenizer.pad_token_id |
| | prot_attention_mask = prot_ex != alphabet.padding_idx |
| | datasets_bench = [apta_toks, prot_ex, str2bool(dataset_bench[:, 2]), apta_attention_mask, prot_attention_mask] |
| |
|
| | return datasets_train, datasets_test, datasets_bench |
| |
|
| | def get_scores(target, pred): |
| | threshold = find_opt_threshold(target, pred) |
| | pred_threshold = np.where(pred > threshold, 1, 0) |
| | acc = accuracy_score(target, pred_threshold) |
| | roc_auc = roc_auc_score(target, pred) |
| | mcc = matthews_corrcoef(target, pred_threshold) |
| | f1 = f1_score(target, pred_threshold) |
| | pr_auc = average_precision_score(target, pred) |
| | cls_report = classification_report(target, pred_threshold) |
| | scores = { |
| | 'threshold': threshold, |
| | 'acc': acc, |
| | 'roc_auc': roc_auc, |
| | 'mcc': mcc, |
| | 'f1': f1, |
| | 'pr_auc': pr_auc, |
| | 'cls_report': cls_report |
| | } |
| | return scores |
| |
|