File size: 2,197 Bytes
94391f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import random

class PLEncoder(nn.Module):

    def __init__(self, embed_dim, pocket_graph=None, aggregator=None, idx2assayid={}, assayid_lst_train=[], mol_smi={}, train_label_lst=[], cuda="cpu", uv=True):
        super(PLEncoder, self).__init__()

        self.uv = uv
        self.pocket_graph = pocket_graph
        self.aggregator = aggregator
        self.embed_dim = embed_dim
        self.device = cuda
        smi2idx = {smi:idx for idx, smi in enumerate(mol_smi)}
        self.idx2assayid, self.assayid_lst_train, self.smi2idx, self.mol_smi, self.train_label_lst = idx2assayid, assayid_lst_train, smi2idx, mol_smi, train_label_lst
        self.assayid_set_train = set(assayid_lst_train)
        self.label_dicts = {x["assay_id"]: x for x in self.train_label_lst}
        self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)  #

    def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
        to_neighs = []
        if nodes_lig is None:
            lig_smi_lst = ["----"] * len(nodes_pocket)
        else:
            lig_smi_lst = [self.mol_smi[lig_id] for lig_id in nodes_lig]

        for node, smi in zip(nodes_pocket, lig_smi_lst):
            assayid = self.idx2assayid[node]
            neighbors = []
            nbr_pockets = self.pocket_graph.get(assayid, [])
            # random.shuffle(nbr_pockets)
            # breakpoint()
            for n_assayid, score in nbr_pockets:
                nbr_smi = self.label_dicts[n_assayid]["ligands"][0]["smi"]
                if assayid == n_assayid:
                    continue
                if smi == nbr_smi:
                    continue
                if n_assayid not in self.assayid_set_train:
                    continue
                neighbors.append((self.smi2idx[nbr_smi], int((score - 0.5) * 10)))
            to_neighs.append(neighbors)

        neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs)  # user-item network
        return neigh_feats
    
    def refine_pocket(self, pocket_embed, neighbor_list=None):
        return self.aggregator.forward_inference(pocket_embed, neighbor_list)