File size: 2,144 Bytes
6bed8a1
f2fb237
6bed8a1
 
 
 
0ce2dc8
6bed8a1
 
 
 
 
f4c3e35
6bed8a1
 
 
 
b968003
6bed8a1
 
 
 
 
 
 
 
 
b968003
 
36552b6
b968003
 
 
6bed8a1
b968003
6bed8a1
 
 
 
 
 
 
b968003
 
 
6bed8a1
 
 
0ce2dc8
 
2132820
0ce2dc8
 
b968003
 
6bed8a1
 
 
 
 
 
 
 
 
 
 
92703a9
6bed8a1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import random
import torch
import re
# Clear existing cache
torch.cuda.empty_cache()


# Load model directly
tokenizer = AutoTokenizer.from_pretrained("Salesforce/xgen-7b-8k-inst", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Salesforce/xgen-7b-8k-inst", torch_dtype=torch.float16).to('cuda')

# Bloom LLM
def xgen(input_text,
         history):
    """
    This will take an input text, encode with the tokenizer,
    generate with the input_ids into the Bloom LLM, than decode
    the output id into text.
    """

    # # User's question
    # input_text = "How was jupiter created in the solar system."

    # Prompt template for LLM "context"
    header = (
        "A chat between a curious human and an artificial intelligence assistant called bubble bee. "
        "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
    )

    # token id's for prompt
    input_ids = tokenizer(header + input_text, return_tensors='pt').to('cuda')

    # Bloom already comes in fp16

    # Let's use torch.no_grad() to save memory and computation
    with torch.no_grad():
        # Generate output from LLM
        outputs = model.generate(**input_ids,
                                 max_new_tokens=256,
                                 top_k=100,
                                 eos_token_id=50256)

    # Decode the output tensors into string
    outputs_decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # CLEAN UP TEXT
    output_text = outputs_decoded.replace(header, "").strip()
    output_text = re.sub(r'^Assistant:\s*', '', output_text)
    output_text = output_text.replace('<|endoftext\>', '').strip()

    return output_text

torch.cuda.empty_cache()

# Create the mushroom UI

chatbot=gr.Chatbot(height=700, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=xgen,
        fill_height=True,
        title="Bubble Bee 🐝"
    )

if __name__ == "__main__":
    demo.launch()