KingNish commited on
Commit
c547944
·
verified ·
1 Parent(s): 86d52bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -5
app.py CHANGED
@@ -4,9 +4,6 @@ from PIL import Image
4
  import gradio as gr
5
  import spaces
6
 
7
- # Set the default dtype for tensors to float16
8
- torch.set_default_dtype(torch.float16)
9
-
10
  # --- 1. Model and Processor Setup ---
11
  model_id = "bharatgenai/patram-7b-instruct"
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -59,8 +56,8 @@ def process_chat(user_message, chatbot_display, messages_list, image_pil):
59
  inputs = processor.process(images=[image_pil], text=prompt)
60
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
61
 
62
- # Ensure all tensors are in float16
63
- inputs = {k: v.half() for k, v in inputs.items()}
64
 
65
  # Generate output using model's specific method
66
  output = model.generate_from_batch(
 
4
  import gradio as gr
5
  import spaces
6
 
 
 
 
7
  # --- 1. Model and Processor Setup ---
8
  model_id = "bharatgenai/patram-7b-instruct"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
56
  inputs = processor.process(images=[image_pil], text=prompt)
57
  inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
58
 
59
+ # Ensure all tensors are in the same dtype
60
+ inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
61
 
62
  # Generate output using model's specific method
63
  output = model.generate_from_batch(