K00B404's picture
Create app_sd15_variant.py
b0f83cf verified
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler
from PIL import Image
import numpy as np
from typing import List, Optional, Dict, Any
from collections import deque
import cv2
class SimpleTemporalBuffer:
"""Simplified temporal buffer for SD1.5 img2img"""
def __init__(self, buffer_size: int = 6):
self.buffer_size = buffer_size
self.frames = deque(maxlen=buffer_size)
self.frame_embeddings = deque(maxlen=buffer_size)
self.motion_vectors = deque(maxlen=buffer_size-1)
def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None):
"""Add frame to buffer"""
# Calculate optical flow if we have previous frames
if len(self.frames) > 0:
prev_frame = np.array(self.frames[-1])
curr_frame = np.array(frame)
# Convert to grayscale for optical flow
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY)
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY)
# Calculate optical flow
flow = cv2.calcOpticalFlowPyrLK(
prev_gray, curr_gray,
np.array([[frame.width//2, frame.height//2]], dtype=np.float32),
None
)[0]
if flow is not None:
motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2])
self.motion_vectors.append(motion_magnitude)
self.frames.append(frame)
if embedding is not None:
self.frame_embeddings.append(embedding)
def get_reference_frame(self) -> Optional[Image.Image]:
"""Get most recent frame as reference"""
return self.frames[-1] if self.frames else None
def get_motion_context(self) -> Dict[str, Any]:
"""Get motion context for next frame generation"""
if len(self.motion_vectors) == 0:
return {"has_motion": False, "predicted_motion": 0.0}
# Simple motion prediction based on recent vectors
recent_motion = list(self.motion_vectors)[-3:] # Last 3 motions
avg_motion = np.mean(recent_motion)
motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0
predicted_motion = avg_motion + motion_trend * 0.5
return {
"has_motion": True,
"current_motion": avg_motion,
"predicted_motion": predicted_motion,
"motion_trend": motion_trend,
"motion_history": recent_motion
}
class SD15FlexibleI2VGenerator:
"""Flexible I2V generator using SD1.5 img2img pipeline"""
def __init__(
self,
model_id: str = "runwayml/stable-diffusion-v1-5",
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
self.device = device
print(f"πŸš€ Loading SD1.5 pipeline on {device}...")
# Load pipeline with DDIM scheduler for better img2img
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
# Use DDIM for more consistent results
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe = self.pipe.to(device)
# Enable memory efficient attention
if device == "cuda":
self.pipe.enable_attention_slicing()
try:
self.pipe.enable_xformers_memory_efficient_attention()
except:
print("⚠️ xformers not available, using standard attention")
self.temporal_buffer = SimpleTemporalBuffer()
print("βœ… SD1.5 Flexible I2V Generator ready!")
def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float:
"""Calculate adaptive denoising strength based on motion"""
if not motion_context.get("has_motion", False):
return base_strength
motion = motion_context["current_motion"]
# More motion = less strength (preserve more of previous frame)
# Less motion = more strength (allow more change)
motion_factor = np.clip(motion / 50.0, 0.0, 1.0) # Normalize motion
adaptive_strength = base_strength * (1.0 - motion_factor * 0.3)
return np.clip(adaptive_strength, 0.3, 0.9)
def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str:
"""Enhance prompt based on motion context"""
if not motion_context.get("has_motion", False):
return base_prompt
motion = motion_context["current_motion"]
trend = motion_context.get("motion_trend", 0)
# Add motion descriptors based on analysis
if motion > 30:
if trend > 5:
motion_desc = ", fast movement, dynamic motion, motion blur"
else:
motion_desc = ", steady movement, continuous motion"
elif motion > 10:
motion_desc = ", gentle movement, smooth transition"
else:
motion_desc = ", subtle movement, slight change"
return base_prompt + motion_desc
def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image:
"""Blend current frame with reference for temporal consistency"""
current_array = np.array(current_frame, dtype=np.float32)
reference_array = np.array(reference_frame, dtype=np.float32)
# Blend frames
blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8)
return Image.fromarray(blended_array)
@torch.no_grad()
def generate_frame_batch(
self,
init_image: Image.Image,
prompt: str,
num_frames: int = 1,
strength: float = 0.75,
guidance_scale: float = 7.5,
num_inference_steps: int = 20,
generator: Optional[torch.Generator] = None
) -> List[Image.Image]:
"""Generate a batch of frames using img2img"""
frames = []
current_image = init_image
for i in range(num_frames):
print(f" 🎬 Generating frame {i+1}/{num_frames}")
# Get motion context
motion_context = self.temporal_buffer.get_motion_context()
# Adaptive parameters based on motion
adaptive_strength = self.calculate_adaptive_strength(motion_context, strength)
enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context)
print(f" πŸ“Š Motion: {motion_context.get('current_motion', 0):.1f}, Strength: {adaptive_strength:.2f}")
# Generate frame
result = self.pipe(
prompt=enhanced_prompt,
image=current_image,
strength=adaptive_strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
)
generated_frame = result.images[0]
# Apply temporal consistency blending
if len(self.temporal_buffer.frames) > 0:
reference_frame = self.temporal_buffer.get_reference_frame()
blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2
generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio)
# Update buffer
self.temporal_buffer.add_frame(generated_frame)
frames.append(generated_frame)
# Use generated frame as input for next iteration
current_image = generated_frame
return frames
def generate_i2v_sequence(
self,
init_image: Image.Image,
prompt: str,
total_frames: int = 16,
frames_per_batch: int = 2, # Key parameter - batch size!
strength: float = 0.75,
guidance_scale: float = 7.5,
num_inference_steps: int = 20,
seed: Optional[int] = None
) -> List[Image.Image]:
"""Generate I2V sequence with flexible batch sizes"""
print(f"🎯 Generating {total_frames} frames in batches of {frames_per_batch}")
# Setup generator
generator = torch.Generator(device=self.device)
if seed is not None:
generator.manual_seed(seed)
# Reset temporal buffer and add initial frame
self.temporal_buffer = SimpleTemporalBuffer()
self.temporal_buffer.add_frame(init_image)
all_frames = [init_image] # Start with initial frame
frames_generated = 1
current_reference = init_image
# Generate in batches
while frames_generated < total_frames:
remaining_frames = total_frames - frames_generated
current_batch_size = min(frames_per_batch, remaining_frames)
print(f"πŸš€ Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}")
# Generate batch
batch_frames = self.generate_frame_batch(
init_image=current_reference,
prompt=prompt,
num_frames=current_batch_size,
strength=strength,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
)
# Add to results
all_frames.extend(batch_frames)
frames_generated += current_batch_size
# Update reference for next batch
current_reference = batch_frames[-1]
print(f"βœ… Completed batch - {frames_generated}/{total_frames} frames done")
return all_frames
def generate_with_variable_batching(
self,
init_image: Image.Image,
prompt: str,
total_frames: int = 24,
batch_pattern: List[int] = [1, 2, 3, 2, 2, 1], # Variable batch sizes
**kwargs
) -> List[Image.Image]:
"""Generate with dynamic batch pattern"""
print(f"🎨 Generating {total_frames} frames with pattern: {batch_pattern}")
all_frames = [init_image]
frames_generated = 1
current_reference = init_image
pattern_idx = 0
while frames_generated < total_frames:
# Get current batch size from pattern
current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)]
remaining_frames = total_frames - frames_generated
actual_batch_size = min(current_batch_size, remaining_frames)
print(f"πŸ“Š Pattern step {pattern_idx+1}: {actual_batch_size} frames")
# Generate batch with current settings
batch_frames = self.generate_frame_batch(
init_image=current_reference,
prompt=prompt,
num_frames=actual_batch_size,
**kwargs
)
all_frames.extend(batch_frames)
frames_generated += actual_batch_size
current_reference = batch_frames[-1]
pattern_idx += 1
return all_frames[:total_frames+1] # Include initial frame
def save_frames_as_gif(frames: List[Image.Image], output_path: str, duration: int = 100):
"""Save frames as animated GIF"""
frames[0].save(
output_path,
save_all=True,
append_images=frames[1:],
duration=duration,
loop=0
)
print(f"πŸ’Ύ Saved {len(frames)} frames to {output_path}")
def save_frames_as_video(frames: List[Image.Image], output_path: str, fps: int = 8):
"""Save frames as MP4 video"""
try:
import imageio
with imageio.get_writer(output_path, fps=fps) as writer:
for frame in frames:
writer.append_data(np.array(frame))
print(f"πŸŽ₯ Saved {len(frames)} frames to {output_path}")
except ImportError:
print("⚠️ imageio not available, saving as GIF instead")
save_frames_as_gif(frames, output_path.replace('.mp4', '.gif'))
# Example usage
def example_usage():
"""Example of how to use the SD1.5 Flexible I2V Generator"""
# Initialize generator
generator = SD15FlexibleI2VGenerator()
# Load or create initial image
# For demo, create a simple colored image
init_image = Image.new('RGB', (512, 512), color='lightblue')
# Example 1: Fixed batch size generation
print("\n🎬 Example 1: Fixed batch size (2 frames per batch)")
frames_fixed = generator.generate_i2v_sequence(
init_image=init_image,
prompt="a peaceful lake with gentle ripples, cinematic lighting",
total_frames=8,
frames_per_batch=2, # Generate 2 frames at a time
strength=0.7,
guidance_scale=7.5,
num_inference_steps=20,
seed=42
)
# Example 2: Variable batch pattern
print("\n🎨 Example 2: Variable batch pattern")
frames_variable = generator.generate_with_variable_batching(
init_image=init_image,
prompt="a cat walking through a garden, smooth motion",
total_frames=12,
batch_pattern=[1, 2, 3, 2, 1], # Start slow, ramp up, slow down
strength=0.75,
guidance_scale=7.5,
seed=123
)
# Save results
save_frames_as_gif(frames_fixed, "sd15_fixed_batch.gif", duration=200)
save_frames_as_gif(frames_variable, "sd15_variable_batch.gif", duration=150)
print("\nπŸŽ‰ SD1.5 Flexible Batch I2V Generation Complete!")
print("✨ Key innovations demonstrated:")
print(" - Flexible batch sizing (frames_per_batch parameter)")
print(" - Motion-aware adaptive strength")
print(" - Temporal consistency through frame blending")
print(" - Variable batch patterns for dynamic control")
print(" - Optical flow-based motion analysis")
if __name__ == "__main__":
example_usage()