import torch import torch.nn as nn from monai.metrics import Cumulative, CumulativeAverage from sklearn.metrics import confusion_matrix, roc_auc_score def train_epoch(cspca_model, loader, optimizer, epoch, args): cspca_model.train() criterion = nn.BCELoss() loss = 0.0 run_loss = CumulativeAverage() targets_cumulative = Cumulative() preds_cumulative = Cumulative() for _, batch_data in enumerate(loader): data = batch_data["image"].as_subclass(torch.Tensor).to(args.device) target = batch_data["label"].as_subclass(torch.Tensor).to(args.device) optimizer.zero_grad() output = cspca_model(data) output = output.squeeze(1) loss = criterion(output, target) loss.backward() optimizer.step() targets_cumulative.extend(target.detach().cpu()) preds_cumulative.extend(output.detach().cpu()) run_loss.append(loss.item()) loss_epoch = run_loss.aggregate() target_list = targets_cumulative.get_buffer().cpu().numpy() pred_list = preds_cumulative.get_buffer().cpu().numpy() auc_epoch = roc_auc_score(target_list, pred_list) return loss_epoch, auc_epoch def val_epoch(cspca_model, loader, epoch, args): cspca_model.eval() criterion = nn.BCELoss() loss = 0.0 run_loss = CumulativeAverage() targets_cumulative = Cumulative() preds_cumulative = Cumulative() with torch.no_grad(): for _, batch_data in enumerate(loader): data = batch_data["image"].as_subclass(torch.Tensor).to(args.device) target = batch_data["label"].as_subclass(torch.Tensor).to(args.device) output = cspca_model(data) output = output.squeeze(1) loss = criterion(output, target) targets_cumulative.extend(target.detach().cpu()) preds_cumulative.extend(output.detach().cpu()) run_loss.append(loss.item()) loss_epoch = run_loss.aggregate() target_list = targets_cumulative.get_buffer().cpu().numpy() pred_list = preds_cumulative.get_buffer().cpu().numpy() auc_epoch = roc_auc_score(target_list, pred_list) y_pred_categoric = pred_list >= 0.5 tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel() sens_epoch = tp / (tp + fn) spec_epoch = tn / (tn + fp) val_epoch_metric = { "epoch": epoch, "loss": loss_epoch, "auc": auc_epoch, "sensitivity": sens_epoch, "specificity": spec_epoch, } return val_epoch_metric