import spaces import torch print(f'torch version:{torch.__version__}') import huggingface_hub print(f' huggingface_hub.__version__ {huggingface_hub.__version__}') import functools import gc import os os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' import subprocess import shutil import sys import tempfile import time from datetime import datetime from pathlib import Path import uuid import cv2 import gradio as gr from huggingface_hub import hf_hub_download from PIL import Image sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.misc.image_io import save_interpolated_video from src.model.model.anysplat import AnySplat from src.model.ply_export import export_ply from src.utils.image import process_image os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" from plyfile import PlyData import numpy as np import argparse from io import BytesIO def process_ply_to_splat(ply_file_path): plydata = PlyData.read(ply_file_path) vert = plydata["vertex"] sorted_indices = np.argsort( -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) / (1 + np.exp(-vert["opacity"])) ) buffer = BytesIO() for idx in sorted_indices: v = plydata["vertex"][idx] position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) scales = np.exp( np.array( [v["scale_0"], v["scale_1"], v["scale_2"]], dtype=np.float32, ) ) rot = np.array( [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], dtype=np.float32, ) SH_C0 = 0.28209479177387814 color = np.array( [ 0.5 + SH_C0 * v["f_dc_0"], 0.5 + SH_C0 * v["f_dc_1"], 0.5 + SH_C0 * v["f_dc_2"], 1 / (1 + np.exp(-v["opacity"])), ] ) buffer.write(position.tobytes()) buffer.write(scales.tobytes()) buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) buffer.write( ((rot / np.linalg.norm(rot)) * 128 + 128) .clip(0, 255) .astype(np.uint8) .tobytes() ) return buffer.getvalue() def save_splat_file(splat_data, output_path): with open(output_path, "wb") as f: f.write(splat_data) def get_reconstructed_scene(outdir, image_files, model, device): images = [process_image(img_path) for img_path in image_files] images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448] b, v, c, h, w = images.shape assert c == 3, "Images must have 3 channels" gaussians, pred_context_pose = model.inference((images + 1) * 0.5) pred_all_extrinsic = pred_context_pose["extrinsic"] pred_all_intrinsic = pred_context_pose["intrinsic"] video, depth_colored = save_interpolated_video( pred_all_extrinsic, pred_all_intrinsic, b, h, w, gaussians, outdir, model.decoder, ) plyfile = os.path.join(outdir, "gaussians.ply") # splatfile = os.path.join(outdir, "gaussians.splat") export_ply( gaussians.means[0], gaussians.scales[0], gaussians.rotations[0], gaussians.harmonics[0], gaussians.opacities[0], Path(plyfile), save_sh_dc_only=True, ) # splat_data = process_ply_to_splat(plyfile) # save_splat_file(splat_data, splatfile) # Clean up torch.cuda.empty_cache() return plyfile, video, depth_colored def extract_images(input_images, session_id): start_time = time.time() gc.collect() torch.cuda.empty_cache() base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) target_dir = base_dir target_dir_images = os.path.join(target_dir, "images") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] if input_images is not None: for file_data in input_images: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = file_data dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) shutil.copy(file_path, dst_path) image_paths.append(dst_path) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths def extract_frames(input_video, session_id): start_time = time.time() gc.collect() torch.cuda.empty_cache() base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) target_dir = base_dir target_dir_images = os.path.join(target_dir, "images") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] if input_video is not None: if isinstance(input_video, dict) and "name" in input_video: video_path = input_video["name"] else: video_path = input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * 1) # 1 frame/sec count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join( target_dir_images, f"{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths def update_gallery_on_video_upload(input_video, session_id): if not input_video: return None, None, None target_dir, image_paths = extract_frames(input_video, session_id) return None, target_dir, image_paths def update_gallery_on_images_upload(input_images, session_id): if not input_images: return None, None, None target_dir, image_paths = extract_images(input_images, session_id) return None, target_dir, image_paths @spaces.GPU() def generate_splats_from_video(video_path, session_id=None): """ Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model. Args: video_path (str): Path to the input video file on disk. Returns: plyfile: Path to the reconstructed 3D object from the given video. rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting. """ if session_id is None: session_id = uuid.uuid4().hex images_folder, image_paths = extract_frames(video_path, session_id) plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id) return plyfile, rgb_vid, depth_vid, image_paths @spaces.GPU() def generate_splats_from_images(image_paths, session_id=None): """ Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model. Args: image_paths (str): Path to the input image files on disk. Returns: plyfile: Path to the reconstructed 3D object from the given image files. rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. """ processed_image_paths = [] for file_data in image_paths: if isinstance(file_data, tuple): file_path, _ = file_data processed_image_paths.append(file_path) else: processed_image_paths.append(file_data) image_paths = processed_image_paths print(image_paths) if len(image_paths) == 1: image_paths.append(image_paths[0]) if session_id is None: session_id = uuid.uuid4().hex start_time = time.time() gc.collect() torch.cuda.empty_cache() base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) print("Running run_model...") with torch.no_grad(): plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device) end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") return plyfile, rgb_vid, depth_vid def cleanup(request: gr.Request): sid = request.session_hash if sid: d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid) shutil.rmtree(d1, ignore_errors=True) def start_session(request: gr.Request): return request.session_hash if __name__ == "__main__": share = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = AnySplat.from_pretrained( "lhjiang/anysplat" ) model = model.to(device) model.eval() for param in model.parameters(): param.requires_grad = False css = """ #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css, title="AnySplat Demo") as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") scene_name = gr.Textbox(label="scene_name", visible=False, value="None") image_type = gr.Textbox(label="image_type", visible=False, value="None") with gr.Column(elem_id="col-container"): gr.HTML( """
""" ) with gr.Row(): with gr.Column(): with gr.Tab("Video"): input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512) with gr.Tab("Images"): input_images = gr.File(file_count="multiple", label="Upload Files", height=512) submit_btn = gr.Button( "🖌️ Generate Gaussian Splat", scale=1, variant="primary" ) image_gallery = gr.Gallery( label="Preview", columns=4, height="300px", show_download_button=True, object_fit="contain", preview=True, ) with gr.Column(): with gr.Column(): gr.HTML( """This might take a few seconds to load the 3D model
""" ) reconstruction_output = gr.Model3D( label="Ply Gaussian Model", height=512, zoom_speed=0.5, pan_speed=0.5, # camera_position=[20, 20, 20], ) with gr.Row(): rgb_video = gr.Video( label="RGB Video", interactive=False, autoplay=True ) depth_video = gr.Video( label="Depth Video", interactive=False, autoplay=True, ) with gr.Row(): examples = [ ["examples/video/re10k_1eca36ec55b88fe4.mp4"], ["examples/video/spann3r.mp4"], ["examples/video/bungeenerf_colosseum.mp4"], ["examples/video/fox.mp4"], ["examples/video/vrnerf_apartment.mp4"], # [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",], # [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",], # [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",], # [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",], # [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",], # [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",], # [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",], # [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",], # [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",], # [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",], ] gr.Examples( examples=examples, inputs=[ input_video ], outputs=[ reconstruction_output, rgb_video, depth_video, image_gallery ], fn=generate_splats_from_video, cache_examples=True, ) submit_btn.click( fn=generate_splats_from_images, inputs=[image_gallery, session_state], outputs=[reconstruction_output, rgb_video, depth_video]) input_video.upload( fn=update_gallery_on_video_upload, inputs=[input_video, session_state], outputs=[reconstruction_output, target_dir_output, image_gallery], show_api=False ) input_images.upload( fn=update_gallery_on_images_upload, inputs=[input_images, session_state], outputs=[reconstruction_output, target_dir_output, image_gallery], show_api=False ) demo.unload(cleanup) demo.queue() demo.launch(show_error=True, share=True, mcp_server=True)