| import math |
| from functools import partial |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from src.common.all_atom import compute_backbone |
| from src.common.geo_utils import calc_distogram |
| from src.models.net.ipa import TranslationIPA |
|
|
|
|
| def get_positional_embedding(indices, embedding_dim, max_len=2056): |
| """Creates sine / cosine positional embeddings from a prespecified indices. |
| |
| Args: |
| indices: offsets of size [..., N_edges] of type integer |
| max_len: maximum length. |
| embedding_dim: dimension of the embeddings to create |
| |
| Returns: |
| positional embedding of shape [N, embedding_dim] |
| """ |
| K = torch.arange(embedding_dim//2, device=indices.device) |
| pos_embedding_sin = torch.sin( |
| indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device) |
| pos_embedding_cos = torch.cos( |
| indices[..., None] * math.pi / (max_len**(2*K[None]/embedding_dim))).to(indices.device) |
| pos_embedding = torch.cat([ |
| pos_embedding_sin, pos_embedding_cos], axis=-1) |
| return pos_embedding |
|
|
|
|
| def get_timestep_embedding(timesteps, embedding_dim, max_len=10000): |
| |
| assert len(timesteps.shape) == 1 |
| timesteps = timesteps * max_len |
| half_dim = embedding_dim // 2 |
| emb = math.log(max_len) / (half_dim - 1) |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float, device=timesteps.device) * -emb) |
| emb = timesteps.float()[:, None] * emb[None, :] |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| if embedding_dim % 2 == 1: |
| emb = F.pad(emb, (0, 1), mode='constant') |
| assert emb.shape == (timesteps.shape[0], embedding_dim) |
| return emb |
|
|
|
|
| class EmbeddingModule(nn.Module): |
| def __init__(self, |
| init_embed_size: int, |
| node_embed_size: int, |
| edge_embed_size: int, |
| num_bins: int = 22, |
| min_bin: float = 1e-5, |
| max_bin: float = 20.0, |
| self_conditioning: bool = True, |
| ): |
| super(EmbeddingModule, self).__init__() |
| pos_embed_size = init_embed_size |
| t_embed_size = init_embed_size |
| |
| |
| node_in_dim = t_embed_size + 1 |
| edge_in_dim = (t_embed_size + 1) * 2 |
|
|
| |
| node_in_dim += pos_embed_size |
| edge_in_dim += pos_embed_size |
|
|
| self.node_embed = nn.Sequential( |
| nn.Linear(node_in_dim, node_embed_size), |
| nn.ReLU(), |
| nn.Linear(node_embed_size, node_embed_size), |
| nn.ReLU(), |
| nn.Linear(node_embed_size, node_embed_size), |
| nn.LayerNorm(node_embed_size), |
| ) |
| |
| |
| self.self_conditioning = self_conditioning |
| if self_conditioning: |
| edge_in_dim += num_bins |
| |
| self.edge_embed = nn.Sequential( |
| nn.Linear(edge_in_dim, edge_embed_size), |
| nn.ReLU(), |
| nn.Linear(edge_embed_size, edge_embed_size), |
| nn.ReLU(), |
| nn.Linear(edge_embed_size, edge_embed_size), |
| nn.LayerNorm(edge_embed_size), |
| ) |
|
|
| self.time_embed = partial( |
| get_timestep_embedding, embedding_dim=t_embed_size |
| ) |
| self.position_embed = partial( |
| get_positional_embedding, embedding_dim=pos_embed_size |
| ) |
| self.distogram_embed = partial( |
| calc_distogram, |
| min_bin=min_bin, |
| max_bin=max_bin, |
| num_bins=num_bins, |
| ) |
|
|
| def forward( |
| self, |
| residue_idx, |
| t, |
| fixed_mask, |
| self_conditioning_ca, |
| ): |
| """ |
| Args: |
| residue_idx: [..., N] Positional sequence index for each residue. |
| t: Sampled t in [0, 1]. |
| fixed_mask: mask of fixed (motif) residues. |
| self_conditioning_ca: [..., N, 3] Ca positions of self-conditioning |
| input. |
| |
| Returns: |
| node_embed: [B, N, D_node] |
| edge_embed: [B, N, N, D_edge] |
| """ |
| B, L = residue_idx.shape |
| fixed_mask = fixed_mask[..., None].float() |
| node_feats = [] |
| pair_feats = [] |
| |
| |
| t_embed = torch.tile(self.time_embed(t)[:, None, :], (1, L, 1)) |
| t_embed = torch.cat([t_embed, fixed_mask], dim=-1) |
| node_feats.append(t_embed) |
| |
| |
| concat_1d = torch.cat( |
| [torch.tile(t_embed[:, :, None, :], (1, 1, L, 1)), |
| torch.tile(t_embed[:, None, :, :], (1, L, 1, 1))], |
| dim=-1).float().reshape([B, L**2, -1]) |
| pair_feats.append(concat_1d) |
|
|
| |
| node_feats.append(self.position_embed(residue_idx)) |
| |
| |
| rel_seq_offset = residue_idx[:, :, None] - residue_idx[:, None, :] |
| rel_seq_offset = rel_seq_offset.reshape([B, L**2]) |
| pair_feats.append(self.position_embed(rel_seq_offset)) |
|
|
| |
| if self.self_conditioning: |
| ca_dist = self.distogram_embed(self_conditioning_ca) |
| pair_feats.append(ca_dist.reshape([B, L**2, -1])) |
|
|
| node_embed = self.node_embed(torch.cat(node_feats, dim=-1).float()) |
| edge_embed = self.edge_embed(torch.cat(pair_feats, dim=-1).float()) |
| edge_embed = edge_embed.reshape([B, L, L, -1]) |
| return node_embed, edge_embed |
|
|
|
|
| class DenoisingNet(nn.Module): |
| def __init__(self, |
| embedder: nn.Module, |
| translator: nn.Module, |
| ): |
| super(DenoisingNet, self).__init__() |
| self.embedder = embedder |
| self.translator = translator |
|
|
| def forward(self, batch, as_tensor_7=False): |
| """Forward computes the denoised frames p(X^t|X^{t+1}) |
| """ |
| |
| node_mask = batch['residue_mask'].type(torch.float) |
| fixed_mask = batch['fixed_mask'].type(torch.float) |
| edge_mask = node_mask[..., None] * node_mask[..., None, :] |
|
|
| |
| node_embed, edge_embed = self.embedder( |
| residue_idx=batch['residue_idx'], |
| t=batch['t'], |
| fixed_mask=fixed_mask, |
| self_conditioning_ca=batch['sc_ca_t'], |
| ) |
| node_embed = node_embed * node_mask[..., None] |
| edge_embed = edge_embed * edge_mask[..., None] |
| |
| |
| model_out = self.translator(node_embed, edge_embed, batch) |
| |
| |
| gt_psi = batch['torsion_angles_sin_cos'][..., 2, :] |
| psi_pred = gt_psi * fixed_mask[..., None] + model_out['psi'] * (1 - fixed_mask[..., None]) |
| rigids_pred = model_out['out_rigids'] |
| |
| bb_representations = compute_backbone( |
| rigids_pred, psi_pred, aatype=batch['aatype'] if 'aatype' in batch else None |
| ) |
| atom37_pos = bb_representations[0].to(rigids_pred.device) |
| atom14_pos = bb_representations[-1].to(rigids_pred.device) |
| |
| if as_tensor_7: |
| rigids_pred = rigids_pred.to_tensor_7() |
| |
| return { |
| 'rigids': rigids_pred, |
| 'psi': psi_pred, |
| 'atom37': atom37_pos, |
| 'atom14': atom14_pos, |
| } |