import torch import torch.nn.functional as F from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler from PIL import Image import numpy as np from typing import List, Optional, Dict, Any from collections import deque import cv2 class SimpleTemporalBuffer: """Simplified temporal buffer for SD1.5 img2img""" def __init__(self, buffer_size: int = 6): self.buffer_size = buffer_size self.frames = deque(maxlen=buffer_size) self.frame_embeddings = deque(maxlen=buffer_size) self.motion_vectors = deque(maxlen=buffer_size-1) def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None): """Add frame to buffer""" # Calculate optical flow if we have previous frames if len(self.frames) > 0: prev_frame = np.array(self.frames[-1]) curr_frame = np.array(frame) # Convert to grayscale for optical flow prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY) curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY) # Calculate optical flow flow = cv2.calcOpticalFlowPyrLK( prev_gray, curr_gray, np.array([[frame.width//2, frame.height//2]], dtype=np.float32), None )[0] if flow is not None: motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2]) self.motion_vectors.append(motion_magnitude) self.frames.append(frame) if embedding is not None: self.frame_embeddings.append(embedding) def get_reference_frame(self) -> Optional[Image.Image]: """Get most recent frame as reference""" return self.frames[-1] if self.frames else None def get_motion_context(self) -> Dict[str, Any]: """Get motion context for next frame generation""" if len(self.motion_vectors) == 0: return {"has_motion": False, "predicted_motion": 0.0} # Simple motion prediction based on recent vectors recent_motion = list(self.motion_vectors)[-3:] # Last 3 motions avg_motion = np.mean(recent_motion) motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0 predicted_motion = avg_motion + motion_trend * 0.5 return { "has_motion": True, "current_motion": avg_motion, "predicted_motion": predicted_motion, "motion_trend": motion_trend, "motion_history": recent_motion } class SD15FlexibleI2VGenerator: """Flexible I2V generator using SD1.5 img2img pipeline""" def __init__( self, model_id: str = "runwayml/stable-diffusion-v1-5", device: str = "cuda" if torch.cuda.is_available() else "cpu" ): self.device = device print(f"šŸš€ Loading SD1.5 pipeline on {device}...") # Load pipeline with DDIM scheduler for better img2img self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if device == "cuda" else torch.float32, safety_checker=None, requires_safety_checker=False ) # Use DDIM for more consistent results self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) self.pipe = self.pipe.to(device) # Enable memory efficient attention if device == "cuda": self.pipe.enable_attention_slicing() try: self.pipe.enable_xformers_memory_efficient_attention() except: print("āš ļø xformers not available, using standard attention") self.temporal_buffer = SimpleTemporalBuffer() print("āœ… SD1.5 Flexible I2V Generator ready!") def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float: """Calculate adaptive denoising strength based on motion""" if not motion_context.get("has_motion", False): return base_strength motion = motion_context["current_motion"] # More motion = less strength (preserve more of previous frame) # Less motion = more strength (allow more change) motion_factor = np.clip(motion / 50.0, 0.0, 1.0) # Normalize motion adaptive_strength = base_strength * (1.0 - motion_factor * 0.3) return np.clip(adaptive_strength, 0.3, 0.9) def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str: """Enhance prompt based on motion context""" if not motion_context.get("has_motion", False): return base_prompt motion = motion_context["current_motion"] trend = motion_context.get("motion_trend", 0) # Add motion descriptors based on analysis if motion > 30: if trend > 5: motion_desc = ", fast movement, dynamic motion, motion blur" else: motion_desc = ", steady movement, continuous motion" elif motion > 10: motion_desc = ", gentle movement, smooth transition" else: motion_desc = ", subtle movement, slight change" return base_prompt + motion_desc def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image: """Blend current frame with reference for temporal consistency""" current_array = np.array(current_frame, dtype=np.float32) reference_array = np.array(reference_frame, dtype=np.float32) # Blend frames blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) return Image.fromarray(blended_array) @torch.no_grad() def generate_frame_batch( self, init_image: Image.Image, prompt: str, num_frames: int = 1, strength: float = 0.75, guidance_scale: float = 7.5, num_inference_steps: int = 20, generator: Optional[torch.Generator] = None ) -> List[Image.Image]: """Generate a batch of frames using img2img""" frames = [] current_image = init_image for i in range(num_frames): print(f" šŸŽ¬ Generating frame {i+1}/{num_frames}") # Get motion context motion_context = self.temporal_buffer.get_motion_context() # Adaptive parameters based on motion adaptive_strength = self.calculate_adaptive_strength(motion_context, strength) enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context) print(f" šŸ“Š Motion: {motion_context.get('current_motion', 0):.1f}, Strength: {adaptive_strength:.2f}") # Generate frame result = self.pipe( prompt=enhanced_prompt, image=current_image, strength=adaptive_strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ) generated_frame = result.images[0] # Apply temporal consistency blending if len(self.temporal_buffer.frames) > 0: reference_frame = self.temporal_buffer.get_reference_frame() blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2 generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio) # Update buffer self.temporal_buffer.add_frame(generated_frame) frames.append(generated_frame) # Use generated frame as input for next iteration current_image = generated_frame return frames def generate_i2v_sequence( self, init_image: Image.Image, prompt: str, total_frames: int = 16, frames_per_batch: int = 2, # Key parameter - batch size! strength: float = 0.75, guidance_scale: float = 7.5, num_inference_steps: int = 20, seed: Optional[int] = None ) -> List[Image.Image]: """Generate I2V sequence with flexible batch sizes""" print(f"šŸŽÆ Generating {total_frames} frames in batches of {frames_per_batch}") # Setup generator generator = torch.Generator(device=self.device) if seed is not None: generator.manual_seed(seed) # Reset temporal buffer and add initial frame self.temporal_buffer = SimpleTemporalBuffer() self.temporal_buffer.add_frame(init_image) all_frames = [init_image] # Start with initial frame frames_generated = 1 current_reference = init_image # Generate in batches while frames_generated < total_frames: remaining_frames = total_frames - frames_generated current_batch_size = min(frames_per_batch, remaining_frames) print(f"šŸš€ Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}") # Generate batch batch_frames = self.generate_frame_batch( init_image=current_reference, prompt=prompt, num_frames=current_batch_size, strength=strength, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, generator=generator ) # Add to results all_frames.extend(batch_frames) frames_generated += current_batch_size # Update reference for next batch current_reference = batch_frames[-1] print(f"āœ… Completed batch - {frames_generated}/{total_frames} frames done") return all_frames def generate_with_variable_batching( self, init_image: Image.Image, prompt: str, total_frames: int = 24, batch_pattern: List[int] = [1, 2, 3, 2, 2, 1], # Variable batch sizes **kwargs ) -> List[Image.Image]: """Generate with dynamic batch pattern""" print(f"šŸŽØ Generating {total_frames} frames with pattern: {batch_pattern}") all_frames = [init_image] frames_generated = 1 current_reference = init_image pattern_idx = 0 while frames_generated < total_frames: # Get current batch size from pattern current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)] remaining_frames = total_frames - frames_generated actual_batch_size = min(current_batch_size, remaining_frames) print(f"šŸ“Š Pattern step {pattern_idx+1}: {actual_batch_size} frames") # Generate batch with current settings batch_frames = self.generate_frame_batch( init_image=current_reference, prompt=prompt, num_frames=actual_batch_size, **kwargs ) all_frames.extend(batch_frames) frames_generated += actual_batch_size current_reference = batch_frames[-1] pattern_idx += 1 return all_frames[:total_frames+1] # Include initial frame def save_frames_as_gif(frames: List[Image.Image], output_path: str, duration: int = 100): """Save frames as animated GIF""" frames[0].save( output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0 ) print(f"šŸ’¾ Saved {len(frames)} frames to {output_path}") def save_frames_as_video(frames: List[Image.Image], output_path: str, fps: int = 8): """Save frames as MP4 video""" try: import imageio with imageio.get_writer(output_path, fps=fps) as writer: for frame in frames: writer.append_data(np.array(frame)) print(f"šŸŽ„ Saved {len(frames)} frames to {output_path}") except ImportError: print("āš ļø imageio not available, saving as GIF instead") save_frames_as_gif(frames, output_path.replace('.mp4', '.gif')) # Example usage def example_usage(): """Example of how to use the SD1.5 Flexible I2V Generator""" # Initialize generator generator = SD15FlexibleI2VGenerator() # Load or create initial image # For demo, create a simple colored image init_image = Image.new('RGB', (512, 512), color='lightblue') # Example 1: Fixed batch size generation print("\nšŸŽ¬ Example 1: Fixed batch size (2 frames per batch)") frames_fixed = generator.generate_i2v_sequence( init_image=init_image, prompt="a peaceful lake with gentle ripples, cinematic lighting", total_frames=8, frames_per_batch=2, # Generate 2 frames at a time strength=0.7, guidance_scale=7.5, num_inference_steps=20, seed=42 ) # Example 2: Variable batch pattern print("\nšŸŽØ Example 2: Variable batch pattern") frames_variable = generator.generate_with_variable_batching( init_image=init_image, prompt="a cat walking through a garden, smooth motion", total_frames=12, batch_pattern=[1, 2, 3, 2, 1], # Start slow, ramp up, slow down strength=0.75, guidance_scale=7.5, seed=123 ) # Save results save_frames_as_gif(frames_fixed, "sd15_fixed_batch.gif", duration=200) save_frames_as_gif(frames_variable, "sd15_variable_batch.gif", duration=150) print("\nšŸŽ‰ SD1.5 Flexible Batch I2V Generation Complete!") print("✨ Key innovations demonstrated:") print(" - Flexible batch sizing (frames_per_batch parameter)") print(" - Motion-aware adaptive strength") print(" - Temporal consistency through frame blending") print(" - Variable batch patterns for dynamic control") print(" - Optical flow-based motion analysis") if __name__ == "__main__": example_usage()