Spaces:
Runtime error
Runtime error
# MusicGen + Gradio + GPT Demo App (CPU-Optimized with MCP Server) | |
import gradio as gr | |
import os | |
import numpy as np | |
import torch | |
from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
from openai import OpenAI | |
import scipy.io.wavfile | |
# Force CPU device (no GPU required) | |
device = torch.device("cpu") | |
# Load MusicGen model onto CPU | |
model_name = "facebook/musicgen-small" | |
model = MusicgenForConditionalGeneration.from_pretrained(model_name).to(device) | |
processor = AutoProcessor.from_pretrained(model_name) | |
# Initialize OpenAI client (set OPENAI_API_KEY in HF Spaces Secrets) | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# Refine user prompt via GPT | |
def refine_prompt(user_input): | |
completion = client.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You are a music assistant. Make the user's input more descriptive for an AI music generator."}, | |
{"role": "user", "content": user_input} | |
] | |
) | |
return completion.choices[0].message.content.strip() | |
# Generate music (shorter tokens for CPU speed) | |
def generate_music(prompt, max_new_tokens: int = 128): | |
inputs = processor(text=[prompt], return_tensors="pt").to(device) | |
audio_values = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
sampling_rate = model.config.audio_encoder.sampling_rate | |
audio = audio_values[0].cpu().numpy() | |
# Normalize to float32 in -1.0 to 1.0 range for Gradio | |
audio = audio / np.max(np.abs(audio)) | |
audio = audio.astype(np.float32) | |
# Prepare int16 version and ensure 1D for WAV | |
int_audio = (audio * 32767).astype(np.int16) | |
int_audio = np.squeeze(int_audio) | |
if int_audio.ndim > 1: | |
int_audio = int_audio[:, 0] | |
# Save as .wav file (in /tmp for Spaces) | |
scipy.io.wavfile.write("/tmp/output.wav", sampling_rate, int_audio) | |
return sampling_rate, audio | |
# Combined Gradio function | |
def main(user_input, max_new_tokens): | |
detailed_prompt = refine_prompt(user_input) | |
sampling_rate, audio = generate_music(detailed_prompt, max_new_tokens) | |
return detailed_prompt, (sampling_rate, audio), "/tmp/output.wav" | |
# Build Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("""# π΅ AI Music Generator | |
Enter a music idea or mood and get a short AI-generated track. (CPU mode)""") | |
user_input = gr.Textbox(label="Describe the mood or style of music") | |
max_tokens = gr.Slider(32, 256, value=128, step=32, label="Length (tokens) for CPU") | |
generate_btn = gr.Button("Generate Music") | |
refined_output = gr.Textbox(label="Enhanced Prompt by GPT") | |
audio_output = gr.Audio(label="Generated Audio", type="numpy") | |
download_wav = gr.File(label="Download .wav file") | |
generate_btn.click( | |
main, | |
inputs=[user_input, max_tokens], | |
outputs=[refined_output, audio_output, download_wav] | |
) | |
# Launch with Gradio MCP Server | |
from gradio.mcp_server import MCPServer | |
if __name__ == "__main__": | |
server = MCPServer(demo, host="0.0.0.0", port=7860) | |
server.run() | |