File size: 9,425 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.data.dataset import Dataset
from anndata import AnnData
import pandas as pd
import random
import numpy as np

def get_mlm_loaders(train_data, val_data, batch_size=32, batch_key='batch_no', data_dtype=torch.float32):
    if isinstance(train_data, AnnData) and \
       isinstance(val_data, AnnData):
        X_train = torch.tensor(train_data.X.toarray().copy(), dtype=data_dtype)
        b_train = torch.tensor(train_data.obs[batch_key], dtype=torch.int32)

        X_val = torch.tensor(val_data.X.toarray().copy(), dtype=data_dtype)
        b_val = torch.tensor(val_data.obs[batch_key], dtype=torch.int32)

    elif isinstance(train_data, tuple) and \
         isinstance(train_data[0], (pd.DataFrame)) and \
         isinstance(val_data, (tuple)) and \
         isinstance(val_data[0], (pd.DataFrame)):
        
        X_train = torch.tensor(train_data[0].values, dtype=data_dtype)
        b_train = torch.tensor(train_data[1], dtype=torch.int32)

        X_val = torch.tensor(val_data[0].values, dtype=data_dtype)
        b_val = torch.tensor(val_data[1], dtype=torch.int32)
    else:
        raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list).")
    
    mlm_train_dataset = TensorDataset(X_train, b_train)
    mlm_train_loader = DataLoader(mlm_train_dataset, batch_size=batch_size, shuffle=True)

    mlm_val_dataset = TensorDataset(X_val, b_val)
    mlm_val_loader = DataLoader(mlm_val_dataset, batch_size=batch_size, shuffle=False)

    return mlm_train_loader, mlm_val_loader


def get_cls_dataset(data, batch_key='batch_no', label_key='label', 
                    pct_key='pct', filter_pcts=50.0, 
                    data_dtype=torch.float32):

    if isinstance(data, AnnData):
        X = torch.tensor(data.X.toarray().copy(), dtype=data_dtype)
        y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data.obs[label_key])], dtype=torch.float32)
        b = torch.tensor(data.obs[batch_key], dtype=torch.int32)
        pcts = torch.tensor(data.obs[pct_key], dtype=torch.float32)

        X = X[pcts > filter_pcts]
        y = y[pcts > filter_pcts]
        b = b[pcts > filter_pcts]
        pcts = pcts[pcts > filter_pcts]
        feature_names = data.var_names.tolist()

    elif isinstance(data, tuple) and isinstance(data[0], pd.DataFrame):
        X = torch.tensor(data[0].values, dtype=data_dtype)
        y = torch.tensor([{'reprogramming':1, 'dead-end':0}[i] for i in list(data[1])], dtype=torch.float32)
        b = torch.tensor(data[2], dtype=torch.int32)
        pcts = torch.tensor(data[3], dtype=torch.float32)
        X = X[pcts > filter_pcts]
        y = y[pcts > filter_pcts]
        b = b[pcts > filter_pcts]
        pcts = pcts[pcts > filter_pcts]
        feature_names = data[0].columns.tolist()

    else:
        raise ValueError("Data must be an AnnData object or a tuple of (pd.DataFrame, list, list, list).")
 
    dataset = TensorDataset(X, b, y)
    
    return dataset, pcts, feature_names

def get_pair_modalities(adata_rna, adata_atac, flux_df, include_unused_atacs=False, seed=42):
    """
    Pair RNA, ATAC and Flux data based on clone IDs.
    Args:
        adata_rna (AnnData): RNA data.
        adata_atac (AnnData): ATAC data.
        flux_df (pd.DataFrame): Flux data.
        include_unused_atacs (bool): Include ATAC samples that do not have a paired RNA sample.
    Returns:
        tuple:
         - rna_data (pd.DataFrame): RNA data matched by clone IDs, with rows representing samples and columns representing gene expressions.
         - atac_data (pd.DataFrame): ATAC data matched by clone IDs, with rows representing samples and columns representing chromatin accessibility features.
         - flux_data (pd.DataFrame): Flux data matched by clone IDs, with rows representing samples and columns representing flux measurements.
    
        np.array: labels. np.array of labels.
        np.array: batch indices. np.array of batch indices.
        pd.DataFrame: indices. A DataFrame where each row contains the indices of matched RNA and ATAC samples. 
                                  If no match is found for one modality, the corresponding value is None.
        np.array: pcts. Array of dominant fate percentages for each paired sample.
    """

    # Create a dictionary to map ATAC clone IDs to their indices
    atac_clone_to_indices = {clone_id: [] for clone_id in adata_atac.obs['clone_id'].unique()}
    adata_atac.obs['index'] = adata_atac.obs.index
    grouped = adata_atac.obs.groupby('clone_id')['index'].apply(list)
    atac_clone_to_indices.update(grouped)

    rna_data, atac_data, flux_data, labels, batch_ind, indices, pcts = [], [], [], [], [], [], []
    
    used_atac_indices = set()
    
    for rna_index, row in adata_rna.obs.iterrows():
        clone_id = row['clone_id']
        sibling_atac_indices = [idx for idx in atac_clone_to_indices.get(clone_id, []) if idx not in used_atac_indices]

        if sibling_atac_indices:
            random.seed(seed)
            atac_index = random.choice(sibling_atac_indices)
            # atac_index = sibling_atac_indices[0]
            
            used_atac_indices.add(atac_index)
            
            rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X
            atac_sample = adata_atac[atac_index].X.toarray().flatten() if hasattr(adata_atac[atac_index].X, 'toarray') else adata_atac[atac_index].X
        else:
            rna_sample = adata_rna[rna_index].X.toarray().flatten() if hasattr(adata_rna[rna_index].X, 'toarray') else adata_rna[rna_index].X
            atac_sample = np.zeros(adata_atac.shape[1])  # Fill with zeros if no ATAC pair is found

        flux_sample = flux_df.loc[rna_index].values

        label = row['label']
        bt = row['batch_no']
        pct = row['pct']
        
        rna_data.append(rna_sample)
        atac_data.append(atac_sample)
        flux_data.append(flux_sample)
        labels.append(label)
        batch_ind.append(bt)
        pcts.append(pct)
        indices.append((rna_index, atac_index) if sibling_atac_indices else (rna_index, None))
        
    
    if include_unused_atacs:
        all_atac_indices = set(adata_atac.obs.index)
        unused_atac_indices = sorted(list(all_atac_indices - used_atac_indices))
        unused_atac_samples = adata_atac[list(unused_atac_indices)]

        for atac_index in unused_atac_indices:
            atac_sample = unused_atac_samples[atac_index].X.toarray().flatten() if hasattr(unused_atac_samples[atac_index].X, 'toarray') else unused_atac_samples[atac_index].X
            rna_sample = np.zeros(adata_rna.shape[1])  # Fill with zeros for RNA
            flux_sample = np.zeros(flux_df.shape[1])   # Fill with zeros for flux

            label = adata_atac.obs.loc[atac_index, 'label']
            bt = adata_atac.obs.loc[atac_index, 'batch_no']
            pct = adata_atac.obs.loc[atac_index, 'pct']

            rna_data.append(rna_sample)
            atac_data.append(atac_sample)
            flux_data.append(flux_sample)
            labels.append(label)
            batch_ind.append(bt)
            pcts.append(pct)
            indices.append((None, atac_index))
        
    rna_data = pd.DataFrame(rna_data, columns=adata_rna.var_names, index=indices)
    atac_data = pd.DataFrame(atac_data, columns=adata_atac.var_names, index=indices)
    flux_data = pd.DataFrame(flux_data, columns=flux_df.columns, index=indices)
    
    X_i = (rna_data, atac_data, flux_data)
    y_i = np.array(labels)
    b_i = np.array(batch_ind)
    indices = pd.DataFrame(np.array(indices), columns=["RNA", "ATAC"])
    pcts = np.array(pcts)
    
    return X_i, y_i, b_i, indices, pcts

class MultiModalDataset(Dataset):
    """
    Multi-modal dataset for RNA, ATAC, and Flux data.
    Args:
        X (tuple): Tuple of (RNA, ATAC, Flux) data.
        batch_no (list): List of batch indices.
        labels (list): List of labels.
    """
    def __init__(self, X, batch_no, labels, df_indics=None, pcts=None, label_names=None):
        if isinstance(X[0], pd.DataFrame):
            self.rna_data = torch.tensor(X[0].values, dtype=torch.int32)
            self.atac_data = torch.tensor(X[1].values, dtype=torch.float32)
            self.flux_data = torch.tensor(X[2].values, dtype=torch.float32)
        else:
            self.rna_data = torch.tensor(X[0], dtype=torch.int32)
            self.atac_data = torch.tensor(X[1], dtype=torch.float32)
            self.flux_data = torch.tensor(X[2], dtype=torch.float32)

        self.batch_no = torch.tensor(batch_no, dtype=torch.int32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        self.df_indics = df_indics
        self.pcts = pcts
        self.label_names = label_names
    def __len__(self):
        return len(self.labels)

    def get_df_indices(self):
        return self.df_indics
    def get_pcts(self):
        return self.pcts
    def get_label_names(self):
        return self.label_names
    def __getitem__(self, idx):
        rna_sample = self.rna_data[idx]
        atac_sample = self.atac_data[idx]
        flux_sample = self.flux_data[idx]
        batch_no = self.batch_no[idx]
        label = self.labels[idx]
        return (rna_sample, atac_sample, flux_sample), batch_no, label