File size: 12,510 Bytes
d84bfe0
4c9d5a2
d84bfe0
4c9d5a2
d84bfe0
405581e
4058b98
 
d7da4f2
405581e
 
 
d7da4f2
405581e
 
 
 
f95c930
405581e
 
 
 
 
 
 
d7da4f2
405581e
 
 
 
4058b98
 
 
 
 
 
 
 
 
 
405581e
 
d7da4f2
405581e
 
d7da4f2
405581e
 
 
 
d7da4f2
405581e
d7da4f2
405581e
 
d6e1b99
405581e
 
d6e1b99
 
405581e
d6e1b99
 
 
405581e
08e1931
405581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08e1931
405581e
08e1931
 
 
 
516e5a3
 
 
405581e
d6e1b99
405581e
d6e1b99
 
405581e
d6e1b99
 
 
08e1931
d6e1b99
 
405581e
 
d6e1b99
405581e
d6e1b99
 
 
 
 
 
 
 
 
405581e
d6e1b99
 
 
405581e
d6e1b99
405581e
d6e1b99
 
 
405581e
d6e1b99
 
 
405581e
d6e1b99
 
405581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e1b99
08e1931
d6e1b99
 
 
405581e
 
 
 
d6e1b99
 
405581e
 
d6e1b99
405581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e1b99
405581e
 
d6e1b99
 
 
405581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6e1b99
 
08e1931
 
d6e1b99
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
import tempfile # For creating temporary video files
import os # Import the 'os' module
import accelerate # Import accelerate for better memory management (recommended)

# Marigold specific imports
from diffusers import MarigoldDepthPipeline, DDIMScheduler
from huggingface_hub import login # For Hugging Face Hub login if needed

# --- Marigold Model Setup ---
CHECKPOINT = "prs-eth/marigold-depth-v1-1"

# Check for HF_TOKEN_LOGIN environment variable for private models or higher rate limits
if "HF_TOKEN_LOGIN" in os.environ:
    login(token=os.environ["HF_TOKEN_LOGIN"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use bfloat16 for CUDA if available for performance, else float32
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# Load the Marigold pipeline
try:
    pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
    pipe = pipe.to(device=device, dtype=dtype)
    
    # Enable xformers for memory-efficient attention ONLY IF CUDA is available
    if torch.cuda.is_available():
        try:
            import xformers
            pipe.enable_xformers_memory_efficient_attention()
            print("xFormers enabled for Marigold pipeline.")
        except ImportError:
            print("xFormers not found, running without memory-efficient attention (on GPU).")
    else:
        print("Running on CPU or MPS. xFormers memory-efficient attention is not applicable.")

    print(f"MarigoldDepthPipeline loaded successfully from {CHECKPOINT} on {device}.")
except Exception as e:
    print(f"Error loading MarigoldDepthPipeline: {e}")
    pipe = None # Set pipe to None to gracefully handle if it couldn't be loaded

# --- Default Marigold Parameters (from their demo) ---
DEFAULT_MARIGOLD_ENSEMBLE_SIZE = 1
DEFAULT_MARIGOLD_DENOISE_STEPS = 4
DEFAULT_MARIGOLD_PROCESSING_RES = 768 # Recommended resolution for Marigold

def process_image(image, max_disparity_ratio, inpaint_radius, ensemble_size, denoise_steps, processing_res):
    """
    Convert a 2D photo to a stereoscopic 3D image pair using Marigold for depth estimation
    and DIBR, with adjustable parameters.
    """
    if pipe is None:
        print("Error: Marigold model not loaded. Cannot process image.")
        return Image.new('RGB', (200, 200), color = 'red')

    # Convert PIL image to numpy array
    image_np = np.array(image)
    height, width = image_np.shape[:2]

    # Step 1: Estimate the depth map using Marigold
    try:
        # Marigold's pipeline directly takes a PIL Image.
        # Use a fixed seed for reproducibility if desired, otherwise remove 'generator'.
        generator = torch.Generator(device=device).manual_seed(2024) 
        marigold_output = pipe(
            image, # Pass PIL Image directly
            ensemble_size=ensemble_size,
            num_inference_steps=denoise_steps,
            processing_resolution=processing_res,
            batch_size=1 if processing_res == 0 else 2, # Batch size recommended by Marigold for resolutions
            generator=generator,
        ).prediction # This is the predicted depth map as a torch.Tensor

        # Move to CPU and convert to NumPy array
        depth_map = marigold_output.squeeze().cpu().numpy()

    except Exception as e:
        print(f"Error during Marigold depth estimation: {e}")
        # Return an orange image to indicate a depth estimation specific error
        return Image.new('RGB', (200, 200), color = 'orange')

    # Normalize the depth map to [0,1]
    if depth_map.max() - depth_map.min() > 0:
        depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    else:
        depth_map = np.zeros_like(depth_map) # Handle flat depth map

    # Smooth the depth map to reduce noise for DIBR
    depth_map = cv2.GaussianBlur(depth_map, (5, 5), 0)

    # Step 2: Calculate the disparity map (inversely proportional to depth)
    max_disparity_pixels = int(max_disparity_ratio * width)
    disparity_map = max_disparity_pixels * (1 - depth_map)

    # Step 3: Initialize left and right images and masks for DIBR
    left_image = np.zeros_like(image_np)
    right_image = np.zeros_like(image_np)
    left_mask = np.ones((height, width), dtype=bool)
    right_mask = np.ones((height, width), dtype=bool)

    # Step 4: Perform pixel shifting (forward warping)
    for y in range(height):
        for x in range(width):
            disparity = int(disparity_map[y, x])
            
            new_x_left = x + disparity
            new_x_right = x - disparity
            
            if 0 <= new_x_left < width:
                left_image[y, new_x_left] = image_np[y, x]
                left_mask[y, new_x_left] = False
            
            if 0 <= new_x_right < width:
                right_image[y, new_x_right] = image_np[y, x]
                right_mask[y, new_x_right] = False

    # Convert masks to uint8 for OpenCV inpainting
    left_mask_uint8 = left_mask.astype(np.uint8) * 255
    right_mask_uint8 = right_mask.astype(np.uint8) * 255

    # Step 5: Apply inpainting to fill holes
    left_image_inpaint = cv2.inpaint(left_image, left_mask_uint8, inpaint_radius, cv2.INPAINT_TELEA)
    right_image_inpaint = cv2.inpaint(right_image, right_mask_uint8, inpaint_radius, cv2.INPAINT_TELEA)

    # Step 6: Combine into a side-by-side stereoscopic image
    stereo_image = np.hstack((left_image_inpaint, right_image_inpaint))

    return Image.fromarray(stereo_image)


def process_video(video_path, max_disparity_ratio, inpaint_radius, ensemble_size, denoise_steps, processing_res):
    """
    Convert a 2D video to a stereoscopic 3D video by processing each frame.
    """
    if pipe is None:
        print("Error: Marigold model not loaded. Cannot process video.")
        return None

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video file at {video_path}")
        return None

    fps = cap.get(cv2.CAP_PROP_FPS)
    original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    output_width = original_width * 2
    output_height = original_height

    temp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4
    out = cv2.VideoWriter(temp_output_video_path, fourcc, fps, (output_width, output_height))

    if not out.isOpened():
        print(f"Error: Could not create video writer for {temp_output_video_path}")
        cap.release()
        return None

    frame_count = 0
    while True:
        ret, frame_bgr = cap.read() # frame_bgr is in BGR format
        if not ret:
            break

        frame_rgb_pil = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
        
        # Process the single frame using the existing image processing logic
        processed_frame_pil = process_image(
            frame_rgb_pil, 
            max_disparity_ratio, 
            inpaint_radius,
            ensemble_size, # Pass Marigold params
            denoise_steps, # Pass Marigold params
            processing_res   # Pass Marigold params
        )
        
        if processed_frame_pil is None:
            print(f"Skipping frame {frame_count} due to processing error.")
            processed_frame_bgr = np.zeros((output_height, output_width, 3), dtype=np.uint8)
        else:
            processed_frame_np_rgb = np.array(processed_frame_pil)
            processed_frame_bgr = cv2.cvtColor(processed_frame_np_rgb, cv2.COLOR_RGB2BGR)

        out.write(processed_frame_bgr)
        frame_count += 1
        print(f"Processed frame {frame_count}...")

    cap.release()
    out.release()
    print(f"Finished processing {frame_count} frames. Output video saved to: {temp_output_video_path}")
    return temp_output_video_path

# Define the Gradio web interface layout and components
with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 2D to Stereoscopic 3D Converter (with Marigold Depth)
        Upload a 2D photo or video to generate a stereoscopic 3D image or video pair for viewing on a Quest headset.
        The output is a side-by-side format: left half for the left eye, right half for the right eye.
        Adjust the sliders to fine-tune the 3D effect and Marigold's depth estimation.
        """
    )
    
    # Global sliders for DIBR and Marigold parameters
    with gr.Row():
        max_disparity_slider = gr.Slider(
            minimum=0.01,
            maximum=0.10,
            value=0.03, # A balanced default
            step=0.005,
            label="Max Disparity Ratio (controls 3D intensity)",
            info="Higher values mean a stronger 3D effect, but can cause more distortion."
        )
        inpaint_radius_slider = gr.Slider(
            minimum=1,
            maximum=20,
            value=5, # A common default for inpainting
            step=1,
            label="Inpainting Radius (controls hole filling)",
            info="Larger values fill holes more, but can blur details around shifted objects."
        )

    with gr.Accordion("Marigold Depth Estimation Settings", open=False):
        with gr.Row():
            ensemble_size_slider = gr.Slider(
                label="Marigold Ensemble size",
                minimum=1,
                maximum=10,
                step=1,
                value=DEFAULT_MARIGOLD_ENSEMBLE_SIZE,
                info="Higher values improve accuracy but increase processing time."
            )
            denoise_steps_slider = gr.Slider(
                label="Marigold Denoising steps",
                minimum=1,
                maximum=20,
                step=1,
                value=DEFAULT_MARIGOLD_DENOISE_STEPS,
                info="More steps improve quality but increase processing time."
            )
            processing_res_radio = gr.Radio(
                [
                    ("Native", 0),
                    ("Recommended (768)", 768),
                    ("High (1024)", 1024)
                ],
                label="Marigold Processing resolution",
                value=DEFAULT_MARIGOLD_PROCESSING_RES,
                info="Resolution for Marigold's internal processing. Native uses original image resolution. Higher resolutions are more accurate but slower."
            )
    
    with gr.Tabs():
        with gr.TabItem("Image Conversion"):
            with gr.Row():
                with gr.Column():
                    image_input = gr.Image(type="pil", label="Upload a 2D Photo")
                    image_process_button = gr.Button("Convert Image to 3D")
                with gr.Column():
                    image_output = gr.Image(type="pil", label="Stereoscopic 3D Image Output (Side-by-Side)")
            # Connect the image button to the image processing function
            image_process_button.click(
                fn=process_image,
                inputs=[
                    image_input, 
                    max_disparity_slider, 
                    inpaint_radius_slider,
                    ensemble_size_slider, 
                    denoise_steps_slider, 
                    processing_res_radio
                ],
                outputs=image_output
            )

        with gr.TabItem("Video Conversion"):
            with gr.Row():
                with gr.Column():
                    video_input = gr.Video(label="Upload a 2D MP4 Video")
                    video_process_button = gr.Button("Convert Video to 3D")
                with gr.Column():
                    video_output = gr.Video(label="Stereoscopic 3D Video Output (Side-by-Side)")
            # Connect the video button to the video processing function
            video_process_button.click(
                fn=process_video,
                inputs=[
                    video_input, 
                    max_disparity_slider, 
                    inpaint_radius_slider,
                    ensemble_size_slider, 
                    denoise_steps_slider, 
                    processing_res_radio
                ],
                outputs=video_output
            )

# This block is executed when the script is run directly (e.g., for local testing).
# Hugging Face Spaces typically runs the app via its own internal mechanisms.
if __name__ == '__main__':
    demo.launch()