| |
| |
|
|
| from typing import List |
|
|
| import torch |
| import traceback |
|
|
| from llama.tokenizer import Tokenizer |
| from llama.model import Transformer |
| from tqdm import trange |
|
|
|
|
| class LLaMA: |
| def __init__(self, model: Transformer, tokenizer: Tokenizer): |
| self.model = model |
| self.tokenizer = tokenizer |
|
|
| def generate( |
| self, |
| prompts: List[str], |
| max_gen_len: int, |
| temperature: float = 0.8, |
| top_p: float = 0.95, |
| ) -> List[str]: |
| bsz = len(prompts) |
| params = self.model.params |
| assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) |
|
|
| count_newlines = prompts[0].count("\n") |
|
|
| prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] |
|
|
| min_prompt_size = min([len(t) for t in prompt_tokens]) |
| max_prompt_size = max([len(t) for t in prompt_tokens]) |
|
|
| total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) |
|
|
| tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).long() |
| for k, t in enumerate(prompt_tokens): |
| tokens[k, : len(t)] = torch.tensor(t).long() |
| tokens[k, -1] = self.tokenizer.eos_id |
| input_text_mask = tokens != self.tokenizer.pad_id |
| start_pos = min_prompt_size |
| prev_pos = 0 |
| decoded = [None] * bsz |
| for cur_pos in trange(start_pos, total_len, desc="forward"): |
| logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) |
| if temperature > 0: |
| probs = torch.softmax(logits / temperature, dim=-1) |
| next_token = sample_top_p(probs, top_p) |
| else: |
| next_token = torch.argmax(logits, dim=-1) |
| next_token = next_token.reshape(-1).cpu() |
| |
| next_token = torch.where( |
| input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token |
| ) |
| tokens[:, cur_pos] = next_token |
| prev_pos = cur_pos |
|
|
| print("-" * 30) |
| for i, t in enumerate(tokens.tolist()): |
| |
| |
| |
| |
| t = t[: min(cur_pos, len(prompt_tokens[i]) + max_gen_len)] |
| |
| try: |
| t = t[: t.index(self.tokenizer.eos_id)] |
| except ValueError: |
| pass |
| try: |
| d = self.tokenizer.decode(t) |
| print([i] * 20) |
| print(d) |
| decoded[i] = d |
|
|
| result_count_newlines = d.count("\n") |
| if result_count_newlines > count_newlines: |
| return decoded |
|
|
| except IndexError: |
| traceback.print_exc() |
| print(t) |
| print("-" * 30) |
| return decoded |
|
|
|
|
| def sample_top_p(probs, p): |
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| probs_sum = torch.cumsum(probs_sort, dim=-1) |
| mask = probs_sum - probs_sort > p |
| probs_sort[mask] = 0.0 |
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| next_token = torch.multinomial(probs_sort, num_samples=1) |
| next_token = torch.gather(probs_idx, -1, next_token) |
| return next_token |
|
|