Spaces:
Sleeping
Sleeping
""" | |
Adaptive Music Exercise Generator (Strict Duration Enforcement) | |
============================================================== | |
Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures | |
AND time signature, guaranteeing exact durations in MIDI and in the UI! | |
Major updates: | |
- Added Gemma, Kimi Dev 72b, and Llama 3.1 AI model options | |
- Added duration sum display in Exercise Data tab | |
- Shows total duration units (16th notes) for verification | |
- Added DeepSeek AI model option | |
- Fixed difficulty level implementation | |
- Maintained all original functionality | |
""" | |
# ----------------------------------------------------------------------------- | |
# 1. Runtime-time package installation (for fresh containers/Colab/etc) | |
# ----------------------------------------------------------------------------- | |
import sys | |
import subprocess | |
from typing import Dict, Optional, Tuple, List | |
import time | |
import random | |
def install(packages: List[str]): | |
for package in packages: | |
try: | |
__import__(package) | |
except ImportError: | |
print(f"Installing missing package: {package}") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install([ | |
"mido", "midi2audio", "pydub", "gradio", "openai", | |
"requests", "numpy", "matplotlib", "librosa", "scipy", | |
]) | |
# ----------------------------------------------------------------------------- | |
# 2. Static imports | |
# ----------------------------------------------------------------------------- | |
import requests | |
import json | |
import tempfile | |
import mido | |
from mido import Message, MidiFile, MidiTrack, MetaMessage | |
import re | |
from io import BytesIO | |
from midi2audio import FluidSynth | |
from pydub import AudioSegment | |
import gradio as gr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import librosa | |
from scipy.io import wavfile | |
import os | |
import subprocess as sp | |
import base64 | |
import shutil | |
from openai import OpenAI # For API models | |
# ----------------------------------------------------------------------------- | |
# 3. Configuration & constants | |
# ----------------------------------------------------------------------------- | |
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" | |
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # Replace with your key | |
OPENROUTER_API_KEYS = { | |
"DeepSeek": "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", | |
"Claude": "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", | |
"Gemma": "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", | |
"Kimi": "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", | |
"Llama 3.1": "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" | |
} | |
SOUNDFONT_URLS = { | |
"Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2", | |
"Piano": "https://musical-artifacts.com/artifacts/2719/GeneralUser_GS_1.471.sf2", | |
"Violin": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
"Clarinet": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
"Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", | |
} | |
SAMPLE_RATE = 44100 # Hz | |
TICKS_PER_BEAT = 480 # Standard MIDI resolution | |
TICKS_PER_16TH = TICKS_PER_BEAT // 4 # 120 ticks per 16th note | |
if not os.path.exists('/usr/bin/fluidsynth'): | |
try: | |
os.system('apt-get update && apt-get install -y fluidsynth') | |
except Exception: | |
print("Could not install FluidSynth automatically. Please install it manually.") | |
os.makedirs("static", exist_ok=True) | |
# ----------------------------------------------------------------------------- | |
# 4. Music theory helpers (note names ↔︎ MIDI numbers) | |
# ----------------------------------------------------------------------------- | |
NOTE_MAP: Dict[str, int] = { | |
"C": 0, "C#": 1, "DB": 1, | |
"D": 2, "D#": 3, "EB": 3, | |
"E": 4, "F": 5, "F#": 6, "GB": 6, | |
"G": 7, "G#": 8, "AB": 8, | |
"A": 9, "A#": 10, "BB": 10, | |
"B": 11, | |
} | |
INSTRUMENT_PROGRAMS: Dict[str, int] = { | |
"Piano": 0, "Trumpet": 56, "Violin": 40, | |
"Clarinet": 71, "Flute": 73, | |
} | |
def note_name_to_midi(note: str) -> int: | |
match = re.match(r"([A-Ga-g][#b]?)(\d)", note) | |
if not match: | |
raise ValueError(f"Invalid note: {note}") | |
pitch, octave = match.groups() | |
pitch = pitch.upper().replace('b', 'B') | |
if pitch not in NOTE_MAP: | |
raise ValueError(f"Invalid pitch: {pitch}") | |
return NOTE_MAP[pitch] + (int(octave) + 1) * 12 | |
def midi_to_note_name(midi_num: int) -> str: | |
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] | |
octave = (midi_num // 12) - 1 | |
return f"{notes[midi_num % 12]}{octave}" | |
# ----------------------------------------------------------------------------- | |
# 5. Duration scaling: guarantee the output sums to requested total (using integers) | |
# ----------------------------------------------------------------------------- | |
def scale_json_durations(json_data, target_units: int) -> list: | |
"""Scales durations so that their sum is exactly target_units (16th notes).""" | |
durations = [int(d) for _, d in json_data] | |
total = sum(durations) | |
if total == 0: | |
return json_data | |
# Calculate proportional scaling with integer arithmetic | |
scaled = [] | |
remainder = target_units | |
for i, (note, d) in enumerate(json_data): | |
if i < len(json_data) - 1: | |
# Proportional allocation | |
portion = max(1, round(d * target_units / total)) | |
scaled.append([note, portion]) | |
remainder -= portion | |
else: | |
# Last note gets all remaining units | |
scaled.append([note, max(1, remainder)]) | |
return scaled | |
# ----------------------------------------------------------------------------- | |
# 6. MIDI from scaled JSON (using integer durations) | |
# ----------------------------------------------------------------------------- | |
def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int) -> MidiFile: | |
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT) | |
track = MidiTrack(); mid.tracks.append(track) | |
program = INSTRUMENT_PROGRAMS.get(instrument, 56) | |
numerator, denominator = map(int, time_signature.split('/')) | |
track.append(MetaMessage('time_signature', numerator=numerator, | |
denominator=denominator, time=0)) | |
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0)) | |
track.append(Message('program_change', program=program, time=0)) | |
for note_name, duration_units in json_data: | |
try: | |
note_num = note_name_to_midi(note_name) | |
ticks = int(duration_units * TICKS_PER_16TH) | |
ticks = max(ticks, 1) | |
velocity = random.randint(60, 100) | |
track.append(Message('note_on', note=note_num, velocity=velocity, time=0)) | |
track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks)) | |
except Exception as e: | |
print(f"Error parsing note {note_name}: {e}") | |
return mid | |
# ----------------------------------------------------------------------------- | |
# 7. MIDI → Audio (MP3) helpers | |
# ----------------------------------------------------------------------------- | |
def get_soundfont(instrument: str) -> str: | |
os.makedirs("soundfonts", exist_ok=True) | |
sf2_path = f"soundfonts/{instrument}.sf2" | |
if not os.path.exists(sf2_path): | |
url = SOUNDFONT_URLS.get(instrument, SOUNDFONT_URLS["Trumpet"]) | |
print(f"Downloading SoundFont for {instrument}…") | |
response = requests.get(url) | |
with open(sf2_path, "wb") as f: | |
f.write(response.content) | |
return sf2_path | |
def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, float]: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file: | |
midi_obj.save(mid_file.name) | |
wav_path = mid_file.name.replace(".mid", ".wav") | |
mp3_path = mid_file.name.replace(".mid", ".mp3") | |
sf2_path = get_soundfont(instrument) | |
try: | |
sp.run([ | |
'fluidsynth', '-ni', sf2_path, mid_file.name, | |
'-F', wav_path, '-r', '44100', '-g', '1.0' | |
], check=True, capture_output=True) | |
except Exception: | |
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0) | |
fs.midi_to_audio(mid_file.name, wav_path) | |
try: | |
sound = AudioSegment.from_wav(wav_path) | |
if instrument == "Trumpet": | |
sound = sound.high_pass_filter(200) | |
elif instrument == "Violin": | |
sound = sound.low_pass_filter(5000) | |
sound.export(mp3_path, format="mp3") | |
static_mp3_path = os.path.join('static', os.path.basename(mp3_path)) | |
shutil.move(mp3_path, static_mp3_path) | |
return static_mp3_path, sound.duration_seconds | |
finally: | |
for f in [mid_file.name, wav_path]: | |
try: | |
os.remove(f) | |
except FileNotFoundError: | |
pass | |
# ----------------------------------------------------------------------------- | |
# 8. Prompt engineering for variety (using integer durations) | |
# ----------------------------------------------------------------------------- | |
def get_fallback_exercise(instrument: str, level: str, key: str, | |
time_sig: str, measures: int) -> str: | |
instrument_patterns = { | |
"Trumpet": ["C4", "D4", "E4", "G4", "E4", "C4"], | |
"Piano": ["C4", "E4", "G4", "C5", "G4", "E4"], | |
"Violin": ["G4", "A4", "B4", "D5", "B4", "G4"], | |
"Clarinet": ["E4", "F4", "G4", "Bb4", "G4", "E4"], | |
"Flute": ["A4", "B4", "C5", "E5", "C5", "A4"], | |
} | |
pattern = instrument_patterns.get(instrument, instrument_patterns["Trumpet"]) | |
numerator, denominator = map(int, time_sig.split('/')) | |
units_per_measure = numerator * (16 // denominator) | |
target_units = measures * units_per_measure | |
notes, durs = [], [] | |
i = 0 | |
# Use quarter notes (4 units) as base duration | |
while len(notes) * 4 < target_units: | |
notes.append(pattern[i % len(pattern)]) | |
durs.append(4) | |
i += 1 | |
# Adjust last duration to match total exactly | |
total_units = len(durs) * 4 | |
if total_units > target_units: | |
durs[-1] = 4 - (total_units - target_units) | |
return json.dumps([[n, d] for n, d in zip(notes, durs)]) | |
def get_style_based_on_level(level: str) -> str: | |
styles = { | |
"Beginner": ["simple", "legato", "stepwise"], | |
"Intermediate": ["jazzy", "bluesy", "march-like", "syncopated"], | |
"Advanced": ["technical", "chromatic", "fast arpeggios", "wide intervals"], | |
} | |
return random.choice(styles.get(level, ["technical"])) | |
def get_technique_based_on_level(level: str) -> str: | |
techniques = { | |
"Beginner": ["with long tones", "with simple rhythms", "focusing on tone"], | |
"Intermediate": ["with slurs", "with accents", "using triplets"], | |
"Advanced": ["with double tonguing", "with extreme registers", "complex rhythms"], | |
} | |
return random.choice(techniques.get(level, ["with slurs"])) | |
# ----------------------------------------------------------------------------- | |
# 9. LLM Query Function (with enhanced error handling) | |
# ----------------------------------------------------------------------------- | |
def query_llm(model_name: str, prompt: str, instrument: str, level: str, key: str, | |
time_sig: str, measures: int) -> str: | |
numerator, denominator = map(int, time_sig.split('/')) | |
units_per_measure = numerator * (16 // denominator) | |
required_total = measures * units_per_measure | |
duration_constraint = ( | |
f"Sum of all durations MUST BE EXACTLY {required_total} units (16th notes). " | |
f"Each integer duration represents a 16th note (1=16th, 2=8th, 4=quarter, 8=half, 16=whole). " | |
f"If it doesn't match, the exercise is invalid." | |
) | |
system_prompt = ( | |
f"You are an expert music teacher specializing in {instrument.lower()}. " | |
"Create customized exercises using INTEGER durations representing 16th notes." | |
) | |
if prompt.strip(): | |
user_prompt = ( | |
f"{prompt} {duration_constraint} Output ONLY a JSON array of [note, duration] pairs." | |
) | |
else: | |
style = get_style_based_on_level(level) | |
technique = get_technique_based_on_level(level) | |
user_prompt = ( | |
f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature " | |
f"{technique} for a {level.lower()} player. {duration_constraint} " | |
"Output ONLY a JSON array of [note, duration] pairs following these rules: " | |
"Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. " | |
"Durations: 1=16th, 2=8th, 4=quarter, 8=half, 16=whole. " | |
"Sum must be exactly as specified. ONLY output the JSON array. No prose." | |
) | |
# Retry up to 3 times for rate limited models | |
max_retries = 3 | |
retry_delay = 5 # seconds | |
for attempt in range(max_retries): | |
try: | |
if model_name == "Mistral": | |
headers = { | |
"Authorization": f"Bearer {MISTRAL_API_KEY}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"model": "mistral-medium", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
], | |
"temperature": 0.7 if level == "Advanced" else 0.5, | |
"max_tokens": 1000, | |
"top_p": 0.95, | |
"frequency_penalty": 0.2, | |
"presence_penalty": 0.2, | |
} | |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
response.raise_for_status() | |
content = response.json()["choices"][0]["message"]["content"] | |
return content.replace("```json","").replace("```","").strip() | |
elif model_name in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]: | |
client = OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=OPENROUTER_API_KEYS[model_name], | |
) | |
model_map = { | |
"DeepSeek": "deepseek/deepseek-chat-v3-0324:free", | |
"Claude": "anthropic/claude-3.5-sonnet:beta", | |
"Gemma": "google/gemma-3n-e2b-it:free", | |
"Kimi": "moonshotai/kimi-dev-72b:free", | |
"Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free" | |
} | |
# Special handling for Gemma API structure | |
if model_name == "Gemma": | |
messages = [ | |
{"role": "user", "content": user_prompt} | |
] | |
else: | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
] | |
completion = client.chat.completions.create( | |
extra_headers={ | |
"HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator", | |
"X-Title": "Music Exercise Generator", | |
}, | |
model=model_map[model_name], | |
messages=messages, | |
temperature=0.7 if level == "Advanced" else 0.5, | |
max_tokens=1000, | |
top_p=0.95, | |
frequency_penalty=0.2, | |
presence_penalty=0.2, | |
) | |
content = completion.choices[0].message.content | |
return content.replace("```json","").replace("```","").strip() | |
else: | |
return get_fallback_exercise(instrument, level, key, time_sig, measures) | |
except Exception as e: | |
print(f"Error querying {model_name} API (attempt {attempt+1}): {e}") | |
if "429" in str(e) or "Rate limit" in str(e): | |
print(f"Rate limited, retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
else: | |
break | |
# Fallback to Mistral if other APIs fail | |
print(f"All attempts failed for {model_name}, using Mistral fallback") | |
try: | |
headers = { | |
"Authorization": f"Bearer {MISTRAL_API_KEY}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"model": "mistral-medium", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}, | |
], | |
"temperature": 0.7 if level == "Advanced" else 0.5, | |
"max_tokens": 1000, | |
"top_p": 0.95, | |
"frequency_penalty": 0.2, | |
"presence_penalty": 0.2, | |
} | |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
response.raise_for_status() | |
content = response.json()["choices"][0]["message"]["content"] | |
return content.replace("```json","").replace("```","").strip() | |
except Exception as e: | |
print(f"Error querying Mistral fallback: {e}") | |
return get_fallback_exercise(instrument, level, key, time_sig, measures) | |
# ----------------------------------------------------------------------------- | |
# 10. Robust JSON parsing for LLM outputs | |
# ----------------------------------------------------------------------------- | |
def safe_parse_json(text: str) -> Optional[list]: | |
try: | |
text = text.replace("'", '"') | |
match = re.search(r"\[(\s*\[.*?\]\s*,?)*\]", text, re.DOTALL) | |
if match: | |
return json.loads(match.group(0)) | |
return json.loads(text) | |
except Exception as e: | |
print(f"JSON parsing error: {e}\nRaw text: {text}") | |
return None | |
# ----------------------------------------------------------------------------- | |
# 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values | |
# ----------------------------------------------------------------------------- | |
def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str, | |
measures: int, custom_prompt: str, mode: str, ai_model: str) -> Tuple[str, Optional[str], str, MidiFile, str, str, int]: | |
try: | |
prompt_to_use = custom_prompt if mode == "Exercise Prompt" else "" | |
output = query_llm(ai_model, prompt_to_use, instrument, level, key, time_signature, measures) | |
parsed = safe_parse_json(output) | |
if not parsed: | |
return "Invalid JSON format", None, str(tempo), None, "0", time_signature, 0 | |
# Calculate total required 16th notes | |
numerator, denominator = map(int, time_signature.split('/')) | |
units_per_measure = numerator * (16 // denominator) | |
total_units = measures * units_per_measure | |
# Strict scaling | |
parsed_scaled = scale_json_durations(parsed, total_units) | |
# Calculate total duration units | |
total_duration = sum(d for _, d in parsed_scaled) | |
# Generate MIDI and audio | |
midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures) | |
mp3_path, real_duration = midi_to_mp3(midi, instrument) | |
output_json_str = json.dumps(parsed_scaled, indent=2) | |
return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration | |
except Exception as e: | |
return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0 | |
# ----------------------------------------------------------------------------- | |
# 12. AI chat assistant with enhanced error handling | |
# ----------------------------------------------------------------------------- | |
def handle_chat(message: str, history: List, instrument: str, level: str, ai_model: str): | |
if not message.strip(): | |
return "", history | |
messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
max_retries = 3 | |
retry_delay = 3 # seconds | |
for attempt in range(max_retries): | |
try: | |
if ai_model == "Mistral": | |
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"} | |
payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500} | |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
response.raise_for_status() | |
content = response.json()["choices"][0]["message"]["content"] | |
history.append((message, content)) | |
return "", history | |
elif ai_model in ["DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"]: | |
client = OpenAI( | |
base_url="https://openrouter.ai/api/v1", | |
api_key=OPENROUTER_API_KEYS[ai_model], | |
) | |
model_map = { | |
"DeepSeek": "deepseek/deepseek-chat-v3-0324:free", | |
"Claude": "anthropic/claude-3.5-sonnet:beta", | |
"Gemma": "google/gemma-3n-e2b-it:free", | |
"Kimi": "moonshotai/kimi-dev-72b:free", | |
"Llama 3.1": "meta-llama/llama-3.1-405b-instruct:free" | |
} | |
# Special handling for Gemma API structure | |
if ai_model == "Gemma": | |
adjusted_messages = [{"role": "user", "content": msg["content"]} for msg in messages] | |
else: | |
adjusted_messages = messages | |
completion = client.chat.completions.create( | |
extra_headers={ | |
"HTTP-Referer": "https://github.com/AdaptiveMusicExerciseGenerator", | |
"X-Title": "Music Exercise Generator", | |
}, | |
model=model_map[ai_model], | |
messages=adjusted_messages, | |
temperature=0.7, | |
max_tokens=500, | |
) | |
content = completion.choices[0].message.content | |
history.append((message, content)) | |
return "", history | |
else: | |
history.append((message, "Error: Invalid AI model selected")) | |
return "", history | |
except Exception as e: | |
print(f"Chat error with {ai_model} (attempt {attempt+1}): {e}") | |
if "429" in str(e) or "Rate limit" in str(e): | |
print(f"Rate limited, retrying in {retry_delay} seconds...") | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
else: | |
# Fallback to Mistral | |
print(f"Using Mistral fallback for chat") | |
try: | |
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"} | |
payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500} | |
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) | |
response.raise_for_status() | |
content = response.json()["choices"][0]["message"]["content"] | |
history.append((message, content)) | |
return "", history | |
except Exception as e: | |
history.append((message, f"Error: {str(e)}")) | |
return "", history | |
history.append((message, "Error: All API attempts failed")) | |
return "", history | |
# ----------------------------------------------------------------------------- | |
# 13. Gradio user interface definition | |
# ----------------------------------------------------------------------------- | |
def create_ui() -> gr.Blocks: | |
with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo: | |
gr.Markdown("# 🎼 Adaptive Music Exercise Generator") | |
current_midi = gr.State(None) | |
current_exercise = gr.State("") | |
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(visible=True) as params_group: | |
gr.Markdown("### Exercise Parameters") | |
ai_model = gr.Radio( | |
["Mistral", "DeepSeek", "Claude", "Gemma", "Kimi", "Llama 3.1"], | |
value="Mistral", | |
label="AI Model" | |
) | |
instrument = gr.Dropdown([ | |
"Trumpet", "Piano", "Violin", "Clarinet", "Flute", | |
], value="Trumpet", label="Instrument") | |
level = gr.Radio([ | |
"Beginner", "Intermediate", "Advanced", | |
], value="Intermediate", label="Difficulty Level") | |
key = gr.Dropdown([ | |
"C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor", | |
], value="C Major", label="Key Signature") | |
time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature") | |
measures = gr.Radio([4, 8], value=4, label="Length (measures)") | |
with gr.Group(visible=False) as prompt_group: | |
gr.Markdown("### Exercise Prompt") | |
custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3) | |
measures_prompt = gr.Radio([4, 8], value=4, label="Length (measures)") | |
generate_btn = gr.Button("Generate Exercise", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("Exercise Player"): | |
audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath") | |
bpm_display = gr.Textbox(label="Tempo (BPM)") | |
time_sig_display = gr.Textbox(label="Time Signature") | |
duration_display = gr.Textbox(label="Audio Duration", interactive=False) | |
with gr.TabItem("Exercise Data"): | |
json_output = gr.Code(label="JSON Representation", language="json") | |
# Duration sum display | |
duration_sum = gr.Number( | |
label="Total Duration Units (16th notes)", | |
interactive=False, | |
precision=0 | |
) | |
with gr.TabItem("MIDI Export"): | |
midi_output = gr.File(label="MIDI File") | |
download_midi = gr.Button("Generate MIDI File") | |
with gr.TabItem("AI Chat"): | |
chat_history = gr.Chatbot(label="Practice Assistant", height=400) | |
chat_message = gr.Textbox(label="Ask the AI anything about your practice") | |
send_chat_btn = gr.Button("Send") | |
# Toggle UI groups | |
mode.change( | |
fn=lambda m: { | |
params_group: gr.update(visible=(m == "Exercise Parameters")), | |
prompt_group: gr.update(visible=(m == "Exercise Prompt")), | |
}, | |
inputs=[mode], outputs=[params_group, prompt_group] | |
) | |
def generate_caller(mode_val, instrument_val, level_val, key_val, | |
time_sig_val, measures_val, prompt_val, measures_prompt_val, ai_model_val): | |
real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val | |
fixed_tempo = 60 | |
return generate_exercise( | |
instrument_val, level_val, key_val, fixed_tempo, time_sig_val, | |
real_measures, prompt_val, mode_val, ai_model_val | |
) | |
generate_btn.click( | |
fn=generate_caller, | |
inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt, ai_model], | |
outputs=[json_output, audio_output, bpm_display, current_midi, duration_display, time_sig_display, duration_sum] | |
) | |
def save_midi(json_data, instr, time_sig): | |
parsed = safe_parse_json(json_data) | |
if not parsed: | |
return None | |
numerator, denominator = map(int, time_sig.split('/')) | |
units_per_measure = numerator * (16 // denominator) | |
total_units = sum(int(d[1]) for d in parsed) | |
measures_est = max(1, round(total_units / units_per_measure)) | |
scaled = scale_json_durations(parsed, measures_est * units_per_measure) | |
midi_obj = json_to_midi(scaled, instr, 60, time_sig, measures_est) | |
midi_path = os.path.join("static", "exercise.mid") | |
midi_obj.save(midi_path) | |
return midi_path | |
download_midi.click( | |
fn=save_midi, | |
inputs=[json_output, instrument, time_signature], | |
outputs=[midi_output], | |
) | |
send_chat_btn.click( | |
fn=handle_chat, | |
inputs=[chat_message, chat_history, instrument, level, ai_model], | |
outputs=[chat_message, chat_history], | |
) | |
return demo | |
# ----------------------------------------------------------------------------- | |
# 14. Entry point | |
# ----------------------------------------------------------------------------- | |
if __name__ == "__main__": | |
demo = create_ui() | |
demo.launch() |