ReCamMaster / model_loader.py
jbilcke-hf's picture
jbilcke-hf HF Staff
Upload 5 files
099dc67 verified
raw
history blame
14.1 kB
import os
import torch
import torch.nn as nn
import logging
from pathlib import Path
from huggingface_hub import hf_hub_download
from diffsynth import ModelManager, WanVideoReCamMasterPipeline
logger = logging.getLogger(__name__)
# Get model storage path from environment variable or use default
MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models")
logger.info(f"Using models root directory: {MODELS_ROOT_DIR}")
# Define model repositories and files
WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B"
WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B"
WAN21_FILES = [
"diffusion_pytorch_model.safetensors",
"models_t5_umt5-xxl-enc-bf16.pth",
"Wan2.1_VAE.pth"
]
# Define tokenizer files to download
UMT5_XXL_TOKENIZER_FILES = [
"google/umt5-xxl/special_tokens_map.json",
"google/umt5-xxl/spiece.model",
"google/umt5-xxl/tokenizer.json",
"google/umt5-xxl/tokenizer_config.json"
]
RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1"
RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt"
RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints"
class ModelLoader:
def __init__(self):
self.model_manager = None
self.pipe = None
self.is_loaded = False
def download_umt5_xxl_tokenizer(self, progress_callback=None):
"""Download UMT5-XXL tokenizer files from HuggingFace"""
total_files = len(UMT5_XXL_TOKENIZER_FILES)
downloaded_paths = []
for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES):
local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}"
filename = os.path.basename(file_path)
full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}"
# Update progress
if progress_callback:
progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}")
# Check if already exists
if os.path.exists(full_local_path):
logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}")
downloaded_paths.append(full_local_path)
continue
# Create directory if it doesn't exist
os.makedirs(local_dir, exist_ok=True)
# Download the file
logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...")
if progress_callback:
progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=WAN21_REPO_ID,
filename=file_path,
local_dir=WAN21_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!")
downloaded_paths.append(downloaded_path)
except Exception as e:
logger.error(f"✗ Error downloading tokenizer file {filename}: {e}")
raise
if progress_callback:
progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!")
return downloaded_paths
def download_wan21_models(self, progress_callback=None):
"""Download Wan2.1 model files from HuggingFace"""
total_files = len(WAN21_FILES)
downloaded_paths = []
# Create directory if it doesn't exist
Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
for i, filename in enumerate(WAN21_FILES):
local_path = Path(WAN21_LOCAL_DIR) / filename
# Update progress
if progress_callback:
progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}")
# Check if already exists
if local_path.exists():
logger.info(f"✓ {filename} already exists at {local_path}")
downloaded_paths.append(str(local_path))
continue
# Download the file
logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...")
if progress_callback:
progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=WAN21_REPO_ID,
filename=filename,
local_dir=WAN21_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!")
downloaded_paths.append(downloaded_path)
except Exception as e:
logger.error(f"✗ Error downloading {filename}: {e}")
raise
if progress_callback:
progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!")
return downloaded_paths
def download_recammaster_checkpoint(self, progress_callback=None):
"""Download ReCamMaster checkpoint from HuggingFace using huggingface_hub"""
checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE
# Check if already exists
if checkpoint_path.exists():
logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}")
return checkpoint_path
# Create directory if it doesn't exist
Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True)
# Download the checkpoint
logger.info("Downloading ReCamMaster checkpoint from HuggingFace...")
logger.info(f"Repository: {RECAMMASTER_REPO_ID}")
logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}")
logger.info(f"Destination: {checkpoint_path}")
if progress_callback:
progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...")
try:
# Download using huggingface_hub
downloaded_path = hf_hub_download(
repo_id=RECAMMASTER_REPO_ID,
filename=RECAMMASTER_CHECKPOINT_FILE,
local_dir=RECAMMASTER_LOCAL_DIR,
local_dir_use_symlinks=False
)
logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!")
if progress_callback:
progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!")
return downloaded_path
except Exception as e:
logger.error(f"✗ Error downloading checkpoint: {e}")
raise
def create_symlink_for_tokenizer(self):
"""Create symlink for google/umt5-xxl to handle potential path issues"""
try:
google_dir = f"{MODELS_ROOT_DIR}/google"
if not os.path.exists(google_dir):
os.makedirs(google_dir, exist_ok=True)
umt5_xxl_symlink = f"{google_dir}/umt5-xxl"
umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl"
# Create a symlink if it doesn't exist
if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source):
if os.name == 'nt': # Windows
import ctypes
kdll = ctypes.windll.LoadLibrary("kernel32.dll")
kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1)
else: # Unix/Linux
os.symlink(umt5_xxl_source, umt5_xxl_symlink)
logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}")
except Exception as e:
logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}")
# This is a warning, not an error, as we'll try to proceed anyway
def load_models(self, progress_callback=None):
"""Load the ReCamMaster models"""
if self.is_loaded:
return "Models already loaded!"
try:
logger.info("Starting model loading...")
# Import test data creator
from test_data import create_test_data_structure
# First create the test data structure
if progress_callback:
progress_callback(0.05, desc="Setting up test data structure...")
try:
create_test_data_structure(progress_callback)
except Exception as e:
error_msg = f"Error creating test data structure: {str(e)}"
logger.error(error_msg)
return error_msg
# Second, ensure the checkpoint is downloaded
if progress_callback:
progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...")
try:
ckpt_path = self.download_recammaster_checkpoint(progress_callback)
logger.info(f"Using checkpoint at {ckpt_path}")
except Exception as e:
error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}"
logger.error(error_msg)
return error_msg
# Third, download Wan2.1 models if needed
if progress_callback:
progress_callback(0.2, desc="Checking for Wan2.1 models...")
try:
wan21_paths = self.download_wan21_models(progress_callback)
logger.info(f"Using Wan2.1 models: {wan21_paths}")
except Exception as e:
error_msg = f"Error downloading Wan2.1 models: {str(e)}"
logger.error(error_msg)
return error_msg
# Fourth, download UMT5-XXL tokenizer files
if progress_callback:
progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...")
try:
tokenizer_paths = self.download_umt5_xxl_tokenizer(progress_callback)
logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}")
except Exception as e:
error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}"
logger.error(error_msg)
return error_msg
# Now, load the models
if progress_callback:
progress_callback(0.4, desc="Loading model manager...")
# Create symlink for tokenizer
self.create_symlink_for_tokenizer()
# Load Wan2.1 pre-trained models
self.model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
if progress_callback:
progress_callback(0.5, desc="Loading Wan2.1 models...")
# Build full paths for the model files
model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES]
for model_file in model_files:
logger.info(f"Loading model from: {model_file}")
if not os.path.exists(model_file):
error_msg = f"Error: Model file not found: {model_file}"
logger.error(error_msg)
return error_msg
# Set environment variable for transformers to find the tokenizer
os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning
self.model_manager.load_models(model_files)
if progress_callback:
progress_callback(0.7, desc="Creating pipeline...")
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(self.model_manager, device="cuda")
if progress_callback:
progress_callback(0.8, desc="Initializing ReCamMaster modules...")
# Initialize additional modules introduced in ReCamMaster
dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0]
for block in self.pipe.dit.blocks:
block.cam_encoder = nn.Linear(12, dim)
block.projector = nn.Linear(dim, dim)
block.cam_encoder.weight.data.zero_()
block.cam_encoder.bias.data.zero_()
block.projector.weight = nn.Parameter(torch.eye(dim))
block.projector.bias = nn.Parameter(torch.zeros(dim))
if progress_callback:
progress_callback(0.9, desc="Loading ReCamMaster checkpoint...")
# Load ReCamMaster checkpoint
if not os.path.exists(ckpt_path):
error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt."
logger.error(error_msg)
return error_msg
state_dict = torch.load(ckpt_path, map_location="cpu")
self.pipe.dit.load_state_dict(state_dict, strict=True)
self.pipe.to("cuda")
self.pipe.to(dtype=torch.bfloat16)
self.is_loaded = True
if progress_callback:
progress_callback(1.0, desc="Models loaded successfully!")
logger.info("Models loaded successfully!")
return "Models loaded successfully!"
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
return f"Error loading models: {str(e)}"