|
""" |
|
Adaptive Music Exercise Generator (Strict Duration Enforcement) |
|
============================================================== |
|
Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures |
|
AND time signature, guaranteeing exact durations in MIDI and in the UI! |
|
Major updates: |
|
- Changed base duration unit from 16th notes to 8th notes (1 unit = 8th note) |
|
- Updated all calculations and prompts to use new duration system |
|
- Duration sum display now shows total in 8th notes |
|
- Maintained all original functionality |
|
- Added cumulative duration tracking |
|
- Enforced JSON output format with note, duration, cumulative_duration |
|
- Enhanced rest handling and JSON parsing |
|
- Fixed JSON parsing errors for 8-measure exercises |
|
- Added robust error handling for MIDI generation |
|
""" |
|
|
|
|
|
|
|
|
|
import sys |
|
import subprocess |
|
from typing import Dict, Optional, Tuple, List |
|
|
|
def install(packages: List[str]): |
|
for package in packages: |
|
try: |
|
__import__(package) |
|
except ImportError: |
|
print(f"Installing missing package: {package}") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
|
install([ |
|
"mido", "midi2audio", "pydub", "gradio", |
|
"requests", "numpy", "matplotlib", "librosa", "scipy", |
|
"uuid", "datetime" |
|
]) |
|
|
|
|
|
|
|
|
|
import random |
|
import requests |
|
import json |
|
import tempfile |
|
import mido |
|
from mido import Message, MidiFile, MidiTrack, MetaMessage |
|
import re |
|
from io import BytesIO |
|
from midi2audio import FluidSynth |
|
from pydub import AudioSegment |
|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import librosa |
|
from scipy.io import wavfile |
|
import os |
|
import subprocess as sp |
|
import base64 |
|
import shutil |
|
import ast |
|
import uuid |
|
from datetime import datetime |
|
import time |
|
|
|
|
|
|
|
|
|
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions" |
|
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" |
|
|
|
SOUNDFONT_URLS = { |
|
"Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2", |
|
"Piano": "https://musical-artifacts.com/artifacts/2719/GeneralUser_GS_1.471.sf2", |
|
"Violin": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", |
|
"Clarinet": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", |
|
"Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2", |
|
} |
|
|
|
SAMPLE_RATE = 44100 |
|
TICKS_PER_BEAT = 480 |
|
TICKS_PER_8TH = TICKS_PER_BEAT // 2 |
|
|
|
if not os.path.exists('/usr/bin/fluidsynth'): |
|
try: |
|
os.system('apt-get update && apt-get install -y fluidsynth') |
|
except Exception: |
|
print("Could not install FluidSynth automatically. Please install it manually.") |
|
|
|
os.makedirs("static", exist_ok=True) |
|
os.makedirs("temp_audio", exist_ok=True) |
|
|
|
|
|
|
|
|
|
NOTE_MAP: Dict[str, int] = { |
|
"C": 0, "C#": 1, "DB": 1, |
|
"D": 2, "D#": 3, "EB": 3, |
|
"E": 4, "F": 5, "F#": 6, "GB": 6, |
|
"G": 7, "G#": 8, "AB": 8, |
|
"A": 9, "A#": 10, "BB": 10, |
|
"B": 11, |
|
} |
|
|
|
REST_INDICATORS = ["rest", "r", "Rest", "R", "P", "p", "pause"] |
|
|
|
INSTRUMENT_PROGRAMS: Dict[str, int] = { |
|
"Piano": 0, "Trumpet": 56, "Violin": 40, |
|
"Clarinet": 71, "Flute": 73, |
|
} |
|
|
|
def is_rest(note: str) -> bool: |
|
"""Check if a note string represents a rest.""" |
|
return note.strip().lower() in [r.lower() for r in REST_INDICATORS] |
|
|
|
def note_name_to_midi(note: str) -> int: |
|
if is_rest(note): |
|
return -1 |
|
|
|
|
|
match = re.match(r"([A-Ga-g][#b]?)(\'*)(\d?)", note) |
|
if not match: |
|
raise ValueError(f"Invalid note: {note}") |
|
|
|
pitch, apostrophes, octave = match.groups() |
|
pitch = pitch.upper().replace('b', 'B') |
|
|
|
|
|
octave_num = 4 |
|
if octave: |
|
octave_num = int(octave) |
|
elif apostrophes: |
|
octave_num = 5 + len(apostrophes) |
|
|
|
if pitch not in NOTE_MAP: |
|
raise ValueError(f"Invalid pitch: {pitch}") |
|
|
|
return NOTE_MAP[pitch] + (octave_num + 1) * 12 |
|
|
|
def midi_to_note_name(midi_num: int) -> str: |
|
if midi_num == -1: |
|
return "Rest" |
|
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"] |
|
octave = (midi_num // 12) - 1 |
|
return f"{notes[midi_num % 12]}{octave}" |
|
|
|
|
|
|
|
|
|
def scale_json_durations(json_data, target_units: int) -> list: |
|
"""Scales durations so that their sum is exactly target_units (8th notes).""" |
|
durations = [int(d) for _, d in json_data] |
|
total = sum(durations) |
|
if total == 0: |
|
return json_data |
|
|
|
|
|
scaled = [] |
|
remainder = target_units |
|
for i, (note, d) in enumerate(json_data): |
|
if i < len(json_data) - 1: |
|
|
|
portion = max(1, round(d * target_units / total)) |
|
scaled.append([note, portion]) |
|
remainder -= portion |
|
else: |
|
|
|
scaled.append([note, max(1, remainder)]) |
|
|
|
return scaled |
|
|
|
|
|
|
|
|
|
def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int, key: str = "C Major") -> MidiFile: |
|
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT) |
|
track = MidiTrack(); mid.tracks.append(track) |
|
program = INSTRUMENT_PROGRAMS.get(instrument, 56) |
|
numerator, denominator = map(int, time_signature.split('/')) |
|
|
|
|
|
track.append(MetaMessage('time_signature', numerator=numerator, |
|
denominator=denominator, time=0)) |
|
|
|
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0)) |
|
|
|
|
|
|
|
key_map = { |
|
"C Major": "C", |
|
"G Major": "G", |
|
"D Major": "D", |
|
"F Major": "F", |
|
"Bb Major": "Bb", |
|
"A Minor": "Am", |
|
"E Minor": "Em", |
|
} |
|
|
|
|
|
midi_key = key_map.get(key, "C") |
|
|
|
track.append(MetaMessage('key_signature', key=midi_key, time=0)) |
|
|
|
|
|
track.append(Message('program_change', program=program, time=0)) |
|
|
|
|
|
accumulated_rest = 0 |
|
|
|
for note_item in json_data: |
|
try: |
|
|
|
if isinstance(note_item, list) and len(note_item) == 2: |
|
note_name, duration_units = note_item |
|
elif isinstance(note_item, dict): |
|
note_name = note_item["note"] |
|
duration_units = note_item["duration"] |
|
else: |
|
print(f"Unsupported note format: {note_item}") |
|
continue |
|
|
|
ticks = int(duration_units * TICKS_PER_8TH) |
|
ticks = max(ticks, 1) |
|
|
|
if is_rest(note_name): |
|
|
|
accumulated_rest += ticks |
|
else: |
|
|
|
if accumulated_rest > 0: |
|
|
|
|
|
|
|
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest)) |
|
track.append(Message('note_off', note=0, velocity=0, time=0)) |
|
accumulated_rest = 0 |
|
|
|
|
|
note_num = note_name_to_midi(note_name) |
|
velocity = random.randint(60, 100) |
|
track.append(Message('note_on', note=note_num, velocity=velocity, time=0)) |
|
track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks)) |
|
except Exception as e: |
|
print(f"Error parsing note {note_item}: {e}") |
|
|
|
|
|
if accumulated_rest > 0: |
|
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest)) |
|
track.append(Message('note_off', note=0, velocity=0, time=0)) |
|
|
|
return mid |
|
|
|
|
|
|
|
|
|
def get_soundfont(instrument: str) -> str: |
|
os.makedirs("soundfonts", exist_ok=True) |
|
sf2_path = f"soundfonts/{instrument}.sf2" |
|
if not os.path.exists(sf2_path): |
|
url = SOUNDFONT_URLS.get(instrument, SOUNDFONT_URLS["Trumpet"]) |
|
print(f"Downloading SoundFont for {instrument}…") |
|
response = requests.get(url) |
|
with open(sf2_path, "wb") as f: |
|
f.write(response.content) |
|
return sf2_path |
|
|
|
def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, float]: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file: |
|
midi_obj.save(mid_file.name) |
|
wav_path = mid_file.name.replace(".mid", ".wav") |
|
mp3_path = mid_file.name.replace(".mid", ".mp3") |
|
sf2_path = get_soundfont(instrument) |
|
try: |
|
sp.run([ |
|
'fluidsynth', '-ni', sf2_path, mid_file.name, |
|
'-F', wav_path, '-r', '44100', '-g', '1.0' |
|
], check=True, capture_output=True) |
|
except Exception: |
|
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0) |
|
fs.midi_to_audio(mid_file.name, wav_path) |
|
try: |
|
sound = AudioSegment.from_wav(wav_path) |
|
if instrument == "Trumpet": |
|
sound = sound.high_pass_filter(200) |
|
elif instrument == "Violin": |
|
sound = sound.low_pass_filter(5000) |
|
sound.export(mp3_path, format="mp3") |
|
static_mp3_path = os.path.join('static', os.path.basename(mp3_path)) |
|
shutil.move(mp3_path, static_mp3_path) |
|
return static_mp3_path, sound.duration_seconds |
|
finally: |
|
for f in [mid_file.name, wav_path]: |
|
try: |
|
os.remove(f) |
|
except FileNotFoundError: |
|
pass |
|
|
|
|
|
|
|
|
|
def get_fallback_exercise(instrument: str, level: str, key: str, |
|
time_sig: str, measures: int) -> str: |
|
key_notes = { |
|
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4"], |
|
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4"], |
|
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5"], |
|
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4"], |
|
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4"], |
|
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4"], |
|
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4"], |
|
} |
|
|
|
|
|
fundamental_note = key.split()[0] |
|
is_major = "Major" in key |
|
|
|
|
|
notes = key_notes.get(key, key_notes["C Major"]) |
|
|
|
|
|
fundamental_with_octave = None |
|
for note in notes: |
|
if note.startswith(fundamental_note): |
|
fundamental_with_octave = note |
|
break |
|
|
|
|
|
if not fundamental_with_octave: |
|
fundamental_with_octave = notes[0] |
|
|
|
numerator, denominator = map(int, time_sig.split('/')) |
|
|
|
|
|
units_per_measure = numerator * (8 // denominator) |
|
target_units = measures * units_per_measure |
|
|
|
|
|
if numerator == 3: |
|
rhythm = [2, 1, 2, 1, 2] |
|
else: |
|
rhythm = [2, 2, 1, 1, 2, 2] |
|
|
|
|
|
result = [] |
|
cumulative = 0 |
|
current_units = 0 |
|
|
|
|
|
final_note_duration = min(4, max(2, rhythm[0])) |
|
available_units = target_units - final_note_duration |
|
|
|
|
|
while current_units < available_units: |
|
|
|
if is_major: |
|
|
|
available_notes = [n for n in notes if not (n.startswith("Bb") and key == "C Major") and |
|
not (n.startswith("F") and key == "G Major") and |
|
not (n.startswith("C") and key == "D Major") and |
|
not (n.startswith("Eb") and key == "F Major") and |
|
not (n.startswith("Ab") and key == "Bb Major")] |
|
else: |
|
available_notes = notes |
|
|
|
note = random.choice(available_notes) |
|
dur = random.choice(rhythm) |
|
|
|
|
|
if current_units + dur > available_units: |
|
dur = available_units - current_units |
|
if dur <= 0: |
|
break |
|
|
|
cumulative += dur |
|
current_units += dur |
|
result.append({ |
|
"note": note, |
|
"duration": dur, |
|
"cumulative_duration": cumulative |
|
}) |
|
|
|
|
|
final_duration = target_units - current_units |
|
if final_duration > 0: |
|
cumulative += final_duration |
|
result.append({ |
|
"note": fundamental_with_octave, |
|
"duration": final_duration, |
|
"cumulative_duration": cumulative |
|
}) |
|
|
|
return json.dumps(result) |
|
|
|
def get_style_based_on_level(level: str) -> str: |
|
styles = { |
|
"Beginner": ["simple", "legato", "stepwise", "folk-like", "gentle"], |
|
"Intermediate": ["jazzy", "bluesy", "march-like", "syncopated", "dance-like", "lyrical"], |
|
"Advanced": ["technical", "chromatic", "fast arpeggios", "wide intervals", "virtuosic", "complex", "contemporary"], |
|
} |
|
return random.choice(styles.get(level, ["technical"])) |
|
|
|
def get_technique_based_on_level(level: str) -> str: |
|
techniques = { |
|
"Beginner": [ |
|
"with long tones", "with simple rhythms", "focusing on tone", |
|
"with step-wise motion", "with easy intervals", "focusing on breath control", |
|
"with simple articulation", "with repeated patterns" |
|
], |
|
"Intermediate": [ |
|
"with slurs", "with accents", "using triplets", "with moderate syncopation", |
|
"with varied articulation", "with moderate interval jumps", "with dynamic contrast", |
|
"with scale patterns", "with simple ornaments", "with moderate register changes" |
|
], |
|
"Advanced": [ |
|
"with double tonguing", "with extreme registers", "with complex rhythms", |
|
"with challenging intervals", "with rapid articulation", "with advanced ornaments", |
|
"with extended techniques", "with complex syncopation", "with virtuosic passages", |
|
"with extreme dynamic contrast", "with challenging arpeggios" |
|
], |
|
} |
|
return random.choice(techniques.get(level, ["with slurs"])) |
|
|
|
|
|
|
|
|
|
def query_mistral(prompt: str, instrument: str, level: str, key: str, |
|
time_sig: str, measures: int, difficulty_modifier: int = 0, |
|
practice_focus: str = "Balanced") -> str: |
|
headers = { |
|
"Authorization": f"Bearer {MISTRAL_API_KEY}", |
|
"Content-Type": "application/json", |
|
} |
|
numerator, denominator = map(int, time_sig.split('/')) |
|
|
|
|
|
units_per_measure = numerator * (8 // denominator) |
|
required_total = measures * units_per_measure |
|
|
|
|
|
duration_constraint = ( |
|
f"Sum of all durations MUST BE EXACTLY {required_total} units (8th notes). " |
|
f"Each integer duration represents an 8th note (1=8th, 2=quarter, 4=half, 8=whole). " |
|
f"If it doesn't match, the exercise is invalid." |
|
) |
|
system_prompt = ( |
|
f"You are an expert music teacher specializing in {instrument.lower()}. " |
|
"Create customized exercises using INTEGER durations representing 8th notes." |
|
) |
|
|
|
if prompt.strip(): |
|
user_prompt = ( |
|
f"{prompt} {duration_constraint} Output ONLY a JSON array of objects with " |
|
"the following structure: [{{'note': string, 'duration': integer, 'cumulative_duration': integer}}]" |
|
) |
|
else: |
|
|
|
effective_level = level |
|
if difficulty_modifier != 0: |
|
level_map = {"Beginner": 0, "Intermediate": 1, "Advanced": 2} |
|
level_list = ["Beginner", "Intermediate", "Advanced"] |
|
base_level_idx = level_map.get(level, 1) |
|
adjusted_idx = max(0, min(2, base_level_idx + difficulty_modifier)) |
|
effective_level = level_list[adjusted_idx] |
|
|
|
style = get_style_based_on_level(effective_level) |
|
technique = get_technique_based_on_level(effective_level) |
|
|
|
|
|
fundamental_note = key.split()[0] |
|
is_major = "Major" in key |
|
|
|
|
|
key_constraints = ( |
|
f"The exercise MUST end on the fundamental note of the key ({fundamental_note}). " |
|
f"{'' if not is_major else 'For this major key, avoid using the minor 7th degree.'}" |
|
) |
|
|
|
|
|
focus_constraints = "" |
|
if practice_focus == "Rhythmic Focus": |
|
focus_constraints = "Include varied rhythmic patterns with syncopation and different note durations. " |
|
elif practice_focus == "Melodic Focus": |
|
focus_constraints = "Create a melodically interesting line with good contour and phrasing. " |
|
elif practice_focus == "Technical Focus": |
|
focus_constraints = "Include technical challenges like arpeggios, scales, or interval jumps. " |
|
elif practice_focus == "Expressive Focus": |
|
focus_constraints = "Design a lyrical exercise with opportunities for dynamic contrast and expression. " |
|
|
|
|
|
difficulty_desc = "" |
|
if difficulty_modifier > 0: |
|
difficulty_desc = f"Make this slightly more challenging than a typical {level.lower()} exercise. " |
|
elif difficulty_modifier < 0: |
|
difficulty_desc = f"Make this slightly easier than a typical {level.lower()} exercise. " |
|
|
|
user_prompt = ( |
|
f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature " |
|
f"{technique} for a {level.lower()} player. {difficulty_desc}{focus_constraints}{duration_constraint} {key_constraints} " |
|
"Output ONLY a JSON array of objects with the following structure: " |
|
"[{{'note': string, 'duration': integer, 'cumulative_duration': integer}}] " |
|
"Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. " |
|
"Durations: 1=8th, 2=quarter, 4=half, 8=whole. " |
|
"Sum must be exactly as specified. ONLY output the JSON array. No prose." |
|
) |
|
|
|
payload = { |
|
"model": "mistral-medium", |
|
"messages": [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
], |
|
"temperature": 0.7 if level == "Advanced" else 0.5, |
|
"max_tokens": 1000, |
|
"top_p": 0.95, |
|
"frequency_penalty": 0.2, |
|
"presence_penalty": 0.2, |
|
} |
|
|
|
try: |
|
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) |
|
response.raise_for_status() |
|
content = response.json()["choices"][0]["message"]["content"] |
|
return content.replace("```json","").replace("```","").strip() |
|
except Exception as e: |
|
print(f"Error querying Mistral API: {e}") |
|
return get_fallback_exercise(instrument, level, key, time_sig, measures) |
|
|
|
|
|
|
|
|
|
def safe_parse_json(text: str) -> Optional[list]: |
|
try: |
|
text = text.strip().replace("'", '"') |
|
|
|
|
|
start_idx = text.find('[') |
|
end_idx = text.rfind(']') |
|
if start_idx == -1 or end_idx == -1: |
|
return None |
|
|
|
json_str = text[start_idx:end_idx+1] |
|
|
|
|
|
json_str = re.sub(r',\s*([}\]])', r'\1', json_str) |
|
json_str = re.sub(r'{\s*(\w+)\s*:', r'{"\1":', json_str) |
|
json_str = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)(\s*[,}])', r':"\1"\2', json_str) |
|
|
|
parsed = json.loads(json_str) |
|
|
|
|
|
normalized = [] |
|
for item in parsed: |
|
if isinstance(item, dict): |
|
|
|
note_val = None |
|
for key in ['note', 'pitch', 'nota', 'ton']: |
|
if key in item: |
|
note_val = str(item[key]) |
|
break |
|
|
|
|
|
dur_val = None |
|
for key in ['duration', 'dur', 'length', 'value']: |
|
if key in item: |
|
try: |
|
dur_val = int(item[key]) |
|
except (TypeError, ValueError): |
|
pass |
|
|
|
if note_val is not None and dur_val is not None: |
|
normalized.append({"note": note_val, "duration": dur_val}) |
|
|
|
return normalized if normalized else None |
|
|
|
except Exception as e: |
|
print(f"JSON parsing error: {e}\nRaw text: {text}") |
|
return None |
|
|
|
|
|
|
|
|
|
def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str, |
|
measures: int, custom_prompt: str, mode: str, difficulty_modifier: int = 0, |
|
practice_focus: str = "Balanced") -> Tuple[str, Optional[str], str, MidiFile, str, str, int]: |
|
try: |
|
prompt_to_use = custom_prompt if mode == "Exercise Prompt" else "" |
|
output = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures, difficulty_modifier, practice_focus) |
|
parsed = safe_parse_json(output) |
|
if not parsed: |
|
print("Primary parsing failed, using fallback") |
|
fallback_str = get_fallback_exercise(instrument, level, key, time_signature, measures) |
|
parsed = safe_parse_json(fallback_str) |
|
if not parsed: |
|
print("Fallback parsing failed, using ultimate fallback") |
|
|
|
key_notes = { |
|
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4", "C5"], |
|
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4", "G4"], |
|
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5", "D5"], |
|
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4", "F4"], |
|
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4", "Bb4"], |
|
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4", "A4"], |
|
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4", "E4"], |
|
} |
|
notes = key_notes.get(key, key_notes["C Major"]) |
|
numerator, denominator = map(int, time_signature.split('/')) |
|
units_per_measure = numerator * (8 // denominator) |
|
target_units = measures * units_per_measure |
|
note_duration = max(1, target_units // len(notes)) |
|
parsed = [{"note": n, "duration": note_duration} for n in notes] |
|
|
|
total = sum(item["duration"] for item in parsed) |
|
if total < target_units: |
|
parsed[-1]["duration"] += target_units - total |
|
elif total > target_units: |
|
parsed[-1]["duration"] -= total - target_units |
|
|
|
|
|
numerator, denominator = map(int, time_signature.split('/')) |
|
units_per_measure = numerator * (8 // denominator) |
|
total_units = measures * units_per_measure |
|
|
|
|
|
old_format = [] |
|
for item in parsed: |
|
if isinstance(item, dict): |
|
old_format.append([item["note"], item["duration"]]) |
|
else: |
|
old_format.append(item) |
|
|
|
|
|
parsed_scaled_old = scale_json_durations(old_format, total_units) |
|
|
|
|
|
cumulative = 0 |
|
parsed_scaled = [] |
|
for note, dur in parsed_scaled_old: |
|
cumulative += dur |
|
parsed_scaled.append({ |
|
"note": note, |
|
"duration": dur, |
|
"cumulative_duration": cumulative |
|
}) |
|
|
|
|
|
total_duration = cumulative |
|
|
|
|
|
midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures, key) |
|
mp3_path, real_duration = midi_to_mp3(midi, instrument) |
|
output_json_str = json.dumps(parsed_scaled, indent=2) |
|
return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration |
|
except Exception as e: |
|
return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0 |
|
|
|
|
|
|
|
|
|
def handle_chat(message: str, history: List, instrument: str, level: str): |
|
if not message.strip(): |
|
return "", history |
|
messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}] |
|
for user_msg, assistant_msg in history: |
|
messages.append({"role": "user", "content": user_msg}) |
|
messages.append({"role": "assistant", "content": assistant_msg}) |
|
messages.append({"role": "user", "content": message}) |
|
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"} |
|
payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500} |
|
try: |
|
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload) |
|
response.raise_for_status() |
|
content = response.json()["choices"][0]["message"]["content"] |
|
history.append((message, content)) |
|
return "", history |
|
except Exception as e: |
|
history.append((message, f"Error: {str(e)}")) |
|
return "", history |
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_visualization(json_data, time_sig): |
|
try: |
|
if not json_data or "Error" in json_data: |
|
return None |
|
|
|
parsed = json.loads(json_data) |
|
if not isinstance(parsed, list) or len(parsed) == 0: |
|
return None |
|
|
|
|
|
notes = [] |
|
durations = [] |
|
for item in parsed: |
|
if isinstance(item, dict) and "note" in item and "duration" in item: |
|
note_name = item["note"] |
|
if not is_rest(note_name): |
|
try: |
|
midi_note = note_name_to_midi(note_name) |
|
notes.append(midi_note) |
|
durations.append(item["duration"]) |
|
except ValueError: |
|
notes.append(60) |
|
durations.append(item["duration"]) |
|
else: |
|
notes.append(None) |
|
durations.append(item["duration"]) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
|
|
time_positions = [0] |
|
for dur in durations[:-1]: |
|
time_positions.append(time_positions[-1] + dur) |
|
|
|
|
|
for i, (note, dur, pos) in enumerate(zip(notes, durations, time_positions)): |
|
if note is not None: |
|
rect = plt.Rectangle((pos, note-0.4), dur, 0.8, color='blue', alpha=0.7) |
|
ax.add_patch(rect) |
|
|
|
ax.text(pos + dur/2, note+0.5, midi_to_note_name(note), |
|
ha='center', va='bottom', fontsize=8) |
|
|
|
|
|
numerator, denominator = map(int, time_sig.split('/')) |
|
units_per_measure = numerator * (8 // denominator) |
|
max_time = time_positions[-1] + durations[-1] |
|
for measure in range(1, int(max_time / units_per_measure) + 1): |
|
measure_pos = measure * units_per_measure |
|
if measure_pos <= max_time: |
|
ax.axvline(x=measure_pos, color='gray', linestyle='--', alpha=0.5) |
|
|
|
|
|
ax.set_ylim(min(notes) - 5 if None not in notes else 55, |
|
max(notes) + 5 if None not in notes else 75) |
|
ax.set_xlim(0, max_time) |
|
ax.set_ylabel('MIDI Note') |
|
ax.set_xlabel('Time (8th note units)') |
|
ax.set_title('Exercise Visualization') |
|
|
|
|
|
ax.set_yticks([60, 62, 64, 65, 67, 69, 71, 72]) |
|
ax.set_yticklabels(['C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5']) |
|
ax.grid(True, axis='y', alpha=0.3) |
|
|
|
|
|
temp_img_path = os.path.join('static', f'visualization_{uuid.uuid4().hex}.png') |
|
plt.tight_layout() |
|
plt.savefig(temp_img_path) |
|
plt.close() |
|
|
|
return temp_img_path |
|
except Exception as e: |
|
print(f"Error creating visualization: {e}") |
|
return None |
|
|
|
|
|
def create_vexflow_notation(json_data, time_sig, key_sig): |
|
|
|
def durationToVex(units): |
|
if units == 1: |
|
return "8" |
|
elif units == 2: |
|
return "4" |
|
elif units == 3: |
|
return "4d" |
|
elif units == 4: |
|
return "2" |
|
elif units == 6: |
|
return "2d" |
|
elif units == 8: |
|
return "1" |
|
else: |
|
return "8" |
|
|
|
if not json_data or "Error" in json_data: |
|
return None |
|
|
|
try: |
|
parsed = json.loads(json_data) |
|
if not isinstance(parsed, list) or len(parsed) == 0: |
|
return None |
|
|
|
|
|
html_content = f''' |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="utf-8"> |
|
<title>Music Notation</title> |
|
<script src="https://cdn.jsdelivr.net/npm/vexflow@4.2.2/build/cjs/vexflow.js"></script> |
|
<style> |
|
#output {{width: 100%; overflow: auto;}} |
|
body {{font-family: Arial, sans-serif;}} |
|
h2 {{color: #333;}} |
|
</style> |
|
</head> |
|
<body> |
|
<h2>Exercise in {key_sig}, {time_sig}</h2> |
|
<div id="output"></div> |
|
<script> |
|
const {{Factory, EasyScore, System}} = Vex.Flow; |
|
|
|
// Create VexFlow factory and context |
|
const vf = new Factory({{renderer: {{elementId: 'output', width: 1200, height: 200}}}}); |
|
const score = vf.EasyScore(); |
|
const system = vf.System(); |
|
|
|
// Parse notes from JSON |
|
const jsonData = {json.dumps(parsed)}; |
|
|
|
// Convert to VexFlow notation |
|
let vexNotes = []; |
|
let currentMeasure = []; |
|
let currentDuration = 0; |
|
const timeSignature = "{time_sig}"; |
|
const [numerator, denominator] = timeSignature.split('/').map(Number); |
|
const unitsPerMeasure = numerator * (8 / denominator); |
|
|
|
// Helper function to convert duration units to VexFlow duration |
|
function durationToVex(units) {{ |
|
if (units === 1) return "8"; |
|
if (units === 2) return "4"; |
|
if (units === 3) return "4d"; |
|
if (units === 4) return "2"; |
|
if (units === 6) return "2d"; |
|
if (units === 8) return "1"; |
|
return "8"; |
|
}} |
|
|
|
// Process notes |
|
jsonData.forEach(item => {{ |
|
const noteName = item.note; |
|
const duration = item.duration; |
|
|
|
// Skip invalid notes |
|
if (!noteName || duration <= 0) return; |
|
|
|
// Handle rests |
|
const isRest = /rest|r|p/i.test(noteName); |
|
let vexNote; |
|
|
|
if (isRest) {{ |
|
vexNote = `B4/${{durationToVex(duration)}}/r`; |
|
}} else {{ |
|
// Convert scientific notation to VexFlow format |
|
// VexFlow uses lowercase for note names |
|
const noteRegex = /([A-Ga-g][#b]?)(\d)/; |
|
const match = noteName.match(noteRegex); |
|
if (match) {{ |
|
const [_, pitch, octave] = match; |
|
vexNote = `${{pitch.toLowerCase()}}${{octave}}/${{durationToVex(duration)}}`; |
|
}} else {{ |
|
// Default if parsing fails |
|
vexNote = `c4/${{durationToVex(duration)}}`; |
|
}} |
|
}} |
|
|
|
currentMeasure.push(vexNote); |
|
currentDuration += duration; |
|
|
|
// Check if measure is complete |
|
if (currentDuration >= unitsPerMeasure) {{ |
|
vexNotes.push(currentMeasure); |
|
currentMeasure = []; |
|
currentDuration = 0; |
|
}} |
|
}}); |
|
|
|
// Add any remaining notes |
|
if (currentMeasure.length > 0) {{ |
|
vexNotes.push(currentMeasure); |
|
}} |
|
|
|
// Create staves and add notes |
|
const staves = []; |
|
const measuresPerLine = 4; |
|
|
|
for (let i = 0; i < vexNotes.length; i += measuresPerLine) {{ |
|
const lineStaves = []; |
|
const lineNotes = vexNotes.slice(i, i + measuresPerLine); |
|
|
|
// Create a new system for each line |
|
const lineSystem = vf.System({{width: 1100}}); |
|
|
|
// Add staves for each measure in the line |
|
lineNotes.forEach((measure, index) => {{ |
|
const stave = lineSystem.addStave({{ |
|
voices: [ |
|
score.voice(score.notes(measure.join(', '))) |
|
] |
|
}}); |
|
|
|
// Add time signature and key to first measure of first line |
|
if (i === 0 && index === 0) {{ |
|
stave.addTimeSignature(timeSignature); |
|
stave.addKeySignature("{key_sig.split()[0]}"); |
|
}} |
|
}}); |
|
|
|
lineSystem.addConnector("singleRight"); |
|
staves.push(lineSystem); |
|
}} |
|
|
|
// Format and draw |
|
vf.draw(); |
|
</script> |
|
</body> |
|
</html> |
|
''' |
|
|
|
|
|
|
|
try: |
|
html_path = os.path.join('static', f'notation_{uuid.uuid4().hex}.html') |
|
with open(html_path, 'w') as f: |
|
f.write(html_content) |
|
except Exception as file_error: |
|
print(f"Warning: Could not save notation file: {file_error}") |
|
|
|
|
|
return html_content |
|
except Exception as e: |
|
print(f"Error creating VexFlow notation: {e}") |
|
return "<p>Failed to generate music notation. Error: " + str(e) + "</p>" |
|
|
|
|
|
def create_metronome_audio(tempo, time_sig, measures): |
|
try: |
|
numerator, denominator = map(int, time_sig.split('/')) |
|
|
|
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT) |
|
track = MidiTrack() |
|
mid.tracks.append(track) |
|
|
|
|
|
track.append(MetaMessage('time_signature', numerator=numerator, |
|
denominator=denominator, time=0)) |
|
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(int(tempo)), time=0)) |
|
|
|
|
|
beats_per_measure = numerator |
|
total_beats = beats_per_measure * measures |
|
|
|
|
|
for beat in range(total_beats): |
|
|
|
note_num = 77 if beat % beats_per_measure == 0 else 76 |
|
velocity = 100 if beat % beats_per_measure == 0 else 80 |
|
|
|
|
|
if beat == 0: |
|
track.append(Message('note_on', note=note_num, velocity=velocity, time=0)) |
|
else: |
|
|
|
track.append(Message('note_on', note=note_num, velocity=velocity, time=TICKS_PER_BEAT)) |
|
|
|
|
|
track.append(Message('note_off', note=note_num, velocity=0, time=10)) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file: |
|
mid.save(mid_file.name) |
|
wav_path = mid_file.name.replace(".mid", ".wav") |
|
mp3_path = mid_file.name.replace(".mid", ".mp3") |
|
|
|
|
|
sf2_path = get_soundfont("Piano") |
|
try: |
|
sp.run([ |
|
'fluidsynth', '-ni', sf2_path, mid_file.name, |
|
'-F', wav_path, '-r', '44100', '-g', '1.0' |
|
], check=True, capture_output=True) |
|
except Exception: |
|
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0) |
|
fs.midi_to_audio(mid_file.name, wav_path) |
|
|
|
|
|
sound = AudioSegment.from_wav(wav_path) |
|
sound.export(mp3_path, format="mp3") |
|
|
|
|
|
static_mp3_path = os.path.join('static', f'metronome_{uuid.uuid4().hex}.mp3') |
|
shutil.move(mp3_path, static_mp3_path) |
|
|
|
|
|
for f in [mid_file.name, wav_path]: |
|
try: |
|
os.remove(f) |
|
except FileNotFoundError: |
|
pass |
|
|
|
return static_mp3_path |
|
except Exception as e: |
|
print(f"Error creating metronome: {e}") |
|
return None |
|
|
|
|
|
|
|
|
|
def calculate_difficulty_rating(json_data, level, difficulty_modifier=0, practice_focus="Balanced"): |
|
try: |
|
if not json_data or "Error" in json_data: |
|
return 0 |
|
|
|
parsed = json.loads(json_data) |
|
if not isinstance(parsed, list) or len(parsed) == 0: |
|
return 0 |
|
|
|
|
|
notes = [] |
|
durations = [] |
|
for item in parsed: |
|
if isinstance(item, dict) and "note" in item and "duration" in item: |
|
note_name = item["note"] |
|
if not is_rest(note_name): |
|
try: |
|
midi_note = note_name_to_midi(note_name) |
|
notes.append(midi_note) |
|
durations.append(item["duration"]) |
|
except ValueError: |
|
pass |
|
|
|
if not notes: |
|
return 0 |
|
|
|
|
|
|
|
note_range = max(notes) - min(notes) if notes else 0 |
|
range_factor = min(note_range / 12, 1.0) |
|
|
|
|
|
unique_durations = len(set(durations)) |
|
rhythm_factor = min(unique_durations / 4, 1.0) |
|
|
|
|
|
jumps = [abs(notes[i] - notes[i-1]) for i in range(1, len(notes))] |
|
avg_jump = sum(jumps) / len(jumps) if jumps else 0 |
|
jump_factor = min(avg_jump / 7, 1.0) |
|
|
|
|
|
avg_duration = sum(durations) / len(durations) if durations else 0 |
|
speed_factor = min(2.0 / avg_duration if avg_duration > 0 else 1.0, 1.0) |
|
|
|
|
|
weights = {"range": 0.25, "rhythm": 0.25, "jump": 0.25, "speed": 0.25} |
|
|
|
if practice_focus == "Rhythmic Focus": |
|
weights = {"range": 0.15, "rhythm": 0.55, "jump": 0.15, "speed": 0.15} |
|
elif practice_focus == "Melodic Focus": |
|
weights = {"range": 0.40, "rhythm": 0.15, "jump": 0.30, "speed": 0.15} |
|
elif practice_focus == "Technical Focus": |
|
weights = {"range": 0.25, "rhythm": 0.15, "jump": 0.40, "speed": 0.20} |
|
elif practice_focus == "Expressive Focus": |
|
weights = {"range": 0.35, "rhythm": 0.25, "jump": 0.25, "speed": 0.15} |
|
|
|
|
|
base_difficulty = ( |
|
range_factor * weights["range"] + |
|
rhythm_factor * weights["rhythm"] + |
|
jump_factor * weights["jump"] + |
|
speed_factor * weights["speed"] |
|
) |
|
|
|
|
|
level_multiplier = { |
|
"Beginner": 0.7, |
|
"Intermediate": 1.0, |
|
"Advanced": 1.3 |
|
}.get(level, 1.0) |
|
|
|
|
|
modifier_multiplier = 1.0 + (difficulty_modifier * 0.15) |
|
|
|
|
|
rating = round(base_difficulty * level_multiplier * modifier_multiplier * 10) |
|
return max(1, min(rating, 10)) |
|
except Exception as e: |
|
print(f"Error calculating difficulty: {e}") |
|
return 0 |
|
|
|
|
|
|
|
|
|
def create_ui() -> gr.Blocks: |
|
with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo: |
|
gr.Markdown("# 🎼 Adaptive Music Exercise Generator") |
|
current_midi = gr.State(None) |
|
current_exercise = gr.State("") |
|
current_audio_path = gr.State(None) |
|
|
|
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode") |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
with gr.Group(visible=True) as params_group: |
|
gr.Markdown("### Exercise Parameters") |
|
instrument = gr.Dropdown([ |
|
"Trumpet", "Piano", "Violin", "Clarinet", "Flute", |
|
], value="Trumpet", label="Instrument") |
|
level = gr.Radio([ |
|
"Beginner", "Intermediate", "Advanced", |
|
], value="Intermediate", label="Difficulty Level") |
|
difficulty_modifier = gr.Slider(minimum=-2, maximum=2, value=0, step=1, |
|
label="Difficulty Modifier", |
|
info="Fine-tune the difficulty: -2 (easier) to +2 (harder)") |
|
practice_focus = gr.Dropdown([ |
|
"Balanced", "Rhythmic Focus", "Melodic Focus", "Technical Focus", "Expressive Focus" |
|
], value="Balanced", label="Practice Focus") |
|
key = gr.Dropdown([ |
|
"C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor", |
|
], value="C Major", label="Key Signature") |
|
time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature") |
|
measures = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)") |
|
with gr.Group(visible=False) as prompt_group: |
|
gr.Markdown("### Exercise Prompt") |
|
custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3) |
|
measures_prompt = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)") |
|
generate_btn = gr.Button("Generate Exercise", variant="primary") |
|
with gr.Column(scale=2): |
|
with gr.Tabs(): |
|
with gr.TabItem("Exercise Player"): |
|
audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath") |
|
with gr.Row(): |
|
bpm_display = gr.Textbox(label="Tempo (BPM)") |
|
time_sig_display = gr.Textbox(label="Time Signature") |
|
duration_display = gr.Textbox(label="Audio Duration", interactive=False) |
|
with gr.Row(): |
|
difficulty_rating = gr.Number(label="Difficulty Rating (1-10)", interactive=False, precision=1) |
|
|
|
|
|
gr.Markdown("### Metronome") |
|
with gr.Row(): |
|
metronome_tempo = gr.Slider(minimum=40, maximum=200, value=60, step=1, label="Metronome Tempo") |
|
metronome_btn = gr.Button("Generate Metronome", variant="secondary") |
|
metronome_audio = gr.Audio(label="Metronome", type="filepath") |
|
|
|
with gr.TabItem("Exercise Data"): |
|
json_output = gr.Code(label="JSON Representation", language="json") |
|
duration_sum = gr.Number( |
|
label="Total Duration Units (8th notes)", |
|
interactive=False, |
|
precision=0 |
|
) |
|
|
|
with gr.TabItem("Visualization"): |
|
visualization_output = gr.Image(label="Exercise Visualization", type="filepath") |
|
visualize_btn = gr.Button("Generate Visualization", variant="secondary") |
|
|
|
with gr.TabItem("Music Notation"): |
|
notation_html = gr.HTML(label="Music Notation") |
|
notation_btn = gr.Button("Generate Music Notation", variant="secondary") |
|
|
|
with gr.TabItem("MIDI Export"): |
|
midi_output = gr.File(label="MIDI File") |
|
download_midi = gr.Button("Generate MIDI File") |
|
|
|
|
|
|
|
with gr.TabItem("AI Chat"): |
|
chat_history = gr.Chatbot(label="Practice Assistant", height=400) |
|
chat_message = gr.Textbox(label="Ask the AI anything about your practice") |
|
send_chat_btn = gr.Button("Send") |
|
|
|
mode.change( |
|
fn=lambda m: { |
|
params_group: gr.update(visible=(m == "Exercise Parameters")), |
|
prompt_group: gr.update(visible=(m == "Exercise Prompt")), |
|
}, |
|
inputs=[mode], outputs=[params_group, prompt_group] |
|
) |
|
def generate_caller(mode_val, instrument_val, level_val, key_val, |
|
time_sig_val, measures_val, prompt_val, measures_prompt_val, |
|
difficulty_modifier_val, practice_focus_val): |
|
real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val |
|
fixed_tempo = 60 |
|
json_data, mp3_path, tempo, midi, duration, time_sig, total_duration = generate_exercise( |
|
instrument_val, level_val, key_val, fixed_tempo, time_sig_val, |
|
real_measures, prompt_val, mode_val, difficulty_modifier_val, practice_focus_val |
|
) |
|
|
|
|
|
rating = calculate_difficulty_rating(json_data, level_val, difficulty_modifier_val, practice_focus_val) |
|
|
|
|
|
viz_path = create_visualization(json_data, time_sig_val) |
|
|
|
|
|
html_content = create_vexflow_notation(json_data, time_sig_val, key_val) |
|
if not html_content: |
|
html_content = "" |
|
|
|
return json_data, mp3_path, tempo, midi, duration, time_sig, total_duration, rating, viz_path, mp3_path, html_content |
|
|
|
generate_btn.click( |
|
fn=generate_caller, |
|
inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt, |
|
difficulty_modifier, practice_focus], |
|
outputs=[json_output, audio_output, bpm_display, current_midi, duration_display, |
|
time_sig_display, duration_sum, difficulty_rating, visualization_output, current_audio_path, notation_html] |
|
) |
|
|
|
|
|
visualize_btn.click( |
|
fn=create_visualization, |
|
inputs=[json_output, time_signature], |
|
outputs=[visualization_output] |
|
) |
|
|
|
|
|
def display_notation(json_data, time_sig, key_val): |
|
html_content = create_vexflow_notation(json_data, time_sig, key_val) |
|
if html_content: |
|
return html_content |
|
return "<p>Failed to generate music notation.</p>" |
|
|
|
notation_btn.click( |
|
fn=display_notation, |
|
inputs=[json_output, time_signature, key], |
|
outputs=[notation_html] |
|
) |
|
|
|
|
|
def generate_metronome(tempo, time_sig, measures_val): |
|
return create_metronome_audio(tempo, time_sig, measures_val) |
|
|
|
metronome_btn.click( |
|
fn=generate_metronome, |
|
inputs=[metronome_tempo, time_signature, measures], |
|
outputs=[metronome_audio] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def save_midi(json_data, instr, time_sig, key_sig="C Major"): |
|
try: |
|
if not json_data or "Error" in json_data: |
|
return None |
|
|
|
parsed = json.loads(json_data) |
|
|
|
|
|
if not isinstance(parsed, list): |
|
return None |
|
|
|
old_format = [] |
|
for item in parsed: |
|
if isinstance(item, dict) and "note" in item and "duration" in item: |
|
old_format.append([item["note"], item["duration"]]) |
|
|
|
if not old_format: |
|
return None |
|
|
|
|
|
total_units = sum(d[1] for d in old_format) |
|
numerator, denominator = map(int, time_sig.split('/')) |
|
units_per_measure = numerator * (8 // denominator) |
|
measures_est = max(1, round(total_units / units_per_measure)) |
|
|
|
|
|
cumulative = 0 |
|
scaled_new = [] |
|
for note, dur in old_format: |
|
cumulative += dur |
|
scaled_new.append({ |
|
"note": note, |
|
"duration": dur, |
|
"cumulative_duration": cumulative |
|
}) |
|
|
|
midi_obj = json_to_midi(scaled_new, instr, 60, time_sig, measures_est, key=key_sig) |
|
midi_path = os.path.join("static", "exercise.mid") |
|
midi_obj.save(midi_path) |
|
return midi_path |
|
except Exception as e: |
|
print(f"Error saving MIDI: {e}") |
|
return None |
|
|
|
download_midi.click( |
|
fn=save_midi, |
|
inputs=[json_output, instrument, time_signature, key], |
|
outputs=[midi_output], |
|
) |
|
send_chat_btn.click( |
|
fn=handle_chat, |
|
inputs=[chat_message, chat_history, instrument, level], |
|
outputs=[chat_message, chat_history], |
|
) |
|
return demo |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_ui() |
|
demo.launch() |