maxsim banner

MaxSim

Introduction

A fast, memory-efficient exact MaxSim kernel for ColBERT-style late-interaction retrieval and training. It's a drop-in replacement for PyTorch's MaxSim that's ~10× faster in training and saves 98% of the backward memory. See benchmarks.

The kernel is packaged as a Hugging Face kernels repo so you can kernels.get_kernel(...) it in your own code with minimal dependencies.

The kernel computes

score(q, d) = sum_i max_j  <q_i, d_j>

over batches of (query, document) pairs without materialising the full [Lq, Ld] similarity matrix. Most of those values get discarded by max anyway so the kernel never computes them in forward, and saves only the surviving argmax positions ([Lq] int32 per pair) for backward. This is where the speed and memory savings come from.

Install

The kernel is ahead-of-time compiled and distributed as a Hugging Face kernel repo. This means that you can use the kernel without any compile step and just with a minimal dependency. Using the kernel is as easy as:

uv add kernels        # or: pip install kernels
from kernels import get_kernel
maxsim = get_kernel("erikkaum/maxsim", version=2, trust_remote_code=True)

Common workflows

The MaxSim kernel has three entrypoints with an inference and training variant of each:

  • Contrastive: every query scored against every document. This is the classic ColBERT-style paradigm that's mostly used in training and fine-tuning.
  • Padded: each query has a fixed number C of candidates, queries and documents are both padded to max-Lq / max-Ld. This is the most common inference workflow as a second stage exact rerank.
  • Packed: arbitrary (query, document) pair grids over ragged queries and documents. The lowest-level surface; the most flexible, but also the most boilerplate to set up. You build the offset arrays and pair-id arrays yourself.

The two most common ways you will use the kernel are contrastive training and padded reranking. The packed API is lower-level and is covered in the reference section below.

Contrastive training

Use this when training or fine-tuning a ColBERT-style model with in-batch negatives. score_contrastive_train scores every query in the batch against every document in the batch, producing an [Nq, Nb] score matrix.

These are the abbreviations used in this section:

Nq  = number of queries in the batch
Nb  = number of documents in the batch
q   = query index
d   = document index
Lq  = query length after padding/truncation
dim = embedding dimension from the encoder

For example, in the common in-batch setup:

scores[q, d] = MaxSim(queries[q], documents[d])

This scores query q against document d. The batch is arranged so that queries[i] and documents[i] are the positive pair for training example i, so the diagonal entries are positives and the off-diagonal entries are in-batch negatives:

    d=0    d=1    d=2    d=3
q=0  +      -      -      -
q=1  -      +      -      -
q=2  -      -      +      -
q=3  -      -      -      +

That is why the result can be passed directly to cross-entropy with diagonal targets, like:

targets = torch.arange(Nq, device=scores.device)
loss = torch.nn.functional.cross_entropy(scores, targets)

The contrastive API uses a mixed layout: queries are rectangular, while documents are packed.

Queries are usually short and padded/truncated to one fixed length for the batch. Each filled square is one token embedding:

queries: [Nq, Lq, dim]

     t0 t1 t2 t3 ... t31
q0   â–   â–   â–   â–        â– 
q1   â–   â–   â–   â–        â– 
q2   â–   â–   â–   â–        â– 

In contrastive mode, all Lq query positions participate in MaxSim; there is no separate query_lengths tensor. This matches the usual ColBERT training setup where queries are encoded to a fixed query length.

Documents are often longer and have more variable lengths. If we padded every document to the longest document in the batch, some of the work would go to padding:

padded documents: [Nb, max_Ld, dim]

     t0 t1 t2 ... t79 t80 ... t255
d0   â–   â–   â–       â–    â–        â–      256 tokens
d1   â–   â–   â–       â–    â–¡       â–¡      80 tokens
d2   â–   â–   â–   ... â–¡   â–¡       â–¡      64 tokens

â–  = real token embedding
â–¡ = padding

Instead, contrastive mode packs only the real document token embeddings into one flat tensor:

documents: [total_d_tokens, dim]

idx     0 1 2 ... 255 | 256 257 ... 335 | 336 337 ... 399
token   â–  â–  â–       â–   |  â–    â–        â–   |  â–    â–        â– 
doc          doc 0    |      doc 1      |      doc 2

The companion document_offsets tensor marks the boundaries between documents:

document_offsets = [0, 256, 336, 400]

document 0 = documents[0:256]     # 256 tokens
document 1 = documents[256:336]   # 80 tokens
document 2 = documents[336:400]   # 64 tokens

A typical training step looks like this:

import torch
import torch.nn.functional as F
from kernels import get_kernel

maxsim = get_kernel("erikkaum/maxsim", version=2, trust_remote_code=True)

# Encoder outputs.
queries = encoder_q(query_ids)      # [Nq, Lq, dim]
doc_emb = encoder_d(document_ids)   # [Nb, Ld, dim] before packing

# Naive/illustrative packing according to real token lengths.
documents = torch.cat(
    [doc_emb[i, :doc_lengths[i]] for i in range(Nb)],
    dim=0,
)  # [total_d_tokens, dim]

document_offsets = torch.zeros(Nb + 1, dtype=torch.int32, device=doc_emb.device)
document_offsets[1:] = doc_lengths.cumsum(0)

scores = maxsim.score_contrastive_train(
    queries,
    documents,
    document_offsets,
)  # [Nq, Nb] fp32

# Diagonal positives: query i should match document i.
targets = torch.arange(Nq, device=scores.device)

# MaxSim scores grow with Lq, so training loops often use a temperature/scale.
temperature = 1.0
loss = F.cross_entropy(scores / temperature, targets)
loss.backward()

Typical ColBERT-style values are:

Nq, Nb = 16 to 64
Lq     = 32
Ld     = 80 to 256 document tokens before packing
dim    = encoder output width, often 64 / 96 / 128

On CUDA, the contrastive training kernel currently uses the WMMA path and expects Lq % 16 == 0 and dim % 16 == 0. Typical ColBERT shapes such as Lq=32, dim=128 satisfy this.

The kernel saves only the winning document-token index per query token for backward, rather than retaining the full [Nq, Nb, Lq, Ld] similarity tensor. This is where most of the training-time memory saving comes from.

Padded inference

A typical ColBERT/late-interaction retrieval pipeline usually has two stages.

  1. A retriever such as PLAID first gives you K candidate documents for each query.
  2. Then rerank those K with exact MaxSim, through score_candidates_padded(...). This computes the exact MaxSim for every query-candidate pair and returns the scores as a [B, K] matrix.

The inputs come from your retriever, e.g. PLAID:

queries        [B, Lq, dim]       fp16/bf16
candidates     [B, K, Ld, dim]    same dtype as queries
query_lengths  [B]                int32, real per-query token count (≤ Lq)
doc_lengths    [B, K]             int32, real per-doc token count   (≤ Ld)

Lq and Ld are the padded lengths of the tensors, while query_lengths and doc_lengths are the actual per-query and per-document token counts. Padding positions are ignored when computing MaxSim. In practice, Lq is usually the longest query length in the batch, and Ld is the longest document length in the batch.

These are the abbreviations used in this section:

B   = number of queries in the batch
K   = number of candidate documents per query
Lq  = padded query length
Ld  = padded document length
dim = embedding dimension from your encoder, e.g. 64 / 96 / 128

We feed these into the score_candidates_padded, which returns a [B,K] matrix with dtype fp32, regardless of the input dtype.

# Exact MaxSim score per (query, candidate)
scores = maxsim.score_candidates_padded(queries, candidates, query_lengths, doc_lengths)

And sort them to get the reranked order

reranked_ids = scores.argsort(dim=1, descending=True)

For B=32, this runs ~1.2–3× faster than PyTorch einsum+max at K=50 and ~2.5–5× at K=100. The padded kernel's advantage grows with the candidate count K. See Benchmarks.

Reference

Inference

# Padded: per-query reranking with a fixed K candidates/documents each. Inputs are
# padded to max-Lq / max-Ld; *_lengths give the real per-row length.
scores = maxsim.score_candidates_padded(
    queries,        # [B, Lq, dim]      fp32 / fp16 / bf16
    documents,      # [B, K, Ld, dim]   same dtype as queries
    query_lengths,  # [B]               int32 — real query length (≤ Lq)
    doc_lengths,    # [B, K]            int32 — real doc length   (≤ Ld)
)  # -> [B, K] fp32

# Contrastive: every query × every document. Returns the full score
# matrix [Nq, Nb] — what most contrastive losses (InfoNCE, etc.) want.
scores = maxsim.score_contrastive(
    queries,           # [Nq, Lq, dim]          fp16 / bf16 on CUDA; fp32 also on Metal
    documents,         # [total_d_tokens, dim]  flat (all docs concatenated)
    document_offsets,  # [Nb + 1]   int32       documents[offsets[i]:offsets[i+1]] = doc i
)  # -> [Nq, Nb] fp32

# Packed: arbitrary (q, d) pair grids. You pick which pairs to score via
# pair_query_ids / pair_document_ids; the kernel computes only those.
scores = maxsim.score_pairs_packed(
    queries,            # [total_q_tokens, dim]   fp32 / fp16 / bf16
    query_offsets,      # [num_queries + 1]       int32 — CSR offsets into queries
    documents,          # [total_d_tokens, dim]   same dtype as queries
    document_offsets,   # [num_documents + 1]     int32 — CSR offsets into documents
    pair_query_ids,     # [num_pairs]             int32 — which query to score
    pair_document_ids,  # [num_pairs]             int32 — which document to score
)  # -> [num_pairs] fp32

Argmax helpers

Each inference API also has a _with_argmax variant. It returns the same fp32 scores plus the document-token index that won the max for each query token. These helpers are useful for debugging alignments, visualizing which document tokens matched, and inspecting the compact state used by the training kernels.

scores, argmax = maxsim.score_candidates_padded_with_argmax(
    queries,        # [B, Lq, dim]
    documents,      # [B, K, Ld, dim]
    query_lengths,  # [B]
    doc_lengths,    # [B, K]
)  # scores: [B, K] fp32; argmax: [B, K, Lq] int32

scores, argmax = maxsim.score_contrastive_with_argmax(
    queries,           # [Nq, Lq, dim]
    documents,         # [total_d_tokens, dim]
    document_offsets,  # [Nb + 1]
)  # scores: [Nq, Nb] fp32; argmax: [Nq, Nb, Lq] int32

scores, argmax = maxsim.score_pairs_packed_with_argmax(
    queries,
    query_offsets,
    documents,
    document_offsets,
    pair_query_ids,
    pair_document_ids,
)  # scores: [num_pairs] fp32; argmax: [num_pairs, max_q_len] int32

Training

Each inference API has a _train variant with the same input layout and output shape, but connected to PyTorch autograd. The _train variants return fp32 scores and propagate gradients to queries and documents; length, offset, and pair-id tensors are metadata and do not receive gradients. Backward uses fp32 accumulation.

# Padded training.
scores = maxsim.score_candidates_padded_train(
    queries,        # [B, Lq, dim]      fp32 / fp16 / bf16
    documents,      # [B, K, Ld, dim]   same dtype as queries
    query_lengths,  # [B]               int32 / int64
    doc_lengths,    # [B, K]            int32 / int64
)  # -> [B, K] fp32

# Contrastive training.
scores = maxsim.score_contrastive_train(
    queries,           # [Nq, Lq, dim]          fp16 / bf16 on CUDA; fp32 also on Metal
    documents,         # [total_d_tokens, dim]  same dtype as queries
    document_offsets,  # [Nb + 1]               int32 / int64
)  # -> [Nq, Nb] fp32

# Packed training.
scores = maxsim.score_pairs_packed_train(
    queries,            # [total_q_tokens, dim]   fp32 / fp16 / bf16
    query_offsets,      # [num_queries + 1]       int32 / int64
    documents,          # [total_d_tokens, dim]   same dtype as queries
    document_offsets,   # [num_documents + 1]     int32 / int64
    pair_query_ids,     # [num_pairs]             int32 / int64
    pair_document_ids,  # [num_pairs]             int32 / int64
)  # -> [num_pairs] fp32

Example scripts

The examples/ directory is the quickest way to run some code locally:

just example 01   # quickstart: padded inference + padded training
just example 02   # argmax / alignment visualization
just example 03   # tiny contrastive InfoNCE training loop
just example 04   # retained-memory showcase vs naive PyTorch

Benchmarks

All numbers below are medians over 30 iterations × 3 independent runs (MAXSIM_BENCH_REPEATS=3), generated by just cuda-bench-matrix. The baseline is a straightforward PyTorch implementation that uses torch.einsum to materialise the full similarity tensor before max. Raw outputs are committed under bench_results/v2/.

ColBERT training on H200

The headline workload uses a common ColBERT-style in-batch contrastive shape: Nq=Nb=32, Lq=32, Ld=80, dim=128. ~10× faster training step, with 1/80th the retained backward state.

dtype maxsim step naive step full step× bwd alone× peak mem× retained state
fp16 0.261 ms 2.697 ms 10.33× 12.86× 1.27× 1/80
bf16 0.272 ms 2.704 ms 9.94× 12.22× 1.27× 1/80

NVIDIA H200 (PTX-JIT). Run-to-run spread over 3 repeats: 0.33× (bf16), 1.20× (fp16). The fp16 step is noisier at this small grid on Hopper. Raw JSON artifact: bench_results/v2/h200.json.

The backward speedup is the load-bearing one — training is backward-dominated, and that's where the kernel's argmax-only save pays off:

naive contrastive backward:  [Nq, Nb, Lq, Ld] fp32 similarities
maxsim contrastive backward: [Nq, Nb, Lq]     int32 argmax
maxsim retained state = 1/Ld of naive
Ld naive retained maxsim retained maxsim / naive
80 10.5 MB 0.13 MB 1/80
128 16.8 MB 0.13 MB 1/128
256 33.6 MB 0.13 MB 1/256
512 67.1 MB 0.13 MB 1/512
1024 134.2 MB 0.13 MB 1/1024

Cross-Device Contrastive fp16

This table is rendered from the GPU JSON artifacts available under bench_results/v2/.

GPU sm maxsim step naive step speedup
H200 (PTX-JIT) sm_90 0.261 ms 2.697 ms 10.33×
A100 SXM4 80GB sm_80 0.378 ms 4.516 ms 11.94×
L4 sm_89 0.360 ms 3.340 ms 9.28×
A10G sm_86 0.433 ms 4.086 ms 9.43×

Full H200 benchmark matrix

The full matrix includes every benchmark workload present in the rendered JSON artifact.

Surface Preset Shape dtype maxsim PyTorch speedup padded bwd× peak× retained state
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 fp16 0.261 ms 2.697 ms 10.33× — 12.86× 1.27× 1/80
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 bf16 0.272 ms 2.704 ms 9.94× — 12.22× 1.27× 1/80
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 fp16 0.270 ms 2.703 ms 10.01× — 25.69× 2.22× 1/512
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 bf16 0.263 ms 2.713 ms 10.32× — 32.21× 2.22× 1/512
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 fp16 0.320 ms 4.906 ms 15.34× — 43.90× 2.47× 1/128
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 bf16 0.321 ms 4.911 ms 15.30× — 44.26× 2.47× 1/128
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 0.327 ms 0.381 ms 1.17× — — 2.32× —
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 0.338 ms 0.382 ms 1.13× — — 2.32× —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 0.335 ms 0.826 ms 2.47× — — 2.89× —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 0.374 ms 0.830 ms 2.22× — — 2.89× —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 5.527 ms 0.380 ms 0.07× 0.328 ms — 2.32× —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 5.533 ms 0.382 ms 0.07× 0.338 ms — 2.32× —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 14.640 ms 0.825 ms 0.06× 0.333 ms — 2.89× —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 14.654 ms 0.830 ms 0.06× 0.372 ms — 2.89× —

Full A100 benchmark matrix

Surface Preset Shape dtype maxsim PyTorch speedup padded bwd× peak× retained state
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 fp16 0.378 ms 4.516 ms 11.94× — 18.73× 1.78× 1/80
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 bf16 0.378 ms 4.512 ms 11.95× — 18.66× 1.78× 1/80
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 fp16 0.422 ms 4.517 ms 10.69× — 47.36× 3.24× 1/512
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 bf16 0.424 ms 4.510 ms 10.64× — 48.17× 3.24× 1/512
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 fp16 0.505 ms 8.286 ms 16.41× — 48.09× 4.15× 1/128
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 bf16 0.515 ms 8.260 ms 16.04× — 47.61× 4.15× 1/128
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 0.543 ms 0.961 ms 1.77× — — 3.01× —
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 0.540 ms 0.964 ms 1.79× — — 3.01× —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 0.592 ms 2.469 ms 4.17× — — 3.30× —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 0.576 ms 2.475 ms 4.30× — — 3.30× —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 8.976 ms 0.961 ms 0.11× 0.542 ms — 3.01× —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 8.970 ms 0.964 ms 0.11× 0.540 ms — 3.01× —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 24.709 ms 2.468 ms 0.10× 0.589 ms — 3.30× —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 24.718 ms 2.473 ms 0.10× 0.573 ms — 3.30× —

Packed rows compare score_pairs_packed against both the equivalent padded kernel layout and the PyTorch einsum baseline.

Apple Silicon (Metal) benchmark matrix

Maxsim ships a Metal backend. The same workloads as above, measured on the MPS build against the same naive PyTorch baseline. Peak-memory figures aren't captured on MPS, so the peak× column shows —.

Surface Preset Shape dtype maxsim PyTorch speedup padded bwd× peak× retained state
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 fp16 0.925 ms 4.945 ms 5.34× — — — 1/80
contrastive_train Contrastive Nq=32, Nb=32, Lq=32, Ld=80, D=128 bf16 0.863 ms 4.738 ms 5.49× — — — 1/80
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 fp16 2.471 ms 15.966 ms 6.46× — 61.52× — 1/512
contrastive_train LongDocs Nq=32, Nb=32, Lq=32, Ld=512, D=128 bf16 2.455 ms 15.830 ms 6.45× — 46.16× — 1/512
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 fp16 2.501 ms 15.527 ms 6.21× — 56.26× — 1/128
contrastive_train BigBatch Nq=64, Nb=64, Lq=32, Ld=128, D=128 bf16 2.485 ms 15.485 ms 6.23× — 71.39× — 1/128
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 1.693 ms 6.229 ms 3.68× — — — —
padded_infer Rerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 1.688 ms 5.867 ms 3.48× — — — —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 3.255 ms 16.490 ms 5.07× — — — —
padded_infer HeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 3.311 ms 16.440 ms 4.96× — — — —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 fp16 2.529 ms 5.897 ms 2.33× 1.697 ms — — —
packed_infer PackedRerank B=32, K=50, Lq=32, Ld=180, D=128 bf16 2.566 ms 6.085 ms 2.37× 1.686 ms — — —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 fp16 4.335 ms 16.524 ms 3.81× 3.324 ms — — —
packed_infer PackedHeavyRerank B=32, K=100, Lq=32, Ld=256, D=128 bf16 4.306 ms 16.533 ms 3.84× 3.273 ms — — —

Reproducing

just cuda-bench-matrix a100-large 3 30   # flavor, repeats, iters
just cuda-bench-matrix l40sx1     3 30   # optional additional GPU
just cuda-bench-matrix a10g-small 3 30   # optional additional GPU
just bench-matrix-metal           3 30   # optional Apple Silicon / MPS run
just bench-render                        # regenerate README tables from JSON

The CUDA bench recipe submits an HF Jobs run; bench-matrix-metal runs the same matrix against the Metal/MPS build. Both write one JSON artifact per device class to bench_results/v2/. The artifacts capture the kernel commit, device/torch/CUDA versions, timing + variance per workload, and a structured shape per row. just bench-render rewrites the tables above from those JSONs — same data, no copy-paste drift.

Supported

Backend Devices Input dtypes Backward accumulation
Metal Apple Silicon (MPS, MSL 3.1+) fp32 / fp16 / bf16 fp32
CUDA sm_80, sm_86, sm_89, sm_90 (PTX JIT) fp32 / fp16 / bf16 fp32

The table is package-level support. A few API-specific constraints are listed below.

Limitations and edge cases

  • CUDA contrastive APIs (score_contrastive, score_contrastive_train, and score_contrastive_with_argmax) currently use the WMMA path directly. They require fp16/bf16 inputs, Lq % 16 == 0, and dim % 16 == 0.
  • Padded inference (score_candidates_padded) accepts fp32/fp16/bf16 and arbitrary Lq, Ld, and dim on both backends. CUDA uses the WMMA fast path for fp16/bf16 when Lq % 16 == 0 and dim % 16 == 0; otherwise it uses the scalar path.
  • Padded training on CUDA uses the argmax WMMA path and currently requires fp16/bf16 inputs, Lq >= 16, Lq % 16 == 0, and dim % 16 == 0.
  • Metal handles arbitrary Lq for padded and contrastive paths. Its simdgroup_matrix path fires when dim % 8 == 0; otherwise it uses scalar work. Very large Lq * dim shapes can hit Metal threadgroup-memory limits.
  • Packed CUDA paths are scalar today: score_pairs_packed uses a scalar score-only kernel, and packed training uses scalar argmax/backward kernels rather than WMMA. Which means that there isn't really a performance win over native PyTorch. Use packed for arbitrary or sparse pair grids; for fixed-K reranking, prefer the padded API.
  • Backward kernels accumulate gradients with fp32 atomics. The results are numerically stable for training, but not bit-exact deterministic across repeated backward passes because floating-point additions can arrive in a different order.
  • NaN inputs do not follow PyTorch's exact torch.max NaN propagation semantics. Kernel comparisons use v > max, so NaN candidates are skipped. A NaN in a query token usually drives the whole score for that pair to -inf; a NaN in one document token may simply be ignored if another document token wins the max. Treat non-finite encoder outputs as upstream bugs and pre-screen with torch.isfinite(...) if you need strict behavior.
  • Hopper (sm_90) currently runs via PTX forward compatibility. Native WGMMA tuning is future work.

Related projects and acknowledgements

  • Flash-MaxSim is a fused Triton MaxSim kernel for ColBERT/ColPali scoring.

Source

Source: https://github.com/erikkaum/maxsim.

Kernel repo: https://huggingface.co/kernels/erikkaum/maxsim.

License: Apache-2.0.

Downloads last month
40
apache-2.0
Supported hardwares new
CUDA 8.08.68.99.0+PTX
NVIDIA SXM
H200
141GB
NVIDIA SXM
H100
80GB
GPU
L40s
48GB
GPU
L40
48GB
GPU
L20
48GB
GPU
L4
24GB
GPU
RTX 6000 Ada
48GB
GPU
RTX 5880 Ada
48GB
RTX
RTX 5000 Ada
32GB
GPU
RTX 4500 Ada
24GB
RTX
RTX 4000 Ada
20GB
RTX
RTX 4000 SFF Ada
20GB
GPU
RTX 2000 Ada
16GB
GPU
RTX A6000
48GB
GPU
RTX A5000
8GB
GPU
RTX A5000 Max-Q
16GB
GPU
RTX A5000 Mobile
16GB
GPU
RTX A4000
16GB
GPU
RTX A4000 Max-Q
8GB
GPU
RTX A4000 Mobile
8GB
GPU
RTX A3000 Mobile
6GB
GPU
RTX A2000
6GB
GPU
RTX A2000 Embedded
4GB
GPU
RTX A2000 Max-Q
4GB
GPU
RTX A2000 Mobile
4GB
GPU
A100
80GB
GPU
A40
48GB
GPU
A30
24GB
GPU
A10
24GB
GPU
A2
16GB
RTX
RTX 4090
24GB
RTX
RTX 4090D
24GB
RTX
RTX 4090 Mobile
16GB
RTX
RTX 4080 SUPER
16GB
RTX
RTX 4080
16GB
RTX
RTX 4080 Mobile
12GB
RTX
RTX 4070
12GB
RTX
RTX 4070 Mobile
8GB
RTX
RTX 4070 Ti
12GB
RTX
RTX 4070 Super
12GB
RTX
RTX 4070 Ti Super
16GB
RTX
RTX 4060
8GB
RTX
RTX 4060 Ti
8GB
RTX
RTX 4090 Laptop
16GB
RTX
RTX 4080 Laptop
12GB
RTX
RTX 4070 Laptop
8GB
RTX
RTX 4060 Laptop
8GB
RTX
RTX 4050 Laptop
6GB
RTX
RTX 3090
24GB
RTX
RTX 3090 Ti
24GB
RTX
RTX 3080
12GB
RTX
RTX 3080 Ti
12GB
RTX
RTX 3080 Mobile
16GB
RTX
RTX 3070
8GB
RTX
RTX 3070 Ti
8GB
RTX
RTX 3070 Ti Mobile
8GB
RTX
RTX 3060 Ti
8GB
RTX
RTX 3060
12GB
RTX
RTX 3060 Mobile
6GB
RTX
RTX 3050 Mobile
4GB
GPU
RTX 2050 Mobile
4GB
Jetson
Jetson AGX Orin 64GB
64GB
Jetson
Jetson AGX Orin 32GB
32GB
Jetson
Jetson Orin NX 16GB
16GB
Jetson
Jetson Orin NX 8GB
8GB
Jetson
Jetson Orin Nano 8GB
8GB
Jetson
Jetson Orin Nano 4GB
4GB
Metal
Apple Silicon
Apple MacBook Neo
8GB
Apple Silicon
Apple M1
8GB
Apple Silicon Pro
Apple M1 Pro
16GB
Apple Silicon Max
Apple M1 Max
16GB
Apple Silicon Ultra
Apple M1 Ultra
16GB
Apple Silicon
Apple M2
8GB
Apple Silicon Pro
Apple M2 Pro
16GB
Apple Silicon Max
Apple M2 Max
32GB
Apple Silicon Ultra
Apple M2 Ultra
64GB
Apple Silicon
Apple M3
8GB
Apple Silicon Pro
Apple M3 Pro
18GB
Apple Silicon Max
Apple M3 Max
36GB
Apple Silicon Ultra
Apple M3 Ultra
96GB
Apple Silicon
Apple M4
16GB
Apple Silicon Pro
Apple M4 Pro
24GB
Apple Silicon Max
Apple M4 Max
36GB
Apple Silicon
Apple M5
16GB
Apple Silicon Pro
Apple M5 Pro
24GB
Apple Silicon Max
Apple M5 Max
36GB
OS
macoslinux
Arch
x86_64aarch64