Spaces:
Paused
Paused
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig | |
from PIL import Image | |
import gradio as gr | |
import spaces | |
# --- 1. Model and Processor Setup --- | |
model_id = "bharatgenai/patram-7b-instruct" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load processor and model | |
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, # Use float16 for less memory usage on GPU | |
device_map="auto", # Automatically uses available GPUs | |
trust_remote_code=True | |
) | |
print("Model and processor loaded successfully.") | |
# --- Define and apply the chat template --- | |
chat_template = """{% for message in messages -%} | |
{%- if (loop.index % 2 == 1 and message['role'] != 'user') or | |
(loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%} | |
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} | |
{%- endif -%} | |
{{ message['role'].capitalize() + ': ' + message['content'] }} | |
{%- if not loop.last -%} | |
{{ ' ' }} | |
{%- endif %} | |
{%- endfor -%} | |
{%- if add_generation_prompt -%} | |
{{ ' Assistant:' }} | |
{%- endif %}""" | |
processor.tokenizer.chat_template = chat_template | |
# --- 2. Gradio Chatbot Logic --- | |
def process_chat(user_message, chatbot_display, messages_list, image_pil): | |
if image_pil is None: | |
chatbot_display.append((user_message, "Please upload an image first to start the conversation.")) | |
return chatbot_display, messages_list, "" | |
messages_list.append({"role": "user", "content": user_message}) | |
chatbot_display.append((user_message, None)) | |
try: | |
prompt = processor.tokenizer.apply_chat_template( | |
messages_list, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Preprocess image and the entire formatted prompt | |
inputs = processor.process(images=[image_pil], text=prompt) | |
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} | |
# Ensure all tensors are in the same dtype | |
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} | |
# Generate output using model's specific method | |
output = model.generate_from_batch( | |
inputs, | |
GenerationConfig(max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings="<|endoftext|>"), | |
tokenizer=processor.tokenizer | |
) | |
generated_tokens = output[0, inputs['input_ids'].size(1):] | |
response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
messages_list.append({"role": "assistant", "content": response}) | |
chatbot_display[-1] = (user_message, response) | |
except Exception as e: | |
print(f"Error during inference: {e}") | |
error_message = f"Sorry, an error occurred during processing: {e}" | |
chatbot_display[-1] = (user_message, error_message) | |
return chatbot_display, messages_list, "" | |
def clear_chat(chatbot_display, messages_list, image_input): | |
"""Resets the chat, history, and image.""" | |
return [], [], None, "Type your question here..." | |
# --- 3. Gradio Interface Definition --- | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: | |
gr.Markdown("# π€ Patram-7B-Instruct Chatbot") | |
gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.") | |
messages_list = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
clear_btn = gr.Button("ποΈ Clear Chat and Image") | |
with gr.Column(scale=2): | |
chatbot_display = gr.Chatbot( | |
label="Conversation", | |
bubble_full_width=False, | |
height=500 | |
) | |
with gr.Row(): | |
user_textbox = gr.Textbox( | |
placeholder="Type your question here...", | |
show_label=False, | |
scale=4, | |
container=False | |
) | |
submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0) | |
# --- Event Listeners --- | |
submit_action = user_textbox.submit( | |
fn=process_chat, | |
inputs=[user_textbox, chatbot_display, messages_list, image_input], | |
outputs=[chatbot_display, messages_list, user_textbox] | |
) | |
submit_btn.click( | |
fn=process_chat, | |
inputs=[user_textbox, chatbot_display, messages_list, image_input], | |
outputs=[chatbot_display, messages_list, user_textbox] | |
) | |
clear_btn.click( | |
fn=lambda: ([], [], None, ""), | |
inputs=[], | |
outputs=[chatbot_display, messages_list, image_input, user_textbox], | |
queue=False | |
) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |