|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from typing import Optional, Tuple, List, Union |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionXLPipeline, |
|
DiffusionPipeline |
|
) |
|
import math |
|
|
|
|
|
class TKGDMNoiseOptimizer: |
|
""" |
|
TKG-DM: Training-free Chroma Key Content Generation Diffusion Model |
|
Implementation of non-reactive noise optimization for background control |
|
""" |
|
|
|
def __init__(self, device: str = "cuda", dtype: torch.dtype = None): |
|
self.device = device |
|
self.dtype = dtype or (torch.float16 if device == "cuda" else torch.float32) |
|
|
|
def calculate_initial_ratio(self, noise_tensor: torch.Tensor, channel: int) -> float: |
|
"""Calculate initial positive ratio for a specific channel""" |
|
channel_data = noise_tensor[:, channel, :, :] |
|
positive_pixels = (channel_data > 0).sum().float() |
|
total_pixels = channel_data.numel() |
|
return (positive_pixels / total_pixels).item() |
|
|
|
def optimize_channel_shift(self, |
|
noise_tensor: torch.Tensor, |
|
channel: int, |
|
target_shift: float, |
|
max_iterations: int = 100, |
|
tolerance: float = 1e-4) -> float: |
|
""" |
|
Optimize channel mean shift to achieve target positive ratio |
|
|
|
Args: |
|
noise_tensor: Initial noise tensor [B, C, H, W] |
|
channel: Channel index to optimize (0=R, 1=G, 2=B) |
|
target_shift: Desired shift in positive ratio |
|
max_iterations: Maximum optimization iterations |
|
tolerance: Convergence tolerance |
|
|
|
Returns: |
|
Optimal channel shift value |
|
""" |
|
initial_ratio = self.calculate_initial_ratio(noise_tensor, channel) |
|
target_ratio = initial_ratio + target_shift |
|
|
|
|
|
delta_min, delta_max = -10.0, 10.0 |
|
delta_optimal = 0.0 |
|
|
|
for _ in range(max_iterations): |
|
delta_mid = (delta_min + delta_max) / 2 |
|
|
|
|
|
shifted_tensor = noise_tensor.clone() |
|
shifted_tensor[:, channel, :, :] += delta_mid |
|
current_ratio = self.calculate_initial_ratio(shifted_tensor, channel) |
|
|
|
if abs(current_ratio - target_ratio) < tolerance: |
|
delta_optimal = delta_mid |
|
break |
|
|
|
if current_ratio < target_ratio: |
|
delta_min = delta_mid |
|
else: |
|
delta_max = delta_mid |
|
|
|
delta_optimal = delta_mid |
|
|
|
return delta_optimal |
|
|
|
def create_bounding_box_mask(self, |
|
height: int, |
|
width: int, |
|
bounding_boxes: List[Tuple[float, float, float, float]], |
|
sigma: Optional[float] = None) -> torch.Tensor: |
|
""" |
|
Create binary occupancy map from axis-aligned bounding boxes with Gaussian blur |
|
|
|
Args: |
|
height, width: Mask dimensions |
|
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates [0,1] |
|
sigma: Gaussian blur standard deviation (auto-calculated if None) |
|
|
|
Returns: |
|
Soft transition mask M_blur with values in [0,1] |
|
""" |
|
|
|
binary_mask = torch.zeros((height, width), device=self.device) |
|
|
|
for x1, y1, x2, y2 in bounding_boxes: |
|
|
|
px1 = int(x1 * width) |
|
py1 = int(y1 * height) |
|
px2 = int(x2 * width) |
|
py2 = int(y2 * height) |
|
|
|
|
|
px1, px2 = max(0, min(px1, px2)), min(width, max(px1, px2)) |
|
py1, py2 = max(0, min(py1, py2)), min(height, max(py1, py2)) |
|
|
|
|
|
binary_mask[py1:py2, px1:px2] = 1.0 |
|
|
|
|
|
if sigma is None: |
|
|
|
sigma = min(height, width) * 0.02 |
|
|
|
|
|
mask_4d = binary_mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
kernel_size = int(6 * sigma + 1) |
|
if kernel_size % 2 == 0: |
|
kernel_size += 1 |
|
|
|
|
|
blur_mask = self._gaussian_blur_2d(mask_4d, kernel_size, sigma) |
|
|
|
return blur_mask.squeeze(0).squeeze(0) |
|
|
|
def _gaussian_blur_2d(self, tensor: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: |
|
"""Apply 2D Gaussian blur to tensor""" |
|
|
|
coords = torch.arange(kernel_size, dtype=tensor.dtype, device=tensor.device) |
|
coords = coords - kernel_size // 2 |
|
|
|
kernel = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) |
|
kernel = kernel / kernel.sum() |
|
kernel = kernel.view(1, 1, 1, -1) |
|
|
|
|
|
|
|
padding = kernel_size // 2 |
|
blurred = torch.nn.functional.conv2d( |
|
tensor, kernel, padding=(0, padding) |
|
) |
|
|
|
|
|
kernel_t = kernel.transpose(-1, -2) |
|
blurred = torch.nn.functional.conv2d( |
|
blurred, kernel_t, padding=(padding, 0) |
|
) |
|
|
|
return blurred |
|
|
|
def create_gaussian_mask(self, |
|
height: int, |
|
width: int, |
|
center_x: float, |
|
center_y: float, |
|
sigma: float) -> torch.Tensor: |
|
""" |
|
Create 2D Gaussian mask for noise blending |
|
|
|
Args: |
|
height, width: Mask dimensions |
|
center_x, center_y: Gaussian center (normalized 0-1) |
|
sigma: Gaussian spread parameter |
|
|
|
Returns: |
|
2D Gaussian mask tensor |
|
""" |
|
y_coords = torch.linspace(0, 1, height).view(-1, 1) |
|
x_coords = torch.linspace(0, 1, width).view(1, -1) |
|
|
|
|
|
dx = x_coords - center_x |
|
dy = y_coords - center_y |
|
|
|
|
|
mask = torch.exp(-((dx**2 + dy**2) / (2 * sigma**2))) |
|
|
|
return mask.to(self.device) |
|
|
|
def apply_color_shift(self, |
|
noise_tensor: torch.Tensor, |
|
target_color: List[float], |
|
target_shifts: List[float]) -> torch.Tensor: |
|
""" |
|
Apply latent channel shifts to noise tensor |
|
|
|
Args: |
|
noise_tensor: Input noise [B, C, H, W] (typically 4 channels for SD latent) |
|
target_color: Target RGB color [R, G, B] (0-1 range) |
|
target_shifts: Channel shift amounts for all latent channels |
|
|
|
Returns: |
|
Color-shifted noise tensor |
|
""" |
|
shifted_noise = noise_tensor.clone() |
|
|
|
|
|
num_channels = min(len(target_shifts), noise_tensor.shape[1]) |
|
|
|
for channel in range(num_channels): |
|
if abs(target_shifts[channel]) > 1e-6: |
|
delta = self.optimize_channel_shift( |
|
noise_tensor, channel, target_shifts[channel] |
|
) |
|
shifted_noise[:, channel, :, :] += delta |
|
|
|
return shifted_noise |
|
|
|
def blend_noise_with_mask(self, |
|
original_noise: torch.Tensor, |
|
shifted_noise: torch.Tensor, |
|
mask: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Blend original and shifted noise using space-aware formula from paper: |
|
ε_masked = ε + M_blur(ε_shifted - ε) |
|
|
|
Args: |
|
original_noise: Original Gaussian noise ε [B, C, H, W] |
|
shifted_noise: Non-reactive shifted noise ε_shifted [B, C, H, W] |
|
mask: Soft transition mask M_blur [H, W] with values in [0,1] |
|
|
|
Returns: |
|
Composite noise tensor ε_masked |
|
""" |
|
|
|
mask_expanded = mask.unsqueeze(0).unsqueeze(0).to(original_noise.dtype) |
|
mask_expanded = mask_expanded.expand_as(original_noise) |
|
|
|
|
|
|
|
|
|
blended_noise = original_noise + mask_expanded * (shifted_noise - original_noise) |
|
|
|
return blended_noise |
|
|
|
def generate_chroma_key_noise(self, |
|
shape: Tuple[int, int, int, int], |
|
background_color: List[float] = [0.0, 1.0, 0.0], |
|
foreground_center: Tuple[float, float] = (0.5, 0.5), |
|
foreground_size: float = 0.3, |
|
target_shift_percent: float = 0.07) -> torch.Tensor: |
|
""" |
|
Generate optimized noise for chroma key content generation using TKG-DM method |
|
Handles latent space channels correctly: |
|
- Channels 0,3: Luminance and color information |
|
- Channels 1,2: Color channels (positive = pink/yellow, negative = red/blue) |
|
|
|
Args: |
|
shape: Noise tensor shape [B, C, H, W] (typically [1, 4, 64, 64] for SD) |
|
background_color: Target background RGB color (0-1) |
|
foreground_center: Center of foreground object (x, y) normalized |
|
foreground_size: Size of foreground region (sigma parameter) |
|
target_shift_percent: Target shift percentage (±7% as per paper) |
|
|
|
Returns: |
|
Optimized noise tensor for chroma key generation |
|
""" |
|
batch_size, channels, height, width = shape |
|
|
|
|
|
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype) |
|
|
|
|
|
target_shifts = [] |
|
|
|
|
|
r, g, b = background_color[0], background_color[1], background_color[2] |
|
|
|
for c in range(channels): |
|
initial_ratio = self.calculate_initial_ratio(original_noise, c) |
|
|
|
if c == 0: |
|
|
|
brightness = (r + g + b) / 3.0 |
|
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent |
|
|
|
elif c == 1: |
|
|
|
|
|
if g > 0.7 or (r > 0.7 and g > 0.7): |
|
target_shift = target_shift_percent |
|
elif r > 0.7 or b > 0.7: |
|
target_shift = -target_shift_percent |
|
else: |
|
target_shift = 0.0 |
|
|
|
elif c == 2: |
|
|
|
if r > 0.7 and g < 0.3: |
|
target_shift = -target_shift_percent |
|
elif g > 0.7 and r < 0.3: |
|
target_shift = target_shift_percent |
|
else: |
|
target_shift = 0.0 |
|
|
|
elif c == 3: |
|
|
|
if b > 0.7: |
|
target_shift = -target_shift_percent |
|
elif r > 0.7 and g > 0.7: |
|
target_shift = target_shift_percent * 0.8 |
|
|
|
|
|
|
|
else: |
|
target_shift = 0.0 |
|
|
|
target_shifts.append(target_shift) |
|
print(f"Latent Channel {c}: Initial ratio={initial_ratio:.4f}, Target shift={target_shift:+.1%}") |
|
|
|
|
|
shifted_noise = self.apply_color_shift(original_noise, background_color, target_shifts) |
|
|
|
|
|
mask = self.create_gaussian_mask( |
|
height, width, |
|
foreground_center[0], foreground_center[1], |
|
foreground_size |
|
) |
|
|
|
|
|
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask) |
|
|
|
return optimized_noise |
|
|
|
def generate_direct_channel_noise(self, |
|
shape: Tuple[int, int, int, int], |
|
channel_shifts: List[float], |
|
foreground_center: Tuple[float, float] = (0.5, 0.5), |
|
foreground_size: float = 0.3, |
|
target_shift_percent: float = 0.07) -> torch.Tensor: |
|
""" |
|
Generate optimized noise with direct channel shift control |
|
|
|
Args: |
|
shape: Noise tensor shape [B, C, H, W] |
|
channel_shifts: Direct shift values for each channel [-1, 1] |
|
foreground_center: Center of foreground object (x, y) normalized |
|
foreground_size: Size of foreground region (sigma parameter) |
|
target_shift_percent: Base shift percentage (±7% as per paper) |
|
|
|
Returns: |
|
Optimized noise tensor with direct channel control |
|
""" |
|
batch_size, channels, height, width = shape |
|
|
|
|
|
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype) |
|
|
|
|
|
target_shifts = [] |
|
for c in range(min(channels, len(channel_shifts))): |
|
|
|
user_shift = channel_shifts[c] |
|
target_shift = user_shift * target_shift_percent |
|
target_shifts.append(target_shift) |
|
|
|
initial_ratio = self.calculate_initial_ratio(original_noise, c) |
|
print(f"Direct Channel {c}: User shift={user_shift:+.2f}, Target shift={target_shift:+.1%}, Initial ratio={initial_ratio:.4f}") |
|
|
|
|
|
while len(target_shifts) < channels: |
|
target_shifts.append(0.0) |
|
|
|
|
|
shifted_noise = self.apply_color_shift(original_noise, [0.0, 0.0, 0.0], target_shifts) |
|
|
|
|
|
mask = self.create_gaussian_mask( |
|
height, width, |
|
foreground_center[0], foreground_center[1], |
|
foreground_size |
|
) |
|
|
|
|
|
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask) |
|
|
|
return optimized_noise |
|
|
|
def generate_space_aware_noise(self, |
|
shape: Tuple[int, int, int, int], |
|
bounding_boxes: List[Tuple[float, float, float, float]], |
|
channel_shifts: Optional[List[float]] = None, |
|
background_color: Optional[List[float]] = None, |
|
target_shift_percent: float = 0.07, |
|
blur_sigma: Optional[float] = None) -> torch.Tensor: |
|
""" |
|
Generate space-aware noise with reserved bounding boxes as per section 2.2 |
|
|
|
Args: |
|
shape: Noise tensor shape [B, C, H, W] |
|
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates for reserved regions |
|
channel_shifts: Direct channel shifts (optional) |
|
background_color: RGB background color for shift calculation (fallback) |
|
target_shift_percent: Base shift percentage (±7% as per paper) |
|
blur_sigma: Gaussian blur sigma (auto-calculated if None) |
|
|
|
Returns: |
|
Space-aware composite noise tensor ε_masked |
|
""" |
|
batch_size, channels, height, width = shape |
|
|
|
|
|
epsilon = torch.randn(shape, device=self.device, dtype=self.dtype) |
|
|
|
|
|
if channel_shifts is not None: |
|
|
|
target_shifts = [] |
|
for c in range(min(channels, len(channel_shifts))): |
|
user_shift = channel_shifts[c] |
|
target_shift = user_shift * target_shift_percent |
|
target_shifts.append(target_shift) |
|
|
|
|
|
while len(target_shifts) < channels: |
|
target_shifts.append(0.0) |
|
|
|
epsilon_shifted = self.apply_color_shift(epsilon, [0.0, 0.0, 0.0], target_shifts) |
|
|
|
elif background_color is not None: |
|
|
|
epsilon_shifted = self._generate_color_shifted_noise(epsilon, background_color, target_shift_percent) |
|
|
|
else: |
|
raise ValueError("Either channel_shifts or background_color must be provided") |
|
|
|
|
|
if not bounding_boxes: |
|
|
|
return epsilon |
|
|
|
mask_blur = self.create_bounding_box_mask(height, width, bounding_boxes, blur_sigma) |
|
|
|
|
|
epsilon_masked = self.blend_noise_with_mask(epsilon, epsilon_shifted, mask_blur) |
|
|
|
print(f"🔲 Space-aware noise: {len(bounding_boxes)} reserved boxes, sigma={blur_sigma or 'auto'}") |
|
|
|
return epsilon_masked |
|
|
|
def _generate_color_shifted_noise(self, |
|
epsilon: torch.Tensor, |
|
background_color: List[float], |
|
target_shift_percent: float) -> torch.Tensor: |
|
"""Helper method to generate color-shifted noise from background color""" |
|
channels = epsilon.shape[1] |
|
r, g, b = background_color[0], background_color[1], background_color[2] |
|
|
|
target_shifts = [] |
|
for c in range(channels): |
|
if c == 0: |
|
brightness = (r + g + b) / 3.0 |
|
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent |
|
elif c == 1: |
|
if g > 0.7 or (r > 0.7 and g > 0.7): |
|
target_shift = target_shift_percent |
|
elif r > 0.7 or b > 0.7: |
|
target_shift = -target_shift_percent |
|
else: |
|
target_shift = 0.0 |
|
elif c == 2: |
|
if r > 0.7 and g < 0.3: |
|
target_shift = -target_shift_percent |
|
elif g > 0.7 and r < 0.3: |
|
target_shift = target_shift_percent |
|
else: |
|
target_shift = 0.0 |
|
elif c == 3: |
|
if b > 0.7: |
|
target_shift = -target_shift_percent |
|
elif r > 0.7 and g > 0.7: |
|
target_shift = target_shift_percent * 0.8 |
|
else: |
|
target_shift = 0.0 |
|
else: |
|
target_shift = 0.0 |
|
|
|
target_shifts.append(target_shift) |
|
|
|
return self.apply_color_shift(epsilon, background_color, target_shifts) |
|
|
|
|
|
class TKGDMPipeline: |
|
""" |
|
Enhanced Diffusion pipeline with TKG-DM non-reactive noise |
|
Supports multiple model architectures: SD 1.5, SDXL, SD 2.1, etc. |
|
""" |
|
|
|
|
|
MODEL_CONFIGS = { |
|
'sd1.5': { |
|
'pipeline_class': StableDiffusionPipeline, |
|
'default_model': 'runwayml/stable-diffusion-v1-5', |
|
'latent_channels': 4, |
|
'latent_scale_factor': 8, |
|
'default_size': (512, 512) |
|
}, |
|
'sdxl': { |
|
'pipeline_class': StableDiffusionXLPipeline, |
|
'default_model': 'stabilityai/stable-diffusion-xl-base-1.0', |
|
'latent_channels': 4, |
|
'latent_scale_factor': 8, |
|
'default_size': (1024, 1024) |
|
}, |
|
'sd2.1': { |
|
'pipeline_class': StableDiffusionPipeline, |
|
'default_model': 'stabilityai/stable-diffusion-2-1', |
|
'latent_channels': 4, |
|
'latent_scale_factor': 8, |
|
'default_size': (768, 768) |
|
} |
|
} |
|
|
|
def __init__(self, |
|
model_id: Optional[str] = None, |
|
model_type: str = "sd1.5", |
|
device: str = "cuda"): |
|
""" |
|
Initialize TKG-DM pipeline with specified model architecture |
|
|
|
Args: |
|
model_id: Specific model ID (optional, uses default for model_type) |
|
model_type: Model architecture type ('sd1.5', 'sdxl', 'sd2.1') |
|
device: Device to load model on |
|
""" |
|
self.device = device |
|
self.dtype = torch.float16 if device == "cuda" else torch.float32 |
|
self.model_type = model_type |
|
|
|
|
|
if model_type not in self.MODEL_CONFIGS: |
|
raise ValueError(f"Unsupported model type: {model_type}. Supported: {list(self.MODEL_CONFIGS.keys())}") |
|
|
|
self.config = self.MODEL_CONFIGS[model_type] |
|
self.model_id = model_id or self.config['default_model'] |
|
|
|
|
|
if model_id and model_id != self.config['default_model']: |
|
detected_type = self._detect_model_type(model_id) |
|
if detected_type and detected_type != model_type: |
|
print(f"🔄 Auto-detected model type: {detected_type} for {model_id}") |
|
self.model_type = detected_type |
|
self.config = self.MODEL_CONFIGS[detected_type] |
|
|
|
|
|
self._load_pipeline() |
|
|
|
|
|
self.noise_optimizer = TKGDMNoiseOptimizer(device, self.dtype) |
|
|
|
def _detect_model_type(self, model_id: str) -> Optional[str]: |
|
"""Auto-detect model type based on model ID patterns""" |
|
model_id_lower = model_id.lower() |
|
|
|
|
|
if any(pattern in model_id_lower for pattern in ['xl', 'sdxl', 'stable-diffusion-xl']): |
|
return 'sdxl' |
|
|
|
|
|
if any(pattern in model_id_lower for pattern in ['stable-diffusion-2', 'sd-2', 'v2-']): |
|
return 'sd2.1' |
|
|
|
|
|
return 'sd1.5' |
|
|
|
def _load_pipeline(self): |
|
"""Load the diffusion pipeline based on model type""" |
|
try: |
|
pipeline_class = self.config['pipeline_class'] |
|
|
|
|
|
pipeline_args = { |
|
'torch_dtype': self.dtype, |
|
'safety_checker': None, |
|
'requires_safety_checker': False |
|
} |
|
|
|
|
|
if self.model_type == 'sdxl': |
|
pipeline_args.update({ |
|
'use_safetensors': True, |
|
'variant': "fp16" if self.dtype == torch.float16 else None |
|
}) |
|
|
|
self.pipe = pipeline_class.from_pretrained( |
|
self.model_id, |
|
**pipeline_args |
|
).to(self.device) |
|
|
|
print(f"✅ Successfully loaded {self.model_type} model: {self.model_id}") |
|
|
|
except Exception as e: |
|
print(f"❌ Error loading {self.model_type} pipeline: {e}") |
|
self.pipe = None |
|
|
|
def get_latent_shape(self, height: int, width: int) -> Tuple[int, int, int, int]: |
|
"""Get latent shape for given image dimensions""" |
|
scale_factor = self.config['latent_scale_factor'] |
|
channels = self.config['latent_channels'] |
|
return (1, channels, height // scale_factor, width // scale_factor) |
|
|
|
def __call__(self, |
|
prompt: str, |
|
negative_prompt: Optional[str] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
background_color: Optional[List[float]] = None, |
|
channel_shifts: Optional[List[float]] = None, |
|
bounding_boxes: Optional[List[Tuple[float, float, float, float]]] = None, |
|
foreground_center: Tuple[float, float] = (0.5, 0.5), |
|
foreground_size: float = 0.3, |
|
target_shift_percent: float = 0.07, |
|
blur_sigma: Optional[float] = None, |
|
generator: Optional[torch.Generator] = None, |
|
**kwargs) -> torch.Tensor: |
|
""" |
|
Generate image with space-aware text-to-image using TKG-DM (Section 2.2) |
|
|
|
Args: |
|
prompt: Text prompt for generation |
|
negative_prompt: Negative prompt |
|
height, width: Output dimensions (uses model defaults if None) |
|
num_inference_steps: Denoising steps |
|
guidance_scale: CFG guidance scale |
|
background_color: Target background RGB (0-1) - for backwards compatibility |
|
channel_shifts: Direct latent channel shift values [-1,1] for each channel |
|
bounding_boxes: List of (x1,y1,x2,y2) normalized coords for reserved regions |
|
foreground_center: Foreground object center (x, y) - legacy single region |
|
foreground_size: Foreground region size - legacy single region |
|
target_shift_percent: TKG-DM shift percentage (±7%) |
|
blur_sigma: Gaussian blur sigma for soft transitions (auto if None) |
|
generator: Random number generator |
|
**kwargs: Additional model-specific parameters |
|
|
|
Returns: |
|
Generated image tensor |
|
""" |
|
if self.pipe is None: |
|
raise RuntimeError("Pipeline not loaded successfully") |
|
|
|
|
|
if height is None or width is None: |
|
default_height, default_width = self.config['default_size'] |
|
height = height or default_height |
|
width = width or default_width |
|
|
|
|
|
noise_shape = self.get_latent_shape(height, width) |
|
|
|
|
|
if bounding_boxes is not None: |
|
|
|
optimized_noise = self.noise_optimizer.generate_space_aware_noise( |
|
noise_shape, |
|
bounding_boxes=bounding_boxes, |
|
channel_shifts=channel_shifts, |
|
background_color=background_color, |
|
target_shift_percent=target_shift_percent, |
|
blur_sigma=blur_sigma |
|
) |
|
elif channel_shifts is not None: |
|
|
|
optimized_noise = self.noise_optimizer.generate_direct_channel_noise( |
|
noise_shape, |
|
channel_shifts=channel_shifts, |
|
foreground_center=foreground_center, |
|
foreground_size=foreground_size, |
|
target_shift_percent=target_shift_percent |
|
) |
|
else: |
|
|
|
bg_color = background_color or [0.0, 1.0, 0.0] |
|
optimized_noise = self.noise_optimizer.generate_chroma_key_noise( |
|
noise_shape, |
|
background_color=bg_color, |
|
foreground_center=foreground_center, |
|
foreground_size=foreground_size, |
|
target_shift_percent=target_shift_percent |
|
) |
|
|
|
|
|
pipe_args = { |
|
'prompt': prompt, |
|
'negative_prompt': negative_prompt, |
|
'height': height, |
|
'width': width, |
|
'num_inference_steps': num_inference_steps, |
|
'guidance_scale': guidance_scale, |
|
'latents': optimized_noise, |
|
'generator': generator |
|
} |
|
|
|
|
|
if self.model_type == 'sdxl': |
|
|
|
pipe_args.update({ |
|
'denoising_end': kwargs.get('denoising_end', None), |
|
'guidance_scale_end': kwargs.get('guidance_scale_end', None), |
|
'original_size': kwargs.get('original_size', (width, height)), |
|
'target_size': kwargs.get('target_size', (width, height)), |
|
'crops_coords_top_left': kwargs.get('crops_coords_top_left', (0, 0)) |
|
}) |
|
|
|
|
|
pipe_args = {k: v for k, v in pipe_args.items() if v is not None} |
|
|
|
|
|
with torch.no_grad(): |
|
result = self.pipe(**pipe_args) |
|
image = result.images[0] |
|
|
|
return image |