KingNish commited on
Commit
6493390
·
verified ·
1 Parent(s): 20b047b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -61
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import torch
2
- from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
3
  from PIL import Image
4
  import gradio as gr
5
- from threading import Thread
6
  import spaces
7
 
8
  # --- 1. Model and Processor Setup ---
@@ -38,22 +37,32 @@ chat_template = """{% for message in messages -%}
38
  {%- endif %}"""
39
  processor.tokenizer.chat_template = chat_template
40
 
41
- # --- 2. Gradio Chatbot Logic with Streaming ---
42
  @spaces.GPU
43
- def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil):
44
  """
45
- This generator function handles the chat logic with streaming.
46
- It yields the updated chatbot display at each step.
 
 
 
 
 
 
 
 
47
  """
48
  # Check if an image has been uploaded
49
  if image_pil is None:
 
50
  chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
51
- yield chatbot_display, messages_list
52
- return # Stop the generator
53
 
54
- # Append user's message to the conversation history and display
55
  messages_list.append({"role": "user", "content": user_message})
56
- chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response
 
 
57
 
58
  try:
59
  # Use the processor to apply the chat template
@@ -64,61 +73,52 @@ def process_chat_streaming(user_message, chatbot_display, messages_list, image_p
64
  )
65
 
66
  # Preprocess image and the entire formatted prompt
 
67
  inputs = processor.process(images=[image_pil], text=prompt)
68
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
69
 
70
- # Setup the streamer
71
- streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
72
-
73
- # Define generation configuration
74
- generation_config = GenerationConfig(
75
- max_new_tokens=512,
76
- do_sample=True,
77
- top_p=0.9,
78
- temperature=0.6,
79
- stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation
80
  )
81
 
82
- # *** THE FIX IS HERE ***
83
- # We must pass 'inputs' as a positional argument for 'batch'
84
- # and the rest as keyword arguments.
85
- thread = Thread(
86
- target=model.generate_from_batch,
87
- args=[inputs], # Pass `inputs` as the first positional argument ('batch')
88
- kwargs={ # Pass the rest as keyword arguments
89
- "generation_config": generation_config,
90
- "tokenizer": processor.tokenizer,
91
- "streamer": streamer,
92
- }
93
- )
94
- thread.start()
95
 
96
- # Yield updates to the Gradio UI
97
- full_response = ""
98
- for new_text in streamer:
99
- full_response += new_text
100
- chatbot_display[-1] = (user_message, full_response)
101
- yield chatbot_display, messages_list
102
 
103
- # After the loop, the generation is complete.
104
- # Add the final full response to the messages list for context.
105
- messages_list.append({"role": "assistant", "content": full_response})
106
- yield chatbot_display, messages_list # Yield the final state
107
 
108
  except Exception as e:
109
- print(f"Error during streaming inference: {e}")
110
- error_message = f"Sorry, an error occurred: {e}"
 
111
  chatbot_display[-1] = (user_message, error_message)
112
- yield chatbot_display, messages_list
 
 
 
 
 
 
 
 
113
 
114
  # --- 3. Gradio Interface Definition ---
115
 
116
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
117
- gr.Markdown("# 🤖 Patram-7B-Instruct Streaming Chatbot")
118
- gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.")
119
 
120
- # State variables to hold conversation history
121
  messages_list = gr.State([])
 
 
122
 
123
  with gr.Row():
124
  with gr.Column(scale=1):
@@ -129,8 +129,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
129
  chatbot_display = gr.Chatbot(
130
  label="Conversation",
131
  bubble_full_width=False,
132
- height=500,
133
- avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg")
134
  )
135
  with gr.Row():
136
  user_textbox = gr.Textbox(
@@ -139,22 +138,21 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutra
139
  scale=4,
140
  container=False
141
  )
 
 
142
 
143
  # --- Event Listeners ---
144
 
145
- # Define the action for submitting a message (via enter key)
146
  submit_action = user_textbox.submit(
147
- fn=process_chat_streaming,
148
  inputs=[user_textbox, chatbot_display, messages_list, image_input],
149
- outputs=[chatbot_display, messages_list],
150
  )
151
-
152
- # Chain the action to also clear the textbox after submission
153
- submit_action.then(
154
- fn=lambda: gr.update(value=""),
155
- inputs=None,
156
- outputs=[user_textbox],
157
- queue=False
158
  )
159
 
160
  # Define the action for the clear button
 
1
  import torch
2
+ from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig
3
  from PIL import Image
4
  import gradio as gr
 
5
  import spaces
6
 
7
  # --- 1. Model and Processor Setup ---
 
37
  {%- endif %}"""
38
  processor.tokenizer.chat_template = chat_template
39
 
40
+ # --- 2. Gradio Chatbot Logic ---
41
  @spaces.GPU
42
+ def process_chat(user_message, chatbot_display, messages_list, image_pil):
43
  """
44
+ This function handles the chat logic for a single turn.
45
+
46
+ Args:
47
+ user_message (str): The new message from the user.
48
+ chatbot_display (list): The current state of the Gradio chatbot display.
49
+ messages_list (list): The conversation history in the format for the model.
50
+ image_pil (PIL.Image): The uploaded image.
51
+
52
+ Returns:
53
+ tuple: Updated chatbot_display, updated messages_list, and an empty string for the textbox.
54
  """
55
  # Check if an image has been uploaded
56
  if image_pil is None:
57
+ # Update the chatbot display with an error message
58
  chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
59
+ return chatbot_display, messages_list, "" # Clear the input box
 
60
 
61
+ # Append user's message to the conversation history for the model
62
  messages_list.append({"role": "user", "content": user_message})
63
+
64
+ # Append user's message to the chatbot display list
65
+ chatbot_display.append((user_message, None))
66
 
67
  try:
68
  # Use the processor to apply the chat template
 
73
  )
74
 
75
  # Preprocess image and the entire formatted prompt
76
+ # Patram expects a single image and the full text prompt
77
  inputs = processor.process(images=[image_pil], text=prompt)
78
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
79
 
80
+ # Generate output using model's specific method
81
+ output = model.generate_from_batch(
82
+ inputs,
83
+ GenerationConfig(max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.6, stop_strings="<|endoftext|>"),
84
+ tokenizer=processor.tokenizer
 
 
 
 
 
85
  )
86
 
87
+ # Extract generated tokens (excluding input tokens) and decode
88
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
89
+ response = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
 
 
 
 
 
 
 
 
 
 
90
 
91
+ # Append assistant's response to the conversation history
92
+ messages_list.append({"role": "assistant", "content": response})
 
 
 
 
93
 
94
+ # Update the chatbot display with the assistant's response
95
+ chatbot_display[-1] = (user_message, response)
 
 
96
 
97
  except Exception as e:
98
+ print(f"Error during inference: {e}")
99
+ error_message = f"Sorry, an error occurred during processing: {e}"
100
+ # Update the last message in the chatbot display with the error
101
  chatbot_display[-1] = (user_message, error_message)
102
+
103
+ # Return the updated state and clear the input textbox
104
+ return chatbot_display, messages_list, ""
105
+
106
+
107
+ def clear_chat(chatbot_display, messages_list, image_input):
108
+ """Resets the chat, history, and image."""
109
+ return [], [], None, "Type your question here..."
110
+
111
 
112
  # --- 3. Gradio Interface Definition ---
113
 
114
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
115
+ gr.Markdown("# 🤖 Patram-7B-Instruct Chatbot")
116
+ gr.Markdown("Upload an image and ask questions about it. The chatbot will remember the conversation context.")
117
 
118
+ # State variables to hold conversation history and image
119
  messages_list = gr.State([])
120
+ # We don't need a state for chatbot_display as it's passed as an input/output directly
121
+ # The image is also passed directly from the gr.Image component
122
 
123
  with gr.Row():
124
  with gr.Column(scale=1):
 
129
  chatbot_display = gr.Chatbot(
130
  label="Conversation",
131
  bubble_full_width=False,
132
+ height=500
 
133
  )
134
  with gr.Row():
135
  user_textbox = gr.Textbox(
 
138
  scale=4,
139
  container=False
140
  )
141
+ submit_btn = gr.Button("Send", variant="primary", scale=1, min_width=0)
142
+
143
 
144
  # --- Event Listeners ---
145
 
146
+ # Define the action for submitting a message (via button or enter key)
147
  submit_action = user_textbox.submit(
148
+ fn=process_chat,
149
  inputs=[user_textbox, chatbot_display, messages_list, image_input],
150
+ outputs=[chatbot_display, messages_list, user_textbox]
151
  )
152
+ submit_btn.click(
153
+ fn=process_chat,
154
+ inputs=[user_textbox, chatbot_display, messages_list, image_input],
155
+ outputs=[chatbot_display, messages_list, user_textbox]
 
 
 
156
  )
157
 
158
  # Define the action for the clear button