| | |
| | |
| |
|
| | from typing import List |
| |
|
| | import torch |
| | import traceback |
| |
|
| | from llama.tokenizer import Tokenizer |
| | from llama.model import Transformer |
| | from tqdm import trange |
| |
|
| |
|
| | class LLaMA: |
| | def __init__(self, model: Transformer, tokenizer: Tokenizer): |
| | self.model = model |
| | self.tokenizer = tokenizer |
| |
|
| | def generate( |
| | self, |
| | prompts: List[str], |
| | max_gen_len: int, |
| | temperature: float = 0.8, |
| | top_p: float = 0.95, |
| | ) -> List[str]: |
| | bsz = len(prompts) |
| | params = self.model.params |
| | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) |
| |
|
| | count_newlines = prompts[0].count("\n") |
| |
|
| | prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] |
| |
|
| | min_prompt_size = min([len(t) for t in prompt_tokens]) |
| | max_prompt_size = max([len(t) for t in prompt_tokens]) |
| |
|
| | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) |
| |
|
| | tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).long() |
| | for k, t in enumerate(prompt_tokens): |
| | tokens[k, : len(t)] = torch.tensor(t).long() |
| | tokens[k, -1] = self.tokenizer.eos_id |
| | input_text_mask = tokens != self.tokenizer.pad_id |
| | start_pos = min_prompt_size |
| | prev_pos = 0 |
| | decoded = [None] * bsz |
| | for cur_pos in trange(start_pos, total_len, desc="forward"): |
| | logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) |
| | if temperature > 0: |
| | probs = torch.softmax(logits / temperature, dim=-1) |
| | next_token = sample_top_p(probs, top_p) |
| | else: |
| | next_token = torch.argmax(logits, dim=-1) |
| | next_token = next_token.reshape(-1).cpu() |
| | |
| | next_token = torch.where( |
| | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
| | ) |
| | tokens[:, cur_pos] = next_token |
| | prev_pos = cur_pos |
| |
|
| | print("-" * 30) |
| | for i, t in enumerate(tokens.tolist()): |
| | |
| | |
| | |
| | |
| | t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)] |
| | |
| | try: |
| | t = t[: t.index(self.tokenizer.eos_id)] |
| | except ValueError: |
| | pass |
| | try: |
| | d = self.tokenizer.decode(t) |
| | print([i] * 20) |
| | print(d) |
| | decoded[i] = d |
| |
|
| | result_count_newlines = d.count("\n") |
| | if result_count_newlines > count_newlines: |
| | return decoded |
| |
|
| | except IndexError: |
| | traceback.print_exc() |
| | print(t) |
| | print("-" * 30) |
| | return decoded |
| |
|
| |
|
| | def sample_top_p(probs, p): |
| | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| | probs_sum = torch.cumsum(probs_sort, dim=-1) |
| | mask = probs_sum - probs_sort > p |
| | probs_sort[mask] = 0.0 |
| | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| | next_token = torch.multinomial(probs_sort, num_samples=1) |
| | next_token = torch.gather(probs_idx, -1, next_token) |
| | return next_token |
| |
|