Patram-7b-Demo / app.py
KingNish's picture
Update app.py
27f604a verified
raw
history blame
6.82 kB
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from PIL import Image
import gradio as gr
from threading import Thread
import spaces
# --- 1. Model and Processor Setup ---
# This part is loaded only once when the script starts.
try:
model_id = "bharatgenai/patram-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load processor and model
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for less memory usage on GPU
device_map="auto", # Automatically uses available GPUs
trust_remote_code=True
)
print("Model and processor loaded successfully.")
# --- Define and apply the chat template ---
# This is crucial for multi-turn conversation
chat_template = """{% for message in messages -%}
{%- if (loop.index % 2 == 1 and message['role'] != 'user') or
(loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif -%}
{{ message['role'].capitalize() + ': ' + message['content'] }}
{%- if not loop.last -%}
{{ ' ' }}
{%- endif %}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{ ' Assistant:' }}
{%- endif %}"""
processor.tokenizer.chat_template = chat_template
except Exception as e:
print(f"Error during model loading: {e}")
# Exit if model can't be loaded, as the app is unusable.
exit()
# --- 2. Gradio Chatbot Logic with Streaming ---
@spaces.GPU
def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil):
"""
This generator function handles the chat logic with streaming.
It yields the updated chatbot display at each step.
"""
# Check if an image has been uploaded
if image_pil is None:
chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
yield chatbot_display, messages_list
return # Stop the generator
# Append user's message to the conversation history and display
messages_list.append({"role": "user", "content": user_message})
chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response
try:
# Use the processor to apply the chat template
prompt = processor.tokenizer.apply_chat_template(
messages_list,
tokenize=False,
add_generation_prompt=True
)
# Preprocess image and the entire formatted prompt
inputs = processor.process(images=[image_pil], text=prompt)
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
# Setup the streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define generation configuration
generation_config = GenerationConfig(
max_new_tokens=512,
do_sample=True,
top_p=0.9,
temperature=0.6,
stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation
)
# Create generation kwargs for the thread
generation_kwargs = dict(
inputs,
streamer=streamer,
generation_config=generation_config
)
# Run generation in a separate thread
thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs)
thread.start()
# Yield updates to the Gradio UI
full_response = ""
for new_text in streamer:
full_response += new_text
chatbot_display[-1] = (user_message, full_response)
yield chatbot_display, messages_list
# After the loop, the generation is complete.
# Add the final full response to the messages list for context.
messages_list.append({"role": "assistant", "content": full_response})
yield chatbot_display, messages_list # Yield the final state
except Exception as e:
print(f"Error during streaming inference: {e}")
error_message = f"Sorry, an error occurred: {e}"
chatbot_display[-1] = (user_message, error_message)
yield chatbot_display, messages_list
# --- 3. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
gr.Markdown("# πŸ€– Patram-7B-Instruct Streaming Chatbot")
gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.")
# State variables to hold conversation history
messages_list = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat and Image")
with gr.Column(scale=2):
chatbot_display = gr.Chatbot(
label="Conversation",
bubble_full_width=False,
height=500,
avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg")
)
with gr.Row():
user_textbox = gr.Textbox(
placeholder="Type your question here...",
show_label=False,
scale=4,
container=False
)
# The submit button is now primarily for show; Enter key is the main way to submit
# but we will wire it up anyway.
# --- Event Listeners ---
# Define the action for submitting a message (via enter key)
submit_action = user_textbox.submit(
fn=process_chat_streaming,
inputs=[user_textbox, chatbot_display, messages_list, image_input],
outputs=[chatbot_display, messages_list],
# queue=False # Set queue to False for faster interaction with streaming
)
# Chain the action to also clear the textbox after submission
submit_action.then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=[user_textbox],
queue=False
)
# Define the action for the clear button
clear_btn.click(
fn=lambda: ([], [], None, ""), # Function to return empty/default values
inputs=[],
outputs=[chatbot_display, messages_list, image_input, user_textbox],
queue=False
)
if __name__ == "__main__":
demo.launch(debug=True, mcp_server=True)