File size: 7,970 Bytes
e44dcbf
698861e
e44dcbf
 
27f604a
698861e
e44dcbf
86d52bb
ce2ebf6
 
 
 
 
 
 
 
98aa9c9
 
ce2ebf6
 
 
 
7ee9a0d
 
 
 
 
98aa9c9
 
 
 
 
 
ce2ebf6
e44dcbf
6493390
27f604a
698861e
 
 
 
e44dcbf
98aa9c9
 
 
 
 
698861e
 
e44dcbf
98aa9c9
e44dcbf
 
 
 
7ee9a0d
 
 
 
 
e44dcbf
c547944
86d52bb
698861e
 
 
 
 
 
 
 
 
 
689b8e4
 
698861e
 
6493390
698861e
689b8e4
698861e
689b8e4
e44dcbf
 
698861e
 
 
e44dcbf
698861e
98aa9c9
698861e
98aa9c9
 
e44dcbf
 
6493390
698861e
6493390
698861e
 
 
 
 
 
 
 
 
98aa9c9
698861e
98aa9c9
698861e
 
 
6493390
698861e
 
 
 
6493390
7ee9a0d
6493390
e44dcbf
 
6493390
 
e44dcbf
698861e
e44dcbf
698861e
 
e44dcbf
 
698861e
e44dcbf
698861e
2fadbdb
3838d2a
 
698861e
3838d2a
e44dcbf
 
 
 
 
6493390
e44dcbf
 
 
 
 
 
 
 
6493390
 
7ee9a0d
 
 
 
 
 
 
 
 
 
 
 
 
 
e44dcbf
698861e
 
e44dcbf
6493390
698861e
6493390
e44dcbf
6493390
 
698861e
6493390
e44dcbf
 
698861e
e44dcbf
698861e
e44dcbf
7ee9a0d
e44dcbf
 
 
698861e
 
 
 
 
 
 
e44dcbf
98aa9c9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from PIL import Image
import gradio as gr
import spaces
import threading

# --- 1. Model and Processor Setup ---
model_id = "bharatgenai/patram-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load processor and model
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
print("Model and processor loaded successfully.")

# Default system prompt
DEFAULT_SYSTEM_PROMPT = """You are Patram, a helpful AI assistant created by BharatGenAI. You are designed to analyze images and answer questions about them.
Think step by step before providing your answers. Be detailed, accurate, and helpful in your responses.
You can understand both text and image inputs to provide comprehensive answers to user queries."""

# --- Define and apply a more flexible chat template ---
chat_template = """{% for message in messages %}
    {{ message['role'].capitalize() }}: {{ message['content'] }}
    {% if not loop.last %}{{ '\n' }}{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"""
processor.tokenizer.chat_template = chat_template

# --- 2. Gradio Chatbot Logic ---
@spaces.GPU
def generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
    """
    Generate a response from the model using streaming.
    """
    try:
        # Create a copy of the messages list to avoid modifying the original
        current_messages = messages_list.copy()
        current_messages.append({"role": "user", "content": user_message})

        print(current_messages)

        # Use the processor to apply the chat template
        prompt = processor.tokenizer.apply_chat_template(
            current_messages,
            tokenize=False,
            add_generation_prompt=True
        )

        if image_pil:
            # Preprocess image and the entire formatted prompt
            inputs = processor.process(images=[image_pil], text=prompt)
        else: 
            inputs = processor.process(text=prompt)
        inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
        inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}

        # Initialize the streamer
        streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)

        # Define generation config
        generation_config = GenerationConfig(
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=top_p,
            top_k=top_k,
            temperature=temperature,
            eos_token_id=processor.tokenizer.eos_token_id,
            pad_token_id=processor.tokenizer.pad_token_id
        )

        # Generate output using model's specific method
        generate_kwargs = dict(
            batch=inputs,
            streamer=streamer,
            generation_config=generation_config
        )

        # Start the generation in a separate thread to allow streaming
        thread = threading.Thread(target=model.generate_from_batch, kwargs=generate_kwargs)
        thread.start()

        # Yield the generated tokens as they become available
        response = ""
        for new_token in streamer:
            response += new_token
            yield response

    except Exception as e:
        print(f"Error during inference: {e}")
        yield f"Sorry, an error occurred during processing: {e}"

def process_chat(user_message, chatbot_display, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
    """
    This function handles the chat logic for a single turn with streaming.
    """

    # Append user's message to the chatbot display list
    chatbot_display.append((user_message, ""))

    # Generate the response using streaming
    response = ""
    for chunk in generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
        response = chunk
        # Update the chatbot display with the current response
        chatbot_display[-1] = (user_message, response)
        yield chatbot_display, messages_list, ""

    # Append assistant's response to the conversation history
    messages_list.append({"role": "assistant", "content": response})

def clear_chat():
    """Resets the chat, history, and image."""
    return [], [], None, "", 256, 0.9, 50, 0.6, DEFAULT_SYSTEM_PROMPT

# --- 3. Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
    gr.Markdown("# πŸ€– Patram-7B-Instruct Chatbot")
    gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")

    # State variables to hold conversation history and image
    messages_list = gr.State([])
    image_input = gr.State(None)

    with gr.Row():
        with gr.Column(scale=1):
            image_input_render = gr.Image(type="pil", label="Upload Image")
            clear_btn = gr.Button("πŸ—‘οΈ Clear Chat and Image")
            with gr.Accordion("Generation Parameters", open=False):
                system_prompt = gr.Textbox( label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, interactive=True, lines=5)
                max_new_tokens = gr.Slider(minimum=32, maximum=4096, value=256, step=32, label="Max New Tokens")
                top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (Nucleus Sampling)")
                top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
                temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.05, label="Temperature")

        with gr.Column(scale=2):
            chatbot_display = gr.Chatbot(
                label="Conversation",
                bubble_full_width=False,
                height=500
            )
            with gr.Row():
                user_textbox = gr.Textbox(
                    placeholder="Type your question here...",
                    show_label=False,
                    scale=4,
                    container=False
                )
                submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)

    # Initialize messages_list with system prompt
    demo.load(
        fn=lambda: [{"role": "system", "content": DEFAULT_SYSTEM_PROMPT}],
        inputs=None,
        outputs=messages_list
    )

    # Update messages_list when system prompt changes
    system_prompt.change(
        fn=lambda system_prompt: [{"role": "system", "content": system_prompt}],
        inputs=system_prompt,
        outputs=messages_list
    )

    # --- Event Listeners ---

    # Define the action for submitting a message (via button or enter key)
    submit_action = user_textbox.submit(
        fn=process_chat,
        inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
        outputs=[chatbot_display, messages_list, user_textbox]
    )
    submit_btn.click(
        fn=process_chat,
        inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
        outputs=[chatbot_display, messages_list, user_textbox]
    )

    # Define the action for the clear button
    clear_btn.click(
        fn=clear_chat,
        inputs=[],
        outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature, system_prompt],
        queue=False
    )

    # Update the image state when a new image is uploaded
    image_input_render.change(
        fn=lambda x: x,
        inputs=image_input_render,
        outputs=image_input
    )

if __name__ == "__main__":
    demo.launch(mcp_server=True)