Spaces:
Running
Running
import gradio as gr | |
import tempfile | |
import random | |
import json | |
import re | |
import pretty_midi | |
import subprocess | |
import os | |
from openai import OpenAI | |
# === LLM APIs === | |
def query_llm(prompt, model_name=None): | |
if model_name and model_name != "OpenAI": | |
import requests | |
response = requests.post("http://localhost:11434/api/generate", json={"model": model_name, "prompt": prompt, "stream": False}) | |
return response.json().get("response", "") | |
else: | |
client = OpenAI( | |
base_url="https://api.studio.nebius.com/v1/", | |
api_key=os.environ.get("NEBIUS_API_KEY") | |
) | |
response = client.chat.completions.create( | |
model="Qwen/Qwen3-30B-A3B", | |
messages=[ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant." | |
}, | |
{ | |
"role": "user", | |
"content": prompt | |
} | |
] | |
) | |
return response.choices[0].message.content | |
# === Step 1: Parse intent === | |
def get_intent_from_prompt(prompt, model_name): | |
system_prompt = f""" | |
Extract the musical intent from this prompt. | |
Return JSON with keys: tempo (int), key (A-G#), scale (major/minor), genre (e.g., lo-fi, trap), emotion, instrument. | |
Prompt: '{prompt}' | |
""" | |
response = query_llm(system_prompt, model_name) | |
match = re.search(r'\{.*\}', response, re.DOTALL) | |
if match: | |
try: | |
return json.loads(match.group(0)) | |
except json.JSONDecodeError: | |
return {"tempo": 120, "key": "C", "scale": "major", "genre": "default", "emotion": "neutral", "instrument": "piano"} | |
return {"tempo": 120, "key": "C", "scale": "major", "genre": "default", "emotion": "neutral", "instrument": "piano"} | |
# === Step 2: Melody planning === | |
def get_melody_from_intent(intent, model_name): | |
melody_prompt = f""" | |
You are a music composer. | |
Based on this musical intent: | |
{json.dumps(intent)} | |
Generate a melody plan using a list of 16 notes with pitch (A-G#), octave (3-6), and duration (0.25 to 1.0 seconds). | |
Output ONLY valid JSON like: | |
[ | |
{{"note": "D", "octave": 4, "duration": 0.5}}, | |
{{"note": "F", "octave": 4, "duration": 1.0}} | |
] | |
""" | |
response = query_llm(melody_prompt, model_name) | |
print(f"\n[DEBUG] LLM Response for melody:\n{response}\n") | |
try: | |
json_strs = re.findall(r'\[\s*\{[^]]+\}\s*\]', response, re.DOTALL) | |
for js in json_strs: | |
parsed = json.loads(js) | |
if isinstance(parsed, list) and all("note" in note and "octave" in note and "duration" in note for note in parsed): | |
return parsed | |
except json.JSONDecodeError as e: | |
print(f"[ERROR] Melody JSON decode error: {e}") | |
print("[WARNING] Using fallback melody.") | |
return [ | |
{"note": "C", "octave": 4, "duration": 0.5}, | |
{"note": "E", "octave": 4, "duration": 0.5}, | |
{"note": "G", "octave": 4, "duration": 0.5}, | |
{"note": "B", "octave": 4, "duration": 0.5}, | |
] | |
# === Step 3: MIDI generation === | |
def midi_from_plan(melody, tempo): | |
midi = pretty_midi.PrettyMIDI() | |
instrument = pretty_midi.Instrument(program=0) | |
time = 0.0 | |
seconds_per_beat = 60.0 / tempo | |
note_map = {"C": 0, "C#": 1, "D": 2, "D#": 3, "E": 4, "F": 5, "F#": 6, | |
"G": 7, "G#": 8, "A": 9, "A#": 10, "B": 11} | |
for note_info in melody: | |
try: | |
pitch = 12 * (note_info["octave"] + 1) + note_map[note_info["note"].upper()] | |
duration = float(note_info["duration"]) | |
start = time | |
end = time + duration | |
instrument.notes.append(pretty_midi.Note( | |
velocity=100, pitch=pitch, start=start, end=end | |
)) | |
time = end | |
except: | |
continue | |
midi.instruments.append(instrument) | |
return midi | |
# === Gradio function === | |
def generate_midi_from_prompt(prompt, model_name): | |
intent = get_intent_from_prompt(prompt, model_name) | |
melody = get_melody_from_intent(intent, model_name) | |
midi = midi_from_plan(melody, intent.get("tempo", 120)) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as tmp: | |
midi.write(tmp.name) | |
midi_path = tmp.name | |
return midi_path | |
# === Get Ollama models === | |
def get_ollama_models(): | |
try: | |
result = subprocess.run(["ollama", "list"], capture_output=True, text=True) | |
models = [line.split()[0] for line in result.stdout.strip().splitlines()[1:]] | |
return ["OpenAI"] + models | |
except Exception as e: | |
return ["OpenAI"] | |
# === Gradio UI === | |
models = get_ollama_models() | |
demo = gr.Interface( | |
fn=generate_midi_from_prompt, | |
inputs=[ | |
gr.Textbox(label="Music Prompt"), | |
gr.Dropdown(choices=models, label="LLM Model", value=models[0]) | |
], | |
outputs=[ | |
gr.File(label="๐ต Download MIDI File") | |
], | |
title="๐ผ Music Command Prompt (MCP Agent)", | |
description="Describe your music idea and download a generated MIDI file. Choose from local or Nebius/OpenAI LLMs." | |
) | |
demo.launch(mcp_server=True) | |