LigUnity / HGNN /PP_Encoder.py
Irwiny123's picture
提交LigUnity初始代码
94391f2
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import random
import copy
class PPEncoder(nn.Module):
def __init__(self, pocket_encoder, embed_dim, pocket_graph=None, aggregator=None, assayid_lst_all=[], assayid_lst_train=[], base_model=None, cuda="cpu"):
super(PPEncoder, self).__init__()
self.pocket_encoder = pocket_encoder
self.pocket_graph = pocket_graph
self.aggregator = aggregator
if base_model != None:
self.base_model = base_model
self.embed_dim = embed_dim
self.device = cuda
self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim)
self.assayid_lst_all, self.assayid_set_train = assayid_lst_all, set(assayid_lst_train)
self.assayid2idxes = {}
for idx, assayid in enumerate(assayid_lst_all):
if assayid not in self.assayid2idxes:
self.assayid2idxes[assayid] = []
self.assayid2idxes[assayid].append(idx)
def forward(self, nodes_pocket, nodes_lig=None, max_sample=10):
to_neighs = []
for node in nodes_pocket:
assayid = self.assayid_lst_all[node]
neighbors = []
nbr_pockets = self.pocket_graph.get(assayid, [])
for n_assayid, score in nbr_pockets:
if n_assayid == assayid:
continue
if n_assayid not in self.assayid_set_train:
continue
neighbors.append((random.choices(self.assayid2idxes[n_assayid])[0], score))
to_neighs.append(neighbors)
neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) # user-user network
self_feats = self.pocket_encoder(nodes_pocket, nodes_lig, max_sample)
return (self_feats + neigh_feats) / 2
def refine_pocket(self, pocket_embed, neighbor_list=None):
neigh_feats = self.aggregator.forward_inference(pocket_embed, neighbor_list)
self_feats = self.pocket_encoder.refine_pocket(pocket_embed, neighbor_list)
return (self_feats + neigh_feats) / 2