Spaces:
Runtime error
Runtime error
File size: 3,996 Bytes
673210b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
# DESCRIPTION = ""
# if not torch.cuda.is_available():
# DESCRIPTION = "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
device = torch.device("cuda")
print('There are %d GPU(s) available.' % torch.cuda.device_count())
print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("Back-up/T5-pretrain")
model = AutoModelForSeq2SeqLM.from_pretrained("Back-up/T5-large-QA")
model.to(device)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
tokenized_text = tokenizer.encode(message, return_tensors="pt").to(model.device)
model.eval()
summary_ids = model.generate(
tokenized_text,
max_length=1024,
min_length=8,
num_beams=5,
repetition_penalty=2.5,
length_penalty=1.0,
early_stopping=True
)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
yield output
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Trường đại học Nông Lâm thành phố Hồ Chí Minh nằm ở đâu?"],
["Mục tiêu chiến lược của trường đại học Nông Lâm thành phố Hồ Chí Minh là gì?"],
["Sinh viên được khen thưởng cá nhân và tập thể khi nào?"],
["Điều kiện cơ bản để được hỗ trợ vay tiền sinh viên là gì?"],
["Trường Đại học Nông Lâm đã trải qua bao nhiêu năm hoạt động tính đến năm 2023?"],
["Những hành vi nào của sinh viên bị coi là vi phạm quy định của Nhà trường?"],
["Địa chỉ của Phân hiệu Trường Đại học Nông Lâm tại Ninh Thuận?"],
["Làm thế nào khi sinh viên không hài lòng với việc giải quyết thắc mắc của Trưởng Bộ môn?"],
["Làm thế để yêu cầu phúc khảo bài thi?"],
["Nghĩa vụ của sinh viên là gì?"],
["Viết cho tôi một chương trình tính số nguyên tố bằng python."]
],
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|