|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch_scatter import scatter_mean |
|
|
|
|
|
from data.format import VOCAB |
|
|
from utils import register as R |
|
|
from utils.oom_decorator import oom_decorator |
|
|
from utils.const import aas |
|
|
from utils.nn_utils import variadic_meshgrid |
|
|
|
|
|
from .sidechain.api import SideChainModel |
|
|
from .backbone.api import BackboneModel |
|
|
|
|
|
from ..dyMEAN.modules.am_egnn import AMEGNN |
|
|
from ..dyMEAN.nn_utils import SeparatedAminoAcidFeature, ProteinFeature |
|
|
|
|
|
|
|
|
def create_encoder( |
|
|
name, |
|
|
atom_embed_size, |
|
|
embed_size, |
|
|
hidden_size, |
|
|
n_channel, |
|
|
n_layers, |
|
|
dropout, |
|
|
n_rbf, |
|
|
cutoff |
|
|
): |
|
|
if name == 'dyMEAN': |
|
|
encoder = AMEGNN( |
|
|
embed_size, hidden_size, hidden_size, n_channel, |
|
|
channel_nf=atom_embed_size, radial_nf=hidden_size, |
|
|
in_edge_nf=0, n_layers=n_layers, residual=True, |
|
|
dropout=dropout, dense=False, n_rbf=n_rbf, cutoff=cutoff) |
|
|
else: |
|
|
raise NotImplementedError(f'Encoder {encoder} not implemented') |
|
|
|
|
|
return encoder |
|
|
|
|
|
|
|
|
|
|
|
@R.register('AutoEncoder') |
|
|
class AutoEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
embed_size, |
|
|
hidden_size, |
|
|
latent_size, |
|
|
n_channel, |
|
|
latent_n_channel=1, |
|
|
mask_id=VOCAB.get_mask_idx(), |
|
|
latent_id=VOCAB.symbol_to_idx(VOCAB.LAT), |
|
|
max_position=2048, |
|
|
relative_position=False, |
|
|
CA_channel_idx=VOCAB.backbone_atoms.index('CA'), |
|
|
n_layers=3, |
|
|
dropout=0.1, |
|
|
mask_ratio=0.0, |
|
|
fix_alpha_carbon=False, |
|
|
h_kl_weight=0.1, |
|
|
z_kl_weight=0.5, |
|
|
coord_loss_weights={ |
|
|
'Xloss': 1.0, |
|
|
'ca_Xloss': 0.0, |
|
|
'bb_bond_lengths_loss': 1.0, |
|
|
'sc_bond_lengths_loss': 1.0, |
|
|
'bb_dihedral_angles_loss': 0.0, |
|
|
'sc_chi_angles_loss': 0.5 |
|
|
}, |
|
|
coord_loss_ratio=0.5, |
|
|
coord_prior_var=1.0, |
|
|
anchor_at_ca=False, |
|
|
share_decoder=False, |
|
|
n_rbf=0, |
|
|
cutoff=0, |
|
|
encoder='dyMEAN', |
|
|
mode='codesign', |
|
|
additional_noise_scale=0.0 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.mask_id = mask_id |
|
|
self.latent_id = latent_id |
|
|
self.ca_channel_idx = CA_channel_idx |
|
|
self.n_channel = n_channel |
|
|
self.mask_ratio = mask_ratio |
|
|
self.fix_alpha_carbon = fix_alpha_carbon |
|
|
self.h_kl_weight = h_kl_weight |
|
|
self.z_kl_weight = z_kl_weight |
|
|
self.coord_loss_weights = coord_loss_weights |
|
|
self.coord_loss_ratio = coord_loss_ratio |
|
|
self.mode = mode |
|
|
self.latent_size = 0 if self.mode == 'fixseq' else latent_size |
|
|
self.latent_n_channel = 0 if self.mode == 'fixbb' else latent_n_channel |
|
|
self.anchor_at_ca = anchor_at_ca |
|
|
self.coord_prior_var = coord_prior_var |
|
|
self.additional_noise_scale = additional_noise_scale |
|
|
|
|
|
if self.fix_alpha_carbon: assert self.latent_n_channel == 1, f'Specifying fix alpha carbon (use Ca as the latent coordinate) but number of latent channels is not 1' |
|
|
if self.anchor_at_ca: assert self.latent_n_channel == 1, f'Specifying anchor_at_ca as True but number of latent channels is not 1' |
|
|
if self.mode == 'fixseq': assert self.coord_loss_ratio == 1.0, f'Specifying fixseq mode but coordination loss ratio is not 1.0: {self.coord_loss_ratio}' |
|
|
if self.mode == 'fixbb': assert self.coord_loss_ratio == 0.0, f'Specifying fixbb mode but coordination loss ratio is not 0.0: {self.coord_loss_ratio}' |
|
|
|
|
|
atom_embed_size = embed_size // 4 |
|
|
self.aa_feature = SeparatedAminoAcidFeature( |
|
|
embed_size, atom_embed_size, |
|
|
max_position=max_position, |
|
|
relative_position=relative_position, |
|
|
fix_atom_weights=True |
|
|
) |
|
|
self.protein_feature = ProteinFeature() |
|
|
|
|
|
self.encoder = create_encoder( |
|
|
name = encoder, |
|
|
atom_embed_size = atom_embed_size, |
|
|
embed_size = embed_size, |
|
|
hidden_size = hidden_size, |
|
|
n_channel = n_channel, |
|
|
n_layers = n_layers, |
|
|
dropout = dropout, |
|
|
n_rbf = n_rbf, |
|
|
cutoff = cutoff |
|
|
) |
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
self.sidechain_decoder = create_encoder( |
|
|
name = encoder, |
|
|
atom_embed_size = atom_embed_size, |
|
|
embed_size = embed_size, |
|
|
hidden_size = hidden_size, |
|
|
n_channel = n_channel, |
|
|
n_layers = n_layers, |
|
|
dropout = dropout, |
|
|
n_rbf = n_rbf, |
|
|
cutoff = cutoff |
|
|
) |
|
|
self.backbone_model = BackboneModel() |
|
|
self.sidechain_model = SideChainModel() |
|
|
self.W_Z_log_var = nn.Linear(hidden_size, latent_n_channel * 3) |
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
self.W_mean = nn.Linear(hidden_size, latent_size) |
|
|
self.W_log_var = nn.Linear(hidden_size, latent_size) |
|
|
|
|
|
self.latent2hidden = nn.Linear(latent_size, hidden_size) |
|
|
self.merge_S_H = nn.Linear(hidden_size * 2, hidden_size) |
|
|
|
|
|
if share_decoder: |
|
|
self.seq_decoder = self.sidechain_decoder |
|
|
else: |
|
|
self.seq_decoder = create_encoder( |
|
|
name = encoder, |
|
|
atom_embed_size = atom_embed_size, |
|
|
embed_size = embed_size, |
|
|
hidden_size = hidden_size, |
|
|
n_channel = n_channel, |
|
|
n_layers = n_layers, |
|
|
dropout = dropout, |
|
|
n_rbf = n_rbf, |
|
|
cutoff = cutoff |
|
|
) |
|
|
|
|
|
|
|
|
self.unk_idx = 0 |
|
|
self.s_map = [0 for _ in range(len(VOCAB))] |
|
|
self.s_remap = [0 for _ in range(len(aas) + 1)] |
|
|
self.s_remap[0] = VOCAB.symbol_to_idx(VOCAB.UNK) |
|
|
for i, (a, _) in enumerate(aas): |
|
|
original_idx = VOCAB.symbol_to_idx(a) |
|
|
self.s_map[original_idx] = i + 1 |
|
|
self.s_remap[i + 1] = original_idx |
|
|
self.s_map = nn.Parameter(torch.tensor(self.s_map, dtype=torch.long), requires_grad=False) |
|
|
self.s_remap = nn.Parameter(torch.tensor(self.s_remap, dtype=torch.long), requires_grad=False) |
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
self.seq_linear = nn.Linear(hidden_size, len(self.s_remap)) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def prepare_inputs(self, X, S, mask, atom_mask, lengths): |
|
|
|
|
|
|
|
|
batch_ids = self.get_batch_ids(S, lengths) |
|
|
|
|
|
|
|
|
row, col = variadic_meshgrid( |
|
|
input1=torch.arange(batch_ids.shape[0], device=batch_ids.device), |
|
|
size1=lengths, |
|
|
input2=torch.arange(batch_ids.shape[0], device=batch_ids.device), |
|
|
size2=lengths, |
|
|
) |
|
|
|
|
|
is_ctx = mask[row] == mask[col] |
|
|
is_inter = ~is_ctx |
|
|
ctx_edges = torch.stack([row[is_ctx], col[is_ctx]], dim=0) |
|
|
inter_edges = torch.stack([row[is_inter], col[is_inter]], dim=0) |
|
|
|
|
|
return ctx_edges, inter_edges, batch_ids |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_batch_ids(self, S, lengths): |
|
|
batch_ids = torch.zeros_like(S) |
|
|
batch_ids[torch.cumsum(lengths, dim=0)[:-1]] = 1 |
|
|
batch_ids.cumsum_(dim=0) |
|
|
return batch_ids |
|
|
|
|
|
def rsample(self, H, Z, Z_centers, no_randomness=False): |
|
|
''' |
|
|
H: [N, latent_size] |
|
|
Z: [N, latent_channel, 3] |
|
|
Z_centers: [N, latent_channel, 3] |
|
|
''' |
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
data_size = H.shape[0] |
|
|
H_mean = self.W_mean(H) |
|
|
H_log_var = -torch.abs(self.W_log_var(H)) |
|
|
H_kl_loss = -0.5 * torch.sum(1.0 + H_log_var - H_mean * H_mean - torch.exp(H_log_var)) / data_size |
|
|
H_vecs = H_mean if no_randomness else H_mean + torch.exp(H_log_var / 2) * torch.randn_like(H_mean) |
|
|
else: |
|
|
H_vecs, H_kl_loss = None, 0 |
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
data_size = Z.shape[0] |
|
|
Z_mean_delta = Z - Z_centers |
|
|
Z_log_var = -torch.abs(self.W_Z_log_var(H)).view(-1, self.latent_n_channel, 3) |
|
|
Z_kl_loss = -0.5 * torch.sum(1.0 + Z_log_var - math.log(self.coord_prior_var) - Z_mean_delta * Z_mean_delta / self.coord_prior_var - torch.exp(Z_log_var) / self.coord_prior_var) / data_size |
|
|
Z_vecs = Z if no_randomness else Z + torch.exp(Z_log_var / 2) * torch.randn_like(Z) |
|
|
else: |
|
|
Z_vecs, Z_kl_loss = None, 0 |
|
|
|
|
|
return H_vecs, Z_vecs, H_kl_loss, Z_kl_loss |
|
|
|
|
|
def _get_latent_channels(self, X, atom_mask): |
|
|
atom_weights = atom_mask.float() |
|
|
if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: |
|
|
return X[:, self.ca_channel_idx].unsqueeze(1) |
|
|
elif self.latent_n_channel == 1: |
|
|
X = (X * atom_weights.unsqueeze(-1)).sum(1) |
|
|
X = X / atom_weights.sum(-1).unsqueeze(-1) |
|
|
return X.unsqueeze(1) |
|
|
elif self.latent_n_channel == 5: |
|
|
bb_X = X[:, :4] |
|
|
X = (X * atom_weights.unsqueeze(-1)).sum(1) |
|
|
X = X / atom_weights.sum(-1).unsqueeze(-1) |
|
|
X = torch.cat([bb_X, X.unsqueeze(1)], dim=1) |
|
|
return X |
|
|
else: |
|
|
raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') |
|
|
|
|
|
def _get_latent_channel_anchors(self, X, atom_mask): |
|
|
if self.anchor_at_ca: |
|
|
return X[:, self.ca_channel_idx].unsqueeze(1) |
|
|
else: |
|
|
return self._get_latent_channels(X, atom_mask) |
|
|
|
|
|
def _fill_latent_channels(self, latent_X): |
|
|
if self.latent_n_channel == 1: |
|
|
return latent_X.repeat(1, self.n_channel, 1) |
|
|
elif self.latent_n_channel == 5: |
|
|
bb_X = latent_X[:, :4] |
|
|
sc_X = latent_X[:, 4].unsqueeze(1).repeat(1, self.n_channel - 4, 1) |
|
|
return torch.cat([bb_X, sc_X], dim=1) |
|
|
else: |
|
|
raise NotImplementedError(f'Latent number of channels: {self.latent_n_channel} not implemented') |
|
|
|
|
|
def _remove_sidechain_atom_mask(self, atom_mask, mask_generate): |
|
|
atom_mask = atom_mask.clone() |
|
|
bb_mask = atom_mask[mask_generate] |
|
|
bb_mask[:, 4:] = 0 |
|
|
atom_mask[mask_generate] = bb_mask |
|
|
return atom_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def _mask_pep(self, S, atom_mask, mask_generate): |
|
|
assert self.mask_ratio > 0 |
|
|
S, atom_mask = S.clone(), atom_mask.clone() |
|
|
pep_S = S[mask_generate] |
|
|
do_mask = torch.rand_like(pep_S, dtype=torch.float) < self.mask_ratio |
|
|
pep_S[do_mask] = self.mask_id |
|
|
|
|
|
S[mask_generate] = pep_S |
|
|
atom_mask[mask_generate ]= self._remove_sidechain_atom_mask(atom_mask[mask_generate], do_mask) |
|
|
|
|
|
return S, atom_mask |
|
|
|
|
|
def encode(self, X, S, mask, position_ids, lengths, atom_mask, no_randomness=False): |
|
|
true_X = X.clone() |
|
|
|
|
|
ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) |
|
|
H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) |
|
|
|
|
|
edges = torch.cat([ctx_edges, inter_edges], dim=1) |
|
|
atom_weights = atom_mask.float() |
|
|
|
|
|
H, pred_X = self.encoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_weights) |
|
|
H = H[mask] |
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
if hasattr(self, 'fix_alpha_carbon') and self.fix_alpha_carbon: |
|
|
Z = self._get_latent_channels(true_X, atom_mask) |
|
|
else: |
|
|
Z = self._get_latent_channels(pred_X, atom_mask) |
|
|
Z_centers = self._get_latent_channel_anchors(true_X, atom_mask) |
|
|
Z, Z_centers = Z[mask], Z_centers[mask] |
|
|
else: |
|
|
Z, Z_centers = None, None |
|
|
|
|
|
|
|
|
latent_H, latent_X, H_kl_loss, X_kl_loss = self.rsample(H, Z, Z_centers, no_randomness) |
|
|
return latent_H, latent_X, H_kl_loss, X_kl_loss |
|
|
|
|
|
def decode(self, X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing): |
|
|
X, S, atom_mask = X.clone(), S.clone(), atom_mask.clone() |
|
|
true_S = S[mask].clone() |
|
|
if self.mode != 'fixbb': |
|
|
X[mask] = self._fill_latent_channels(Z) |
|
|
if self.mode != 'fixseq': |
|
|
S[mask] = self.latent_id |
|
|
H_from_latent = self.latent2hidden(H) |
|
|
|
|
|
if self.mode == 'fixbb': |
|
|
atom_mask = self._remove_sidechain_atom_mask(atom_mask, mask) |
|
|
elif self.mode == 'codesign': |
|
|
atom_mask[mask] = 1 |
|
|
else: |
|
|
pass |
|
|
|
|
|
ctx_edges, inter_edges, batch_ids = self.prepare_inputs(X, S, mask, atom_mask, lengths) |
|
|
edges = torch.cat([ctx_edges, inter_edges], dim=1) |
|
|
|
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
H_0, (atom_embeddings, _) = self.aa_feature(S, position_ids) |
|
|
H_0 = H_0.clone() |
|
|
H_0[mask] = H_from_latent |
|
|
H, _ = self.seq_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) |
|
|
pred_S_logits = self.seq_linear(H[mask]) |
|
|
S = S.clone() |
|
|
if teacher_forcing: |
|
|
S[mask] = true_S |
|
|
else: |
|
|
S[mask] = self.s_remap[torch.argmax(pred_S_logits, dim=-1)] |
|
|
else: |
|
|
pred_S_logits = None |
|
|
|
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
H_0, (atom_embeddings, atom_weights) = self.aa_feature(S, position_ids) |
|
|
H_0 = H_0.clone() |
|
|
if self.mode != 'fixseq': |
|
|
H_0[mask] = self.merge_S_H(torch.cat([H_from_latent, H_0[mask]], dim=-1)) |
|
|
|
|
|
atom_mask = atom_mask.clone() |
|
|
atom_mask[mask] = atom_weights.bool()[mask] & atom_mask[mask] |
|
|
_, pred_X = self.sidechain_decoder(H_0, X, edges, channel_attr=atom_embeddings, channel_weights=atom_mask.float()) |
|
|
pred_X = pred_X[mask] |
|
|
else: |
|
|
pred_X = None |
|
|
|
|
|
return pred_S_logits, pred_X |
|
|
|
|
|
@oom_decorator |
|
|
def forward(self, X, S, mask, position_ids, lengths, atom_mask, teacher_forcing=True): |
|
|
true_X, true_S = X[mask].clone(), S[mask].clone() |
|
|
|
|
|
|
|
|
if self.mask_ratio > 0: |
|
|
input_S, input_atom_mask = self._mask_pep(S, atom_mask, mask) |
|
|
else: |
|
|
input_S, input_atom_mask = S, atom_mask |
|
|
H, Z, H_kl_loss, Z_kl_loss = self.encode(X, input_S, mask, position_ids, lengths, input_atom_mask) |
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
coord_reg_loss = F.mse_loss(Z, self._get_latent_channel_anchors(true_X, atom_mask[mask])) |
|
|
else: |
|
|
coord_reg_loss = 0 |
|
|
|
|
|
|
|
|
if self.training: |
|
|
noise = torch.randn_like(Z) * getattr(self, 'additional_noise_scale', 0.0) |
|
|
Z = Z + noise |
|
|
|
|
|
|
|
|
recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing) |
|
|
|
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
seq_recon_loss = F.cross_entropy(recon_S_logits, self.s_map[true_S]) |
|
|
|
|
|
with torch.no_grad(): |
|
|
aar = (torch.argmax(recon_S_logits, dim=-1) == self.s_map[true_S]).sum() / len(recon_S_logits) |
|
|
else: |
|
|
seq_recon_loss, aar = 0, 1.0 |
|
|
|
|
|
|
|
|
if self.mode != 'fixbb': |
|
|
xloss_mask = atom_mask[mask] |
|
|
batch_ids = self.get_batch_ids(S, lengths)[mask] |
|
|
segment_ids = torch.ones_like(true_S, device=true_S.device, dtype=torch.long) |
|
|
if self.n_channel == 4: |
|
|
loss_profile = {} |
|
|
else: |
|
|
true_struct_profile = self.protein_feature.get_struct_profile(true_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) |
|
|
recon_struct_profile = self.protein_feature.get_struct_profile(recon_X, true_S, batch_ids, self.aa_feature, segment_ids, xloss_mask) |
|
|
loss_profile = { key + '_loss': F.l1_loss(recon_struct_profile[key], true_struct_profile[key]) for key in recon_struct_profile } |
|
|
|
|
|
|
|
|
xloss = F.mse_loss(recon_X[xloss_mask], true_X[xloss_mask]) |
|
|
loss_profile['Xloss'] = xloss |
|
|
|
|
|
|
|
|
ca_xloss_mask = xloss_mask[:, self.ca_channel_idx] |
|
|
ca_xloss = F.mse_loss(recon_X[:, self.ca_channel_idx][ca_xloss_mask], true_X[:, self.ca_channel_idx][ca_xloss_mask]) |
|
|
loss_profile['ca_Xloss'] = ca_xloss |
|
|
|
|
|
struct_recon_loss = 0 |
|
|
for name in loss_profile: |
|
|
struct_recon_loss = struct_recon_loss + self.coord_loss_weights[name] * loss_profile[name] |
|
|
else: |
|
|
struct_recon_loss, loss_profile = 0, {} |
|
|
|
|
|
recon_loss = (1 - self.coord_loss_ratio) * (seq_recon_loss + self.h_kl_weight * H_kl_loss) + \ |
|
|
self.coord_loss_ratio * (struct_recon_loss + self.z_kl_weight * Z_kl_loss) |
|
|
|
|
|
return recon_loss, (seq_recon_loss, aar), (struct_recon_loss, loss_profile), (H_kl_loss, Z_kl_loss, coord_reg_loss) |
|
|
|
|
|
def _reconstruct(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, allow_unk=False, optimize_sidechain=True, idealize=False, no_randomness=False): |
|
|
if given_laten_H is None and given_latent_X is None: |
|
|
|
|
|
H, Z, _, _ = self.encode(X, S, mask, position_ids, lengths, atom_mask, no_randomness=no_randomness) |
|
|
|
|
|
else: |
|
|
H, Z = given_laten_H, given_latent_X |
|
|
|
|
|
|
|
|
recon_S_logits, recon_X = self.decode(X, S, H, Z, mask, position_ids, lengths, atom_mask, teacher_forcing=False) |
|
|
batch_ids = self.get_batch_ids(S, lengths)[mask] |
|
|
|
|
|
if self.mode != 'fixseq': |
|
|
if not allow_unk: |
|
|
recon_S_logits[:, 0] = float('-inf') |
|
|
|
|
|
|
|
|
recon_S = self.s_remap[torch.argmax(recon_S_logits, dim=-1)] |
|
|
|
|
|
snll_all = F.cross_entropy(recon_S_logits, torch.argmax(recon_S_logits, dim=-1), reduction='none') |
|
|
batch_ppls = scatter_mean(snll_all, batch_ids, dim=0) |
|
|
else: |
|
|
recon_S = S[mask] |
|
|
batch_ppls = torch.zeros(batch_ids.max() + 1, device=recon_X.device).float() |
|
|
|
|
|
if self.mode == 'fixseq' or (self.mode != 'fixbb' and idealize): |
|
|
|
|
|
recon_X = self.backbone_model(recon_X, batch_ids) |
|
|
|
|
|
recon_X = self.sidechain_model(recon_X, recon_S, batch_ids, optimize_sidechain) |
|
|
|
|
|
return recon_X, recon_S, batch_ppls, batch_ids |
|
|
|
|
|
@torch.no_grad() |
|
|
def test(self, X, S, mask, position_ids, lengths, atom_mask, given_laten_H=None, given_latent_X=None, return_tensor=False, allow_unk=False, optimize_sidechain=True, idealize=False, n_iter=1): |
|
|
|
|
|
no_randomness = given_laten_H is not None |
|
|
for i in range(n_iter): |
|
|
recon_X, recon_S, batch_ppls, batch_ids = self._reconstruct(X, S, mask, position_ids, lengths, atom_mask, given_laten_H, given_latent_X, allow_unk, optimize_sidechain, idealize, no_randomness) |
|
|
X, S = X.clone(), S.clone() |
|
|
if self.mode != 'fixbb': |
|
|
X[mask] = recon_X |
|
|
if self.mode != 'fixseq': |
|
|
S[mask] = recon_S |
|
|
given_laten_H, given_latent_X = None, None |
|
|
|
|
|
if return_tensor: |
|
|
return recon_X, recon_S, batch_ppls |
|
|
|
|
|
batch_X, batch_S = [], [] |
|
|
batch_ppls = batch_ppls.tolist() |
|
|
for i, l in enumerate(lengths): |
|
|
cur_mask = batch_ids == i |
|
|
if self.mode == 'fixbb': |
|
|
batch_X.append(None) |
|
|
else: |
|
|
batch_X.append(recon_X[cur_mask].tolist()) |
|
|
if self.mode == 'fixseq': |
|
|
batch_S.append(None) |
|
|
else: |
|
|
batch_S.append(''.join([VOCAB.idx_to_symbol(s) for s in recon_S[cur_mask]])) |
|
|
|
|
|
return batch_X, batch_S, batch_ppls |
|
|
|