| | import math |
| | import torch |
| | from torch.nn import functional as F |
| | import numpy as np |
| | from data import utils as du |
| |
|
| |
|
| | def calc_distogram(pos, min_bin, max_bin, num_bins): |
| | dists_2d = torch.linalg.norm( |
| | pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None] |
| | lower = torch.linspace( |
| | min_bin, |
| | max_bin, |
| | num_bins, |
| | device=pos.device) |
| | upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1) |
| | dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype) |
| | return dgram |
| |
|
| | def add_RoPE(indices): |
| | """Creates sine / cosine positional embeddings from a prespecified indices. |
| | |
| | Args: |
| | indices: (B,L,embed_size) |
| | embed_size: dimension of the embeddings to create |
| | |
| | Returns: |
| | positional embedding of shape [B, L, embed_size] |
| | """ |
| | seq_len, embed_size = indices.shape[-2:] |
| | seq_all = torch.arange(seq_len, device=indices.device)[:,None] |
| | theta_all = torch.pow(1e4, torch.arange(embed_size)//2 / -embed_size)[None,:] |
| | sinusoidal_pos = (seq_all * theta_all.to(indices.device))[None,...] |
| |
|
| | cos_pos = torch.cos(sinusoidal_pos) |
| | sin_pos = torch.sin(sinusoidal_pos) |
| | indices_sin = torch.stack([-indices[..., 1::2], indices[..., ::2]], dim=-1) |
| | indices_sin = indices_sin.reshape(indices.shape) |
| | indices = indices * cos_pos + indices_sin * sin_pos |
| | return indices |
| |
|
| | def get_index_embedding(indices, embed_size, max_len=2056): |
| | """Creates sine / cosine positional embeddings from a prespecified indices. |
| | |
| | Args: |
| | indices: offsets of size [..., N_edges] of type integer |
| | max_len: maximum length. |
| | embed_size: dimension of the embeddings to create |
| | |
| | Returns: |
| | positional embedding of shape [N, embed_size] |
| | """ |
| | K = torch.arange(embed_size//2, device=indices.device) |
| | pos_embedding_sin = torch.sin( |
| | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) |
| | pos_embedding_cos = torch.cos( |
| | indices[..., None] * math.pi / (max_len**(2*K[None]/embed_size))).to(indices.device) |
| | pos_embedding = torch.cat([ |
| | pos_embedding_sin, pos_embedding_cos], axis=-1) |
| | return pos_embedding |
| |
|
| |
|
| | def get_time_embedding(timesteps, embedding_dim, max_positions=2000): |
| | |
| | assert len(timesteps.shape) == 1 |
| | timesteps = timesteps * max_positions |
| | half_dim = embedding_dim // 2 |
| | emb = math.log(max_positions) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) |
| | emb = timesteps.float()[:, None] * emb[None, :] |
| | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) |
| | if embedding_dim % 2 == 1: |
| | emb = F.pad(emb, (0, 1), mode='constant') |
| | assert emb.shape == (timesteps.shape[0], embedding_dim) |
| | return emb |
| |
|
| |
|
| | def t_stratified_loss(batch_t, batch_loss, num_bins=4, loss_name=None): |
| | """Stratify loss by binning t.""" |
| | batch_t = du.to_numpy(batch_t) |
| | batch_loss = du.to_numpy(batch_loss) |
| | flat_losses = batch_loss.flatten() |
| | flat_t = batch_t.flatten() |
| | bin_edges = np.linspace(0.0, 1.0 + 1e-3, num_bins+1) |
| | bin_idx = np.sum(bin_edges[:, None] <= flat_t[None, :], axis=0) - 1 |
| | t_binned_loss = np.bincount(bin_idx, weights=flat_losses) |
| | t_binned_n = np.bincount(bin_idx) |
| | stratified_losses = {} |
| | if loss_name is None: |
| | loss_name = 'loss' |
| | for t_bin in np.unique(bin_idx).tolist(): |
| | bin_start = bin_edges[t_bin] |
| | bin_end = bin_edges[t_bin+1] |
| | t_range = f'{loss_name} t=[{bin_start:.2f},{bin_end:.2f})' |
| | range_loss = t_binned_loss[t_bin] / t_binned_n[t_bin] |
| | stratified_losses[t_range] = range_loss |
| | return stratified_losses |
| |
|