import torch from torch.utils.data import TensorDataset, DataLoader import numpy as np from models import SingleTransformer, MultiModalTransformer import config from data import create_dataset def create_masked_input(input_tensor, mask_token, mask_prob=0.20): """ Creates a masked input tensor by randomly replacing elements with a mask token. Args: input_tensor (torch.Tensor): The input tensor to be masked. mask_token: The token to be used for masking. mask_prob (float, optional): The probability of masking an element. Defaults to 0.20. Returns: torch.Tensor: The masked input tensor. torch.Tensor: A boolean mask indicating which elements were masked. """ mask = torch.rand(input_tensor.shape) < mask_prob masked_input = input_tensor.clone() masked_input[mask] = mask_token return masked_input, mask def get_max(adata): """ Get the maximum value in the data. Args: adata (list): A list of AnnData objects. Returns: float: The maximum value in the list data. """ assert(isinstance(adata, list)), "adata must be a list of AnnData objects." x_s = [] for i in adata: X = torch.tensor(i.X.toarray().copy()) x_s.append(np.array(X).flatten().max()) return max(x_s) def get_token_embeddings(model, dataset, device): """ Get the token embeddings for the dataset. Args: model (torch.nn.Module): Model. dataset (torch.utils.data.Dataset): Dataset. device (str): Device to use. Returns: torch.Tensor: Embeddings. """ model.eval() embeddings = [] loader = DataLoader(dataset, batch_size=32, shuffle=False) with torch.no_grad(): for batch in loader: if len(batch) == 3: inputs, bi, _ = batch elif len(batch) == 2: inputs, bi = batch if isinstance(inputs, list): rna= inputs[0].to(device) atac = inputs[1].to(device) flux = inputs[2].to(device) inputs = (rna, atac, flux) else: inputs = inputs.to(device) bi = bi.to(device) output = model(inputs, bi, return_embeddings=True) embeddings.append(output.cpu().detach()) # Concatenate embeddings across batches embeddings = torch.cat(embeddings, dim=0) # shape: (n_samples, seq_len, d_model) return embeddings def get_all_modalities_available_samples(dataset): rna = dataset.rna_data atac = dataset.atac_data flux = dataset.flux_data mask = (rna != 0).any(axis=1) & (atac != 0).any(axis=1) & (flux != 0).any(axis=1) new_ds = create_dataset.MultiModalDataset((rna[mask], atac[mask], flux[mask]), dataset.batch_no[mask], dataset.labels[mask]) return new_ds def separate_dataset(ds): """ Separate a dataset into two groups based on the labels. Args: ds (TensorDataset): Dataset. Returns: TensorDataset: Dataset with label 0. TensorDataset: Dataset with label 1. """ X, b, y = ds.tensors # Create masks for labels 0 and 1 mask_0 = (y == 0) mask_1 = (y == 1) # Filter the tensors based on the masks X_0, b_0, y_0 = X[mask_0], b[mask_0], y[mask_0] X_1, b_1, y_1 = X[mask_1], b[mask_1], y[mask_1] # Create new datasets for each group dataset_0 = TensorDataset(X_0, b_0, y_0) # Dataset with y == 0 dataset_1 = TensorDataset(X_1, b_1, y_1) return dataset_0, dataset_1 def create_multimodal_model(model_config, device, use_mlm=False): """ Create a multimodal model. Args: model_config (dict): Model configuration. device (str): Device to use. use_mlm (bool, optional): Whether to use MLM pretraining. Defaults to False. Returns: MultiModalTransformer: Multimodal model. """ model_config_rna, model_config_atac, model_config_flux = model_config['RNA'], model_config['ATAC'], model_config['Flux'] share_config, model_config_multi = model_config['Share'], model_config['Multi'] rna_model = SingleTransformer("RNA", **model_config_rna, **share_config).to(device) atac_model = SingleTransformer("ATAC", **model_config_atac, **share_config).to(device) flux_model = SingleTransformer("Flux", **model_config_flux, **share_config).to(device) if use_mlm: rna_model.load_state_dict(torch.load(config.MLM_RNA_CKP), strict=False) atac_model.load_state_dict(torch.load(config.MLM_ATAC_CKP), strict=False) flux_model.load_state_dict(torch.load(config.MLM_FLUX_CKP), strict=False) # print("Loaded MLM pretraining weights.: \n RNA: {}, ATAC: {}, Flux: {}".format(config.MLM_RNA_CKP, config.MLM_ATAC_CKP, config.MLM_FLUX_CKP)) model = MultiModalTransformer(rna_model, atac_model, flux_model, **model_config_multi).to(device) return model