hb-setosys commited on
Commit
972246b
·
verified ·
1 Parent(s): 0339008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -4,6 +4,8 @@ from tensorflow.keras.applications.resnet import ResNet152, preprocess_input, de
4
  from tensorflow.keras.preprocessing.image import img_to_array
5
  from PIL import Image
6
  import numpy as np
 
 
7
 
8
  # Load the pre-trained ResNet152 model
9
  MODEL_PATH = "resnet152-image-classifier.h5" # Path to the saved model
@@ -13,11 +15,25 @@ except Exception as e:
13
  print(f"Error loading model: {e}")
14
  exit()
15
 
 
 
 
 
 
 
 
 
 
 
16
  def predict_image(image):
17
  """
18
  Process the uploaded image and return the top 3 predictions.
19
  """
20
  try:
 
 
 
 
21
  # Preprocess the image
22
  image = image.resize((224, 224)) # ResNet152 expects 224x224 input
23
  image_array = img_to_array(image)
@@ -38,7 +54,7 @@ def predict_image(image):
38
  # Create the Gradio interface
39
  interface = gr.Interface(
40
  fn=predict_image,
41
- inputs=gr.Image(type="pil"), # Accepts an image input
42
  outputs=gr.Label(num_top_classes=3), # Shows top 3 predictions with confidence
43
  title="ResNet152 Image Classifier",
44
  description="Upload an image, and the model will predict what's in the image.",
@@ -47,4 +63,4 @@ interface = gr.Interface(
47
 
48
  # Launch the Gradio app
49
  if __name__ == "__main__":
50
- interface.launch()
 
4
  from tensorflow.keras.preprocessing.image import img_to_array
5
  from PIL import Image
6
  import numpy as np
7
+ import base64
8
+ from io import BytesIO
9
 
10
  # Load the pre-trained ResNet152 model
11
  MODEL_PATH = "resnet152-image-classifier.h5" # Path to the saved model
 
15
  print(f"Error loading model: {e}")
16
  exit()
17
 
18
+ def decode_image_from_base64(base64_str):
19
+ """
20
+ Decodes a base64 string to a PIL image.
21
+ """
22
+ # Decode the base64 string to bytes
23
+ image_data = base64.b64decode(base64_str)
24
+ # Convert the bytes into a PIL image
25
+ image = Image.open(BytesIO(image_data))
26
+ return image
27
+
28
  def predict_image(image):
29
  """
30
  Process the uploaded image and return the top 3 predictions.
31
  """
32
  try:
33
+ # If the image is base64 encoded, decode it
34
+ if isinstance(image, str):
35
+ image = decode_image_from_base64(image)
36
+
37
  # Preprocess the image
38
  image = image.resize((224, 224)) # ResNet152 expects 224x224 input
39
  image_array = img_to_array(image)
 
54
  # Create the Gradio interface
55
  interface = gr.Interface(
56
  fn=predict_image,
57
+ inputs=gr.Image(type="pil", tool="editor"), # Accepts an image input
58
  outputs=gr.Label(num_top_classes=3), # Shows top 3 predictions with confidence
59
  title="ResNet152 Image Classifier",
60
  description="Upload an image, and the model will predict what's in the image.",
 
63
 
64
  # Launch the Gradio app
65
  if __name__ == "__main__":
66
+ interface.launch()