File size: 14,620 Bytes
b0f83cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler
from PIL import Image
import numpy as np
from typing import List, Optional, Dict, Any
from collections import deque
import cv2

class SimpleTemporalBuffer:
    """Simplified temporal buffer for SD1.5 img2img"""
    
    def __init__(self, buffer_size: int = 6):
        self.buffer_size = buffer_size
        self.frames = deque(maxlen=buffer_size)
        self.frame_embeddings = deque(maxlen=buffer_size)
        self.motion_vectors = deque(maxlen=buffer_size-1)
        
    def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None):
        """Add frame to buffer"""
        # Calculate optical flow if we have previous frames
        if len(self.frames) > 0:
            prev_frame = np.array(self.frames[-1])
            curr_frame = np.array(frame)
            
            # Convert to grayscale for optical flow
            prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
            curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY)
            
            # Calculate optical flow
            flow = cv2.calcOpticalFlowPyrLK(
                prev_gray, curr_gray, 
                np.array([[frame.width//2, frame.height//2]], dtype=np.float32),
                None
            )[0]
            
            if flow is not None:
                motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2])
                self.motion_vectors.append(motion_magnitude)
        
        self.frames.append(frame)
        if embedding is not None:
            self.frame_embeddings.append(embedding)
    
    def get_reference_frame(self) -> Optional[Image.Image]:
        """Get most recent frame as reference"""
        return self.frames[-1] if self.frames else None
    
    def get_motion_context(self) -> Dict[str, Any]:
        """Get motion context for next frame generation"""
        if len(self.motion_vectors) == 0:
            return {"has_motion": False, "predicted_motion": 0.0}
        
        # Simple motion prediction based on recent vectors
        recent_motion = list(self.motion_vectors)[-3:]  # Last 3 motions
        avg_motion = np.mean(recent_motion)
        motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0
        
        predicted_motion = avg_motion + motion_trend * 0.5
        
        return {
            "has_motion": True,
            "current_motion": avg_motion,
            "predicted_motion": predicted_motion,
            "motion_trend": motion_trend,
            "motion_history": recent_motion
        }

class SD15FlexibleI2VGenerator:
    """Flexible I2V generator using SD1.5 img2img pipeline"""
    
    def __init__(
        self,
        model_id: str = "runwayml/stable-diffusion-v1-5",
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.device = device
        print(f"πŸš€ Loading SD1.5 pipeline on {device}...")
        
        # Load pipeline with DDIM scheduler for better img2img
        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            safety_checker=None,
            requires_safety_checker=False
        )
        
        # Use DDIM for more consistent results
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.pipe = self.pipe.to(device)
        
        # Enable memory efficient attention
        if device == "cuda":
            self.pipe.enable_attention_slicing()
            try:
                self.pipe.enable_xformers_memory_efficient_attention()
            except:
                print("⚠️  xformers not available, using standard attention")
        
        self.temporal_buffer = SimpleTemporalBuffer()
        print("βœ… SD1.5 Flexible I2V Generator ready!")
    
    def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float:
        """Calculate adaptive denoising strength based on motion"""
        if not motion_context.get("has_motion", False):
            return base_strength
        
        motion = motion_context["current_motion"]
        
        # More motion = less strength (preserve more of previous frame)
        # Less motion = more strength (allow more change)
        motion_factor = np.clip(motion / 50.0, 0.0, 1.0)  # Normalize motion
        adaptive_strength = base_strength * (1.0 - motion_factor * 0.3)
        
        return np.clip(adaptive_strength, 0.3, 0.9)
    
    def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str:
        """Enhance prompt based on motion context"""
        if not motion_context.get("has_motion", False):
            return base_prompt
        
        motion = motion_context["current_motion"]
        trend = motion_context.get("motion_trend", 0)
        
        # Add motion descriptors based on analysis
        if motion > 30:
            if trend > 5:
                motion_desc = ", fast movement, dynamic motion, motion blur"
            else:
                motion_desc = ", steady movement, continuous motion"
        elif motion > 10:
            motion_desc = ", gentle movement, smooth transition"
        else:
            motion_desc = ", subtle movement, slight change"
        
        return base_prompt + motion_desc
    
    def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image:
        """Blend current frame with reference for temporal consistency"""
        current_array = np.array(current_frame, dtype=np.float32)
        reference_array = np.array(reference_frame, dtype=np.float32)
        
        # Blend frames
        blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio
        blended_array = np.clip(blended_array, 0, 255).astype(np.uint8)
        
        return Image.fromarray(blended_array)
    
    @torch.no_grad()
    def generate_frame_batch(
        self,
        init_image: Image.Image,
        prompt: str,
        num_frames: int = 1,
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 20,
        generator: Optional[torch.Generator] = None
    ) -> List[Image.Image]:
        """Generate a batch of frames using img2img"""
        
        frames = []
        current_image = init_image
        
        for i in range(num_frames):
            print(f"  🎬 Generating frame {i+1}/{num_frames}")
            
            # Get motion context
            motion_context = self.temporal_buffer.get_motion_context()
            
            # Adaptive parameters based on motion
            adaptive_strength = self.calculate_adaptive_strength(motion_context, strength)
            enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context)
            
            print(f"    πŸ“Š Motion: {motion_context.get('current_motion', 0):.1f}, Strength: {adaptive_strength:.2f}")
            
            # Generate frame
            result = self.pipe(
                prompt=enhanced_prompt,
                image=current_image,
                strength=adaptive_strength,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator
            )
            
            generated_frame = result.images[0]
            
            # Apply temporal consistency blending
            if len(self.temporal_buffer.frames) > 0:
                reference_frame = self.temporal_buffer.get_reference_frame()
                blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2
                generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio)
            
            # Update buffer
            self.temporal_buffer.add_frame(generated_frame)
            frames.append(generated_frame)
            
            # Use generated frame as input for next iteration
            current_image = generated_frame
        
        return frames
    
    def generate_i2v_sequence(
        self,
        init_image: Image.Image,
        prompt: str,
        total_frames: int = 16,
        frames_per_batch: int = 2,  # Key parameter - batch size!
        strength: float = 0.75,
        guidance_scale: float = 7.5,
        num_inference_steps: int = 20,
        seed: Optional[int] = None
    ) -> List[Image.Image]:
        """Generate I2V sequence with flexible batch sizes"""
        
        print(f"🎯 Generating {total_frames} frames in batches of {frames_per_batch}")
        
        # Setup generator
        generator = torch.Generator(device=self.device)
        if seed is not None:
            generator.manual_seed(seed)
        
        # Reset temporal buffer and add initial frame
        self.temporal_buffer = SimpleTemporalBuffer()
        self.temporal_buffer.add_frame(init_image)
        
        all_frames = [init_image]  # Start with initial frame
        frames_generated = 1
        current_reference = init_image
        
        # Generate in batches
        while frames_generated < total_frames:
            remaining_frames = total_frames - frames_generated
            current_batch_size = min(frames_per_batch, remaining_frames)
            
            print(f"πŸš€ Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}")
            
            # Generate batch
            batch_frames = self.generate_frame_batch(
                init_image=current_reference,
                prompt=prompt,
                num_frames=current_batch_size,
                strength=strength,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                generator=generator
            )
            
            # Add to results
            all_frames.extend(batch_frames)
            frames_generated += current_batch_size
            
            # Update reference for next batch
            current_reference = batch_frames[-1]
            
            print(f"βœ… Completed batch - {frames_generated}/{total_frames} frames done")
        
        return all_frames
    
    def generate_with_variable_batching(
        self,
        init_image: Image.Image,
        prompt: str,
        total_frames: int = 24,
        batch_pattern: List[int] = [1, 2, 3, 2, 2, 1],  # Variable batch sizes
        **kwargs
    ) -> List[Image.Image]:
        """Generate with dynamic batch pattern"""
        
        print(f"🎨 Generating {total_frames} frames with pattern: {batch_pattern}")
        
        all_frames = [init_image]
        frames_generated = 1
        current_reference = init_image
        pattern_idx = 0
        
        while frames_generated < total_frames:
            # Get current batch size from pattern
            current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)]
            remaining_frames = total_frames - frames_generated
            actual_batch_size = min(current_batch_size, remaining_frames)
            
            print(f"πŸ“Š Pattern step {pattern_idx+1}: {actual_batch_size} frames")
            
            # Generate batch with current settings
            batch_frames = self.generate_frame_batch(
                init_image=current_reference,
                prompt=prompt,
                num_frames=actual_batch_size,
                **kwargs
            )
            
            all_frames.extend(batch_frames)
            frames_generated += actual_batch_size
            current_reference = batch_frames[-1]
            pattern_idx += 1
        
        return all_frames[:total_frames+1]  # Include initial frame

def save_frames_as_gif(frames: List[Image.Image], output_path: str, duration: int = 100):
    """Save frames as animated GIF"""
    frames[0].save(
        output_path,
        save_all=True,
        append_images=frames[1:],
        duration=duration,
        loop=0
    )
    print(f"πŸ’Ύ Saved {len(frames)} frames to {output_path}")

def save_frames_as_video(frames: List[Image.Image], output_path: str, fps: int = 8):
    """Save frames as MP4 video"""
    try:
        import imageio
        with imageio.get_writer(output_path, fps=fps) as writer:
            for frame in frames:
                writer.append_data(np.array(frame))
        print(f"πŸŽ₯ Saved {len(frames)} frames to {output_path}")
    except ImportError:
        print("⚠️  imageio not available, saving as GIF instead")
        save_frames_as_gif(frames, output_path.replace('.mp4', '.gif'))

# Example usage
def example_usage():
    """Example of how to use the SD1.5 Flexible I2V Generator"""
    
    # Initialize generator
    generator = SD15FlexibleI2VGenerator()
    
    # Load or create initial image
    # For demo, create a simple colored image
    init_image = Image.new('RGB', (512, 512), color='lightblue')
    
    # Example 1: Fixed batch size generation
    print("\n🎬 Example 1: Fixed batch size (2 frames per batch)")
    frames_fixed = generator.generate_i2v_sequence(
        init_image=init_image,
        prompt="a peaceful lake with gentle ripples, cinematic lighting",
        total_frames=8,
        frames_per_batch=2,  # Generate 2 frames at a time
        strength=0.7,
        guidance_scale=7.5,
        num_inference_steps=20,
        seed=42
    )
    
    # Example 2: Variable batch pattern
    print("\n🎨 Example 2: Variable batch pattern")
    frames_variable = generator.generate_with_variable_batching(
        init_image=init_image,
        prompt="a cat walking through a garden, smooth motion",
        total_frames=12,
        batch_pattern=[1, 2, 3, 2, 1],  # Start slow, ramp up, slow down
        strength=0.75,
        guidance_scale=7.5,
        seed=123
    )
    
    # Save results
    save_frames_as_gif(frames_fixed, "sd15_fixed_batch.gif", duration=200)
    save_frames_as_gif(frames_variable, "sd15_variable_batch.gif", duration=150)
    
    print("\nπŸŽ‰ SD1.5 Flexible Batch I2V Generation Complete!")
    print("✨ Key innovations demonstrated:")
    print("  - Flexible batch sizing (frames_per_batch parameter)")
    print("  - Motion-aware adaptive strength")
    print("  - Temporal consistency through frame blending")
    print("  - Variable batch patterns for dynamic control")
    print("  - Optical flow-based motion analysis")

if __name__ == "__main__":
    example_usage()