eiji
init
c8db08b
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, List, Union
from diffusers import (
StableDiffusionPipeline,
StableDiffusionXLPipeline,
DiffusionPipeline
)
import math
class TKGDMNoiseOptimizer:
"""
TKG-DM: Training-free Chroma Key Content Generation Diffusion Model
Implementation of non-reactive noise optimization for background control
"""
def __init__(self, device: str = "cuda", dtype: torch.dtype = None):
self.device = device
self.dtype = dtype or (torch.float16 if device == "cuda" else torch.float32)
def calculate_initial_ratio(self, noise_tensor: torch.Tensor, channel: int) -> float:
"""Calculate initial positive ratio for a specific channel"""
channel_data = noise_tensor[:, channel, :, :]
positive_pixels = (channel_data > 0).sum().float()
total_pixels = channel_data.numel()
return (positive_pixels / total_pixels).item()
def optimize_channel_shift(self,
noise_tensor: torch.Tensor,
channel: int,
target_shift: float,
max_iterations: int = 100,
tolerance: float = 1e-4) -> float:
"""
Optimize channel mean shift to achieve target positive ratio
Args:
noise_tensor: Initial noise tensor [B, C, H, W]
channel: Channel index to optimize (0=R, 1=G, 2=B)
target_shift: Desired shift in positive ratio
max_iterations: Maximum optimization iterations
tolerance: Convergence tolerance
Returns:
Optimal channel shift value
"""
initial_ratio = self.calculate_initial_ratio(noise_tensor, channel)
target_ratio = initial_ratio + target_shift
# Binary search for optimal shift
delta_min, delta_max = -10.0, 10.0
delta_optimal = 0.0
for _ in range(max_iterations):
delta_mid = (delta_min + delta_max) / 2
# Apply shift and calculate new ratio
shifted_tensor = noise_tensor.clone()
shifted_tensor[:, channel, :, :] += delta_mid
current_ratio = self.calculate_initial_ratio(shifted_tensor, channel)
if abs(current_ratio - target_ratio) < tolerance:
delta_optimal = delta_mid
break
if current_ratio < target_ratio:
delta_min = delta_mid
else:
delta_max = delta_mid
delta_optimal = delta_mid
return delta_optimal
def create_bounding_box_mask(self,
height: int,
width: int,
bounding_boxes: List[Tuple[float, float, float, float]],
sigma: Optional[float] = None) -> torch.Tensor:
"""
Create binary occupancy map from axis-aligned bounding boxes with Gaussian blur
Args:
height, width: Mask dimensions
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates [0,1]
sigma: Gaussian blur standard deviation (auto-calculated if None)
Returns:
Soft transition mask M_blur with values in [0,1]
"""
# Create binary occupancy map
binary_mask = torch.zeros((height, width), device=self.device)
for x1, y1, x2, y2 in bounding_boxes:
# Convert normalized coordinates to pixel coordinates
px1 = int(x1 * width)
py1 = int(y1 * height)
px2 = int(x2 * width)
py2 = int(y2 * height)
# Ensure valid bounds
px1, px2 = max(0, min(px1, px2)), min(width, max(px1, px2))
py1, py2 = max(0, min(py1, py2)), min(height, max(py1, py2))
# Mark pixels inside bounding box
binary_mask[py1:py2, px1:px2] = 1.0
# Apply Gaussian blur for soft transition
if sigma is None:
# Standard deviation proportional to shorter side as per paper
sigma = min(height, width) * 0.02 # 2% of shorter side
# Convert to tensor format for Gaussian blur
mask_4d = binary_mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
# Create Gaussian blur kernel
kernel_size = int(6 * sigma + 1) # 6-sigma rule
if kernel_size % 2 == 0:
kernel_size += 1
# Manual Gaussian blur implementation
blur_mask = self._gaussian_blur_2d(mask_4d, kernel_size, sigma)
return blur_mask.squeeze(0).squeeze(0) # [H, W]
def _gaussian_blur_2d(self, tensor: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
"""Apply 2D Gaussian blur to tensor"""
# Create 1D Gaussian kernel
coords = torch.arange(kernel_size, dtype=tensor.dtype, device=tensor.device)
coords = coords - kernel_size // 2
kernel = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
kernel = kernel / kernel.sum()
kernel = kernel.view(1, 1, 1, -1)
# Apply separable Gaussian blur
# Horizontal blur
padding = kernel_size // 2
blurred = torch.nn.functional.conv2d(
tensor, kernel, padding=(0, padding)
)
# Vertical blur
kernel_t = kernel.transpose(-1, -2)
blurred = torch.nn.functional.conv2d(
blurred, kernel_t, padding=(padding, 0)
)
return blurred
def create_gaussian_mask(self,
height: int,
width: int,
center_x: float,
center_y: float,
sigma: float) -> torch.Tensor:
"""
Create 2D Gaussian mask for noise blending
Args:
height, width: Mask dimensions
center_x, center_y: Gaussian center (normalized 0-1)
sigma: Gaussian spread parameter
Returns:
2D Gaussian mask tensor
"""
y_coords = torch.linspace(0, 1, height).view(-1, 1)
x_coords = torch.linspace(0, 1, width).view(1, -1)
# Calculate distances from center
dx = x_coords - center_x
dy = y_coords - center_y
# Gaussian formula
mask = torch.exp(-((dx**2 + dy**2) / (2 * sigma**2)))
return mask.to(self.device)
def apply_color_shift(self,
noise_tensor: torch.Tensor,
target_color: List[float],
target_shifts: List[float]) -> torch.Tensor:
"""
Apply latent channel shifts to noise tensor
Args:
noise_tensor: Input noise [B, C, H, W] (typically 4 channels for SD latent)
target_color: Target RGB color [R, G, B] (0-1 range)
target_shifts: Channel shift amounts for all latent channels
Returns:
Color-shifted noise tensor
"""
shifted_noise = noise_tensor.clone()
# Apply shifts to all available latent channels
num_channels = min(len(target_shifts), noise_tensor.shape[1])
for channel in range(num_channels):
if abs(target_shifts[channel]) > 1e-6:
delta = self.optimize_channel_shift(
noise_tensor, channel, target_shifts[channel]
)
shifted_noise[:, channel, :, :] += delta
return shifted_noise
def blend_noise_with_mask(self,
original_noise: torch.Tensor,
shifted_noise: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
"""
Blend original and shifted noise using space-aware formula from paper:
ε_masked = ε + M_blur(ε_shifted - ε)
Args:
original_noise: Original Gaussian noise ε [B, C, H, W]
shifted_noise: Non-reactive shifted noise ε_shifted [B, C, H, W]
mask: Soft transition mask M_blur [H, W] with values in [0,1]
Returns:
Composite noise tensor ε_masked
"""
# Expand mask to match noise dimensions and ensure correct dtype
mask_expanded = mask.unsqueeze(0).unsqueeze(0).to(original_noise.dtype) # [1, 1, H, W]
mask_expanded = mask_expanded.expand_as(original_noise)
# Apply space-aware blending formula: ε_masked = ε + M_blur(ε_shifted - ε)
# Inside reserved boxes (M_blur ≈ 1): dominated by mean-shifted noise (nonreactive)
# Outside boxes (M_blur ≈ 0): reduces to ordinary Gaussian noise
blended_noise = original_noise + mask_expanded * (shifted_noise - original_noise)
return blended_noise
def generate_chroma_key_noise(self,
shape: Tuple[int, int, int, int],
background_color: List[float] = [0.0, 1.0, 0.0], # Green screen
foreground_center: Tuple[float, float] = (0.5, 0.5),
foreground_size: float = 0.3,
target_shift_percent: float = 0.07) -> torch.Tensor:
"""
Generate optimized noise for chroma key content generation using TKG-DM method
Handles latent space channels correctly:
- Channels 0,3: Luminance and color information
- Channels 1,2: Color channels (positive = pink/yellow, negative = red/blue)
Args:
shape: Noise tensor shape [B, C, H, W] (typically [1, 4, 64, 64] for SD)
background_color: Target background RGB color (0-1)
foreground_center: Center of foreground object (x, y) normalized
foreground_size: Size of foreground region (sigma parameter)
target_shift_percent: Target shift percentage (±7% as per paper)
Returns:
Optimized noise tensor for chroma key generation
"""
batch_size, channels, height, width = shape
# Generate initial random noise with correct dtype
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
# Calculate target shifts for latent space channels based on color mapping
target_shifts = []
# Convert RGB to latent space color shifts
r, g, b = background_color[0], background_color[1], background_color[2]
for c in range(channels):
initial_ratio = self.calculate_initial_ratio(original_noise, c)
if c == 0: # Channel 0: Luminance
# Luminance based on overall brightness
brightness = (r + g + b) / 3.0
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
elif c == 1: # Channel 1: Pink/Yellow vs Red/Blue (first color channel)
# Pink/Yellow (positive) vs Red/Blue (negative)
# Green and Yellow favor positive, Red and Blue favor negative
if g > 0.7 or (r > 0.7 and g > 0.7): # Green or Yellow
target_shift = target_shift_percent # Positive for pink/yellow
elif r > 0.7 or b > 0.7: # Red or Blue
target_shift = -target_shift_percent # Negative for red/blue
else:
target_shift = 0.0 # Neutral for other colors
elif c == 2: # Channel 2: Pink/Yellow vs Red/Blue (second color channel)
# Complementary to channel 1 for full color control
if r > 0.7 and g < 0.3: # Pure red
target_shift = -target_shift_percent # Negative for red
elif g > 0.7 and r < 0.3: # Pure green
target_shift = target_shift_percent # Positive for green
else:
target_shift = 0.0
elif c == 3: # Channel 3: Additional luminance/color information
# Secondary luminance control
if b > 0.7: # Blue background
target_shift = -target_shift_percent # Negative for blue
elif r > 0.7 and g > 0.7: # Yellow/orange
target_shift = target_shift_percent * 0.8 # Moderate positive
# contrast = max(r, g, b) - min(r, g, b)
# target_shift = target_shift_percent * 0.5 if contrast > 0.5 else -target_shift_percent * 0.3
else:
target_shift = 0.0 # No shift for additional channels
target_shifts.append(target_shift)
print(f"Latent Channel {c}: Initial ratio={initial_ratio:.4f}, Target shift={target_shift:+.1%}")
# Apply color shifts to create background-optimized noise
shifted_noise = self.apply_color_shift(original_noise, background_color, target_shifts)
# Create Gaussian mask for foreground/background separation
mask = self.create_gaussian_mask(
height, width,
foreground_center[0], foreground_center[1],
foreground_size
)
# Blend original and shifted noise
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
return optimized_noise
def generate_direct_channel_noise(self,
shape: Tuple[int, int, int, int],
channel_shifts: List[float],
foreground_center: Tuple[float, float] = (0.5, 0.5),
foreground_size: float = 0.3,
target_shift_percent: float = 0.07) -> torch.Tensor:
"""
Generate optimized noise with direct channel shift control
Args:
shape: Noise tensor shape [B, C, H, W]
channel_shifts: Direct shift values for each channel [-1, 1]
foreground_center: Center of foreground object (x, y) normalized
foreground_size: Size of foreground region (sigma parameter)
target_shift_percent: Base shift percentage (±7% as per paper)
Returns:
Optimized noise tensor with direct channel control
"""
batch_size, channels, height, width = shape
# Generate initial random noise with correct dtype
original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
# Convert user channel shifts to target shifts
target_shifts = []
for c in range(min(channels, len(channel_shifts))):
# Scale user input (-1 to 1) to target shift percentage
user_shift = channel_shifts[c]
target_shift = user_shift * target_shift_percent
target_shifts.append(target_shift)
initial_ratio = self.calculate_initial_ratio(original_noise, c)
print(f"Direct Channel {c}: User shift={user_shift:+.2f}, Target shift={target_shift:+.1%}, Initial ratio={initial_ratio:.4f}")
# Fill remaining channels with zero shift
while len(target_shifts) < channels:
target_shifts.append(0.0)
# Apply direct channel shifts
shifted_noise = self.apply_color_shift(original_noise, [0.0, 0.0, 0.0], target_shifts)
# Create Gaussian mask for foreground/background separation
mask = self.create_gaussian_mask(
height, width,
foreground_center[0], foreground_center[1],
foreground_size
)
# Blend original and shifted noise
optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
return optimized_noise
def generate_space_aware_noise(self,
shape: Tuple[int, int, int, int],
bounding_boxes: List[Tuple[float, float, float, float]],
channel_shifts: Optional[List[float]] = None,
background_color: Optional[List[float]] = None,
target_shift_percent: float = 0.07,
blur_sigma: Optional[float] = None) -> torch.Tensor:
"""
Generate space-aware noise with reserved bounding boxes as per section 2.2
Args:
shape: Noise tensor shape [B, C, H, W]
bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates for reserved regions
channel_shifts: Direct channel shifts (optional)
background_color: RGB background color for shift calculation (fallback)
target_shift_percent: Base shift percentage (±7% as per paper)
blur_sigma: Gaussian blur sigma (auto-calculated if None)
Returns:
Space-aware composite noise tensor ε_masked
"""
batch_size, channels, height, width = shape
# Generate standard Gaussian tensor ε ~ N(0, I)
epsilon = torch.randn(shape, device=self.device, dtype=self.dtype)
# Generate non-reactive noise ε_shifted using TKG-DM
if channel_shifts is not None:
# Use direct channel control
target_shifts = []
for c in range(min(channels, len(channel_shifts))):
user_shift = channel_shifts[c]
target_shift = user_shift * target_shift_percent
target_shifts.append(target_shift)
# Fill remaining channels with zero shift
while len(target_shifts) < channels:
target_shifts.append(0.0)
epsilon_shifted = self.apply_color_shift(epsilon, [0.0, 0.0, 0.0], target_shifts)
elif background_color is not None:
# Fallback to background color method
epsilon_shifted = self._generate_color_shifted_noise(epsilon, background_color, target_shift_percent)
else:
raise ValueError("Either channel_shifts or background_color must be provided")
# Create binary occupancy map and apply Gaussian blur: M_blur = GaussianBlur(M, σ)
if not bounding_boxes:
# No reserved boxes - return standard noise
return epsilon
mask_blur = self.create_bounding_box_mask(height, width, bounding_boxes, blur_sigma)
# Apply space-aware blending: ε_masked = ε + M_blur(ε_shifted - ε)
epsilon_masked = self.blend_noise_with_mask(epsilon, epsilon_shifted, mask_blur)
print(f"🔲 Space-aware noise: {len(bounding_boxes)} reserved boxes, sigma={blur_sigma or 'auto'}")
return epsilon_masked
def _generate_color_shifted_noise(self,
epsilon: torch.Tensor,
background_color: List[float],
target_shift_percent: float) -> torch.Tensor:
"""Helper method to generate color-shifted noise from background color"""
channels = epsilon.shape[1]
r, g, b = background_color[0], background_color[1], background_color[2]
target_shifts = []
for c in range(channels):
if c == 0: # Luminance
brightness = (r + g + b) / 3.0
target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
elif c == 1: # Color channel 1
if g > 0.7 or (r > 0.7 and g > 0.7):
target_shift = target_shift_percent
elif r > 0.7 or b > 0.7:
target_shift = -target_shift_percent
else:
target_shift = 0.0
elif c == 2: # Color channel 2
if r > 0.7 and g < 0.3:
target_shift = -target_shift_percent
elif g > 0.7 and r < 0.3:
target_shift = target_shift_percent
else:
target_shift = 0.0
elif c == 3: # Secondary color
if b > 0.7:
target_shift = -target_shift_percent
elif r > 0.7 and g > 0.7:
target_shift = target_shift_percent * 0.8
else:
target_shift = 0.0
else:
target_shift = 0.0
target_shifts.append(target_shift)
return self.apply_color_shift(epsilon, background_color, target_shifts)
class TKGDMPipeline:
"""
Enhanced Diffusion pipeline with TKG-DM non-reactive noise
Supports multiple model architectures: SD 1.5, SDXL, SD 2.1, etc.
"""
# Model configurations for different architectures
MODEL_CONFIGS = {
'sd1.5': {
'pipeline_class': StableDiffusionPipeline,
'default_model': 'runwayml/stable-diffusion-v1-5',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (512, 512)
},
'sdxl': {
'pipeline_class': StableDiffusionXLPipeline,
'default_model': 'stabilityai/stable-diffusion-xl-base-1.0',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (1024, 1024)
},
'sd2.1': {
'pipeline_class': StableDiffusionPipeline,
'default_model': 'stabilityai/stable-diffusion-2-1',
'latent_channels': 4,
'latent_scale_factor': 8,
'default_size': (768, 768)
}
}
def __init__(self,
model_id: Optional[str] = None,
model_type: str = "sd1.5",
device: str = "cuda"):
"""
Initialize TKG-DM pipeline with specified model architecture
Args:
model_id: Specific model ID (optional, uses default for model_type)
model_type: Model architecture type ('sd1.5', 'sdxl', 'sd2.1')
device: Device to load model on
"""
self.device = device
self.dtype = torch.float16 if device == "cuda" else torch.float32
self.model_type = model_type
# Get model configuration
if model_type not in self.MODEL_CONFIGS:
raise ValueError(f"Unsupported model type: {model_type}. Supported: {list(self.MODEL_CONFIGS.keys())}")
self.config = self.MODEL_CONFIGS[model_type]
self.model_id = model_id or self.config['default_model']
# Auto-detect model type for custom models if needed
if model_id and model_id != self.config['default_model']:
detected_type = self._detect_model_type(model_id)
if detected_type and detected_type != model_type:
print(f"🔄 Auto-detected model type: {detected_type} for {model_id}")
self.model_type = detected_type
self.config = self.MODEL_CONFIGS[detected_type]
# Load the appropriate pipeline
self._load_pipeline()
# Initialize noise optimizer with model-specific parameters
self.noise_optimizer = TKGDMNoiseOptimizer(device, self.dtype)
def _detect_model_type(self, model_id: str) -> Optional[str]:
"""Auto-detect model type based on model ID patterns"""
model_id_lower = model_id.lower()
# SDXL detection patterns
if any(pattern in model_id_lower for pattern in ['xl', 'sdxl', 'stable-diffusion-xl']):
return 'sdxl'
# SD 2.x detection patterns
if any(pattern in model_id_lower for pattern in ['stable-diffusion-2', 'sd-2', 'v2-']):
return 'sd2.1'
# Default to SD 1.5 for most other cases
return 'sd1.5'
def _load_pipeline(self):
"""Load the diffusion pipeline based on model type"""
try:
pipeline_class = self.config['pipeline_class']
# Common pipeline arguments
pipeline_args = {
'torch_dtype': self.dtype,
'safety_checker': None,
'requires_safety_checker': False
}
# SDXL-specific arguments
if self.model_type == 'sdxl':
pipeline_args.update({
'use_safetensors': True,
'variant': "fp16" if self.dtype == torch.float16 else None
})
self.pipe = pipeline_class.from_pretrained(
self.model_id,
**pipeline_args
).to(self.device)
print(f"✅ Successfully loaded {self.model_type} model: {self.model_id}")
except Exception as e:
print(f"❌ Error loading {self.model_type} pipeline: {e}")
self.pipe = None
def get_latent_shape(self, height: int, width: int) -> Tuple[int, int, int, int]:
"""Get latent shape for given image dimensions"""
scale_factor = self.config['latent_scale_factor']
channels = self.config['latent_channels']
return (1, channels, height // scale_factor, width // scale_factor)
def __call__(self,
prompt: str,
negative_prompt: Optional[str] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
background_color: Optional[List[float]] = None, # Backwards compatibility
channel_shifts: Optional[List[float]] = None, # Direct channel control
bounding_boxes: Optional[List[Tuple[float, float, float, float]]] = None, # Space-aware boxes
foreground_center: Tuple[float, float] = (0.5, 0.5), # Legacy single region
foreground_size: float = 0.3, # Legacy single region
target_shift_percent: float = 0.07,
blur_sigma: Optional[float] = None, # Gaussian blur sigma
generator: Optional[torch.Generator] = None,
**kwargs) -> torch.Tensor:
"""
Generate image with space-aware text-to-image using TKG-DM (Section 2.2)
Args:
prompt: Text prompt for generation
negative_prompt: Negative prompt
height, width: Output dimensions (uses model defaults if None)
num_inference_steps: Denoising steps
guidance_scale: CFG guidance scale
background_color: Target background RGB (0-1) - for backwards compatibility
channel_shifts: Direct latent channel shift values [-1,1] for each channel
bounding_boxes: List of (x1,y1,x2,y2) normalized coords for reserved regions
foreground_center: Foreground object center (x, y) - legacy single region
foreground_size: Foreground region size - legacy single region
target_shift_percent: TKG-DM shift percentage (±7%)
blur_sigma: Gaussian blur sigma for soft transitions (auto if None)
generator: Random number generator
**kwargs: Additional model-specific parameters
Returns:
Generated image tensor
"""
if self.pipe is None:
raise RuntimeError("Pipeline not loaded successfully")
# Use model default dimensions if not specified
if height is None or width is None:
default_height, default_width = self.config['default_size']
height = height or default_height
width = width or default_width
# Get model-specific latent shape
noise_shape = self.get_latent_shape(height, width)
# Generate optimized initial noise
if bounding_boxes is not None:
# Use space-aware noise generation with reserved bounding boxes (Section 2.2)
optimized_noise = self.noise_optimizer.generate_space_aware_noise(
noise_shape,
bounding_boxes=bounding_boxes,
channel_shifts=channel_shifts,
background_color=background_color,
target_shift_percent=target_shift_percent,
blur_sigma=blur_sigma
)
elif channel_shifts is not None:
# Use direct channel shifts with legacy single region
optimized_noise = self.noise_optimizer.generate_direct_channel_noise(
noise_shape,
channel_shifts=channel_shifts,
foreground_center=foreground_center,
foreground_size=foreground_size,
target_shift_percent=target_shift_percent
)
else:
# Fallback to background color method with legacy single region
bg_color = background_color or [0.0, 1.0, 0.0]
optimized_noise = self.noise_optimizer.generate_chroma_key_noise(
noise_shape,
background_color=bg_color,
foreground_center=foreground_center,
foreground_size=foreground_size,
target_shift_percent=target_shift_percent
)
# Prepare pipeline arguments
pipe_args = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'height': height,
'width': width,
'num_inference_steps': num_inference_steps,
'guidance_scale': guidance_scale,
'latents': optimized_noise,
'generator': generator
}
# Add model-specific arguments
if self.model_type == 'sdxl':
# SDXL supports additional parameters
pipe_args.update({
'denoising_end': kwargs.get('denoising_end', None),
'guidance_scale_end': kwargs.get('guidance_scale_end', None),
'original_size': kwargs.get('original_size', (width, height)),
'target_size': kwargs.get('target_size', (width, height)),
'crops_coords_top_left': kwargs.get('crops_coords_top_left', (0, 0))
})
# Filter None values
pipe_args = {k: v for k, v in pipe_args.items() if v is not None}
# Generate image using optimized noise
with torch.no_grad():
result = self.pipe(**pipe_args)
image = result.images[0]
return image