Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from monai.metrics import Cumulative, CumulativeAverage | |
| from sklearn.metrics import confusion_matrix, roc_auc_score | |
| def train_epoch(cspca_model, loader, optimizer, epoch, args): | |
| cspca_model.train() | |
| criterion = nn.BCELoss() | |
| loss = 0.0 | |
| run_loss = CumulativeAverage() | |
| targets_cumulative = Cumulative() | |
| preds_cumulative = Cumulative() | |
| for _, batch_data in enumerate(loader): | |
| data = batch_data["image"].as_subclass(torch.Tensor).to(args.device) | |
| target = batch_data["label"].as_subclass(torch.Tensor).to(args.device) | |
| optimizer.zero_grad() | |
| output = cspca_model(data) | |
| output = output.squeeze(1) | |
| loss = criterion(output, target) | |
| loss.backward() | |
| optimizer.step() | |
| targets_cumulative.extend(target.detach().cpu()) | |
| preds_cumulative.extend(output.detach().cpu()) | |
| run_loss.append(loss.item()) | |
| loss_epoch = run_loss.aggregate() | |
| target_list = targets_cumulative.get_buffer().cpu().numpy() | |
| pred_list = preds_cumulative.get_buffer().cpu().numpy() | |
| auc_epoch = roc_auc_score(target_list, pred_list) | |
| return loss_epoch, auc_epoch | |
| def val_epoch(cspca_model, loader, epoch, args): | |
| cspca_model.eval() | |
| criterion = nn.BCELoss() | |
| loss = 0.0 | |
| run_loss = CumulativeAverage() | |
| targets_cumulative = Cumulative() | |
| preds_cumulative = Cumulative() | |
| with torch.no_grad(): | |
| for _, batch_data in enumerate(loader): | |
| data = batch_data["image"].as_subclass(torch.Tensor).to(args.device) | |
| target = batch_data["label"].as_subclass(torch.Tensor).to(args.device) | |
| output = cspca_model(data) | |
| output = output.squeeze(1) | |
| loss = criterion(output, target) | |
| targets_cumulative.extend(target.detach().cpu()) | |
| preds_cumulative.extend(output.detach().cpu()) | |
| run_loss.append(loss.item()) | |
| loss_epoch = run_loss.aggregate() | |
| target_list = targets_cumulative.get_buffer().cpu().numpy() | |
| pred_list = preds_cumulative.get_buffer().cpu().numpy() | |
| auc_epoch = roc_auc_score(target_list, pred_list) | |
| y_pred_categoric = pred_list >= 0.5 | |
| tn, fp, fn, tp = confusion_matrix(target_list, y_pred_categoric).ravel() | |
| sens_epoch = tp / (tp + fn) | |
| spec_epoch = tn / (tn + fp) | |
| val_epoch_metric = { | |
| "epoch": epoch, | |
| "loss": loss_epoch, | |
| "auc": auc_epoch, | |
| "sensitivity": sens_epoch, | |
| "specificity": spec_epoch, | |
| } | |
| return val_epoch_metric | |