Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import json | |
import logging | |
import torch | |
from PIL import Image | |
import spaces | |
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline | |
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition | |
from diffusers.utils import export_to_video, load_video, load_image | |
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download | |
import copy | |
import random | |
import numpy as np | |
import imageio | |
import time | |
import re | |
#--- LoRA related: Load LoRAs from JSON file --- | |
try: | |
with open('loras.json', 'r') as f: | |
loras = json.load(f) | |
except FileNotFoundError: | |
print("WARNING: loras.json not found. LoRA gallery will be empty or non-functional.") | |
print("Please create loras.json with entries like: [{'title': 'My LTX LoRA', 'repo': 'user/repo', 'weights': 'lora.safetensors', 'trigger_word': 'my style', 'image': 'url_to_image.jpg'}]") | |
loras = [] | |
except json.JSONDecodeError: | |
print("WARNING: loras.json is not valid JSON. LoRA gallery will be empty or non-functional.") | |
loras = [] | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = LTXConditionPipeline.from_pretrained("Lightricks/LTX-Video-0.9.7-dev", torch_dtype=dtype) | |
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("Lightricks/ltxv-spatial-upscaler-0.9.7", vae=pipe.vae, torch_dtype=dtype) | |
pipe.to(device) | |
pipe_upsample.to(device) | |
pipe.vae.enable_tiling() | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1280 | |
MAX_NUM_FRAMES = 257 | |
FPS = 30.0 | |
MIN_DIM_SLIDER = 256 | |
TARGET_FIXED_SIDE = 768 | |
last_lora = "" | |
last_fused=False | |
class calculateDuration: | |
def __init__(self, activity_name=""): | |
self.activity_name = activity_name | |
def __enter__(self): | |
self.start_time = time.time() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.end_time = time.time() | |
self.elapsed_time = self.end_time - self.start_time | |
if self.activity_name: | |
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds") | |
else: | |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds") | |
def update_lora_selection(evt: gr.SelectData): | |
if not loras or evt.index is None or evt.index >= len(loras): | |
return gr.update(), None # No update to markdown, no selected index | |
selected_lora_item = loras[evt.index] | |
# new_placeholder = f"Type a prompt for {selected_lora_item['title']}" # Not updating placeholders directly | |
lora_repo = selected_lora_item["repo"] | |
updated_text = f"### Selected LoRA: [{selected_lora_item['title']}](https://huggingface.co/{lora_repo}) ✨" | |
if selected_lora_item.get('trigger_word'): | |
updated_text += f"\nTrigger word: `{selected_lora_item['trigger_word']}`" | |
return ( | |
# gr.update(placeholder=new_placeholder), # Not changing prompt placeholder | |
updated_text, | |
evt.index, | |
) | |
def get_huggingface_safetensors_for_ltx(link): # Renamed for clarity | |
split_link = link.split("/") | |
if len(split_link) != 2: | |
raise Exception("Invalid Hugging Face repository link format. Should be 'username/repository_name'.") | |
print(f"Repository attempted: {link}") # Use the combined link | |
model_card = ModelCard.load(link) # link is "username/repository_name" | |
base_model = model_card.data.get("base_model") | |
print(f"Base model from card: {base_model}") | |
# Validate model type for LTX | |
acceptable_models = {"Lightricks/LTX-Video-0.9.7-dev"} # Key line for LTX compatibility | |
models_to_check = base_model if isinstance(base_model, list) else [base_model] | |
if not any(str(model).strip() in acceptable_models for model in models_to_check): # Ensure string comparison | |
raise Exception(f"Not a LoRA for a compatible LTX base model! Expected one of {acceptable_models}, found {models_to_check}") | |
image_path = None | |
if model_card.data.get("widget") and isinstance(model_card.data["widget"], list) and len(model_card.data["widget"]) > 0: | |
image_path = model_card.data["widget"][0].get("output", {}).get("url", None) | |
trigger_word = model_card.data.get("instance_prompt", "") | |
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None | |
fs = HfFileSystem() | |
try: | |
list_of_files = fs.ls(link, detail=False) | |
safetensors_name = None | |
# Prioritize files common for LoRAs | |
common_lora_filenames = ["lora.safetensors", "pytorch_lora_weights.safetensors"] | |
for f_common in common_lora_filenames: | |
if f"{link}/{f_common}" in list_of_files: | |
safetensors_name = f_common | |
break | |
if not safetensors_name: # Fallback to first .safetensors | |
for file_path in list_of_files: | |
filename = file_path.split("/")[-1] | |
if filename.endswith(".safetensors"): | |
safetensors_name = filename | |
break | |
if not safetensors_name: # If still not found, then raise error | |
raise Exception("No valid *.safetensors file found in the repository.") | |
if not image_url: # Fallback image search | |
for file_path in list_of_files: | |
filename = file_path.split("/")[-1] | |
if filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")): | |
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}" | |
break | |
except Exception as e: | |
print(f"Error accessing repository or finding safetensors: {e}") | |
raise Exception(f"Could not validate Hugging Face repository '{link}' or find a .safetensors LoRA file.") from e | |
# split_link[0] is user, split_link[1] is repo_name | |
return split_link[1], link, safetensors_name, trigger_word, image_url | |
def check_custom_model_for_ltx(link_input): # Renamed for clarity | |
print(f"Checking a custom model on: {link_input}") | |
if not link_input or not isinstance(link_input, str): | |
raise Exception("Invalid custom LoRA input. Please provide a Hugging Face repository path (e.g., 'username/repo-name') or URL.") | |
link_to_check = link_input.strip() | |
if link_to_check.startswith("https://huggingface.co/"): | |
link_to_check = link_to_check.replace("https://huggingface.co/", "").split("?")[0] # Remove base URL and query params | |
elif link_to_check.startswith("www.huggingface.co/"): | |
link_to_check = link_to_check.replace("www.huggingface.co/", "").split("?")[0] | |
# Basic check for 'user/repo' format | |
if '/' not in link_to_check or len(link_to_check.split('/')) != 2: | |
raise Exception("Invalid Hugging Face repository path. Use 'username/repo-name' format.") | |
return get_huggingface_safetensors_for_ltx(link_to_check) | |
def add_custom_lora_for_ltx(custom_lora_path_input): # Renamed for clarity | |
global loras # To modify the global loras list | |
if custom_lora_path_input: | |
try: | |
title, repo_id, weights_filename, trigger_word, image_url = check_custom_model_for_ltx(custom_lora_path_input) | |
print(f"Loaded custom LoRA: {repo_id}") | |
# Create HTML card for display | |
card_html = f''' | |
<div class="custom_lora_card"> | |
<span>Loaded custom LoRA:</span> | |
<div class="card_internal"> | |
<img src="{image_url if image_url else 'https://huggingface.co/front/assets/huggingface_logo-noborder.svg'}" alt="{title}" style="width:80px; height:80px; object-fit:cover;" /> | |
<div> | |
<h4>{title}</h4> | |
<small>Repo: {repo_id}<br>Weights: {weights_filename}<br> | |
{"Trigger: <code><b>"+trigger_word+"</code></b>" if trigger_word else "No trigger word found. If one is needed, include it in your prompt."} | |
</small> | |
</div> | |
</div> | |
</div> | |
''' | |
# Check if this LoRA (by repo_id) already exists | |
existing_item_index = next((index for (index, item) in enumerate(loras) if item['repo'] == repo_id), None) | |
new_item_data = { | |
"image": image_url, | |
"title": title, | |
"repo": repo_id, | |
"weights": weights_filename, | |
"trigger_word": trigger_word, | |
"custom": True # Mark as custom | |
} | |
if existing_item_index is not None: | |
loras[existing_item_index] = new_item_data # Update existing | |
else: | |
loras.append(new_item_data) | |
existing_item_index = len(loras) - 1 | |
# Update gallery choices | |
gallery_choices = [(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras] | |
return ( | |
gr.update(visible=True, value=card_html), | |
gr.update(visible=True), # Show remove button | |
gr.update(choices=gallery_choices, value=None), # Update gallery, deselect | |
f"Custom LoRA '{title}' added. Select it from the gallery.", # Selected info text | |
None, # Reset selected_index state | |
"" # Clear custom LoRA input textbox | |
) | |
except Exception as e: | |
gr.Warning(f"Invalid Custom LoRA: {e}") | |
return gr.update(visible=True, value=f"<p style='color:red;'>Error adding LoRA: {e}</p>"), gr.update(visible=False), gr.update(), "", None, custom_lora_path_input | |
else: # No input | |
return gr.update(visible=False), gr.update(visible=False), gr.update(), "", None, "" | |
def remove_custom_lora_for_ltx(): # Renamed for clarity | |
global loras | |
# Remove the last added custom LoRA if it's marked (simplistic: assumes one custom at a time or last one) | |
# A more robust way would be to track the index of the custom LoRA being displayed. | |
# For now, let's find the *last* custom LoRA and remove it. | |
custom_lora_indices = [i for i, item in enumerate(loras) if item.get("custom")] | |
if custom_lora_indices: | |
loras.pop(custom_lora_indices[-1]) # Remove the last one marked as custom | |
gallery_choices = [(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras] | |
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(choices=gallery_choices, value=None), "", None, "" | |
def round_to_nearest_resolution_acceptable_by_vae(height, width): | |
height = height - (height % pipe.vae_spatial_compression_ratio) | |
width = width - (width % pipe.vae_spatial_compression_ratio) | |
return height, width | |
def calculate_new_dimensions(orig_w, orig_h): | |
"""Calculates new dimensions maintaining aspect ratio with one side fixed to TARGET_FIXED_SIDE.""" | |
if orig_w == 0 or orig_h == 0: return MIN_DIM_SLIDER, MIN_DIM_SLIDER # Avoid division by zero | |
if orig_w > orig_h: # Landscape or square | |
new_w = TARGET_FIXED_SIDE | |
new_h = int(TARGET_FIXED_SIDE * orig_h / orig_w) | |
else: # Portrait | |
new_h = TARGET_FIXED_SIDE | |
new_w = int(TARGET_FIXED_SIDE * orig_w / orig_h) | |
# Ensure dimensions are at least MIN_DIM_SLIDER | |
new_w = max(MIN_DIM_SLIDER, new_w) | |
new_h = max(MIN_DIM_SLIDER, new_h) | |
# Ensure divisibility by VAE compression ratio (e.g., 32) | |
new_h, new_w = round_to_nearest_resolution_acceptable_by_vae(new_h, new_w) | |
return new_h, new_w | |
def handle_image_upload_for_dims(image_filepath, current_h, current_w): | |
if not image_filepath: | |
return gr.update(value=current_h), gr.update(value=current_w) | |
try: | |
img = Image.open(image_filepath) | |
orig_w, orig_h = img.size | |
new_h, new_w = calculate_new_dimensions(orig_w, orig_h) | |
return gr.update(value=new_h), gr.update(value=new_w) | |
except Exception as e: | |
print(f"Error processing image for dimension update: {e}") | |
return gr.update(value=current_h), gr.update(value=current_w) | |
def handle_video_upload_for_dims(video_filepath, current_h, current_w): | |
if not video_filepath: | |
return gr.update(value=current_h), gr.update(value=current_w) | |
try: | |
video_filepath_str = str(video_filepath) | |
if not os.path.exists(video_filepath_str): | |
print(f"Video file path does not exist for dimension update: {video_filepath_str}") | |
return gr.update(value=current_h), gr.update(value=current_w) | |
orig_w, orig_h = -1, -1 | |
with imageio.get_reader(video_filepath_str) as reader: | |
meta = reader.get_meta_data() | |
if 'size' in meta: | |
orig_w, orig_h = meta['size'] | |
else: | |
try: | |
first_frame = reader.get_data(0) | |
orig_h, orig_w = first_frame.shape[0], first_frame.shape[1] | |
except Exception as e_frame: | |
print(f"Could not get video size from metadata or first frame: {e_frame}") | |
return gr.update(value=current_h), gr.update(value=current_w) | |
if orig_w == -1 or orig_h == -1: | |
print(f"Could not determine dimensions for video: {video_filepath_str}") | |
return gr.update(value=current_h), gr.update(value=current_w) | |
new_h, new_w = calculate_new_dimensions(orig_w, orig_h) | |
return gr.update(value=new_h), gr.update(value=new_w) | |
except Exception as e: | |
print(f"Error processing video for dimension update: {e} (Path: {video_filepath}, Type: {type(video_filepath)})") | |
return gr.update(value=current_h), gr.update(value=current_w) | |
def update_task_image(): return "image-to-video" | |
def update_task_text(): return "text-to-video" | |
def update_task_video(): return "video-to-video" | |
def get_duration(prompt, negative_prompt, image, video, height, width, mode, steps, num_frames, | |
frames_to_use, seed, randomize_seed, guidance_scale, duration_input, improve_texture, | |
# New LoRA params | |
selected_lora_index, lora_scale_value, | |
progress): | |
if duration_input > 7: | |
return 95 | |
else: | |
return 85 | |
def generate(prompt, | |
negative_prompt, | |
image, | |
video, | |
height, | |
width, | |
mode, | |
steps, | |
num_frames_slider_val, # Renamed to avoid conflict with internal num_frames | |
frames_to_use, | |
seed, | |
randomize_seed, | |
guidance_scale, | |
duration_input, | |
improve_texture=False, | |
# New LoRA params | |
selected_lora_index=None, | |
lora_scale_value=0.8, # Default LoRA scale | |
progress=gr.Progress(track_tqdm=True)): | |
effective_prompt = prompt | |
global last_fused, last_lora | |
# --- LoRA Handling --- | |
# Unload any existing LoRAs from main pipes first to prevent conflicts | |
if selected_lora_index is not None and 0 <= selected_lora_index < len(loras): | |
selected_lora_data = loras[selected_lora_index] | |
lora_repo_id = selected_lora_data["repo"] | |
lora_weights_name = selected_lora_data.get("weights", None) | |
lora_trigger = selected_lora_data.get("trigger_word", "") | |
print("Last LoRA: ", last_lora) | |
print("Current LoRA: ", lora_repo_id) | |
print("Last fused: ", last_fused) | |
print(f"Selected LoRA: {selected_lora_data['title']} from {lora_repo_id}") | |
if last_lora != lora_repo_id: | |
if(last_fused): | |
with calculateDuration("Unloading previous LoRAs"): | |
pipe.unfuse_lora() | |
print("Previous LoRAs unloaded if any.") | |
with calculateDuration(f"Loading LoRA weights for {selected_lora_data['title']}"): | |
pipe.load_lora_weights( | |
lora_repo_id, | |
weight_name=lora_weights_name, | |
adapter_name="active_lora" | |
) | |
#pipe.set_adapters(["active_lora"], adapter_weights=[lora_scale_value]) | |
pipe.fuse_lora(adapter_names=["active_lora"],lora_scale=lora_scale_value) | |
pipe.unload_lora_weights() | |
print(f"LoRA loaded into main pipe with scale {lora_scale_value}") | |
last_fused = True | |
last_lora = lora_repo_id | |
if lora_trigger: | |
print(f"Applying trigger word: {lora_trigger}") | |
if selected_lora_data.get("trigger_position") == "prepend": | |
effective_prompt = f"{lora_trigger} {prompt}" | |
else: # Default to append or if not specified | |
effective_prompt = f"{prompt} {lora_trigger}" | |
else: | |
print("No LoRA selected or invalid index.") | |
# --- End LoRA Handling --- | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
target_frames_ideal = duration_input * FPS | |
target_frames_rounded = round(target_frames_ideal) | |
if target_frames_rounded < 1: target_frames_rounded = 1 | |
n_val = round((float(target_frames_rounded) - 1.0) / 8.0) | |
actual_num_frames = int(n_val * 8 + 1) | |
actual_num_frames = max(9, actual_num_frames) | |
num_frames = min(MAX_NUM_FRAMES, actual_num_frames) # This num_frames is used by the pipe | |
if mode == "video-to-video" and (video is not None): | |
loaded_video_frames = load_video(video)[:frames_to_use] | |
condition_input_video = True | |
width, height = loaded_video_frames[0].size | |
# steps = 4 # This was hardcoded, let user control steps | |
elif mode == "image-to-video" and (image is not None): | |
loaded_video_frames = [load_image(image)] | |
width, height = loaded_video_frames[0].size | |
condition_input_video = True | |
else: # text-to-video | |
condition_input_video=False | |
loaded_video_frames = None # No video frames for pure t2v | |
if condition_input_video and loaded_video_frames: | |
condition1 = LTXVideoCondition(video=loaded_video_frames, frame_index=0) | |
else: | |
condition1 = None | |
expected_height, expected_width = height, width | |
downscale_factor = 2 / 3 | |
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor) | |
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width) | |
#timesteps_first_pass = [1000, 993, 987, 981, 975, 909, 725] | |
#timesteps_second_pass = [1000, 909, 725, 421] | |
#if steps == 8: | |
#timesteps_first_pass = [1000, 993, 987, 981, 975, 909, 725, 0.03] | |
# timesteps_second_pass = [1000, 909, 725, 421, 0] | |
# elif 7 < steps < 8: # Non-integer steps could be an issue for these pre-defined timesteps | |
#timesteps_first_pass = None | |
# timesteps_second_pass = None | |
with calculateDuration("video generation"): | |
latents = pipe( | |
conditions=condition1, | |
prompt=effective_prompt, # Use prompt with trigger word | |
negative_prompt=negative_prompt, | |
width=downscaled_width, | |
height=downscaled_height, | |
num_frames=num_frames, | |
num_inference_steps=steps, | |
decode_timestep=0.05, | |
decode_noise_scale=0.025, | |
#timesteps=timesteps_first_pass, | |
image_cond_noise_scale=0.025, | |
guidance_rescale=0.7, | |
guidance_scale=guidance_scale, | |
generator=torch.Generator(device=device).manual_seed(seed), | |
output_type="latent", | |
).frames | |
final_video_frames_np = None # Initialize | |
if improve_texture: | |
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2 # These are internal, not user-facing W/H | |
with calculateDuration("Latent upscaling"): | |
upscaled_latents = pipe_upsample( | |
latents=latents, | |
adain_factor=1.0, | |
output_type="latent" | |
).frames | |
with calculateDuration("Denoising upscaled video"): | |
final_video_frames_np = pipe( # Using main pipe for denoising | |
conditions=condition1, # Re-pass condition if applicable | |
prompt=effective_prompt, | |
negative_prompt=negative_prompt, | |
width=upscaled_width, # Use upscaled dimensions for this pass | |
height=upscaled_height, | |
num_frames=num_frames, | |
guidance_scale=guidance_scale, | |
denoise_strength=0.4, | |
#timesteps=timesteps_second_pass, | |
num_inference_steps=10, # Or make this configurable | |
latents=upscaled_latents, | |
decode_timestep=0.05, | |
decode_noise_scale=0.025, | |
image_cond_noise_scale=0.025, | |
guidance_rescale=0.7, | |
generator=torch.Generator(device=device).manual_seed(seed), | |
output_type="np", | |
).frames[0] | |
else: # No texture improvement, just upscale latents and decode | |
with calculateDuration("Latent upscaling and decoding (no improve_texture)"): | |
final_video_frames_np = pipe_upsample( | |
latents=latents, | |
output_type="np" # Decode directly | |
).frames[0] | |
# Video saving | |
video_uint8_frames = [(frame * 255).astype(np.uint8) for frame in final_video_frames_np] | |
output_filename = "output.mp4" | |
with calculateDuration("Saving video to mp4"): | |
with imageio.get_writer(output_filename, fps=FPS, quality=8, macro_block_size=1) as writer: # Removed bitrate=None | |
for frame_idx, frame_data in enumerate(video_uint8_frames): | |
progress((frame_idx + 1) / len(video_uint8_frames), desc="Encoding video frames...") | |
writer.append_data(frame_data) | |
return output_filename, seed # Return seed for display | |
# --- Gradio UI --- | |
css=""" | |
#title{text-align: center} | |
#title h1{font-size: 3em; display:inline-flex; align-items:center} | |
#title img{width: 100px; margin-right: 0.5em} | |
#col-container { margin: 0 auto; max-width: 1000px; } /* Increased max-width for gallery */ | |
#gallery .grid-wrap{height: 20vh !important; max-height: 250px !important;} | |
.custom_lora_card { border: 1px solid #e0e0e0; border-radius: 8px; padding: 10px; margin-top: 10px; background-color: #f9f9f9; } | |
.card_internal { display: flex; align-items: center; } | |
.card_internal img { margin-right: 1em; border-radius: 4px; } | |
.card_internal div h4 { margin-bottom: 0.2em; } | |
.card_internal div small { font-size: 0.9em; color: #555; } | |
#lora_list_link { font-size: 90%; background: var(--block-background-fill); padding: 0.5em 1em; border-radius: 8px; display:inline-block; margin-top:10px;} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend Deca"), "sans-serif"])) as demo: | |
# gr.Markdown("# LTX Video 0.9.7 Distilled with LoRA Explorer") | |
# gr.Markdown("Fast high quality video generation with custom LoRA support. [Model](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-distilled.safetensors) [GitHub](https://github.com/Lightricks/LTX-Video)") | |
title = gr.HTML( | |
"""<h1><img src="https://huggingface.co/spaces/linoyts/LTXV-lora-the-explorer/resolve/main/Group%20588.png" alt="LoRA"> LTX Video LoRA the Explorer</h1>""", | |
elem_id="title", | |
) | |
gr.Markdown("[🧨diffusers implementation of LTX Video 0.9.7 Distilled](https://huggingface.co/Lightricks/LTX-Video-0.9.7-distilled) with community trained LoRAs 🤗") | |
selected_lora_index_state = gr.State(None) | |
with gr.Row(): | |
with gr.Column(scale=1): # Main controls | |
with gr.Tab("image-to-video") as image_tab: | |
with gr.Group(): | |
video_i_hidden = gr.Textbox(label="video_i", visible=False, value=None) | |
image_i2v = gr.Image(label="Input Image", type="filepath", sources=["upload", "clipboard"]) # Removed webcam | |
i2v_prompt = gr.Textbox(label="Prompt", value="", lines=3) | |
i2v_button = gr.Button("Generate Image-to-Video", variant="primary") | |
with gr.Tab("text-to-video") as text_tab: | |
with gr.Group(): | |
image_n_hidden = gr.Textbox(label="image_n", visible=False, value=None) | |
video_n_hidden = gr.Textbox(label="video_n", visible=False, value=None) | |
t2v_prompt = gr.Textbox(label="Prompt", value="a playfull penguin", lines=3) | |
t2v_button = gr.Button("Generate Text-to-Video", variant="primary") | |
with gr.Tab("video-to-video", visible=False) as video_tab: | |
with gr.Group(): | |
image_v_hidden = gr.Textbox(label="image_v", visible=False, value=None) | |
video_v2v = gr.Video(label="Input Video") | |
frames_to_use_slider = gr.Slider(label="Frames to use from input video", minimum=9, maximum=MAX_NUM_FRAMES, value=9, step=8, info="Number of initial frames for conditioning. Must be N*8+1.") | |
v2v_prompt = gr.Textbox(label="Prompt", value="Change the style to cinematic anime", lines=3) | |
v2v_button = gr.Button("Generate Video-to-Video", variant="primary") | |
# duration_slider = gr.Slider( | |
# label="Video Duration (seconds)", minimum=0.3, maximum=8.5, value=2, step=0.1, | |
# info="Target video duration (0.3s to 8.5s). Actual frames depend on model constraints (multiple of 8 + 1)." | |
# ) | |
# improve_texture_checkbox = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower.") | |
with gr.Column(scale=1): # LoRA Gallery and Output | |
selected_lora_info_markdown = gr.Markdown("No LoRA selected.") | |
lora_gallery_display = gr.Gallery( | |
# Ensure loras is a list of (image_url, title) tuples or similar | |
value=[(item.get("image", "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"), item["title"]) for item in loras] if loras else [], | |
label="pick a LoRA", | |
allow_preview=False, | |
columns=3, height="auto", | |
elem_id="gallery" | |
) | |
with gr.Group(): | |
custom_lora_input_path = gr.Textbox(label="Add Custom LoRA from Hugging Face", info="Path like 'username/repo-name'", placeholder="e.g., ", visible=False) | |
#gr.Markdown("[Find LTX-compatible LoRAs on Hugging Face](https://huggingface.co/models?other=base_model:Lightricks/LTX-Video-0.9.7-distilled&sort=trending)", elem_id="lora_list_link") | |
custom_lora_status_html = gr.HTML(visible=False) # For displaying custom LoRA card | |
remove_custom_lora_button = gr.Button("Remove Last Added Custom LoRA", visible=False) | |
with gr.Column(scale=1): | |
output_video = gr.Video(label="Generated Video", interactive=False) | |
duration_slider = gr.Slider( | |
label="Video Duration (seconds)", minimum=0.3, maximum=8.5, value=2, step=0.1, | |
info="Target video duration (0.3s to 8.5s). Actual frames depend on model constraints (multiple of 8 + 1)." | |
) | |
improve_texture_checkbox = gr.Checkbox(label="Improve Texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower.") | |
# gr.DeepLinkButton() | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
lora_scale_slider = gr.Slider(label="LoRA Scale", minimum=0.0, maximum=3, step=0.05, value=1.5, info="Adjusts the influence of the selected LoRA.") | |
mode_dropdown = gr.Dropdown(["text-to-video", "image-to-video", "video-to-video"], label="Task Mode", value="image-to-video", visible=False) # Keep internal | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted", lines=2) | |
with gr.Row(): | |
seed_number_input = gr.Number(label="Seed", value=0, precision=0) | |
randomize_seed_checkbox = gr.Checkbox(label="Randomize Seed", value=True) | |
with gr.Row(): | |
guidance_scale_slider = gr.Slider(label="Guidance Scale (CFG)", minimum=0, maximum=10, value=5.0, step=0.1) # LTX uses low CFG | |
steps_slider = gr.Slider(label="Inference Steps (Main Pass)", minimum=1, maximum=30, value=25, step=1) # Default steps for LTX | |
# num_frames_slider = gr.Slider(label="# Frames (Debug - Overridden by Duration)", minimum=9, maximum=MAX_NUM_FRAMES, value=96, step=8, visible=False) # Hidden, as duration controls it | |
with gr.Row(): | |
height_slider = gr.Slider(label="Target Height", value=512, step=pipe.vae_spatial_compression_ratio, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info=f"Must be divisible by {pipe.vae_spatial_compression_ratio}.") | |
width_slider = gr.Slider(label="Target Width", value=704, step=pipe.vae_spatial_compression_ratio, minimum=MIN_DIM_SLIDER, maximum=MAX_IMAGE_SIZE, info=f"Must be divisible by {pipe.vae_spatial_compression_ratio}.") | |
# --- Event Handlers --- | |
image_i2v.upload(fn=handle_image_upload_for_dims, inputs=[image_i2v, height_slider, width_slider], outputs=[height_slider, width_slider]) | |
video_v2v.upload(fn=handle_video_upload_for_dims, inputs=[video_v2v, height_slider, width_slider], outputs=[height_slider, width_slider]) | |
video_v2v.clear(lambda cur_h, cur_w: (gr.update(value=cur_h), gr.update(value=cur_w)), inputs=[height_slider, width_slider], outputs=[height_slider, width_slider]) | |
image_i2v.clear(lambda cur_h, cur_w: (gr.update(value=cur_h), gr.update(value=cur_w)), inputs=[height_slider, width_slider], outputs=[height_slider, width_slider]) | |
image_tab.select(fn=update_task_image, outputs=[mode_dropdown]) | |
text_tab.select(fn=update_task_text, outputs=[mode_dropdown]) | |
video_tab.select(fn=update_task_video, outputs=[mode_dropdown]) | |
# LoRA Gallery Callbacks | |
lora_gallery_display.select( | |
update_lora_selection, | |
outputs=[selected_lora_info_markdown, selected_lora_index_state] | |
) | |
custom_lora_input_path.submit( | |
add_custom_lora_for_ltx, | |
inputs=[custom_lora_input_path], | |
outputs=[custom_lora_status_html, remove_custom_lora_button, lora_gallery_display, selected_lora_info_markdown, selected_lora_index_state, custom_lora_input_path] | |
) | |
remove_custom_lora_button.click( | |
remove_custom_lora_for_ltx, | |
outputs=[custom_lora_status_html, remove_custom_lora_button, lora_gallery_display, selected_lora_info_markdown, selected_lora_index_state, custom_lora_input_path] | |
) | |
# Consolidate inputs for generate function | |
gen_inputs = [ | |
height_slider, width_slider, mode_dropdown, steps_slider, | |
gr.Number(value=96, visible=False), # placeholder for num_frames_slider_val, as it's controlled by duration | |
frames_to_use_slider, | |
seed_number_input, randomize_seed_checkbox, guidance_scale_slider, duration_slider, improve_texture_checkbox, | |
selected_lora_index_state, lora_scale_slider | |
] | |
t2v_button.click(fn=generate, | |
inputs=[t2v_prompt, negative_prompt, image_n_hidden, video_n_hidden] + gen_inputs, | |
outputs=[output_video, seed_number_input]) # Added seed_number_input to outputs | |
i2v_button.click(fn=generate, | |
inputs=[i2v_prompt, negative_prompt, image_i2v, video_i_hidden] + gen_inputs, | |
outputs=[output_video, seed_number_input]) | |
v2v_button.click(fn=generate, | |
inputs=[v2v_prompt, negative_prompt, image_v_hidden, video_v2v] + gen_inputs, | |
outputs=[output_video, seed_number_input]) | |
demo.queue(max_size=10).launch() |