File size: 1,194 Bytes
f86dc09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python3
"""Minimal CPU chat with tilelli_chat_v4.pt — what the README points new users at.

Uses TilelliLiteLM.generate_with_cache so long prompts + replies stay within
the 256-byte context window. Greedy decoding, deliberately tiny."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent / "src"))

import torch
from tilelli.eval.metacog_probe import load_bridge

CKPT = Path(__file__).parent / "checkpoints" / "tilelli_chat_v4.pt"
MSG = sys.argv[1] if len(sys.argv) > 1 else "Hello, who are you?"
PROMPT = f"USER: {MSG}\nTILELLI:"
MAX_NEW = 120

model, _abstain, tok = load_bridge(str(CKPT))
ids = tok.encode(PROMPT).long().unsqueeze(0)

# Trim the prompt from the left so the prompt + MAX_NEW stays within the
# 256-byte context window the bundled v4 was trained on.
max_ctx = getattr(model, "max_seq_len", 256)
budget = max_ctx - MAX_NEW - 4
if ids.size(1) > budget:
    ids = ids[:, -budget:]

# Stop on newline (10) or null (0). generate_with_cache handles the rest.
with torch.no_grad():
    full, _generated, _confs = model.generate_with_cache(
        ids, n_new_tokens=MAX_NEW, stop_ids=(10, 0)
    )

print(tok.decode(full[0].tolist()))