File size: 2,197 Bytes
94391f2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import random
class PLEncoder(nn.Module):
def __init__(self, embed_dim, pocket_graph=None, aggregator=None, idx2assayid={}, assayid_lst_train=[], mol_smi={}, train_label_lst=[], cuda="cpu", uv=True):
super(PLEncoder, self).__init__()
self.uv = uv
self.pocket_graph = pocket_graph
self.aggregator = aggregator
self.embed_dim = embed_dim
self.device = cuda
smi2idx = {smi:idx for idx, smi in enumerate(mol_smi)}
self.idx2assayid, self.assayid_lst_train, self.smi2idx, self.mol_smi, self.train_label_lst = idx2assayid, assayid_lst_train, smi2idx, mol_smi, train_label_lst
self.assayid_set_train = set(assayid_lst_train)
self.label_dicts = {x["assay_id"]: x for x in self.train_label_lst}
self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim) #
def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
to_neighs = []
if nodes_lig is None:
lig_smi_lst = ["----"] * len(nodes_pocket)
else:
lig_smi_lst = [self.mol_smi[lig_id] for lig_id in nodes_lig]
for node, smi in zip(nodes_pocket, lig_smi_lst):
assayid = self.idx2assayid[node]
neighbors = []
nbr_pockets = self.pocket_graph.get(assayid, [])
# random.shuffle(nbr_pockets)
# breakpoint()
for n_assayid, score in nbr_pockets:
nbr_smi = self.label_dicts[n_assayid]["ligands"][0]["smi"]
if assayid == n_assayid:
continue
if smi == nbr_smi:
continue
if n_assayid not in self.assayid_set_train:
continue
neighbors.append((self.smi2idx[nbr_smi], int((score - 0.5) * 10)))
to_neighs.append(neighbors)
neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-item network
return neigh_feats
def refine_pocket(self, pocket_embed, neighbor_list=None):
return self.aggregator.forward_inference(pocket_embed, neighbor_list)
|