""" test_chatmodel.py — Interactive CLI chat and evaluation for the fine-tuned SLLM chat model. Usage: python test_chatmodel.py --run_dir runs/sllm_150m_chat python test_chatmodel.py --run_dir runs/sllm_150m_chat --mode sample In interactive mode: Type your message and press Enter. Special commands: /reset Clear conversation history /system Change the system prompt /quit Exit the chat """ import os import sys import argparse from pathlib import Path import torch import torch.nn as nn from torch.amp import autocast from transformers import PreTrainedTokenizerFast # Add project root to path PROJECT_ROOT = Path(__file__).resolve().parent sys.path.insert(0, str(PROJECT_ROOT)) from model.config import SLLM_150M from model.model import SLLM DEFAULT_SYSTEM = "You are a helpful, concise assistant." DEFAULT_RUN_DIR = str(PROJECT_ROOT / "runs" / "sllm_150m_chat") # ------------------------------------------------------------------ # # HELPERS # ------------------------------------------------------------------ # def find_latest_ckpt(run_dir: str) -> str: """Returns path to the most recent SFT or base checkpoint in run_dir.""" if not os.path.isdir(run_dir): raise FileNotFoundError(f"Run directory '{run_dir}' does not exist.") ckpts = sorted([ f for f in os.listdir(run_dir) if (f.startswith("ckpt_sft_") or f.startswith("ckpt_")) and f.endswith(".pt") ]) if not ckpts: raise FileNotFoundError( f"No checkpoints found in '{run_dir}'.\n" f"Please ensure you have trained the model or point to the correct folder." ) return os.path.join(run_dir, ckpts[-1]) def resize_token_embeddings(model: SLLM, new_vocab_size: int): """Resizes the token embeddings matrix to support added special tokens.""" old_size = model.config.vocab_size if new_vocab_size == old_size: return d_model = model.config.d_model device = model.token_emb.weight.device dtype = model.token_emb.weight.dtype old_weight = model.token_emb.weight.data.clone() mean_vec = old_weight.mean(dim=0) new_weight = torch.zeros(new_vocab_size, d_model, dtype=dtype, device=device) new_weight[:old_size] = old_weight new_weight[old_size:] = mean_vec.unsqueeze(0).expand(new_vocab_size - old_size, -1) new_emb = nn.Embedding(new_vocab_size, d_model).to(device=device, dtype=dtype) new_emb.weight.data = new_weight model.token_emb = new_emb model.lm_head.weight = model.token_emb.weight model.config.vocab_size = new_vocab_size print(f" [INFO] Resized model vocab embedding from {old_size:,} to {new_vocab_size:,}") def load_model_and_tokenizer(run_dir: str, device: torch.device): """Loads tokenizer and the latest model checkpoint.""" # ---- Tokenizer ------------------------------------------------- # # Look in finetune/data or tokenizer/fineweb_edu_tokenizer data_tok_dir = PROJECT_ROOT / "finetune" / "data" base_tok_dir = PROJECT_ROOT / "tokenizer" / "fineweb_edu_tokenizer" if os.path.exists(data_tok_dir / "tokenizer.json"): tok_path = str(data_tok_dir) tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) print(f" Tokenizer: Loaded extended tokenizer from '{tok_path}'") elif os.path.exists(base_tok_dir): tok_path = str(base_tok_dir) tokenizer = PreTrainedTokenizerFast.from_pretrained(tok_path) tokenizer.add_special_tokens({ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"] }) print(f" Tokenizer: Loaded base tokenizer from '{tok_path}' and added ChatML tokens") else: raise FileNotFoundError("Could not find a tokenizer directory.") # ---- Checkpoint ------------------------------------------------ # try: ckpt_path = find_latest_ckpt(run_dir) except FileNotFoundError: # Fall back to base pretraining checkpoint if SFT directory is empty print(f" [WARN] No checkpoint found in '{run_dir}'. Trying pretraining base run...") base_dir = PROJECT_ROOT / "runs" / "sllm_150m" ckpt_path = find_latest_ckpt(str(base_dir)) print(f" Loading checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) # ---- Model ----------------------------------------------------- # model = SLLM(SLLM_150M).to(device) saved_vocab = ckpt.get("vocab_size", len(tokenizer)) resize_token_embeddings(model, saved_vocab) model.load_state_dict(ckpt["model_state_dict"]) model.eval() step = ckpt.get("step", "?") loss = ckpt.get("loss", float("nan")) return model, tokenizer, ckpt_path, step, loss # ------------------------------------------------------------------ # # PROMPT BUILDING # ------------------------------------------------------------------ # def build_prompt(history: list[dict], system_prompt: str, tokenizer: PreTrainedTokenizerFast) -> torch.Tensor: """Formats conversation history as ChatML and tokenizes it.""" text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" for turn in history: text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n" # Prime the model to respond as assistant text += "<|im_start|>assistant\n" ids = tokenizer.encode(text, add_special_tokens=False) return torch.tensor([ids], dtype=torch.long) # ------------------------------------------------------------------ # # GENERATION # ------------------------------------------------------------------ # @torch.no_grad() def generate_response( model: SLLM, input_ids: torch.Tensor, tokenizer: PreTrainedTokenizerFast, max_new_tokens: int = 200, temperature: float = 0.7, top_k: int = 40, top_p: float = 0.9, device: torch.device = None, dtype_torch: torch.dtype = torch.float32, use_amp: bool = False, ) -> str: """Generates a response from the model using top-k/top-p sampling.""" im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") eos_id = tokenizer.eos_token_id ids = input_ids.to(device) generated = [] for _ in range(max_new_tokens): # Crop context to model window ctx = ids if ids.shape[1] <= model.config.context_length \ else ids[:, -model.config.context_length:] with autocast(device_type=device.type, dtype=dtype_torch, enabled=use_amp): logits, _ = model(ctx) # (1, T, V) # Pull last token logits logits = logits[:, -1, :] if temperature == 0.0: # Greedy next_token = logits.argmax(dim=-1, keepdim=True) else: logits = logits / max(temperature, 1e-8) # Top-k filtering if top_k and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True) cumprobs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_logits[cumprobs - torch.softmax(sorted_logits, dim=-1) > top_p] = float("-inf") logits = torch.zeros_like(logits).scatter_(1, sorted_idx, sorted_logits) probs = torch.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # (1, 1) tok_id = next_token.item() # Stop if end of message or end of stream token is generated if tok_id == im_end_id or tok_id == eos_id: break generated.append(tok_id) ids = torch.cat([ids, next_token], dim=1) return tokenizer.decode(generated, skip_special_tokens=True).strip() # ------------------------------------------------------------------ # # MODES # ------------------------------------------------------------------ # def run_interactive(model, tokenizer, device, dtype_torch, use_amp, args): system_prompt = args.system history = [] print("\n" + "=" * 60) print(" CHAT MODE (Interactive)") print("=" * 60) print(f" System prompt : {system_prompt}") print(" Commands : /reset to clear memory | /system | /quit to exit") print("─" * 60 + "\n") while True: try: user_input = input("You: ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if not user_input: continue # Check for commands if user_input.lower() in ("/quit", "/exit", "quit", "exit"): print("Bye!") break if user_input.lower() == "/reset": history = [] print(" [Conversation history reset]\n") continue if user_input.lower().startswith("/system "): new_sys = user_input[8:].strip() if new_sys: system_prompt = new_sys history = [] print(f" [System prompt updated. History cleared.]\n") continue # Add to history and build ChatML prompt history.append({"role": "user", "content": user_input}) input_ids = build_prompt(history, system_prompt, tokenizer) # Trim conversation window if it exceeds model context length while input_ids.shape[1] > model.config.context_length - args.max_new_tokens - 10: if len(history) > 2: history = history[2:] # Remove oldest user + assistant turn input_ids = build_prompt(history, system_prompt, tokenizer) else: break print("SLLM: ", end="", flush=True) response = generate_response( model, input_ids, tokenizer, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, device=device, dtype_torch=dtype_torch, use_amp=use_amp, ) print(response + "\n") history.append({"role": "assistant", "content": response}) def run_sample(model, tokenizer, device, dtype_torch, use_amp, args): sample_prompts = [ "Hello! Who are you?", "What is the capital of France?", "Write a quick, 3-line poem about a small robot learning to speak.", "Explain gravity in one simple sentence.", ] print("\n" + "=" * 60) print(" SAMPLE EVALUATION MODE") print("=" * 60) print(f" System prompt: {args.system}") print("─" * 60) for prompt in sample_prompts: print(f"\n[PROMPT] : {prompt}") history = [{"role": "user", "content": prompt}] input_ids = build_prompt(history, args.system, tokenizer) print("[SLLM] : ", end="", flush=True) response = generate_response( model, input_ids, tokenizer, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, device=device, dtype_torch=dtype_torch, use_amp=use_amp, ) print(response) print("\n" + "─" * 60 + "\n") # ------------------------------------------------------------------ # # MAIN # ------------------------------------------------------------------ # def main(): p = argparse.ArgumentParser(description="SLLM Chat Checker") p.add_argument("--run_dir", type=str, default=DEFAULT_RUN_DIR) p.add_argument("--mode", type=str, default="interactive", choices=["interactive", "sample"]) p.add_argument("--temperature", type=float, default=0.7) p.add_argument("--top_k", type=int, default=40) p.add_argument("--top_p", type=float, default=0.9) p.add_argument("--max_new_tokens", type=int, default=200) p.add_argument("--system", type=str, default=DEFAULT_SYSTEM) p.add_argument("--dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"]) args = p.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\nDevice : {device}") if device.type == "cuda": print(f"GPU : {torch.cuda.get_device_name(0)}") # Precision setup use_amp = False if args.dtype == "bf16" and device.type == "cuda" and torch.cuda.is_bf16_supported(): dtype_torch = torch.bfloat16 use_amp = True elif args.dtype == "fp16" and device.type == "cuda": dtype_torch = torch.float16 use_amp = True else: dtype_torch = torch.float32 print(f"dtype : {args.dtype}") # Load Model and Tokenizer try: model, tokenizer, ckpt_path, step, loss = load_model_and_tokenizer(args.run_dir, device) print(f" Step : {step}") if not torch.isnan(torch.tensor(loss)): print(f" Loss : {loss:.4f}") except Exception as e: print(f"\n[ERROR] Failed to load chat model: {e}") return if args.mode == "interactive": run_interactive(model, tokenizer, device, dtype_torch, use_amp, args) elif args.mode == "sample": run_sample(model, tokenizer, device, dtype_torch, use_amp, args) if __name__ == "__main__": main()