| | import argparse |
| | import os |
| | import yaml as yaml |
| | import numpy as np |
| | import random |
| | import time |
| | import datetime |
| | import json |
| | from pathlib import Path |
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from sklearn.metrics import roc_auc_score, precision_recall_curve, accuracy_score |
| |
|
| | from models.model_MeDSLIP import MeDSLIP |
| | from dataset.dataset import Chestxray14_Dataset |
| | from models.tokenization_bert import BertTokenizer |
| |
|
| | from tqdm import tqdm |
| |
|
| | chexray14_cls = [ |
| | "atelectasis", |
| | "cardiomegaly", |
| | "effusion", |
| | "infiltrate", |
| | "mass", |
| | "nodule", |
| | "pneumonia", |
| | "pneumothorax", |
| | "consolidation", |
| | "edema", |
| | "emphysema", |
| | "tail_abnorm_obs", |
| | "thicken", |
| | "hernia", |
| | ] |
| |
|
| | original_class = [ |
| | "normal", |
| | "clear", |
| | "sharp", |
| | "sharply", |
| | "unremarkable", |
| | "intact", |
| | "stable", |
| | "free", |
| | "effusion", |
| | "opacity", |
| | "pneumothorax", |
| | "edema", |
| | "atelectasis", |
| | "tube", |
| | "consolidation", |
| | "process", |
| | "abnormality", |
| | "enlarge", |
| | "tip", |
| | "low", |
| | "pneumonia", |
| | "line", |
| | "congestion", |
| | "catheter", |
| | "cardiomegaly", |
| | "fracture", |
| | "air", |
| | "tortuous", |
| | "lead", |
| | "disease", |
| | "calcification", |
| | "prominence", |
| | "device", |
| | "engorgement", |
| | "picc", |
| | "clip", |
| | "elevation", |
| | "expand", |
| | "nodule", |
| | "wire", |
| | "fluid", |
| | "degenerative", |
| | "pacemaker", |
| | "thicken", |
| | "marking", |
| | "scar", |
| | "hyperinflate", |
| | "blunt", |
| | "loss", |
| | "widen", |
| | "collapse", |
| | "density", |
| | "emphysema", |
| | "aerate", |
| | "mass", |
| | "crowd", |
| | "infiltrate", |
| | "obscure", |
| | "deformity", |
| | "hernia", |
| | "drainage", |
| | "distention", |
| | "shift", |
| | "stent", |
| | "pressure", |
| | "lesion", |
| | "finding", |
| | "borderline", |
| | "hardware", |
| | "dilation", |
| | "chf", |
| | "redistribution", |
| | "aspiration", |
| | "tail_abnorm_obs", |
| | "excluded_obs", |
| | ] |
| |
|
| | mapping = [] |
| | for disease in chexray14_cls: |
| | if disease in original_class: |
| | mapping.append(original_class.index(disease)) |
| | else: |
| | mapping.append(-1) |
| | MIMIC_mapping = [_ for i, _ in enumerate(mapping) if _ != -1] |
| | chexray14_mapping = [i for i, _ in enumerate(mapping) if _ != -1] |
| | target_class = [chexray14_cls[i] for i in chexray14_mapping] |
| |
|
| |
|
| | def compute_AUCs(gt, pred, n_class): |
| | """Computes Area Under the Curve (AUC) from prediction scores. |
| | Args: |
| | gt: Pytorch tensor on GPU, shape = [n_samples, n_classes] |
| | true binary labels. |
| | pred: Pytorch tensor on GPU, shape = [n_samples, n_classes] |
| | can either be probability estimates of the positive class, |
| | confidence values, or binary decisions. |
| | Returns: |
| | List of AUROCs of all classes. |
| | """ |
| | AUROCs = [] |
| | gt_np = gt.cpu().numpy() |
| | pred_np = pred.cpu().numpy() |
| | for i in range(n_class): |
| | AUROCs.append(roc_auc_score(gt_np[:, i], pred_np[:, i])) |
| | return AUROCs |
| |
|
| |
|
| | def get_tokenizer(tokenizer, target_text): |
| |
|
| | target_tokenizer = tokenizer( |
| | list(target_text), |
| | padding="max_length", |
| | truncation=True, |
| | max_length=64, |
| | return_tensors="pt", |
| | ) |
| |
|
| | return target_tokenizer |
| |
|
| |
|
| | def test(args, config): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("Total CUDA devices: ", torch.cuda.device_count()) |
| | torch.set_default_tensor_type("torch.FloatTensor") |
| |
|
| | test_dataset = Chestxray14_Dataset(config["test_file"], is_train=False) |
| | test_dataloader = DataLoader( |
| | test_dataset, |
| | batch_size=config["test_batch_size"], |
| | num_workers=30, |
| | pin_memory=True, |
| | sampler=None, |
| | shuffle=True, |
| | collate_fn=None, |
| | drop_last=False, |
| | ) |
| |
|
| | print("Creating book") |
| | json_book = json.load(open(config["disease_book"], "r")) |
| | disease_book = [json_book[i] for i in json_book] |
| | tokenizer = BertTokenizer.from_pretrained(config["text_encoder"]) |
| | disease_book_tokenizer = get_tokenizer(tokenizer, disease_book).to(device) |
| |
|
| | print("Creating model") |
| | model = MeDSLIP(config, disease_book_tokenizer) |
| | if args.ddp: |
| | model = nn.DataParallel( |
| | model, device_ids=[i for i in range(torch.cuda.device_count())] |
| | ) |
| | model = model.to(device) |
| |
|
| | print("Load model from checkpoint:", args.model_path) |
| | checkpoint = torch.load(args.model_path, map_location="cpu") |
| | state_dict = checkpoint["model"] |
| | model.load_state_dict(state_dict) |
| |
|
| | |
| | gt = torch.FloatTensor() |
| | gt = gt.to(device) |
| | pred = torch.FloatTensor() |
| | pred = pred.to(device) |
| |
|
| | print("Start testing") |
| | model.eval() |
| | loop = tqdm(test_dataloader) |
| | for i, sample in enumerate(loop): |
| | loop.set_description(f"Testing: {i+1}/{len(test_dataloader)}") |
| | image = sample["image"] |
| | label = sample["label"][:, chexray14_mapping].float().to(device) |
| | gt = torch.cat((gt, label), 0) |
| | input_image = image.to(device, non_blocking=True) |
| | with torch.no_grad(): |
| | pred_class = model(input_image) |
| | pred_class = F.softmax(pred_class.reshape(-1, 2)).reshape( |
| | -1, len(original_class), 2 |
| | ) |
| | pred_class = pred_class[:, MIMIC_mapping, 1] |
| | pred = torch.cat((pred, pred_class), 0) |
| |
|
| | AUROCs = compute_AUCs(gt, pred, len(target_class)) |
| | AUROC_avg = np.array(AUROCs).mean() |
| | print("The average AUROC is {AUROC_avg:.4f}".format(AUROC_avg=AUROC_avg)) |
| | for i in range(len(target_class)): |
| | print("The AUROC of {} is {}".format(target_class[i], AUROCs[i])) |
| | max_f1s = [] |
| | accs = [] |
| | for i in range(len(target_class)): |
| | gt_np = gt[:, i].cpu().numpy() |
| | pred_np = pred[:, i].cpu().numpy() |
| | precision, recall, thresholds = precision_recall_curve(gt_np, pred_np) |
| | numerator = 2 * recall * precision |
| | denom = recall + precision |
| | f1_scores = np.divide( |
| | numerator, denom, out=np.zeros_like(denom), where=(denom != 0) |
| | ) |
| | max_f1 = np.max(f1_scores) |
| | max_f1_thresh = thresholds[np.argmax(f1_scores)] |
| | max_f1s.append(max_f1) |
| | accs.append(accuracy_score(gt_np, pred_np > max_f1_thresh)) |
| |
|
| | f1_avg = np.array(max_f1s).mean() |
| | acc_avg = np.array(accs).mean() |
| | print("The average f1 is {F1_avg:.4f}".format(F1_avg=f1_avg)) |
| | print("The average ACC is {ACC_avg:.4f}".format(ACC_avg=acc_avg)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--config", |
| | default="Sample_zero-shot_Classification_CXR14/configs/MeDSLIP_config.yaml", |
| | ) |
| |
|
| | parser.add_argument("--model_path", default="MeDSLIP_resnet50.pth") |
| | parser.add_argument("--device", default="cuda") |
| | parser.add_argument("--gpu", type=str, default="0", help="gpu") |
| | parser.add_argument("--ddp", action="store_true", help="whether to use ddp") |
| | args = parser.parse_args() |
| |
|
| | config = yaml.load(open(args.config, "r"), Loader=yaml.Loader) |
| |
|
| | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
| | if args.gpu != "-1": |
| | torch.cuda.current_device() |
| | torch.cuda._initialized = True |
| |
|
| | test(args, config) |
| |
|