| |
| from __future__ import annotations |
|
|
| from typing import Dict, List, Optional, Sequence, Literal |
|
|
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| |
| from src.data_builder import TargetScaler, grouped_split_by_smiles |
|
|
|
|
| |
| |
| |
|
|
| def seed_everything(seed: int) -> None: |
| """Deterministically seed Python, NumPy, and PyTorch (CPU/CUDA).""" |
| import random |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def to_device(batch, device: torch.device): |
| """Move a PyG Batch or simple dict of tensors to device.""" |
| if hasattr(batch, "to"): |
| return batch.to(device) |
| if isinstance(batch, dict): |
| return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} |
| return batch |
|
|
|
|
| |
| |
| |
|
|
| def _safe_div(num: torch.Tensor, den: torch.Tensor) -> torch.Tensor: |
| den = torch.clamp(den, min=1e-12) |
| return num / den |
|
|
|
|
| def masked_mse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, |
| reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor: |
| """ |
| pred/target: [B, T]; mask: [B, T] bool |
| """ |
| pred, target = pred.float(), target.float() |
| mask = mask.bool() |
| se = ((pred - target) ** 2) * mask |
| if reduction == "sum": |
| return se.sum() |
| return _safe_div(se.sum(), mask.sum().float()) |
|
|
|
|
| def masked_mae(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, |
| reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor: |
| ae = (pred - target).abs() * mask |
| if reduction == "sum": |
| return ae.sum() |
| return _safe_div(ae.sum(), mask.sum().float()) |
|
|
|
|
| def masked_rmse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| return torch.sqrt(masked_mse(pred, target, mask, reduction="mean")) |
|
|
|
|
| def masked_r2(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """ |
| Masked coefficient of determination across all elements jointly. |
| """ |
| pred, target = pred.float(), target.float() |
| mask = mask.bool() |
| count = mask.sum().float().clamp(min=1.0) |
| mean = _safe_div((target * mask).sum(), count) |
| sst = (((target - mean) ** 2) * mask).sum() |
| sse = (((target - pred) ** 2) * mask).sum() |
| return 1.0 - _safe_div(sse, sst.clamp(min=1e-12)) |
|
|
|
|
| def masked_metrics_overall(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> Dict[str, float]: |
| return { |
| "rmse": float(masked_rmse(pred, target, mask).detach().cpu()), |
| "mae": float(masked_mae(pred, target, mask).detach().cpu()), |
| "r2": float(masked_r2(pred, target, mask).detach().cpu()), |
| } |
|
|
|
|
| def masked_metrics_per_task( |
| pred: torch.Tensor, |
| target: torch.Tensor, |
| mask: torch.Tensor, |
| task_names: Sequence[str], |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Per-task metrics using the same masked formulations. |
| """ |
| out: Dict[str, Dict[str, float]] = {} |
| for t, name in enumerate(task_names): |
| m = mask[:, t] |
| if m.any(): |
| rmse = float(masked_rmse(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) |
| mae = float(masked_mae(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) |
| r2 = float(masked_r2(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu()) |
| else: |
| rmse = mae = r2 = float("nan") |
| out[name] = {"rmse": rmse, "mae": mae, "r2": r2} |
| return out |
|
|
|
|
| def masked_metrics_by_fidelity( |
| pred: torch.Tensor, |
| target: torch.Tensor, |
| mask: torch.Tensor, |
| fid_idx: torch.Tensor, |
| fid_names: Sequence[str], |
| task_names: Sequence[str], |
| ) -> Dict[str, Dict[str, float]]: |
| """ |
| Overall metrics per fidelity (aggregated across tasks). |
| """ |
| out: Dict[str, Dict[str, float]] = {} |
| fid_idx = fid_idx.view(-1).long() |
| for i, fname in enumerate(fid_names): |
| sel = (fid_idx == i) |
| if sel.any(): |
| p = pred[sel] |
| y = target[sel] |
| m = mask[sel] |
| out[fname] = masked_metrics_overall(p, y, m) |
| else: |
| out[fname] = {"rmse": float("nan"), "mae": float("nan"), "r2": float("nan")} |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def gaussian_nll(mu: torch.Tensor, logvar: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
| """ |
| Element-wise Gaussian NLL (no reduction). |
| Shapes: mu, logvar, target -> [B, T] (or broadcastable). |
| """ |
| logvar = torch.as_tensor(logvar, device=mu.device, dtype=mu.dtype) |
| logvar = logvar.clamp(min=-20.0, max=20.0) |
| var = torch.exp(logvar) |
| err2_over_var = (target - mu) ** 2 / var |
| nll = 0.5 * (err2_over_var + logvar + math.log(2.0 * math.pi)) |
| return nll |
|
|
|
|
| def loss_multitask_fidelity( |
| *, |
| pred: torch.Tensor, |
| target: torch.Tensor, |
| mask: torch.Tensor, |
| fid_idx: torch.Tensor, |
| fid_loss_w: Sequence[float] | torch.Tensor | None, |
| task_weights: Optional[Sequence[float] | torch.Tensor] = None, |
| hetero_logvar: Optional[torch.Tensor] = None, |
| reduction: Literal["mean", "sum"] = "mean", |
| task_log_sigma2: Optional[torch.Tensor] = None, |
| balanced: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Multi-task, multi-fidelity loss with *balanced per-task reduction* by default. |
| |
| - If `hetero_logvar` is given: uses Gaussian NLL per element. |
| - Applies per-fidelity weights via `fid_idx`. |
| - Balanced reduction: compute mean loss per task first, then average across tasks |
| (optionally weight by `task_weights` or learned uncertainty `task_log_sigma2`). |
| - If `balanced=False`, uses legacy global reduction. |
| """ |
| B, T = pred.shape |
| pred = pred.float() |
| target = target.float() |
| mask = mask.bool() |
| fid_idx = fid_idx.view(-1).long() |
|
|
| |
| if task_weights is None: |
| tw = pred.new_ones(T) |
| else: |
| tw = torch.as_tensor(task_weights, dtype=pred.dtype, device=pred.device) |
| assert tw.numel() == T, f"task_weights len {tw.numel()} != T {T}" |
| s = tw.sum().clamp(min=1e-12) |
| tw = tw * (T / s) |
|
|
| |
| if fid_loss_w is None: |
| fw = pred.new_ones(int(fid_idx.max().item()) + 1) |
| else: |
| fw = torch.as_tensor(fid_loss_w, dtype=pred.dtype, device=pred.device) |
| w_fid = fw[fid_idx].unsqueeze(1).expand(-1, T) |
|
|
| |
| if hetero_logvar is not None: |
| elem_loss = gaussian_nll(pred, hetero_logvar.float(), target) |
| else: |
| elem_loss = (pred - target) ** 2 |
|
|
| if not balanced: |
| |
| w_task = tw.view(1, T).expand(B, -1) |
| weighted = elem_loss * mask * w_task * w_fid |
| if reduction == "sum": |
| return weighted.sum() |
| denom = (mask * w_task * w_fid).sum().float().clamp(min=1e-12) |
| return weighted.sum() / denom |
|
|
| |
| |
| num = (elem_loss * mask * w_fid).sum(dim=0) |
| den = (mask * w_fid).sum(dim=0).float().clamp(min=1e-12) |
| per_task_loss = num / den |
|
|
| |
| if task_weights is not None: |
| per_task_loss = per_task_loss * tw |
|
|
| |
| if task_log_sigma2 is not None: |
| assert task_log_sigma2.numel() == T, f"task_log_sigma2 must be [T], got {task_log_sigma2.shape}" |
| sigma2 = torch.exp(task_log_sigma2) |
| per_task_loss = per_task_loss / (2.0 * sigma2) + 0.5 * torch.log(sigma2) |
|
|
| if reduction == "sum": |
| return per_task_loss.sum() |
| return per_task_loss.mean() |
|
|
|
|
| |
| |
| |
|
|
| def exp_weight_at_epoch( |
| epoch: int, |
| total_epochs: int, |
| schedule: Literal["none", "linear", "cosine"] = "none", |
| start: float = 0.6, |
| end: float = 1.0, |
| ) -> float: |
| """ |
| Returns the EXP loss weight for a given epoch under the chosen schedule. |
| """ |
| if schedule == "none": |
| return float(end) |
| epoch = max(0, min(epoch, total_epochs)) |
| if total_epochs <= 0: |
| return float(end) |
| t = epoch / float(total_epochs) |
| if schedule == "linear": |
| return float(start + (end - start) * t) |
| if schedule == "cosine": |
| cos_t = 0.5 - 0.5 * math.cos(math.pi * t) |
| return float(start + (end - start) * cos_t) |
| raise ValueError(f"Unknown schedule: {schedule}") |
|
|
|
|
| def make_fid_loss_weights( |
| fids: Sequence[str], |
| base_weights: Optional[Sequence[float]] = None, |
| exp_weight: Optional[float] = None, |
| ) -> List[float]: |
| """ |
| Builds a per-fidelity weight vector aligned with dataset.fids order. |
| If exp_weight is provided, it overrides the weight for the 'exp' fidelity. |
| If base_weights is provided, it must match len(fids) and is used as a template. |
| """ |
| fids_lc = [f.lower() for f in fids] |
| F = len(fids_lc) |
| if base_weights is None: |
| w = [1.0] * F |
| else: |
| assert len(base_weights) == F, f"base_weights len {len(base_weights)} != {F}" |
| w = [float(x) for x in base_weights] |
| if exp_weight is not None and "exp" in fids_lc: |
| idx = fids_lc.index("exp") |
| w[idx] = float(exp_weight) |
| return w |
|
|
|
|
| |
| |
| |
|
|
| def apply_inverse_transform(pred: torch.Tensor, scaler): |
| """ |
| Apply inverse target scaling safely on the same device as pred. |
| Works for CPU/GPU and legacy scalers. |
| """ |
| dev = pred.device |
|
|
| |
| if hasattr(scaler, "mean") and scaler.mean.device != dev: |
| scaler.mean = scaler.mean.to(dev) |
| if hasattr(scaler, "std") and scaler.std.device != dev: |
| scaler.std = scaler.std.to(dev) |
| if hasattr(scaler, "eps") and scaler.eps is not None and scaler.eps.device != dev: |
| scaler.eps = scaler.eps.to(dev) |
|
|
| return scaler.inverse(pred) |
|
|
|
|
|
|
| def ensure_2d(x: torch.Tensor) -> torch.Tensor: |
| """Utility to guarantee [B, T] shape for single-task or squeezed outputs.""" |
| if x.dim() == 1: |
| return x.unsqueeze(1) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| |
| torch.manual_seed(0) |
| B, T = 5, 3 |
| pred = torch.randn(B, T) |
| targ = torch.randn(B, T) |
| mask = torch.rand(B, T) > 0.3 |
| fid_idx = torch.randint(0, 4, (B,)) |
| fid_w = [1.0, 0.8, 0.6, 0.5] |
| task_w = [1.0, 2.0, 1.0] |
|
|
| l1 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=task_w) |
| l2 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=None) |
| print("Loss with task weights:", float(l1)) |
| print("Loss without task weights:", float(l2)) |
|
|
| m_all = masked_metrics_overall(pred, targ, mask) |
| print("Overall metrics:", m_all) |
|
|