Spaces:
Paused
Paused
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer | |
from PIL import Image | |
import gradio as gr | |
from threading import Thread | |
import spaces | |
# --- 1. Model and Processor Setup --- | |
# This part is loaded only once when the script starts. | |
try: | |
model_id = "bharatgenai/patram-7b-instruct" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load processor and model | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, # Use float16 for less memory usage on GPU | |
device_map="auto", # Automatically uses available GPUs | |
trust_remote_code=True | |
) | |
print("Model and processor loaded successfully.") | |
# --- Define and apply the chat template --- | |
# This is crucial for multi-turn conversation | |
chat_template = """{% for message in messages -%} | |
{%- if (loop.index % 2 == 1 and message['role'] != 'user') or | |
(loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%} | |
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} | |
{%- endif -%} | |
{{ message['role'].capitalize() + ': ' + message['content'] }} | |
{%- if not loop.last -%} | |
{{ ' ' }} | |
{%- endif %} | |
{%- endfor -%} | |
{%- if add_generation_prompt -%} | |
{{ ' Assistant:' }} | |
{%- endif %}""" | |
processor.tokenizer.chat_template = chat_template | |
except Exception as e: | |
print(f"Error during model loading: {e}") | |
# Exit if model can't be loaded, as the app is unusable. | |
exit() | |
# --- 2. Gradio Chatbot Logic with Streaming --- | |
def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil): | |
""" | |
This generator function handles the chat logic with streaming. | |
It yields the updated chatbot display at each step. | |
""" | |
# Check if an image has been uploaded | |
if image_pil is None: | |
chatbot_display.append((user_message, "Please upload an image first to start the conversation.")) | |
yield chatbot_display, messages_list | |
return # Stop the generator | |
# Append user's message to the conversation history and display | |
messages_list.append({"role": "user", "content": user_message}) | |
chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response | |
try: | |
# Use the processor to apply the chat template | |
prompt = processor.tokenizer.apply_chat_template( | |
messages_list, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Preprocess image and the entire formatted prompt | |
inputs = processor.process(images=[image_pil], text=prompt) | |
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} | |
# Setup the streamer | |
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Define generation configuration | |
generation_config = GenerationConfig( | |
max_new_tokens=512, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.6, | |
stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation | |
) | |
# Create generation kwargs for the thread | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
generation_config=generation_config | |
) | |
# Run generation in a separate thread | |
thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs) | |
thread.start() | |
# Yield updates to the Gradio UI | |
full_response = "" | |
for new_text in streamer: | |
full_response += new_text | |
chatbot_display[-1] = (user_message, full_response) | |
yield chatbot_display, messages_list | |
# After the loop, the generation is complete. | |
# Add the final full response to the messages list for context. | |
messages_list.append({"role": "assistant", "content": full_response}) | |
yield chatbot_display, messages_list # Yield the final state | |
except Exception as e: | |
print(f"Error during streaming inference: {e}") | |
error_message = f"Sorry, an error occurred: {e}" | |
chatbot_display[-1] = (user_message, error_message) | |
yield chatbot_display, messages_list | |
# --- 3. Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: | |
gr.Markdown("# π€ Patram-7B-Instruct Streaming Chatbot") | |
gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.") | |
# State variables to hold conversation history | |
messages_list = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
clear_btn = gr.Button("ποΈ Clear Chat and Image") | |
with gr.Column(scale=2): | |
chatbot_display = gr.Chatbot( | |
label="Conversation", | |
bubble_full_width=False, | |
height=500, | |
avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg") | |
) | |
with gr.Row(): | |
user_textbox = gr.Textbox( | |
placeholder="Type your question here...", | |
show_label=False, | |
scale=4, | |
container=False | |
) | |
# The submit button is now primarily for show; Enter key is the main way to submit | |
# but we will wire it up anyway. | |
# --- Event Listeners --- | |
# Define the action for submitting a message (via enter key) | |
submit_action = user_textbox.submit( | |
fn=process_chat_streaming, | |
inputs=[user_textbox, chatbot_display, messages_list, image_input], | |
outputs=[chatbot_display, messages_list], | |
# queue=False # Set queue to False for faster interaction with streaming | |
) | |
# Chain the action to also clear the textbox after submission | |
submit_action.then( | |
fn=lambda: gr.update(value=""), | |
inputs=None, | |
outputs=[user_textbox], | |
queue=False | |
) | |
# Define the action for the clear button | |
clear_btn.click( | |
fn=lambda: ([], [], None, ""), # Function to return empty/default values | |
inputs=[], | |
outputs=[chatbot_display, messages_list, image_input, user_textbox], | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, mcp_server=True) |