svincoff's picture
eval mode, fixed, full binary mode
7b33404
"""
Define loss functions needed for training the model — padding safe (-1 sentinel)
"""
import torch
import torch.nn.functional as F
from torchmetrics.functional.classification import (
auroc, average_precision, roc, precision_recall_curve
)
import rootutils
from dpacman.utils import pylogger
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
def _expand_like(mask: torch.Tensor, like: torch.Tensor):
# Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
while mask.dim() < like.dim():
mask = mask.unsqueeze(-1)
return mask.expand_as(like)
def bce_loss_masked(logits, targets, mask, pos_weight=None, eps=1e-8):
"""
Compute masked BCE with logits over non-peak positions only.
Expects nonpeak_mask already broadcastable to logits.
"""
# Clamp targets into [0,1] to be safe, even if pads slip through earlier
t = targets.clamp(0.0, 1.0)
loss = F.binary_cross_entropy_with_logits(
logits, t, reduction="none", pos_weight=pos_weight
)
m = _expand_like(mask, loss).to(loss.dtype)
denom = m.sum().clamp_min(eps)
return (loss * m).sum() / denom
def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
"""
Calculate MSE on peaks only (on probabilities), masking everything else.
"""
probs = torch.sigmoid(logits)
per_elem = F.mse_loss(probs, targets, reduction="none")
m = _expand_like(peak_mask, per_elem).to(per_elem.dtype)
denom = m.sum().clamp_min(eps)
return (per_elem * m).sum() / denom
def calculate_loss(
logits,
targets,
binder_kpm,
glm_kpm,
eps: float = 1e-8,
alpha: float = 1.0,
gamma: float = 1.0,
pos_weight=None,
pad_value: float = -1.0,
loss_type="mixed"
):
"""
Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding.
Assumes targets == -1 are pads; non-peak = 0; peak > 0.
binder_kpm is 1 at PAD positions, 0 elsewhere
glm_kpm is 1 at PAD positions, 0 elsewhere
if loss_type is mixed, we're doing binary cross entropy off the peaks and MSE on the peaks.
if loss_type is binary, we're doing binary cross entropy everywhere because the labels are binary.
"""
# calculate validity in two ways; these should be the same.
# targets are padded to -1 where there is not really a DNA sequence there
valid = (targets != pad_value)
if glm_kpm is not None:
nvalid = torch.sum(valid).item()
nvalid_2 = torch.sum(~glm_kpm).item()
assert nvalid==nvalid_2
# Peak / non-peak masks that exclude pads
nonpeak_mask = valid & (targets == 0)
peak_mask = valid & (targets > 0)
# For safety, zero-out targets at pad positions so they never feed into BCE/MSE
targets_safe = torch.where(valid, targets, torch.zeros_like(targets))
if loss_type=="mixed":
bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps)
mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps)
return alpha * bce_nonpeak + gamma * mse_peak
else:
# we're expecting all binary labels. make sure.
all_binary = ((targets_safe==1) | (targets_safe==0)).all().item()
if not(all_binary):
logger.info(f"WARNING: expecting all binary labels for loss_type={loss_type}. Did not get all binary labels.")
# bce over all valid positions
bce_all = bce_loss_masked(logits, targets_safe, valid, pos_weight=pos_weight, eps=eps)
return alpha*bce_all
@torch.no_grad()
def auroc_zeros_vs_ones_from_logits(
logits: torch.Tensor, # (B, L)
labels: torch.Tensor, # (B, L)
glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
pos_thresh: float = 0.99,
):
"""
Returns:
auc: scalar tensor (AUROC)
n_pos, n_neg: ints
tpr, fpr: tensors of shape (T,)
thresholds: tensor of shape (T,)
tp, fp: integer counts per threshold (shape (T,))
"""
device = logits.device
# glm_kpm is 1 where there's a pad, so ~glm_kpm is valid positions
valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
keep = valid & ((labels > pos_thresh) | (labels == 0.0))
if keep.sum() == 0:
return (torch.tensor(float('nan'), device=device), 0, 0,
torch.empty(0, device=device), torch.empty(0, device=device),
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
y = (labels[keep] > pos_thresh).to(torch.int)
s = logits[keep]
n_pos = int(y.sum().item())
n_neg = y.numel() - n_pos
if n_pos == 0 or n_neg == 0:
return (torch.tensor(float('nan'), device=device), n_pos, n_neg,
torch.empty(0, device=device), torch.empty(0, device=device),
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
# Full ROC curve
fpr, tpr, thresholds = roc(s, y, task="binary")
# AUROC (TM handles logits)
auc = auroc(s, y, task="binary")
# Convert rates to counts (round to nearest to avoid float off-by-one)
tp = (tpr * n_pos).round().to(torch.long)
fp = (fpr * n_neg).round().to(torch.long)
return auc.to(device), n_pos, n_neg, tpr.to(device), fpr.to(device), thresholds.to(device), tp.to(device), fp.to(device)
@torch.no_grad()
def auprc_zeros_vs_ones_from_logits(
logits: torch.Tensor, # (B, L)
labels: torch.Tensor, # (B, L)
glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
pos_thresh: float = 0.99,
):
"""
Returns:
ap: scalar tensor (Average Precision / AUPRC)
n_pos, n_neg: ints
precision: (T,)
recall: (T,)
thresholds: (T,)
"""
device = logits.device
# glm_kpm is 1 where there's a pad, so ~glm_kpm is valid
valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
keep = valid & ((labels > pos_thresh) | (labels == 0.0))
if keep.sum() == 0:
return (torch.tensor(float('nan'), device=device), 0, 0,
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
y = (labels[keep] > pos_thresh).to(torch.int)
s = logits[keep]
n_pos = int(y.sum().item())
n_neg = y.numel() - n_pos
if n_pos == 0:
# By convention, AP=0 when there are no positives
return (torch.tensor(0.0, device=device), 0, n_neg,
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
# Full PR curve
precision, recall, thresholds = precision_recall_curve(s, y, task="binary")
# Average Precision / AUPRC
ap = average_precision(s, y, task="binary")
return ap.to(device), n_pos, n_neg, precision.to(device), recall.to(device), thresholds.to(device)
def accuracy_percentage(
logits,
targets,
peak_thresh: float = 0.5,
eps: float = 1e-8,
pad_value: float = -1.0,
):
"""
Compute accuracy for predicting high-confidence peaks (prob >= 0.5), ignoring padding.
"""
valid = (targets != pad_value)
probs = torch.sigmoid(logits)
preds_bin = (probs >= 0.5)
labels = (targets >= peak_thresh)
v = _expand_like(valid, preds_bin)
correct = ((preds_bin == labels) & v).to(torch.float32).sum()
total = v.to(torch.float32).sum().clamp_min(eps)
return (correct / total).item() * 100.0
if __name__ == "__main__":
import torch
torch.manual_seed(0)
PAD = -1.0
def make_targets_BL(B=2, L=8, pad_positions=(6, 7)):
"""Create (B,L) targets: 0=non-peak, >0=peak, -1=pad."""
t = torch.zeros(B, L)
# sprinkle a few peaks (values in [0.6, 1.0])
t[:, 1] = torch.rand(B) * 0.4 + 0.6
t[:, 3] = torch.rand(B) * 0.4 + 0.6
# pads
for p in pad_positions:
t[:, p] = PAD
return t
def make_targets_BLC(B=2, L=8, C=3, pad_positions=(6, 7)):
"""
Create (B,L,C) targets by broadcasting a (B,L) base across channels
(so masking needs to expand correctly).
"""
base = make_targets_BL(B, L, pad_positions) # (B,L)
t = base.unsqueeze(-1).expand(-1, -1, C).clone()
# Make channel 1 slightly different to show per-channel variety
t[..., 1] = torch.where(t[..., 1] > 0, (t[..., 1] * 0.85).clamp(0, 1), t[..., 1])
return t
def mask_stats(name, logits, targets, pad_value=PAD):
valid = (targets != pad_value)
nonpeak_mask = valid & (targets == 0)
peak_mask = valid & (targets > 0)
m_nonpeak = _expand_like(nonpeak_mask, logits)
m_peak = _expand_like(peak_mask, logits)
print(f"\n[{name}]")
print(f" logits.shape = {tuple(logits.shape)}")
print(f" targets.shape = {tuple(targets.shape)}")
# Previews (first batch)
if targets.dim() == 2: # (B,L)
print(f" targets[0,:] preview: {targets[0]}")
else: # (B,L,C)
print(f" targets[0,:,0] ch0 preview: {targets[0,:,0]}")
print(f" targets[0,:,1] ch1 preview: {targets[0,:,1]}")
# Mask counts after EXPANSION (these define denominators)
print(f" #non-peak elems used = {m_nonpeak.sum().item():.0f}")
print(f" #peak elems used = {m_peak.sum().item():.0f}")
# =========================
# Case A: (B, L)
# =========================
B, L = 2, 8
logits_BL = torch.randn(B, L) # raw scores
targets_BL = make_targets_BL(B, L) # 0, >0, and -1 pads
mask_stats("BL", logits_BL, targets_BL, pad_value=PAD)
loss_BL = calculate_loss(
logits_BL, targets_BL, pad_value=PAD, alpha=1.0, gamma=1.0
)
acc_BL = accuracy_percentage(
logits_BL, targets_BL, pad_value=PAD, peak_thresh=0.5
)
print(f" loss_BL = {loss_BL.item():.6f}")
print(f" acc_BL = {acc_BL:.2f}%")
# =========================
# Case B: (B, L, C)
# =========================
B, L, C = 2, 8, 3
logits_BLC = torch.randn(B, L, C) # raw scores with channels
targets_BLC = make_targets_BLC(B, L, C) # broadcasted targets + tweaks
mask_stats("BLC", logits_BLC, targets_BLC, pad_value=PAD)
loss_BLC = calculate_loss(
logits_BLC, targets_BLC, pad_value=PAD, alpha=1.0, gamma=1.0
)
acc_BLC = accuracy_percentage(
logits_BLC, targets_BLC, pad_value=PAD, peak_thresh=0.5
)
print(f" loss_BLC = {loss_BLC.item():.6f}")
print(f" acc_BLC = {acc_BLC:.2f}%")