|
|
| import os
|
| import torch
|
| import pickle
|
| from model import GPTConfig, GPT
|
| import tiktoken
|
| from rich.traceback import install
|
|
|
| install()
|
|
|
|
|
| ckpt_path = 'out/ckpt.pt'
|
| meta_path = 'data/mydata/meta.pkl'
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| tokenizer_name = 'cl100k_base'
|
| max_new_tokens = 1024
|
| temperature = 0.8
|
| top_k = 100
|
| special_tokens = {"<|endoftext|>", "<|im_start|>", "<|im_stop|>"}
|
|
|
|
|
| enc = tiktoken.get_encoding(tokenizer_name)
|
| encode = enc.encode
|
| decode = enc.decode
|
|
|
|
|
| with open(meta_path, 'rb') as f:
|
| meta = pickle.load(f)
|
| vocab_size = meta['vocab_size']
|
|
|
|
|
| checkpoint = torch.load(ckpt_path, map_location=device)
|
| model_args = checkpoint['model_args']
|
| model_args['vocab_size'] = vocab_size
|
| block_size = model_args.get('block_size', 1024)
|
|
|
|
|
| model = GPT(GPTConfig(**model_args))
|
| model.load_state_dict(checkpoint['model'])
|
| model.to(device)
|
| model.eval()
|
|
|
| @torch.no_grad()
|
| def generate_stream(model, input_ids, max_new_tokens, temperature=1.0, top_k=None):
|
| model.eval()
|
| special_token_id = encode("<|endoftext|>", allowed_special=special_tokens)[0]
|
|
|
| for _ in range(max_new_tokens):
|
| if input_ids.size(1) > block_size:
|
| input_ids = input_ids[:, -block_size:]
|
|
|
| logits, _ = model(input_ids)
|
| logits = logits[:, -1, :] / temperature
|
|
|
| if top_k is not None:
|
| v, _ = torch.topk(logits, top_k)
|
| logits[logits < v[:, [-1]]] = -float('Inf')
|
|
|
| probs = torch.nn.functional.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| next_token_id = next_token.item()
|
|
|
| input_ids = torch.cat((input_ids, next_token), dim=1)
|
|
|
| decoded_token = decode([next_token_id])
|
| print(decoded_token, end='', flush=True) if decoded_token not in special_tokens else None
|
|
|
| if next_token_id == special_token_id:
|
| break
|
|
|
| print()
|
| return input_ids
|
|
|
| def main():
|
| print("π€ AI Assistant is ready. Type 'exit' or press Ctrl+C to quit.\n")
|
| try:
|
| while True:
|
| user_input = input("You: ")
|
| if user_input.lower() in {"exit", "quit"}:
|
| print("π Exiting assistant.")
|
| break
|
|
|
| prompt = f"""
|
| <|im_start|>user
|
| {user_input}<|endoftext|>
|
| <|im_stop|>
|
|
|
| <|im_start|>assistant
|
|
|
| """
|
| input_ids = torch.tensor(encode(prompt, allowed_special=special_tokens), dtype=torch.long, device=device)[None, ...]
|
|
|
| print("π€ Assistant:", end=' ', flush=True)
|
| generate_stream(model, input_ids, max_new_tokens, temperature, top_k)
|
| print("-" * 50)
|
|
|
| except KeyboardInterrupt:
|
| print("\nπ Exiting assistant.")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|