Steven18 commited on
Commit
856fb1f
·
1 Parent(s): 4197679

fix image_to_3d with api output

Browse files
Files changed (1) hide show
  1. app.py +31 -26
app.py CHANGED
@@ -119,26 +119,30 @@ def image_to_3d(
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
121
  req: gr.Request,
122
- ) -> Tuple[dict, str]:
123
  """
124
- Convert an image to a 3D model.
125
 
126
  Args:
127
- image (Image.Image): The input image.
128
- multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
129
- is_multiimage (bool): Whether is in multi-image mode.
130
- seed (int): The random seed.
131
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
132
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
133
- slat_guidance_strength (float): The guidance strength for structured latent generation.
134
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
135
- multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
136
 
137
  Returns:
138
- dict: The information of the generated 3D model.
139
- str: The path to the video of the 3D model.
 
140
  """
141
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
 
 
 
142
  if not is_multiimage:
143
  outputs = pipeline.run(
144
  image,
@@ -156,7 +160,7 @@ def image_to_3d(
156
  )
157
  else:
158
  outputs = pipeline.run_multi_image(
159
- [image[0] for image in multiimages],
160
  seed=seed,
161
  formats=["gaussian", "mesh"],
162
  preprocess_image=False,
@@ -170,25 +174,21 @@ def image_to_3d(
170
  },
171
  mode=multiimage_algo,
172
  )
173
- # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
174
- # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
175
- # video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
176
- # video_path = os.path.join(user_dir, 'sample.mp4')
177
- # imageio.mimsave(video_path, video, fps=15)
178
- # state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
179
- # torch.cuda.empty_cache()
180
- # return state, video_path
181
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
182
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
183
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
184
 
 
185
  video_path = os.path.join(user_dir, 'sample.mp4')
186
- os.makedirs(os.path.dirname(video_path), exist_ok=True)
187
  imageio.mimsave(video_path, video, fps=15)
188
 
 
189
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
190
  torch.cuda.empty_cache()
191
- return state, video_path
 
192
 
193
 
194
  @spaces.GPU(duration=90)
@@ -324,6 +324,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
324
 
325
  is_multiimage = gr.State(False)
326
  output_buf = gr.State()
 
327
 
328
  # Example images at the bottom of the page
329
  with gr.Row() as single_image_example:
@@ -378,8 +379,12 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
378
  outputs=[seed],
379
  ).then(
380
  image_to_3d,
381
- inputs=[image_prompt, multiimage_prompt, is_multiimage, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps, multiimage_algo],
382
- outputs=[output_buf, video_output],
 
 
 
 
383
  ).then(
384
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
385
  outputs=[extract_glb_btn, extract_gs_btn],
 
119
  slat_sampling_steps: int,
120
  multiimage_algo: Literal["multidiffusion", "stochastic"],
121
  req: gr.Request,
122
+ ) -> Tuple[dict, dict, str]:
123
  """
124
+ Convert an image (or multiple images) into a 3D model and return its state and video.
125
 
126
  Args:
127
+ image (Image.Image): The input image for single-image mode.
128
+ multiimages (List[Tuple[Image.Image, str]]): List of images with captions for multi-image mode.
129
+ is_multiimage (bool): Whether to use multi-image generation.
130
+ seed (int): Random seed for reproducibility.
131
+ ss_guidance_strength (float): Sparse structure guidance strength.
132
+ ss_sampling_steps (int): Sparse structure sampling steps.
133
+ slat_guidance_strength (float): SLAT guidance strength.
134
+ slat_sampling_steps (int): SLAT sampling steps.
135
+ multiimage_algo (str): Multi-image algorithm to use.
136
 
137
  Returns:
138
+ dict: Packed state (Gaussian + Mesh) for later usage (e.g., extract_glb).
139
+ dict: Gradio-compatible video dictionary {"video": ..., "subtitles": None}.
140
+ str: Path to raw video file (used by Gradio Client or download logic).
141
  """
142
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
143
+ os.makedirs(user_dir, exist_ok=True)
144
+
145
+ # Run pipeline depending on mode
146
  if not is_multiimage:
147
  outputs = pipeline.run(
148
  image,
 
160
  )
161
  else:
162
  outputs = pipeline.run_multi_image(
163
+ [img[0] for img in multiimages],
164
  seed=seed,
165
  formats=["gaussian", "mesh"],
166
  preprocess_image=False,
 
174
  },
175
  mode=multiimage_algo,
176
  )
177
+
178
+ # Render the 3D video combining color and geometry
 
 
 
 
 
 
179
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
180
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
181
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
182
 
183
+ # Save the video
184
  video_path = os.path.join(user_dir, 'sample.mp4')
 
185
  imageio.mimsave(video_path, video, fps=15)
186
 
187
+ # Pack state for downstream use
188
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
189
  torch.cuda.empty_cache()
190
+ return state, {"video": video_path, "subtitles": None}, video_path
191
+
192
 
193
 
194
  @spaces.GPU(duration=90)
 
324
 
325
  is_multiimage = gr.State(False)
326
  output_buf = gr.State()
327
+ video_file_path = gr.Textbox(visible=False, label="Video Path")
328
 
329
  # Example images at the bottom of the page
330
  with gr.Row() as single_image_example:
 
379
  outputs=[seed],
380
  ).then(
381
  image_to_3d,
382
+ inputs=[
383
+ image_prompt, multiimage_prompt, is_multiimage, seed,
384
+ ss_guidance_strength, ss_sampling_steps,
385
+ slat_guidance_strength, slat_sampling_steps, multiimage_algo
386
+ ],
387
+ outputs=[output_buf, video_output, video_file_path], # multi output
388
  ).then(
389
  lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
390
  outputs=[extract_glb_btn, extract_gs_btn],