Spaces:
Running
on
L40S
Running
on
L40S
import os | |
import torch | |
import torch.nn as nn | |
import logging | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
from diffsynth import ModelManager, WanVideoReCamMasterPipeline | |
logger = logging.getLogger(__name__) | |
# Get model storage path from environment variable or use default | |
MODELS_ROOT_DIR = os.environ.get("RECAMMASTER_MODELS_DIR", "/data/models") | |
logger.info(f"Using models root directory: {MODELS_ROOT_DIR}") | |
# Define model repositories and files | |
WAN21_REPO_ID = "Wan-AI/Wan2.1-T2V-1.3B" | |
WAN21_LOCAL_DIR = f"{MODELS_ROOT_DIR}/Wan-AI/Wan2.1-T2V-1.3B" | |
WAN21_FILES = [ | |
"diffusion_pytorch_model.safetensors", | |
"models_t5_umt5-xxl-enc-bf16.pth", | |
"Wan2.1_VAE.pth" | |
] | |
# Define tokenizer files to download | |
UMT5_XXL_TOKENIZER_FILES = [ | |
"google/umt5-xxl/special_tokens_map.json", | |
"google/umt5-xxl/spiece.model", | |
"google/umt5-xxl/tokenizer.json", | |
"google/umt5-xxl/tokenizer_config.json" | |
] | |
RECAMMASTER_REPO_ID = "KwaiVGI/ReCamMaster-Wan2.1" | |
RECAMMASTER_CHECKPOINT_FILE = "step20000.ckpt" | |
RECAMMASTER_LOCAL_DIR = f"{MODELS_ROOT_DIR}/ReCamMaster/checkpoints" | |
class ModelLoader: | |
def __init__(self): | |
self.model_manager = None | |
self.pipe = None | |
self.is_loaded = False | |
def download_umt5_xxl_tokenizer(self, progress_callback=None): | |
"""Download UMT5-XXL tokenizer files from HuggingFace""" | |
total_files = len(UMT5_XXL_TOKENIZER_FILES) | |
downloaded_paths = [] | |
for i, file_path in enumerate(UMT5_XXL_TOKENIZER_FILES): | |
local_dir = f"{WAN21_LOCAL_DIR}/{os.path.dirname(file_path)}" | |
filename = os.path.basename(file_path) | |
full_local_path = f"{WAN21_LOCAL_DIR}/{file_path}" | |
# Update progress | |
if progress_callback: | |
progress_callback(i/total_files, desc=f"Checking tokenizer file {i+1}/{total_files}: {filename}") | |
# Check if already exists | |
if os.path.exists(full_local_path): | |
logger.info(f"✓ Tokenizer file {filename} already exists at {full_local_path}") | |
downloaded_paths.append(full_local_path) | |
continue | |
# Create directory if it doesn't exist | |
os.makedirs(local_dir, exist_ok=True) | |
# Download the file | |
logger.info(f"Downloading tokenizer file {filename} from {WAN21_REPO_ID}/{file_path}...") | |
if progress_callback: | |
progress_callback(i/total_files, desc=f"Downloading tokenizer file {i+1}/{total_files}: {filename}") | |
try: | |
# Download using huggingface_hub | |
downloaded_path = hf_hub_download( | |
repo_id=WAN21_REPO_ID, | |
filename=file_path, | |
local_dir=WAN21_LOCAL_DIR, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"✓ Successfully downloaded tokenizer file {filename} to {downloaded_path}!") | |
downloaded_paths.append(downloaded_path) | |
except Exception as e: | |
logger.error(f"✗ Error downloading tokenizer file {filename}: {e}") | |
raise | |
if progress_callback: | |
progress_callback(1.0, desc=f"All tokenizer files downloaded successfully!") | |
return downloaded_paths | |
def download_wan21_models(self, progress_callback=None): | |
"""Download Wan2.1 model files from HuggingFace""" | |
total_files = len(WAN21_FILES) | |
downloaded_paths = [] | |
# Create directory if it doesn't exist | |
Path(WAN21_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
for i, filename in enumerate(WAN21_FILES): | |
local_path = Path(WAN21_LOCAL_DIR) / filename | |
# Update progress | |
if progress_callback: | |
progress_callback(i/total_files, desc=f"Checking Wan2.1 file {i+1}/{total_files}: {filename}") | |
# Check if already exists | |
if local_path.exists(): | |
logger.info(f"✓ {filename} already exists at {local_path}") | |
downloaded_paths.append(str(local_path)) | |
continue | |
# Download the file | |
logger.info(f"Downloading {filename} from {WAN21_REPO_ID}...") | |
if progress_callback: | |
progress_callback(i/total_files, desc=f"Downloading Wan2.1 file {i+1}/{total_files}: {filename}") | |
try: | |
# Download using huggingface_hub | |
downloaded_path = hf_hub_download( | |
repo_id=WAN21_REPO_ID, | |
filename=filename, | |
local_dir=WAN21_LOCAL_DIR, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"✓ Successfully downloaded {filename} to {downloaded_path}!") | |
downloaded_paths.append(downloaded_path) | |
except Exception as e: | |
logger.error(f"✗ Error downloading {filename}: {e}") | |
raise | |
if progress_callback: | |
progress_callback(1.0, desc=f"All Wan2.1 models downloaded successfully!") | |
return downloaded_paths | |
def download_recammaster_checkpoint(self, progress_callback=None): | |
"""Download ReCamMaster checkpoint from HuggingFace using huggingface_hub""" | |
checkpoint_path = Path(RECAMMASTER_LOCAL_DIR) / RECAMMASTER_CHECKPOINT_FILE | |
# Check if already exists | |
if checkpoint_path.exists(): | |
logger.info(f"✓ ReCamMaster checkpoint already exists at {checkpoint_path}") | |
return checkpoint_path | |
# Create directory if it doesn't exist | |
Path(RECAMMASTER_LOCAL_DIR).mkdir(parents=True, exist_ok=True) | |
# Download the checkpoint | |
logger.info("Downloading ReCamMaster checkpoint from HuggingFace...") | |
logger.info(f"Repository: {RECAMMASTER_REPO_ID}") | |
logger.info(f"File: {RECAMMASTER_CHECKPOINT_FILE}") | |
logger.info(f"Destination: {checkpoint_path}") | |
if progress_callback: | |
progress_callback(0.0, desc=f"Downloading ReCamMaster checkpoint...") | |
try: | |
# Download using huggingface_hub | |
downloaded_path = hf_hub_download( | |
repo_id=RECAMMASTER_REPO_ID, | |
filename=RECAMMASTER_CHECKPOINT_FILE, | |
local_dir=RECAMMASTER_LOCAL_DIR, | |
local_dir_use_symlinks=False | |
) | |
logger.info(f"✓ Successfully downloaded ReCamMaster checkpoint to {downloaded_path}!") | |
if progress_callback: | |
progress_callback(1.0, desc=f"ReCamMaster checkpoint downloaded successfully!") | |
return downloaded_path | |
except Exception as e: | |
logger.error(f"✗ Error downloading checkpoint: {e}") | |
raise | |
def create_symlink_for_tokenizer(self): | |
"""Create symlink for google/umt5-xxl to handle potential path issues""" | |
try: | |
google_dir = f"{MODELS_ROOT_DIR}/google" | |
if not os.path.exists(google_dir): | |
os.makedirs(google_dir, exist_ok=True) | |
umt5_xxl_symlink = f"{google_dir}/umt5-xxl" | |
umt5_xxl_source = f"{WAN21_LOCAL_DIR}/google/umt5-xxl" | |
# Create a symlink if it doesn't exist | |
if not os.path.exists(umt5_xxl_symlink) and os.path.exists(umt5_xxl_source): | |
if os.name == 'nt': # Windows | |
import ctypes | |
kdll = ctypes.windll.LoadLibrary("kernel32.dll") | |
kdll.CreateSymbolicLinkA(umt5_xxl_symlink.encode(), umt5_xxl_source.encode(), 1) | |
else: # Unix/Linux | |
os.symlink(umt5_xxl_source, umt5_xxl_symlink) | |
logger.info(f"Created symlink from {umt5_xxl_source} to {umt5_xxl_symlink}") | |
except Exception as e: | |
logger.warning(f"Could not create symlink for google/umt5-xxl: {str(e)}") | |
# This is a warning, not an error, as we'll try to proceed anyway | |
def load_models(self, progress_callback=None): | |
"""Load the ReCamMaster models""" | |
if self.is_loaded: | |
return "Models already loaded!" | |
try: | |
logger.info("Starting model loading...") | |
# Import test data creator | |
from test_data import create_test_data_structure | |
# First create the test data structure | |
if progress_callback: | |
progress_callback(0.05, desc="Setting up test data structure...") | |
try: | |
create_test_data_structure(progress_callback) | |
except Exception as e: | |
error_msg = f"Error creating test data structure: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
# Second, ensure the checkpoint is downloaded | |
if progress_callback: | |
progress_callback(0.1, desc="Checking for ReCamMaster checkpoint...") | |
try: | |
ckpt_path = self.download_recammaster_checkpoint(progress_callback) | |
logger.info(f"Using checkpoint at {ckpt_path}") | |
except Exception as e: | |
error_msg = f"Error downloading ReCamMaster checkpoint: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
# Third, download Wan2.1 models if needed | |
if progress_callback: | |
progress_callback(0.2, desc="Checking for Wan2.1 models...") | |
try: | |
wan21_paths = self.download_wan21_models(progress_callback) | |
logger.info(f"Using Wan2.1 models: {wan21_paths}") | |
except Exception as e: | |
error_msg = f"Error downloading Wan2.1 models: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
# Fourth, download UMT5-XXL tokenizer files | |
if progress_callback: | |
progress_callback(0.3, desc="Checking for UMT5-XXL tokenizer files...") | |
try: | |
tokenizer_paths = self.download_umt5_xxl_tokenizer(progress_callback) | |
logger.info(f"Using UMT5-XXL tokenizer files: {tokenizer_paths}") | |
except Exception as e: | |
error_msg = f"Error downloading UMT5-XXL tokenizer files: {str(e)}" | |
logger.error(error_msg) | |
return error_msg | |
# Now, load the models | |
if progress_callback: | |
progress_callback(0.4, desc="Loading model manager...") | |
# Create symlink for tokenizer | |
self.create_symlink_for_tokenizer() | |
# Load Wan2.1 pre-trained models | |
self.model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") | |
if progress_callback: | |
progress_callback(0.5, desc="Loading Wan2.1 models...") | |
# Build full paths for the model files | |
model_files = [f"{WAN21_LOCAL_DIR}/{filename}" for filename in WAN21_FILES] | |
for model_file in model_files: | |
logger.info(f"Loading model from: {model_file}") | |
if not os.path.exists(model_file): | |
error_msg = f"Error: Model file not found: {model_file}" | |
logger.error(error_msg) | |
return error_msg | |
# Set environment variable for transformers to find the tokenizer | |
os.environ["TRANSFORMERS_CACHE"] = MODELS_ROOT_DIR | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism warning | |
self.model_manager.load_models(model_files) | |
if progress_callback: | |
progress_callback(0.7, desc="Creating pipeline...") | |
self.pipe = WanVideoReCamMasterPipeline.from_model_manager(self.model_manager, device="cuda") | |
if progress_callback: | |
progress_callback(0.8, desc="Initializing ReCamMaster modules...") | |
# Initialize additional modules introduced in ReCamMaster | |
dim = self.pipe.dit.blocks[0].self_attn.q.weight.shape[0] | |
for block in self.pipe.dit.blocks: | |
block.cam_encoder = nn.Linear(12, dim) | |
block.projector = nn.Linear(dim, dim) | |
block.cam_encoder.weight.data.zero_() | |
block.cam_encoder.bias.data.zero_() | |
block.projector.weight = nn.Parameter(torch.eye(dim)) | |
block.projector.bias = nn.Parameter(torch.zeros(dim)) | |
if progress_callback: | |
progress_callback(0.9, desc="Loading ReCamMaster checkpoint...") | |
# Load ReCamMaster checkpoint | |
if not os.path.exists(ckpt_path): | |
error_msg = f"Error: ReCamMaster checkpoint not found at {ckpt_path} even after download attempt." | |
logger.error(error_msg) | |
return error_msg | |
state_dict = torch.load(ckpt_path, map_location="cpu") | |
self.pipe.dit.load_state_dict(state_dict, strict=True) | |
self.pipe.to("cuda") | |
self.pipe.to(dtype=torch.bfloat16) | |
self.is_loaded = True | |
if progress_callback: | |
progress_callback(1.0, desc="Models loaded successfully!") | |
logger.info("Models loaded successfully!") | |
return "Models loaded successfully!" | |
except Exception as e: | |
logger.error(f"Error loading models: {str(e)}") | |
return f"Error loading models: {str(e)}" |