P2DFlow / analysis /src /models /net /denoising_ipa.py
Holmes
test
ca7299e
import math
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
from src.common.all_atom import compute_backbone
from src.common.geo_utils import calc_distogram
from src.models.net.ipa import TranslationIPA
def get_positional_embedding(indices, embedding_dim, max_len=2056):
"""Creates sine / cosine positional embeddings from a prespecified indices.
Args:
indices: offsets of size [..., N_edges] of type integer
max_len: maximum length.
embedding_dim: dimension of the embeddings to create
Returns:
positional embedding of shape [N, embedding_dim]
"""
K = torch.arange(embedding_dim//2, device=indices.device)
pos_embedding_sin = torch.sin(
indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device)
pos_embedding_cos = torch.cos(
indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device)
pos_embedding = torch.cat([
pos_embedding_sin, pos_embedding_cos], axis=-1)
return pos_embedding
def get_timestep_embedding(timesteps, embedding_dim, max_len=10000):
# Code from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
assert len(timesteps.shape) == 1
timesteps = timesteps * max_len
half_dim = embedding_dim // 2
emb = math.log(max_len) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
class EmbeddingModule(nn.Module):
def __init__(self,
init_embed_size: int,
node_embed_size: int,
edge_embed_size: int,
num_bins: int = 22,
min_bin: float = 1e-5,
max_bin: float = 20.0,
self_conditioning: bool = True,
):
super(EmbeddingModule, self).__init__()
pos_embed_size = init_embed_size
t_embed_size = init_embed_size
# time embedding
node_in_dim = t_embed_size + 1
edge_in_dim = (t_embed_size + 1) * 2
# positional embedding
node_in_dim += pos_embed_size
edge_in_dim += pos_embed_size
self.node_embed = nn.Sequential(
nn.Linear(node_in_dim, node_embed_size),
nn.ReLU(),
nn.Linear(node_embed_size, node_embed_size),
nn.ReLU(),
nn.Linear(node_embed_size, node_embed_size),
nn.LayerNorm(node_embed_size),
)
# self-conditioning trick used in RFDiffusion
self.self_conditioning = self_conditioning
if self_conditioning:
edge_in_dim += num_bins
self.edge_embed = nn.Sequential(
nn.Linear(edge_in_dim, edge_embed_size),
nn.ReLU(),
nn.Linear(edge_embed_size, edge_embed_size),
nn.ReLU(),
nn.Linear(edge_embed_size, edge_embed_size),
nn.LayerNorm(edge_embed_size),
)
self.time_embed = partial(
get_timestep_embedding, embedding_dim=t_embed_size
)
self.position_embed = partial(
get_positional_embedding, embedding_dim=pos_embed_size
)
self.distogram_embed = partial(
calc_distogram,
min_bin=min_bin,
max_bin=max_bin,
num_bins=num_bins,
)
def forward(
self,
residue_idx,
t,
fixed_mask,
self_conditioning_ca,
):
"""
Args:
residue_idx: [..., N] Positional sequence index for each residue.
t: Sampled t in [0, 1].
fixed_mask: mask of fixed (motif) residues.
self_conditioning_ca: [..., N, 3] Ca positions of self-conditioning
input.
Returns:
node_embed: [B, N, D_node]
edge_embed: [B, N, N, D_edge]
"""
B, L = residue_idx.shape
fixed_mask = fixed_mask[..., None].float()
node_feats = []
pair_feats = []
# configure time embedding
t_embed = torch.tile(self.time_embed(t)[:, None, :], (1, L, 1))
t_embed = torch.cat([t_embed, fixed_mask], dim=-1)
node_feats.append(t_embed)
# make pair embedding from 1d time feats
concat_1d = torch.cat(
[torch.tile(t_embed[:, :, None, :], (1, 1, L, 1)),
torch.tile(t_embed[:, None, :, :], (1, L, 1, 1))],
dim=-1).float().reshape([B, L**2, -1])
pair_feats.append(concat_1d)
# positional embedding
node_feats.append(self.position_embed(residue_idx))
# relative 2d positional embedding
rel_seq_offset = residue_idx[:, :, None] - residue_idx[:, None, :]
rel_seq_offset = rel_seq_offset.reshape([B, L**2])
pair_feats.append(self.position_embed(rel_seq_offset))
# self-conditioning distogram of C-alpha atoms
if self.self_conditioning:
ca_dist = self.distogram_embed(self_conditioning_ca)
pair_feats.append(ca_dist.reshape([B, L**2, -1]))
node_embed = self.node_embed(torch.cat(node_feats, dim=-1).float())
edge_embed = self.edge_embed(torch.cat(pair_feats, dim=-1).float())
edge_embed = edge_embed.reshape([B, L, L, -1])
return node_embed, edge_embed
class DenoisingNet(nn.Module):
def __init__(self,
embedder: nn.Module,
translator: nn.Module,
):
super(DenoisingNet, self).__init__()
self.embedder = embedder # embedding module
self.translator = translator # translationIPA
def forward(self, batch, as_tensor_7=False):
"""Forward computes the denoised frames p(X^t|X^{t+1})
"""
# Frames as [batch, res, 7] tensors.
node_mask = batch['residue_mask'].type(torch.float) # [B, N]
fixed_mask = batch['fixed_mask'].type(torch.float)
edge_mask = node_mask[..., None] * node_mask[..., None, :]
# Get embeddings.
node_embed, edge_embed = self.embedder(
residue_idx=batch['residue_idx'],
t=batch['t'],
fixed_mask=fixed_mask,
self_conditioning_ca=batch['sc_ca_t'],
)
node_embed = node_embed * node_mask[..., None] # (L, D)
edge_embed = edge_embed * edge_mask[..., None] # (L, L, D)
# Translation for frames.
model_out = self.translator(node_embed, edge_embed, batch)
# Psi angle prediction
gt_psi = batch['torsion_angles_sin_cos'][..., 2, :]
psi_pred = gt_psi * fixed_mask[..., None] + model_out['psi'] * (1 - fixed_mask[..., None])
rigids_pred = model_out['out_rigids']
bb_representations = compute_backbone(
rigids_pred, psi_pred, aatype=batch['aatype'] if 'aatype' in batch else None
)
atom37_pos = bb_representations[0].to(rigids_pred.device)
atom14_pos = bb_representations[-1].to(rigids_pred.device)
if as_tensor_7:
rigids_pred = rigids_pred.to_tensor_7()
return {
'rigids': rigids_pred,
'psi': psi_pred,
'atom37': atom37_pos,
'atom14': atom14_pos,
}