KingNish commited on
Commit
698861e
·
verified ·
1 Parent(s): c547944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -29
app.py CHANGED
@@ -1,10 +1,12 @@
1
  import torch
2
- from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
3
  from PIL import Image
4
  import gradio as gr
5
  import spaces
 
6
 
7
  # --- 1. Model and Processor Setup ---
 
8
  model_id = "bharatgenai/patram-7b-instruct"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  print(f"Using device: {device}")
@@ -37,15 +39,15 @@ processor.tokenizer.chat_template = chat_template
37
 
38
  # --- 2. Gradio Chatbot Logic ---
39
  @spaces.GPU
40
- def process_chat(user_message, chatbot_display, messages_list, image_pil):
41
- if image_pil is None:
42
- chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
43
- return chatbot_display, messages_list, ""
44
-
45
- messages_list.append({"role": "user", "content": user_message})
46
- chatbot_display.append((user_message, None))
47
-
48
  try:
 
 
 
 
49
  prompt = processor.tokenizer.apply_chat_template(
50
  messages_list,
51
  tokenize=False,
@@ -55,44 +57,87 @@ def process_chat(user_message, chatbot_display, messages_list, image_pil):
55
  # Preprocess image and the entire formatted prompt
56
  inputs = processor.process(images=[image_pil], text=prompt)
57
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
58
-
59
- # Ensure all tensors are in the same dtype
60
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Generate output using model's specific method
63
- output = model.generate_from_batch(
64
- inputs,
65
- GenerationConfig(max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings="<|endoftext|>"),
 
66
  tokenizer=processor.tokenizer
67
  )
68
 
69
- generated_tokens = output[0, inputs['input_ids'].size(1):]
70
- response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
 
71
 
72
- messages_list.append({"role": "assistant", "content": response})
73
- chatbot_display[-1] = (user_message, response)
 
74
 
75
  except Exception as e:
76
  print(f"Error during inference: {e}")
77
- error_message = f"Sorry, an error occurred during processing: {e}"
78
- chatbot_display[-1] = (user_message, error_message)
79
 
80
- return chatbot_display, messages_list, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- def clear_chat(chatbot_display, messages_list, image_input):
 
 
 
83
  """Resets the chat, history, and image."""
84
- return [], [], None, "Type your question here..."
85
 
86
  # --- 3. Gradio Interface Definition ---
87
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
88
  gr.Markdown("# 🤖 Patram-7B-Instruct Chatbot")
89
  gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
90
 
 
91
  messages_list = gr.State([])
 
 
92
  with gr.Row():
93
  with gr.Column(scale=1):
94
- image_input = gr.Image(type="pil", label="Upload Image")
95
  clear_btn = gr.Button("🗑️ Clear Chat and Image")
 
 
 
 
 
96
 
97
  with gr.Column(scale=2):
98
  chatbot_display = gr.Chatbot(
@@ -110,23 +155,33 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
110
  submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
111
 
112
  # --- Event Listeners ---
 
 
113
  submit_action = user_textbox.submit(
114
  fn=process_chat,
115
- inputs=[user_textbox, chatbot_display, messages_list, image_input],
116
  outputs=[chatbot_display, messages_list, user_textbox]
117
  )
118
  submit_btn.click(
119
  fn=process_chat,
120
- inputs=[user_textbox, chatbot_display, messages_list, image_input],
121
  outputs=[chatbot_display, messages_list, user_textbox]
122
  )
123
 
 
124
  clear_btn.click(
125
- fn=lambda: ([], [], None, ""),
126
  inputs=[],
127
- outputs=[chatbot_display, messages_list, image_input, user_textbox],
128
  queue=False
129
  )
130
 
 
 
 
 
 
 
 
131
  if __name__ == "__main__":
132
- demo.launch(mcp_server=True)
 
1
  import torch
2
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
3
  from PIL import Image
4
  import gradio as gr
5
  import spaces
6
+ import threading
7
 
8
  # --- 1. Model and Processor Setup ---
9
+
10
  model_id = "bharatgenai/patram-7b-instruct"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  print(f"Using device: {device}")
 
39
 
40
  # --- 2. Gradio Chatbot Logic ---
41
  @spaces.GPU
42
+ def generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
43
+ """
44
+ Generate a response from the model using streaming.
45
+ """
 
 
 
 
46
  try:
47
+ # Append user's message to the conversation history for the model
48
+ messages_list.append({"role": "user", "content": user_message})
49
+
50
+ # Use the processor to apply the chat template
51
  prompt = processor.tokenizer.apply_chat_template(
52
  messages_list,
53
  tokenize=False,
 
57
  # Preprocess image and the entire formatted prompt
58
  inputs = processor.process(images=[image_pil], text=prompt)
59
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
 
 
60
  inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
61
 
62
+ # Initialize the streamer
63
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
64
+
65
+ # Define generation config
66
+ generation_config = GenerationConfig(
67
+ max_new_tokens=max_new_tokens,
68
+ do_sample=True,
69
+ top_p=top_p,
70
+ top_k=top_k,
71
+ temperature=temperature,
72
+ stop_strings="<|endoftext|>"
73
+ )
74
+
75
  # Generate output using model's specific method
76
+ generate_kwargs = dict(
77
+ **inputs,
78
+ streamer=streamer,
79
+ generation_config=generation_config,
80
  tokenizer=processor.tokenizer
81
  )
82
 
83
+ # Start the generation in a separate thread to allow streaming
84
+ thread = threading.Thread(target=model.generate_from_batch, kwargs=generate_kwargs)
85
+ thread.start()
86
 
87
+ # Yield the generated tokens as they become available
88
+ for new_token in streamer:
89
+ yield new_token
90
 
91
  except Exception as e:
92
  print(f"Error during inference: {e}")
93
+ yield f"Sorry, an error occurred during processing: {e}"
 
94
 
95
+ def process_chat(user_message, chatbot_display, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
96
+ """
97
+ This function handles the chat logic for a single turn with streaming.
98
+ """
99
+ if image_pil is None:
100
+ chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
101
+ return chatbot_display, messages_list, ""
102
+
103
+ # Append user's message to the chatbot display list
104
+ chatbot_display.append((user_message, ""))
105
+
106
+ # Initialize the response as an empty string
107
+ response = ""
108
+
109
+ # Generate the response using streaming
110
+ for chunk in generate_response(user_message, messages_list, image_pil, max_new_tokens, top_p, top_k, temperature):
111
+ response += chunk
112
+ # Update the chatbot display with the current response
113
+ chatbot_display[-1] = (user_message, response)
114
+ yield chatbot_display, messages_list, ""
115
 
116
+ # Append assistant's response to the conversation history
117
+ messages_list.append({"role": "assistant", "content": response})
118
+
119
+ def clear_chat():
120
  """Resets the chat, history, and image."""
121
+ return [], [], None, "", 256, 0.9, 50, 0.6
122
 
123
  # --- 3. Gradio Interface Definition ---
124
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
125
  gr.Markdown("# 🤖 Patram-7B-Instruct Chatbot")
126
  gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
127
 
128
+ # State variables to hold conversation history and image
129
  messages_list = gr.State([])
130
+ image_input = gr.State(None)
131
+
132
  with gr.Row():
133
  with gr.Column(scale=1):
134
+ image_input_render = gr.Image(type="pil", label="Upload Image")
135
  clear_btn = gr.Button("🗑️ Clear Chat and Image")
136
+ with gr.Accordion("Generation Parameters", open=False):
137
+ max_new_tokens = gr.Slider(minimum=32, maximum=512, value=256, step=32, label="Max New Tokens")
138
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (Nucleus Sampling)")
139
+ top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
140
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.6, step=0.1, label="Temperature")
141
 
142
  with gr.Column(scale=2):
143
  chatbot_display = gr.Chatbot(
 
155
  submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
156
 
157
  # --- Event Listeners ---
158
+
159
+ # Define the action for submitting a message (via button or enter key)
160
  submit_action = user_textbox.submit(
161
  fn=process_chat,
162
+ inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
163
  outputs=[chatbot_display, messages_list, user_textbox]
164
  )
165
  submit_btn.click(
166
  fn=process_chat,
167
+ inputs=[user_textbox, chatbot_display, messages_list, image_input, max_new_tokens, top_p, top_k, temperature],
168
  outputs=[chatbot_display, messages_list, user_textbox]
169
  )
170
 
171
+ # Define the action for the clear button
172
  clear_btn.click(
173
+ fn=clear_chat,
174
  inputs=[],
175
+ outputs=[chatbot_display, messages_list, image_input_render, user_textbox, max_new_tokens, top_p, top_k, temperature],
176
  queue=False
177
  )
178
 
179
+ # Update the image state when a new image is uploaded
180
+ image_input_render.change(
181
+ fn=lambda x: x,
182
+ inputs=image_input_render,
183
+ outputs=image_input
184
+ )
185
+
186
  if __name__ == "__main__":
187
+ demo.launch()