Spaces:
Running
Running
File size: 9,425 Bytes
ef814bf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data.dataset import Dataset
from anndata import AnnData
import pandas as pd
import random
import numpy as np
def get_mlm_loaders(train_data, val_data, batch_size=32, batch_key='batch_no', data_dtype=torch.float32):
if isinstance(train_data, AnnData) and \
isinstance(val_data, AnnData):
X_train = torch.tensor(train_data.X.toarray().copy(), dtype=data_dtype)
b_train = torch.tensor(train_data.obs[batch_key], dtype=torch.int32)
X_val = torch.tensor(val_data.X.toarray().copy(), dtype=data_dtype)
b_val = torch.tensor(val_data.obs[batch_key], dtype=torch.int32)
elif isinstance(train_data, tuple) and \
isinstance(train_data[0], (pd.DataFrame)) and \
isinstance(val_data, (tuple)) and \
isinstance(val_data[0], (pd.DataFrame)):
X_train = torch.tensor(train_data[0].values, dtype=data_dtype)
b_train = torch.tensor(train_data[1], dtype=torch.int32)
X_val = torch.tensor(val_data[0].values, dtype=data_dtype)
b_val = torch.tensor(val_data[1], dtype=torch.int32)
else:
raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list).")
mlm_train_dataset = TensorDataset(X_train, b_train)
mlm_train_loader = DataLoader(mlm_train_dataset, batch_size=batch_size, shuffle=True)
mlm_val_dataset = TensorDataset(X_val, b_val)
mlm_val_loader = DataLoader(mlm_val_dataset, batch_size=batch_size, shuffle=False)
return mlm_train_loader, mlm_val_loader
def get_cls_dataset(data, batch_key='batch_no', label_key='label',
pct_key='pct', filter_pcts=50.0,
data_dtype=torch.float32):
if isinstance(data, AnnData):
X = torch.tensor(data.X.toarray().copy(), dtype=data_dtype)
y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data.obs[label_key])], dtype=torch.float32)
b = torch.tensor(data.obs[batch_key], dtype=torch.int32)
pcts = torch.tensor(data.obs[pct_key], dtype=torch.float32)
X = X[pcts > filter_pcts]
y = y[pcts > filter_pcts]
b = b[pcts > filter_pcts]
pcts = pcts[pcts > filter_pcts]
feature_names = data.var_names.tolist()
elif isinstance(data, tuple) and isinstance(data[0], pd.DataFrame):
X = torch.tensor(data[0].values, dtype=data_dtype)
y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data[1])], dtype=torch.float32)
b = torch.tensor(data[2], dtype=torch.int32)
pcts = torch.tensor(data[3], dtype=torch.float32)
X = X[pcts > filter_pcts]
y = y[pcts > filter_pcts]
b = b[pcts > filter_pcts]
pcts = pcts[pcts > filter_pcts]
feature_names = data[0].columns.tolist()
else:
raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list, list, list).")
dataset = TensorDataset(X, b, y)
return dataset, pcts, feature_names
def get_pair_modalities(adata_rna, adata_atac, flux_df, include_unused_atacs=False, seed=42):
"""
Pair RNA, ATAC and Flux data based on clone IDs.
Args:
adata_rna (AnnData): RNA data.
adata_atac (AnnData): ATAC data.
flux_df (pd.DataFrame): Flux data.
include_unused_atacs (bool): Include ATAC samples that do not have a paired RNA sample.
Returns:
tuple:
- rna_data (pd.DataFrame): RNA data matched by clone IDs, with rows representing samples and columns representing gene expressions.
- atac_data (pd.DataFrame): ATAC data matched by clone IDs, with rows representing samples and columns representing chromatin accessibility features.
- flux_data (pd.DataFrame): Flux data matched by clone IDs, with rows representing samples and columns representing flux measurements.
np.array: labels. np.array of labels.
np.array: batch indices. np.array of batch indices.
pd.DataFrame: indices. A DataFrame where each row contains the indices of matched RNA and ATAC samples.
If no match is found for one modality, the corresponding value is None.
np.array: pcts. Array of dominant fate percentages for each paired sample.
"""
# Create a dictionary to map ATAC clone IDs to their indices
atac_clone_to_indices = {clone_id: [] for clone_id in adata_atac.obs['clone_id'].unique()}
adata_atac.obs['index'] = adata_atac.obs.index
grouped = adata_atac.obs.groupby('clone_id')['index'].apply(list)
atac_clone_to_indices.update(grouped)
rna_data, atac_data, flux_data, labels, batch_ind, indices, pcts = [], [], [], [], [], [], []
used_atac_indices = set()
for rna_index, row in adata_rna.obs.iterrows():
clone_id = row['clone_id']
sibling_atac_indices = [idx for idx in atac_clone_to_indices.get(clone_id, []) if idx not in used_atac_indices]
if sibling_atac_indices:
random.seed(seed)
atac_index = random.choice(sibling_atac_indices)
# atac_index = sibling_atac_indices[0]
used_atac_indices.add(atac_index)
rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X
atac_sample = adata_atac[atac_index].X.toarray().flatten() if hasattr(adata_atac[atac_index].X, 'toarray') else adata_atac[atac_index].X
else:
rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X
atac_sample = np.zeros(adata_atac.shape[1]) # Fill with zeros if no ATAC pair is found
flux_sample = flux_df.loc[rna_index].values
label = row['label']
bt = row['batch_no']
pct = row['pct']
rna_data.append(rna_sample)
atac_data.append(atac_sample)
flux_data.append(flux_sample)
labels.append(label)
batch_ind.append(bt)
pcts.append(pct)
indices.append((rna_index, atac_index) if sibling_atac_indices else (rna_index, None))
if include_unused_atacs:
all_atac_indices = set(adata_atac.obs.index)
unused_atac_indices = sorted(list(all_atac_indices - used_atac_indices))
unused_atac_samples = adata_atac[list(unused_atac_indices)]
for atac_index in unused_atac_indices:
atac_sample = unused_atac_samples[atac_index].X.toarray().flatten() if hasattr(unused_atac_samples[atac_index].X, 'toarray') else unused_atac_samples[atac_index].X
rna_sample = np.zeros(adata_rna.shape[1]) # Fill with zeros for RNA
flux_sample = np.zeros(flux_df.shape[1]) # Fill with zeros for flux
label = adata_atac.obs.loc[atac_index, 'label']
bt = adata_atac.obs.loc[atac_index, 'batch_no']
pct = adata_atac.obs.loc[atac_index, 'pct']
rna_data.append(rna_sample)
atac_data.append(atac_sample)
flux_data.append(flux_sample)
labels.append(label)
batch_ind.append(bt)
pcts.append(pct)
indices.append((None, atac_index))
rna_data = pd.DataFrame(rna_data, columns=adata_rna.var_names, index=indices)
atac_data = pd.DataFrame(atac_data, columns=adata_atac.var_names, index=indices)
flux_data = pd.DataFrame(flux_data, columns=flux_df.columns, index=indices)
X_i = (rna_data, atac_data, flux_data)
y_i = np.array(labels)
b_i = np.array(batch_ind)
indices = pd.DataFrame(np.array(indices), columns=["RNA", "ATAC"])
pcts = np.array(pcts)
return X_i, y_i, b_i, indices, pcts
class MultiModalDataset(Dataset):
"""
Multi-modal dataset for RNA, ATAC, and Flux data.
Args:
X (tuple): Tuple of (RNA, ATAC, Flux) data.
batch_no (list): List of batch indices.
labels (list): List of labels.
"""
def __init__(self, X, batch_no, labels, df_indics=None, pcts=None, label_names=None):
if isinstance(X[0], pd.DataFrame):
self.rna_data = torch.tensor(X[0].values, dtype=torch.int32)
self.atac_data = torch.tensor(X[1].values, dtype=torch.float32)
self.flux_data = torch.tensor(X[2].values, dtype=torch.float32)
else:
self.rna_data = torch.tensor(X[0], dtype=torch.int32)
self.atac_data = torch.tensor(X[1], dtype=torch.float32)
self.flux_data = torch.tensor(X[2], dtype=torch.float32)
self.batch_no = torch.tensor(batch_no, dtype=torch.int32)
self.labels = torch.tensor(labels, dtype=torch.float32)
self.df_indics = df_indics
self.pcts = pcts
self.label_names = label_names
def __len__(self):
return len(self.labels)
def get_df_indices(self):
return self.df_indics
def get_pcts(self):
return self.pcts
def get_label_names(self):
return self.label_names
def __getitem__(self, idx):
rna_sample = self.rna_data[idx]
atac_sample = self.atac_data[idx]
flux_sample = self.flux_data[idx]
batch_no = self.batch_no[idx]
label = self.labels[idx]
return (rna_sample, atac_sample, flux_sample), batch_no, label
|