|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from diffusers import StableDiffusionImg2ImgPipeline, DDIMScheduler |
|
from PIL import Image |
|
import numpy as np |
|
from typing import List, Optional, Dict, Any |
|
from collections import deque |
|
import cv2 |
|
import os |
|
import tempfile |
|
import imageio |
|
from datetime import datetime |
|
|
|
class SimpleTemporalBuffer: |
|
"""Simplified temporal buffer for SD1.5 img2img""" |
|
|
|
def __init__(self, buffer_size: int = 6): |
|
self.buffer_size = buffer_size |
|
self.frames = deque(maxlen=buffer_size) |
|
self.frame_embeddings = deque(maxlen=buffer_size) |
|
self.motion_vectors = deque(maxlen=buffer_size-1) |
|
|
|
def add_frame(self, frame: Image.Image, embedding: Optional[torch.Tensor] = None): |
|
"""Add frame to buffer""" |
|
try: |
|
|
|
if len(self.frames) > 0: |
|
prev_frame = np.array(self.frames[-1]) |
|
curr_frame = np.array(frame) |
|
|
|
|
|
prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_RGB2GRAY) |
|
curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
flow = cv2.calcOpticalFlowPyrLK( |
|
prev_gray, curr_gray, |
|
np.array([[frame.width//2, frame.height//2]], dtype=np.float32), |
|
None |
|
)[0] |
|
|
|
if flow is not None: |
|
motion_magnitude = np.linalg.norm(flow[0] - [frame.width//2, frame.height//2]) |
|
self.motion_vectors.append(motion_magnitude) |
|
except Exception as e: |
|
print(f"Motion calculation error: {e}") |
|
|
|
self.frames.append(frame) |
|
if embedding is not None: |
|
self.frame_embeddings.append(embedding) |
|
|
|
def get_reference_frame(self) -> Optional[Image.Image]: |
|
"""Get most recent frame as reference""" |
|
return self.frames[-1] if self.frames else None |
|
|
|
def get_motion_context(self) -> Dict[str, Any]: |
|
"""Get motion context for next frame generation""" |
|
if len(self.motion_vectors) == 0: |
|
return {"has_motion": False, "predicted_motion": 0.0} |
|
|
|
|
|
recent_motion = list(self.motion_vectors)[-3:] |
|
avg_motion = np.mean(recent_motion) |
|
motion_trend = recent_motion[-1] - recent_motion[0] if len(recent_motion) > 1 else 0 |
|
|
|
predicted_motion = avg_motion + motion_trend * 0.5 |
|
|
|
return { |
|
"has_motion": True, |
|
"current_motion": avg_motion, |
|
"predicted_motion": predicted_motion, |
|
"motion_trend": motion_trend, |
|
"motion_history": recent_motion |
|
} |
|
|
|
class SD15FlexibleI2VGenerator: |
|
"""Flexible I2V generator using SD1.5 img2img pipeline""" |
|
|
|
def __init__( |
|
self, |
|
model_id: str = "runwayml/stable-diffusion-v1-5", |
|
device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
): |
|
self.device = device |
|
self.pipe = None |
|
self.temporal_buffer = SimpleTemporalBuffer() |
|
self.is_loaded = False |
|
|
|
def load_model(self): |
|
"""Load the SD1.5 pipeline""" |
|
if self.is_loaded: |
|
return "Model already loaded" |
|
|
|
try: |
|
print(f"π Loading SD1.5 pipeline on {self.device}...") |
|
|
|
|
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", |
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
safety_checker=None, |
|
requires_safety_checker=False |
|
) |
|
|
|
|
|
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
|
self.pipe = self.pipe.to(self.device) |
|
|
|
|
|
if self.device == "cuda": |
|
self.pipe.enable_attention_slicing() |
|
try: |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
print("β οΈ xformers not available, using standard attention") |
|
|
|
self.is_loaded = True |
|
return "β
Model loaded successfully!" |
|
|
|
except Exception as e: |
|
return f"β Error loading model: {str(e)}" |
|
|
|
def calculate_adaptive_strength(self, motion_context: Dict[str, Any], base_strength: float = 0.75) -> float: |
|
"""Calculate adaptive denoising strength based on motion""" |
|
if not motion_context.get("has_motion", False): |
|
return base_strength |
|
|
|
motion = motion_context["current_motion"] |
|
|
|
|
|
|
|
motion_factor = np.clip(motion / 50.0, 0.0, 1.0) |
|
adaptive_strength = base_strength * (1.0 - motion_factor * 0.3) |
|
|
|
return np.clip(adaptive_strength, 0.3, 0.9) |
|
|
|
def enhance_prompt_with_motion(self, base_prompt: str, motion_context: Dict[str, Any]) -> str: |
|
"""Enhance prompt based on motion context""" |
|
if not motion_context.get("has_motion", False): |
|
return base_prompt |
|
|
|
motion = motion_context["current_motion"] |
|
trend = motion_context.get("motion_trend", 0) |
|
|
|
|
|
if motion > 30: |
|
if trend > 5: |
|
motion_desc = ", fast movement, dynamic motion, motion blur" |
|
else: |
|
motion_desc = ", steady movement, continuous motion" |
|
elif motion > 10: |
|
motion_desc = ", gentle movement, smooth transition" |
|
else: |
|
motion_desc = ", subtle movement, slight change" |
|
|
|
return base_prompt + motion_desc |
|
|
|
def blend_frames(self, current_frame: Image.Image, reference_frame: Image.Image, blend_ratio: float = 0.15) -> Image.Image: |
|
"""Blend current frame with reference for temporal consistency""" |
|
current_array = np.array(current_frame, dtype=np.float32) |
|
reference_array = np.array(reference_frame, dtype=np.float32) |
|
|
|
|
|
blended_array = current_array * (1 - blend_ratio) + reference_array * blend_ratio |
|
blended_array = np.clip(blended_array, 0, 255).astype(np.uint8) |
|
|
|
return Image.fromarray(blended_array) |
|
|
|
@torch.no_grad() |
|
def generate_frame_batch( |
|
self, |
|
init_image: Image.Image, |
|
prompt: str, |
|
num_frames: int = 1, |
|
strength: float = 0.75, |
|
guidance_scale: float = 7.5, |
|
num_inference_steps: int = 20, |
|
generator: Optional[torch.Generator] = None, |
|
progress_callback=None |
|
) -> List[Image.Image]: |
|
"""Generate a batch of frames using img2img""" |
|
|
|
if not self.is_loaded: |
|
raise ValueError("Model not loaded. Please load the model first.") |
|
|
|
frames = [] |
|
current_image = init_image |
|
|
|
for i in range(num_frames): |
|
if progress_callback: |
|
progress_callback(f"Generating frame {i+1}/{num_frames}") |
|
|
|
|
|
motion_context = self.temporal_buffer.get_motion_context() |
|
|
|
|
|
adaptive_strength = self.calculate_adaptive_strength(motion_context, strength) |
|
enhanced_prompt = self.enhance_prompt_with_motion(prompt, motion_context) |
|
|
|
|
|
result = self.pipe( |
|
prompt=enhanced_prompt, |
|
image=current_image, |
|
strength=adaptive_strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator |
|
) |
|
|
|
generated_frame = result.images[0] |
|
|
|
|
|
if len(self.temporal_buffer.frames) > 0: |
|
reference_frame = self.temporal_buffer.get_reference_frame() |
|
blend_ratio = 0.1 if motion_context.get("current_motion", 0) > 20 else 0.2 |
|
generated_frame = self.blend_frames(generated_frame, reference_frame, blend_ratio) |
|
|
|
|
|
self.temporal_buffer.add_frame(generated_frame) |
|
frames.append(generated_frame) |
|
|
|
|
|
current_image = generated_frame |
|
|
|
return frames |
|
|
|
def generate_i2v_sequence( |
|
self, |
|
init_image: Image.Image, |
|
prompt: str, |
|
total_frames: int = 16, |
|
frames_per_batch: int = 2, |
|
strength: float = 0.75, |
|
guidance_scale: float = 7.5, |
|
num_inference_steps: int = 20, |
|
seed: Optional[int] = None, |
|
progress_callback=None |
|
) -> List[Image.Image]: |
|
"""Generate I2V sequence with flexible batch sizes""" |
|
|
|
if not self.is_loaded: |
|
raise ValueError("Model not loaded. Please load the model first.") |
|
|
|
|
|
generator = torch.Generator(device=self.device) |
|
if seed is not None: |
|
generator.manual_seed(seed) |
|
|
|
|
|
self.temporal_buffer = SimpleTemporalBuffer() |
|
self.temporal_buffer.add_frame(init_image) |
|
|
|
all_frames = [init_image] |
|
frames_generated = 1 |
|
current_reference = init_image |
|
|
|
|
|
while frames_generated < total_frames: |
|
remaining_frames = total_frames - frames_generated |
|
current_batch_size = min(frames_per_batch, remaining_frames) |
|
|
|
if progress_callback: |
|
progress_callback(f"Batch: Generating frames {frames_generated+1}-{frames_generated+current_batch_size}") |
|
|
|
|
|
batch_frames = self.generate_frame_batch( |
|
init_image=current_reference, |
|
prompt=prompt, |
|
num_frames=current_batch_size, |
|
strength=strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=generator, |
|
progress_callback=progress_callback |
|
) |
|
|
|
|
|
all_frames.extend(batch_frames) |
|
frames_generated += current_batch_size |
|
|
|
|
|
current_reference = batch_frames[-1] |
|
|
|
return all_frames |
|
|
|
|
|
generator = SD15FlexibleI2VGenerator() |
|
|
|
def load_model_interface(): |
|
"""Interface function to load the model""" |
|
status = generator.load_model() |
|
return status |
|
|
|
def create_frames_to_gif(frames: List[Image.Image], duration: int = 200) -> str: |
|
"""Convert frames to GIF and return file path""" |
|
temp_dir = tempfile.mkdtemp() |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
gif_path = os.path.join(temp_dir, f"i2v_sequence_{timestamp}.gif") |
|
|
|
frames[0].save( |
|
gif_path, |
|
save_all=True, |
|
append_images=frames[1:], |
|
duration=duration, |
|
loop=0 |
|
) |
|
|
|
return gif_path |
|
|
|
def create_frames_to_video(frames: List[Image.Image], fps: int = 8) -> str: |
|
"""Convert frames to MP4 video and return file path""" |
|
temp_dir = tempfile.mkdtemp() |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
video_path = os.path.join(temp_dir, f"i2v_sequence_{timestamp}.mp4") |
|
|
|
try: |
|
with imageio.get_writer(video_path, fps=fps) as writer: |
|
for frame in frames: |
|
writer.append_data(np.array(frame)) |
|
return video_path |
|
except ImportError: |
|
|
|
return create_frames_to_gif(frames, duration=int(1000/fps)) |
|
|
|
def generate_i2v_interface( |
|
init_image, |
|
prompt, |
|
total_frames, |
|
frames_per_batch, |
|
strength, |
|
guidance_scale, |
|
num_inference_steps, |
|
seed, |
|
output_format, |
|
progress=gr.Progress() |
|
): |
|
"""Main interface function for I2V generation""" |
|
|
|
if init_image is None: |
|
return None, None, "β Please upload an initial image" |
|
|
|
if not prompt.strip(): |
|
return None, None, "β Please enter a prompt" |
|
|
|
try: |
|
|
|
def update_progress(message): |
|
progress(0.5, desc=message) |
|
|
|
progress(0.1, desc="Starting generation...") |
|
|
|
|
|
if init_image.size != (512, 512): |
|
init_image = init_image.resize((512, 512), Image.Resampling.LANCZOS) |
|
|
|
|
|
frames = generator.generate_i2v_sequence( |
|
init_image=init_image, |
|
prompt=prompt, |
|
total_frames=total_frames, |
|
frames_per_batch=frames_per_batch, |
|
strength=strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
seed=seed if seed > 0 else None, |
|
progress_callback=update_progress |
|
) |
|
|
|
progress(0.8, desc="Creating output file...") |
|
|
|
|
|
if output_format == "GIF": |
|
output_path = create_frames_to_gif(frames, duration=200) |
|
else: |
|
output_path = create_frames_to_video(frames, fps=8) |
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
return frames[-1], output_path, f"β
Generated {len(frames)} frames successfully!" |
|
|
|
except Exception as e: |
|
return None, None, f"β Error: {str(e)}" |
|
|
|
def generate_variable_pattern_interface( |
|
init_image, |
|
prompt, |
|
total_frames, |
|
batch_pattern_str, |
|
strength, |
|
guidance_scale, |
|
num_inference_steps, |
|
seed, |
|
output_format, |
|
progress=gr.Progress() |
|
): |
|
"""Interface for variable batch pattern generation""" |
|
|
|
if init_image is None: |
|
return None, None, "β Please upload an initial image" |
|
|
|
if not prompt.strip(): |
|
return None, None, "β Please enter a prompt" |
|
|
|
try: |
|
|
|
batch_pattern = [int(x.strip()) for x in batch_pattern_str.split(",")] |
|
if not batch_pattern or any(x <= 0 for x in batch_pattern): |
|
raise ValueError("Invalid batch pattern") |
|
|
|
progress(0.1, desc="Starting variable pattern generation...") |
|
|
|
|
|
if init_image.size != (512, 512): |
|
init_image = init_image.resize((512, 512), Image.Resampling.LANCZOS) |
|
|
|
|
|
frames = [init_image] |
|
frames_generated = 1 |
|
current_reference = init_image |
|
pattern_idx = 0 |
|
|
|
generator.temporal_buffer = SimpleTemporalBuffer() |
|
generator.temporal_buffer.add_frame(init_image) |
|
|
|
gen = torch.Generator(device=generator.device) |
|
if seed > 0: |
|
gen.manual_seed(seed) |
|
|
|
while frames_generated < total_frames: |
|
current_batch_size = batch_pattern[pattern_idx % len(batch_pattern)] |
|
remaining_frames = total_frames - frames_generated |
|
actual_batch_size = min(current_batch_size, remaining_frames) |
|
|
|
progress(frames_generated / total_frames, |
|
desc=f"Pattern step {pattern_idx+1}: {actual_batch_size} frames") |
|
|
|
batch_frames = generator.generate_frame_batch( |
|
init_image=current_reference, |
|
prompt=prompt, |
|
num_frames=actual_batch_size, |
|
strength=strength, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=gen |
|
) |
|
|
|
frames.extend(batch_frames) |
|
frames_generated += actual_batch_size |
|
current_reference = batch_frames[-1] |
|
pattern_idx += 1 |
|
|
|
progress(0.9, desc="Creating output file...") |
|
|
|
|
|
final_frames = frames[:total_frames+1] |
|
if output_format == "GIF": |
|
output_path = create_frames_to_gif(final_frames, duration=200) |
|
else: |
|
output_path = create_frames_to_video(final_frames, fps=8) |
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
return final_frames[-1], output_path, f"β
Generated {len(final_frames)} frames with pattern {batch_pattern}!" |
|
|
|
except Exception as e: |
|
return None, None, f"β Error: {str(e)}" |
|
|
|
|
|
def create_gradio_app(): |
|
"""Create the main Gradio application""" |
|
|
|
with gr.Blocks(title="SD1.5 Flexible I2V Generator", theme=gr.themes.Soft()) as app: |
|
|
|
gr.Markdown(""" |
|
# π¬ SD1.5 Flexible I2V Generator |
|
|
|
Generate image-to-video sequences with **flexible batch processing** and **temporal consistency**! |
|
|
|
## Key Features: |
|
- π― **Flexible Batch Sizes**: Generate 1, 2, 3+ frames at a time |
|
- π **Motion-Aware Processing**: Adapts based on detected motion |
|
- π¨ **Temporal Consistency**: Smooth transitions between frames |
|
- π **Variable Patterns**: Dynamic batch sizing patterns |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
load_btn = gr.Button("π Load SD1.5 Model", variant="primary", size="lg") |
|
model_status = gr.Textbox( |
|
label="Model Status", |
|
value="Model not loaded. Click 'Load SD1.5 Model' to start.", |
|
interactive=False |
|
) |
|
|
|
load_btn.click(load_model_interface, outputs=model_status) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
with gr.Tab("π― Fixed Batch Generation"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
init_image_1 = gr.Image( |
|
label="Initial Image", |
|
type="pil", |
|
height=300 |
|
) |
|
prompt_1 = gr.Textbox( |
|
label="Prompt", |
|
placeholder="e.g., a cat walking through a peaceful garden, cinematic lighting", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
total_frames_1 = gr.Slider( |
|
label="Total Frames", |
|
minimum=4, |
|
maximum=32, |
|
value=12, |
|
step=1 |
|
) |
|
frames_per_batch_1 = gr.Slider( |
|
label="Frames per Batch (Key Parameter!)", |
|
minimum=1, |
|
maximum=4, |
|
value=2, |
|
step=1 |
|
) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
strength_1 = gr.Slider( |
|
label="Strength", |
|
minimum=0.3, |
|
maximum=0.9, |
|
value=0.75, |
|
step=0.05 |
|
) |
|
guidance_scale_1 = gr.Slider( |
|
label="Guidance Scale", |
|
minimum=3.0, |
|
maximum=15.0, |
|
value=7.5, |
|
step=0.5 |
|
) |
|
num_inference_steps_1 = gr.Slider( |
|
label="Inference Steps", |
|
minimum=10, |
|
maximum=50, |
|
value=20, |
|
step=5 |
|
) |
|
seed_1 = gr.Number( |
|
label="Seed (-1 for random)", |
|
value=-1 |
|
) |
|
output_format_1 = gr.Radio( |
|
label="Output Format", |
|
choices=["GIF", "MP4"], |
|
value="GIF" |
|
) |
|
|
|
generate_btn_1 = gr.Button("π¬ Generate I2V Sequence", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
preview_1 = gr.Image(label="Last Frame Preview", height=300) |
|
output_file_1 = gr.File(label="Download Generated Video/GIF") |
|
status_1 = gr.Textbox(label="Status", interactive=False) |
|
|
|
generate_btn_1.click( |
|
generate_i2v_interface, |
|
inputs=[ |
|
init_image_1, prompt_1, total_frames_1, frames_per_batch_1, |
|
strength_1, guidance_scale_1, num_inference_steps_1, seed_1, output_format_1 |
|
], |
|
outputs=[preview_1, output_file_1, status_1] |
|
) |
|
|
|
|
|
with gr.Tab("π Variable Pattern Generation"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
init_image_2 = gr.Image( |
|
label="Initial Image", |
|
type="pil", |
|
height=300 |
|
) |
|
prompt_2 = gr.Textbox( |
|
label="Prompt", |
|
placeholder="e.g., smooth camera movement through a scene", |
|
lines=3 |
|
) |
|
|
|
total_frames_2 = gr.Slider( |
|
label="Total Frames", |
|
minimum=6, |
|
maximum=40, |
|
value=16, |
|
step=1 |
|
) |
|
|
|
batch_pattern_2 = gr.Textbox( |
|
label="Batch Pattern (comma-separated)", |
|
value="1,2,3,2,1", |
|
placeholder="e.g., 1,2,3,2,1 or 2,4,2" |
|
) |
|
|
|
gr.Markdown(""" |
|
**Pattern Examples:** |
|
- `1,2,3,2,1` - Start slow, ramp up, slow down |
|
- `2,2,2,2` - Consistent 2-frame batches |
|
- `1,3,1,3` - Alternating single and triple |
|
""") |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
strength_2 = gr.Slider(label="Strength", minimum=0.3, maximum=0.9, value=0.75, step=0.05) |
|
guidance_scale_2 = gr.Slider(label="Guidance Scale", minimum=3.0, maximum=15.0, value=7.5, step=0.5) |
|
num_inference_steps_2 = gr.Slider(label="Inference Steps", minimum=10, maximum=50, value=20, step=5) |
|
seed_2 = gr.Number(label="Seed (-1 for random)", value=-1) |
|
output_format_2 = gr.Radio(label="Output Format", choices=["GIF", "MP4"], value="GIF") |
|
|
|
generate_btn_2 = gr.Button("π¨ Generate with Pattern", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=1): |
|
preview_2 = gr.Image(label="Last Frame Preview", height=300) |
|
output_file_2 = gr.File(label="Download Generated Video/GIF") |
|
status_2 = gr.Textbox(label="Status", interactive=False) |
|
|
|
generate_btn_2.click( |
|
generate_variable_pattern_interface, |
|
inputs=[ |
|
init_image_2, prompt_2, total_frames_2, batch_pattern_2, |
|
strength_2, guidance_scale_2, num_inference_steps_2, seed_2, output_format_2 |
|
], |
|
outputs=[preview_2, output_file_2, status_2] |
|
) |
|
|
|
|
|
with gr.Accordion("π Example Prompts & Tips", open=False): |
|
gr.Markdown(""" |
|
## π― Good Prompts for I2V: |
|
- `a peaceful lake with gentle ripples, soft sunlight, cinematic` |
|
- `a cat slowly walking through a garden, smooth movement` |
|
- `camera slowly panning across a mountain landscape` |
|
- `a flower blooming in timelapse, natural lighting` |
|
- `gentle waves on a beach, golden hour lighting` |
|
|
|
## π Parameter Tips: |
|
- **Frames per Batch**: |
|
- `1` = Maximum consistency, slower generation |
|
- `2-3` = Balanced quality and speed |
|
- `4+` = Faster but less consistent |
|
- **Strength**: |
|
- `0.6-0.7` = Subtle changes |
|
- `0.7-0.8` = Moderate animation |
|
- `0.8-0.9` = More dramatic changes |
|
- **Batch Patterns**: |
|
- Use `1,2,3,2,1` for organic acceleration/deceleration |
|
- Use consistent values like `2,2,2` for steady pacing |
|
""") |
|
|
|
gr.Markdown(""" |
|
--- |
|
|
|
## π **Innovation Highlights:** |
|
|
|
This app demonstrates **flexible batch processing** for I2V generation: |
|
- Generate multiple frames simultaneously with `frames_per_batch` |
|
- Motion-aware strength adaptation based on optical flow |
|
- Temporal consistency through intelligent frame blending |
|
- Variable stepping patterns for dynamic control |
|
|
|
**Built with SD1.5 img2img pipeline + custom temporal processing!** |
|
""") |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_gradio_app() |
|
app.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=False, |
|
debug=True |
|
) |