Spaces:
Running
Running
| import torch | |
| from utils.helpers import create_multimodal_model | |
| from models import SingleTransformer | |
| from utils.helpers import get_all_modalities_available_samples | |
| from data import create_dataset | |
| import shap | |
| def filter_ds(dataset, indices): | |
| rna = dataset.rna_data[indices] | |
| atac = dataset.atac_data[indices] | |
| flux = dataset.flux_data[indices] | |
| new_ds = create_dataset.MultiModalDataset((rna, atac, flux), | |
| dataset.batch_no[indices], | |
| dataset.labels[indices]) | |
| return new_ds | |
| def get_background_data(id, dataset, samples=100, return_other_samples=False): | |
| """ | |
| Get background data with balanced samples from each label | |
| Args: | |
| dataset: MultiModalDataset object | |
| samples: Number of samples to get | |
| return_other_samples: If True, return other samples as well | |
| Returns: | |
| new_ds: MultiModalDataset object with background samples | |
| background_indices: Indices of background samples | |
| other_ds: MultiModalDataset object with other samples | |
| other_indices: Indices of other samples | |
| """ | |
| if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: | |
| raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") | |
| if id == 'Multi': | |
| dataset = get_all_modalities_available_samples(dataset) | |
| labels = dataset.labels | |
| # get a balance of samples between labels | |
| samples_per_label = samples // len(torch.unique(labels)) | |
| background_indices = [] | |
| for label in torch.unique(labels): | |
| label_indices = torch.where(labels == label)[0] | |
| background_indices.extend(label_indices[:samples_per_label]) | |
| background_indices = torch.tensor(background_indices) | |
| background_rna = dataset.rna_data[background_indices] | |
| background_atac = dataset.atac_data[background_indices] | |
| background_flux = dataset.flux_data[background_indices] | |
| bg_ds = create_dataset.MultiModalDataset((background_rna, background_atac, background_flux), | |
| dataset.batch_no[background_indices], | |
| dataset.labels[background_indices]) | |
| if return_other_samples: | |
| # create a new dataset of other samples | |
| other_indices = torch.tensor([i for i in range(len(labels)) if i not in background_indices]) | |
| other_rna = dataset.rna_data[other_indices] | |
| other_atac = dataset.atac_data[other_indices] | |
| other_flux = dataset.flux_data[other_indices] | |
| other_ds = create_dataset.MultiModalDataset((other_rna, other_atac, other_flux), | |
| dataset.batch_no[other_indices], | |
| dataset.labels[other_indices]) | |
| return bg_ds, background_indices, other_ds, other_indices | |
| return bg_ds, background_indices | |
| else: | |
| raise ValueError("Not Implemented") | |
| class ShapWrapper(torch.nn.Module): | |
| def __init__(self, model): | |
| super().__init__() | |
| self.model = model | |
| self.model.eval() | |
| def forward(self, x): | |
| inputs, b = x[:,:-2], x[:,-1].squeeze(-1).long() | |
| inputs = (inputs[:,:944].long(), inputs[:,944:944+883].float(), inputs[:,944+883:].float()) | |
| preds, _ = self.model(inputs, b) | |
| preds = torch.sigmoid(preds) | |
| # print(preds.shape) | |
| return preds | |
| def compute_shap_values(id, fold_results, dataset, model_config, device): | |
| if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: | |
| raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") | |
| all_shap_values = [] | |
| if id == 'Multi': | |
| bg_ds, bg_idx, other_ds, other_idx = get_background_data(id, dataset, samples=50, return_other_samples=True) | |
| print("total background samples: ", len(bg_idx), "total test samples: ", len(other_idx)) | |
| for fold in fold_results: | |
| val_idx = fold['val_idx'] | |
| # filter val_idx if is in indices | |
| val_idx = [i for i in val_idx if i in other_idx] | |
| if len(val_idx) == 0: | |
| print('No samples of the specified type in the validation set. Skipping...') | |
| continue | |
| else: | |
| print(f'fold {fold["fold"]} -> {len(val_idx)} samples') | |
| val_ds = filter_ds(dataset, val_idx) | |
| val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle=False) | |
| if id=='Multi': | |
| model = create_multimodal_model(model_config, device, use_mlm=False) | |
| else: | |
| model = SingleTransformer(id=id, **model_config).to(device) | |
| model_path = fold['best_model_path'] | |
| model.load_state_dict(torch.load(model_path)) | |
| model.eval() | |
| wrapped_model = ShapWrapper(model).to(device) | |
| bg_x = torch.cat([bg_ds.rna_data, bg_ds.atac_data, bg_ds.flux_data], dim=1).to(device) | |
| bg_b = bg_ds.batch_no.to(device) | |
| bgx = torch.cat([bg_x, bg_b[...,None]], dim=-1) | |
| explainer = shap.GradientExplainer(wrapped_model, bgx) | |
| inputs, batch_indices = (val_ds.rna_data, val_ds.atac_data, val_ds.flux_data), val_ds.batch_no | |
| inputs = torch.cat([inputs[0], inputs[1], inputs[2]], dim=1).to(device) | |
| batch_indices = batch_indices.to(device) | |
| bgv = torch.cat([inputs, batch_indices[...,None]], dim=-1) | |
| shap_values = explainer(bgv) | |
| all_shap_values.append(shap_values) | |
| return all_shap_values |