jbilcke-hf HF Staff commited on
Commit
2932acc
·
verified ·
1 Parent(s): bc60693

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -605
app.py CHANGED
@@ -1,512 +1,30 @@
1
  import gradio as gr
2
  import torch
3
- import torch.nn as nn
4
  import os
5
  import tempfile
6
  import shutil
7
  import imageio
8
- import pandas as pd
9
- import numpy as np
10
- from diffsynth import ModelManager, WanVideoReCamMasterPipeline, save_video
11
- import json
12
- from torchvision.transforms import v2
13
- from einops import rearrange
14
- import torchvision
15
- from PIL import Image
16
  import logging
17
  from pathlib import Path
18
- from huggingface_hub import hf_hub_download
 
 
 
 
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
- # Get model storage path from environment variable or use default
24
- MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models")
25
- logger.info(f"Using models root directory: {MODELS_ROOT_DIR}")
26
-
27
- # Camera transformation types
28
- CAMERA_TRANSFORMATIONS = {
29
- "1": "Pan Right",
30
- "2": "Pan Left",
31
- "3": "Tilt Up",
32
- "4": "Tilt Down",
33
- "5": "Zoom In",
34
- "6": "Zoom Out",
35
- "7": "Translate Up (with rotation)",
36
- "8": "Translate Down (with rotation)",
37
- "9": "Arc Left (with rotation)",
38
- "10": "Arc Right (with rotation)"
39
- }
40
-
41
- # Global variables for model
42
- model_manager = None
43
- pipe = None
44
- is_model_loaded = False
45
-
46
- # Define model repositories and files
47
- WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B"
48
- WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B"
49
- WAN21_FILES = [
50
- "diffusion_pytorch_model.safetensors",
51
- "models_t5_umt5-xxl-enc-bf16.pth",
52
- "Wan2.1_VAE.pth"
53
- ]
54
-
55
- # Define tokenizer files to download
56
- UMT5_XXL_TOKENIZER_FILES = [
57
- "google/umt5-xxl/special_tokens_map.json",
58
- "google/umt5-xxl/spiece.model",
59
- "google/umt5-xxl/tokenizer.json",
60
- "google/umt5-xxl/tokenizer_config.json"
61
- ]
62
-
63
- RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1"
64
- RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt"
65
- RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints"
66
-
67
- # Define test data directory
68
- TEST_DATA_DIR = "example_test_data"
69
-
70
- def download_umt5_xxl_tokenizer(progress_callback=None):
71
- """Download UMT5-XXL tokenizer files from HuggingFace"""
72
-
73
- total_files = len(UMT5_XXL_TOKENIZER_FILES)
74
- downloaded_paths = []
75
-
76
- for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES):
77
- local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}"
78
- filename = os.path.basename(file_path)
79
- full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}"
80
-
81
- # Update progress
82
- if progress_callback:
83
- progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}")
84
-
85
- # Check if already exists
86
- if os.path.exists(full_local_path):
87
- logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}")
88
- downloaded_paths.append(full_local_path)
89
- continue
90
-
91
- # Create directory if it doesn't exist
92
- os.makedirs(local_dir, exist_ok=True)
93
-
94
- # Download the file
95
- logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...")
96
-
97
- if progress_callback:
98
- progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}")
99
-
100
- try:
101
- # Download using huggingface_hub
102
- downloaded_path = hf_hub_download(
103
- repo_id=WAN21_REPO_ID,
104
- filename=file_path,
105
- local_dir=WAN21_LOCAL_DIR,
106
- local_dir_use_symlinks=False
107
- )
108
- logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!")
109
- downloaded_paths.append(downloaded_path)
110
- except Exception as e:
111
- logger.error(f"✗ Error downloading tokenizer file {filename}: {e}")
112
- raise
113
-
114
- if progress_callback:
115
- progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!")
116
-
117
- return downloaded_paths
118
-
119
- def download_wan21_models(progress_callback=None):
120
- """Download Wan2.1 model files from HuggingFace"""
121
-
122
- total_files = len(WAN21_FILES)
123
- downloaded_paths = []
124
-
125
- # Create directory if it doesn't exist
126
- Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
127
-
128
- for i, filename in enumerate(WAN21_FILES):
129
- local_path = Path(WAN21_LOCAL_DIR) / filename
130
-
131
- # Update progress
132
- if progress_callback:
133
- progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}")
134
-
135
- # Check if already exists
136
- if local_path.exists():
137
- logger.info(f"✓ {filename} already exists at {local_path}")
138
- downloaded_paths.append(str(local_path))
139
- continue
140
-
141
- # Download the file
142
- logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...")
143
-
144
- if progress_callback:
145
- progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}")
146
-
147
- try:
148
- # Download using huggingface_hub
149
- downloaded_path = hf_hub_download(
150
- repo_id=WAN21_REPO_ID,
151
- filename=filename,
152
- local_dir=WAN21_LOCAL_DIR,
153
- local_dir_use_symlinks=False
154
- )
155
- logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!")
156
- downloaded_paths.append(downloaded_path)
157
- except Exception as e:
158
- logger.error(f"✗ Error downloading {filename}: {e}")
159
- raise
160
-
161
- if progress_callback:
162
- progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!")
163
-
164
- return downloaded_paths
165
-
166
- def download_recammaster_checkpoint(progress_callback=None):
167
- """Download ReCamMaster checkpoint from HuggingFace using huggingface_hub"""
168
- checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE
169
-
170
- # Check if already exists
171
- if checkpoint_path.exists():
172
- logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}")
173
- return checkpoint_path
174
-
175
- # Create directory if it doesn't exist
176
- Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
177
-
178
- # Download the checkpoint
179
- logger.info("Downloading ReCamMaster checkpoint from HuggingFace...")
180
- logger.info(f"Repository: {RECAMMASTER_REPO_ID}")
181
- logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}")
182
- logger.info(f"Destination: {checkpoint_path}")
183
-
184
- if progress_callback:
185
- progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...")
186
-
187
- try:
188
- # Download using huggingface_hub
189
- downloaded_path = hf_hub_download(
190
- repo_id=RECAMMASTER_REPO_ID,
191
- filename=RECAMMASTER_CHECKPOINT_FILE,
192
- local_dir=RECAMMASTER_LOCAL_DIR,
193
- local_dir_use_symlinks=False
194
- )
195
- logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!")
196
-
197
- if progress_callback:
198
- progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!")
199
-
200
- return downloaded_path
201
- except Exception as e:
202
- logger.error(f"✗ Error downloading checkpoint: {e}")
203
- raise
204
-
205
- def create_test_data_structure(progress_callback=None):
206
- """Create sample camera extrinsics data for testing"""
207
-
208
- if progress_callback:
209
- progress_callback(0.0, desc="Creating test data structure...")
210
-
211
- # Create directories
212
- data_dir = Path(f"{TEST_DATA_DIR}/cameras")
213
- videos_dir = Path(f"{TEST_DATA_DIR}/videos")
214
- data_dir.mkdir(parents=True, exist_ok=True)
215
- videos_dir.mkdir(parents=True, exist_ok=True)
216
-
217
- camera_file = data_dir / "camera_extrinsics.json"
218
-
219
- # Skip if file already exists
220
- if camera_file.exists():
221
- logger.info(f"✓ Camera extrinsics already exist at {camera_file}")
222
-
223
- if progress_callback:
224
- progress_callback(1.0, desc="Test data structure already exists")
225
-
226
- return
227
-
228
- if progress_callback:
229
- progress_callback(0.3, desc="Generating camera extrinsics data...")
230
-
231
- # Generate sample camera data
232
- camera_data = {}
233
-
234
- # Create 81 frames with 10 camera trajectories each
235
- for frame_idx in range(81):
236
- frame_key = f"frame{frame_idx}"
237
- camera_data[frame_key] = {}
238
-
239
- for cam_idx in range(1, 11): # Camera types 1-10
240
- # Create a sample camera matrix (this is just an example - replace with actual logic if needed)
241
- # In reality, these would be calculated based on specific camera movement patterns
242
-
243
- # Create a base identity matrix
244
- base_matrix = np.eye(4)
245
-
246
- # Add some variation based on frame and camera type
247
- # This is a simplistic example - real camera movements would be more complex
248
- if cam_idx == 1: # Pan Right
249
- base_matrix[0, 3] = 0.01 * frame_idx # Move right over time
250
- elif cam_idx == 2: # Pan Left
251
- base_matrix[0, 3] = -0.01 * frame_idx # Move left over time
252
- elif cam_idx == 3: # Tilt Up
253
- # Rotate around X-axis
254
- angle = 0.005 * frame_idx
255
- base_matrix[1, 1] = np.cos(angle)
256
- base_matrix[1, 2] = -np.sin(angle)
257
- base_matrix[2, 1] = np.sin(angle)
258
- base_matrix[2, 2] = np.cos(angle)
259
- elif cam_idx == 4: # Tilt Down
260
- # Rotate around X-axis (opposite direction)
261
- angle = -0.005 * frame_idx
262
- base_matrix[1, 1] = np.cos(angle)
263
- base_matrix[1, 2] = -np.sin(angle)
264
- base_matrix[2, 1] = np.sin(angle)
265
- base_matrix[2, 2] = np.cos(angle)
266
- elif cam_idx == 5: # Zoom In
267
- base_matrix[2, 3] = -0.01 * frame_idx # Move forward over time
268
- elif cam_idx == 6: # Zoom Out
269
- base_matrix[2, 3] = 0.01 * frame_idx # Move backward over time
270
- elif cam_idx == 7: # Translate Up (with rotation)
271
- base_matrix[1, 3] = 0.01 * frame_idx # Move up over time
272
- angle = 0.003 * frame_idx
273
- base_matrix[0, 0] = np.cos(angle)
274
- base_matrix[0, 2] = np.sin(angle)
275
- base_matrix[2, 0] = -np.sin(angle)
276
- base_matrix[2, 2] = np.cos(angle)
277
- elif cam_idx == 8: # Translate Down (with rotation)
278
- base_matrix[1, 3] = -0.01 * frame_idx # Move down over time
279
- angle = -0.003 * frame_idx
280
- base_matrix[0, 0] = np.cos(angle)
281
- base_matrix[0, 2] = np.sin(angle)
282
- base_matrix[2, 0] = -np.sin(angle)
283
- base_matrix[2, 2] = np.cos(angle)
284
- elif cam_idx == 9: # Arc Left (with rotation)
285
- angle = 0.005 * frame_idx
286
- radius = 2.0
287
- base_matrix[0, 3] = -radius * np.sin(angle)
288
- base_matrix[2, 3] = -radius * np.cos(angle) + radius
289
- # Rotate to look at center
290
- look_angle = angle + np.pi
291
- base_matrix[0, 0] = np.cos(look_angle)
292
- base_matrix[0, 2] = np.sin(look_angle)
293
- base_matrix[2, 0] = -np.sin(look_angle)
294
- base_matrix[2, 2] = np.cos(look_angle)
295
- elif cam_idx == 10: # Arc Right (with rotation)
296
- angle = -0.005 * frame_idx
297
- radius = 2.0
298
- base_matrix[0, 3] = -radius * np.sin(angle)
299
- base_matrix[2, 3] = -radius * np.cos(angle) + radius
300
- # Rotate to look at center
301
- look_angle = angle + np.pi
302
- base_matrix[0, 0] = np.cos(look_angle)
303
- base_matrix[0, 2] = np.sin(look_angle)
304
- base_matrix[2, 0] = -np.sin(look_angle)
305
- base_matrix[2, 2] = np.cos(look_angle)
306
-
307
- # Format the matrix as a string (as expected by the app)
308
- matrix_str = ' '.join([' '.join([str(base_matrix[i, j]) for j in range(4)]) for i in range(4)])
309
- matrix_str = '[ ' + matrix_str.replace(' ', ' ] [ ', 3) + ' ]'
310
-
311
- camera_data[frame_key][f"cam{cam_idx:02d}"] = matrix_str
312
-
313
- if progress_callback:
314
- progress_callback(0.7, desc="Saving camera extrinsics data...")
315
-
316
- # Save camera extrinsics to JSON file
317
- with open(camera_file, 'w') as f:
318
- json.dump(camera_data, f, indent=2)
319
-
320
- logger.info(f"Created sample camera extrinsics at {camera_file}")
321
- logger.info(f"Created directory for example videos at {videos_dir}")
322
-
323
- if progress_callback:
324
- progress_callback(1.0, desc="Test data structure created successfully!")
325
-
326
- class Camera(object):
327
- def __init__(self, c2w):
328
- c2w_mat = np.array(c2w).reshape(4, 4)
329
- self.c2w_mat = c2w_mat
330
- self.w2c_mat = np.linalg.inv(c2w_mat)
331
-
332
- def parse_matrix(matrix_str):
333
- """Parse camera matrix string from JSON format"""
334
- rows = matrix_str.strip().split('] [')
335
- matrix = []
336
- for row in rows:
337
- row = row.replace('[', '').replace(']', '')
338
- matrix.append(list(map(float, row.split())))
339
- return np.array(matrix)
340
 
341
- def get_relative_pose(cam_params):
342
- """Calculate relative camera poses"""
343
- abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
344
- abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
345
-
346
- cam_to_origin = 0
347
- target_cam_c2w = np.array([
348
- [1, 0, 0, 0],
349
- [0, 1, 0, -cam_to_origin],
350
- [0, 0, 1, 0],
351
- [0, 0, 0, 1]
352
- ])
353
- abs2rel = target_cam_c2w @ abs_w2cs[0]
354
- ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
355
- ret_poses = np.array(ret_poses, dtype=np.float32)
356
- return ret_poses
357
-
358
- def load_models(progress_callback=None):
359
- """Load the ReCamMaster models"""
360
- global model_manager, pipe, is_model_loaded
361
-
362
- if is_model_loaded:
363
- return "Models already loaded!"
364
-
365
- try:
366
- logger.info("Starting model loading...")
367
-
368
- # First create the test data structure
369
- if progress_callback:
370
- progress_callback(0.05, desc="Setting up test data structure...")
371
-
372
- try:
373
- create_test_data_structure(progress_callback)
374
- except Exception as e:
375
- error_msg = f"Error creating test data structure: {str(e)}"
376
- logger.error(error_msg)
377
- return error_msg
378
-
379
- # Second, ensure the checkpoint is downloaded
380
- if progress_callback:
381
- progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...")
382
-
383
- try:
384
- ckpt_path = download_recammaster_checkpoint(progress_callback)
385
- logger.info(f"Using checkpoint at {ckpt_path}")
386
- except Exception as e:
387
- error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}"
388
- logger.error(error_msg)
389
- return error_msg
390
-
391
- # Third, download Wan2.1 models if needed
392
- if progress_callback:
393
- progress_callback(0.2, desc="Checking for Wan2.1 models...")
394
-
395
- try:
396
- wan21_paths = download_wan21_models(progress_callback)
397
- logger.info(f"Using Wan2.1 models: {wan21_paths}")
398
- except Exception as e:
399
- error_msg = f"Error downloading Wan2.1 models: {str(e)}"
400
- logger.error(error_msg)
401
- return error_msg
402
-
403
- # Fourth, download UMT5-XXL tokenizer files
404
- if progress_callback:
405
- progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...")
406
-
407
- try:
408
- tokenizer_paths = download_umt5_xxl_tokenizer(progress_callback)
409
- logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}")
410
- except Exception as e:
411
- error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}"
412
- logger.error(error_msg)
413
- return error_msg
414
-
415
- # Now, load the models
416
- if progress_callback:
417
- progress_callback(0.4, desc="Loading model manager...")
418
-
419
- # Create symlink for google/umt5-xxl to handle potential path issues
420
- # Some libraries might look for this in a different way
421
- try:
422
- google_dir = f"{MODELS_ROOT_DIR}/google"
423
- if not os.path.exists(google_dir):
424
- os.makedirs(google_dir, exist_ok=True)
425
-
426
- umt5_xxl_symlink = f"{google_dir}/umt5-xxl"
427
- umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl"
428
-
429
- # Create a symlink if it doesn't exist
430
- if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source):
431
- if os.name == 'nt': # Windows
432
- import ctypes
433
- kdll = ctypes.windll.LoadLibrary("kernel32.dll")
434
- kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1)
435
- else: # Unix/Linux
436
- os.symlink(umt5_xxl_source, umt5_xxl_symlink)
437
- logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}")
438
- except Exception as e:
439
- logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}")
440
- # This is a warning, not an error, as we'll try to proceed anyway
441
-
442
- # Load Wan2.1 pre-trained models
443
- model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
444
-
445
- if progress_callback:
446
- progress_callback(0.5, desc="Loading Wan2.1 models...")
447
-
448
- # Build full paths for the model files
449
- model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES]
450
-
451
- for model_file in model_files:
452
- logger.info(f"Loading model from: {model_file}")
453
- if not os.path.exists(model_file):
454
- error_msg = f"Error: Model file not found: {model_file}"
455
- logger.error(error_msg)
456
- return error_msg
457
-
458
- # Set environment variable for transformers to find the tokenizer
459
- os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR
460
-
461
- # Set the configuration for the text encoder to use the downloaded tokenizer path
462
- # This is needed because the WanTextEncoder expects the tokenizer to be at this path
463
- os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning
464
-
465
- model_manager.load_models(model_files)
466
-
467
- if progress_callback:
468
- progress_callback(0.7, desc="Creating pipeline...")
469
-
470
- pipe = WanVideoReCamMasterPipeline.from_model_manager(model_manager, device="cuda")
471
-
472
- if progress_callback:
473
- progress_callback(0.8, desc="Initializing ReCamMaster modules...")
474
-
475
- # Initialize additional modules introduced in ReCamMaster
476
- dim = pipe.dit.blocks[0].self_attn.q.weight.shape[0]
477
- for block in pipe.dit.blocks:
478
- block.cam_encoder = nn.Linear(12, dim)
479
- block.projector = nn.Linear(dim, dim)
480
- block.cam_encoder.weight.data.zero_()
481
- block.cam_encoder.bias.data.zero_()
482
- block.projector.weight = nn.Parameter(torch.eye(dim))
483
- block.projector.bias = nn.Parameter(torch.zeros(dim))
484
-
485
- if progress_callback:
486
- progress_callback(0.9, desc="Loading ReCamMaster checkpoint...")
487
-
488
- # Load ReCamMaster checkpoint
489
- if not os.path.exists(ckpt_path):
490
- error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt."
491
- logger.error(error_msg)
492
- return error_msg
493
-
494
- state_dict = torch.load(ckpt_path, map_location="cpu")
495
- pipe.dit.load_state_dict(state_dict, strict=True)
496
- pipe.to("cuda")
497
- pipe.to(dtype=torch.bfloat16)
498
-
499
- is_model_loaded = True
500
-
501
- if progress_callback:
502
- progress_callback(1.0, desc="Models loaded successfully!")
503
-
504
- logger.info("Models loaded successfully!")
505
- return "Models loaded successfully!"
506
-
507
- except Exception as e:
508
- logger.error(f"Error loading models: {str(e)}")
509
- return f"Error loading models: {str(e)}"
510
 
511
  def extract_frames_from_video(video_path, output_dir, max_frames=81):
512
  """Extract frames from video and ensure we have at least 81 frames"""
@@ -535,93 +53,6 @@ def extract_frames_from_video(video_path, output_dir, max_frames=81):
535
 
536
  return len(frames[:max_frames]), fps
537
 
538
- def process_video_for_recammaster(video_path, text_prompt, cam_type, height=480, width=832):
539
- """Process video through ReCamMaster model"""
540
- global pipe
541
-
542
- # Create frame processor
543
- frame_process = v2.Compose([
544
- v2.CenterCrop(size=(height, width)),
545
- v2.Resize(size=(height, width), antialias=True),
546
- v2.ToTensor(),
547
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
548
- ])
549
-
550
- def crop_and_resize(image):
551
- width_img, height_img = image.size
552
- scale = max(width / width_img, height / height_img)
553
- image = torchvision.transforms.functional.resize(
554
- image,
555
- (round(height_img*scale), round(width_img*scale)),
556
- interpolation=torchvision.transforms.InterpolationMode.BILINEAR
557
- )
558
- return image
559
-
560
- # Load video frames
561
- reader = imageio.get_reader(video_path)
562
- frames = []
563
-
564
- for i in range(81): # ReCamMaster needs exactly 81 frames
565
- try:
566
- frame = reader.get_data(i)
567
- frame = Image.fromarray(frame)
568
- frame = crop_and_resize(frame)
569
- frame = frame_process(frame)
570
- frames.append(frame)
571
- except:
572
- # If we run out of frames, repeat the last one
573
- if frames:
574
- frames.append(frames[-1])
575
- else:
576
- raise ValueError("Video is too short!")
577
-
578
- reader.close()
579
-
580
- frames = torch.stack(frames, dim=0)
581
- frames = rearrange(frames, "T C H W -> C T H W")
582
- video_tensor = frames.unsqueeze(0) # Add batch dimension
583
-
584
- # Load camera trajectory
585
- tgt_camera_path = f"./{TEST_DATA_DIR}/cameras/camera_extrinsics.json"
586
- with open(tgt_camera_path, 'r') as file:
587
- cam_data = json.load(file)
588
-
589
- # Get camera trajectory for selected type
590
- cam_idx = list(range(81))[::4] # Sample every 4 frames
591
- traj = [parse_matrix(cam_data[f"frame{idx}"][f"cam{int(cam_type):02d}"]) for idx in cam_idx]
592
- traj = np.stack(traj).transpose(0, 2, 1)
593
-
594
- c2ws = []
595
- for c2w in traj:
596
- c2w = c2w[:, [1, 2, 0, 3]]
597
- c2w[:3, 1] *= -1.
598
- c2w[:3, 3] /= 100
599
- c2ws.append(c2w)
600
-
601
- tgt_cam_params = [Camera(cam_param) for cam_param in c2ws]
602
- relative_poses = []
603
- for i in range(len(tgt_cam_params)):
604
- relative_pose = get_relative_pose([tgt_cam_params[0], tgt_cam_params[i]])
605
- relative_poses.append(torch.as_tensor(relative_pose)[:,:3,:][1])
606
-
607
- pose_embedding = torch.stack(relative_poses, dim=0) # 21x3x4
608
- pose_embedding = rearrange(pose_embedding, 'b c d -> b (c d)')
609
- camera_tensor = pose_embedding.to(torch.bfloat16).unsqueeze(0) # Add batch dimension
610
-
611
- # Generate video with ReCamMaster
612
- video = pipe(
613
- prompt=[text_prompt],
614
- negative_prompt=["worst quality, low quality, blurry, jittery, distorted"],
615
- source_video=video_tensor,
616
- target_camera=camera_tensor,
617
- cfg_scale=5.0,
618
- num_inference_steps=50,
619
- seed=0,
620
- tiled=True
621
- )
622
-
623
- return video
624
-
625
  def generate_recammaster_video(
626
  video_file,
627
  text_prompt,
@@ -629,11 +60,13 @@ def generate_recammaster_video(
629
  progress=gr.Progress()
630
  ):
631
  """Main function to generate video with ReCamMaster"""
632
- global pipe, is_model_loaded
633
 
634
- if not is_model_loaded:
635
  return None, "Error: Models not loaded! Please load models first."
636
 
 
 
 
637
  if video_file is None:
638
  return None, "Please upload a video file."
639
 
@@ -653,7 +86,7 @@ def generate_recammaster_video(
653
 
654
  # Process with ReCamMaster
655
  progress(0.3, desc="Processing with ReCamMaster...")
656
- output_video = process_video_for_recammaster(
657
  input_video_path,
658
  text_prompt,
659
  camera_type
@@ -662,6 +95,7 @@ def generate_recammaster_video(
662
  # Save output video
663
  progress(0.9, desc="Saving output video...")
664
  output_path = os.path.join(temp_dir, "output.mp4")
 
665
  save_video(output_video, output_path, fps=30, quality=5)
666
 
667
  # Copy to persistent location
@@ -681,22 +115,12 @@ def generate_recammaster_video(
681
 
682
  # Create Gradio interface
683
  with gr.Blocks(title="ReCamMaster Demo") as demo:
684
- # Show loading status
685
- loading_status = gr.Textbox(
686
- label="Model Loading Status",
687
- value="Loading models, please wait...",
688
- interactive=False,
689
- visible=True
690
- )
691
-
692
  gr.Markdown(f"""
693
- # 🎥 ReCamMaster Demo
694
 
695
  ReCamMaster allows you to re-capture videos with novel camera trajectories.
696
  Upload a video and select a camera transformation to see the magic!
697
-
698
- **Note:** All required models will be automatically downloaded to {MODELS_ROOT_DIR} when you start the app.
699
- You can customize this location by setting the RECAMMASTER_MODELS_DIR environment variable.
700
  """)
701
 
702
  with gr.Row():
@@ -738,13 +162,6 @@ with gr.Blocks(title="ReCamMaster Demo") as demo:
738
  inputs=[video_input, text_prompt, camera_type],
739
  )
740
 
741
- # Load models automatically when the interface loads
742
- def on_load():
743
- status = load_models()
744
- return gr.update(value=status, visible=True if "Error" in status else False)
745
-
746
- demo.load(on_load, outputs=[loading_status])
747
-
748
  # Event handlers
749
  generate_btn.click(
750
  fn=generate_recammaster_video,
@@ -753,4 +170,5 @@ with gr.Blocks(title="ReCamMaster Demo") as demo:
753
  )
754
 
755
  if __name__ == "__main__":
 
756
  demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
 
3
  import os
4
  import tempfile
5
  import shutil
6
  import imageio
 
 
 
 
 
 
 
 
7
  import logging
8
  from pathlib import Path
9
+
10
+ # Import from our modules
11
+ from model_loader import ModelLoader, MODELS_ROOT_DIR
12
+ from video_processor import VideoProcessor
13
+ from config import CAMERA_TRANSFORMATIONS, TEST_DATA_DIR
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ # Global model loader instance
19
+ model_loader = ModelLoader()
20
+ video_processor = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ def init_video_processor():
23
+ """Initialize video processor"""
24
+ global video_processor
25
+ if model_loader.is_loaded and video_processor is None:
26
+ video_processor = VideoProcessor(model_loader.pipe)
27
+ return video_processor is not None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def extract_frames_from_video(video_path, output_dir, max_frames=81):
30
  """Extract frames from video and ensure we have at least 81 frames"""
 
53
 
54
  return len(frames[:max_frames]), fps
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def generate_recammaster_video(
57
  video_file,
58
  text_prompt,
 
60
  progress=gr.Progress()
61
  ):
62
  """Main function to generate video with ReCamMaster"""
 
63
 
64
+ if not model_loader.is_loaded:
65
  return None, "Error: Models not loaded! Please load models first."
66
 
67
+ if not init_video_processor():
68
+ return None, "Error: Failed to initialize video processor."
69
+
70
  if video_file is None:
71
  return None, "Please upload a video file."
72
 
 
86
 
87
  # Process with ReCamMaster
88
  progress(0.3, desc="Processing with ReCamMaster...")
89
+ output_video = video_processor.process_video(
90
  input_video_path,
91
  text_prompt,
92
  camera_type
 
95
  # Save output video
96
  progress(0.9, desc="Saving output video...")
97
  output_path = os.path.join(temp_dir, "output.mp4")
98
+ from diffsynth import save_video
99
  save_video(output_video, output_path, fps=30, quality=5)
100
 
101
  # Copy to persistent location
 
115
 
116
  # Create Gradio interface
117
  with gr.Blocks(title="ReCamMaster Demo") as demo:
118
+
 
 
 
 
 
 
 
119
  gr.Markdown(f"""
120
+ # 🎥 ReCamMaster
121
 
122
  ReCamMaster allows you to re-capture videos with novel camera trajectories.
123
  Upload a video and select a camera transformation to see the magic!
 
 
 
124
  """)
125
 
126
  with gr.Row():
 
162
  inputs=[video_input, text_prompt, camera_type],
163
  )
164
 
 
 
 
 
 
 
 
165
  # Event handlers
166
  generate_btn.click(
167
  fn=generate_recammaster_video,
 
170
  )
171
 
172
  if __name__ == "__main__":
173
+ model_loader.load_models()
174
  demo.launch(share=True)