Trinity-Nano Pre-Compiled for AWS Inferentia2 (TP=1)

Pre-compiled and pre-sharded Trinity-Nano-Preview (~6B total, ~1B active MoE) for AWS Neuron SDK 2.28, ready to load on inf2.xlarge (16GB system RAM) or any larger Inferentia2/Trainium instance.

Why Pre-Sharded?

The standard NxDI load path downloads the full HuggingFace checkpoint (~12GB bf16) into CPU RAM for weight conversion and sharding. On inf2.xlarge (16GB system RAM), this causes an OOM kill at 15+ GB RSS.

Pre-sharded weights bypass this entirely — NxDI reads directly from the per-rank sharded files, using only 1.4 GB RSS (12.6% of system RAM).

Contents

File Size Description
model.pt 49 MB Compiled Neuron NEFF graphs
neuron_config.json 9 KB NxDI configuration (TP=1, BS=1, seq_len=2048, bf16)
weights/tp0_sharded_checkpoint.safetensors 12 GB Pre-sharded model weights for rank 0

Performance

Measured on inf2.xlarge (1 NeuronCore, 16GB system RAM):

Metric Value
TTFT 706 ms
TKG (per token) 9.0 ms
Throughput 112 tok/s
Load time 18.4 s
Peak RSS 1.39 GB

Quick Start

Prerequisites

1. Clone the model implementation

The Trinity Neuron implementation is not yet merged into the main NxDI repo. Use the contrib branch from the fork:

git clone --branch contrib/trinity-model --single-branch \
    https://github.com/jimburtoft/neuronx-distributed-inference.git nxdi-trinity

2. Download this artifact and the base model config/tokenizer

from huggingface_hub import snapshot_download

# Download the pre-compiled artifact (model.pt + sharded weights)
snapshot_download("jburtoft/Trinity-Nano-Neuron-TP1",
                  local_dir="/home/ubuntu/Trinity-Nano-Neuron-TP1")

# Download config + tokenizer only (no model weights needed)
snapshot_download("arcee-ai/Trinity-Nano-Preview",
                  local_dir="/home/ubuntu/Trinity-Nano-Preview",
                  ignore_patterns=["*.safetensors", "*.bin", "*.pt", "*.gguf"])

3. Load and run inference

import sys
import torch
from transformers import AutoTokenizer
from neuronx_distributed_inference.models.config import MoENeuronConfig

# Point to the Trinity implementation from the cloned repo
sys.path.insert(0, "/home/ubuntu/nxdi-trinity/contrib/models/Trinity/src")
from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig

# Build model with save_sharded_checkpoint=True (must match compilation)
neuron_config = MoENeuronConfig(
    tp_degree=1,
    batch_size=1,
    seq_len=2048,
    torch_dtype=torch.bfloat16,
    save_sharded_checkpoint=True,
)

config = TrinityInferenceConfig.from_pretrained(
    "/home/ubuntu/Trinity-Nano-Preview",
    neuron_config=neuron_config,
)

model = NeuronTrinityForCausalLM("/home/ubuntu/Trinity-Nano-Preview", config)
model.load("/home/ubuntu/Trinity-Nano-Neuron-TP1")

# Tokenize
tokenizer = AutoTokenizer.from_pretrained(
    "/home/ubuntu/Trinity-Nano-Preview", trust_remote_code=True
)

prompt = "Hello, how are you today?"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids

# Generate
model.reset()
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0)
seq_ids = torch.arange(1)

with torch.no_grad():
    outputs = model(input_ids, position_ids=position_ids, seq_ids=seq_ids)

logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
next_token = torch.argmax(logits[:, -1, :], dim=-1)
print(f"Prompt: {prompt}")
print(f"Next token: {tokenizer.decode(next_token)}")

# Autoregressive generation
generated = [next_token.unsqueeze(0)]
for i in range(31):
    pos = torch.tensor([[input_ids.shape[1] + i]])
    with torch.no_grad():
        outputs = model(generated[-1], position_ids=pos, seq_ids=seq_ids)
    logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
    next_token = torch.argmax(logits[:, -1, :], dim=-1)
    generated.append(next_token.unsqueeze(0))

text = tokenizer.decode(torch.cat(generated, dim=1)[0], skip_special_tokens=True)
print(f"Generated: {text}")

Compilation Details

Parameter Value
SDK 2.28 (NxDI 0.8.16251, neuronx-cc 2.23.6484, torch-neuronx 2.9.0.2.12)
TP degree 1
Batch size 1
Sequence length 2048
Dtype bfloat16
save_sharded_checkpoint True

Compiling Your Own

To compile for different configurations (e.g., TP=2, BS=4), you need a larger instance (inf2.8xlarge or trn2.3xlarge):

import sys
import torch
from neuronx_distributed_inference.models.config import MoENeuronConfig

sys.path.insert(0, "/path/to/nxdi-trinity/contrib/models/Trinity/src")
from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig

neuron_config = MoENeuronConfig(
    tp_degree=1,       # Adjust as needed
    batch_size=1,      # Adjust as needed
    seq_len=2048,      # Adjust as needed
    torch_dtype=torch.bfloat16,
    save_sharded_checkpoint=True,  # Required for pre-sharded deployment
)

config = TrinityInferenceConfig.from_pretrained(
    "/path/to/Trinity-Nano-Preview", neuron_config=neuron_config
)
model = NeuronTrinityForCausalLM("/path/to/Trinity-Nano-Preview", config)
model.compile("/path/to/compiled-output")
# Output: model.pt, neuron_config.json, weights/tp{rank}_sharded_checkpoint.safetensors

Base Model

  • Model: arcee-ai/Trinity-Nano-Preview
  • Architecture: MoE (128 experts, top-8 active, 1 shared expert)
  • Parameters: ~6B total, ~1B active per token
  • License: Apache 2.0

Model Implementation

The NeuronX Distributed Inference implementation for Trinity is available at: github.com/jimburtoft/neuronx-distributed-inference (branch: contrib/trinity-model)

This implementation supports all three Trinity model sizes (Nano, Mini, Large) with a single unified modeling_trinity.py.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for jburtoft/Trinity-Nano-Neuron-TP1