import util.env_resolver import spaces import torch import numpy as np from PIL import Image from tooncomposer import ToonComposer, get_base_model_paths import argparse import json from util.training_util import extract_img_to_sketch import os import tempfile import cv2 import gradio as gr from einops import rearrange from datetime import datetime from typing import Optional, List, Dict from huggingface_hub import snapshot_download os.environ["GRADIO_TEMP_DIR"] = os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")) # ----------------------------------------------------------------------------- # Weights resolution and download helpers # ----------------------------------------------------------------------------- WAN_REPO_ID = "Wan-AI/Wan2.1-I2V-14B-480P" TOONCOMPOSER_REPO_ID = "TencentARC/ToonComposer" def _path_is_dir_with_files(dir_path: str, required_files: List[str]) -> bool: if not dir_path or not os.path.isdir(dir_path): return False for f in required_files: if not os.path.exists(os.path.join(dir_path, f)): return False return True def resolve_wan_model_root(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str: """Return a directory containing Wan2.1-I2V-14B-480P weights. Resolution order: 1) preferred_dir arg (if valid) 2) WAN21_I2V_DIR env var (if valid) 3) HF local cache (no download) via snapshot_download(local_files_only=True) 4) HF download to cache via snapshot_download() """ # Required filenames relative to the model root expected = get_base_model_paths("Wan2.1-I2V-14B-480P", format='dict', model_root=".") required_files = [] required_files.extend([os.path.basename(p) for p in expected["dit"]]) required_files.append(os.path.basename(expected["image_encoder"])) required_files.append(os.path.basename(expected["text_encoder"])) required_files.append(os.path.basename(expected["vae"])) # 1) preferred_dir arg if _path_is_dir_with_files(preferred_dir or "", required_files): return os.path.abspath(preferred_dir) # 2) environment variable env_dir = os.environ.get("WAN21_I2V_DIR") if _path_is_dir_with_files(env_dir or "", required_files): return os.path.abspath(env_dir) # 3) try local cache without network try: cached_dir = snapshot_download(repo_id=WAN_REPO_ID, local_files_only=True) return cached_dir except Exception: pass # 4) download (may be large) cached_dir = snapshot_download(repo_id=WAN_REPO_ID, token=hf_token) return cached_dir def resolve_tooncomposer_repo_dir(preferred_dir: Optional[str] = None, hf_token: Optional[str] = None) -> str: """Return a directory containing ToonComposer repo with 480p/608p subdirs.""" # Quick validity check: ensure either a subdir 480p or 608p exists with required files def has_resolution_dirs(base_dir: str) -> bool: if not base_dir or not os.path.isdir(base_dir): return False ok = False for res in ["480p", "608p"]: d = os.path.join(base_dir, res) if os.path.isdir(d): ckpt = os.path.join(d, "tooncomposer.ckpt") cfg = os.path.join(d, "config.json") if os.path.exists(ckpt) and os.path.exists(cfg): ok = True return ok # 1) preferred_dir arg if has_resolution_dirs(preferred_dir or ""): return os.path.abspath(preferred_dir) # 2) environment variable env_dir = os.environ.get("TOONCOMPOSER_DIR") if has_resolution_dirs(env_dir or ""): return os.path.abspath(env_dir) # 3) try local cache first try: cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, local_files_only=True) return cached_dir except Exception: pass # 4) download repo to cache cached_dir = snapshot_download(repo_id=TOONCOMPOSER_REPO_ID, token=hf_token) return cached_dir def build_checkpoints_by_resolution(tooncomposer_base_dir: str) -> Dict[str, Dict[str, object]]: """Construct resolution mapping from a base repo dir that contains 480p/608p. The ToonComposer HF repo stores, inside each resolution dir: - tooncomposer.ckpt - config.json (model configuration) """ mapping = {} # Known target sizes res_to_hw = { "480p": (480, 832), "608p": (608, 1088), } for res, (h, w) in res_to_hw.items(): res_dir = os.path.join(tooncomposer_base_dir, res) mapping[res] = { "target_height": h, "target_width": w, "snapshot_args_path": os.path.join(res_dir, "config.json"), "checkpoint_path": os.path.join(res_dir, "tooncomposer.ckpt"), } return mapping # Will be populated in main() after resolving ToonComposer repo directory checkpoints_by_resolution = {} def tensor2video(frames): frames = rearrange(frames, "C T H W -> T H W C") frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) frames = [Image.fromarray(frame) for frame in frames] return frames def _load_model_config(config_path: str) -> Dict[str, object]: with open(config_path, "r") as f: data = json.load(f) return data def _merge_with_defaults(cfg: Dict[str, object]) -> Dict[str, object]: # Provide safe defaults for optional fields used at inference-time defaults = { "base_model_name": "Wan2.1-I2V-14B-480P", "learning_rate": 1e-5, "train_architecture": "lora", "lora_rank": 4, "lora_alpha": 4, "lora_target_modules": "q,k,v,o,ffn.0,ffn.2", "init_lora_weights": "kaiming", "use_gradient_checkpointing": True, "tiled": False, "tile_size_height": 34, "tile_size_width": 34, "tile_stride_height": 18, "tile_stride_width": 16, "output_path": "./", "use_local_lora": False, "use_dera": False, "dera_rank": None, "use_dera_spatial": True, "use_dera_temporal": True, "use_sequence_cond": True, "sequence_cond_mode": "sparse", "use_channel_cond": False, "use_sequence_cond_position_aware_residual": True, "use_sequence_cond_loss": False, "fast_dev": False, "max_num_cond_images": 1, "max_num_cond_sketches": 2, "visualize_attention": False, "random_spaced_cond_frames": False, "use_sketch_mask": True, "sketch_mask_ratio": 0.2, "no_first_sketch": False, } merged = defaults.copy() merged.update(cfg) return merged def initialize_model(resolution="480p", fast_dev=False, device="cuda:0", dtype=torch.bfloat16, wan_model_dir: Optional[str] = None, tooncomposer_dir: Optional[str] = None, hf_token: Optional[str] = None): # Initialize model components if resolution not in checkpoints_by_resolution: raise ValueError(f"Resolution '{resolution}' is not available. Found: {list(checkpoints_by_resolution.keys())}") # 1) resolve config and checkpoint from ToonComposer repo (local or HF) snapshot_args_path = checkpoints_by_resolution[resolution]["snapshot_args_path"] checkpoint_path = checkpoints_by_resolution[resolution]["checkpoint_path"] # 2) load model config snapshot_args_raw = _load_model_config(snapshot_args_path) snapshot_args = _merge_with_defaults(snapshot_args_raw) snapshot_args["checkpoint_path"] = checkpoint_path # 3) resolve Wan2.1 model root snapshot_args["model_root"] = resolve_wan_model_root(preferred_dir=wan_model_dir, hf_token=hf_token) # Backward-compat fields if "training_max_frame_stride" not in snapshot_args: snapshot_args["training_max_frame_stride"] = 4 snapshot_args["random_spaced_cond_frames"] = False args = argparse.Namespace(**snapshot_args) if not fast_dev: model = ToonComposer( base_model_name=args.base_model_name, model_root=args.model_root, learning_rate=args.learning_rate, train_architecture=args.train_architecture, lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, use_gradient_checkpointing=args.use_gradient_checkpointing, checkpoint_path=args.checkpoint_path, tiled=args.tiled, tile_size=(args.tile_size_height, args.tile_size_width), tile_stride=(args.tile_stride_height, args.tile_stride_width), output_path=args.output_path, use_local_lora=args.use_local_lora, use_dera=args.use_dera, dera_rank=args.dera_rank, use_dera_spatial=args.use_dera_spatial, use_dera_temporal=args.use_dera_temporal, use_sequence_cond=args.use_sequence_cond, sequence_cond_mode=args.sequence_cond_mode, use_channel_cond=args.use_channel_cond, use_sequence_cond_position_aware_residual=args.use_sequence_cond_position_aware_residual, use_sequence_cond_loss=args.use_sequence_cond_loss, fast_dev=args.fast_dev, max_num_cond_images=args.max_num_cond_images, max_num_cond_sketches=args.max_num_cond_sketches, visualize_attention=args.visualize_attention, random_spaced_cond_frames=args.random_spaced_cond_frames, use_sketch_mask=args.use_sketch_mask, sketch_mask_ratio=args.sketch_mask_ratio, no_first_sketch=args.no_first_sketch, ) model = model.to(device, dtype=dtype).eval() else: print("Fast dev mode. Models will not be loaded.") model = None print("Models initialized.") return model, device, dtype # ----------------------------------------------------------------------------- # CLI args and global initialization # ----------------------------------------------------------------------------- def _parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--resolution", type=str, default=os.environ.get("TOONCOMPOSER_RESOLUTION", "480p"), choices=["480p", "608p"], help="Target resolution to load by default.") parser.add_argument("--device", type=str, default=os.environ.get("DEVICE", "cuda")) parser.add_argument("--dtype", type=str, default=os.environ.get("DTYPE", "bfloat16"), choices=["bfloat16", "float32"]) parser.add_argument("--wan_model_dir", type=str, default=os.environ.get("WAN21_I2V_DIR"), help="Local directory containing Wan2.1 model files. If not provided, will try HF cache and download if needed.") parser.add_argument("--tooncomposer_dir", type=str, default=os.environ.get("TOONCOMPOSER_DIR"), help="Local directory containing ToonComposer weights with 480p/608p subdirectories. If not provided, will try HF cache and download if needed.") parser.add_argument("--hf_token", type=str, default=os.environ.get("HF_TOKEN"), help="Hugging Face token (if needed for gated models).") parser.add_argument("--fast_dev", action="store_true", help="Run in fast dev mode without loading heavy models.") return parser.parse_args() _cli_args = _parse_args() # Resolve ToonComposer repo dir and build resolution mapping _toon_dir = resolve_tooncomposer_repo_dir(preferred_dir=_cli_args.tooncomposer_dir, hf_token=_cli_args.hf_token) checkpoints_by_resolution = build_checkpoints_by_resolution(_toon_dir) _dtype_map = { "bfloat16": torch.bfloat16, "float32": torch.float32, } fast_dev = bool(_cli_args.fast_dev) model, device, dtype = initialize_model( resolution=_cli_args.resolution, fast_dev=fast_dev, device=_cli_args.device, dtype=_dtype_map[_cli_args.dtype], wan_model_dir=_cli_args.wan_model_dir, tooncomposer_dir=_cli_args.tooncomposer_dir, hf_token=_cli_args.hf_token, ) def process_conditions(num_items, item_inputs, num_frames, is_sketch=False, target_height=480, target_width=832): """Process condition images/sketches into masked video tensor and mask""" # Create empty tensors filled with -1 video = torch.zeros((1, 3, num_frames, target_height, target_width), device=device) mask = torch.zeros((1, num_frames), device=device) for i in range(num_items): img, frame_idx = item_inputs[i] if img is None or frame_idx is None: continue # Convert PIL image to tensor img_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).float() / 127.5 - 1.0 if is_sketch: img_tensor = -img_tensor img_tensor = img_tensor.unsqueeze(0).to(device) # Resize to model's expected resolution while preserving aspect ratio # Get original dimensions _, _, h, w = img_tensor.shape # Resize based on short edge while maintaining aspect ratio if h/w < target_height/target_width: new_h = target_height new_w = int(w * (new_h / h)) else: # Width is the short edge new_w = target_width new_h = int(h * (new_w / w)) # Resize with the calculated dimensions img_tensor = torch.nn.functional.interpolate(img_tensor, size=(new_h, new_w), mode="bilinear") # Center crop to target resolution if needed if new_h > target_height or new_w > target_width: # Calculate starting positions for crop start_h = max(0, (new_h - target_height) // 2) start_w = max(0, (new_w - target_width) // 2) # Crop img_tensor = img_tensor[:, :, start_h:start_h+target_height, start_w:start_w+target_width] # Place in video tensor frame_idx = min(max(int(frame_idx), 0), num_frames-1) if is_sketch: video[:, :, frame_idx] = img_tensor[:, :3] # Handle RGBA sketches else: video[:, :, frame_idx] = img_tensor mask[:, frame_idx] = 1.0 return video, mask def process_sketch_masks(num_sketch_masks, sketch_mask_inputs, num_frames, target_height=480, target_width=832): """Process sketch masks into a single tensor""" # Create empty tensor filled with 1s (1 means no mask, keep original) sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device) for i in range(num_sketch_masks): editor_value, frame_idx = sketch_mask_inputs[i] if editor_value is None or frame_idx is None: continue # For ImageMask, we need to extract the mask from the editor_value dictionary # editor_value is a dict with 'background', 'layers', and 'composite' keys from ImageEditor if isinstance(editor_value, dict): if "composite" in editor_value and editor_value["composite"] is not None: # The 'composite' is the image with mask drawn on it # Since we're using ImageMask with fixed black brush, the black areas are the mask # Convert the composite to a binary mask (0=masked, 1=not masked) # sketch = editor_value["background"] # This is the sketch mask = editor_value["layers"][0] if editor_value["layers"] else None # This is the mask layer if mask is not None: # Convert mask to tensor and normalize mask_array = np.array(mask) mask_array = np.max(mask_array, axis=2) # Convert to tensor, normalize to [0, 1] mask_tensor = torch.from_numpy(mask_array).float() if mask_tensor.max() > 1.0: mask_tensor = mask_tensor / 255.0 # Resize to model's expected resolution mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, h, w] mask_tensor = torch.nn.functional.interpolate(mask_tensor, size=(target_height, target_width), mode="nearest") # Invert the mask: black (0) = masked area, white (1) = keep original # We need to invert because in the UI black means "masked" mask_tensor = 1.0 - mask_tensor # Place in sketch_local_mask tensor frame_idx = min(max(int(frame_idx), 0), num_frames-1) sketch_local_mask[:, :, frame_idx] = mask_tensor sketch_mask_vis = torch.ones((1, 3, num_frames, target_height, target_width), device=device) for t in range(sketch_local_mask.shape[2]): for c in range(3): sketch_mask_vis[0, c, t, :, :] = torch.where( sketch_local_mask[0, 0, t] > 0.5, 1.0, # White for unmasked areas -1.0 # Black for masked areas ) return sketch_local_mask def invert_sketch(image): """Invert the colors of an image (black to white, white to black)""" if image is None: return None # Handle input from ImageMask component (EditorValue dictionary) if isinstance(image, dict) and "background" in image: # Extract the background image bg_image = image["background"] # Invert the background inverted_bg = invert_sketch_internal(bg_image) # Return updated editor value return gr.update(value=inverted_bg) # Original function for regular images return invert_sketch_internal(image) def invert_sketch_internal(image): """Internal function to invert an image""" if image is None: return None # Convert to PIL image if needed if isinstance(image, str): # If it's a filepath image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) # Ensure it's a PIL image now if not isinstance(image, Image.Image): try: image = Image.fromarray(np.array(image)) except: print(f"Warning: Could not convert image of type {type(image)} to PIL Image") return image # Invert the image inverted = Image.fromarray(255 - np.array(image)) return inverted def create_blank_mask(canvas_width=832, canvas_height=480): """Create a blank white mask image""" return Image.new('RGB', (canvas_width, canvas_height), color='white') def create_mask_with_sketch(sketch, canvas_width=832, canvas_height=480): """Create a mask image with sketch as background""" if sketch is None: return create_blank_mask(canvas_width, canvas_height) # Convert sketch to PIL if needed if not isinstance(sketch, Image.Image): sketch = Image.fromarray(np.array(sketch)) # Resize sketch to fit the canvas sketch = sketch.resize((canvas_width, canvas_height)) # Create a semi-transparent white layer over the sketch overlay = Image.new('RGBA', (canvas_width, canvas_height), (255, 255, 255, 128)) # Ensure sketch has alpha channel if sketch.mode != 'RGBA': sketch = sketch.convert('RGBA') # Overlay the semi-transparent white layer on the sketch result = Image.alpha_composite(sketch, overlay) # Convert back to RGB for Gradio return result.convert('RGB') def validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args): """Validate user inputs and return error messages if any""" errors = [] # Check text prompt if not text_prompt or text_prompt.strip() == "": errors.append("❌ Text prompt is required. Please enter a description for your video.") # Check condition images cond_images_count = 0 for i in range(int(num_cond_images)): img = args[i*2] frame_idx = args[i*2+1] if img is None: errors.append(f"❌ Image #{i+1} is missing. Please upload an image or reduce the number of keyframe images.") else: cond_images_count += 1 if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames): errors.append(f"❌ Frame index for Image #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.") # Check condition sketches num_cond_sketches_index = 8 # Starting index for sketch inputs cond_sketches_count = 0 sketch_frame_indices = [] for i in range(int(num_cond_sketches)): sketch_idx = num_cond_sketches_index + i*2 frame_idx_idx = num_cond_sketches_index + 1 + i*2 if sketch_idx < len(args) and frame_idx_idx < len(args): sketch = args[sketch_idx] frame_idx = args[frame_idx_idx] # Check if sketch is provided if sketch is None: errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch or reduce the number of keyframe sketches.") else: # For ImageMask components, check if background is provided if isinstance(sketch, dict): if "background" not in sketch or sketch["background"] is None: errors.append(f"❌ Sketch #{i+1} is missing. Please upload a sketch image.") else: cond_sketches_count += 1 else: cond_sketches_count += 1 # Check frame index if frame_idx is not None and (frame_idx < 0 or frame_idx >= num_frames): errors.append(f"❌ Frame index for Sketch #{i+1} is {frame_idx}, which is out of range. Must be between 0 and {num_frames-1}.") elif frame_idx is not None: sketch_frame_indices.append(frame_idx) # Check for duplicate frame indices image_frame_indices = [] for i in range(int(num_cond_images)): frame_idx = args[i*2+1] if frame_idx is not None: image_frame_indices.append(frame_idx) all_frame_indices = image_frame_indices + sketch_frame_indices if len(all_frame_indices) != len(set(all_frame_indices)): errors.append("❌ Duplicate frame indices detected. Each image and sketch must be placed at a different frame.") # Check minimum requirements if cond_images_count == 0: errors.append("❌ At least one input image is required.") return errors @spaces.GPU(duration=240) def tooncomposer_inference(num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution, *args): # Validate inputs first validation_errors = validate_inputs(num_frames, num_cond_images, num_cond_sketches, text_prompt, *args) if validation_errors: error_message = "\n".join(validation_errors) return gr.update(value=None), error_message try: # Parse inputs # Get the condition images cond_images = [] for i in range(int(num_cond_images)): img = args[i*2] frame_idx = args[i*2+1] if img is not None and frame_idx is not None: cond_images.append((img, frame_idx)) # Get num_cond_sketches if num_cond_sketches is None: num_cond_sketches = 0 else: num_cond_sketches = int(num_cond_sketches) # Get condition sketches and masks cond_sketches = [] sketch_masks = [] num_cond_sketches_index = 8 # Starting index for sketch inputs for i in range(num_cond_sketches): sketch_idx = num_cond_sketches_index + i*2 frame_idx_idx = num_cond_sketches_index + 1 + i*2 if sketch_idx < len(args) and frame_idx_idx < len(args): editor_value = args[sketch_idx] frame_idx = args[frame_idx_idx] if editor_value is not None and frame_idx is not None: # Extract the sketch from the background of the editor value if isinstance(editor_value, dict) and "background" in editor_value: sketch = editor_value["background"] if sketch is not None: cond_sketches.append((sketch, frame_idx)) # Also add to sketch_masks for mask processing sketch_masks.append((editor_value, frame_idx)) else: # For regular image inputs (first sketch) if editor_value is not None: cond_sketches.append((editor_value, frame_idx)) # Set target resolution based on selection target_height, target_width = checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"] # Update model resolution if not fast_dev: model.update_height_width(target_height, target_width) # Process conditions with torch.no_grad(): # Process image conditions masked_cond_video, preserved_cond_mask = process_conditions( num_cond_images, cond_images, num_frames, target_height=target_height, target_width=target_width ) # Process sketch conditions masked_cond_sketch, preserved_sketch_mask = process_conditions( len(cond_sketches), cond_sketches, num_frames, is_sketch=True, target_height=target_height, target_width=target_width ) # Process sketch masks (if any) sketch_local_mask = None if len(sketch_masks) > 0: sketch_local_mask = process_sketch_masks( len(sketch_masks), sketch_masks, num_frames, target_height=target_height, target_width=target_width ) else: sketch_local_mask = torch.ones((1, 1, num_frames, target_height, target_width), device=device) if fast_dev: print("Fast dev mode, returning dummy video") # Create a simple dummy video for testing temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, "dummy_video.mp4") # Create a simple test video fourcc = cv2.VideoWriter_fourcc(*'mp4v') video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (target_width, target_height)) for i in range(30): # 30 frames # Create a simple colored frame frame = np.full((target_height, target_width, 3), (i * 8) % 255, dtype=np.uint8) video_writer.write(frame) video_writer.release() return video_path, "✅ Dummy video generated successfully in fast dev mode!" masked_cond_video = masked_cond_video.to(device=device, dtype=dtype) preserved_cond_mask = preserved_cond_mask.to(device=device, dtype=dtype) masked_cond_sketch = masked_cond_sketch.to(device=device, dtype=dtype) preserved_sketch_mask = preserved_sketch_mask.to(device=device, dtype=dtype) with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(device).type): # Generate video model.pipe.device = device generated_video = model.pipe( prompt=[text_prompt], negative_prompt=[model.negative_prompt], input_image=None, num_inference_steps=15, num_frames=num_frames, seed=42, tiled=True, input_condition_video=masked_cond_video, input_condition_preserved_mask=preserved_cond_mask, input_condition_video_sketch=masked_cond_sketch, input_condition_preserved_mask_sketch=preserved_sketch_mask, sketch_local_mask=sketch_local_mask, cfg_scale=cfg_scale, sequence_cond_residual_scale=sequence_cond_residual_scale, height=target_height, width=target_width, ) # Convert to PIL images video_frames = model.pipe.tensor2video(generated_video[0].cpu()) # Convert PIL images to an MP4 video temp_dir = tempfile.mkdtemp() video_path = os.path.join(temp_dir, "generated_video.mp4") width, height = video_frames[0].size fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Codec for MP4 video video_writer = cv2.VideoWriter(video_path, fourcc, 20.0, (width, height)) # 20 fps for frame in video_frames: # Convert PIL image to OpenCV BGR format frame_bgr = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR) video_writer.write(frame_bgr) video_writer.release() print(f"Generated video saved to {video_path}. Current time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") return video_path, f"✅ Video generated successfully! (with {len(cond_images)} keyframe images, {len(cond_sketches)} keyframe sketches)" except Exception as e: error_msg = f"❌ Error during generation: {str(e)}" print(error_msg) return gr.update(value=None), error_msg def create_sample_gallery(): """Create gallery items for samples""" import os gallery_items = [] sample_info = [ { "id": 1, "title": "Sample 1", "description": "Man playing with blue fish underwater (3 sketches)", "preview": "samples/1_image1.png" }, { "id": 2, "title": "Sample 2", "description": "Girl and boy planting a growing flower (2 sketches)", "preview": "samples/2_image1.jpg" }, { "id": 3, "title": "Sample 3", "description": "Ancient Chinese boy giving apple to elder (1 sketch)", "preview": "samples/3_image1.png" } ] for sample in sample_info: if os.path.exists(sample["preview"]): gallery_items.append((sample["preview"], f"{sample['title']}: {sample['description']}")) return gallery_items def handle_gallery_select(evt: gr.SelectData): """Handle gallery selection and load the corresponding sample""" sample_id = evt.index + 1 # Gallery index starts from 0, sample IDs start from 1 return apply_sample_to_ui(sample_id) def load_sample_data(sample_id): """Load sample data based on the selected sample""" import os samples_dir = "samples" # Sample configurations sample_configs = { 1: { "prompt": "Underwater scene: A shirtless man plays with a spiraling blue fish. A whale follows a bag in the man's hand, swimming in circles as the man uses the bag to lure the blue fish forward. Anime. High quality.", "num_sketches": 3, "image_frame": 0, "sketch_frames": [20, 40, 60], "num_frames": 61 }, 2: { "prompt": "A girl and a silver-haired boy plant a huge flower. As the camera slowly moves up, the huge flower continues to grow and bloom. Anime. High quality.", "num_sketches": 2, "image_frame": 0, "sketch_frames": [30, 60], "num_frames": 61 }, 3: { "prompt": "An ancient Chinese boy holds an apple and smiles as he gives it to an elderly man nearby. Anime. High quality.", "num_sketches": 1, "image_frame": 0, "sketch_frames": [30], "num_frames": 33 } } if sample_id not in sample_configs: return None config = sample_configs[sample_id] # Load image image_path = os.path.join(samples_dir, f"{sample_id}_image1.png") if not os.path.exists(image_path): image_path = os.path.join(samples_dir, f"{sample_id}_image1.jpg") # Load sketches sketches = [] for i in range(config["num_sketches"]): sketch_path = os.path.join(samples_dir, f"{sample_id}_sketch{i+1}.jpg") if os.path.exists(sketch_path): sketches.append(sketch_path) # Load output video output_path = os.path.join(samples_dir, f"{sample_id}_out.mp4") return { "prompt": config["prompt"], "image": image_path if os.path.exists(image_path) else None, "sketches": sketches, "image_frame": config["image_frame"], "sketch_frames": config["sketch_frames"][:len(sketches)], "output_video": output_path if os.path.exists(output_path) else None, "num_sketches": len(sketches), "num_frames": config["num_frames"] } def apply_sample_to_ui(sample_id): """Apply sample data to UI components""" sample_data = load_sample_data(sample_id) if not sample_data: return [gr.update() for _ in range(20)] # Return no updates if sample not found updates = [gr.update(value=sample_data["num_frames"])] # Update prompt updates.append(gr.update(value=sample_data["prompt"])) # Update number of sketches updates.append(gr.update(value=sample_data["num_sketches"])) # Update condition image updates.append(gr.update(value=sample_data["image"])) updates.append(gr.update(value=sample_data["image_frame"])) # Update sketches (up to 4) for i in range(4): if i < len(sample_data["sketches"]): # Load sketch image sketch_img = Image.open(sample_data["sketches"][i]) # Create ImageMask format sketch_dict = { "background": sketch_img, "layers": [], "composite": sketch_img } updates.append(gr.update(value=sketch_dict)) updates.append(gr.update(value=sample_data["sketch_frames"][i])) else: updates.append(gr.update(value=None)) updates.append(gr.update(value=30)) # Update output video updates.append(gr.update(value=sample_data["output_video"])) # Update status updates.append(gr.update(value=f"✅ Loaded Sample {sample_id}: {sample_data['prompt'][:50]}...")) return updates if __name__ == "__main__": from util.stylesheets import css, pre_js, banner_image with gr.Blocks(title="🎨 ToonComposer Demo", css=css, js=pre_js) as iface: with gr.Row(): with gr.Column(scale=1): gr.HTML(banner_image) with gr.Column(scale=1): gr.Markdown(""" 💡 **Quick Guide** 1. Set the promopt and number of target frames, input keyframe images/sketches, etc. 2. Upload keyframe image as the first frame (with index set to 0). 3. Upload sketches with optional motion masks for controlled generation at specified frame indices. 4. Click the *Generate* button to create your cartoon video. """) max_num_frames = 61 cond_images_inputs = [] cond_sketches_inputs = [] with gr.Row(): with gr.Column(scale=1): with gr.Accordion("Video Settings", open=True): num_frames = gr.Slider( minimum=17, maximum=max_num_frames, value=max_num_frames, step=1, label="🎥 Number of Frames", info="Select the total number of frames for the generated video. Should be 4N+" ) resolution = gr.Radio( choices=["480p", "608p"], value="480p", label="🎥 Resolution", info="Select the resolution for the generated video." ) text_prompt = gr.Textbox( label="📝 Text Prompt", placeholder="Enter a description for the video.", info="Describe what you want to generate in the video.", lines=5 ) cfg_scale = gr.Slider( minimum=1.0, maximum=15.0, value=7.5, label="⚙️ CFG Scale", info="Adjust the classifier-free guidance scale for generation." ) sequence_cond_residual_scale = gr.Slider( minimum=0.0, maximum=1.2, value=1.0, label="⚙️ Pos-aware Residual Scale", info="Adjust the residual scale for the position-aware sequence condition." ) with gr.Column(scale=3): with gr.Accordion("Keyframe Image(s)", open=True): num_cond_images = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="🖼️ Number of Keyframe Images", info="Specify how many keyframe color images to use (max 4 images)." ) for i in range(4): # Max 4 condition images with gr.Tab(label=f"Image {i+1}", interactive=i==0) as tab: gr.Markdown("At least one image is required. \n Each image or sketch will be used to control the cartoon geneartion at the given frame index.") image_input = gr.Image( label=f"Image {i+1}", type="pil", placeholder=f"Upload a keyframe image {i+1}..." ) frame_index_input = gr.Slider( label=f"Frame Index for Image #{i+1}", minimum=0, maximum=max_num_frames - 1, value=i * (max_num_frames-1) // 3, step=1, info=f"Frame position for Image {i+1} (0 to {max_num_frames-1})" ) cond_images_inputs.append((image_input, frame_index_input, tab)) with gr.Column(scale=3): with gr.Accordion("Keyframe Sketch(es)", open=True): num_cond_sketches = gr.Slider( minimum=1, maximum=4, value=1, step=1, label="✏️ Number of Keyframe Sketch(es)", info="Specify how many keyframe sketches to use (max 4 sketches)." ) for i in range(4): # Max 4 condition sketches with gr.Tab(label=f"Sketch {i + 1}", interactive=i==0) as tab: gr.Markdown("At least one sketch is required. \n You can optionally draw black areas using the brush tool to mark regions where motion can be generated freely.") # Use ImageMask which allows uploading an image and drawing a mask sketch_input = gr.ImageMask( label=f"Sketch {i + 1} with Motion Mask", type="pil", elem_id=f"sketch_mask_{i + 1}" ) # All sketches have a frame index input _frame_index_input = gr.Slider( label=f"Frame Index for Sketch #{i + 1}", minimum=0, maximum=max_num_frames - 1, value=max_num_frames-1, step=1, info=f"Frame position for Sketch {i + 1} (0 to {max_num_frames-1})" ) cond_sketches_inputs.append((sketch_input, _frame_index_input, tab)) with gr.Row(): with gr.Column(scale=1): # Sample Gallery Section with gr.Accordion("🔍 Sample Gallery", open=True): gr.Markdown("Click on any sample image below to load the sample inputs.") sample_gallery = gr.Gallery( value=create_sample_gallery(), label="Sample Examples", show_label=False, elem_id="sample-gallery", columns=3, rows=1, height=200, allow_preview=True, object_fit="contain") with gr.Accordion("🛠️ Tools", open=False): tool_input = gr.Image( label=f"Input Image", type="pil", placeholder=f"Upload an image." ) invert_btn = gr.Button(f"Invert Colors") invert_btn.click( fn=invert_sketch, inputs=[tool_input], outputs=[tool_input] ) with gr.Column(scale=1): status_text = gr.Textbox( label="📊 Status", value="Ready to generate. Please check your inputs and click Run.", interactive=False, lines=5 ) with gr.Accordion("🎬 Generated Video", open=True): output_video = gr.Video( label="Video Output", show_label=True ) run_button = gr.Button("🚀 Generate Video", variant="primary", size="lg") def update_visibility(num_items, num_frames): # Update visibility for columns updates_images = [] updates_indices = [] for i in range(4): is_visible = i < num_items # is_visible = True updates_images.append(gr.update(interactive=is_visible)) updates_indices.append(gr.update( value=((num_frames - 1) // max(num_items, 1)) * (i + 1), minimum=0, maximum=num_frames-1, )) return updates_images + updates_indices def update_visibility_images(num_items, num_frames): # Update visibility for columns updates_images = [] updates_indices = [] for i in range(4): is_visible = i < num_items updates_images.append(gr.update(interactive=is_visible)) updates_indices.append(gr.update( value=((num_frames - 1) // max(num_items, 1)) * i, minimum=0, maximum=num_frames-1, )) return updates_images + updates_indices def update_frame_ranges(num_items_images, num_items_sketches, num_frames): """Update the maximum values for all frame index sliders""" updates = [] for i in range(4): # Images updates.append(gr.update( value=((num_frames - 1) // max(num_items_images, 1)) * i, maximum=num_frames-1 )) for i in range(4): # Sketches updates.append(gr.update( value=((num_frames - 1) // max(num_items_sketches, 1)) * (i + 1), maximum=num_frames-1)) return updates num_cond_images.change( fn=update_visibility_images, inputs=[num_cond_images, num_frames], outputs=[tab for _, _, tab in cond_images_inputs] \ + [frame_index_input for _, frame_index_input, _ in cond_images_inputs], ) num_cond_sketches.change( fn=update_visibility, inputs=[num_cond_sketches, num_frames], outputs=[tab for _, _, tab in cond_sketches_inputs] \ + [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs], ) num_frames.change( fn=update_frame_ranges, inputs=[num_cond_images, num_cond_sketches, num_frames], outputs=[frame_index_input for _, frame_index_input, _ in cond_images_inputs] + \ [frame_index_input for _, frame_index_input, _ in cond_sketches_inputs] ) def update_resolution(resolution): model.update_height_width(checkpoints_by_resolution[resolution]["target_height"], checkpoints_by_resolution[resolution]["target_width"]) model.load_tooncomposer_checkpoint(checkpoints_by_resolution[resolution]["checkpoint_path"]) return gr.update(), gr.update() resolution.change( fn=update_resolution, inputs=[resolution], outputs=[output_video, run_button] ) sample_outputs = [ num_frames, text_prompt, num_cond_sketches, cond_images_inputs[0][0], cond_images_inputs[0][1], # Image 1 cond_sketches_inputs[0][0], cond_sketches_inputs[0][1], # Sketch 1 cond_sketches_inputs[1][0], cond_sketches_inputs[1][1], # Sketch 2 cond_sketches_inputs[2][0], cond_sketches_inputs[2][1], # Sketch 3 cond_sketches_inputs[3][0], cond_sketches_inputs[3][1], # Sketch 4 output_video, status_text ] sample_gallery.select( fn=handle_gallery_select, outputs=sample_outputs ) inputs = [num_frames, num_cond_images, num_cond_sketches, text_prompt, cfg_scale, sequence_cond_residual_scale, resolution] run_button.click( fn=tooncomposer_inference, inputs=inputs, outputs=[output_video, status_text] ) # Add condition image inputs for image_input, frame_index_input, _ in cond_images_inputs: inputs.append(image_input) inputs.append(frame_index_input) # Add sketch inputs (both regular and ImageMask) for sketch_input, frame_index_input, _ in cond_sketches_inputs: inputs.append(sketch_input) inputs.append(frame_index_input) iface.launch(server_port=7860, server_name="0.0.0.0", allowed_paths=[os.path.abspath(os.path.join(os.path.dirname(__file__), "gradio_cache")), os.path.abspath(os.path.join(os.path.dirname(__file__), "samples"))])