DurgaDeepak commited on
Commit
b6e2d20
·
verified ·
1 Parent(s): c0d825f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -321
app.py CHANGED
@@ -71,327 +71,6 @@ def preload_models():
71
  get_model("segmentation", "deeplabv3_resnet50", device="cpu")
72
  get_model("depth", "midas_v21_small_256", device="cpu")
73
 
74
-
75
- # Utility Functions
76
- def format_error(message):
77
- """Formats error messages for consistent user feedback."""
78
- return {"error": message}
79
-
80
- def toggle_visibility(show, *components):
81
- """Toggles visibility for multiple Gradio components."""
82
- return [gr.update(visible=show) for _ in components]
83
-
84
- def generate_session_id():
85
- """Generates a unique session ID for tracking inputs."""
86
- return str(uuid.uuid4())
87
-
88
- def log_runtime(start_time):
89
- """Logs the runtime of a process."""
90
- elapsed_time = time.time() - start_time
91
- logger.info(f"Process completed in {elapsed_time:.2f} seconds.")
92
- return elapsed_time
93
-
94
- def is_public_ip(url):
95
- """
96
- Checks whether the resolved IP address of a URL is public (non-local).
97
- Prevents SSRF by blocking internal addresses like 127.0.0.1 or 192.168.x.x.
98
- """
99
- try:
100
- hostname = urlparse(url).hostname
101
- ip = socket.gethostbyname(hostname)
102
- ip_obj = ipaddress.ip_address(ip)
103
- return ip_obj.is_global # Only allow globally routable IPs
104
- except Exception as e:
105
- logger.warning(f"URL IP validation failed: {e}")
106
- return False
107
-
108
-
109
- def fetch_media_from_url(url):
110
- """
111
- Downloads media from a URL. Supports images and videos.
112
- Returns PIL.Image or video file path.
113
- """
114
- logger.info(f"Fetching media from URL: {url}")
115
- if not is_public_ip(url):
116
- logger.warning("Blocked non-public URL request (possible SSRF).")
117
- return None
118
-
119
- try:
120
- parsed_url = urlparse(url)
121
- ext = os.path.splitext(parsed_url.path)[-1].lower()
122
- headers = {"User-Agent": "Mozilla/5.0"}
123
- r = requests.get(url, headers=headers, timeout=10)
124
-
125
- if r.status_code != 200 or len(r.content) > 50 * 1024 * 1024:
126
- logger.warning(f"Download failed or file too large.")
127
- return None
128
-
129
- tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
130
- tmp_file.write(r.content)
131
- tmp_file.close()
132
-
133
- if ext in [".jpg", ".jpeg", ".png"]:
134
- return Image.open(tmp_file.name).convert("RGB")
135
- elif ext in [".mp4", ".avi", ".mov"]:
136
- return tmp_file.name
137
- else:
138
- logger.warning("Unsupported file type from URL.")
139
- return None
140
- except Exception as e:
141
- logger.error(f"URL fetch failed: {e}")
142
- return None
143
-
144
- # Input Validation Functions
145
- def validate_image(img):
146
- """
147
- Validates the uploaded image based on size and resolution limits.
148
-
149
- Args:
150
- img (PIL.Image.Image): Image to validate.
151
-
152
- Returns:
153
- Tuple[bool, str or None]: (True, None) if valid; (False, reason) otherwise.
154
- """
155
- logger.info("Validating uploaded image.")
156
- try:
157
- buffer = io.BytesIO()
158
- img.save(buffer, format="PNG")
159
- size_mb = len(buffer.getvalue()) / (1024 * 1024)
160
-
161
- if size_mb > MAX_IMAGE_MB:
162
- logger.warning("Image exceeds size limit of 5MB.")
163
- return False, "Image exceeds 5MB limit."
164
-
165
- if img.width > MAX_IMAGE_RES[0] or img.height > MAX_IMAGE_RES[1]:
166
- logger.warning("Image resolution exceeds 1920x1080.")
167
- return False, "Image resolution exceeds 1920x1080."
168
-
169
- logger.info("Image validation passed.")
170
- return True, None
171
- except Exception as e:
172
- logger.error(f"Error validating image: {e}")
173
- return False, str(e)
174
-
175
- def validate_video(path):
176
- """
177
- Validates the uploaded video based on size and duration limits.
178
-
179
- Args:
180
- path (str): Path to the video file.
181
-
182
- Returns:
183
- Tuple[bool, str or None]: (True, None) if valid; (False, reason) otherwise.
184
- """
185
- logger.info(f"Validating video file at: {path}")
186
- try:
187
- size_mb = os.path.getsize(path) / (1024 * 1024)
188
- if size_mb > MAX_VIDEO_MB:
189
- logger.warning("Video exceeds size limit of 50MB.")
190
- return False, "Video exceeds 50MB limit."
191
-
192
- cap = cv2.VideoCapture(path)
193
- fps = cap.get(cv2.CAP_PROP_FPS)
194
- frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
195
- duration = frames / fps if fps else 0
196
- cap.release()
197
-
198
- if duration > MAX_VIDEO_DURATION:
199
- logger.warning("Video exceeds 30 seconds duration limit.")
200
- return False, "Video exceeds 30 seconds duration limit."
201
-
202
- logger.info("Video validation passed.")
203
- return True, None
204
- except Exception as e:
205
- logger.error(f"Error validating video: {e}")
206
- return False, str(e)
207
-
208
- # Input Resolution
209
- def resolve_input(mode, media_upload, url):
210
- """
211
- Resolves the media input based on selected mode.
212
- - If mode is 'Upload', accepts either:
213
- * 1–5 images (PIL.Image)
214
- * OR 1 video file (file path as string)
215
- - If mode is 'URL', fetches remote image or video.
216
-
217
- Args:
218
- mode (str): 'Upload' or 'URL'
219
- media_upload (List[Union[PIL.Image.Image, str]]): Uploaded media
220
- url (str): URL to image or video
221
-
222
- Returns:
223
- List[Union[PIL.Image.Image, str]] or None
224
- """
225
- try:
226
- logger.info(f"Resolving input for mode: {mode}")
227
-
228
- if mode == "Upload":
229
- if not media_upload:
230
- logger.warning("No upload detected.")
231
- return None
232
-
233
- image_files = [f for f in media_upload if isinstance(f, Image.Image)]
234
- video_files = [f for f in media_upload if isinstance(f, str) and f.lower().endswith((".mp4", ".mov", ".avi"))]
235
-
236
- if image_files and video_files:
237
- logger.warning("Mixed media upload not supported (images + video).")
238
- return None
239
-
240
- if image_files:
241
- if 1 <= len(image_files) <= 5:
242
- logger.info(f"Accepted {len(image_files)} image(s).")
243
- return image_files
244
- logger.warning("Invalid number of images. Must be 1 to 5.")
245
- return None
246
-
247
- if video_files:
248
- if len(video_files) == 1:
249
- logger.info("Accepted single video upload.")
250
- return video_files
251
- logger.warning("Only one video allowed.")
252
- return None
253
-
254
- logger.warning("Unsupported upload type.")
255
- return None
256
-
257
- elif mode == "URL":
258
- if not url:
259
- logger.warning("URL mode selected but URL is empty.")
260
- return None
261
- media = fetch_media_from_url(url)
262
- if media:
263
- logger.info("Media successfully fetched from URL.")
264
- return [media]
265
- else:
266
- logger.warning("Failed to resolve media from URL.")
267
- return None
268
-
269
- else:
270
- logger.error(f"Invalid mode selected: {mode}")
271
- return None
272
-
273
- except Exception as e:
274
- logger.error(f"Exception in resolve_input(): {e}")
275
- return None
276
-
277
- @timeout_decorator.timeout(35, use_signals=False) # 35 sec limit per image
278
- def process_image(
279
- image: Image.Image,
280
- run_det: bool,
281
- det_model: str,
282
- det_confidence: float,
283
- run_seg: bool,
284
- seg_model: str,
285
- run_depth: bool,
286
- depth_model: str,
287
- blend: float
288
- ):
289
- """
290
- Runs selected perception tasks on the input image and packages results.
291
-
292
- Args:
293
- image (PIL.Image): Input image.
294
- run_det (bool): Run object detection.
295
- det_model (str): Detection model key.
296
- det_confidence (float): Detection confidence threshold.
297
- run_seg (bool): Run segmentation.
298
- seg_model (str): Segmentation model key.
299
- run_depth (bool): Run depth estimation.
300
- depth_model (str): Depth model key.
301
- blend (float): Overlay blend alpha (0.0 - 1.0).
302
-
303
- Returns:
304
- Tuple[Image, dict, Tuple[str, bytes]]: Final image, scene JSON, and downloadable ZIP.
305
- """
306
- logger.info("Starting image processing pipeline.")
307
- start_time = time.time()
308
- outputs, scene = {}, {}
309
- combined_np = np.array(image)
310
-
311
- try:
312
- # Detection
313
- if run_det:
314
- logger.info(f"Running detection with model: {det_model}")
315
- load_start = time.time()
316
- model = get_model("detection", DETECTION_MODEL_MAP[det_model], device="cpu")
317
- logger.info(f"{det_model} detection model loaded in {time.time() - load_start:.2f} seconds.")
318
- boxes = model.predict(image, conf_threshold=det_confidence)
319
- overlay = model.draw(image, boxes)
320
- combined_np = np.array(overlay)
321
- buf = io.BytesIO()
322
- overlay.save(buf, format="PNG")
323
- outputs["detection.png"] = buf.getvalue()
324
- scene["detection"] = boxes
325
-
326
- # Segmentation
327
- if run_seg:
328
- logger.info(f"Running segmentation with model: {seg_model}")
329
- load_start = time.time()
330
- model = get_model("segmentation", SEGMENTATION_MODEL_MAP[seg_model], device="cpu")
331
- logger.info(f"{seg_model} segmentation model loaded in {time.time() - load_start:.2f} seconds.")
332
- mask = model.predict(image)
333
- overlay = model.draw(image, mask, alpha=blend)
334
- combined_np = cv2.addWeighted(combined_np, 1 - blend, np.array(overlay), blend, 0)
335
- buf = io.BytesIO()
336
- overlay.save(buf, format="PNG")
337
- outputs["segmentation.png"] = buf.getvalue()
338
- scene["segmentation"] = mask.tolist()
339
-
340
- # Depth Estimation
341
- if run_depth:
342
- logger.info(f"Running depth estimation with model: {depth_model}")
343
- load_start = time.time()
344
- model = get_model("depth", DEPTH_MODEL_MAP[depth_model], device="cpu")
345
- logger.info(f"{depth_model} depth model loaded in {time.time() - load_start:.2f} seconds.")
346
- dmap = model.predict(image)
347
- norm_dmap = ((dmap - dmap.min()) / (dmap.ptp()) * 255).astype(np.uint8)
348
- d_pil = Image.fromarray(norm_dmap)
349
- combined_np = cv2.addWeighted(combined_np, 1 - blend, np.array(d_pil.convert("RGB")), blend, 0)
350
- buf = io.BytesIO()
351
- d_pil.save(buf, format="PNG")
352
- outputs["depth_map.png"] = buf.getvalue()
353
- scene["depth"] = dmap.tolist()
354
-
355
- # Final image overlay
356
- final_img = Image.fromarray(combined_np)
357
- buf = io.BytesIO()
358
- final_img.save(buf, format="PNG")
359
- outputs["scene_blueprint.png"] = buf.getvalue()
360
-
361
- # Scene description
362
- try:
363
- scene_json = describe_scene(**scene)
364
- except Exception as e:
365
- logger.warning(f"describe_scene failed: {e}")
366
- scene_json = {"error": str(e)}
367
- telemetry = {
368
- "session_id": generate_session_id(),
369
- "runtime_sec": round(log_runtime(start_time), 2),
370
- "used_models": {
371
- "detection": det_model if run_det else None,
372
- "segmentation": seg_model if run_seg else None,
373
- "depth": depth_model if run_depth else None
374
- }
375
- }
376
- scene_json["telemetry"] = telemetry
377
-
378
- outputs["scene_description.json"] = json.dumps(scene_json, indent=2).encode("utf-8")
379
-
380
- # ZIP file creation
381
- zip_buf = io.BytesIO()
382
- with zipfile.ZipFile(zip_buf, "w") as zipf:
383
- for name, data in outputs.items():
384
- zipf.writestr(name, data)
385
-
386
- elapsed = log_runtime(start_time)
387
- logger.info(f"Image processing completed in {elapsed:.2f} seconds.")
388
-
389
- return final_img, scene_json, ("uvis_results.zip", zip_buf.getvalue())
390
-
391
- except Exception as e:
392
- logger.error(f"Error in processing pipeline: {e}")
393
- return None, {"error": str(e)}, None
394
-
395
  # Main Handler
396
  def handle(mode, media_upload, url, run_det, det_model, det_confidence, run_seg, seg_model, run_depth, depth_model, blend):
397
  """
 
71
  get_model("segmentation", "deeplabv3_resnet50", device="cpu")
72
  get_model("depth", "midas_v21_small_256", device="cpu")
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Main Handler
75
  def handle(mode, media_upload, url, run_det, det_model, det_confidence, run_seg, seg_model, run_depth, depth_model, blend):
76
  """