asimfayaz commited on
Commit
a4dff59
·
1 Parent(s): 4f7fdcd

Fixed tensor dimension issues for multi-view processing -- again

Browse files
Files changed (1) hide show
  1. hy3dshape/hy3dshape/pipelines.py +35 -4
hy3dshape/hy3dshape/pipelines.py CHANGED
@@ -500,10 +500,41 @@ class Hunyuan3DDiTPipeline:
500
 
501
  # Handle dictionary input (multi-view mode)
502
  if isinstance(image, dict):
503
- # Use the multi-view image processor for dictionaries
504
- from .preprocessors import MVImageProcessorV2
505
- mv_processor = MVImageProcessorV2(size=self.image_processor.size)
506
- return mv_processor(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  if not isinstance(image, list):
509
  image = [image]
 
500
 
501
  # Handle dictionary input (multi-view mode)
502
  if isinstance(image, dict):
503
+ # Process each view individually with the single-image processor
504
+ # and then combine them appropriately
505
+ processed_views = []
506
+ view_order = []
507
+
508
+ # Define the standard view order
509
+ view_mapping = {'front': 0, 'left': 1, 'back': 2, 'right': 3}
510
+
511
+ # Sort views by their standard order
512
+ sorted_views = sorted(image.items(), key=lambda x: view_mapping.get(x[0], 999))
513
+
514
+ for view_name, view_image in sorted_views:
515
+ # Process each view individually
516
+ view_output = self.image_processor(view_image)
517
+ processed_views.append(view_output)
518
+ view_order.append(view_mapping.get(view_name, 0))
519
+
520
+ # Combine all views into a single batch
521
+ # Each view_output has shape [1, 3, H, W], we want to concatenate along batch dimension
522
+ combined_images = []
523
+ combined_masks = []
524
+
525
+ for view_output in processed_views:
526
+ combined_images.append(view_output['image'])
527
+ combined_masks.append(view_output['mask'])
528
+
529
+ # Concatenate along batch dimension: [num_views, 3, H, W]
530
+ final_image = torch.cat(combined_images, dim=0)
531
+ final_mask = torch.cat(combined_masks, dim=0)
532
+
533
+ return {
534
+ 'image': final_image,
535
+ 'mask': final_mask,
536
+ 'view_idxs': view_order
537
+ }
538
 
539
  if not isinstance(image, list):
540
  image = [image]