Mochi / Mochi++

Pretrained checkpoints for Mochi and Mochi++ — a meta-learned few-shot graph foundation model that unifies node classification, link prediction, and graph classification under a single differentiable-ridge readout.

Source code: https://github.com/joaopedromattos/mochi

Contents

File Variant Seed
checkpoints/mochi++_s0.pt Mochi++ 0
checkpoints/mochi++_s1.pt Mochi++ 1
checkpoints/mochi++_s2.pt Mochi++ 2

All checkpoints use the paper-default configuration (latdim=512, gnn_layer=3, niter=2, ridge_lambda=10.0), trained on the 15-dataset link1 LP group plus NC={citeseer, pubmed, physics, computers} and GC={DD, ENZYMES, REDDIT-MULTI-5K} for 12 991 steps.

Quickstart

from mochi import Mochi, default_params, load_pretrained

model = Mochi(**default_params)
load_pretrained(model, seed=2)   # downloads from this repo and loads weights

Or via huggingface_hub directly:

from huggingface_hub import hf_hub_download
import torch
from mochi import Mochi, default_params

path = hf_hub_download(repo_id="jrm28/mochi",
                       filename="checkpoints/mochi++_s2.pt")
model = Mochi(**default_params)
model.load_state_dict(torch.load(path, map_location="cpu"))

Citation

If you use these weights, please cite the Mochi paper.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support