Patram-7b-Demo / app.py
KingNish's picture
Update app.py
2fadbdb verified
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from PIL import Image
import gradio as gr
import spaces
import threading
# --- 1. Model and Processor Setup ---
model_id = "bharatgenai/patram-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load processor and model
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("Model and processor loaded successfully.")
# Default system prompt
DEFAULT_SYSTEM_PROMPT = """You are Patram, a helpful AI assistant created by BharatGenAI. You are designed to analyze images and answer questions about them.
Think step by step before providing your answers. Be detailed, accurate, and helpful in your responses.
You can understand both text and image inputs to provide comprehensive answers to user queries."""
# --- Define and apply a more flexible chat template ---
chat_template = """{% for message in messages %}
{{ message['role'].capitalize() }}: {{ message['content'] }}
{% if not loop.last %}{{ '\n' }}{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"""
processor.tokenizer.chat_template = chat_template
# --- 2. Gradio Chatbot Logic ---
@spaces.GPU
def generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
"""
Generate a response from the model using streaming.
"""
try:
# Create a copy of the messages list to avoid modifying the original
current_messages = messages_list.copy()
current_messages.append({"role": "user", "content": user_message})
print(current_messages)
# Use the processor to apply the chat template
prompt = processor.tokenizer.apply_chat_template(
current_messages,
tokenize=False,
add_generation_prompt=True
)
if image_pil:
# Preprocess image and the entire formatted prompt
inputs = processor.process(images=[image_pil], text=prompt)
else:
inputs = processor.process(text=prompt)
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
# Initialize the streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define generation config
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
# Generate output using model's specific method
generate_kwargs = dict(
batch=inputs,
streamer=streamer,
generation_config=generation_config
)
# Start the generation in a separate thread to allow streaming
thread = threading.Thread(target=model.generate_from_batch, kwargs=generate_kwargs)
thread.start()
# Yield the generated tokens as they become available
response = ""
for new_token in streamer:
response += new_token
yield response
except Exception as e:
print(f"Error during inference: {e}")
yield f"Sorry, an error occurred during processing: {e}"
def process_chat(user_message, chatbot_display, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
"""
This function handles the chat logic for a single turn with streaming.
"""
# Append user's message to the chatbot display list
chatbot_display.append((user_message, ""))
# Generate the response using streaming
response = ""
for chunk in generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
response = chunk
# Update the chatbot display with the current response
chatbot_display[-1] = (user_message, response)
yield chatbot_display, messages_list, ""
# Append assistant's response to the conversation history
messages_list.append({"role": "assistant", "content": response})
def clear_chat():
"""Resets the chat, history, and image."""
return [], [], None, "", 256, 0.9, 50, 0.6, DEFAULT_SYSTEM_PROMPT
# --- 3. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
gr.Markdown("# πŸ€– Patram-7B-Instruct Chatbot")
gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
# State variables to hold conversation history and image
messages_list = gr.State([])
image_input = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
image_input_render = gr.Image(type="pil", label="Upload Image")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat and Image")
with gr.Accordion("Generation Parameters", open=False):
system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, interactive=True, lines=5)
max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (Nucleus Sampling)")
top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature")
with gr.Column(scale=2):
chatbot_display = gr.Chatbot(
label="Conversation",
bubble_full_width=False,
height=500
)
with gr.Row():
user_textbox = gr.Textbox(
placeholder="Type your question here...",
show_label=False,
scale=4,
container=False
)
submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
# Initialize messages_list with system prompt
demo.load(
fn=lambda: [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}],
inputs=None,
outputs=messages_list
)
# Update messages_list when system prompt changes
system_prompt.change(
fn=lambda system_prompt: [{"role": "system", "content": system_prompt}],
inputs=system_prompt,
outputs=messages_list
)
# --- Event Listeners ---
# Define the action for submitting a message (via button or enter key)
submit_action = user_textbox.submit(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
outputs=[chatbot_display, messages_list, user_textbox]
)
submit_btn.click(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
outputs=[chatbot_display, messages_list, user_textbox]
)
# Define the action for the clear button
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature, system_prompt],
queue=False
)
# Update the image state when a new image is uploaded
image_input_render.change(
fn=lambda x: x,
inputs=image_input_render,
outputs=image_input
)
if __name__ == "__main__":
demo.launch(mcp_server=True)