Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from typing import Dict, Optional | |
import numpy as np | |
import torch | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetInpaintPipeline | |
from huggingface_hub import hf_hub_download, list_repo_files | |
from PIL import Image, ImageChops, ImageEnhance | |
from rembg import new_session, remove | |
from transformers import DPTForDepthEstimation, DPTImageProcessor | |
from ip_adapter_instantstyle import IPAdapterXL | |
from ip_adapter_instantstyle.utils import register_cross_attention_hook | |
from parametric_control_mlp import control_mlp | |
file_dir = os.path.dirname(os.path.abspath(__file__)) | |
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0" | |
image_encoder_path = "models/image_encoder" | |
ip_ckpt = "sdxl_models/ip-adapter_sdxl_vit-h.bin" | |
controlnet_path = "diffusers/controlnet-depth-sdxl-1.0" | |
# Cache for rembg sessions | |
_session_cache = None | |
CONTROL_MLPS = ["metallic", "roughness", "transparency", "glow"] | |
def get_session(): | |
global _session_cache | |
if _session_cache is None: | |
_session_cache = new_session() | |
return _session_cache | |
def get_device(): | |
return "cuda" if torch.cuda.is_available() else "cpu" | |
def setup_control_mlps( | |
features: int = 1024, | |
device: Optional[str] = None, | |
dtype: torch.dtype = torch.float16, | |
) -> Dict[str, torch.nn.Module]: | |
ret = {} | |
if device is None: | |
device = get_device() | |
print(f"Setting up control MLPs on {device}") | |
for mlp in CONTROL_MLPS: | |
ret[mlp] = setup_control_mlp(mlp, features, device, dtype) | |
return ret | |
def setup_control_mlp( | |
material_parameter: str, | |
features: int = 1024, | |
device: Optional[str] = None, | |
dtype: torch.dtype = torch.float16, | |
): | |
if device is None: | |
device = get_device() | |
net = control_mlp(features) | |
net.load_state_dict( | |
torch.load( | |
os.path.join(file_dir, f"model_weights/{material_parameter}.pt"), | |
map_location=device | |
) | |
) | |
net.to(device, dtype=dtype) | |
net.eval() | |
return net | |
def download_ip_adapter(): | |
repo_id = "h94/IP-Adapter" | |
target_folders = ["models/", "sdxl_models/"] | |
local_dir = file_dir | |
# Check if folders exist and contain files | |
folders_exist = all( | |
os.path.exists(os.path.join(local_dir, folder)) for folder in target_folders | |
) | |
if folders_exist: | |
# Check if any of the target folders are empty | |
folders_empty = any( | |
len(os.listdir(os.path.join(local_dir, folder))) == 0 | |
for folder in target_folders | |
) | |
if not folders_empty: | |
print("IP-Adapter files already downloaded. Skipping download.") | |
return | |
# List all files in the repo | |
all_files = list_repo_files(repo_id) | |
# Filter for files in the desired folders | |
filtered_files = [ | |
f for f in all_files if any(f.startswith(folder) for folder in target_folders) | |
] | |
# Download each file | |
for file_path in filtered_files: | |
local_path = hf_hub_download( | |
repo_id=repo_id, | |
filename=file_path, | |
local_dir=local_dir, | |
local_dir_use_symlinks=False, | |
) | |
print(f"Downloaded: {file_path} to {local_path}") | |
def setup_pipeline( | |
device: Optional[str] = None, | |
dtype: torch.dtype = torch.float16, | |
): | |
if device is None: | |
device = get_device() | |
print(f"Setting up pipeline on {device}") | |
download_ip_adapter() | |
cur_block = ("up", 0, 1) | |
controlnet = ControlNetModel.from_pretrained( | |
controlnet_path, variant="fp16", use_safetensors=True, torch_dtype=dtype | |
).to(device) | |
pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
base_model_path, | |
controlnet=controlnet, | |
use_safetensors=True, | |
torch_dtype=dtype, | |
add_watermarker=False, | |
).to(device) | |
pipe.unet = register_cross_attention_hook(pipe.unet) | |
block_name = ( | |
cur_block[0] | |
+ "_blocks." | |
+ str(cur_block[1]) | |
+ ".attentions." | |
+ str(cur_block[2]) | |
) | |
print("Testing block {}".format(block_name)) | |
return IPAdapterXL( | |
pipe, | |
os.path.join(file_dir, image_encoder_path), | |
os.path.join(file_dir, ip_ckpt), | |
device, | |
target_blocks=[block_name], | |
) | |
def get_dpt_model(device: Optional[str] = None, dtype: torch.dtype = torch.float16): | |
if device is None: | |
device = get_device() | |
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") | |
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas") | |
model.to(device, dtype=dtype) | |
model.eval() | |
return model, image_processor | |
def run_dpt_depth( | |
image: Image.Image, model, processor, device: Optional[str] = None | |
) -> Image.Image: | |
"""Run DPT depth estimation on an image.""" | |
if device is None: | |
device = get_device() | |
# Prepare image | |
inputs = processor(images=image, return_tensors="pt").to(device, dtype=model.dtype) | |
# Get depth prediction | |
with torch.no_grad(): | |
depth_map = model(**inputs).predicted_depth | |
# Now normalize to 0-1 range | |
depth_map = (depth_map - depth_map.min()) / ( | |
depth_map.max() - depth_map.min() + 1e-7 | |
) | |
depth_map = depth_map.clip(0, 1) * 255 | |
# Convert to PIL Image | |
depth_map = depth_map.squeeze().cpu().numpy().astype(np.uint8) | |
return Image.fromarray(depth_map).resize((1024, 1024)) | |
def prepare_mask(image: Image.Image) -> Image.Image: | |
"""Prepare mask from image using rembg.""" | |
rm_bg = remove(image, session=get_session()) | |
target_mask = ( | |
rm_bg.convert("RGB") | |
.point(lambda x: 0 if x < 1 else 255) | |
.convert("L") | |
.convert("RGB") | |
) | |
return target_mask.resize((1024, 1024)) | |
def prepare_init_image(image: Image.Image, mask: Image.Image) -> Image.Image: | |
"""Prepare initial image for inpainting.""" | |
# Create grayscale version | |
gray_image = image.convert("L").convert("RGB") | |
gray_image = ImageEnhance.Brightness(gray_image).enhance(1.0) | |
# Create mask inversions | |
invert_mask = ImageChops.invert(mask) | |
# Combine images | |
grayscale_img = ImageChops.darker(gray_image, mask) | |
img_black_mask = ImageChops.darker(image, invert_mask) | |
init_img = ImageChops.lighter(img_black_mask, grayscale_img) | |
return init_img.resize((1024, 1024)) | |
def run_parametric_control( | |
ip_model, | |
target_image: Image.Image, | |
edit_mlps: dict[torch.nn.Module, float], | |
texture_image: Image.Image = None, | |
num_inference_steps: int = 30, | |
seed: int = 42, | |
depth_map: Optional[Image.Image] = None, | |
mask: Optional[Image.Image] = None, | |
) -> Image.Image: | |
"""Run parametric control with metallic and roughness adjustments.""" | |
# Get depth map | |
if depth_map is None: | |
print("No depth map provided, running DPT depth estimation") | |
model, processor = get_dpt_model() | |
depth_map = run_dpt_depth(target_image, model, processor) | |
else: | |
depth_map = depth_map.resize((1024, 1024)) | |
# Prepare mask and init image | |
if mask is None: | |
print("No mask provided, preparing mask") | |
mask = prepare_mask(target_image) | |
else: | |
mask = mask.resize((1024, 1024)) | |
print("Preparing initial image") | |
if texture_image is None: | |
texture_image = target_image | |
init_img = prepare_init_image(target_image, mask) | |
# Generate edit | |
print("Generating parametric edit") | |
images = ip_model.generate_parametric_edits( | |
texture_image, | |
image=init_img, | |
control_image=depth_map, | |
mask_image=mask, | |
controlnet_conditioning_scale=1.0, | |
num_samples=1, | |
num_inference_steps=num_inference_steps, | |
seed=seed, | |
edit_mlps=edit_mlps, | |
strength=1.0, | |
) | |
return images[0] | |
def run_blend( | |
ip_model, | |
target_image: Image.Image, | |
texture_image1: Image.Image, | |
texture_image2: Image.Image, | |
edit_strength: float = 0.0, | |
num_inference_steps: int = 20, | |
seed: int = 1, | |
depth_map: Optional[Image.Image] = None, | |
mask: Optional[Image.Image] = None, | |
) -> Image.Image: | |
"""Run blending between two texture images.""" | |
# Get depth map | |
if depth_map is None: | |
print("No depth map provided, running DPT depth estimation") | |
model, processor = get_dpt_model() | |
depth_map = run_dpt_depth(target_image, model, processor) | |
else: | |
depth_map = depth_map.resize((1024, 1024)) | |
# Prepare mask and init image | |
if mask is None: | |
print("No mask provided, preparing mask") | |
mask = prepare_mask(target_image) | |
else: | |
mask = mask.resize((1024, 1024)) | |
print("Preparing initial image") | |
init_img = prepare_init_image(target_image, mask) | |
# Generate edit | |
print("Generating edit") | |
images = ip_model.generate_edit( | |
start_image=texture_image1, | |
pil_image=texture_image1, | |
pil_image2=texture_image2, | |
image=init_img, | |
control_image=depth_map, | |
mask_image=mask, | |
controlnet_conditioning_scale=1.0, | |
num_samples=1, | |
num_inference_steps=num_inference_steps, | |
seed=seed, | |
edit_strength=edit_strength, | |
clip_strength=1.0, | |
strength=1.0, | |
) | |
return images[0] | |