Trinity-Nano Pre-Compiled for AWS Inferentia2 (TP=1)
Pre-compiled and pre-sharded Trinity-Nano-Preview (~6B total, ~1B active MoE) for AWS Neuron SDK 2.28, ready to load on inf2.xlarge (16GB system RAM) or any larger Inferentia2/Trainium instance.
Why Pre-Sharded?
The standard NxDI load path downloads the full HuggingFace checkpoint (~12GB bf16) into CPU RAM for weight conversion and sharding. On inf2.xlarge (16GB system RAM), this causes an OOM kill at 15+ GB RSS.
Pre-sharded weights bypass this entirely — NxDI reads directly from the per-rank sharded files, using only 1.4 GB RSS (12.6% of system RAM).
Contents
| File | Size | Description |
|---|---|---|
model.pt |
49 MB | Compiled Neuron NEFF graphs |
neuron_config.json |
9 KB | NxDI configuration (TP=1, BS=1, seq_len=2048, bf16) |
weights/tp0_sharded_checkpoint.safetensors |
12 GB | Pre-sharded model weights for rank 0 |
Performance
Measured on inf2.xlarge (1 NeuronCore, 16GB system RAM):
| Metric | Value |
|---|---|
| TTFT | 706 ms |
| TKG (per token) | 9.0 ms |
| Throughput | 112 tok/s |
| Load time | 18.4 s |
| Peak RSS | 1.39 GB |
Quick Start
Prerequisites
- AWS instance with Inferentia2: inf2.xlarge, inf2.8xlarge, or larger
- Deep Learning AMI Neuron (Ubuntu 24.04) 20260227 (SDK 2.28)
- Activate the pre-installed venv:
source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13/bin/activate
1. Clone the model implementation
The Trinity Neuron implementation is not yet merged into the main NxDI repo. Use the contrib branch from the fork:
git clone --branch contrib/trinity-model --single-branch \
https://github.com/jimburtoft/neuronx-distributed-inference.git nxdi-trinity
2. Download this artifact and the base model config/tokenizer
from huggingface_hub import snapshot_download
# Download the pre-compiled artifact (model.pt + sharded weights)
snapshot_download("jburtoft/Trinity-Nano-Neuron-TP1",
local_dir="/home/ubuntu/Trinity-Nano-Neuron-TP1")
# Download config + tokenizer only (no model weights needed)
snapshot_download("arcee-ai/Trinity-Nano-Preview",
local_dir="/home/ubuntu/Trinity-Nano-Preview",
ignore_patterns=["*.safetensors", "*.bin", "*.pt", "*.gguf"])
3. Load and run inference
import sys
import torch
from transformers import AutoTokenizer
from neuronx_distributed_inference.models.config import MoENeuronConfig
# Point to the Trinity implementation from the cloned repo
sys.path.insert(0, "/home/ubuntu/nxdi-trinity/contrib/models/Trinity/src")
from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig
# Build model with save_sharded_checkpoint=True (must match compilation)
neuron_config = MoENeuronConfig(
tp_degree=1,
batch_size=1,
seq_len=2048,
torch_dtype=torch.bfloat16,
save_sharded_checkpoint=True,
)
config = TrinityInferenceConfig.from_pretrained(
"/home/ubuntu/Trinity-Nano-Preview",
neuron_config=neuron_config,
)
model = NeuronTrinityForCausalLM("/home/ubuntu/Trinity-Nano-Preview", config)
model.load("/home/ubuntu/Trinity-Nano-Neuron-TP1")
# Tokenize
tokenizer = AutoTokenizer.from_pretrained(
"/home/ubuntu/Trinity-Nano-Preview", trust_remote_code=True
)
prompt = "Hello, how are you today?"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids
# Generate
model.reset()
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0)
seq_ids = torch.arange(1)
with torch.no_grad():
outputs = model(input_ids, position_ids=position_ids, seq_ids=seq_ids)
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
next_token = torch.argmax(logits[:, -1, :], dim=-1)
print(f"Prompt: {prompt}")
print(f"Next token: {tokenizer.decode(next_token)}")
# Autoregressive generation
generated = [next_token.unsqueeze(0)]
for i in range(31):
pos = torch.tensor([[input_ids.shape[1] + i]])
with torch.no_grad():
outputs = model(generated[-1], position_ids=pos, seq_ids=seq_ids)
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
next_token = torch.argmax(logits[:, -1, :], dim=-1)
generated.append(next_token.unsqueeze(0))
text = tokenizer.decode(torch.cat(generated, dim=1)[0], skip_special_tokens=True)
print(f"Generated: {text}")
Compilation Details
| Parameter | Value |
|---|---|
| SDK | 2.28 (NxDI 0.8.16251, neuronx-cc 2.23.6484, torch-neuronx 2.9.0.2.12) |
| TP degree | 1 |
| Batch size | 1 |
| Sequence length | 2048 |
| Dtype | bfloat16 |
save_sharded_checkpoint |
True |
Compiling Your Own
To compile for different configurations (e.g., TP=2, BS=4), you need a larger instance (inf2.8xlarge or trn2.3xlarge):
import sys
import torch
from neuronx_distributed_inference.models.config import MoENeuronConfig
sys.path.insert(0, "/path/to/nxdi-trinity/contrib/models/Trinity/src")
from modeling_trinity import NeuronTrinityForCausalLM, TrinityInferenceConfig
neuron_config = MoENeuronConfig(
tp_degree=1, # Adjust as needed
batch_size=1, # Adjust as needed
seq_len=2048, # Adjust as needed
torch_dtype=torch.bfloat16,
save_sharded_checkpoint=True, # Required for pre-sharded deployment
)
config = TrinityInferenceConfig.from_pretrained(
"/path/to/Trinity-Nano-Preview", neuron_config=neuron_config
)
model = NeuronTrinityForCausalLM("/path/to/Trinity-Nano-Preview", config)
model.compile("/path/to/compiled-output")
# Output: model.pt, neuron_config.json, weights/tp{rank}_sharded_checkpoint.safetensors
Base Model
- Model: arcee-ai/Trinity-Nano-Preview
- Architecture: MoE (128 experts, top-8 active, 1 shared expert)
- Parameters: ~6B total, ~1B active per token
- License: Apache 2.0
Model Implementation
The NeuronX Distributed Inference implementation for Trinity is available at:
github.com/jimburtoft/neuronx-distributed-inference (branch: contrib/trinity-model)
This implementation supports all three Trinity model sizes (Nano, Mini, Large) with a single unified modeling_trinity.py.
Model tree for jburtoft/Trinity-Nano-Neuron-TP1
Base model
arcee-ai/Trinity-Nano-Base-Pre-Anneal