| |
| from __future__ import annotations |
|
|
| from typing import List, Optional, Literal |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.data import Batch |
|
|
| from src.conv import build_gnn_encoder, GNNEncoder |
|
|
|
|
| def get_activation(name: str) -> nn.Module: |
| name = name.lower() |
| if name == "relu": |
| return nn.ReLU() |
| if name == "gelu": |
| return nn.GELU() |
| if name == "silu": |
| return nn.SiLU() |
| if name in ("leaky_relu", "lrelu"): |
| return nn.LeakyReLU(0.1) |
| raise ValueError(f"Unknown activation: {name}") |
|
|
|
|
| class FiLM(nn.Module): |
| """ |
| Simple FiLM: gamma, beta from condition vector; apply to features as (1+gamma)*h + beta |
| """ |
| def __init__(self, feat_dim: int, cond_dim: int): |
| super().__init__() |
| self.gamma = nn.Linear(cond_dim, feat_dim) |
| self.beta = nn.Linear(cond_dim, feat_dim) |
|
|
| def forward(self, h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: |
| g = self.gamma(cond) |
| b = self.beta(cond) |
| return (1.0 + g) * h + b |
|
|
|
|
| class TaskHead(nn.Module): |
| """ |
| Per-task MLP head. Input is concatenation of [graph_embed, optional task_embed]. |
| Outputs either a mean only (scalar) or mean+logvar (heteroscedastic). |
| """ |
| def __init__( |
| self, |
| in_dim: int, |
| hidden_dim: int = 512, |
| depth: int = 2, |
| act: str = "relu", |
| dropout: float = 0.0, |
| heteroscedastic: bool = False, |
| ): |
| super().__init__() |
| layers: List[nn.Module] = [] |
| d = in_dim |
| for _ in range(depth): |
| layers.append(nn.Linear(d, hidden_dim)) |
| layers.append(get_activation(act)) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| d = hidden_dim |
| out_dim = 2 if heteroscedastic else 1 |
| layers.append(nn.Linear(d, out_dim)) |
| self.net = nn.Sequential(*layers) |
| self.hetero = heteroscedastic |
|
|
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| |
| return self.net(z) |
|
|
|
|
| class MultiTaskMultiFidelityModel(nn.Module): |
| """ |
| General multi-task, multi-fidelity GNN. |
| |
| - Any number of tasks (properties) via T = len(task_names) |
| - Any number of fidelities via num_fids |
| - Fidelity conditioning with an embedding and FiLM on the graph embedding |
| - Optional task embeddings concatenated into each task head input |
| - Single forward returning predictions [B, T] (means); if heteroscedastic, also returns log-variances |
| |
| Expected input Batch fields (PyG): |
| - x : [N_nodes, F_node] |
| - edge_index : [2, N_edges] |
| - edge_attr : [N_edges, F_edge] (required if gnn_type="gine") |
| - batch : [N_nodes] |
| - fid_idx : [B] or [B, 1] long; integer fidelity per graph |
| |
| Notes: |
| - Targets should already be normalized outside the model; apply inverse transform for plots. |
| - Loss weighting/equal-importance and curriculum happen in the trainer, not here. |
| """ |
|
|
| def __init__( |
| self, |
| in_dim_node: int, |
| in_dim_edge: int, |
| task_names: List[str], |
| num_fids: int, |
| gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| gnn_emb_dim: int = 256, |
| gnn_layers: int = 5, |
| gnn_norm: Literal["batch", "layer", "none"] = "batch", |
| gnn_readout: Literal["mean", "sum", "max"] = "mean", |
| gnn_act: str = "relu", |
| gnn_dropout: float = 0.0, |
| gnn_residual: bool = True, |
| |
| fid_emb_dim: int = 64, |
| use_film: bool = True, |
| |
| use_task_embed: bool = True, |
| task_emb_dim: int = 32, |
| |
| head_hidden: int = 512, |
| head_depth: int = 2, |
| head_act: str = "relu", |
| head_dropout: float = 0.0, |
| heteroscedastic: bool = False, |
| |
| use_task_uncertainty: bool = False, |
| |
| fid_emb_l2: float = 0.0, |
| task_emb_l2: float = 0.0, |
| ): |
| super().__init__() |
| self.task_names = list(task_names) |
| self.num_tasks = len(task_names) |
| self.num_fids = int(num_fids) |
| self.hetero = heteroscedastic |
| self.fid_emb_l2 = float(fid_emb_l2) |
| self.task_emb_l2 = float(task_emb_l2) |
| self.use_film = use_film |
| self.use_task_embed = use_task_embed |
|
|
| |
| self.use_task_uncertainty = bool(use_task_uncertainty) |
| if self.use_task_uncertainty: |
| self.task_log_sigma2 = nn.Parameter(torch.zeros(self.num_tasks)) |
| else: |
| self.task_log_sigma2 = None |
|
|
| |
| self.encoder: GNNEncoder = build_gnn_encoder( |
| in_dim_node=in_dim_node, |
| emb_dim=gnn_emb_dim, |
| num_layers=gnn_layers, |
| gnn_type=gnn_type, |
| in_dim_edge=in_dim_edge, |
| act=gnn_act, |
| dropout=gnn_dropout, |
| residual=gnn_residual, |
| norm=gnn_norm, |
| readout=gnn_readout, |
| ) |
|
|
| |
| self.fid_embed = nn.Embedding(self.num_fids, fid_emb_dim) if fid_emb_dim > 0 else None |
| self.film = FiLM(gnn_emb_dim, fid_emb_dim) if (use_film and fid_emb_dim > 0) else None |
|
|
| |
| |
| |
| self.gnn_out_dim = gnn_emb_dim + (fid_emb_dim if (self.fid_embed is not None and self.film is None) else 0) |
|
|
| |
| self.task_embed = nn.Embedding(self.num_tasks, task_emb_dim) if (use_task_embed and task_emb_dim > 0) else None |
|
|
| |
| head_in_dim = self.gnn_out_dim + (task_emb_dim if self.task_embed is not None else 0) |
| self.heads = nn.ModuleList([ |
| TaskHead( |
| in_dim=head_in_dim, |
| hidden_dim=head_hidden, |
| depth=head_depth, |
| act=head_act, |
| dropout=head_dropout, |
| heteroscedastic=heteroscedastic, |
| ) for _ in range(self.num_tasks) |
| ]) |
|
|
|
|
| def reset_parameters(self): |
| if self.fid_embed is not None: |
| nn.init.normal_(self.fid_embed.weight, mean=0.0, std=0.02) |
| if self.task_embed is not None: |
| nn.init.normal_(self.task_embed.weight, mean=0.0, std=0.02) |
| |
|
|
| def forward(self, data: Batch) -> dict: |
| """ |
| Returns: |
| { |
| "pred": [B, T] means, |
| "logvar": [B, T] optional if heteroscedastic, |
| "h": [B, D] graph embedding after FiLM (useful for diagnostics). |
| } |
| """ |
| x, edge_index = data.x, data.edge_index |
| edge_attr = getattr(data, "edge_attr", None) |
| batch = data.batch |
| if edge_attr is None and hasattr(self.encoder, "gnn_type") and self.encoder.gnn_type == "gine": |
| raise ValueError("GINE encoder requires edge_attr, but Batch.edge_attr is None.") |
|
|
| |
| g = self.encoder(x, edge_index, edge_attr, batch) |
|
|
| |
| fid_idx = data.fid_idx.view(-1).long() |
| if self.fid_embed is not None: |
| c = self.fid_embed(fid_idx) |
| if self.film is not None: |
| g = self.film(g, c) |
| else: |
| g = torch.cat([g, c], dim=-1) |
|
|
| |
| preds: List[torch.Tensor] = [] |
| logvars: Optional[List[torch.Tensor]] = [] if self.hetero else None |
| for t_idx, head in enumerate(self.heads): |
| if self.task_embed is not None: |
| tvec = self.task_embed.weight[t_idx].unsqueeze(0).expand(g.size(0), -1) |
| z = torch.cat([g, tvec], dim=-1) |
| else: |
| z = g |
| out = head(z) |
| if self.hetero: |
| mu = out[..., 0:1] |
| lv = out[..., 1:2] |
| preds.append(mu) |
| logvars.append(lv) |
| else: |
| preds.append(out) |
|
|
| pred = torch.cat(preds, dim=-1) |
| result = {"pred": pred, "h": g} |
| if self.hetero and logvars is not None: |
| result["logvar"] = torch.cat(logvars, dim=-1) |
| return result |
|
|
| def regularization_loss(self) -> torch.Tensor: |
| """ |
| Optional small L2 on embeddings to keep them bounded. |
| """ |
| device = next(self.parameters()).device |
| reg = torch.zeros([], device=device) |
| if self.fid_embed is not None and self.fid_emb_l2 > 0: |
| reg = reg + self.fid_emb_l2 * (self.fid_embed.weight.pow(2).mean()) |
| if self.task_embed is not None and self.task_emb_l2 > 0: |
| reg = reg + self.task_emb_l2 * (self.task_embed.weight.pow(2).mean()) |
| return reg |
|
|
|
|
| def build_model( |
| *, |
| in_dim_node: int, |
| in_dim_edge: int, |
| task_names: List[str], |
| num_fids: int, |
| gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| gnn_emb_dim: int = 256, |
| gnn_layers: int = 5, |
| gnn_norm: Literal["batch", "layer", "none"] = "batch", |
| gnn_readout: Literal["mean", "sum", "max"] = "mean", |
| gnn_act: str = "relu", |
| gnn_dropout: float = 0.0, |
| gnn_residual: bool = True, |
| fid_emb_dim: int = 64, |
| use_film: bool = True, |
| use_task_embed: bool = True, |
| task_emb_dim: int = 32, |
| head_hidden: int = 512, |
| use_task_uncertainty: bool = False, |
| head_depth: int = 2, |
| head_act: str = "relu", |
| head_dropout: float = 0.0, |
| heteroscedastic: bool = False, |
| fid_emb_l2: float = 0.0, |
| task_emb_l2: float = 0.0, |
| ) -> MultiTaskMultiFidelityModel: |
| """ |
| Factory to construct the multi-task, multi-fidelity model with a consistent API. |
| """ |
| return MultiTaskMultiFidelityModel( |
| in_dim_node=in_dim_node, |
| in_dim_edge=in_dim_edge, |
| task_names=task_names, |
| num_fids=num_fids, |
| gnn_type=gnn_type, |
| gnn_emb_dim=gnn_emb_dim, |
| gnn_layers=gnn_layers, |
| gnn_norm=gnn_norm, |
| gnn_readout=gnn_readout, |
| gnn_act=gnn_act, |
| gnn_dropout=gnn_dropout, |
| gnn_residual=gnn_residual, |
| fid_emb_dim=fid_emb_dim, |
| use_film=use_film, |
| use_task_embed=use_task_embed, |
| task_emb_dim=task_emb_dim, |
| head_hidden=head_hidden, |
| head_depth=head_depth, |
| head_act=head_act, |
| head_dropout=head_dropout, |
| heteroscedastic=heteroscedastic, |
| fid_emb_l2=fid_emb_l2, |
| task_emb_l2=task_emb_l2, |
| use_task_uncertainty=use_task_uncertainty, |
| ) |
|
|