# Copyright (c) 2025 Team OpthChat.
#
# This source code is based on by web_demo_mm.py, by Alibaba Cloud.
# Licensed under Apache License 2.0
import copy
import os
import re
from argparse import ArgumentParser
from threading import Thread
import gradio as gr
import torch
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
DEFAULT_CKPT_PATH = 'farrell236/OpthModel32B_a'
# DEFAULT_CKPT_PATH = '/scratch/llm-weights/Qwen/Qwen2.5-VL-7B-Instruct'
AUTH_TOKEN = os.environ.get("HF_spaces")
def _get_args():
parser = ArgumentParser()
parser.add_argument('-c',
'--checkpoint-path',
type=str,
default=DEFAULT_CKPT_PATH,
help='Checkpoint name or path, default to %(default)r')
parser.add_argument('-t',
'--auth-token',
type=str,
default=AUTH_TOKEN,
help='Authentication token for model repository, default to %(default)r')
parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
parser.add_argument('--flash-attn2',
action='store_true',
default=False,
help='Enable flash_attention_2 when loading the model.')
parser.add_argument('--share',
action='store_true',
default=False,
help='Create a publicly shareable link for the interface.')
parser.add_argument('--inbrowser',
action='store_true',
default=False,
help='Automatically launch the interface in a new tab on the default browser.')
parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.')
args = parser.parse_args()
return args
def _load_model_processor(args):
if args.cpu_only:
device_map = 'cpu'
else:
device_map = 'auto'
# Check if flash-attn2 flag is enabled and load model accordingly
if args.flash_attn2:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
use_auth_token=args.auth_token,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
device_map=device_map)
else:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.checkpoint_path,
use_auth_token=args.auth_token,
torch_dtype=torch.bfloat16,
device_map=device_map)
processor = AutoProcessor.from_pretrained('Qwen/Qwen2.5-VL-32B-Instruct')
return model, processor
def _parse_text(text):
lines = text.split('\n')
lines = [line for line in lines if line != '']
count = 0
for i, line in enumerate(lines):
if '```' in line:
count += 1
items = line.split('`')
if count % 2 == 1:
lines[i] = f'
'
else:
lines[i] = '
'
else:
if i > 0:
if count % 2 == 1:
line = line.replace('`', r'\`')
line = line.replace('<', '<')
line = line.replace('>', '>')
line = line.replace(' ', ' ')
line = line.replace('*', '*')
line = line.replace('_', '_')
line = line.replace('-', '-')
line = line.replace('.', '.')
line = line.replace('!', '!')
line = line.replace('(', '(')
line = line.replace(')', ')')
line = line.replace('$', '$')
lines[i] = '
' + line
text = ''.join(lines)
return text
def _remove_image_special(text):
text = text.replace('[', '').replace(']', '')
return re.sub(r'.*?(|$)', '', text)
def _is_video_file(filename):
video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
return any(filename.lower().endswith(ext) for ext in video_extensions)
def _gc():
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _transform_messages(original_messages):
transformed_messages = []
for message in original_messages:
new_content = []
for item in message['content']:
if 'image' in item:
new_item = {'type': 'image', 'image': item['image']}
elif 'text' in item:
new_item = {'type': 'text', 'text': item['text']}
elif 'video' in item:
new_item = {'type': 'video', 'video': item['video']}
else:
continue
new_content.append(new_item)
new_message = {'role': message['role'], 'content': new_content}
transformed_messages.append(new_message)
return transformed_messages
def _launch_demo(args, model, processor):
def call_local_model(model, processor, messages,
max_tokens=1024,
temperature=0.6,
top_p=0.9,
top_k=50,
repetition_penalty=1.2):
messages = _transform_messages(messages)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
inputs = inputs.to(model.device)
tokenizer = processor.tokenizer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {'streamer': streamer,
'max_new_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
'repetition_penalty': repetition_penalty,
**inputs}
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
generated_text = ''
for new_text in streamer:
generated_text += new_text
yield generated_text
def create_predict_fn():
def predict(_chatbot, task_history,
max_tokens, temperature, top_p, top_k, repetition_penalty):
nonlocal model, processor
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
if len(chat_query) == 0:
_chatbot.pop()
task_history.pop()
return _chatbot
print('User: ' + _parse_text(query))
history_cp = copy.deepcopy(task_history)
full_response = ''
messages = []
content = []
for q, a in history_cp:
if isinstance(q, (tuple, list)):
if _is_video_file(q[0]):
content.append({'video': f'file://{q[0]}'})
else:
content.append({'image': f'file://{q[0]}'})
else:
content.append({'text': q})
messages.append({'role': 'user', 'content': content})
messages.append({'role': 'assistant', 'content': [{'text': a}]})
content = []
messages.pop()
for response in call_local_model(model, processor, messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty):
_chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
yield _chatbot
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print('Qwen-VL-Chat: ' + _parse_text(full_response))
yield _chatbot
return predict
def create_regenerate_fn():
def regenerate(_chatbot, task_history):
nonlocal model, processor
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
_chatbot_gen = predict(_chatbot, task_history)
for _chatbot in _chatbot_gen:
yield _chatbot
return regenerate
predict = create_predict_fn()
regenerate = create_regenerate_fn()
def add_text(history, task_history, text):
task_text = text
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ''
def add_file(history, task_history, file):
history = history if history is not None else []
task_history = task_history if task_history is not None else []
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value='')
def reset_state(_chatbot, task_history):
task_history.clear()
_chatbot.clear()
_gc()
return []
with gr.Blocks() as demo:
gr.Markdown("# Qwen2.5-VL (model_a) for OpthChat")
chatbot = gr.Chatbot(label='Qwen2.5-VL', elem_classes='control-height', height=500)
with gr.Accordion("Generation Parameters", open=False):
max_tokens = gr.Slider(64, 4096, value=1024, step=64, label="Max Tokens")
temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.1, label="Temperature")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
top_k = gr.Slider(0, 100, value=50, step=1, label="Top-k")
repetition_penalty = gr.Slider(0.5, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton('๐ Upload', file_types=['image', 'video'])
submit_btn = gr.Button('๐ Submit')
regen_btn = gr.Button('โป๏ธ๏ธ Regenerate')
empty_bin = gr.Button('๐งน Clear History')
submit_btn.click(add_text,
[chatbot, task_history, query],
[chatbot, task_history]).then(predict,
[chatbot, task_history,
max_tokens, temperature, top_p, top_k, repetition_penalty],
[chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
gr.Markdown("##### Note: This demo is governed by the original license of Qwen2.5-VL, "
"WebUI based on [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL/blob/main/web_demo_mm.py). "
"Developed by Alibaba Cloud, modified by Team OpthChat")
demo.queue().launch(
share=args.share,
inbrowser=args.inbrowser,
server_port=args.server_port,
server_name=args.server_name,
)
def main():
args = _get_args()
model, processor = _load_model_processor(args)
_launch_demo(args, model, processor)
if __name__ == '__main__':
main()