import os import torch import torch.nn as nn import logging from pathlib import Path from huggingface_hub import hf_hub_download from diffsynth import ModelManager, WanVideoReCamMasterPipeline logger = logging.getLogger(__name__) # Get model storage path from environment variable or use default MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models") logger.info(f"Using models root directory: {MODELS_ROOT_DIR}") # Define model repositories and files WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B" WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B" WAN21_FILES = [ "diffusion_pytorch_model.safetensors", "models_t5_umt5-xxl-enc-bf16.pth", "Wan2.1_VAE.pth" ] # Define tokenizer files to download UMT5_XXL_TOKENIZER_FILES = [ "google/umt5-xxl/special_tokens_map.json", "google/umt5-xxl/spiece.model", "google/umt5-xxl/tokenizer.json", "google/umt5-xxl/tokenizer_config.json" ] RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1" RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt" RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints" class ModelLoader: def __init__(self): self.model_manager = None self.pipe = None self.is_loaded = False def download_umt5_xxl_tokenizer(self, progress_callback=None): """Download UMT5-XXL tokenizer files from HuggingFace""" total_files = len(UMT5_XXL_TOKENIZER_FILES) downloaded_paths = [] for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES): local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}" filename = os.path.basename(file_path) full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}" # Update progress if progress_callback: progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}") # Check if already exists if os.path.exists(full_local_path): logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}") downloaded_paths.append(full_local_path) continue # Create directory if it doesn't exist os.makedirs(local_dir, exist_ok=True) # Download the file logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...") if progress_callback: progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}") try: # Download using huggingface_hub downloaded_path = hf_hub_download( repo_id=WAN21_REPO_ID, filename=file_path, local_dir=WAN21_LOCAL_DIR, local_dir_use_symlinks=False ) logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!") downloaded_paths.append(downloaded_path) except Exception as e: logger.error(f"✗ Error downloading tokenizer file {filename}: {e}") raise if progress_callback: progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!") return downloaded_paths def download_wan21_models(self, progress_callback=None): """Download Wan2.1 model files from HuggingFace""" total_files = len(WAN21_FILES) downloaded_paths = [] # Create directory if it doesn't exist Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True) for i, filename in enumerate(WAN21_FILES): local_path = Path(WAN21_LOCAL_DIR) / filename # Update progress if progress_callback: progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}") # Check if already exists if local_path.exists(): logger.info(f"✓ {filename} already exists at {local_path}") downloaded_paths.append(str(local_path)) continue # Download the file logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...") if progress_callback: progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}") try: # Download using huggingface_hub downloaded_path = hf_hub_download( repo_id=WAN21_REPO_ID, filename=filename, local_dir=WAN21_LOCAL_DIR, local_dir_use_symlinks=False ) logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!") downloaded_paths.append(downloaded_path) except Exception as e: logger.error(f"✗ Error downloading {filename}: {e}") raise if progress_callback: progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!") return downloaded_paths def download_recammaster_checkpoint(self, progress_callback=None): """Download ReCamMaster checkpoint from HuggingFace using huggingface_hub""" checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE # Check if already exists if checkpoint_path.exists(): logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}") return checkpoint_path # Create directory if it doesn't exist Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True) # Download the checkpoint logger.info("Downloading ReCamMaster checkpoint from HuggingFace...") logger.info(f"Repository: {RECAMMASTER_REPO_ID}") logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}") logger.info(f"Destination: {checkpoint_path}") if progress_callback: progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...") try: # Download using huggingface_hub downloaded_path = hf_hub_download( repo_id=RECAMMASTER_REPO_ID, filename=RECAMMASTER_CHECKPOINT_FILE, local_dir=RECAMMASTER_LOCAL_DIR, local_dir_use_symlinks=False ) logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!") if progress_callback: progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!") return downloaded_path except Exception as e: logger.error(f"✗ Error downloading checkpoint: {e}") raise def create_symlink_for_tokenizer(self): """Create symlink for google/umt5-xxl to handle potential path issues""" try: google_dir = f"{MODELS_ROOT_DIR}/google" if not os.path.exists(google_dir): os.makedirs(google_dir, exist_ok=True) umt5_xxl_symlink = f"{google_dir}/umt5-xxl" umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl" # Create a symlink if it doesn't exist if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source): if os.name == 'nt': # Windows import ctypes kdll = ctypes.windll.LoadLibrary("kernel32.dll") kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1) else: # Unix/Linux os.symlink(umt5_xxl_source, umt5_xxl_symlink) logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}") except Exception as e: logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}") # This is a warning, not an error, as we'll try to proceed anyway def load_models(self, progress_callback=None): """Load the ReCamMaster models""" if self.is_loaded: return "Models already loaded!" try: logger.info("Starting model loading...") # Import test data creator from test_data import create_test_data_structure # First create the test data structure if progress_callback: progress_callback(0.05, desc="Setting up test data structure...") try: create_test_data_structure(progress_callback) except Exception as e: error_msg = f"Error creating test data structure: {str(e)}" logger.error(error_msg) return error_msg # Second, ensure the checkpoint is downloaded if progress_callback: progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...") try: ckpt_path = self.download_recammaster_checkpoint(progress_callback) logger.info(f"Using checkpoint at {ckpt_path}") except Exception as e: error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}" logger.error(error_msg) return error_msg # Third, download Wan2.1 models if needed if progress_callback: progress_callback(0.2, desc="Checking for Wan2.1 models...") try: wan21_paths = self.download_wan21_models(progress_callback) logger.info(f"Using Wan2.1 models: {wan21_paths}") except Exception as e: error_msg = f"Error downloading Wan2.1 models: {str(e)}" logger.error(error_msg) return error_msg # Fourth, download UMT5-XXL tokenizer files if progress_callback: progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...") try: tokenizer_paths = self.download_umt5_xxl_tokenizer(progress_callback) logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}") except Exception as e: error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}" logger.error(error_msg) return error_msg # Now, load the models if progress_callback: progress_callback(0.4, desc="Loading model manager...") # Create symlink for tokenizer self.create_symlink_for_tokenizer() # Load Wan2.1 pre-trained models self.model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") if progress_callback: progress_callback(0.5, desc="Loading Wan2.1 models...") # Build full paths for the model files model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES] for model_file in model_files: logger.info(f"Loading model from: {model_file}") if not os.path.exists(model_file): error_msg = f"Error: Model file not found: {model_file}" logger.error(error_msg) return error_msg # Set environment variable for transformers to find the tokenizer os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning self.model_manager.load_models(model_files) if progress_callback: progress_callback(0.7, desc="Creating pipeline...") self.pipe = WanVideoReCamMasterPipeline.from_model_manager(self.model_manager, device="cuda") if progress_callback: progress_callback(0.8, desc="Initializing ReCamMaster modules...") # Initialize additional modules introduced in ReCamMaster dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] for block in self.pipe.dit.blocks: block.cam_encoder = nn.Linear(12, dim) block.projector = nn.Linear(dim, dim) block.cam_encoder.weight.data.zero_() block.cam_encoder.bias.data.zero_() block.projector.weight = nn.Parameter(torch.eye(dim)) block.projector.bias = nn.Parameter(torch.zeros(dim)) if progress_callback: progress_callback(0.9, desc="Loading ReCamMaster checkpoint...") # Load ReCamMaster checkpoint if not os.path.exists(ckpt_path): error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt." logger.error(error_msg) return error_msg state_dict = torch.load(ckpt_path, map_location="cpu") self.pipe.dit.load_state_dict(state_dict, strict=True) self.pipe.to("cuda") self.pipe.to(dtype=torch.bfloat16) self.is_loaded = True if progress_callback: progress_callback(1.0, desc="Models loaded successfully!") logger.info("Models loaded successfully!") return "Models loaded successfully!" except Exception as e: logger.error(f"Error loading models: {str(e)}") return f"Error loading models: {str(e)}"