StoneSeller commited on
Commit
6cc530c
·
verified ·
1 Parent(s): d16040b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -71,11 +71,9 @@ except Exception as e:
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
74
- # Define image transformation pipeline
75
  transform = transforms.Compose([
76
  transforms.Resize((128, 128)),
77
- transforms.PILToTensor(), # Changed from ToTensor()
78
- transforms.ConvertImageDtype(torch.float32),
79
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
80
  ])
81
 
@@ -86,14 +84,21 @@ def process_image(image):
86
  try:
87
  # Convert numpy array to PIL Image
88
  if isinstance(image, np.ndarray):
89
- image = Image.fromarray(image.astype('uint8'))
 
 
 
90
 
91
  # Convert to RGB if necessary
92
  if image.mode != 'RGB':
93
  image = image.convert('RGB')
94
 
 
 
 
95
  print(f"Processed image size: {image.size}")
96
  print(f"Processed image mode: {image.mode}")
 
97
 
98
  return image
99
  except Exception as e:
@@ -111,10 +116,15 @@ def predict(image):
111
  if processed_image is None:
112
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
113
 
114
- # Transform image to tensor using torchvision transforms
115
  try:
116
- tensor_image = transform(processed_image).unsqueeze(0)
 
 
 
117
  print(f"Input tensor shape: {tensor_image.shape}")
 
 
118
  except Exception as e:
119
  print(f"Error in tensor conversion: {str(e)}")
120
  traceback.print_exc()
@@ -142,7 +152,7 @@ def predict(image):
142
  # Gradio interface
143
  interface = gr.Interface(
144
  fn=predict,
145
- inputs=gr.Image(),
146
  outputs=gr.Label(num_top_classes=3),
147
  title="Mechanical Tools Classifier",
148
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",
 
71
  print(f"Error loading model: {str(e)}")
72
  traceback.print_exc()
73
 
 
74
  transform = transforms.Compose([
75
  transforms.Resize((128, 128)),
76
+ transforms.ToTensor(), # PILToTensor 대신 ToTensor 사용
 
77
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
78
  ])
79
 
 
84
  try:
85
  # Convert numpy array to PIL Image
86
  if isinstance(image, np.ndarray):
87
+ # Ensure the array is uint8
88
+ if image.dtype != np.uint8:
89
+ image = (image * 255).astype(np.uint8)
90
+ image = Image.fromarray(image)
91
 
92
  # Convert to RGB if necessary
93
  if image.mode != 'RGB':
94
  image = image.convert('RGB')
95
 
96
+ # Resize the image
97
+ image = image.resize((128, 128), Image.Resampling.LANCZOS)
98
+
99
  print(f"Processed image size: {image.size}")
100
  print(f"Processed image mode: {image.mode}")
101
+ print(f"Image type: {type(image)}")
102
 
103
  return image
104
  except Exception as e:
 
116
  if processed_image is None:
117
  return {cls: 0.0 for cls in ["Rope", "Hammer", "Other"]}
118
 
119
+ # Transform image to tensor
120
  try:
121
+ # Convert PIL Image to tensor
122
+ tensor_image = transform(processed_image)
123
+ # Add batch dimension
124
+ tensor_image = tensor_image.unsqueeze(0)
125
  print(f"Input tensor shape: {tensor_image.shape}")
126
+ print(f"Tensor dtype: {tensor_image.dtype}")
127
+ print(f"Tensor device: {tensor_image.device}")
128
  except Exception as e:
129
  print(f"Error in tensor conversion: {str(e)}")
130
  traceback.print_exc()
 
152
  # Gradio interface
153
  interface = gr.Interface(
154
  fn=predict,
155
+ inputs=gr.Image(type="pil"), # PIL 이미지 타입으로 명시
156
  outputs=gr.Label(num_top_classes=3),
157
  title="Mechanical Tools Classifier",
158
  description="Upload an image of a tool to classify it as 'Rope', 'Hammer', or 'Other'.",