Patram-7b-Demo / app.py
KingNish's picture
Update app.py
c547944 verified
raw
history blame
5.07 kB
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
from PIL import Image
import gradio as gr
import spaces
# --- 1. Model and Processor Setup ---
model_id = "bharatgenai/patram-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load processor and model
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16, # Use float16 for less memory usage on GPU
device_map="auto", # Automatically uses available GPUs
trust_remote_code=True
)
print("Model and processor loaded successfully.")
# --- Define and apply the chat template ---
chat_template = """{% for message in messages -%}
{%- if (loop.index % 2 == 1 and message['role'] != 'user') or
(loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif -%}
{{ message['role'].capitalize() + ': ' + message['content'] }}
{%- if not loop.last -%}
{{ ' ' }}
{%- endif %}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{ ' Assistant:' }}
{%- endif %}"""
processor.tokenizer.chat_template = chat_template
# --- 2. Gradio Chatbot Logic ---
@spaces.GPU
def process_chat(user_message, chatbot_display, messages_list, image_pil):
if image_pil is None:
chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
return chatbot_display, messages_list, ""
messages_list.append({"role": "user", "content": user_message})
chatbot_display.append((user_message, None))
try:
prompt = processor.tokenizer.apply_chat_template(
messages_list,
tokenize=False,
add_generation_prompt=True
)
# Preprocess image and the entire formatted prompt
inputs = processor.process(images=[image_pil], text=prompt)
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
# Ensure all tensors are in the same dtype
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
# Generate output using model's specific method
output = model.generate_from_batch(
inputs,
GenerationConfig(max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings="<|endoftext|>"),
tokenizer=processor.tokenizer
)
generated_tokens = output[0, inputs['input_ids'].size(1):]
response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
messages_list.append({"role": "assistant", "content": response})
chatbot_display[-1] = (user_message, response)
except Exception as e:
print(f"Error during inference: {e}")
error_message = f"Sorry, an error occurred during processing: {e}"
chatbot_display[-1] = (user_message, error_message)
return chatbot_display, messages_list, ""
def clear_chat(chatbot_display, messages_list, image_input):
"""Resets the chat, history, and image."""
return [], [], None, "Type your question here..."
# --- 3. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
gr.Markdown("# πŸ€– Patram-7B-Instruct Chatbot")
gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
messages_list = gr.State([])
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat and Image")
with gr.Column(scale=2):
chatbot_display = gr.Chatbot(
label="Conversation",
bubble_full_width=False,
height=500
)
with gr.Row():
user_textbox = gr.Textbox(
placeholder="Type your question here...",
show_label=False,
scale=4,
container=False
)
submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
# --- Event Listeners ---
submit_action = user_textbox.submit(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input],
outputs=[chatbot_display, messages_list, user_textbox]
)
submit_btn.click(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input],
outputs=[chatbot_display, messages_list, user_textbox]
)
clear_btn.click(
fn=lambda: ([], [], None, ""),
inputs=[],
outputs=[chatbot_display, messages_list, image_input, user_textbox],
queue=False
)
if __name__ == "__main__":
demo.launch(mcp_server=True)