File size: 2,074 Bytes
56798b8
2cb7990
b66115a
 
56798b8
1a39554
 
 
b66115a
56798b8
 
 
 
1a39554
 
36a9fc1
 
1a39554
 
2cb7990
56798b8
88a21ad
 
 
56798b8
1a39554
56798b8
 
1a39554
 
a32d6c4
88a21ad
1a39554
2cb7990
88a21ad
 
 
 
 
2cb7990
 
 
1a39554
2cb7990
 
88a21ad
1a39554
 
88a21ad
1a39554
 
 
88a21ad
2cb7990
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import spaces
import gradio as gr

from transformers import MusicgenForConditionalGeneration
music_gen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
sampling_rate = music_gen_model.config.audio_encoder.sampling_rate



from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")

from diffusers import DiffusionPipeline

sd_pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
# sd_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")


@spaces.GPU
def generate_music(desc):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    music_gen_model.to(device)

    inputs = processor(text=[desc], padding=True, return_tensors="pt")
    audio_values = music_gen_model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
    return sampling_rate, audio_values[0][0].cpu().numpy()
    
@spaces.GPU
def generate_pic(desc):
    device = "cuda" #if torch.cuda.is_available() else "cpu"
    sd_pipe.to(device)
    return sd_pipe(prompt=desc).images[0]

@spaces.GPU
def test_gpu():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return device

with gr.Blocks() as app:
    with gr.Row():
        music_desc = gr.TextArea(label="Music Description")
        music_pic = gr.Image(label="Music Image(StableDiffusion)")
        music_player = gr.Audio(label="Play My Tune")

    device_name = gr.Text(label='device name', interactive=False)
    gen_pic_btn = gr.Button("Gen Picture")
    gen_music_btn = gr.Button("Get Some Tune!!")
    has_gpu_btn = gr.Button("test gpu")

    gen_pic_btn.click(fn=generate_pic, inputs=[music_desc], outputs=[music_pic])
    gen_music_btn.click(fn=generate_music, inputs=[music_desc], outputs=[music_player])
    has_gpu_btn.click(fn=test_gpu, outputs=[device_name])
    
if __name__ == '__main__':
    app.launch()