Spaces:
Runtime error
Runtime error
File size: 2,074 Bytes
56798b8 2cb7990 b66115a 56798b8 1a39554 b66115a 56798b8 1a39554 36a9fc1 1a39554 2cb7990 56798b8 88a21ad 56798b8 1a39554 56798b8 1a39554 a32d6c4 88a21ad 1a39554 2cb7990 88a21ad 2cb7990 1a39554 2cb7990 88a21ad 1a39554 88a21ad 1a39554 88a21ad 2cb7990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
import spaces
import gradio as gr
from transformers import MusicgenForConditionalGeneration
music_gen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
sampling_rate = music_gen_model.config.audio_encoder.sampling_rate
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
from diffusers import DiffusionPipeline
sd_pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
# sd_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
@spaces.GPU
def generate_music(desc):
device = "cuda" if torch.cuda.is_available() else "cpu"
music_gen_model.to(device)
inputs = processor(text=[desc], padding=True, return_tensors="pt")
audio_values = music_gen_model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=256)
return sampling_rate, audio_values[0][0].cpu().numpy()
@spaces.GPU
def generate_pic(desc):
device = "cuda" #if torch.cuda.is_available() else "cpu"
sd_pipe.to(device)
return sd_pipe(prompt=desc).images[0]
@spaces.GPU
def test_gpu():
device = "cuda" if torch.cuda.is_available() else "cpu"
return device
with gr.Blocks() as app:
with gr.Row():
music_desc = gr.TextArea(label="Music Description")
music_pic = gr.Image(label="Music Image(StableDiffusion)")
music_player = gr.Audio(label="Play My Tune")
device_name = gr.Text(label='device name', interactive=False)
gen_pic_btn = gr.Button("Gen Picture")
gen_music_btn = gr.Button("Get Some Tune!!")
has_gpu_btn = gr.Button("test gpu")
gen_pic_btn.click(fn=generate_pic, inputs=[music_desc], outputs=[music_pic])
gen_music_btn.click(fn=generate_music, inputs=[music_desc], outputs=[music_player])
has_gpu_btn.click(fn=test_gpu, outputs=[device_name])
if __name__ == '__main__':
app.launch() |