File size: 6,671 Bytes
ca7299e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import copy
# from utils.constants import *
# from utils.structure_utils import rigid_from_3_points
# from scipy.spatial.transform import Rotation as scipy_R
# def init_lecun_normal(module):
# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
# normal = torch.distributions.normal.Normal(0, 1)
# alpha = (a - mu) / sigma
# beta = (b - mu) / sigma
# alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
# x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
# x = torch.clamp(x, a, b)
# return x
# def sample_truncated_normal(shape):
# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
# return stddev * truncated_normal(torch.rand(shape))
# module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
# return module
# def init_lecun_normal_param(weight):
# def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
# normal = torch.distributions.normal.Normal(0, 1)
# alpha = (a - mu) / sigma
# beta = (b - mu) / sigma
# alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
# p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
# v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
# x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
# x = torch.clamp(x, a, b)
# return x
# def sample_truncated_normal(shape):
# stddev = np.sqrt(1.0/shape[-1])/.87962566103423978 # shape[-1] = fan_in
# return stddev * truncated_normal(torch.rand(shape))
# weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
# return weight
# # for gradient checkpointing
# def create_custom_forward(module, **kwargs):
# def custom_forward(*inputs):
# return module(*inputs, **kwargs)
# return custom_forward
# def get_clones(module, N):
# return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
self.broadcast_dim=broadcast_dim
self.p_drop=p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if not self.broadcast_dim == None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
def rbf(D, D_min = 0., D_max = 20., D_count = 36):
'''
D: (B,L,L)
'''
# Distance radial basis function
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2) # (B, L, L, D_count)
return RBF
# def get_centrality(D, dist_max):
link_mat = (D < dist_max).type(torch.int).to(D.device)
link_mat -= torch.eye(link_mat.shape[1], dtype = torch.int).to(D.device)
cent_vec = torch.sum(link_mat, dim=-1, dtype = torch.int)
return cent_vec, link_mat
def get_sub_graph(xyz, mask = None, dist_max = 20, kmin=32):
B, L = xyz.shape[:2]
device = xyz.device
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)* 10 * dist_max # (B, L, L)
idx = torch.arange(L, device=device)[None,:] # (1, L)
seq = (idx[:,None,:] - idx[:,:,None]).abs() # (1, L, L)
if mask is not None:
mask_cross = mask.unsqueeze(2).type(torch.float32)* mask.unsqueeze(1).type(torch.float32) # (B, L, L)
D = D + 10 * dist_max * (1 - mask_cross)
seq = seq * mask_cross # (B, L, L)
seq_cond = torch.logical_and(seq > 0, seq < kmin)
cond = torch.logical_or(D < dist_max, seq_cond)
b,i,j = torch.where(cond) # return idx where is true
src = b*L+i
tgt = b*L+j
return src, tgt, (b,i,j)
def scatter_add(src ,index ,dim_index, num_nodes): # sum by the index of src for a source node in the graph
out = src.new_zeros(num_nodes ,src.shape[1])
index = index.reshape(-1 ,1).expand_as(src)
return out.scatter_add_(dim_index, index, src)
# # Load ESM-2 model
# model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
# batch_converter = alphabet.get_batch_converter()
# model.eval() # disables dropout for deterministic results
# def get_pre_repr(seq):
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
# data = [
# ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
# ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
# ("protein3", "K A <mask> I S Q"),
# ]
seq_string = ''.join([aa_321[num2aa[i]] for i in seq])
data = [("protein1", seq_string)]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
node_repr = results["representations"][33][0,1:-1,:]
pair_repr = results['attentions'][0,-1,:,1:-1,1:-1].permute(1,2,0)
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
# sequence_representations = []
# for i, tokens_len in enumerate(batch_lens):
# sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
# Look at the unsupervised self-attention map contact predictions
# for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
# plt.matshow(attention_contacts[: tokens_len, : tokens_len])
return node_repr, pair_repr |