import torch from utils.helpers import create_multimodal_model from models import SingleTransformer from utils.helpers import get_all_modalities_available_samples from data import create_dataset import shap def filter_ds(dataset, indices): rna = dataset.rna_data[indices] atac = dataset.atac_data[indices] flux = dataset.flux_data[indices] new_ds = create_dataset.MultiModalDataset((rna, atac, flux), dataset.batch_no[indices], dataset.labels[indices]) return new_ds def get_background_data(id, dataset, samples=100, return_other_samples=False): """ Get background data with balanced samples from each label Args: dataset: MultiModalDataset object samples: Number of samples to get return_other_samples: If True, return other samples as well Returns: new_ds: MultiModalDataset object with background samples background_indices: Indices of background samples other_ds: MultiModalDataset object with other samples other_indices: Indices of other samples """ if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") if id == 'Multi': dataset = get_all_modalities_available_samples(dataset) labels = dataset.labels # get a balance of samples between labels samples_per_label = samples // len(torch.unique(labels)) background_indices = [] for label in torch.unique(labels): label_indices = torch.where(labels == label)[0] background_indices.extend(label_indices[:samples_per_label]) background_indices = torch.tensor(background_indices) background_rna = dataset.rna_data[background_indices] background_atac = dataset.atac_data[background_indices] background_flux = dataset.flux_data[background_indices] bg_ds = create_dataset.MultiModalDataset((background_rna, background_atac, background_flux), dataset.batch_no[background_indices], dataset.labels[background_indices]) if return_other_samples: # create a new dataset of other samples other_indices = torch.tensor([i for i in range(len(labels)) if i not in background_indices]) other_rna = dataset.rna_data[other_indices] other_atac = dataset.atac_data[other_indices] other_flux = dataset.flux_data[other_indices] other_ds = create_dataset.MultiModalDataset((other_rna, other_atac, other_flux), dataset.batch_no[other_indices], dataset.labels[other_indices]) return bg_ds, background_indices, other_ds, other_indices return bg_ds, background_indices else: raise ValueError("Not Implemented") class ShapWrapper(torch.nn.Module): def __init__(self, model): super().__init__() self.model = model self.model.eval() def forward(self, x): inputs, b = x[:,:-2], x[:,-1].squeeze(-1).long() inputs = (inputs[:,:944].long(), inputs[:,944:944+883].float(), inputs[:,944+883:].float()) preds, _ = self.model(inputs, b) preds = torch.sigmoid(preds) # print(preds.shape) return preds def compute_shap_values(id, fold_results, dataset, model_config, device): if id not in ['RNA', 'ATAC', 'Flux', 'Multi']: raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'") all_shap_values = [] if id == 'Multi': bg_ds, bg_idx, other_ds, other_idx = get_background_data(id, dataset, samples=50, return_other_samples=True) print("total background samples: ", len(bg_idx), "total test samples: ", len(other_idx)) for fold in fold_results: val_idx = fold['val_idx'] # filter val_idx if is in indices val_idx = [i for i in val_idx if i in other_idx] if len(val_idx) == 0: print('No samples of the specified type in the validation set. Skipping...') continue else: print(f'fold {fold["fold"]} -> {len(val_idx)} samples') val_ds = filter_ds(dataset, val_idx) val_loader = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle=False) if id=='Multi': model = create_multimodal_model(model_config, device, use_mlm=False) else: model = SingleTransformer(id=id, **model_config).to(device) model_path = fold['best_model_path'] model.load_state_dict(torch.load(model_path)) model.eval() wrapped_model = ShapWrapper(model).to(device) bg_x = torch.cat([bg_ds.rna_data, bg_ds.atac_data, bg_ds.flux_data], dim=1).to(device) bg_b = bg_ds.batch_no.to(device) bgx = torch.cat([bg_x, bg_b[...,None]], dim=-1) explainer = shap.GradientExplainer(wrapped_model, bgx) inputs, batch_indices = (val_ds.rna_data, val_ds.atac_data, val_ds.flux_data), val_ds.batch_no inputs = torch.cat([inputs[0], inputs[1], inputs[2]], dim=1).to(device) batch_indices = batch_indices.to(device) bgv = torch.cat([inputs, batch_indices[...,None]], dim=-1) shap_values = explainer(bgv) all_shap_values.append(shap_values) return all_shap_values