import torch from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer from PIL import Image import gradio as gr import spaces import threading # --- 1. Model and Processor Setup --- model_id = "bharatgenai/patram-7b-instruct" device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load processor and model processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) print("Model and processor loaded successfully.") # Default system prompt DEFAULT_SYSTEM_PROMPT = """You are Patram, a helpful AI assistant created by BharatGenAI. You are designed to analyze images and answer questions about them. Think step by step before providing your answers. Be detailed, accurate, and helpful in your responses. You can understand both text and image inputs to provide comprehensive answers to user queries.""" # --- Define and apply a more flexible chat template --- chat_template = """{% for message in messages %} {{ message['role'].capitalize() }}: {{ message['content'] }} {% if not loop.last %}{{ '\n' }}{% endif %} {% endfor %} {% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}""" processor.tokenizer.chat_template = chat_template # --- 2. Gradio Chatbot Logic --- @spaces.GPU def generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature): """ Generate a response from the model using streaming. """ try: # Create a copy of the messages list to avoid modifying the original current_messages = messages_list.copy() current_messages.append({"role": "user", "content": user_message}) print(current_messages) # Use the processor to apply the chat template prompt = processor.tokenizer.apply_chat_template( current_messages, tokenize=False, add_generation_prompt=True ) if image_pil: # Preprocess image and the entire formatted prompt inputs = processor.process(images=[image_pil], text=prompt) else: inputs = processor.process(text=prompt) inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()} inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()} # Initialize the streamer streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True) # Define generation config generation_config = GenerationConfig( max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, eos_token_id=processor.tokenizer.eos_token_id, pad_token_id=processor.tokenizer.pad_token_id ) # Generate output using model's specific method generate_kwargs = dict( batch=inputs, streamer=streamer, generation_config=generation_config ) # Start the generation in a separate thread to allow streaming thread = threading.Thread(target=model.generate_from_batch, kwargs=generate_kwargs) thread.start() # Yield the generated tokens as they become available response = "" for new_token in streamer: response += new_token yield response except Exception as e: print(f"Error during inference: {e}") yield f"Sorry, an error occurred during processing: {e}" def process_chat(user_message, chatbot_display, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature): """ This function handles the chat logic for a single turn with streaming. """ # Append user's message to the chatbot display list chatbot_display.append((user_message, "")) # Generate the response using streaming response = "" for chunk in generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature): response = chunk # Update the chatbot display with the current response chatbot_display[-1] = (user_message, response) yield chatbot_display, messages_list, "" # Append assistant's response to the conversation history messages_list.append({"role": "assistant", "content": response}) def clear_chat(): """Resets the chat, history, and image.""" return [], [], None, "", 256, 0.9, 50, 0.6, DEFAULT_SYSTEM_PROMPT # --- 3. Gradio Interface Definition --- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo: gr.Markdown("# 🤖 Patram-7B-Instruct Chatbot") gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.") # State variables to hold conversation history and image messages_list = gr.State([]) image_input = gr.State(None) with gr.Row(): with gr.Column(scale=1): image_input_render = gr.Image(type="pil", label="Upload Image") clear_btn = gr.Button("🗑️ Clear Chat and Image") with gr.Accordion("Generation Parameters", open=False): system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, interactive=True, lines=5) max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (Nucleus Sampling)") top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k") temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature") with gr.Column(scale=2): chatbot_display = gr.Chatbot( label="Conversation", bubble_full_width=False, height=500 ) with gr.Row(): user_textbox = gr.Textbox( placeholder="Type your question here...", show_label=False, scale=4, container=False ) submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0) # Initialize messages_list with system prompt demo.load( fn=lambda: [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}], inputs=None, outputs=messages_list ) # Update messages_list when system prompt changes system_prompt.change( fn=lambda system_prompt: [{"role": "system", "content": system_prompt}], inputs=system_prompt, outputs=messages_list ) # --- Event Listeners --- # Define the action for submitting a message (via button or enter key) submit_action = user_textbox.submit( fn=process_chat, inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature], outputs=[chatbot_display, messages_list, user_textbox] ) submit_btn.click( fn=process_chat, inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature], outputs=[chatbot_display, messages_list, user_textbox] ) # Define the action for the clear button clear_btn.click( fn=clear_chat, inputs=[], outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature, system_prompt], queue=False ) # Update the image state when a new image is uploaded image_input_render.change( fn=lambda x: x, inputs=image_input_render, outputs=image_input ) if __name__ == "__main__": demo.launch(mcp_server=True)