EgoHackZero commited on
Commit
6225fda
·
1 Parent(s): fa1476f

try to add segmentation step

Browse files
Files changed (1) hide show
  1. app.py +62 -34
app.py CHANGED
@@ -1,64 +1,92 @@
1
- import torch
2
- import gradio as gr
3
  import numpy as np
 
4
  import cv2
5
  from PIL import Image
 
 
6
 
 
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
8
 
9
- # Загрузка модели
 
10
  midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
11
- midas.to(device)
12
- midas.eval()
13
-
14
- # Загрузка трансформаций
15
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16
  transform = midas_transforms.dpt_transform
17
 
18
- def predict_depth(image):
19
- # ======= 1. Преобразование в OpenCV формат =======
20
- if not isinstance(image, Image.Image):
21
- image = Image.fromarray(image)
22
- image_np = np.array(image)
 
 
23
 
24
- # OpenCV читает в BGR, но image_np скорее всего уже в RGB
25
- img_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
26
- img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) # На всякий случай двойная проверка
 
 
 
 
 
27
 
28
- # ======= 2. Преобразование как в официальном туториале =======
29
- input_tensor = transform(img_rgb).to(device) # shape: [3, H, W]
 
 
 
30
 
31
- # ======= 3. Добавление batch размерности =======
32
- if len(input_tensor.shape) == 3:
33
- input_batch = input_tensor.unsqueeze(0) # shape: [1, 3, H, W]
34
- else:
35
- input_batch = input_tensor # Уже batch
36
 
37
- # ======= 4. Предсказание =======
38
  with torch.no_grad():
39
  prediction = midas(input_batch)
40
  prediction = torch.nn.functional.interpolate(
41
  prediction.unsqueeze(1),
42
- size=(img_rgb.shape[0], img_rgb.shape[1]), # (H, W)
43
  mode="bicubic",
44
- align_corners=False,
45
  ).squeeze()
46
 
47
- # ======= 5. Нормализация и преобразование в изображение =======
48
  depth_map = prediction.cpu().numpy()
49
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
50
  depth_map = (depth_map * 255).astype(np.uint8)
51
- depth_img = Image.fromarray(depth_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return depth_img
 
 
 
 
 
 
54
 
55
- # Gradio интерфейс
56
  iface = gr.Interface(
57
- fn=predict_depth,
58
- inputs=gr.Image(type="pil"),
59
- outputs=gr.Image(type="pil"),
60
- title="MiDaS Depth Estimation",
61
- description="Drop img -> depth map"
62
  )
63
 
64
  if __name__ == "__main__":
 
 
 
1
  import numpy as np
2
+ import torch
3
  import cv2
4
  from PIL import Image
5
+ from transformers import pipeline
6
+ import gradio as gr
7
 
8
+ # ===== Device Setup =====
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ device_index = 0 if torch.cuda.is_available() else -1
11
 
12
+ # ===== MiDaS Depth Estimation Setup =====
13
+ # Load MiDaS model and transforms
14
  midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large")
15
+ midas.to(device).eval()
 
 
 
16
  midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
17
  transform = midas_transforms.dpt_transform
18
 
19
+ # ===== Segmentation Setup =====
20
+ segmenter = pipeline(
21
+ "image-segmentation",
22
+ model="nvidia/segformer-b0-finetuned-ade-512-512",
23
+ device=device_index,
24
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
25
+ )
26
 
27
+ # ===== Utility Functions =====
28
+ def resize_image(img: Image.Image, max_size: int = 512) -> Image.Image:
29
+ width, height = img.size
30
+ if max(width, height) > max_size:
31
+ ratio = max_size / max(width, height)
32
+ new_size = (int(width * ratio), int(height * ratio))
33
+ return img.resize(new_size, Image.LANCZOS)
34
+ return img
35
 
36
+ # ===== Depth Prediction =====
37
+ def predict_depth(image: Image.Image) -> Image.Image:
38
+ # Ensure input is PIL Image
39
+ img = image.convert('RGB') if not isinstance(image, Image.Image) else image
40
+ img_np = np.array(img)
41
 
42
+ # Convert to the format expected by MiDaS
43
+ input_tensor = transform(img_np).to(device)
44
+ input_batch = input_tensor.unsqueeze(0) if input_tensor.ndim == 3 else input_tensor
 
 
45
 
46
+ # Predict depth
47
  with torch.no_grad():
48
  prediction = midas(input_batch)
49
  prediction = torch.nn.functional.interpolate(
50
  prediction.unsqueeze(1),
51
+ size=img_np.shape[:2],
52
  mode="bicubic",
53
+ align_corners=False
54
  ).squeeze()
55
 
56
+ # Normalize to 0-255
57
  depth_map = prediction.cpu().numpy()
58
  depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
59
  depth_map = (depth_map * 255).astype(np.uint8)
60
+ return Image.fromarray(depth_map)
61
+
62
+ # ===== Segmentation =====
63
+ def segment_image(img: Image.Image) -> Image.Image:
64
+ img = img.convert('RGB')
65
+ img_resized = resize_image(img)
66
+ results = segmenter(img_resized)
67
+
68
+ overlay = np.array(img_resized, dtype=np.uint8)
69
+ for res in results:
70
+ mask = np.array(res["mask"], dtype=bool)
71
+ color = np.random.randint(50, 255, 3, dtype=np.uint8)
72
+ overlay[mask] = (overlay[mask] * 0.6 + color * 0.4).astype(np.uint8)
73
+
74
+ return Image.fromarray(overlay)
75
 
76
+ # ===== Gradio App =====
77
+ def predict_fn(input_img: Image.Image) -> Image.Image:
78
+ # 1. Compute depth map
79
+ depth_img = predict_depth(input_img)
80
+ # 2. Segment the depth map
81
+ seg_img = segment_image(depth_img)
82
+ return seg_img
83
 
 
84
  iface = gr.Interface(
85
+ fn=predict_fn,
86
+ inputs=gr.Image(type="pil", label="Upload Image"),
87
+ outputs=gr.Image(type="pil", label="Segmented Depth Overlay"),
88
+ title="Depth-then-Segmentation Pipeline",
89
+ description="Upload an image. First computes a depth map via MiDaS, then applies SegFormer segmentation on the depth map."
90
  )
91
 
92
  if __name__ == "__main__":