| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
| from typing import Literal, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import ( |
| GINEConv, |
| GINConv, |
| GCNConv, |
| global_mean_pool, |
| global_add_pool, |
| global_max_pool, |
| ) |
|
|
|
|
| def get_activation(name: str) -> nn.Module: |
| name = name.lower() |
| if name == "relu": |
| return nn.ReLU() |
| if name == "gelu": |
| return nn.GELU() |
| if name == "silu": |
| return nn.SiLU() |
| if name in ("leaky_relu", "lrelu"): |
| return nn.LeakyReLU(0.1) |
| raise ValueError(f"Unknown activation: {name}") |
|
|
|
|
| class MLP(nn.Module): |
| """Small MLP used inside GNN layers and projections.""" |
| def __init__( |
| self, |
| in_dim: int, |
| hidden_dim: int, |
| out_dim: int, |
| num_layers: int = 2, |
| act: str = "relu", |
| dropout: float = 0.0, |
| bias: bool = True, |
| ): |
| super().__init__() |
| assert num_layers >= 1 |
| layers: list[nn.Module] = [] |
| dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim] |
| for i in range(len(dims) - 1): |
| layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias)) |
| if i < len(dims) - 2: |
| layers.append(get_activation(act)) |
| if dropout > 0: |
| layers.append(nn.Dropout(dropout)) |
| self.net = nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class NodeProjector(nn.Module): |
| """Projects raw node features to model embedding size.""" |
| def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"): |
| super().__init__() |
| if in_dim_node == emb_dim: |
| self.proj = nn.Identity() |
| else: |
| self.proj = nn.Sequential( |
| nn.Linear(in_dim_node, emb_dim), |
| get_activation(act), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.proj(x) |
|
|
|
|
| class EdgeProjector(nn.Module): |
| """Projects raw edge attributes to model embedding size for GINE.""" |
| def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"): |
| super().__init__() |
| if in_dim_edge <= 0: |
| raise ValueError("in_dim_edge must be > 0 when using edge attributes") |
| self.proj = nn.Sequential( |
| nn.Linear(in_dim_edge, emb_dim), |
| get_activation(act), |
| ) |
|
|
| def forward(self, e: torch.Tensor) -> torch.Tensor: |
| return self.proj(e) |
|
|
|
|
| class GNNEncoder(nn.Module): |
| """ |
| Backbone GNN with selectable conv type. |
| |
| gnn_type: |
| - "gine": chemistry-ready, uses edge_attr (recommended) |
| - "gin" : ignores edge_attr, strong node MPNN |
| - "gcn" : ignores edge_attr, fast spectral conv |
| norm: "batch" | "layer" | "none" |
| readout: "mean" | "sum" | "max" |
| """ |
|
|
| def __init__( |
| self, |
| in_dim_node: int, |
| emb_dim: int, |
| num_layers: int = 5, |
| gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| in_dim_edge: int = 0, |
| act: str = "relu", |
| dropout: float = 0.0, |
| residual: bool = True, |
| norm: Literal["batch", "layer", "none"] = "batch", |
| readout: Literal["mean", "sum", "max"] = "mean", |
| ): |
| super().__init__() |
| assert num_layers >= 1 |
|
|
| self.gnn_type = gnn_type.lower() |
| self.emb_dim = emb_dim |
| self.num_layers = num_layers |
| self.residual = residual |
| self.dropout_p = float(dropout) |
| self.readout = readout.lower() |
|
|
| self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act) |
| self.edge_proj: Optional[EdgeProjector] = None |
|
|
| if self.gnn_type == "gine": |
| if in_dim_edge <= 0: |
| raise ValueError( |
| "gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type." |
| ) |
| self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act) |
|
|
| |
| self.convs = nn.ModuleList() |
| self.norms = nn.ModuleList() |
|
|
| for _ in range(num_layers): |
| if self.gnn_type == "gine": |
| |
| nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) |
| conv = GINEConv(nn_mlp) |
| elif self.gnn_type == "gin": |
| nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0) |
| conv = GINConv(nn_mlp) |
| elif self.gnn_type == "gcn": |
| conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True) |
| else: |
| raise ValueError(f"Unknown gnn_type: {gnn_type}") |
| self.convs.append(conv) |
|
|
| if norm == "batch": |
| self.norms.append(nn.BatchNorm1d(emb_dim)) |
| elif norm == "layer": |
| self.norms.append(nn.LayerNorm(emb_dim)) |
| elif norm == "none": |
| self.norms.append(nn.Identity()) |
| else: |
| raise ValueError(f"Unknown norm: {norm}") |
|
|
| self.act = get_activation(act) |
|
|
| def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: |
| if self.readout == "mean": |
| return global_mean_pool(x, batch) |
| if self.readout == "sum": |
| return global_add_pool(x, batch) |
| if self.readout == "max": |
| return global_max_pool(x, batch) |
| raise ValueError(f"Unknown readout: {self.readout}") |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| edge_index: torch.Tensor, |
| edge_attr: Optional[torch.Tensor], |
| batch: Optional[torch.Tensor], |
| ) -> torch.Tensor: |
| """ |
| Returns a graph-level embedding of shape [B, emb_dim]. |
| If batch is None, assumes a single graph and creates a zero batch vector. |
| """ |
| if batch is None: |
| batch = x.new_zeros(x.size(0), dtype=torch.long) |
|
|
| |
| x = x.float() |
| x = self.node_proj(x) |
|
|
| e = None |
| if self.gnn_type == "gine": |
| if edge_attr is None: |
| raise ValueError("GINE requires edge_attr, but got None.") |
| e = self.edge_proj(edge_attr.float()) |
|
|
| |
| h = x |
| for conv, norm in zip(self.convs, self.norms): |
| if self.gnn_type == "gcn": |
| h_next = conv(h, edge_index) |
| elif self.gnn_type == "gin": |
| h_next = conv(h, edge_index) |
| else: |
| h_next = conv(h, edge_index, e) |
|
|
| h_next = norm(h_next) |
| h_next = self.act(h_next) |
|
|
| if self.residual and h_next.shape == h.shape: |
| h = h + h_next |
| else: |
| h = h_next |
|
|
| if self.dropout_p > 0: |
| h = F.dropout(h, p=self.dropout_p, training=self.training) |
|
|
| g = self._readout(h, batch) |
| return g |
|
|
|
|
| def build_gnn_encoder( |
| in_dim_node: int, |
| emb_dim: int, |
| num_layers: int = 5, |
| gnn_type: Literal["gine", "gin", "gcn"] = "gine", |
| in_dim_edge: int = 0, |
| act: str = "relu", |
| dropout: float = 0.0, |
| residual: bool = True, |
| norm: Literal["batch", "layer", "none"] = "batch", |
| readout: Literal["mean", "sum", "max"] = "mean", |
| ) -> GNNEncoder: |
| """ |
| Factory to create a GNNEncoder with a consistent, minimal API. |
| Prefer calling this from model.py so encoder construction is centralized. |
| """ |
| return GNNEncoder( |
| in_dim_node=in_dim_node, |
| emb_dim=emb_dim, |
| num_layers=num_layers, |
| gnn_type=gnn_type, |
| in_dim_edge=in_dim_edge, |
| act=act, |
| dropout=dropout, |
| residual=residual, |
| norm=norm, |
| readout=readout, |
| ) |
|
|
|
|
| __all__ = ["GNNEncoder", "build_gnn_encoder"] |
|
|