Spaces:
Running
on
Zero
Running
on
Zero
Fixed tensor dimension issues for multi-view processing -- again
Browse files
hy3dshape/hy3dshape/pipelines.py
CHANGED
@@ -500,10 +500,41 @@ class Hunyuan3DDiTPipeline:
|
|
500 |
|
501 |
# Handle dictionary input (multi-view mode)
|
502 |
if isinstance(image, dict):
|
503 |
-
#
|
504 |
-
|
505 |
-
|
506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
507 |
|
508 |
if not isinstance(image, list):
|
509 |
image = [image]
|
|
|
500 |
|
501 |
# Handle dictionary input (multi-view mode)
|
502 |
if isinstance(image, dict):
|
503 |
+
# Process each view individually with the single-image processor
|
504 |
+
# and then combine them appropriately
|
505 |
+
processed_views = []
|
506 |
+
view_order = []
|
507 |
+
|
508 |
+
# Define the standard view order
|
509 |
+
view_mapping = {'front': 0, 'left': 1, 'back': 2, 'right': 3}
|
510 |
+
|
511 |
+
# Sort views by their standard order
|
512 |
+
sorted_views = sorted(image.items(), key=lambda x: view_mapping.get(x[0], 999))
|
513 |
+
|
514 |
+
for view_name, view_image in sorted_views:
|
515 |
+
# Process each view individually
|
516 |
+
view_output = self.image_processor(view_image)
|
517 |
+
processed_views.append(view_output)
|
518 |
+
view_order.append(view_mapping.get(view_name, 0))
|
519 |
+
|
520 |
+
# Combine all views into a single batch
|
521 |
+
# Each view_output has shape [1, 3, H, W], we want to concatenate along batch dimension
|
522 |
+
combined_images = []
|
523 |
+
combined_masks = []
|
524 |
+
|
525 |
+
for view_output in processed_views:
|
526 |
+
combined_images.append(view_output['image'])
|
527 |
+
combined_masks.append(view_output['mask'])
|
528 |
+
|
529 |
+
# Concatenate along batch dimension: [num_views, 3, H, W]
|
530 |
+
final_image = torch.cat(combined_images, dim=0)
|
531 |
+
final_mask = torch.cat(combined_masks, dim=0)
|
532 |
+
|
533 |
+
return {
|
534 |
+
'image': final_image,
|
535 |
+
'mask': final_mask,
|
536 |
+
'view_idxs': view_order
|
537 |
+
}
|
538 |
|
539 |
if not isinstance(image, list):
|
540 |
image = [image]
|