videogenerator / app.py
englissi's picture
Update app.py
d0237b3 verified
import os, tempfile
import numpy as np
import torch
import gradio as gr
from diffusers import LTXPipeline, AutoModel
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video
# --------------------------------------------
# ์š”๊ตฌ ํŒจํ‚ค์ง€(Spaces):
# requirements.txt:
# torch>=2.2
# torchvision>=0.17
# accelerate>=0.28.0
# transformers>=4.40.0
# diffusers>=0.31.0
# safetensors>=0.4.2
# sentencepiece>=0.2.0
# gradio>=4.32.0
# imageio>=2.34.0
# imageio-ffmpeg>=0.4.9
# packages.txt:
# ffmpeg
# --------------------------------------------
def load_pipeline():
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
# CPU๋Š” float16/float8 ๋ถˆ๊ฐ€ โ†’ float32๋กœ
dtype = torch.bfloat16 if use_cuda else torch.float32
transformer = AutoModel.from_pretrained(
"Lightricks/LTX-Video",
subfolder="transformer",
torch_dtype=dtype,
# LTXPipeline์€ trust_remote_code๋ฅผ ๋ฌด์‹œํ•˜์ง€๋งŒ ๋„ฃ์–ด๋„ ๋ฌดํ•ด
trust_remote_code=True,
variant="bf16" if (use_cuda and dtype == torch.bfloat16) else None,
)
# FP8์€ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ์—๋งŒ ์‹œ๋„
fp8_ok = False
if use_cuda:
try:
transformer.enable_layerwise_casting(
storage_dtype=torch.float8_e4m3fn, compute_dtype=dtype
)
fp8_ok = True
except Exception:
fp8_ok = False
pipe = LTXPipeline.from_pretrained(
"Lightricks/LTX-Video",
transformer=transformer,
torch_dtype=dtype,
trust_remote_code=True,
variant="bf16" if (use_cuda and dtype == torch.bfloat16) else None,
).to(device)
offload_ok = False
if use_cuda:
try:
onload_device = torch.device(device)
offload_device = torch.device("cpu")
pipe.transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder,
onload_device=onload_device,
offload_type="block_level",
num_blocks_per_group=2,
)
apply_group_offloading(
pipe.vae,
onload_device=onload_device,
offload_type="leaf_level",
)
offload_ok = True
except Exception:
offload_ok = False
return pipe, fp8_ok, offload_ok, device
PIPE, FP8_OK, OFFLOAD_OK, DEVICE = load_pipeline()
def _to_uint8_frames(frames):
# (T,H,W,C) torch/float โ†’ numpy uint8 ๋กœ ์•ˆ์ „ ๋ณ€ํ™˜
if isinstance(frames, torch.Tensor):
frames = frames.detach().to("cpu").numpy()
if frames.ndim == 3: # (T,H,W) โ†’ (T,H,W,1)
frames = frames[..., None]
assert frames.ndim == 4, f"Unexpected frames shape: {frames.shape}"
if frames.dtype != np.uint8:
mx = float(frames.max() if frames.size else 1.0)
if mx <= 1.0:
frames = (np.clip(frames, 0, 1) * 255).astype(np.uint8)
else:
frames = np.clip(frames, 0, 255).astype(np.uint8)
return frames
def generate_video(
prompt, negative_prompt,
width, height, num_frames, fps,
decode_timestep, decode_noise_scale,
steps, seed
):
# ์‹œ๋“œ
g = None
try:
s = int(seed)
if s >= 0:
g = torch.Generator(device=DEVICE).manual_seed(s)
except Exception:
pass
# -------- ์ถ”๋ก  --------
with torch.inference_mode():
out = PIPE(
prompt=(prompt or "").strip(),
negative_prompt=(negative_prompt or "").strip() or None,
width=int(width),
height=int(height),
num_frames=int(num_frames),
# โ˜… LTXPipeline์—๋Š” fps ์ธ์ž๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค.
decode_timestep=float(decode_timestep),
decode_noise_scale=float(decode_noise_scale),
num_inference_steps=int(steps),
generator=g,
)
frames = out.frames[0]
frames = _to_uint8_frames(frames)
# -------- ์ €์žฅ --------
tmpdir = tempfile.mkdtemp()
save_path = os.path.join(tmpdir, "output.mp4")
target_fps = int(fps)
# ์šฐ์„  diffusers saver
try:
export_to_video(frames, save_path, fps=target_fps)
except Exception:
# ํด๋ฐฑ: imageio-ffmpeg
import imageio.v3 as iio
iio.imwrite(save_path, frames, fps=target_fps, codec="libx264")
info = (
f"FP8: {'ON' if FP8_OK else 'OFF'} | "
f"Offloading: {'ON' if OFFLOAD_OK else 'OFF'} | "
f"Device: {DEVICE} | "
f"Frames: {frames.shape} | FPS: {target_fps}"
)
return save_path, info
# ----------------------------- Gradio UI -----------------------------
with gr.Blocks(title="LTX-Video โ€” Prompt to Short Video") as demo:
gr.Markdown("## ๐ŸŽฌ LTX-Video โ€” Prompt to Short Video")
with gr.Row():
prompt_in = gr.Textbox(
label="Prompt",
lines=6,
value="A cinematic close-up of a smiling woman under warm sunset light."
)
neg_in = gr.Textbox(
label="Negative Prompt",
lines=4,
value="worst quality, inconsistent motion, blurry, jittery, distorted"
)
with gr.Row():
width_in = gr.Slider(256, 1024, value=768, step=8, label="Width")
height_in = gr.Slider(256, 1024, value=512, step=8, label="Height")
with gr.Row():
frames_in = gr.Slider(17, 241, value=65, step=2, label="num_frames")
fps_in = gr.Slider(8, 30, value=24, step=1, label="FPS (save only)")
with gr.Row():
dt_in = gr.Slider(0.0, 0.2, value=0.03, step=0.001, label="decode_timestep")
dns_in = gr.Slider(0.0, 0.2, value=0.025, step=0.001, label="decode_noise_scale")
steps_in = gr.Slider(10, 75, value=40, step=1, label="num_inference_steps")
seed_in = gr.Number(value=-1, label="Seed (>=0 to fix)")
gen_btn = gr.Button("๐ŸŽฅ Generate", variant="primary")
video_out = gr.Video(label="Output", autoplay=True)
info_out = gr.Markdown()
gen_btn.click(
fn=generate_video,
inputs=[prompt_in, neg_in, width_in, height_in, frames_in, fps_in, dt_in, dns_in, steps_in, seed_in],
outputs=[video_out, info_out]
)
demo.queue().launch()