File size: 1,687 Bytes
474a58d
9e8156a
9812f31
 
febf9d1
 
9812f31
 
 
213805c
9812f31
213805c
1f1ad42
6517794
9812f31
 
 
1f1ad42
9812f31
 
 
 
1f1ad42
9812f31
 
febf9d1
9812f31
febf9d1
9812f31
 
 
 
 
 
 
213805c
1f1ad42
9812f31
 
 
 
1f1ad42
9812f31
 
213805c
 
9812f31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
import imageio
import numpy as np
from PIL import Image
import tempfile
import random
import os

# Lataa malli CPU:lle (muokkaa "cuda" jos GPU käytössä)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe.to(device)
print("Model loaded!")

def generate_image(prompt, num_images=1):
    generator = torch.Generator(device=device).manual_seed(random.randint(0, 2**32-1))
    images = pipe(prompt=prompt, num_inference_steps=25, generator=generator).images
    return images[0]

def generate_video(prompt, duration=2, fps=5):
    frames = int(duration * fps)
    temp_dir = tempfile.mkdtemp()
    video_path = os.path.join(temp_dir, f"output_{random.randint(10000,99999)}.mp4")
    
    writer = imageio.get_writer(video_path, fps=fps)
    for i in range(frames):
        img = generate_image(prompt)
        img_array = np.array(img)
        writer.append_data(img_array)
    writer.close()
    return video_path

with gr.Blocks() as demo:
    gr.Markdown("# Simple SD Video Generator")
    prompt_input = gr.Textbox(label="Prompt", value="A majestic dragon flying over a castle")
    duration_input = gr.Slider(label="Duration (seconds)", minimum=1, maximum=10, value=2, step=1)
    output_video = gr.Video(label="Generated Video")
    
    gen_button = gr.Button("Generate Video")
    gen_button.click(fn=generate_video, inputs=[prompt_input, duration_input], outputs=[output_video])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)