Revisual-R1 / app.py
cyrus28214's picture
Update app.py
ed8c92d verified
import gradio as gr
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
from qwen_vl_utils import process_vision_info
from threading import Thread
import spaces
file_path = "csfufu/Revisual-R1-final"
processor = AutoProcessor.from_pretrained(
file_path,
min_pixels=256*28*28,
max_pixels=1280*28*28
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
file_path, torch_dtype="auto", device_map="auto"
)
@spaces.GPU
def respond(
input_dict,
chat_history,
system_message,
max_tokens,
temperature,
top_p,
):
text = input_dict["text"]
files = input_dict["files"]
messages = [{
"role": "system",
"content": system_message
}]
print(chat_history)
for message in chat_history:
if isinstance(message["content"], str):
messages.append({
"role": message["role"],
"content": [
{ "type": "text", "text": message["content"] },
]
})
else:
messages.append({
"role": message["role"],
"content": [
{ "type": "image", "image": image }
for image in message["content"]
]
})
messages.append(
{
"role": "user",
"content": [
{ "type": "text", "text": text },
*[{"type": "image", "image": image} for image in files]
]
}
)
image_inputs, video_inputs = process_vision_info(messages)
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[prompt],
images=image_inputs,
videos=video_inputs,
return_tensors="pt",
padding=True,
).to(model.device)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(**inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
print(new_text, end='')
yield buffer
print()
demo = gr.ChatInterface(
title='Revisual-R1',
type='messages',
chatbot=gr.Chatbot(
type='messages',
# allow_tags=['think'],
sanitize_html=False,
scale=1,
),
fn=respond,
examples=[[{
"text": "Solve this question.",
"files": [ "example.png" ]
}
]],
cache_examples=False,
multimodal=True,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=8192, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
demo.launch(debug=True)