""" Define loss functions needed for training the model — padding safe (-1 sentinel) """ import torch import torch.nn.functional as F from torchmetrics.functional.classification import ( auroc, average_precision, roc, precision_recall_curve ) import rootutils from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) logger = pylogger.RankedLogger(__name__, rank_zero_only=True) def _expand_like(mask: torch.Tensor, like: torch.Tensor): # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1)) while mask.dim() < like.dim(): mask = mask.unsqueeze(-1) return mask.expand_as(like) def bce_loss_masked(logits, targets, mask, pos_weight=None, eps=1e-8): """ Compute masked BCE with logits over non-peak positions only. Expects nonpeak_mask already broadcastable to logits. """ # Clamp targets into [0,1] to be safe, even if pads slip through earlier t = targets.clamp(0.0, 1.0) loss = F.binary_cross_entropy_with_logits( logits, t, reduction="none", pos_weight=pos_weight ) m = _expand_like(mask, loss).to(loss.dtype) denom = m.sum().clamp_min(eps) return (loss * m).sum() / denom def mse_peaks_only(logits, targets, peak_mask, eps=1e-8): """ Calculate MSE on peaks only (on probabilities), masking everything else. """ probs = torch.sigmoid(logits) per_elem = F.mse_loss(probs, targets, reduction="none") m = _expand_like(peak_mask, per_elem).to(per_elem.dtype) denom = m.sum().clamp_min(eps) return (per_elem * m).sum() / denom def calculate_loss( logits, targets, binder_kpm, glm_kpm, eps: float = 1e-8, alpha: float = 1.0, gamma: float = 1.0, pos_weight=None, pad_value: float = -1.0, loss_type="mixed" ): """ Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding. Assumes targets == -1 are pads; non-peak = 0; peak > 0. binder_kpm is 1 at PAD positions, 0 elsewhere glm_kpm is 1 at PAD positions, 0 elsewhere if loss_type is mixed, we're doing binary cross entropy off the peaks and MSE on the peaks. if loss_type is binary, we're doing binary cross entropy everywhere because the labels are binary. """ # calculate validity in two ways; these should be the same. # targets are padded to -1 where there is not really a DNA sequence there valid = (targets != pad_value) if glm_kpm is not None: nvalid = torch.sum(valid).item() nvalid_2 = torch.sum(~glm_kpm).item() assert nvalid==nvalid_2 # Peak / non-peak masks that exclude pads nonpeak_mask = valid & (targets == 0) peak_mask = valid & (targets > 0) # For safety, zero-out targets at pad positions so they never feed into BCE/MSE targets_safe = torch.where(valid, targets, torch.zeros_like(targets)) if loss_type=="mixed": bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps) mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps) return alpha * bce_nonpeak + gamma * mse_peak else: # we're expecting all binary labels. make sure. all_binary = ((targets_safe==1) | (targets_safe==0)).all().item() if not(all_binary): logger.info(f"WARNING: expecting all binary labels for loss_type={loss_type}. Did not get all binary labels.") # bce over all valid positions bce_all = bce_loss_masked(logits, targets_safe, valid, pos_weight=pos_weight, eps=eps) return alpha*bce_all @torch.no_grad() def auroc_zeros_vs_ones_from_logits( logits: torch.Tensor, # (B, L) labels: torch.Tensor, # (B, L) glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD pos_thresh: float = 0.99, ): """ Returns: auc: scalar tensor (AUROC) n_pos, n_neg: ints tpr, fpr: tensors of shape (T,) thresholds: tensor of shape (T,) tp, fp: integer counts per threshold (shape (T,)) """ device = logits.device # glm_kpm is 1 where there's a pad, so ~glm_kpm is valid positions valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device) keep = valid & ((labels > pos_thresh) | (labels == 0.0)) if keep.sum() == 0: return (torch.tensor(float('nan'), device=device), 0, 0, torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) y = (labels[keep] > pos_thresh).to(torch.int) s = logits[keep] n_pos = int(y.sum().item()) n_neg = y.numel() - n_pos if n_pos == 0 or n_neg == 0: return (torch.tensor(float('nan'), device=device), n_pos, n_neg, torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) # Full ROC curve fpr, tpr, thresholds = roc(s, y, task="binary") # AUROC (TM handles logits) auc = auroc(s, y, task="binary") # Convert rates to counts (round to nearest to avoid float off-by-one) tp = (tpr * n_pos).round().to(torch.long) fp = (fpr * n_neg).round().to(torch.long) return auc.to(device), n_pos, n_neg, tpr.to(device), fpr.to(device), thresholds.to(device), tp.to(device), fp.to(device) @torch.no_grad() def auprc_zeros_vs_ones_from_logits( logits: torch.Tensor, # (B, L) labels: torch.Tensor, # (B, L) glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD pos_thresh: float = 0.99, ): """ Returns: ap: scalar tensor (Average Precision / AUPRC) n_pos, n_neg: ints precision: (T,) recall: (T,) thresholds: (T,) """ device = logits.device # glm_kpm is 1 where there's a pad, so ~glm_kpm is valid valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device) keep = valid & ((labels > pos_thresh) | (labels == 0.0)) if keep.sum() == 0: return (torch.tensor(float('nan'), device=device), 0, 0, torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) y = (labels[keep] > pos_thresh).to(torch.int) s = logits[keep] n_pos = int(y.sum().item()) n_neg = y.numel() - n_pos if n_pos == 0: # By convention, AP=0 when there are no positives return (torch.tensor(0.0, device=device), 0, n_neg, torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device)) # Full PR curve precision, recall, thresholds = precision_recall_curve(s, y, task="binary") # Average Precision / AUPRC ap = average_precision(s, y, task="binary") return ap.to(device), n_pos, n_neg, precision.to(device), recall.to(device), thresholds.to(device) def accuracy_percentage( logits, targets, peak_thresh: float = 0.5, eps: float = 1e-8, pad_value: float = -1.0, ): """ Compute accuracy for predicting high-confidence peaks (prob >= 0.5), ignoring padding. """ valid = (targets != pad_value) probs = torch.sigmoid(logits) preds_bin = (probs >= 0.5) labels = (targets >= peak_thresh) v = _expand_like(valid, preds_bin) correct = ((preds_bin == labels) & v).to(torch.float32).sum() total = v.to(torch.float32).sum().clamp_min(eps) return (correct / total).item() * 100.0 if __name__ == "__main__": import torch torch.manual_seed(0) PAD = -1.0 def make_targets_BL(B=2, L=8, pad_positions=(6, 7)): """Create (B,L) targets: 0=non-peak, >0=peak, -1=pad.""" t = torch.zeros(B, L) # sprinkle a few peaks (values in [0.6, 1.0]) t[:, 1] = torch.rand(B) * 0.4 + 0.6 t[:, 3] = torch.rand(B) * 0.4 + 0.6 # pads for p in pad_positions: t[:, p] = PAD return t def make_targets_BLC(B=2, L=8, C=3, pad_positions=(6, 7)): """ Create (B,L,C) targets by broadcasting a (B,L) base across channels (so masking needs to expand correctly). """ base = make_targets_BL(B, L, pad_positions) # (B,L) t = base.unsqueeze(-1).expand(-1, -1, C).clone() # Make channel 1 slightly different to show per-channel variety t[..., 1] = torch.where(t[..., 1] > 0, (t[..., 1] * 0.85).clamp(0, 1), t[..., 1]) return t def mask_stats(name, logits, targets, pad_value=PAD): valid = (targets != pad_value) nonpeak_mask = valid & (targets == 0) peak_mask = valid & (targets > 0) m_nonpeak = _expand_like(nonpeak_mask, logits) m_peak = _expand_like(peak_mask, logits) print(f"\n[{name}]") print(f" logits.shape = {tuple(logits.shape)}") print(f" targets.shape = {tuple(targets.shape)}") # Previews (first batch) if targets.dim() == 2: # (B,L) print(f" targets[0,:] preview: {targets[0]}") else: # (B,L,C) print(f" targets[0,:,0] ch0 preview: {targets[0,:,0]}") print(f" targets[0,:,1] ch1 preview: {targets[0,:,1]}") # Mask counts after EXPANSION (these define denominators) print(f" #non-peak elems used = {m_nonpeak.sum().item():.0f}") print(f" #peak elems used = {m_peak.sum().item():.0f}") # ========================= # Case A: (B, L) # ========================= B, L = 2, 8 logits_BL = torch.randn(B, L) # raw scores targets_BL = make_targets_BL(B, L) # 0, >0, and -1 pads mask_stats("BL", logits_BL, targets_BL, pad_value=PAD) loss_BL = calculate_loss( logits_BL, targets_BL, pad_value=PAD, alpha=1.0, gamma=1.0 ) acc_BL = accuracy_percentage( logits_BL, targets_BL, pad_value=PAD, peak_thresh=0.5 ) print(f" loss_BL = {loss_BL.item():.6f}") print(f" acc_BL = {acc_BL:.2f}%") # ========================= # Case B: (B, L, C) # ========================= B, L, C = 2, 8, 3 logits_BLC = torch.randn(B, L, C) # raw scores with channels targets_BLC = make_targets_BLC(B, L, C) # broadcasted targets + tweaks mask_stats("BLC", logits_BLC, targets_BLC, pad_value=PAD) loss_BLC = calculate_loss( logits_BLC, targets_BLC, pad_value=PAD, alpha=1.0, gamma=1.0 ) acc_BLC = accuracy_percentage( logits_BLC, targets_BLC, pad_value=PAD, peak_thresh=0.5 ) print(f" loss_BLC = {loss_BLC.item():.6f}") print(f" acc_BLC = {acc_BLC:.2f}%")