File size: 31,239 Bytes
c8db08b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Tuple, List, Union
from diffusers import (
    StableDiffusionPipeline, 
    StableDiffusionXLPipeline,
    DiffusionPipeline
)
import math


class TKGDMNoiseOptimizer:
    """
    TKG-DM: Training-free Chroma Key Content Generation Diffusion Model
    Implementation of non-reactive noise optimization for background control
    """
    
    def __init__(self, device: str = "cuda", dtype: torch.dtype = None):
        self.device = device
        self.dtype = dtype or (torch.float16 if device == "cuda" else torch.float32)
        
    def calculate_initial_ratio(self, noise_tensor: torch.Tensor, channel: int) -> float:
        """Calculate initial positive ratio for a specific channel"""
        channel_data = noise_tensor[:, channel, :, :]
        positive_pixels = (channel_data > 0).sum().float()
        total_pixels = channel_data.numel()
        return (positive_pixels / total_pixels).item()
    
    def optimize_channel_shift(self, 
                             noise_tensor: torch.Tensor, 
                             channel: int, 
                             target_shift: float,
                             max_iterations: int = 100,
                             tolerance: float = 1e-4) -> float:
        """
        Optimize channel mean shift to achieve target positive ratio
        
        Args:
            noise_tensor: Initial noise tensor [B, C, H, W]
            channel: Channel index to optimize (0=R, 1=G, 2=B)
            target_shift: Desired shift in positive ratio
            max_iterations: Maximum optimization iterations
            tolerance: Convergence tolerance
            
        Returns:
            Optimal channel shift value
        """
        initial_ratio = self.calculate_initial_ratio(noise_tensor, channel)
        target_ratio = initial_ratio + target_shift
        
        # Binary search for optimal shift
        delta_min, delta_max = -10.0, 10.0
        delta_optimal = 0.0
        
        for _ in range(max_iterations):
            delta_mid = (delta_min + delta_max) / 2
            
            # Apply shift and calculate new ratio
            shifted_tensor = noise_tensor.clone()
            shifted_tensor[:, channel, :, :] += delta_mid
            current_ratio = self.calculate_initial_ratio(shifted_tensor, channel)
            
            if abs(current_ratio - target_ratio) < tolerance:
                delta_optimal = delta_mid
                break
                
            if current_ratio < target_ratio:
                delta_min = delta_mid
            else:
                delta_max = delta_mid
                
            delta_optimal = delta_mid
        
        return delta_optimal
    
    def create_bounding_box_mask(self,
                               height: int,
                               width: int,
                               bounding_boxes: List[Tuple[float, float, float, float]],
                               sigma: Optional[float] = None) -> torch.Tensor:
        """
        Create binary occupancy map from axis-aligned bounding boxes with Gaussian blur
        
        Args:
            height, width: Mask dimensions
            bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates [0,1]
            sigma: Gaussian blur standard deviation (auto-calculated if None)
            
        Returns:
            Soft transition mask M_blur with values in [0,1]
        """
        # Create binary occupancy map
        binary_mask = torch.zeros((height, width), device=self.device)
        
        for x1, y1, x2, y2 in bounding_boxes:
            # Convert normalized coordinates to pixel coordinates
            px1 = int(x1 * width)
            py1 = int(y1 * height)
            px2 = int(x2 * width)
            py2 = int(y2 * height)
            
            # Ensure valid bounds
            px1, px2 = max(0, min(px1, px2)), min(width, max(px1, px2))
            py1, py2 = max(0, min(py1, py2)), min(height, max(py1, py2))
            
            # Mark pixels inside bounding box
            binary_mask[py1:py2, px1:px2] = 1.0
        
        # Apply Gaussian blur for soft transition
        if sigma is None:
            # Standard deviation proportional to shorter side as per paper
            sigma = min(height, width) * 0.02  # 2% of shorter side
        
        # Convert to tensor format for Gaussian blur
        mask_4d = binary_mask.unsqueeze(0).unsqueeze(0)  # [1, 1, H, W]
        
        # Create Gaussian blur kernel
        kernel_size = int(6 * sigma + 1)  # 6-sigma rule
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        # Manual Gaussian blur implementation
        blur_mask = self._gaussian_blur_2d(mask_4d, kernel_size, sigma)
        
        return blur_mask.squeeze(0).squeeze(0)  # [H, W]
    
    def _gaussian_blur_2d(self, tensor: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
        """Apply 2D Gaussian blur to tensor"""
        # Create 1D Gaussian kernel
        coords = torch.arange(kernel_size, dtype=tensor.dtype, device=tensor.device)
        coords = coords - kernel_size // 2
        
        kernel = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        kernel = kernel / kernel.sum()
        kernel = kernel.view(1, 1, 1, -1)
        
        # Apply separable Gaussian blur
        # Horizontal blur
        padding = kernel_size // 2
        blurred = torch.nn.functional.conv2d(
            tensor, kernel, padding=(0, padding)
        )
        
        # Vertical blur
        kernel_t = kernel.transpose(-1, -2)
        blurred = torch.nn.functional.conv2d(
            blurred, kernel_t, padding=(padding, 0)
        )
        
        return blurred

    def create_gaussian_mask(self, 
                           height: int, 
                           width: int, 
                           center_x: float, 
                           center_y: float, 
                           sigma: float) -> torch.Tensor:
        """
        Create 2D Gaussian mask for noise blending
        
        Args:
            height, width: Mask dimensions
            center_x, center_y: Gaussian center (normalized 0-1)
            sigma: Gaussian spread parameter
            
        Returns:
            2D Gaussian mask tensor
        """
        y_coords = torch.linspace(0, 1, height).view(-1, 1)
        x_coords = torch.linspace(0, 1, width).view(1, -1)
        
        # Calculate distances from center
        dx = x_coords - center_x
        dy = y_coords - center_y
        
        # Gaussian formula
        mask = torch.exp(-((dx**2 + dy**2) / (2 * sigma**2)))
        
        return mask.to(self.device)
    
    def apply_color_shift(self, 
                         noise_tensor: torch.Tensor,
                         target_color: List[float],
                         target_shifts: List[float]) -> torch.Tensor:
        """
        Apply latent channel shifts to noise tensor
        
        Args:
            noise_tensor: Input noise [B, C, H, W] (typically 4 channels for SD latent)
            target_color: Target RGB color [R, G, B] (0-1 range)
            target_shifts: Channel shift amounts for all latent channels
            
        Returns:
            Color-shifted noise tensor
        """
        shifted_noise = noise_tensor.clone()
        
        # Apply shifts to all available latent channels
        num_channels = min(len(target_shifts), noise_tensor.shape[1])
        
        for channel in range(num_channels):
            if abs(target_shifts[channel]) > 1e-6:
                delta = self.optimize_channel_shift(
                    noise_tensor, channel, target_shifts[channel]
                )
                shifted_noise[:, channel, :, :] += delta
        
        return shifted_noise
    
    def blend_noise_with_mask(self, 
                            original_noise: torch.Tensor,
                            shifted_noise: torch.Tensor,
                            mask: torch.Tensor) -> torch.Tensor:
        """
        Blend original and shifted noise using space-aware formula from paper:
        ε_masked = ε + M_blur(ε_shifted - ε)
        
        Args:
            original_noise: Original Gaussian noise ε [B, C, H, W]
            shifted_noise: Non-reactive shifted noise ε_shifted [B, C, H, W]  
            mask: Soft transition mask M_blur [H, W] with values in [0,1]
            
        Returns:
            Composite noise tensor ε_masked
        """
        # Expand mask to match noise dimensions and ensure correct dtype
        mask_expanded = mask.unsqueeze(0).unsqueeze(0).to(original_noise.dtype)  # [1, 1, H, W]
        mask_expanded = mask_expanded.expand_as(original_noise)
        
        # Apply space-aware blending formula: ε_masked = ε + M_blur(ε_shifted - ε)
        # Inside reserved boxes (M_blur ≈ 1): dominated by mean-shifted noise (nonreactive)
        # Outside boxes (M_blur ≈ 0): reduces to ordinary Gaussian noise
        blended_noise = original_noise + mask_expanded * (shifted_noise - original_noise)
        
        return blended_noise
    
    def generate_chroma_key_noise(self,
                                 shape: Tuple[int, int, int, int],
                                 background_color: List[float] = [0.0, 1.0, 0.0],  # Green screen
                                 foreground_center: Tuple[float, float] = (0.5, 0.5),
                                 foreground_size: float = 0.3,
                                 target_shift_percent: float = 0.07) -> torch.Tensor:
        """
        Generate optimized noise for chroma key content generation using TKG-DM method
        Handles latent space channels correctly:
        - Channels 0,3: Luminance and color information 
        - Channels 1,2: Color channels (positive = pink/yellow, negative = red/blue)
        
        Args:
            shape: Noise tensor shape [B, C, H, W] (typically [1, 4, 64, 64] for SD)
            background_color: Target background RGB color (0-1)
            foreground_center: Center of foreground object (x, y) normalized
            foreground_size: Size of foreground region (sigma parameter)
            target_shift_percent: Target shift percentage (±7% as per paper)
            
        Returns:
            Optimized noise tensor for chroma key generation
        """
        batch_size, channels, height, width = shape
        
        # Generate initial random noise with correct dtype
        original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
        
        # Calculate target shifts for latent space channels based on color mapping
        target_shifts = []
        
        # Convert RGB to latent space color shifts
        r, g, b = background_color[0], background_color[1], background_color[2]
        
        for c in range(channels):
            initial_ratio = self.calculate_initial_ratio(original_noise, c)
            
            if c == 0:  # Channel 0: Luminance
                # Luminance based on overall brightness
                brightness = (r + g + b) / 3.0
                target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
                
            elif c == 1:  # Channel 1: Pink/Yellow vs Red/Blue (first color channel)
                # Pink/Yellow (positive) vs Red/Blue (negative)
                # Green and Yellow favor positive, Red and Blue favor negative
                if g > 0.7 or (r > 0.7 and g > 0.7):  # Green or Yellow
                    target_shift = target_shift_percent  # Positive for pink/yellow
                elif r > 0.7 or b > 0.7:  # Red or Blue
                    target_shift = -target_shift_percent  # Negative for red/blue
                else:
                    target_shift = 0.0  # Neutral for other colors
                    
            elif c == 2:  # Channel 2: Pink/Yellow vs Red/Blue (second color channel)
                # Complementary to channel 1 for full color control
                if r > 0.7 and g < 0.3:  # Pure red
                    target_shift = -target_shift_percent  # Negative for red
                elif g > 0.7 and r < 0.3:  # Pure green
                    target_shift = target_shift_percent  # Positive for green
                else:
                    target_shift = 0.0
                    
            elif c == 3:  # Channel 3: Additional luminance/color information
                # Secondary luminance control
                if b > 0.7:  # Blue background
                    target_shift = -target_shift_percent  # Negative for blue
                elif r > 0.7 and g > 0.7:  # Yellow/orange
                    target_shift = target_shift_percent * 0.8  # Moderate positive
                # contrast = max(r, g, b) - min(r, g, b)
                # target_shift = target_shift_percent * 0.5 if contrast > 0.5 else -target_shift_percent * 0.3
                
            else:
                target_shift = 0.0  # No shift for additional channels
                
            target_shifts.append(target_shift)
            print(f"Latent Channel {c}: Initial ratio={initial_ratio:.4f}, Target shift={target_shift:+.1%}")
        
        # Apply color shifts to create background-optimized noise
        shifted_noise = self.apply_color_shift(original_noise, background_color, target_shifts)
        
        # Create Gaussian mask for foreground/background separation
        mask = self.create_gaussian_mask(
            height, width, 
            foreground_center[0], foreground_center[1], 
            foreground_size
        )
        
        # Blend original and shifted noise
        optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
        
        return optimized_noise
    
    def generate_direct_channel_noise(self,
                                     shape: Tuple[int, int, int, int],
                                     channel_shifts: List[float],
                                     foreground_center: Tuple[float, float] = (0.5, 0.5),
                                     foreground_size: float = 0.3,
                                     target_shift_percent: float = 0.07) -> torch.Tensor:
        """
        Generate optimized noise with direct channel shift control
        
        Args:
            shape: Noise tensor shape [B, C, H, W]
            channel_shifts: Direct shift values for each channel [-1, 1]
            foreground_center: Center of foreground object (x, y) normalized
            foreground_size: Size of foreground region (sigma parameter)
            target_shift_percent: Base shift percentage (±7% as per paper)
            
        Returns:
            Optimized noise tensor with direct channel control
        """
        batch_size, channels, height, width = shape
        
        # Generate initial random noise with correct dtype
        original_noise = torch.randn(shape, device=self.device, dtype=self.dtype)
        
        # Convert user channel shifts to target shifts
        target_shifts = []
        for c in range(min(channels, len(channel_shifts))):
            # Scale user input (-1 to 1) to target shift percentage
            user_shift = channel_shifts[c]
            target_shift = user_shift * target_shift_percent
            target_shifts.append(target_shift)
            
            initial_ratio = self.calculate_initial_ratio(original_noise, c)
            print(f"Direct Channel {c}: User shift={user_shift:+.2f}, Target shift={target_shift:+.1%}, Initial ratio={initial_ratio:.4f}")
        
        # Fill remaining channels with zero shift
        while len(target_shifts) < channels:
            target_shifts.append(0.0)
        
        # Apply direct channel shifts
        shifted_noise = self.apply_color_shift(original_noise, [0.0, 0.0, 0.0], target_shifts)
        
        # Create Gaussian mask for foreground/background separation
        mask = self.create_gaussian_mask(
            height, width, 
            foreground_center[0], foreground_center[1], 
            foreground_size
        )
        
        # Blend original and shifted noise
        optimized_noise = self.blend_noise_with_mask(original_noise, shifted_noise, mask)
        
        return optimized_noise
    
    def generate_space_aware_noise(self,
                                  shape: Tuple[int, int, int, int],
                                  bounding_boxes: List[Tuple[float, float, float, float]],
                                  channel_shifts: Optional[List[float]] = None,
                                  background_color: Optional[List[float]] = None,
                                  target_shift_percent: float = 0.07,
                                  blur_sigma: Optional[float] = None) -> torch.Tensor:
        """
        Generate space-aware noise with reserved bounding boxes as per section 2.2
        
        Args:
            shape: Noise tensor shape [B, C, H, W]
            bounding_boxes: List of (x1, y1, x2, y2) normalized coordinates for reserved regions
            channel_shifts: Direct channel shifts (optional)
            background_color: RGB background color for shift calculation (fallback)
            target_shift_percent: Base shift percentage (±7% as per paper)
            blur_sigma: Gaussian blur sigma (auto-calculated if None)
            
        Returns:
            Space-aware composite noise tensor ε_masked
        """
        batch_size, channels, height, width = shape
        
        # Generate standard Gaussian tensor ε ~ N(0, I)
        epsilon = torch.randn(shape, device=self.device, dtype=self.dtype)
        
        # Generate non-reactive noise ε_shifted using TKG-DM
        if channel_shifts is not None:
            # Use direct channel control
            target_shifts = []
            for c in range(min(channels, len(channel_shifts))):
                user_shift = channel_shifts[c]
                target_shift = user_shift * target_shift_percent
                target_shifts.append(target_shift)
            
            # Fill remaining channels with zero shift
            while len(target_shifts) < channels:
                target_shifts.append(0.0)
                
            epsilon_shifted = self.apply_color_shift(epsilon, [0.0, 0.0, 0.0], target_shifts)
            
        elif background_color is not None:
            # Fallback to background color method
            epsilon_shifted = self._generate_color_shifted_noise(epsilon, background_color, target_shift_percent)
            
        else:
            raise ValueError("Either channel_shifts or background_color must be provided")
        
        # Create binary occupancy map and apply Gaussian blur: M_blur = GaussianBlur(M, σ)
        if not bounding_boxes:
            # No reserved boxes - return standard noise
            return epsilon
            
        mask_blur = self.create_bounding_box_mask(height, width, bounding_boxes, blur_sigma)
        
        # Apply space-aware blending: ε_masked = ε + M_blur(ε_shifted - ε)
        epsilon_masked = self.blend_noise_with_mask(epsilon, epsilon_shifted, mask_blur)
        
        print(f"🔲 Space-aware noise: {len(bounding_boxes)} reserved boxes, sigma={blur_sigma or 'auto'}")
        
        return epsilon_masked
    
    def _generate_color_shifted_noise(self, 
                                    epsilon: torch.Tensor,
                                    background_color: List[float],
                                    target_shift_percent: float) -> torch.Tensor:
        """Helper method to generate color-shifted noise from background color"""
        channels = epsilon.shape[1]
        r, g, b = background_color[0], background_color[1], background_color[2]
        
        target_shifts = []
        for c in range(channels):
            if c == 0:  # Luminance
                brightness = (r + g + b) / 3.0
                target_shift = target_shift_percent if brightness > 0.5 else -target_shift_percent
            elif c == 1:  # Color channel 1
                if g > 0.7 or (r > 0.7 and g > 0.7):
                    target_shift = target_shift_percent
                elif r > 0.7 or b > 0.7:
                    target_shift = -target_shift_percent
                else:
                    target_shift = 0.0
            elif c == 2:  # Color channel 2
                if r > 0.7 and g < 0.3:
                    target_shift = -target_shift_percent
                elif g > 0.7 and r < 0.3:
                    target_shift = target_shift_percent
                else:
                    target_shift = 0.0
            elif c == 3:  # Secondary color
                if b > 0.7:
                    target_shift = -target_shift_percent
                elif r > 0.7 and g > 0.7:
                    target_shift = target_shift_percent * 0.8
                else:
                    target_shift = 0.0
            else:
                target_shift = 0.0
            
            target_shifts.append(target_shift)
        
        return self.apply_color_shift(epsilon, background_color, target_shifts)


class TKGDMPipeline:
    """
    Enhanced Diffusion pipeline with TKG-DM non-reactive noise
    Supports multiple model architectures: SD 1.5, SDXL, SD 2.1, etc.
    """
    
    # Model configurations for different architectures
    MODEL_CONFIGS = {
        'sd1.5': {
            'pipeline_class': StableDiffusionPipeline,
            'default_model': 'runwayml/stable-diffusion-v1-5',
            'latent_channels': 4,
            'latent_scale_factor': 8,
            'default_size': (512, 512)
        },
        'sdxl': {
            'pipeline_class': StableDiffusionXLPipeline,
            'default_model': 'stabilityai/stable-diffusion-xl-base-1.0',
            'latent_channels': 4,
            'latent_scale_factor': 8,
            'default_size': (1024, 1024)
        },
        'sd2.1': {
            'pipeline_class': StableDiffusionPipeline,
            'default_model': 'stabilityai/stable-diffusion-2-1',
            'latent_channels': 4,
            'latent_scale_factor': 8,
            'default_size': (768, 768)
        }
    }
    
    def __init__(self, 
                 model_id: Optional[str] = None,
                 model_type: str = "sd1.5",
                 device: str = "cuda"):
        """
        Initialize TKG-DM pipeline with specified model architecture
        
        Args:
            model_id: Specific model ID (optional, uses default for model_type)
            model_type: Model architecture type ('sd1.5', 'sdxl', 'sd2.1')
            device: Device to load model on
        """
        self.device = device
        self.dtype = torch.float16 if device == "cuda" else torch.float32
        self.model_type = model_type
        
        # Get model configuration
        if model_type not in self.MODEL_CONFIGS:
            raise ValueError(f"Unsupported model type: {model_type}. Supported: {list(self.MODEL_CONFIGS.keys())}")
        
        self.config = self.MODEL_CONFIGS[model_type]
        self.model_id = model_id or self.config['default_model']
        
        # Auto-detect model type for custom models if needed
        if model_id and model_id != self.config['default_model']:
            detected_type = self._detect_model_type(model_id)
            if detected_type and detected_type != model_type:
                print(f"🔄 Auto-detected model type: {detected_type} for {model_id}")
                self.model_type = detected_type
                self.config = self.MODEL_CONFIGS[detected_type]
        
        # Load the appropriate pipeline
        self._load_pipeline()
        
        # Initialize noise optimizer with model-specific parameters
        self.noise_optimizer = TKGDMNoiseOptimizer(device, self.dtype)
    
    def _detect_model_type(self, model_id: str) -> Optional[str]:
        """Auto-detect model type based on model ID patterns"""
        model_id_lower = model_id.lower()
        
        # SDXL detection patterns
        if any(pattern in model_id_lower for pattern in ['xl', 'sdxl', 'stable-diffusion-xl']):
            return 'sdxl'
        
        # SD 2.x detection patterns  
        if any(pattern in model_id_lower for pattern in ['stable-diffusion-2', 'sd-2', 'v2-']):
            return 'sd2.1'
        
        # Default to SD 1.5 for most other cases
        return 'sd1.5'
    
    def _load_pipeline(self):
        """Load the diffusion pipeline based on model type"""
        try:
            pipeline_class = self.config['pipeline_class']
            
            # Common pipeline arguments
            pipeline_args = {
                'torch_dtype': self.dtype,
                'safety_checker': None,
                'requires_safety_checker': False
            }
            
            # SDXL-specific arguments
            if self.model_type == 'sdxl':
                pipeline_args.update({
                    'use_safetensors': True,
                    'variant': "fp16" if self.dtype == torch.float16 else None
                })
            
            self.pipe = pipeline_class.from_pretrained(
                self.model_id,
                **pipeline_args
            ).to(self.device)
            
            print(f"✅ Successfully loaded {self.model_type} model: {self.model_id}")
            
        except Exception as e:
            print(f"❌ Error loading {self.model_type} pipeline: {e}")
            self.pipe = None
    
    def get_latent_shape(self, height: int, width: int) -> Tuple[int, int, int, int]:
        """Get latent shape for given image dimensions"""
        scale_factor = self.config['latent_scale_factor']
        channels = self.config['latent_channels']
        return (1, channels, height // scale_factor, width // scale_factor)
        
    def __call__(self,
                 prompt: str,
                 negative_prompt: Optional[str] = None,
                 height: Optional[int] = None,
                 width: Optional[int] = None,
                 num_inference_steps: int = 50,
                 guidance_scale: float = 7.5,
                 background_color: Optional[List[float]] = None,  # Backwards compatibility
                 channel_shifts: Optional[List[float]] = None,  # Direct channel control
                 bounding_boxes: Optional[List[Tuple[float, float, float, float]]] = None,  # Space-aware boxes
                 foreground_center: Tuple[float, float] = (0.5, 0.5),  # Legacy single region
                 foreground_size: float = 0.3,  # Legacy single region
                 target_shift_percent: float = 0.07,
                 blur_sigma: Optional[float] = None,  # Gaussian blur sigma
                 generator: Optional[torch.Generator] = None,
                 **kwargs) -> torch.Tensor:
        """
        Generate image with space-aware text-to-image using TKG-DM (Section 2.2)
        
        Args:
            prompt: Text prompt for generation
            negative_prompt: Negative prompt
            height, width: Output dimensions (uses model defaults if None)
            num_inference_steps: Denoising steps
            guidance_scale: CFG guidance scale
            background_color: Target background RGB (0-1) - for backwards compatibility
            channel_shifts: Direct latent channel shift values [-1,1] for each channel
            bounding_boxes: List of (x1,y1,x2,y2) normalized coords for reserved regions
            foreground_center: Foreground object center (x, y) - legacy single region
            foreground_size: Foreground region size - legacy single region
            target_shift_percent: TKG-DM shift percentage (±7%)
            blur_sigma: Gaussian blur sigma for soft transitions (auto if None)
            generator: Random number generator
            **kwargs: Additional model-specific parameters
            
        Returns:
            Generated image tensor
        """
        if self.pipe is None:
            raise RuntimeError("Pipeline not loaded successfully")
        
        # Use model default dimensions if not specified
        if height is None or width is None:
            default_height, default_width = self.config['default_size']
            height = height or default_height
            width = width or default_width
        
        # Get model-specific latent shape
        noise_shape = self.get_latent_shape(height, width)
        
        # Generate optimized initial noise
        if bounding_boxes is not None:
            # Use space-aware noise generation with reserved bounding boxes (Section 2.2)
            optimized_noise = self.noise_optimizer.generate_space_aware_noise(
                noise_shape,
                bounding_boxes=bounding_boxes,
                channel_shifts=channel_shifts,
                background_color=background_color,
                target_shift_percent=target_shift_percent,
                blur_sigma=blur_sigma
            )
        elif channel_shifts is not None:
            # Use direct channel shifts with legacy single region
            optimized_noise = self.noise_optimizer.generate_direct_channel_noise(
                noise_shape,
                channel_shifts=channel_shifts,
                foreground_center=foreground_center,
                foreground_size=foreground_size,
                target_shift_percent=target_shift_percent
            )
        else:
            # Fallback to background color method with legacy single region
            bg_color = background_color or [0.0, 1.0, 0.0]
            optimized_noise = self.noise_optimizer.generate_chroma_key_noise(
                noise_shape,
                background_color=bg_color,
                foreground_center=foreground_center,
                foreground_size=foreground_size,
                target_shift_percent=target_shift_percent
            )
        
        # Prepare pipeline arguments
        pipe_args = {
            'prompt': prompt,
            'negative_prompt': negative_prompt,
            'height': height,
            'width': width,
            'num_inference_steps': num_inference_steps,
            'guidance_scale': guidance_scale,
            'latents': optimized_noise,
            'generator': generator
        }
        
        # Add model-specific arguments
        if self.model_type == 'sdxl':
            # SDXL supports additional parameters
            pipe_args.update({
                'denoising_end': kwargs.get('denoising_end', None),
                'guidance_scale_end': kwargs.get('guidance_scale_end', None),
                'original_size': kwargs.get('original_size', (width, height)),
                'target_size': kwargs.get('target_size', (width, height)),
                'crops_coords_top_left': kwargs.get('crops_coords_top_left', (0, 0))
            })
        
        # Filter None values
        pipe_args = {k: v for k, v in pipe_args.items() if v is not None}
        
        # Generate image using optimized noise
        with torch.no_grad():
            result = self.pipe(**pipe_args)
            image = result.images[0]
        
        return image