import os from typing import Dict, Optional import numpy as np import torch from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline from huggingface_hub import hf_hub_download, list_repo_files from PIL import Image, ImageChops, ImageEnhance from rembg import new_session, remove from transformers import DPTForDepthEstimation, DPTImageProcessor from ip_adapter_instantstyle import IPAdapterXL from ip_adapter_instantstyle.utils import register_cross_attention_hook from parametric_control_mlp import control_mlp file_dir = os.path.dirname(os.path.abspath(__file__)) base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" image_encoder_path = "models/image_encoder" ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" # Cache for rembg sessions _session_cache = None CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"] def get_session(): global _session_cache if _session_cache is None: _session_cache = new_session() return _session_cache def get_device(): return "cuda" if torch.cuda.is_available() else "cpu" def setup_control_mlps( features: int = 1024, device: Optional[str] = None, dtype: torch.dtype = torch.float16, ) -> Dict[str, torch.nn.Module]: ret = {} if device is None: device = get_device() print(f"Setting up control MLPs on {device}") for mlp in CONTROL_MLPS: ret[mlp] = setup_control_mlp(mlp, features, device, dtype) return ret def setup_control_mlp( material_parameter: str, features: int = 1024, device: Optional[str] = None, dtype: torch.dtype = torch.float16, ): if device is None: device = get_device() net = control_mlp(features) net.load_state_dict( torch.load( os.path.join(file_dir, f"model_weights/{material_parameter}.pt"), map_location=device ) ) net.to(device, dtype=dtype) net.eval() return net def download_ip_adapter(): repo_id = "h94/IP-Adapter" target_folders = ["models/", "sdxl_models/"] local_dir = file_dir # Check if folders exist and contain files folders_exist = all( os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders ) if folders_exist: # Check if any of the target folders are empty folders_empty = any( len(os.listdir(os.path.join(local_dir, folder))) == 0 for folder in target_folders ) if not folders_empty: print("IP-Adapter files already downloaded. Skipping download.") return # List all files in the repo all_files = list_repo_files(repo_id) # Filter for files in the desired folders filtered_files = [ f for f in all_files if any(f.startswith(folder) for folder in target_folders) ] # Download each file for file_path in filtered_files: local_path = hf_hub_download( repo_id=repo_id, filename=file_path, local_dir=local_dir, local_dir_use_symlinks=False, ) print(f"Downloaded: {file_path} to {local_path}") def setup_pipeline( device: Optional[str] = None, dtype: torch.dtype = torch.float16, ): if device is None: device = get_device() print(f"Setting up pipeline on {device}") download_ip_adapter() cur_block = ("up", 0, 1) controlnet = ControlNetModel.from_pretrained( controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype ).to(device) pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( base_model_path, controlnet=controlnet, use_safetensors=True, torch_dtype=dtype, add_watermarker=False, ).to(device) pipe.unet = register_cross_attention_hook(pipe.unet) block_name = ( cur_block[0] + "_blocks." + str(cur_block[1]) + ".attentions." + str(cur_block[2]) ) print("Testing block {}".format(block_name)) return IPAdapterXL( pipe, os.path.join(file_dir, image_encoder_path), os.path.join(file_dir, ip_ckpt), device, target_blocks=[block_name], ) def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16): if device is None: device = get_device() image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") model.to(device, dtype=dtype) model.eval() return model, image_processor def run_dpt_depth( image: Image.Image, model, processor, device: Optional[str] = None ) -> Image.Image: """Run DPT depth estimation on an image.""" if device is None: device = get_device() # Prepare image inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype) # Get depth prediction with torch.no_grad(): depth_map = model(**inputs).predicted_depth # Now normalize to 0-1 range depth_map = (depth_map - depth_map.min()) / ( depth_map.max() - depth_map.min() + 1e-7 ) depth_map = depth_map.clip(0, 1) * 255 # Convert to PIL Image depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8) return Image.fromarray(depth_map).resize((1024, 1024)) def prepare_mask(image: Image.Image) -> Image.Image: """Prepare mask from image using rembg.""" rm_bg = remove(image, session=get_session()) target_mask = ( rm_bg.convert("RGB") .point(lambda x: 0 if x < 1 else 255) .convert("L") .convert("RGB") ) return target_mask.resize((1024, 1024)) def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image: """Prepare initial image for inpainting.""" # Create grayscale version gray_image = image.convert("L").convert("RGB") gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0) # Create mask inversions invert_mask = ImageChops.invert(mask) # Combine images grayscale_img = ImageChops.darker(gray_image, mask) img_black_mask = ImageChops.darker(image, invert_mask) init_img = ImageChops.lighter(img_black_mask, grayscale_img) return init_img.resize((1024, 1024)) def run_parametric_control( ip_model, target_image: Image.Image, edit_mlps: dict[torch.nn.Module, float], texture_image: Image.Image = None, num_inference_steps: int = 30, seed: int = 42, depth_map: Optional[Image.Image] = None, mask: Optional[Image.Image] = None, ) -> Image.Image: """Run parametric control with metallic and roughness adjustments.""" # Get depth map if depth_map is None: print("No depth map provided, running DPT depth estimation") model, processor = get_dpt_model() depth_map = run_dpt_depth(target_image, model, processor) else: depth_map = depth_map.resize((1024, 1024)) # Prepare mask and init image if mask is None: print("No mask provided, preparing mask") mask = prepare_mask(target_image) else: mask = mask.resize((1024, 1024)) print("Preparing initial image") if texture_image is None: texture_image = target_image init_img = prepare_init_image(target_image, mask) # Generate edit print("Generating parametric edit") images = ip_model.generate_parametric_edits( texture_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=1.0, num_samples=1, num_inference_steps=num_inference_steps, seed=seed, edit_mlps=edit_mlps, strength=1.0, ) return images[0] def run_blend( ip_model, target_image: Image.Image, texture_image1: Image.Image, texture_image2: Image.Image, edit_strength: float = 0.0, num_inference_steps: int = 20, seed: int = 1, depth_map: Optional[Image.Image] = None, mask: Optional[Image.Image] = None, ) -> Image.Image: """Run blending between two texture images.""" # Get depth map if depth_map is None: print("No depth map provided, running DPT depth estimation") model, processor = get_dpt_model() depth_map = run_dpt_depth(target_image, model, processor) else: depth_map = depth_map.resize((1024, 1024)) # Prepare mask and init image if mask is None: print("No mask provided, preparing mask") mask = prepare_mask(target_image) else: mask = mask.resize((1024, 1024)) print("Preparing initial image") init_img = prepare_init_image(target_image, mask) # Generate edit print("Generating edit") images = ip_model.generate_edit( start_image=texture_image1, pil_image=texture_image1, pil_image2=texture_image2, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=1.0, num_samples=1, num_inference_steps=num_inference_steps, seed=seed, edit_strength=edit_strength, clip_strength=1.0, strength=1.0, ) return images[0]