Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -4,9 +4,6 @@ from PIL import Image
|
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
|
7 |
-
# Set the default dtype for tensors to float16
|
8 |
-
torch.set_default_dtype(torch.float16)
|
9 |
-
|
10 |
# --- 1. Model and Processor Setup ---
|
11 |
model_id = "bharatgenai/patram-7b-instruct"
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -59,8 +56,8 @@ def process_chat(user_message, chatbot_display, messages_list, image_pil):
|
|
59 |
inputs = processor.process(images=[image_pil], text=prompt)
|
60 |
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
|
61 |
|
62 |
-
# Ensure all tensors are in
|
63 |
-
inputs = {k: v.half() for k, v in inputs.items()}
|
64 |
|
65 |
# Generate output using model's specific method
|
66 |
output = model.generate_from_batch(
|
|
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
|
|
|
|
|
|
|
7 |
# --- 1. Model and Processor Setup ---
|
8 |
model_id = "bharatgenai/patram-7b-instruct"
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
56 |
inputs = processor.process(images=[image_pil], text=prompt)
|
57 |
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
|
58 |
|
59 |
+
# Ensure all tensors are in the same dtype
|
60 |
+
inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
|
61 |
|
62 |
# Generate output using model's specific method
|
63 |
output = model.generate_from_batch(
|