| import torch, argparse, json |
| from tokenizers import Tokenizer |
| from model.tiny_gpt2 import TinyGPT2, GPTConfig |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--prompt", type=str, required=True) |
| parser.add_argument("--ckpt", type=str, default="out/sft/model_sft.pt") |
| parser.add_argument("--cfg", type=str, default="out/pretrain/gpt_config.json") |
| parser.add_argument("--tok", type=str, default="out/tokenizer.json") |
| args = parser.parse_args() |
|
|
| tok = Tokenizer.from_file(args.tok) |
| cfg = GPTConfig(**json.load(open(args.cfg))) |
| m = TinyGPT2(cfg) |
| m.load_state_dict(torch.load(args.ckpt, map_location="cpu")) |
| m.eval() |
|
|
| ids = tok.encode("[BOS] " + args.prompt).ids |
| x = torch.tensor([ids], dtype=torch.long) |
| with torch.no_grad(): |
| y = m.generate(x, max_new_tokens=80) |
| text = tok.decode(y[0].tolist()) |
| print(text) |
|
|