| |
| """ |
| Interactive chat with the 1B Transformer. |
| Runs in an infinite conversation loop from the terminal. |
| |
| Usage: |
| python chat.py # auto-find latest checkpoint |
| python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint |
| """ |
|
|
| import sys |
| import os |
| import glob |
| import time |
| import torch |
| import torch.nn.functional as F |
| import readline |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.config import ModelConfig |
| from model.transformer import Transformer |
| from model.data import get_tokenizer |
|
|
|
|
| def find_latest_checkpoint(): |
| """Look for DPO > SFT > pretrained checkpoint.""" |
| dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo" |
| sft_dir = "/jfs/deepak-kumar/checkpoints_sft" |
| pt_dir = "/jfs/deepak-kumar/checkpoints" |
|
|
| |
| dpo_final = os.path.join(dpo_dir, "dpo_final.pt") |
| if os.path.exists(dpo_final): |
| return dpo_final, True |
|
|
| dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt")) |
| if dpo_files: |
| return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True |
|
|
| |
| sft_final = os.path.join(sft_dir, "sft_final.pt") |
| if os.path.exists(sft_final): |
| return sft_final, True |
|
|
| sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt")) |
| if sft_files: |
| return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True |
|
|
| |
| pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt")) |
| if pt_files: |
| return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False |
|
|
| return None, False |
|
|
|
|
| def load_model(checkpoint_path, tokenizer, device="cuda:0"): |
| config = ModelConfig() |
| model = Transformer(config) |
| ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) |
|
|
| |
| saved_vocab = ckpt.get("vocab_size", config.vocab_size) |
| if saved_vocab > config.vocab_size: |
| config.vocab_size = saved_vocab |
| model = Transformer(config) |
|
|
| model.load_state_dict(ckpt["model"]) |
| model = model.to(device).bfloat16().eval() |
| step = ckpt.get("step", "?") |
| loss = ckpt.get("loss", "?") |
| del ckpt |
| torch.cuda.empty_cache() |
| return model, config, step, loss |
|
|
|
|
| @torch.no_grad() |
| def generate_stream(model, tokenizer, prompt, max_new_tokens=512, |
| temperature=0.8, top_k=50, top_p=0.9, |
| repetition_penalty=1.15, device="cuda:0", |
| stop_token_ids=None): |
| """Generate tokens one at a time, yielding each for streaming output.""" |
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) |
| generated_ids = [] |
| prev_decoded_len = 0 |
|
|
| if stop_token_ids is None: |
| stop_token_ids = set() |
| else: |
| stop_token_ids = set(stop_token_ids) |
| stop_token_ids.add(tokenizer.eos_token_id) |
|
|
| for _ in range(max_new_tokens): |
| if input_ids.shape[1] >= model.config.max_seq_len: |
| break |
|
|
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| logits, _ = model(input_ids) |
|
|
| logits = logits[:, -1, :] |
|
|
| if repetition_penalty != 1.0 and generated_ids: |
| prev_tokens = torch.tensor(generated_ids, device=device).unique() |
| for token_id in prev_tokens: |
| if logits[0, token_id] > 0: |
| logits[0, token_id] /= repetition_penalty |
| else: |
| logits[0, token_id] *= repetition_penalty |
|
|
| logits = logits / temperature |
|
|
| if top_k > 0: |
| topk_vals, _ = torch.topk(logits, top_k) |
| logits[logits < topk_vals[:, -1:]] = float("-inf") |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
| mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p |
| sorted_logits[mask] = float("-inf") |
| logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| token_id = next_token.item() |
|
|
| |
| if token_id in stop_token_ids: |
| break |
|
|
| generated_ids.append(token_id) |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True) |
| new_text = full_decoded[prev_decoded_len:] |
| prev_decoded_len = len(full_decoded) |
| yield new_text |
|
|
| return |
|
|
|
|
| def print_banner(step, loss, device): |
| print("\033[1;36m") |
| print("=" * 60) |
| print(" 1B TRANSFORMER — Interactive Chat") |
| print("=" * 60) |
| print(f"\033[0m Checkpoint : step {step}") |
| print(f" Loss : {loss}") |
| print(f" Device : {device}") |
| print(f" Parameters : 1.106B") |
| print() |
| print(" \033[90mCommands:\033[0m") |
| print(" \033[33m/quit\033[0m — exit") |
| print(" \033[33m/clear\033[0m — clear conversation context") |
| print(" \033[33m/temp N\033[0m — set temperature (default 0.8)") |
| print(" \033[33m/tokens N\033[0m — set max tokens (default 512)") |
| print(" \033[33m/topp N\033[0m — set top-p (default 0.9)") |
| print(" \033[33m/topk N\033[0m — set top-k (default 50)") |
| print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)") |
| print() |
| print("\033[90m" + "─" * 60 + "\033[0m") |
|
|
|
|
| def main(): |
| device = "cuda:0" |
|
|
| is_sft = False |
| if len(sys.argv) > 1: |
| checkpoint = sys.argv[1] |
| is_sft = "sft" in checkpoint.lower() |
| else: |
| result = find_latest_checkpoint() |
| if result[0] is None: |
| print("No checkpoint found!") |
| sys.exit(1) |
| checkpoint, is_sft = result |
|
|
| tokenizer = get_tokenizer() |
|
|
| |
| if is_sft: |
| special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] |
| vocab = tokenizer.get_vocab() |
| new_tokens = [t for t in special_tokens if t not in vocab] |
| if new_tokens: |
| tokenizer.add_tokens(new_tokens, special_tokens=True) |
|
|
| print(f"\n Loading model from {checkpoint}...") |
| print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}") |
| model, config, step, loss = load_model(checkpoint, tokenizer, device) |
| print(f" Model loaded!\n") |
|
|
| print_banner(step, loss, device) |
| if is_sft: |
| print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n") |
|
|
| |
| temperature = 0.7 if is_sft else 0.8 |
| max_tokens = 512 |
| top_p = 0.9 |
| top_k = 50 |
| rep_penalty = 1.15 |
| context = "" |
|
|
| |
| USER_START = "<|user|>\n" |
| ASST_START = "<|assistant|>\n" |
| TURN_END = "\n<|end|>\n" |
|
|
| |
| sft_stop_ids = [] |
| if is_sft: |
| vocab = tokenizer.get_vocab() |
| for tok_str in ["<|end|>", "<|user|>"]: |
| if tok_str in vocab: |
| sft_stop_ids.append(vocab[tok_str]) |
|
|
| while True: |
| try: |
| user_input = input("\n\033[1;32mYou:\033[0m ").strip() |
| except (KeyboardInterrupt, EOFError): |
| print("\n\nGoodbye!") |
| break |
|
|
| if not user_input: |
| continue |
|
|
| |
| if user_input.startswith("/"): |
| cmd = user_input.lower().split() |
| if cmd[0] == "/quit": |
| print("Goodbye!") |
| break |
| elif cmd[0] == "/clear": |
| context = "" |
| print("\033[90m [Context cleared]\033[0m") |
| continue |
| elif cmd[0] == "/temp" and len(cmd) > 1: |
| temperature = float(cmd[1]) |
| print(f"\033[90m [Temperature set to {temperature}]\033[0m") |
| continue |
| elif cmd[0] == "/tokens" and len(cmd) > 1: |
| max_tokens = int(cmd[1]) |
| print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m") |
| continue |
| elif cmd[0] == "/topp" and len(cmd) > 1: |
| top_p = float(cmd[1]) |
| print(f"\033[90m [Top-p set to {top_p}]\033[0m") |
| continue |
| elif cmd[0] == "/topk" and len(cmd) > 1: |
| top_k = int(cmd[1]) |
| print(f"\033[90m [Top-k set to {top_k}]\033[0m") |
| continue |
| elif cmd[0] == "/rep" and len(cmd) > 1: |
| rep_penalty = float(cmd[1]) |
| print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m") |
| continue |
| else: |
| print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m") |
| continue |
|
|
| |
| if is_sft: |
| prompt = context + USER_START + user_input + TURN_END + ASST_START |
| else: |
| if context: |
| prompt = context + "\n" + user_input |
| else: |
| prompt = user_input |
|
|
| |
| while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens: |
| if is_sft: |
| parts = context.split(TURN_END) |
| if len(parts) <= 2: |
| break |
| context = TURN_END.join(parts[2:]) |
| prompt = context + USER_START + user_input + TURN_END + ASST_START |
| else: |
| lines = prompt.split("\n") |
| if len(lines) <= 2: |
| break |
| prompt = "\n".join(lines[1:]) |
|
|
| |
| print("\033[1;34mModel:\033[0m ", end="", flush=True) |
| t0 = time.time() |
| full_response = "" |
| token_count = 0 |
|
|
| for token_text in generate_stream( |
| model, tokenizer, prompt, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=rep_penalty, |
| device=device, |
| stop_token_ids=sft_stop_ids if is_sft else None, |
| ): |
| print(token_text, end="", flush=True) |
| full_response += token_text |
| token_count += 1 |
|
|
| elapsed = time.time() - t0 |
| tps = token_count / max(elapsed, 1e-9) |
| print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m") |
|
|
| |
| if is_sft: |
| context = (context + USER_START + user_input + TURN_END + |
| ASST_START + full_response.strip() + TURN_END) |
| else: |
| context = prompt + full_response |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|