File size: 3,996 Bytes
673210b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))


# DESCRIPTION = ""
# if not torch.cuda.is_available():
#     DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")



tokenizer = AutoTokenizer.from_pretrained("Back-up/T5-pretrain")
model = AutoModelForSeq2SeqLM.from_pretrained("Back-up/T5-large-QA")
model.to(device)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    tokenized_text = tokenizer.encode(message, return_tensors="pt").to(model.device)

    model.eval()
    summary_ids = model.generate(
                        tokenized_text,
                        max_length=1024,
                        min_length=8,
                        num_beams=5,
                        repetition_penalty=2.5,
                        length_penalty=1.0,
                        early_stopping=True
                    )
    output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    yield output

chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Textbox(label="System prompt", lines=6),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Trường đại học Nông Lâm thành phố Hồ Chí Minh nằm ở đâu?"],
        ["Mục tiêu chiến lược của trường đại học Nông Lâm thành phố Hồ Chí Minh là gì?"],
        ["Sinh viên được khen thưởng cá nhân và tập thể khi nào?"],
        ["Điều kiện cơ bản để được hỗ trợ vay tiền sinh viên là gì?"],
        ["Trường Đại học Nông Lâm đã trải qua bao nhiêu năm hoạt động tính đến năm 2023?"],
        ["Những hành vi nào của sinh viên bị coi là vi phạm quy định của Nhà trường?"],
        ["Địa chỉ của Phân hiệu Trường Đại học Nông Lâm tại Ninh Thuận?"],
        ["Làm thế nào khi sinh viên không hài lòng với việc giải quyết thắc mắc của Trưởng Bộ môn?"],
        ["Làm thế để yêu cầu phúc khảo bài thi?"],
        ["Nghĩa vụ của sinh viên là gì?"],
        ["Viết cho tôi một chương trình tính số nguyên tố bằng python."]
    ],
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()


if __name__ == "__main__":
    demo.queue(max_size=20).launch()