|
import torch |
|
import torch.nn.functional as F |
|
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler |
|
from PIL import Image |
|
import numpy as np |
|
from typing import List, Optional, Dict, Any |
|
from collections import deque |
|
import cv2 |
|
|
|
class SimpleTemporalBuffer: |
|
"""Simplified temporal buffer for SD1.5 img2img""" |
|
|
|
def __init__(self, buffer_size: int = 6): |
|
self.buffer_size = buffer_size |
|
self.frames = deque(maxlen=buffer_size) |
|
self.frame_embeddings = deque(maxlen=buffer_size) |
|
self.motion_vectors = deque(maxlen=buffer_size-1) |
|
|
|
def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None): |
|
"""Add frame to buffer""" |
|
|
|
if len(self.frames) > 0: |
|
prev_frame = np.array(self.frames[-1]) |
|
curr_frame = np.array(frame) |
|
|
|
|
|
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY) |
|
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
flow = cv2.calcOpticalFlowPyrLK( |
|
prev_gray, curr_gray, |
|
np.array([[frame.width//2, frame.height//2]], dtype=np.float32), |
|
None |
|
)[0] |
|
|
|
if flow is not None: |
|
motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2]) |
|
self.motion_vectors.append(motion_magnitude) |
|
|
|
self.frames.append(frame) |
|
if embedding is not None: |
|
self.frame_embeddings.append(embedding) |
|
|
|
def get_reference_frame(self) -> Optional[Image.Image]: |
|
"""Get most recent frame as reference""" |
|
return self.frames[-1] if self.frames else None |
|
|
|
def get_motion_context(self) -> Dict[str, Any]: |
|
"""Get motion context for next frame generation""" |
|
if len(self.motion_vectors) == 0: |
|
return {"has_motion": False, "predicted_motion": 0.0} |
|
|
|
|
|
recent_motion = list(self.motion_vectors)[-3:] |
|
avg_motion = np.mean(recent_motion) |
|
motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0 |
|
|
|
predicted_motion = avg_motion + motion_trend * 0.5 |
|
|
|
return { |
|
"has_motion": True, |
|
"current_motion": avg_motion, |
|
"predicted_motion": predicted_motion, |
|
"motion_trend": motion_trend, |
|
"motion_history": recent_motion |
|
} |
|
|
|
class SD15FlexibleI2VGenerator: |
|
"""Flexible I2V generator using SD1.5 img2img pipeline""" |
|
|
|
def __init__( |
|
self, |
|
model_id: str = "runwayml/stable-diffusion-v1-5", |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
): |
|
self.device = device |
|
print(f"π Loading SD1.5 pipeline on {device}...") |
|
|
|
|
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
|
|
|
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe = self.pipe.to(device) |
|
|
|
|
|
if device == "cuda": |
|
self.pipe.enable_attention_slicing() |
|
try: |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
print("β οΈ xformers not available, using standard attention") |
|
|
|
self.temporal_buffer = SimpleTemporalBuffer() |
|
print("β
SD1.5 Flexible I2V Generator ready!") |
|
|
|
def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float: |
|
"""Calculate adaptive denoising strength based on motion""" |
|
if not motion_context.get("has_motion", False): |
|
return base_strength |
|
|
|
motion = motion_context["current_motion"] |
|
|
|
|
|
|
|
motion_factor = np.clip(motion / 50.0, 0.0, 1.0) |
|
adaptive_strength = base_strength * (1.0 - motion_factor * 0.3) |
|
|
|
return np.clip(adaptive_strength, 0.3, 0.9) |
|
|
|
def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str: |
|
"""Enhance prompt based on motion context""" |
|
if not motion_context.get("has_motion", False): |
|
return base_prompt |
|
|
|
motion = motion_context["current_motion"] |
|
trend = motion_context.get("motion_trend", 0) |
|
|
|
|
|
if motion > 30: |
|
if trend > 5: |
|
motion_desc = ", fast movement, dynamic motion, motion blur" |
|
else: |
|
motion_desc = ", steady movement, continuous motion" |
|
elif motion > 10: |
|
motion_desc = ", gentle movement, smooth transition" |
|
else: |
|
motion_desc = ", subtle movement, slight change" |
|
|
|
return base_prompt + motion_desc |
|
|
|
def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image: |
|
"""Blend current frame with reference for temporal consistency""" |
|
current_array = np.array(current_frame, dtype=np.float32) |
|
reference_array = np.array(reference_frame, dtype=np.float32) |
|
|
|
|
|
blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio |
|
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) |
|
|
|
return Image.fromarray(blended_array) |
|
|
|
@torch.no_grad() |
|
def generate_frame_batch( |
|
self, |
|
init_image: Image.Image, |
|
prompt: str, |
|
num_frames: int = 1, |
|
strength: float = 0.75, |
|
guidance_scale: float = 7.5, |
|
num_inference_steps: int = 20, |
|
generator: Optional[torch.Generator] = None |
|
) -> List[Image.Image]: |
|
"""Generate a batch of frames using img2img""" |
|
|
|
frames = [] |
|
current_image = init_image |
|
|
|
for i in range(num_frames): |
|
print(f" π¬ Generating frame {i+1}/{num_frames}") |
|
|
|
|
|
motion_context = self.temporal_buffer.get_motion_context() |
|
|
|
|
|
adaptive_strength = self.calculate_adaptive_strength(motion_context, strength) |
|
enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context) |
|
|
|
print(f" π Motion: {motion_context.get('current_motion', 0):.1f}, Strength: {adaptive_strength:.2f}") |
|
|
|
|
|
result = self.pipe( |
|
prompt=enhanced_prompt, |
|
image=current_image, |
|
strength=adaptive_strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator |
|
) |
|
|
|
generated_frame = result.images[0] |
|
|
|
|
|
if len(self.temporal_buffer.frames) > 0: |
|
reference_frame = self.temporal_buffer.get_reference_frame() |
|
blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2 |
|
generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio) |
|
|
|
|
|
self.temporal_buffer.add_frame(generated_frame) |
|
frames.append(generated_frame) |
|
|
|
|
|
current_image = generated_frame |
|
|
|
return frames |
|
|
|
def generate_i2v_sequence( |
|
self, |
|
init_image: Image.Image, |
|
prompt: str, |
|
total_frames: int = 16, |
|
frames_per_batch: int = 2, |
|
strength: float = 0.75, |
|
guidance_scale: float = 7.5, |
|
num_inference_steps: int = 20, |
|
seed: Optional[int] = None |
|
) -> List[Image.Image]: |
|
"""Generate I2V sequence with flexible batch sizes""" |
|
|
|
print(f"π― Generating {total_frames} frames in batches of {frames_per_batch}") |
|
|
|
|
|
generator = torch.Generator(device=self.device) |
|
if seed is not None: |
|
generator.manual_seed(seed) |
|
|
|
|
|
self.temporal_buffer = SimpleTemporalBuffer() |
|
self.temporal_buffer.add_frame(init_image) |
|
|
|
all_frames = [init_image] |
|
frames_generated = 1 |
|
current_reference = init_image |
|
|
|
|
|
while frames_generated < total_frames: |
|
remaining_frames = total_frames - frames_generated |
|
current_batch_size = min(frames_per_batch, remaining_frames) |
|
|
|
print(f"π Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}") |
|
|
|
|
|
batch_frames = self.generate_frame_batch( |
|
init_image=current_reference, |
|
prompt=prompt, |
|
num_frames=current_batch_size, |
|
strength=strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator |
|
) |
|
|
|
|
|
all_frames.extend(batch_frames) |
|
frames_generated += current_batch_size |
|
|
|
|
|
current_reference = batch_frames[-1] |
|
|
|
print(f"β
Completed batch - {frames_generated}/{total_frames} frames done") |
|
|
|
return all_frames |
|
|
|
def generate_with_variable_batching( |
|
self, |
|
init_image: Image.Image, |
|
prompt: str, |
|
total_frames: int = 24, |
|
batch_pattern: List[int] = [1, 2, 3, 2, 2, 1], |
|
**kwargs |
|
) -> List[Image.Image]: |
|
"""Generate with dynamic batch pattern""" |
|
|
|
print(f"π¨ Generating {total_frames} frames with pattern: {batch_pattern}") |
|
|
|
all_frames = [init_image] |
|
frames_generated = 1 |
|
current_reference = init_image |
|
pattern_idx = 0 |
|
|
|
while frames_generated < total_frames: |
|
|
|
current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)] |
|
remaining_frames = total_frames - frames_generated |
|
actual_batch_size = min(current_batch_size, remaining_frames) |
|
|
|
print(f"π Pattern step {pattern_idx+1}: {actual_batch_size} frames") |
|
|
|
|
|
batch_frames = self.generate_frame_batch( |
|
init_image=current_reference, |
|
prompt=prompt, |
|
num_frames=actual_batch_size, |
|
**kwargs |
|
) |
|
|
|
all_frames.extend(batch_frames) |
|
frames_generated += actual_batch_size |
|
current_reference = batch_frames[-1] |
|
pattern_idx += 1 |
|
|
|
return all_frames[:total_frames+1] |
|
|
|
def save_frames_as_gif(frames: List[Image.Image], output_path: str, duration: int = 100): |
|
"""Save frames as animated GIF""" |
|
frames[0].save( |
|
output_path, |
|
save_all=True, |
|
append_images=frames[1:], |
|
duration=duration, |
|
loop=0 |
|
) |
|
print(f"πΎ Saved {len(frames)} frames to {output_path}") |
|
|
|
def save_frames_as_video(frames: List[Image.Image], output_path: str, fps: int = 8): |
|
"""Save frames as MP4 video""" |
|
try: |
|
import imageio |
|
with imageio.get_writer(output_path, fps=fps) as writer: |
|
for frame in frames: |
|
writer.append_data(np.array(frame)) |
|
print(f"π₯ Saved {len(frames)} frames to {output_path}") |
|
except ImportError: |
|
print("β οΈ imageio not available, saving as GIF instead") |
|
save_frames_as_gif(frames, output_path.replace('.mp4', '.gif')) |
|
|
|
|
|
def example_usage(): |
|
"""Example of how to use the SD1.5 Flexible I2V Generator""" |
|
|
|
|
|
generator = SD15FlexibleI2VGenerator() |
|
|
|
|
|
|
|
init_image = Image.new('RGB', (512, 512), color='lightblue') |
|
|
|
|
|
print("\n㪠Example 1: Fixed batch size (2 frames per batch)") |
|
frames_fixed = generator.generate_i2v_sequence( |
|
init_image=init_image, |
|
prompt="a peaceful lake with gentle ripples, cinematic lighting", |
|
total_frames=8, |
|
frames_per_batch=2, |
|
strength=0.7, |
|
guidance_scale=7.5, |
|
num_inference_steps=20, |
|
seed=42 |
|
) |
|
|
|
|
|
print("\nπ¨ Example 2: Variable batch pattern") |
|
frames_variable = generator.generate_with_variable_batching( |
|
init_image=init_image, |
|
prompt="a cat walking through a garden, smooth motion", |
|
total_frames=12, |
|
batch_pattern=[1, 2, 3, 2, 1], |
|
strength=0.75, |
|
guidance_scale=7.5, |
|
seed=123 |
|
) |
|
|
|
|
|
save_frames_as_gif(frames_fixed, "sd15_fixed_batch.gif", duration=200) |
|
save_frames_as_gif(frames_variable, "sd15_variable_batch.gif", duration=150) |
|
|
|
print("\nπ SD1.5 Flexible Batch I2V Generation Complete!") |
|
print("β¨ Key innovations demonstrated:") |
|
print(" - Flexible batch sizing (frames_per_batch parameter)") |
|
print(" - Motion-aware adaptive strength") |
|
print(" - Temporal consistency through frame blending") |
|
print(" - Variable batch patterns for dynamic control") |
|
print(" - Optical flow-based motion analysis") |
|
|
|
if __name__ == "__main__": |
|
example_usage() |