| """Quick test of model quality with diverse prompts.""" |
|
|
| import os, sys, time, torch |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from model.config import ModelConfig |
| from model.transformer import Transformer |
| from model.data import get_tokenizer |
|
|
| DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt" |
| SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt" |
| CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT |
| DEVICE = "cuda:0" |
|
|
| USER_START = "<|user|>\n" |
| ASST_START = "<|assistant|>\n" |
| TURN_END = "\n<|end|>\n" |
|
|
| TEST_PROMPTS = [ |
| "Hi! How are you?", |
| "What is photosynthesis?", |
| "Explain gravity to a 5-year-old.", |
| "Write a short poem about the ocean.", |
| "What are the three states of matter?", |
| "How does a computer work?", |
| "What is the capital of France and why is it famous?", |
| "Give me 3 tips for learning a new language.", |
| "What is machine learning in simple terms?", |
| ] |
|
|
|
|
| @torch.no_grad() |
| def generate(model, tokenizer, prompt, max_new_tokens=256, |
| temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15): |
| input_ids = tokenizer.encode(prompt, add_special_tokens=False) |
| input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE) |
| generated = [] |
| eos_id = tokenizer.eos_token_id |
|
|
| end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False) |
| end_id = end_token_ids[0] if end_token_ids else None |
| user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False) |
| user_id = user_token_ids[0] if user_token_ids else None |
|
|
| stop_ids = set() |
| if eos_id is not None: |
| stop_ids.add(eos_id) |
| if end_id is not None: |
| stop_ids.add(end_id) |
| if user_id is not None: |
| stop_ids.add(user_id) |
|
|
| for _ in range(max_new_tokens): |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16): |
| logits, _ = model(input_ids) |
|
|
| logits = logits[:, -1, :].float() |
|
|
| if repetition_penalty != 1.0 and generated: |
| for tid in set(generated): |
| if logits[0, tid] > 0: |
| logits[0, tid] /= repetition_penalty |
| else: |
| logits[0, tid] *= repetition_penalty |
|
|
| logits = logits / max(temperature, 1e-5) |
|
|
| if top_k > 0: |
| topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < topk_vals[:, -1:]] = float('-inf') |
|
|
| if top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) |
| remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p |
| sorted_logits[remove] = float('-inf') |
| logits = sorted_logits.scatter(1, sorted_idx, sorted_logits) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, 1) |
| token_id = next_token.item() |
|
|
| if token_id in stop_ids: |
| break |
|
|
| generated.append(token_id) |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| if input_ids.size(1) > 2048: |
| break |
|
|
| return tokenizer.decode(generated, skip_special_tokens=True) |
|
|
|
|
| def main(): |
| ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT" |
| print("=" * 70) |
| print(" " + ckpt_name + " MODEL TEST") |
| print("=" * 70) |
|
|
| tokenizer = get_tokenizer() |
| special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"] |
| vocab = tokenizer.get_vocab() |
| new_tokens = [t for t in special_tokens if t not in vocab] |
| if new_tokens: |
| tokenizer.add_tokens(new_tokens, special_tokens=True) |
|
|
| config = ModelConfig() |
| config.vocab_size = len(tokenizer) |
| model = Transformer(config) |
|
|
| print("") |
| print("Loading checkpoint: " + CHECKPOINT) |
| ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt["model"]) |
| step = ckpt.get("step", "?") |
| del ckpt |
|
|
| model = model.to(DEVICE).bfloat16().eval() |
| print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")") |
| mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9 |
| print("GPU memory: " + str(round(mem, 1)) + " GB") |
| print("-" * 70) |
|
|
| for i, question in enumerate(TEST_PROMPTS, 1): |
| prompt = USER_START + question + TURN_END + ASST_START |
|
|
| print("") |
| print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]") |
| print(" Q: " + question) |
|
|
| t0 = time.time() |
| response = generate(model, tokenizer, prompt) |
| dt = time.time() - t0 |
| tokens = len(tokenizer.encode(response, add_special_tokens=False)) |
|
|
| response = response.split("<|end|>")[0].split("<|user|>")[0].strip() |
|
|
| print(" A: " + response) |
| tps = int(tokens / max(dt, 0.01)) |
| print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]") |
| print("-" * 70) |
|
|
| print("") |
| print("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|