| """ |
| Define loss functions needed for training the model — padding safe (-1 sentinel) |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from torchmetrics.functional.classification import ( |
| auroc, average_precision, roc, precision_recall_curve |
| ) |
| import rootutils |
| from dpacman.utils import pylogger |
|
|
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| logger = pylogger.RankedLogger(__name__, rank_zero_only=True) |
|
|
|
|
| def _expand_like(mask: torch.Tensor, like: torch.Tensor): |
| |
| while mask.dim() < like.dim(): |
| mask = mask.unsqueeze(-1) |
| return mask.expand_as(like) |
|
|
| def bce_loss_masked(logits, targets, mask, pos_weight=None, eps=1e-8): |
| """ |
| Compute masked BCE with logits over non-peak positions only. |
| Expects nonpeak_mask already broadcastable to logits. |
| """ |
| |
| t = targets.clamp(0.0, 1.0) |
| loss = F.binary_cross_entropy_with_logits( |
| logits, t, reduction="none", pos_weight=pos_weight |
| ) |
| m = _expand_like(mask, loss).to(loss.dtype) |
| denom = m.sum().clamp_min(eps) |
| return (loss * m).sum() / denom |
|
|
| def mse_peaks_only(logits, targets, peak_mask, eps=1e-8): |
| """ |
| Calculate MSE on peaks only (on probabilities), masking everything else. |
| """ |
| probs = torch.sigmoid(logits) |
| per_elem = F.mse_loss(probs, targets, reduction="none") |
| m = _expand_like(peak_mask, per_elem).to(per_elem.dtype) |
| denom = m.sum().clamp_min(eps) |
| return (per_elem * m).sum() / denom |
|
|
| def calculate_loss( |
| logits, |
| targets, |
| binder_kpm, |
| glm_kpm, |
| eps: float = 1e-8, |
| alpha: float = 1.0, |
| gamma: float = 1.0, |
| pos_weight=None, |
| pad_value: float = -1.0, |
| loss_type="mixed" |
| ): |
| """ |
| Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding. |
| Assumes targets == -1 are pads; non-peak = 0; peak > 0. |
| |
| binder_kpm is 1 at PAD positions, 0 elsewhere |
| glm_kpm is 1 at PAD positions, 0 elsewhere |
| |
| if loss_type is mixed, we're doing binary cross entropy off the peaks and MSE on the peaks. |
| if loss_type is binary, we're doing binary cross entropy everywhere because the labels are binary. |
| """ |
| |
| |
| valid = (targets != pad_value) |
| if glm_kpm is not None: |
| nvalid = torch.sum(valid).item() |
| nvalid_2 = torch.sum(~glm_kpm).item() |
| assert nvalid==nvalid_2 |
|
|
| |
| nonpeak_mask = valid & (targets == 0) |
| peak_mask = valid & (targets > 0) |
|
|
| |
| targets_safe = torch.where(valid, targets, torch.zeros_like(targets)) |
|
|
| if loss_type=="mixed": |
| bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps) |
| mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps) |
| return alpha * bce_nonpeak + gamma * mse_peak |
| else: |
| |
| all_binary = ((targets_safe==1) | (targets_safe==0)).all().item() |
| if not(all_binary): |
| logger.info(f"WARNING: expecting all binary labels for loss_type={loss_type}. Did not get all binary labels.") |
| |
| bce_all = bce_loss_masked(logits, targets_safe, valid, pos_weight=pos_weight, eps=eps) |
| return alpha*bce_all |
|
|
| @torch.no_grad() |
| def auroc_zeros_vs_ones_from_logits( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| glm_kpm: torch.Tensor | None = None, |
| pos_thresh: float = 0.99, |
| ): |
| """ |
| Returns: |
| auc: scalar tensor (AUROC) |
| n_pos, n_neg: ints |
| tpr, fpr: tensors of shape (T,) |
| thresholds: tensor of shape (T,) |
| tp, fp: integer counts per threshold (shape (T,)) |
| """ |
| device = logits.device |
| |
| valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device) |
| keep = valid & ((labels > pos_thresh) | (labels == 0.0)) |
| if keep.sum() == 0: |
| return (torch.tensor(float('nan'), device=device), 0, 0, |
| torch.empty(0, device=device), torch.empty(0, device=device), |
| torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) |
|
|
| y = (labels[keep] > pos_thresh).to(torch.int) |
| s = logits[keep] |
|
|
| n_pos = int(y.sum().item()) |
| n_neg = y.numel() - n_pos |
| if n_pos == 0 or n_neg == 0: |
| return (torch.tensor(float('nan'), device=device), n_pos, n_neg, |
| torch.empty(0, device=device), torch.empty(0, device=device), |
| torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) |
|
|
| |
| fpr, tpr, thresholds = roc(s, y, task="binary") |
| |
| auc = auroc(s, y, task="binary") |
|
|
| |
| tp = (tpr * n_pos).round().to(torch.long) |
| fp = (fpr * n_neg).round().to(torch.long) |
|
|
| return auc.to(device), n_pos, n_neg, tpr.to(device), fpr.to(device), thresholds.to(device), tp.to(device), fp.to(device) |
|
|
|
|
| @torch.no_grad() |
| def auprc_zeros_vs_ones_from_logits( |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| glm_kpm: torch.Tensor | None = None, |
| pos_thresh: float = 0.99, |
| ): |
| """ |
| Returns: |
| ap: scalar tensor (Average Precision / AUPRC) |
| n_pos, n_neg: ints |
| precision: (T,) |
| recall: (T,) |
| thresholds: (T,) |
| """ |
| device = logits.device |
| |
| valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device) |
| keep = valid & ((labels > pos_thresh) | (labels == 0.0)) |
| if keep.sum() == 0: |
| return (torch.tensor(float('nan'), device=device), 0, 0, |
| torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) |
|
|
| y = (labels[keep] > pos_thresh).to(torch.int) |
| s = logits[keep] |
|
|
| n_pos = int(y.sum().item()) |
| n_neg = y.numel() - n_pos |
| if n_pos == 0: |
| |
| return (torch.tensor(0.0, device=device), 0, n_neg, |
| torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) |
|
|
| |
| precision, recall, thresholds = precision_recall_curve(s, y, task="binary") |
| |
| ap = average_precision(s, y, task="binary") |
|
|
| return ap.to(device), n_pos, n_neg, precision.to(device), recall.to(device), thresholds.to(device) |
|
|
| def accuracy_percentage( |
| logits, |
| targets, |
| peak_thresh: float = 0.5, |
| eps: float = 1e-8, |
| pad_value: float = -1.0, |
| ): |
| """ |
| Compute accuracy for predicting high-confidence peaks (prob >= 0.5), ignoring padding. |
| """ |
| valid = (targets != pad_value) |
| probs = torch.sigmoid(logits) |
| preds_bin = (probs >= 0.5) |
| labels = (targets >= peak_thresh) |
|
|
| v = _expand_like(valid, preds_bin) |
| correct = ((preds_bin == labels) & v).to(torch.float32).sum() |
| total = v.to(torch.float32).sum().clamp_min(eps) |
| return (correct / total).item() * 100.0 |
|
|
| if __name__ == "__main__": |
| import torch |
|
|
| torch.manual_seed(0) |
| PAD = -1.0 |
|
|
| def make_targets_BL(B=2, L=8, pad_positions=(6, 7)): |
| """Create (B,L) targets: 0=non-peak, >0=peak, -1=pad.""" |
| t = torch.zeros(B, L) |
| |
| t[:, 1] = torch.rand(B) * 0.4 + 0.6 |
| t[:, 3] = torch.rand(B) * 0.4 + 0.6 |
| |
| for p in pad_positions: |
| t[:, p] = PAD |
| return t |
|
|
| def make_targets_BLC(B=2, L=8, C=3, pad_positions=(6, 7)): |
| """ |
| Create (B,L,C) targets by broadcasting a (B,L) base across channels |
| (so masking needs to expand correctly). |
| """ |
| base = make_targets_BL(B, L, pad_positions) |
| t = base.unsqueeze(-1).expand(-1, -1, C).clone() |
| |
| t[..., 1] = torch.where(t[..., 1] > 0, (t[..., 1] * 0.85).clamp(0, 1), t[..., 1]) |
| return t |
|
|
| def mask_stats(name, logits, targets, pad_value=PAD): |
| valid = (targets != pad_value) |
| nonpeak_mask = valid & (targets == 0) |
| peak_mask = valid & (targets > 0) |
|
|
| m_nonpeak = _expand_like(nonpeak_mask, logits) |
| m_peak = _expand_like(peak_mask, logits) |
|
|
| print(f"\n[{name}]") |
| print(f" logits.shape = {tuple(logits.shape)}") |
| print(f" targets.shape = {tuple(targets.shape)}") |
| |
| if targets.dim() == 2: |
| print(f" targets[0,:] preview: {targets[0]}") |
| else: |
| print(f" targets[0,:,0] ch0 preview: {targets[0,:,0]}") |
| print(f" targets[0,:,1] ch1 preview: {targets[0,:,1]}") |
| |
| print(f" #non-peak elems used = {m_nonpeak.sum().item():.0f}") |
| print(f" #peak elems used = {m_peak.sum().item():.0f}") |
|
|
| |
| |
| |
| B, L = 2, 8 |
| logits_BL = torch.randn(B, L) |
| targets_BL = make_targets_BL(B, L) |
|
|
| mask_stats("BL", logits_BL, targets_BL, pad_value=PAD) |
|
|
| loss_BL = calculate_loss( |
| logits_BL, targets_BL, pad_value=PAD, alpha=1.0, gamma=1.0 |
| ) |
| acc_BL = accuracy_percentage( |
| logits_BL, targets_BL, pad_value=PAD, peak_thresh=0.5 |
| ) |
|
|
| print(f" loss_BL = {loss_BL.item():.6f}") |
| print(f" acc_BL = {acc_BL:.2f}%") |
|
|
| |
| |
| |
| B, L, C = 2, 8, 3 |
| logits_BLC = torch.randn(B, L, C) |
| targets_BLC = make_targets_BLC(B, L, C) |
|
|
| mask_stats("BLC", logits_BLC, targets_BLC, pad_value=PAD) |
|
|
| loss_BLC = calculate_loss( |
| logits_BLC, targets_BLC, pad_value=PAD, alpha=1.0, gamma=1.0 |
| ) |
| acc_BLC = accuracy_percentage( |
| logits_BLC, targets_BLC, pad_value=PAD, peak_thresh=0.5 |
| ) |
|
|
| print(f" loss_BLC = {loss_BLC.item():.6f}") |
| print(f" acc_BLC = {acc_BLC:.2f}%") |