File size: 2,114 Bytes
94391f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import random
import copy

class PPEncoder(nn.Module):

    def __init__(self, pocket_encoder, embed_dim, pocket_graph=None, aggregator=None, assayid_lst_all=[], assayid_lst_train=[], base_model=None, cuda="cpu"):
        super(PPEncoder, self).__init__()

        self.pocket_encoder = pocket_encoder
        self.pocket_graph = pocket_graph
        self.aggregator = aggregator
        if base_model != None:
            self.base_model = base_model
        self.embed_dim = embed_dim
        self.device = cuda
        self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)
        self.assayid_lst_all, self.assayid_set_train = assayid_lst_all, set(assayid_lst_train)
        self.assayid2idxes = {}
        for idx, assayid in enumerate(assayid_lst_all):
            if assayid not in self.assayid2idxes:
                self.assayid2idxes[assayid] = []
            self.assayid2idxes[assayid].append(idx)

    def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
        to_neighs = []

        for node in nodes_pocket:
            assayid = self.assayid_lst_all[node]
            neighbors = []
            nbr_pockets = self.pocket_graph.get(assayid, [])
            for n_assayid, score in nbr_pockets:
                if n_assayid == assayid:
                    continue
                if n_assayid not in self.assayid_set_train:
                    continue
                neighbors.append((random.choices(self.assayid2idxes[n_assayid])[0], score))
            to_neighs.append(neighbors)

        neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs)  # user-user network
        self_feats = self.pocket_encoder(nodes_pocket, nodes_lig, max_sample)

        return (self_feats + neigh_feats) / 2
    
    def refine_pocket(self, pocket_embed, neighbor_list=None):
        neigh_feats = self.aggregator.forward_inference(pocket_embed, neighbor_list)
        self_feats = self.pocket_encoder.refine_pocket(pocket_embed, neighbor_list)
        return (self_feats + neigh_feats) / 2