| """ |
| Inference script for the 1B Transformer — Single GPU. |
| |
| Usage: |
| python inference.py # auto-finds latest checkpoint |
| python inference.py /path/to/checkpoint.pt # specific checkpoint |
| """ |
|
|
| import sys |
| import os |
| import glob |
| import time |
| import torch |
| import torch.nn.functional as F |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.config import ModelConfig |
| from model.transformer import Transformer |
| from model.data import get_tokenizer |
|
|
|
|
| def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"): |
| files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt")) |
| if not files: |
| final = os.path.join(checkpoint_dir, "final.pt") |
| return final if os.path.exists(final) else None |
| return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])) |
|
|
|
|
| def load_model(checkpoint_path, device="cuda:0"): |
| config = ModelConfig() |
| model = Transformer(config) |
|
|
| print(f"Loading checkpoint: {checkpoint_path}") |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| model = model.to(device).bfloat16().eval() |
|
|
| step = ckpt.get("step", "?") |
| loss = ckpt.get("loss", "?") |
| print(f" Step: {step} | Loss: {loss}") |
| print(f" Params: {sum(p.numel() for p in model.parameters()):,}") |
| print(f" Device: {device}") |
| del ckpt |
| torch.cuda.empty_cache() |
| return model, config |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer, prompt, max_new_tokens=200, |
| temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"): |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| t0 = time.time() |
|
|
| for i in range(max_new_tokens): |
| if input_ids.shape[1] >= model.config.max_seq_len: |
| break |
|
|
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| logits, _ = model(input_ids) |
|
|
| logits = logits[:, -1, :] / temperature |
|
|
| if top_k > 0: |
| topk_vals, _ = torch.topk(logits, top_k) |
| logits[logits < topk_vals[:, -1:]] = float("-inf") |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[mask] = float("-inf") |
| logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| if next_token.item() == tokenizer.eos_token_id: |
| break |
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| elapsed = time.time() - t0 |
| gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt)) |
| tok_per_sec = gen_tokens / max(elapsed, 1e-9) |
|
|
| text = tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| return text, gen_tokens, tok_per_sec |
|
|
|
|
| def main(): |
| device = "cuda:0" |
| if len(sys.argv) > 1: |
| checkpoint = sys.argv[1] |
| else: |
| checkpoint = find_latest_checkpoint() |
| if checkpoint is None: |
| print("No checkpoint found!") |
| sys.exit(1) |
|
|
| model, config = load_model(checkpoint, device) |
| tokenizer = get_tokenizer() |
|
|
| prompts = [ |
| "The meaning of life is", |
| "In machine learning, a neural network", |
| "The capital of France is", |
| "Once upon a time, there was a", |
| "To solve a quadratic equation, you need to", |
| "The theory of relativity explains that", |
| "Python is a programming language that", |
| "The sun rises in the east and", |
| ] |
|
|
| print("\n" + "=" * 70) |
| print(" INFERENCE — 1B Transformer (Single GPU)") |
| print("=" * 70) |
|
|
| for prompt in prompts: |
| print(f"\n{'─' * 60}") |
| print(f"PROMPT: {prompt}") |
| print(f"{'─' * 60}") |
| text, n_tok, tps = generate(model, tokenizer, prompt, |
| max_new_tokens=150, temperature=0.8, |
| top_k=50, device=device) |
| generated = text[len(prompt):] |
| print(f"OUTPUT:{generated}") |
| print(f" [{n_tok} tokens, {tps:.1f} tok/s]") |
|
|
| print("\n" + "=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|