import torch import torch.nn as nn import numpy as np from typing import Optional, Tuple, List, Union from diffusers import ( StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline ) import math class TKGDMNoiseOptimizer: """ TKG-DM: Training-free Chroma Key Content Generation Diffusion Model Implementation of non-reactive noise optimization for background control """ def __init__(self, device: str = "cuda", dtype: torch.dtype = None): self.device = device self.dtype = dtype or (torch.float16 if device == "cuda" else torch.float32) def calculate_initial_ratio(self, noise_tensor: torch.Tensor, channel: int) -> float: """Calculate initial positive ratio for a specific channel""" channel_data = noise_tensor[:, channel, :, :] positive_pixels = (channel_data > 0).sum().float() total_pixels = channel_data.numel() return (positive_pixels / total_pixels).item() def optimize_channel_shift(self, noise_tensor: torch.Tensor, channel: int, target_shift: float, max_iterations: int = 100, tolerance: float = 1e-4) -> float: """ Optimize channel mean shift to achieve target positive ratio Args: noise_tensor: Initial noise tensor [B, C, H, W] channel: Channel index to optimize (0=R, 1=G, 2=B) target_shift: Desired shift in positive ratio max_iterations: Maximum optimization iterations tolerance: Convergence tolerance Returns: Optimal channel shift value """ initial_ratio = self.calculate_initial_ratio(noise_tensor, channel) target_ratio = initial_ratio + target_shift # Binary search for optimal shift delta_min, delta_max = -10.0, 10.0 delta_optimal = 0.0 for _ in range(max_iterations): delta_mid = (delta_min + delta_max) / 2 # Apply shift and calculate new ratio shifted_tensor = noise_tensor.clone() shifted_tensor[:, channel, :, :] += delta_mid current_ratio = self.calculate_initial_ratio(shifted_tensor, channel) if abs(current_ratio - target_ratio) < tolerance: delta_optimal = delta_mid break if current_ratio < target_ratio: delta_min = delta_mid else: delta_max = delta_mid delta_optimal = delta_mid return delta_optimal def create_bounding_box_mask(self, height: int, width: int, bounding_boxes: List[Tuple[float, float, float, float]], sigma: Optional[float] = None) -> torch.Tensor: """ Create binary occupancy map from axis-aligned bounding boxes with Gaussian blur Args: height, width: Mask dimensions bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates [0,1] sigma: Gaussian blur standard deviation (auto-calculated if None) Returns: Soft transition mask M_blur with values in [0,1] """ # Create binary occupancy map binary_mask = torch.zeros((height, width), device=self.device) for x1, y1, x2, y2 in bounding_boxes: # Convert normalized coordinates to pixel coordinates px1 = int(x1 * width) py1 = int(y1 * height) px2 = int(x2 * width) py2 = int(y2 * height) # Ensure valid bounds px1, px2 = max(0, min(px1, px2)), min(width, max(px1, px2)) py1, py2 = max(0, min(py1, py2)), min(height, max(py1, py2)) # Mark pixels inside bounding box binary_mask[py1:py2, px1:px2] = 1.0 # Apply Gaussian blur for soft transition if sigma is None: # Standard deviation proportional to shorter side as per paper sigma = min(height, width) * 0.02 # 2% of shorter side # Convert to tensor format for Gaussian blur mask_4d = binary_mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W] # Create Gaussian blur kernel kernel_size = int(6 * sigma + 1) # 6-sigma rule if kernel_size % 2 == 0: kernel_size += 1 # Manual Gaussian blur implementation blur_mask = self._gaussian_blur_2d(mask_4d, kernel_size, sigma) return blur_mask.squeeze(0).squeeze(0) # [H, W] def _gaussian_blur_2d(self, tensor: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: """Apply 2D Gaussian blur to tensor""" # Create 1D Gaussian kernel coords = torch.arange(kernel_size, dtype=tensor.dtype, device=tensor.device) coords = coords - kernel_size // 2 kernel = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) kernel = kernel / kernel.sum() kernel = kernel.view(1, 1, 1, -1) # Apply separable Gaussian blur # Horizontal blur padding = kernel_size // 2 blurred = torch.nn.functional.conv2d( tensor, kernel, padding=(0, padding) ) # Vertical blur kernel_t = kernel.transpose(-1, -2) blurred = torch.nn.functional.conv2d( blurred, kernel_t, padding=(padding, 0) ) return blurred def create_gaussian_mask(self, height: int, width: int, center_x: float, center_y: float, sigma: float) -> torch.Tensor: """ Create 2D Gaussian mask for noise blending Args: height, width: Mask dimensions center_x, center_y: Gaussian center (normalized 0-1) sigma: Gaussian spread parameter Returns: 2D Gaussian mask tensor """ y_coords = torch.linspace(0, 1, height).view(-1, 1) x_coords = torch.linspace(0, 1, width).view(1, -1) # Calculate distances from center dx = x_coords - center_x dy = y_coords - center_y # Gaussian formula mask = torch.exp(-((dx**2 + dy**2) / (2 * sigma**2))) return mask.to(self.device) def apply_color_shift(self, noise_tensor: torch.Tensor, target_color: List[float], target_shifts: List[float]) -> torch.Tensor: """ Apply latent channel shifts to noise tensor Args: noise_tensor: Input noise [B, C, H, W] (typically 4 channels for SD latent) target_color: Target RGB color [R, G, B] (0-1 range) target_shifts: Channel shift amounts for all latent channels Returns: Color-shifted noise tensor """ shifted_noise = noise_tensor.clone() # Apply shifts to all available latent channels num_channels = min(len(target_shifts), noise_tensor.shape[1]) for channel in range(num_channels): if abs(target_shifts[channel]) > 1e-6: delta = self.optimize_channel_shift( noise_tensor, channel, target_shifts[channel] ) shifted_noise[:, channel, :, :] += delta return shifted_noise def blend_noise_with_mask(self, original_noise: torch.Tensor, shifted_noise: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Blend original and shifted noise using space-aware formula from paper: ε_masked = ε + M_blur(ε_shifted - ε) Args: original_noise: Original Gaussian noise ε [B, C, H, W] shifted_noise: Non-reactive shifted noise ε_shifted [B, C, H, W] mask: Soft transition mask M_blur [H, W] with values in [0,1] Returns: Composite noise tensor ε_masked """ # Expand mask to match noise dimensions and ensure correct dtype mask_expanded = mask.unsqueeze(0).unsqueeze(0).to(original_noise.dtype) # [1, 1, H, W] mask_expanded = mask_expanded.expand_as(original_noise) # Apply space-aware blending formula: ε_masked = ε + M_blur(ε_shifted - ε) # Inside reserved boxes (M_blur ≈ 1): dominated by mean-shifted noise (nonreactive) # Outside boxes (M_blur ≈ 0): reduces to ordinary Gaussian noise blended_noise = original_noise + mask_expanded * (shifted_noise - original_noise) return blended_noise def generate_chroma_key_noise(self, shape: Tuple[int, int, int, int], background_color: List[float] = [0.0, 1.0, 0.0], # Green screen foreground_center: Tuple[float, float] = (0.5, 0.5), foreground_size: float = 0.3, target_shift_percent: float = 0.07) -> torch.Tensor: """ Generate optimized noise for chroma key content generation using TKG-DM method Handles latent space channels correctly: - Channels 0,3: Luminance and color information - Channels 1,2: Color channels (positive = pink/yellow, negative = red/blue) Args: shape: Noise tensor shape [B, C, H, W] (typically [1, 4, 64, 64] for SD) background_color: Target background RGB color (0-1) foreground_center: Center of foreground object (x, y) normalized foreground_size: Size of foreground region (sigma parameter) target_shift_percent: Target shift percentage (±7% as per paper) Returns: Optimized noise tensor for chroma key generation """ batch_size, channels, height, width = shape # Generate initial random noise with correct dtype original_noise = torch.randn(shape, device=self.device, dtype=self.dtype) # Calculate target shifts for latent space channels based on color mapping target_shifts = [] # Convert RGB to latent space color shifts r, g, b = background_color[0], background_color[1], background_color[2] for c in range(channels): initial_ratio = self.calculate_initial_ratio(original_noise, c) if c == 0: # Channel 0: Luminance # Luminance based on overall brightness brightness = (r + g + b) / 3.0 target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent elif c == 1: # Channel 1: Pink/Yellow vs Red/Blue (first color channel) # Pink/Yellow (positive) vs Red/Blue (negative) # Green and Yellow favor positive, Red and Blue favor negative if g > 0.7 or (r > 0.7 and g > 0.7): # Green or Yellow target_shift = target_shift_percent # Positive for pink/yellow elif r > 0.7 or b > 0.7: # Red or Blue target_shift = -target_shift_percent # Negative for red/blue else: target_shift = 0.0 # Neutral for other colors elif c == 2: # Channel 2: Pink/Yellow vs Red/Blue (second color channel) # Complementary to channel 1 for full color control if r > 0.7 and g < 0.3: # Pure red target_shift = -target_shift_percent # Negative for red elif g > 0.7 and r < 0.3: # Pure green target_shift = target_shift_percent # Positive for green else: target_shift = 0.0 elif c == 3: # Channel 3: Additional luminance/color information # Secondary luminance control if b > 0.7: # Blue background target_shift = -target_shift_percent # Negative for blue elif r > 0.7 and g > 0.7: # Yellow/orange target_shift = target_shift_percent * 0.8 # Moderate positive # contrast = max(r, g, b) - min(r, g, b) # target_shift = target_shift_percent * 0.5 if contrast > 0.5 else -target_shift_percent * 0.3 else: target_shift = 0.0 # No shift for additional channels target_shifts.append(target_shift) print(f"Latent Channel {c}: Initial ratio={initial_ratio:.4f}, Target shift={target_shift:+.1%}") # Apply color shifts to create background-optimized noise shifted_noise = self.apply_color_shift(original_noise, background_color, target_shifts) # Create Gaussian mask for foreground/background separation mask = self.create_gaussian_mask( height, width, foreground_center[0], foreground_center[1], foreground_size ) # Blend original and shifted noise optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask) return optimized_noise def generate_direct_channel_noise(self, shape: Tuple[int, int, int, int], channel_shifts: List[float], foreground_center: Tuple[float, float] = (0.5, 0.5), foreground_size: float = 0.3, target_shift_percent: float = 0.07) -> torch.Tensor: """ Generate optimized noise with direct channel shift control Args: shape: Noise tensor shape [B, C, H, W] channel_shifts: Direct shift values for each channel [-1, 1] foreground_center: Center of foreground object (x, y) normalized foreground_size: Size of foreground region (sigma parameter) target_shift_percent: Base shift percentage (±7% as per paper) Returns: Optimized noise tensor with direct channel control """ batch_size, channels, height, width = shape # Generate initial random noise with correct dtype original_noise = torch.randn(shape, device=self.device, dtype=self.dtype) # Convert user channel shifts to target shifts target_shifts = [] for c in range(min(channels, len(channel_shifts))): # Scale user input (-1 to 1) to target shift percentage user_shift = channel_shifts[c] target_shift = user_shift * target_shift_percent target_shifts.append(target_shift) initial_ratio = self.calculate_initial_ratio(original_noise, c) print(f"Direct Channel {c}: User shift={user_shift:+.2f}, Target shift={target_shift:+.1%}, Initial ratio={initial_ratio:.4f}") # Fill remaining channels with zero shift while len(target_shifts) < channels: target_shifts.append(0.0) # Apply direct channel shifts shifted_noise = self.apply_color_shift(original_noise, [0.0, 0.0, 0.0], target_shifts) # Create Gaussian mask for foreground/background separation mask = self.create_gaussian_mask( height, width, foreground_center[0], foreground_center[1], foreground_size ) # Blend original and shifted noise optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask) return optimized_noise def generate_space_aware_noise(self, shape: Tuple[int, int, int, int], bounding_boxes: List[Tuple[float, float, float, float]], channel_shifts: Optional[List[float]] = None, background_color: Optional[List[float]] = None, target_shift_percent: float = 0.07, blur_sigma: Optional[float] = None) -> torch.Tensor: """ Generate space-aware noise with reserved bounding boxes as per section 2.2 Args: shape: Noise tensor shape [B, C, H, W] bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates for reserved regions channel_shifts: Direct channel shifts (optional) background_color: RGB background color for shift calculation (fallback) target_shift_percent: Base shift percentage (±7% as per paper) blur_sigma: Gaussian blur sigma (auto-calculated if None) Returns: Space-aware composite noise tensor ε_masked """ batch_size, channels, height, width = shape # Generate standard Gaussian tensor ε ~ N(0, I) epsilon = torch.randn(shape, device=self.device, dtype=self.dtype) # Generate non-reactive noise ε_shifted using TKG-DM if channel_shifts is not None: # Use direct channel control target_shifts = [] for c in range(min(channels, len(channel_shifts))): user_shift = channel_shifts[c] target_shift = user_shift * target_shift_percent target_shifts.append(target_shift) # Fill remaining channels with zero shift while len(target_shifts) < channels: target_shifts.append(0.0) epsilon_shifted = self.apply_color_shift(epsilon, [0.0, 0.0, 0.0], target_shifts) elif background_color is not None: # Fallback to background color method epsilon_shifted = self._generate_color_shifted_noise(epsilon, background_color, target_shift_percent) else: raise ValueError("Either channel_shifts or background_color must be provided") # Create binary occupancy map and apply Gaussian blur: M_blur = GaussianBlur(M, σ) if not bounding_boxes: # No reserved boxes - return standard noise return epsilon mask_blur = self.create_bounding_box_mask(height, width, bounding_boxes, blur_sigma) # Apply space-aware blending: ε_masked = ε + M_blur(ε_shifted - ε) epsilon_masked = self.blend_noise_with_mask(epsilon, epsilon_shifted, mask_blur) print(f"🔲 Space-aware noise: {len(bounding_boxes)} reserved boxes, sigma={blur_sigma or 'auto'}") return epsilon_masked def _generate_color_shifted_noise(self, epsilon: torch.Tensor, background_color: List[float], target_shift_percent: float) -> torch.Tensor: """Helper method to generate color-shifted noise from background color""" channels = epsilon.shape[1] r, g, b = background_color[0], background_color[1], background_color[2] target_shifts = [] for c in range(channels): if c == 0: # Luminance brightness = (r + g + b) / 3.0 target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent elif c == 1: # Color channel 1 if g > 0.7 or (r > 0.7 and g > 0.7): target_shift = target_shift_percent elif r > 0.7 or b > 0.7: target_shift = -target_shift_percent else: target_shift = 0.0 elif c == 2: # Color channel 2 if r > 0.7 and g < 0.3: target_shift = -target_shift_percent elif g > 0.7 and r < 0.3: target_shift = target_shift_percent else: target_shift = 0.0 elif c == 3: # Secondary color if b > 0.7: target_shift = -target_shift_percent elif r > 0.7 and g > 0.7: target_shift = target_shift_percent * 0.8 else: target_shift = 0.0 else: target_shift = 0.0 target_shifts.append(target_shift) return self.apply_color_shift(epsilon, background_color, target_shifts) class TKGDMPipeline: """ Enhanced Diffusion pipeline with TKG-DM non-reactive noise Supports multiple model architectures: SD 1.5, SDXL, SD 2.1, etc. """ # Model configurations for different architectures MODEL_CONFIGS = { 'sd1.5': { 'pipeline_class': StableDiffusionPipeline, 'default_model': 'runwayml/stable-diffusion-v1-5', 'latent_channels': 4, 'latent_scale_factor': 8, 'default_size': (512, 512) }, 'sdxl': { 'pipeline_class': StableDiffusionXLPipeline, 'default_model': 'stabilityai/stable-diffusion-xl-base-1.0', 'latent_channels': 4, 'latent_scale_factor': 8, 'default_size': (1024, 1024) }, 'sd2.1': { 'pipeline_class': StableDiffusionPipeline, 'default_model': 'stabilityai/stable-diffusion-2-1', 'latent_channels': 4, 'latent_scale_factor': 8, 'default_size': (768, 768) } } def __init__(self, model_id: Optional[str] = None, model_type: str = "sd1.5", device: str = "cuda"): """ Initialize TKG-DM pipeline with specified model architecture Args: model_id: Specific model ID (optional, uses default for model_type) model_type: Model architecture type ('sd1.5', 'sdxl', 'sd2.1') device: Device to load model on """ self.device = device self.dtype = torch.float16 if device == "cuda" else torch.float32 self.model_type = model_type # Get model configuration if model_type not in self.MODEL_CONFIGS: raise ValueError(f"Unsupported model type: {model_type}. Supported: {list(self.MODEL_CONFIGS.keys())}") self.config = self.MODEL_CONFIGS[model_type] self.model_id = model_id or self.config['default_model'] # Auto-detect model type for custom models if needed if model_id and model_id != self.config['default_model']: detected_type = self._detect_model_type(model_id) if detected_type and detected_type != model_type: print(f"🔄 Auto-detected model type: {detected_type} for {model_id}") self.model_type = detected_type self.config = self.MODEL_CONFIGS[detected_type] # Load the appropriate pipeline self._load_pipeline() # Initialize noise optimizer with model-specific parameters self.noise_optimizer = TKGDMNoiseOptimizer(device, self.dtype) def _detect_model_type(self, model_id: str) -> Optional[str]: """Auto-detect model type based on model ID patterns""" model_id_lower = model_id.lower() # SDXL detection patterns if any(pattern in model_id_lower for pattern in ['xl', 'sdxl', 'stable-diffusion-xl']): return 'sdxl' # SD 2.x detection patterns if any(pattern in model_id_lower for pattern in ['stable-diffusion-2', 'sd-2', 'v2-']): return 'sd2.1' # Default to SD 1.5 for most other cases return 'sd1.5' def _load_pipeline(self): """Load the diffusion pipeline based on model type""" try: pipeline_class = self.config['pipeline_class'] # Common pipeline arguments pipeline_args = { 'torch_dtype': self.dtype, 'safety_checker': None, 'requires_safety_checker': False } # SDXL-specific arguments if self.model_type == 'sdxl': pipeline_args.update({ 'use_safetensors': True, 'variant': "fp16" if self.dtype == torch.float16 else None }) self.pipe = pipeline_class.from_pretrained( self.model_id, **pipeline_args ).to(self.device) print(f"✅ Successfully loaded {self.model_type} model: {self.model_id}") except Exception as e: print(f"❌ Error loading {self.model_type} pipeline: {e}") self.pipe = None def get_latent_shape(self, height: int, width: int) -> Tuple[int, int, int, int]: """Get latent shape for given image dimensions""" scale_factor = self.config['latent_scale_factor'] channels = self.config['latent_channels'] return (1, channels, height // scale_factor, width // scale_factor) def __call__(self, prompt: str, negative_prompt: Optional[str] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, background_color: Optional[List[float]] = None, # Backwards compatibility channel_shifts: Optional[List[float]] = None, # Direct channel control bounding_boxes: Optional[List[Tuple[float, float, float, float]]] = None, # Space-aware boxes foreground_center: Tuple[float, float] = (0.5, 0.5), # Legacy single region foreground_size: float = 0.3, # Legacy single region target_shift_percent: float = 0.07, blur_sigma: Optional[float] = None, # Gaussian blur sigma generator: Optional[torch.Generator] = None, **kwargs) -> torch.Tensor: """ Generate image with space-aware text-to-image using TKG-DM (Section 2.2) Args: prompt: Text prompt for generation negative_prompt: Negative prompt height, width: Output dimensions (uses model defaults if None) num_inference_steps: Denoising steps guidance_scale: CFG guidance scale background_color: Target background RGB (0-1) - for backwards compatibility channel_shifts: Direct latent channel shift values [-1,1] for each channel bounding_boxes: List of (x1,y1,x2,y2) normalized coords for reserved regions foreground_center: Foreground object center (x, y) - legacy single region foreground_size: Foreground region size - legacy single region target_shift_percent: TKG-DM shift percentage (±7%) blur_sigma: Gaussian blur sigma for soft transitions (auto if None) generator: Random number generator **kwargs: Additional model-specific parameters Returns: Generated image tensor """ if self.pipe is None: raise RuntimeError("Pipeline not loaded successfully") # Use model default dimensions if not specified if height is None or width is None: default_height, default_width = self.config['default_size'] height = height or default_height width = width or default_width # Get model-specific latent shape noise_shape = self.get_latent_shape(height, width) # Generate optimized initial noise if bounding_boxes is not None: # Use space-aware noise generation with reserved bounding boxes (Section 2.2) optimized_noise = self.noise_optimizer.generate_space_aware_noise( noise_shape, bounding_boxes=bounding_boxes, channel_shifts=channel_shifts, background_color=background_color, target_shift_percent=target_shift_percent, blur_sigma=blur_sigma ) elif channel_shifts is not None: # Use direct channel shifts with legacy single region optimized_noise = self.noise_optimizer.generate_direct_channel_noise( noise_shape, channel_shifts=channel_shifts, foreground_center=foreground_center, foreground_size=foreground_size, target_shift_percent=target_shift_percent ) else: # Fallback to background color method with legacy single region bg_color = background_color or [0.0, 1.0, 0.0] optimized_noise = self.noise_optimizer.generate_chroma_key_noise( noise_shape, background_color=bg_color, foreground_center=foreground_center, foreground_size=foreground_size, target_shift_percent=target_shift_percent ) # Prepare pipeline arguments pipe_args = { 'prompt': prompt, 'negative_prompt': negative_prompt, 'height': height, 'width': width, 'num_inference_steps': num_inference_steps, 'guidance_scale': guidance_scale, 'latents': optimized_noise, 'generator': generator } # Add model-specific arguments if self.model_type == 'sdxl': # SDXL supports additional parameters pipe_args.update({ 'denoising_end': kwargs.get('denoising_end', None), 'guidance_scale_end': kwargs.get('guidance_scale_end', None), 'original_size': kwargs.get('original_size', (width, height)), 'target_size': kwargs.get('target_size', (width, height)), 'crops_coords_top_left': kwargs.get('crops_coords_top_left', (0, 0)) }) # Filter None values pipe_args = {k: v for k, v in pipe_args.items() if v is not None} # Generate image using optimized noise with torch.no_grad(): result = self.pipe(**pipe_args) image = result.images[0] return image