Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,970 Bytes
e44dcbf 698861e e44dcbf 27f604a 698861e e44dcbf 86d52bb ce2ebf6 98aa9c9 ce2ebf6 7ee9a0d 98aa9c9 ce2ebf6 e44dcbf 6493390 27f604a 698861e e44dcbf 98aa9c9 698861e e44dcbf 98aa9c9 e44dcbf 7ee9a0d e44dcbf c547944 86d52bb 698861e 689b8e4 698861e 6493390 698861e 689b8e4 698861e 689b8e4 e44dcbf 698861e e44dcbf 698861e 98aa9c9 698861e 98aa9c9 e44dcbf 6493390 698861e 6493390 698861e 98aa9c9 698861e 98aa9c9 698861e 6493390 698861e 6493390 7ee9a0d 6493390 e44dcbf 6493390 e44dcbf 698861e e44dcbf 698861e e44dcbf 698861e e44dcbf 698861e 2fadbdb 3838d2a 698861e 3838d2a e44dcbf 6493390 e44dcbf 6493390 7ee9a0d e44dcbf 698861e e44dcbf 6493390 698861e 6493390 e44dcbf 6493390 698861e 6493390 e44dcbf 698861e e44dcbf 698861e e44dcbf 7ee9a0d e44dcbf 698861e e44dcbf 98aa9c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from PIL import Image
import gradio as gr
import spaces
import threading
# --- 1. Model and Processor Setup ---
model_id = "bharatgenai/patram-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load processor and model
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
print("Model and processor loaded successfully.")
# Default system prompt
DEFAULT_SYSTEM_PROMPT = """You are Patram, a helpful AI assistant created by BharatGenAI. You are designed to analyze images and answer questions about them.
Think step by step before providing your answers. Be detailed, accurate, and helpful in your responses.
You can understand both text and image inputs to provide comprehensive answers to user queries."""
# --- Define and apply a more flexible chat template ---
chat_template = """{% for message in messages %}
{{ message['role'].capitalize() }}: {{ message['content'] }}
{% if not loop.last %}{{ '\n' }}{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"""
processor.tokenizer.chat_template = chat_template
# --- 2. Gradio Chatbot Logic ---
@spaces.GPU
def generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
"""
Generate a response from the model using streaming.
"""
try:
# Create a copy of the messages list to avoid modifying the original
current_messages = messages_list.copy()
current_messages.append({"role": "user", "content": user_message})
print(current_messages)
# Use the processor to apply the chat template
prompt = processor.tokenizer.apply_chat_template(
current_messages,
tokenize=False,
add_generation_prompt=True
)
if image_pil:
# Preprocess image and the entire formatted prompt
inputs = processor.process(images=[image_pil], text=prompt)
else:
inputs = processor.process(text=prompt)
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
# Initialize the streamer
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
# Define generation config
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
eos_token_id=processor.tokenizer.eos_token_id,
pad_token_id=processor.tokenizer.pad_token_id
)
# Generate output using model's specific method
generate_kwargs = dict(
batch=inputs,
streamer=streamer,
generation_config=generation_config
)
# Start the generation in a separate thread to allow streaming
thread = threading.Thread(target=model.generate_from_batch, kwargs=generate_kwargs)
thread.start()
# Yield the generated tokens as they become available
response = ""
for new_token in streamer:
response += new_token
yield response
except Exception as e:
print(f"Error during inference: {e}")
yield f"Sorry, an error occurred during processing: {e}"
def process_chat(user_message, chatbot_display, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
"""
This function handles the chat logic for a single turn with streaming.
"""
# Append user's message to the chatbot display list
chatbot_display.append((user_message, ""))
# Generate the response using streaming
response = ""
for chunk in generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
response = chunk
# Update the chatbot display with the current response
chatbot_display[-1] = (user_message, response)
yield chatbot_display, messages_list, ""
# Append assistant's response to the conversation history
messages_list.append({"role": "assistant", "content": response})
def clear_chat():
"""Resets the chat, history, and image."""
return [], [], None, "", 256, 0.9, 50, 0.6, DEFAULT_SYSTEM_PROMPT
# --- 3. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
gr.Markdown("# π€ Patram-7B-Instruct Chatbot")
gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
# State variables to hold conversation history and image
messages_list = gr.State([])
image_input = gr.State(None)
with gr.Row():
with gr.Column(scale=1):
image_input_render = gr.Image(type="pil", label="Upload Image")
clear_btn = gr.Button("ποΈ Clear Chat and Image")
with gr.Accordion("Generation Parameters", open=False):
system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, interactive=True, lines=5)
max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (Nucleus Sampling)")
top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature")
with gr.Column(scale=2):
chatbot_display = gr.Chatbot(
label="Conversation",
bubble_full_width=False,
height=500
)
with gr.Row():
user_textbox = gr.Textbox(
placeholder="Type your question here...",
show_label=False,
scale=4,
container=False
)
submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
# Initialize messages_list with system prompt
demo.load(
fn=lambda: [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}],
inputs=None,
outputs=messages_list
)
# Update messages_list when system prompt changes
system_prompt.change(
fn=lambda system_prompt: [{"role": "system", "content": system_prompt}],
inputs=system_prompt,
outputs=messages_list
)
# --- Event Listeners ---
# Define the action for submitting a message (via button or enter key)
submit_action = user_textbox.submit(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
outputs=[chatbot_display, messages_list, user_textbox]
)
submit_btn.click(
fn=process_chat,
inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
outputs=[chatbot_display, messages_list, user_textbox]
)
# Define the action for the clear button
clear_btn.click(
fn=clear_chat,
inputs=[],
outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature, system_prompt],
queue=False
)
# Update the image state when a new image is uploaded
image_input_render.change(
fn=lambda x: x,
inputs=image_input_render,
outputs=image_input
)
if __name__ == "__main__":
demo.launch(mcp_server=True) |