File size: 12,510 Bytes
d84bfe0 4c9d5a2 d84bfe0 4c9d5a2 d84bfe0 405581e 4058b98 d7da4f2 405581e d7da4f2 405581e f95c930 405581e d7da4f2 405581e 4058b98 405581e d7da4f2 405581e d7da4f2 405581e d7da4f2 405581e d7da4f2 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e 08e1931 405581e 08e1931 405581e 08e1931 516e5a3 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 08e1931 d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 08e1931 d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 405581e d6e1b99 08e1931 d6e1b99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
import tempfile # For creating temporary video files
import os # Import the 'os' module
import accelerate # Import accelerate for better memory management (recommended)
# Marigold specific imports
from diffusers import MarigoldDepthPipeline, DDIMScheduler
from huggingface_hub import login # For Hugging Face Hub login if needed
# --- Marigold Model Setup ---
CHECKPOINT = "prs-eth/marigold-depth-v1-1"
# Check for HF_TOKEN_LOGIN environment variable for private models or higher rate limits
if "HF_TOKEN_LOGIN" in os.environ:
login(token=os.environ["HF_TOKEN_LOGIN"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Use bfloat16 for CUDA if available for performance, else float32
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load the Marigold pipeline
try:
pipe = MarigoldDepthPipeline.from_pretrained(CHECKPOINT)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe = pipe.to(device=device, dtype=dtype)
# Enable xformers for memory-efficient attention ONLY IF CUDA is available
if torch.cuda.is_available():
try:
import xformers
pipe.enable_xformers_memory_efficient_attention()
print("xFormers enabled for Marigold pipeline.")
except ImportError:
print("xFormers not found, running without memory-efficient attention (on GPU).")
else:
print("Running on CPU or MPS. xFormers memory-efficient attention is not applicable.")
print(f"MarigoldDepthPipeline loaded successfully from {CHECKPOINT} on {device}.")
except Exception as e:
print(f"Error loading MarigoldDepthPipeline: {e}")
pipe = None # Set pipe to None to gracefully handle if it couldn't be loaded
# --- Default Marigold Parameters (from their demo) ---
DEFAULT_MARIGOLD_ENSEMBLE_SIZE = 1
DEFAULT_MARIGOLD_DENOISE_STEPS = 4
DEFAULT_MARIGOLD_PROCESSING_RES = 768 # Recommended resolution for Marigold
def process_image(image, max_disparity_ratio, inpaint_radius, ensemble_size, denoise_steps, processing_res):
"""
Convert a 2D photo to a stereoscopic 3D image pair using Marigold for depth estimation
and DIBR, with adjustable parameters.
"""
if pipe is None:
print("Error: Marigold model not loaded. Cannot process image.")
return Image.new('RGB', (200, 200), color = 'red')
# Convert PIL image to numpy array
image_np = np.array(image)
height, width = image_np.shape[:2]
# Step 1: Estimate the depth map using Marigold
try:
# Marigold's pipeline directly takes a PIL Image.
# Use a fixed seed for reproducibility if desired, otherwise remove 'generator'.
generator = torch.Generator(device=device).manual_seed(2024)
marigold_output = pipe(
image, # Pass PIL Image directly
ensemble_size=ensemble_size,
num_inference_steps=denoise_steps,
processing_resolution=processing_res,
batch_size=1 if processing_res == 0 else 2, # Batch size recommended by Marigold for resolutions
generator=generator,
).prediction # This is the predicted depth map as a torch.Tensor
# Move to CPU and convert to NumPy array
depth_map = marigold_output.squeeze().cpu().numpy()
except Exception as e:
print(f"Error during Marigold depth estimation: {e}")
# Return an orange image to indicate a depth estimation specific error
return Image.new('RGB', (200, 200), color = 'orange')
# Normalize the depth map to [0,1]
if depth_map.max() - depth_map.min() > 0:
depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
else:
depth_map = np.zeros_like(depth_map) # Handle flat depth map
# Smooth the depth map to reduce noise for DIBR
depth_map = cv2.GaussianBlur(depth_map, (5, 5), 0)
# Step 2: Calculate the disparity map (inversely proportional to depth)
max_disparity_pixels = int(max_disparity_ratio * width)
disparity_map = max_disparity_pixels * (1 - depth_map)
# Step 3: Initialize left and right images and masks for DIBR
left_image = np.zeros_like(image_np)
right_image = np.zeros_like(image_np)
left_mask = np.ones((height, width), dtype=bool)
right_mask = np.ones((height, width), dtype=bool)
# Step 4: Perform pixel shifting (forward warping)
for y in range(height):
for x in range(width):
disparity = int(disparity_map[y, x])
new_x_left = x + disparity
new_x_right = x - disparity
if 0 <= new_x_left < width:
left_image[y, new_x_left] = image_np[y, x]
left_mask[y, new_x_left] = False
if 0 <= new_x_right < width:
right_image[y, new_x_right] = image_np[y, x]
right_mask[y, new_x_right] = False
# Convert masks to uint8 for OpenCV inpainting
left_mask_uint8 = left_mask.astype(np.uint8) * 255
right_mask_uint8 = right_mask.astype(np.uint8) * 255
# Step 5: Apply inpainting to fill holes
left_image_inpaint = cv2.inpaint(left_image, left_mask_uint8, inpaint_radius, cv2.INPAINT_TELEA)
right_image_inpaint = cv2.inpaint(right_image, right_mask_uint8, inpaint_radius, cv2.INPAINT_TELEA)
# Step 6: Combine into a side-by-side stereoscopic image
stereo_image = np.hstack((left_image_inpaint, right_image_inpaint))
return Image.fromarray(stereo_image)
def process_video(video_path, max_disparity_ratio, inpaint_radius, ensemble_size, denoise_steps, processing_res):
"""
Convert a 2D video to a stereoscopic 3D video by processing each frame.
"""
if pipe is None:
print("Error: Marigold model not loaded. Cannot process video.")
return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"Error: Could not open video file at {video_path}")
return None
fps = cap.get(cv2.CAP_PROP_FPS)
original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_width = original_width * 2
output_height = original_height
temp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4
out = cv2.VideoWriter(temp_output_video_path, fourcc, fps, (output_width, output_height))
if not out.isOpened():
print(f"Error: Could not create video writer for {temp_output_video_path}")
cap.release()
return None
frame_count = 0
while True:
ret, frame_bgr = cap.read() # frame_bgr is in BGR format
if not ret:
break
frame_rgb_pil = Image.fromarray(cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB))
# Process the single frame using the existing image processing logic
processed_frame_pil = process_image(
frame_rgb_pil,
max_disparity_ratio,
inpaint_radius,
ensemble_size, # Pass Marigold params
denoise_steps, # Pass Marigold params
processing_res # Pass Marigold params
)
if processed_frame_pil is None:
print(f"Skipping frame {frame_count} due to processing error.")
processed_frame_bgr = np.zeros((output_height, output_width, 3), dtype=np.uint8)
else:
processed_frame_np_rgb = np.array(processed_frame_pil)
processed_frame_bgr = cv2.cvtColor(processed_frame_np_rgb, cv2.COLOR_RGB2BGR)
out.write(processed_frame_bgr)
frame_count += 1
print(f"Processed frame {frame_count}...")
cap.release()
out.release()
print(f"Finished processing {frame_count} frames. Output video saved to: {temp_output_video_path}")
return temp_output_video_path
# Define the Gradio web interface layout and components
with gr.Blocks() as demo:
gr.Markdown(
"""
# 2D to Stereoscopic 3D Converter (with Marigold Depth)
Upload a 2D photo or video to generate a stereoscopic 3D image or video pair for viewing on a Quest headset.
The output is a side-by-side format: left half for the left eye, right half for the right eye.
Adjust the sliders to fine-tune the 3D effect and Marigold's depth estimation.
"""
)
# Global sliders for DIBR and Marigold parameters
with gr.Row():
max_disparity_slider = gr.Slider(
minimum=0.01,
maximum=0.10,
value=0.03, # A balanced default
step=0.005,
label="Max Disparity Ratio (controls 3D intensity)",
info="Higher values mean a stronger 3D effect, but can cause more distortion."
)
inpaint_radius_slider = gr.Slider(
minimum=1,
maximum=20,
value=5, # A common default for inpainting
step=1,
label="Inpainting Radius (controls hole filling)",
info="Larger values fill holes more, but can blur details around shifted objects."
)
with gr.Accordion("Marigold Depth Estimation Settings", open=False):
with gr.Row():
ensemble_size_slider = gr.Slider(
label="Marigold Ensemble size",
minimum=1,
maximum=10,
step=1,
value=DEFAULT_MARIGOLD_ENSEMBLE_SIZE,
info="Higher values improve accuracy but increase processing time."
)
denoise_steps_slider = gr.Slider(
label="Marigold Denoising steps",
minimum=1,
maximum=20,
step=1,
value=DEFAULT_MARIGOLD_DENOISE_STEPS,
info="More steps improve quality but increase processing time."
)
processing_res_radio = gr.Radio(
[
("Native", 0),
("Recommended (768)", 768),
("High (1024)", 1024)
],
label="Marigold Processing resolution",
value=DEFAULT_MARIGOLD_PROCESSING_RES,
info="Resolution for Marigold's internal processing. Native uses original image resolution. Higher resolutions are more accurate but slower."
)
with gr.Tabs():
with gr.TabItem("Image Conversion"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload a 2D Photo")
image_process_button = gr.Button("Convert Image to 3D")
with gr.Column():
image_output = gr.Image(type="pil", label="Stereoscopic 3D Image Output (Side-by-Side)")
# Connect the image button to the image processing function
image_process_button.click(
fn=process_image,
inputs=[
image_input,
max_disparity_slider,
inpaint_radius_slider,
ensemble_size_slider,
denoise_steps_slider,
processing_res_radio
],
outputs=image_output
)
with gr.TabItem("Video Conversion"):
with gr.Row():
with gr.Column():
video_input = gr.Video(label="Upload a 2D MP4 Video")
video_process_button = gr.Button("Convert Video to 3D")
with gr.Column():
video_output = gr.Video(label="Stereoscopic 3D Video Output (Side-by-Side)")
# Connect the video button to the video processing function
video_process_button.click(
fn=process_video,
inputs=[
video_input,
max_disparity_slider,
inpaint_radius_slider,
ensemble_size_slider,
denoise_steps_slider,
processing_res_radio
],
outputs=video_output
)
# This block is executed when the script is run directly (e.g., for local testing).
# Hugging Face Spaces typically runs the app via its own internal mechanisms.
if __name__ == '__main__':
demo.launch()
|