gradio_chat_001 / app.py
Xylor's picture
Raised timeout time to 120s
3291e63 verified
import os
from collections.abc import Iterator
from threading import Thread
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import logging
import time
logger = logging.getLogger("gradio_chat_001")
logger.setLevel(logging.INFO)
logging.debug("Starting logging for gradio_chat_001.")
categories = [
"Legal", "Specification", "Facts and Figures",
"Publication", "Payment Scheme",
"Alternative Payment Systems", "Crypto Payments",
"Card Payments", "Banking", "Regulations", "Account Payments"
]
logging.debug("Categories to classify: " + repr(categories))
# DESCRIPTION = """\
# # Llama 3.2 3B Instruct
# Llama 3.2 3B is Meta's latest iteration of open LLMs.
# This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
# For more details, please check [our post](https://huggingface.co/blog/llama32).
# """
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
logger.warn("Wants to use CUDA, stop it!")
USE_CUDA = False
device = torch.device("cuda:0" if USE_CUDA else "cpu")
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
logger.info("Created model: " + model_id)
logger.info("Model repr: " + repr(model))
logger.info("Tokenizer repr: " + repr(tokenizer))
model.eval()
# Example:
# from transformers import AutoTokenizer, DeepseekV3ForCausalLM
# model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
# tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf")
# prompt = "Hey, are you conscious? Can you talk to me?"
# inputs = tokenizer(prompt, return_tensors="pt")
# # Generate
# generate_ids = model.generate(inputs.input_ids, max_length=30)
# tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
# @spaces.GPU(duration=90)
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
logger.warn(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
attention_mask = torch.ones_like(input_ids)
streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
def analyse_time_array(arr, extended=False):
length = len(arr)
if length == 0:
return "Empty"
if length == 1:
return "Start"
start = arr[0]
end = arr[-1]
diff = end - start
msg = f"{length-1} Tokens in {diff}s | {diff/length} Tokens/s"
if extended:
diffs = sorted([arr[i+1]-arr[i] for i in range(0, length-1)])
# msg += "\nDiffs between tokens:"
msg += "\nBest/shortest: " + ", ".join(f"{x:.02f}s" for x in diffs[:5])
msg += "\nWorst/longest: " + ", ".join(f"{x:.02f}s" for x in diffs[-5:])
return msg
SPACER = "\n\n" + "-"*80 + "\n\n"
def try_generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
):
try:
logger.info("Create input")
yield "<Create Input>"
conversation = [*chat_history, {"role": "user", "content": message}]
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
logger.warn(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
yield f"<input_ids>{repr(input_ids)}</input_ids>"
input_ids = input_ids.to(model.device)
attention_mask = torch.ones_like(input_ids)
except Exception as e:
logger.warn("Failed to create input parameters: " + repr(e))
yield "Failed to create input parameters: " + repr(e)
return
try:
streamer = TextIteratorStreamer(tokenizer, timeout=120.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids, "attention_mask": attention_mask},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
except Exception as e:
msg ="Failed to create streamer: " + repr(e)
logger.warning(msg)
yield msg
return
try:
yield "<start thread>"
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
except Exception as e:
msg = "Failed to create thread: " + repr(e)
logger.warning(msg)
yield msg
return
outputs = []
times = [time.time()]
try:
yield "<start text>"
for text in streamer:
outputs.append(text)
times.append(time.time())
# yield "".join(outputs)
msg = "".join(outputs)
info = analyse_time_array(times, True)
yield msg+SPACER+info
except Exception as e:
n = len(outputs)
exp = repr(e)
error = f"Failed creating output @ position {n}: {exp}"
logger.warning(error)
msg = "".join(outputs)
info = analyse_time_array(times, True)
yield msg+SPACER+info+SPACER+error
# yield f"{output}\n--------------------\n{msg}"
msg = "".join(outputs)
info = analyse_time_array(times, True)
yield msg+SPACER+info+"\n--- DONE ---"
demo = gr.ChatInterface(
fn=try_generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
cache_examples=False,
type="messages",
# description=DESCRIPTION,
# css_paths="style.css",
fill_height=True,
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()