| from __future__ import annotations |
|
|
| import re |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| from torch_geometric.data import Data |
|
|
| from src.data_builder import featurize_smiles, TargetScaler |
| from src.model import build_model |
| from src.utils import to_device, apply_inverse_transform |
|
|
|
|
| |
| |
| |
| POST_SCALE = { |
| "td": 1e-7, |
| "dif": 1e-5, |
| "visc": 1e-3, |
| } |
|
|
|
|
| def _load_scaler_compat(path: Path) -> TargetScaler: |
| blob = torch.load(path, map_location="cpu") |
| if "mean" not in blob or "std" not in blob: |
| raise RuntimeError(f"Unrecognized target_scaler format: {path}") |
|
|
| ts = TargetScaler( |
| transforms=blob.get("transforms", None), |
| eps=blob.get("eps", None), |
| ) |
| ts.load_state_dict({ |
| "mean": blob["mean"].float(), |
| "std": blob["std"].float(), |
| "transforms": blob.get("transforms", ts.transforms), |
| "eps": blob.get("eps", ts.eps), |
| }) |
| ts.targets = [str(t).lower() for t in blob.get("targets", [])] |
| return ts |
|
|
|
|
| def _infer_seed_from_name(path: Path) -> Optional[int]: |
| m = re.search(r"_([0-9]+)\.pt$", path.name) |
| return int(m.group(1)) if m else None |
|
|
|
|
| def _make_one_graph(smiles: str) -> Data: |
| x, edge_index, edge_attr = featurize_smiles(smiles) |
| d = Data( |
| x=x, |
| edge_index=edge_index, |
| edge_attr=edge_attr, |
| y=torch.zeros(1, 1), |
| y_mask=torch.zeros(1, 1, dtype=torch.bool), |
| fid_idx=torch.tensor([0], dtype=torch.long), |
| ) |
| d.smiles = smiles |
| return d |
|
|
|
|
| class SingleTaskEnsemblePredictor: |
| """ |
| Single-task ensemble: |
| models/single_models/{prop}_single_model_{seed}.pt |
| models/single_models/{prop}_single_scalar_{seed}.pt |
| """ |
|
|
| def __init__(self, models_dir: str = "models/single_models", device: str = "cpu"): |
| self.models_dir = Path(models_dir) |
| self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu") |
| self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {} |
|
|
| def available_seeds(self, prop: str) -> List[int]: |
| prop = prop.lower() |
| seeds = [] |
| for p in self.models_dir.glob(f"{prop}_single_model_*.pt"): |
| s = _infer_seed_from_name(p) |
| if s is not None: |
| seeds.append(s) |
| return sorted(set(seeds)) |
|
|
| def _load_one(self, prop: str, seed: int): |
| prop = prop.lower() |
| key = (prop, seed) |
| if key in self._cache: |
| return self._cache[key] |
|
|
| ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" |
| scaler_path = self.models_dir / f"{prop}_single_scalar_{seed}.pt" |
| if not ckpt_path.exists() or not scaler_path.exists(): |
| raise FileNotFoundError(f"Missing model/scaler for {prop} seed {seed}") |
|
|
| ckpt = torch.load(ckpt_path, map_location=self.device) |
| train_args = ckpt.get("args", {}) |
|
|
| scaler = _load_scaler_compat(scaler_path) |
| task_names = list(getattr(scaler, "targets", [])) or [prop] |
|
|
| meta = {"train_args": train_args, "task_names": task_names} |
| self._cache[key] = (None, scaler, meta) |
| return self._cache[key] |
|
|
| def _build_model_if_needed(self, prop: str, seed: int, in_dim_node: int, in_dim_edge: int): |
| prop = prop.lower() |
| key = (prop, seed) |
| model, scaler, meta = self._cache[key] |
| if model is not None: |
| return model, scaler, meta |
|
|
| train_args = meta["train_args"] |
| task_names = meta["task_names"] |
|
|
| ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt" |
| ckpt = torch.load(ckpt_path, map_location=self.device) |
| state_dict = ckpt["model"] |
|
|
| |
| if "fid_embed.weight" in state_dict: |
| num_fids = state_dict["fid_embed.weight"].shape[0] |
| else: |
| num_fids = 1 |
|
|
| model = build_model( |
| in_dim_node=in_dim_node, |
| in_dim_edge=in_dim_edge, |
| task_names=task_names, |
| num_fids=num_fids, |
| gnn_type=train_args.get("gnn_type", "gine"), |
| gnn_emb_dim=train_args.get("gnn_emb_dim", 256), |
| gnn_layers=train_args.get("gnn_layers", 5), |
| gnn_norm=train_args.get("gnn_norm", "batch"), |
| gnn_readout=train_args.get("gnn_readout", "mean"), |
| gnn_act=train_args.get("gnn_act", "relu"), |
| gnn_dropout=train_args.get("gnn_dropout", 0.0), |
| gnn_residual=train_args.get("gnn_residual", True), |
| fid_emb_dim=train_args.get("fid_emb_dim", 64), |
| use_film=train_args.get("use_film", True), |
| use_task_embed=train_args.get("use_task_embed", True), |
| task_emb_dim=train_args.get("task_emb_dim", 32), |
| head_hidden=train_args.get("head_hidden", 512), |
| head_depth=train_args.get("head_depth", 2), |
| head_act=train_args.get("head_act", "relu"), |
| head_dropout=train_args.get("head_dropout", 0.0), |
| heteroscedastic=train_args.get("heteroscedastic", False), |
| fid_emb_l2=0.0, |
| task_emb_l2=0.0, |
| use_task_uncertainty=train_args.get("task_uncertainty", False), |
| ).to(self.device) |
|
|
| model.load_state_dict(state_dict, strict=True) |
| model.eval() |
|
|
| self._cache[key] = (model, scaler, meta) |
| return model, scaler, meta |
|
|
| def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]: |
| prop = prop.lower() |
| seeds = self.available_seeds(prop) |
| if not seeds: |
| return None, None, {} |
|
|
| try: |
| g = _make_one_graph(smiles) |
| except Exception: |
| return None, None, {} |
|
|
| in_dim_node = g.x.shape[1] |
| in_dim_edge = g.edge_attr.shape[1] |
|
|
| per_seed: Dict[int, float] = {} |
| with torch.no_grad(): |
| for seed in seeds: |
| self._load_one(prop, seed) |
| model, scaler, meta = self._build_model_if_needed(prop, seed, in_dim_node, in_dim_edge) |
|
|
| batch = to_device(g, self.device) |
| out = model(batch) |
| pred_n = out["pred"] |
| pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1) |
| val = float(pred[0]) |
|
|
| |
| val *= POST_SCALE.get(prop, 1.0) |
|
|
| per_seed[seed] = val |
|
|
| vals = np.array(list(per_seed.values()), dtype=float) |
| mean = float(vals.mean()) |
| std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0 |
| return mean, std, per_seed |
|
|