File size: 31,239 Bytes
c8db08b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 |
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, List, Union
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
DiffusionPipeline
)
import math
class TKGDMNoiseOptimizer:
"""
TKG-DM: Training-free Chroma Key Content Generation Diffusion Model
Implementation of non-reactive noise optimization for background control
"""
def __init__(self, device: str = "cuda", dtype: torch.dtype = None):
self.device = device
self.dtype = dtype or (torch.float16 if device == "cuda" else torch.float32)
def calculate_initial_ratio(self, noise_tensor: torch.Tensor, channel: int) -> float:
"""Calculate initial positive ratio for a specific channel"""
channel_data = noise_tensor[:, channel, :, :]
positive_pixels = (channel_data > 0).sum().float()
total_pixels = channel_data.numel()
return (positive_pixels / total_pixels).item()
def optimize_channel_shift(self,
noise_tensor: torch.Tensor,
channel: int,
target_shift: float,
max_iterations: int = 100,
tolerance: float = 1e-4) -> float:
"""
Optimize channel mean shift to achieve target positive ratio
Args:
noise_tensor: Initial noise tensor [B, C, H, W]
channel: Channel index to optimize (0=R, 1=G, 2=B)
target_shift: Desired shift in positive ratio
max_iterations: Maximum optimization iterations
tolerance: Convergence tolerance
Returns:
Optimal channel shift value
"""
initial_ratio = self.calculate_initial_ratio(noise_tensor, channel)
target_ratio = initial_ratio + target_shift
# Binary search for optimal shift
delta_min, delta_max = -10.0, 10.0
delta_optimal = 0.0
for _ in range(max_iterations):
delta_mid = (delta_min + delta_max) / 2
# Apply shift and calculate new ratio
shifted_tensor = noise_tensor.clone()
shifted_tensor[:, channel, :, :] += delta_mid
current_ratio = self.calculate_initial_ratio(shifted_tensor, channel)
if abs(current_ratio - target_ratio) < tolerance:
delta_optimal = delta_mid
break
if current_ratio < target_ratio:
delta_min = delta_mid
else:
delta_max = delta_mid
delta_optimal = delta_mid
return delta_optimal
def create_bounding_box_mask(self,
height: int,
width: int,
bounding_boxes: List[Tuple[float, float, float, float]],
sigma: Optional[float] = None) -> torch.Tensor:
"""
Create binary occupancy map from axis-aligned bounding boxes with Gaussian blur
Args:
height, width: Mask dimensions
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates [0,1]
sigma: Gaussian blur standard deviation (auto-calculated if None)
Returns:
Soft transition mask M_blur with values in [0,1]
"""
# Create binary occupancy map
binary_mask = torch.zeros((height, width), device=self.device)
for x1, y1, x2, y2 in bounding_boxes:
# Convert normalized coordinates to pixel coordinates
px1 = int(x1 * width)
py1 = int(y1 * height)
px2 = int(x2 * width)
py2 = int(y2 * height)
# Ensure valid bounds
px1, px2 = max(0, min(px1, px2)), min(width, max(px1, px2))
py1, py2 = max(0, min(py1, py2)), min(height, max(py1, py2))
# Mark pixels inside bounding box
binary_mask[py1:py2, px1:px2] = 1.0
# Apply Gaussian blur for soft transition
if sigma is None:
# Standard deviation proportional to shorter side as per paper
sigma = min(height, width) * 0.02 # 2% of shorter side
# Convert to tensor format for Gaussian blur
mask_4d = binary_mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
# Create Gaussian blur kernel
kernel_size = int(6 * sigma + 1) # 6-sigma rule
if kernel_size % 2 == 0:
kernel_size += 1
# Manual Gaussian blur implementation
blur_mask = self._gaussian_blur_2d(mask_4d, kernel_size, sigma)
return blur_mask.squeeze(0).squeeze(0) # [H, W]
def _gaussian_blur_2d(self, tensor: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
"""Apply 2D Gaussian blur to tensor"""
# Create 1D Gaussian kernel
coords = torch.arange(kernel_size, dtype=tensor.dtype, device=tensor.device)
coords = coords - kernel_size // 2
kernel = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
kernel = kernel / kernel.sum()
kernel = kernel.view(1, 1, 1, -1)
# Apply separable Gaussian blur
# Horizontal blur
padding = kernel_size // 2
blurred = torch.nn.functional.conv2d(
tensor, kernel, padding=(0, padding)
)
# Vertical blur
kernel_t = kernel.transpose(-1, -2)
blurred = torch.nn.functional.conv2d(
blurred, kernel_t, padding=(padding, 0)
)
return blurred
def create_gaussian_mask(self,
height: int,
width: int,
center_x: float,
center_y: float,
sigma: float) -> torch.Tensor:
"""
Create 2D Gaussian mask for noise blending
Args:
height, width: Mask dimensions
center_x, center_y: Gaussian center (normalized 0-1)
sigma: Gaussian spread parameter
Returns:
2D Gaussian mask tensor
"""
y_coords = torch.linspace(0, 1, height).view(-1, 1)
x_coords = torch.linspace(0, 1, width).view(1, -1)
# Calculate distances from center
dx = x_coords - center_x
dy = y_coords - center_y
# Gaussian formula
mask = torch.exp(-((dx**2 + dy**2) / (2 * sigma**2)))
return mask.to(self.device)
def apply_color_shift(self,
noise_tensor: torch.Tensor,
target_color: List[float],
target_shifts: List[float]) -> torch.Tensor:
"""
Apply latent channel shifts to noise tensor
Args:
noise_tensor: Input noise [B, C, H, W] (typically 4 channels for SD latent)
target_color: Target RGB color [R, G, B] (0-1 range)
target_shifts: Channel shift amounts for all latent channels
Returns:
Color-shifted noise tensor
"""
shifted_noise = noise_tensor.clone()
# Apply shifts to all available latent channels
num_channels = min(len(target_shifts), noise_tensor.shape[1])
for channel in range(num_channels):
if abs(target_shifts[channel]) > 1e-6:
delta = self.optimize_channel_shift(
noise_tensor, channel, target_shifts[channel]
)
shifted_noise[:, channel, :, :] += delta
return shifted_noise
def blend_noise_with_mask(self,
original_noise: torch.Tensor,
shifted_noise: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""
Blend original and shifted noise using space-aware formula from paper:
ε_masked = ε + M_blur(ε_shifted - ε)
Args:
original_noise: Original Gaussian noise ε [B, C, H, W]
shifted_noise: Non-reactive shifted noise ε_shifted [B, C, H, W]
mask: Soft transition mask M_blur [H, W] with values in [0,1]
Returns:
Composite noise tensor ε_masked
"""
# Expand mask to match noise dimensions and ensure correct dtype
mask_expanded = mask.unsqueeze(0).unsqueeze(0).to(original_noise.dtype) # [1, 1, H, W]
mask_expanded = mask_expanded.expand_as(original_noise)
# Apply space-aware blending formula: ε_masked = ε + M_blur(ε_shifted - ε)
# Inside reserved boxes (M_blur ≈ 1): dominated by mean-shifted noise (nonreactive)
# Outside boxes (M_blur ≈ 0): reduces to ordinary Gaussian noise
blended_noise = original_noise + mask_expanded * (shifted_noise - original_noise)
return blended_noise
def generate_chroma_key_noise(self,
shape: Tuple[int, int, int, int],
background_color: List[float] = [0.0, 1.0, 0.0], # Green screen
foreground_center: Tuple[float, float] = (0.5, 0.5),
foreground_size: float = 0.3,
target_shift_percent: float = 0.07) -> torch.Tensor:
"""
Generate optimized noise for chroma key content generation using TKG-DM method
Handles latent space channels correctly:
- Channels 0,3: Luminance and color information
- Channels 1,2: Color channels (positive = pink/yellow, negative = red/blue)
Args:
shape: Noise tensor shape [B, C, H, W] (typically [1, 4, 64, 64] for SD)
background_color: Target background RGB color (0-1)
foreground_center: Center of foreground object (x, y) normalized
foreground_size: Size of foreground region (sigma parameter)
target_shift_percent: Target shift percentage (±7% as per paper)
Returns:
Optimized noise tensor for chroma key generation
"""
batch_size, channels, height, width = shape
# Generate initial random noise with correct dtype
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
# Calculate target shifts for latent space channels based on color mapping
target_shifts = []
# Convert RGB to latent space color shifts
r, g, b = background_color[0], background_color[1], background_color[2]
for c in range(channels):
initial_ratio = self.calculate_initial_ratio(original_noise, c)
if c == 0: # Channel 0: Luminance
# Luminance based on overall brightness
brightness = (r + g + b) / 3.0
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
elif c == 1: # Channel 1: Pink/Yellow vs Red/Blue (first color channel)
# Pink/Yellow (positive) vs Red/Blue (negative)
# Green and Yellow favor positive, Red and Blue favor negative
if g > 0.7 or (r > 0.7 and g > 0.7): # Green or Yellow
target_shift = target_shift_percent # Positive for pink/yellow
elif r > 0.7 or b > 0.7: # Red or Blue
target_shift = -target_shift_percent # Negative for red/blue
else:
target_shift = 0.0 # Neutral for other colors
elif c == 2: # Channel 2: Pink/Yellow vs Red/Blue (second color channel)
# Complementary to channel 1 for full color control
if r > 0.7 and g < 0.3: # Pure red
target_shift = -target_shift_percent # Negative for red
elif g > 0.7 and r < 0.3: # Pure green
target_shift = target_shift_percent # Positive for green
else:
target_shift = 0.0
elif c == 3: # Channel 3: Additional luminance/color information
# Secondary luminance control
if b > 0.7: # Blue background
target_shift = -target_shift_percent # Negative for blue
elif r > 0.7 and g > 0.7: # Yellow/orange
target_shift = target_shift_percent * 0.8 # Moderate positive
# contrast = max(r, g, b) - min(r, g, b)
# target_shift = target_shift_percent * 0.5 if contrast > 0.5 else -target_shift_percent * 0.3
else:
target_shift = 0.0 # No shift for additional channels
target_shifts.append(target_shift)
print(f"Latent Channel {c}: Initial ratio={initial_ratio:.4f}, Target shift={target_shift:+.1%}")
# Apply color shifts to create background-optimized noise
shifted_noise = self.apply_color_shift(original_noise, background_color, target_shifts)
# Create Gaussian mask for foreground/background separation
mask = self.create_gaussian_mask(
height, width,
foreground_center[0], foreground_center[1],
foreground_size
)
# Blend original and shifted noise
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
return optimized_noise
def generate_direct_channel_noise(self,
shape: Tuple[int, int, int, int],
channel_shifts: List[float],
foreground_center: Tuple[float, float] = (0.5, 0.5),
foreground_size: float = 0.3,
target_shift_percent: float = 0.07) -> torch.Tensor:
"""
Generate optimized noise with direct channel shift control
Args:
shape: Noise tensor shape [B, C, H, W]
channel_shifts: Direct shift values for each channel [-1, 1]
foreground_center: Center of foreground object (x, y) normalized
foreground_size: Size of foreground region (sigma parameter)
target_shift_percent: Base shift percentage (±7% as per paper)
Returns:
Optimized noise tensor with direct channel control
"""
batch_size, channels, height, width = shape
# Generate initial random noise with correct dtype
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
# Convert user channel shifts to target shifts
target_shifts = []
for c in range(min(channels, len(channel_shifts))):
# Scale user input (-1 to 1) to target shift percentage
user_shift = channel_shifts[c]
target_shift = user_shift * target_shift_percent
target_shifts.append(target_shift)
initial_ratio = self.calculate_initial_ratio(original_noise, c)
print(f"Direct Channel {c}: User shift={user_shift:+.2f}, Target shift={target_shift:+.1%}, Initial ratio={initial_ratio:.4f}")
# Fill remaining channels with zero shift
while len(target_shifts) < channels:
target_shifts.append(0.0)
# Apply direct channel shifts
shifted_noise = self.apply_color_shift(original_noise, [0.0, 0.0, 0.0], target_shifts)
# Create Gaussian mask for foreground/background separation
mask = self.create_gaussian_mask(
height, width,
foreground_center[0], foreground_center[1],
foreground_size
)
# Blend original and shifted noise
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
return optimized_noise
def generate_space_aware_noise(self,
shape: Tuple[int, int, int, int],
bounding_boxes: List[Tuple[float, float, float, float]],
channel_shifts: Optional[List[float]] = None,
background_color: Optional[List[float]] = None,
target_shift_percent: float = 0.07,
blur_sigma: Optional[float] = None) -> torch.Tensor:
"""
Generate space-aware noise with reserved bounding boxes as per section 2.2
Args:
shape: Noise tensor shape [B, C, H, W]
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates for reserved regions
channel_shifts: Direct channel shifts (optional)
background_color: RGB background color for shift calculation (fallback)
target_shift_percent: Base shift percentage (±7% as per paper)
blur_sigma: Gaussian blur sigma (auto-calculated if None)
Returns:
Space-aware composite noise tensor ε_masked
"""
batch_size, channels, height, width = shape
# Generate standard Gaussian tensor ε ~ N(0, I)
epsilon = torch.randn(shape, device=self.device, dtype=self.dtype)
# Generate non-reactive noise ε_shifted using TKG-DM
if channel_shifts is not None:
# Use direct channel control
target_shifts = []
for c in range(min(channels, len(channel_shifts))):
user_shift = channel_shifts[c]
target_shift = user_shift * target_shift_percent
target_shifts.append(target_shift)
# Fill remaining channels with zero shift
while len(target_shifts) < channels:
target_shifts.append(0.0)
epsilon_shifted = self.apply_color_shift(epsilon, [0.0, 0.0, 0.0], target_shifts)
elif background_color is not None:
# Fallback to background color method
epsilon_shifted = self._generate_color_shifted_noise(epsilon, background_color, target_shift_percent)
else:
raise ValueError("Either channel_shifts or background_color must be provided")
# Create binary occupancy map and apply Gaussian blur: M_blur = GaussianBlur(M, σ)
if not bounding_boxes:
# No reserved boxes - return standard noise
return epsilon
mask_blur = self.create_bounding_box_mask(height, width, bounding_boxes, blur_sigma)
# Apply space-aware blending: ε_masked = ε + M_blur(ε_shifted - ε)
epsilon_masked = self.blend_noise_with_mask(epsilon, epsilon_shifted, mask_blur)
print(f"🔲 Space-aware noise: {len(bounding_boxes)} reserved boxes, sigma={blur_sigma or 'auto'}")
return epsilon_masked
def _generate_color_shifted_noise(self,
epsilon: torch.Tensor,
background_color: List[float],
target_shift_percent: float) -> torch.Tensor:
"""Helper method to generate color-shifted noise from background color"""
channels = epsilon.shape[1]
r, g, b = background_color[0], background_color[1], background_color[2]
target_shifts = []
for c in range(channels):
if c == 0: # Luminance
brightness = (r + g + b) / 3.0
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
elif c == 1: # Color channel 1
if g > 0.7 or (r > 0.7 and g > 0.7):
target_shift = target_shift_percent
elif r > 0.7 or b > 0.7:
target_shift = -target_shift_percent
else:
target_shift = 0.0
elif c == 2: # Color channel 2
if r > 0.7 and g < 0.3:
target_shift = -target_shift_percent
elif g > 0.7 and r < 0.3:
target_shift = target_shift_percent
else:
target_shift = 0.0
elif c == 3: # Secondary color
if b > 0.7:
target_shift = -target_shift_percent
elif r > 0.7 and g > 0.7:
target_shift = target_shift_percent * 0.8
else:
target_shift = 0.0
else:
target_shift = 0.0
target_shifts.append(target_shift)
return self.apply_color_shift(epsilon, background_color, target_shifts)
class TKGDMPipeline:
"""
Enhanced Diffusion pipeline with TKG-DM non-reactive noise
Supports multiple model architectures: SD 1.5, SDXL, SD 2.1, etc.
"""
# Model configurations for different architectures
MODEL_CONFIGS = {
'sd1.5': {
'pipeline_class': StableDiffusionPipeline,
'default_model': 'runwayml/stable-diffusion-v1-5',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (512, 512)
},
'sdxl': {
'pipeline_class': StableDiffusionXLPipeline,
'default_model': 'stabilityai/stable-diffusion-xl-base-1.0',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (1024, 1024)
},
'sd2.1': {
'pipeline_class': StableDiffusionPipeline,
'default_model': 'stabilityai/stable-diffusion-2-1',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (768, 768)
}
}
def __init__(self,
model_id: Optional[str] = None,
model_type: str = "sd1.5",
device: str = "cuda"):
"""
Initialize TKG-DM pipeline with specified model architecture
Args:
model_id: Specific model ID (optional, uses default for model_type)
model_type: Model architecture type ('sd1.5', 'sdxl', 'sd2.1')
device: Device to load model on
"""
self.device = device
self.dtype = torch.float16 if device == "cuda" else torch.float32
self.model_type = model_type
# Get model configuration
if model_type not in self.MODEL_CONFIGS:
raise ValueError(f"Unsupported model type: {model_type}. Supported: {list(self.MODEL_CONFIGS.keys())}")
self.config = self.MODEL_CONFIGS[model_type]
self.model_id = model_id or self.config['default_model']
# Auto-detect model type for custom models if needed
if model_id and model_id != self.config['default_model']:
detected_type = self._detect_model_type(model_id)
if detected_type and detected_type != model_type:
print(f"🔄 Auto-detected model type: {detected_type} for {model_id}")
self.model_type = detected_type
self.config = self.MODEL_CONFIGS[detected_type]
# Load the appropriate pipeline
self._load_pipeline()
# Initialize noise optimizer with model-specific parameters
self.noise_optimizer = TKGDMNoiseOptimizer(device, self.dtype)
def _detect_model_type(self, model_id: str) -> Optional[str]:
"""Auto-detect model type based on model ID patterns"""
model_id_lower = model_id.lower()
# SDXL detection patterns
if any(pattern in model_id_lower for pattern in ['xl', 'sdxl', 'stable-diffusion-xl']):
return 'sdxl'
# SD 2.x detection patterns
if any(pattern in model_id_lower for pattern in ['stable-diffusion-2', 'sd-2', 'v2-']):
return 'sd2.1'
# Default to SD 1.5 for most other cases
return 'sd1.5'
def _load_pipeline(self):
"""Load the diffusion pipeline based on model type"""
try:
pipeline_class = self.config['pipeline_class']
# Common pipeline arguments
pipeline_args = {
'torch_dtype': self.dtype,
'safety_checker': None,
'requires_safety_checker': False
}
# SDXL-specific arguments
if self.model_type == 'sdxl':
pipeline_args.update({
'use_safetensors': True,
'variant': "fp16" if self.dtype == torch.float16 else None
})
self.pipe = pipeline_class.from_pretrained(
self.model_id,
**pipeline_args
).to(self.device)
print(f"✅ Successfully loaded {self.model_type} model: {self.model_id}")
except Exception as e:
print(f"❌ Error loading {self.model_type} pipeline: {e}")
self.pipe = None
def get_latent_shape(self, height: int, width: int) -> Tuple[int, int, int, int]:
"""Get latent shape for given image dimensions"""
scale_factor = self.config['latent_scale_factor']
channels = self.config['latent_channels']
return (1, channels, height // scale_factor, width // scale_factor)
def __call__(self,
prompt: str,
negative_prompt: Optional[str] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
background_color: Optional[List[float]] = None, # Backwards compatibility
channel_shifts: Optional[List[float]] = None, # Direct channel control
bounding_boxes: Optional[List[Tuple[float, float, float, float]]] = None, # Space-aware boxes
foreground_center: Tuple[float, float] = (0.5, 0.5), # Legacy single region
foreground_size: float = 0.3, # Legacy single region
target_shift_percent: float = 0.07,
blur_sigma: Optional[float] = None, # Gaussian blur sigma
generator: Optional[torch.Generator] = None,
**kwargs) -> torch.Tensor:
"""
Generate image with space-aware text-to-image using TKG-DM (Section 2.2)
Args:
prompt: Text prompt for generation
negative_prompt: Negative prompt
height, width: Output dimensions (uses model defaults if None)
num_inference_steps: Denoising steps
guidance_scale: CFG guidance scale
background_color: Target background RGB (0-1) - for backwards compatibility
channel_shifts: Direct latent channel shift values [-1,1] for each channel
bounding_boxes: List of (x1,y1,x2,y2) normalized coords for reserved regions
foreground_center: Foreground object center (x, y) - legacy single region
foreground_size: Foreground region size - legacy single region
target_shift_percent: TKG-DM shift percentage (±7%)
blur_sigma: Gaussian blur sigma for soft transitions (auto if None)
generator: Random number generator
**kwargs: Additional model-specific parameters
Returns:
Generated image tensor
"""
if self.pipe is None:
raise RuntimeError("Pipeline not loaded successfully")
# Use model default dimensions if not specified
if height is None or width is None:
default_height, default_width = self.config['default_size']
height = height or default_height
width = width or default_width
# Get model-specific latent shape
noise_shape = self.get_latent_shape(height, width)
# Generate optimized initial noise
if bounding_boxes is not None:
# Use space-aware noise generation with reserved bounding boxes (Section 2.2)
optimized_noise = self.noise_optimizer.generate_space_aware_noise(
noise_shape,
bounding_boxes=bounding_boxes,
channel_shifts=channel_shifts,
background_color=background_color,
target_shift_percent=target_shift_percent,
blur_sigma=blur_sigma
)
elif channel_shifts is not None:
# Use direct channel shifts with legacy single region
optimized_noise = self.noise_optimizer.generate_direct_channel_noise(
noise_shape,
channel_shifts=channel_shifts,
foreground_center=foreground_center,
foreground_size=foreground_size,
target_shift_percent=target_shift_percent
)
else:
# Fallback to background color method with legacy single region
bg_color = background_color or [0.0, 1.0, 0.0]
optimized_noise = self.noise_optimizer.generate_chroma_key_noise(
noise_shape,
background_color=bg_color,
foreground_center=foreground_center,
foreground_size=foreground_size,
target_shift_percent=target_shift_percent
)
# Prepare pipeline arguments
pipe_args = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'height': height,
'width': width,
'num_inference_steps': num_inference_steps,
'guidance_scale': guidance_scale,
'latents': optimized_noise,
'generator': generator
}
# Add model-specific arguments
if self.model_type == 'sdxl':
# SDXL supports additional parameters
pipe_args.update({
'denoising_end': kwargs.get('denoising_end', None),
'guidance_scale_end': kwargs.get('guidance_scale_end', None),
'original_size': kwargs.get('original_size', (width, height)),
'target_size': kwargs.get('target_size', (width, height)),
'crops_coords_top_left': kwargs.get('crops_coords_top_left', (0, 0))
})
# Filter None values
pipe_args = {k: v for k, v in pipe_args.items() if v is not None}
# Generate image using optimized noise
with torch.no_grad():
result = self.pipe(**pipe_args)
image = result.images[0]
return image |