File size: 57,173 Bytes
ee5a9d4 6b126b7 cf5f729 6b126b7 d53605c cf5f729 d53605c 6b126b7 d53605c e30d28d d53605c 6b126b7 d53605c 5c25abb 9299bf9 d53605c 6b126b7 d53605c 5c25abb 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 5c25abb 6b126b7 cf5f729 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c cf5f729 6b126b7 d53605c 6b126b7 d53605c 6b126b7 cf5f729 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c cf5f729 6b126b7 63e6edc b319fef d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 981df93 57f3c4a 981df93 57f3c4a 981df93 6b126b7 d53605c 57f3c4a d53605c 6b126b7 cf5f729 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 57f3c4a 6b126b7 57f3c4a d53605c 6b126b7 57f3c4a 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 3a0729b d53605c 6b126b7 3a0729b d53605c ee5a9d4 d53605c 6b126b7 d53605c b319fef 6b126b7 5c25abb d53605c 6b126b7 d53605c 6b126b7 57f3c4a 86e68e0 d53605c 57f3c4a 6b126b7 d53605c 6b126b7 cf5f729 981df93 6b126b7 cf5f729 6b126b7 d53605c 6b126b7 86e68e0 981df93 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 61fe200 c5eac10 5ea985a 61fe200 e8437a9 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe acdbd04 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe 61fe200 cfdf6fe a3592e2 cfdf6fe 61fe200 cfdf6fe 61fe200 e8437a9 61fe200 e8437a9 61fe200 e8437a9 61fe200 6b126b7 57f3c4a 6b126b7 57f3c4a 6b126b7 57f3c4a 6b126b7 57f3c4a 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 c770e18 d53605c 6b126b7 57f3c4a 6b126b7 86e68e0 d53605c 6b126b7 d53605c c770e18 d53605c 6b126b7 d53605c 6b126b7 981df93 6b126b7 981df93 6b126b7 61fe200 d53605c 6b126b7 ae60899 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 57f3c4a d53605c 6b126b7 57f3c4a d53605c 6b126b7 57f3c4a 6b126b7 61fe200 57f3c4a e8437a9 61fe200 6b126b7 d53605c 57f3c4a 6b126b7 61fe200 d53605c cf5f729 6b126b7 61fe200 e8437a9 61fe200 6b126b7 ae60899 6b126b7 ae60899 cf5f729 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c 6b126b7 d53605c cfdf6fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 |
"""
Adaptive Music Exercise Generator (Strict Duration Enforcement)
==============================================================
Generates custom musical exercises with LLM, perfectly fit to user-specified number of measures
AND time signature, guaranteeing exact durations in MIDI and in the UI!
Major updates:
- Changed base duration unit from 16th notes to 8th notes (1 unit = 8th note)
- Updated all calculations and prompts to use new duration system
- Duration sum display now shows total in 8th notes
- Maintained all original functionality
- Added cumulative duration tracking
- Enforced JSON output format with note, duration, cumulative_duration
- Enhanced rest handling and JSON parsing
- Fixed JSON parsing errors for 8-measure exercises
- Added robust error handling for MIDI generation
"""
# -----------------------------------------------------------------------------
# 1. Runtime-time package installation (for fresh containers/Colab/etc)
# -----------------------------------------------------------------------------
import sys
import subprocess
from typing import Dict, Optional, Tuple, List
def install(packages: List[str]):
for package in packages:
try:
__import__(package)
except ImportError:
print(f"Installing missing package: {package}")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install([
"mido", "midi2audio", "pydub", "gradio",
"requests", "numpy", "matplotlib", "librosa", "scipy",
"uuid", "datetime"
])
# -----------------------------------------------------------------------------
# 2. Static imports
# -----------------------------------------------------------------------------
import random
import requests
import json
import tempfile
import mido
from mido import Message, MidiFile, MidiTrack, MetaMessage
import re
from io import BytesIO
from midi2audio import FluidSynth
from pydub import AudioSegment
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import librosa
from scipy.io import wavfile
import os
import subprocess as sp
import base64
import shutil
import ast
import uuid
from datetime import datetime
import time
# -----------------------------------------------------------------------------
# 3. Configuration & constants (UPDATED TO USE 8TH NOTES)
# -----------------------------------------------------------------------------
MISTRAL_API_URL = "https://api.mistral.ai/v1/chat/completions"
MISTRAL_API_KEY = "yQdfM8MLbX9uhInQ7id4iUTwN4h4pDLX" # ← Replace with your key!
SOUNDFONT_URLS = {
"Trumpet": "https://github.com/FluidSynth/fluidsynth/raw/master/sf2/Trumpet.sf2",
"Piano": "https://musical-artifacts.com/artifacts/2719/GeneralUser_GS_1.471.sf2",
"Violin": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
"Clarinet": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
"Flute": "https://musical-artifacts.com/artifacts/2744/SalC5Light.sf2",
}
SAMPLE_RATE = 44100 # Hz
TICKS_PER_BEAT = 480 # Standard MIDI resolution
TICKS_PER_8TH = TICKS_PER_BEAT // 2 # 240 ticks per 8th note (UPDATED)
if not os.path.exists('/usr/bin/fluidsynth'):
try:
os.system('apt-get update && apt-get install -y fluidsynth')
except Exception:
print("Could not install FluidSynth automatically. Please install it manually.")
os.makedirs("static", exist_ok=True)
os.makedirs("temp_audio", exist_ok=True)
# -----------------------------------------------------------------------------
# 4. Music theory helpers (note names ↔︎ MIDI numbers) - ENHANCED REST HANDLING
# -----------------------------------------------------------------------------
NOTE_MAP: Dict[str, int] = {
"C": 0, "C#": 1, "DB": 1,
"D": 2, "D#": 3, "EB": 3,
"E": 4, "F": 5, "F#": 6, "GB": 6,
"G": 7, "G#": 8, "AB": 8,
"A": 9, "A#": 10, "BB": 10,
"B": 11,
}
REST_INDICATORS = ["rest", "r", "Rest", "R", "P", "p", "pause"]
INSTRUMENT_PROGRAMS: Dict[str, int] = {
"Piano": 0, "Trumpet": 56, "Violin": 40,
"Clarinet": 71, "Flute": 73,
}
def is_rest(note: str) -> bool:
"""Check if a note string represents a rest."""
return note.strip().lower() in [r.lower() for r in REST_INDICATORS]
def note_name_to_midi(note: str) -> int:
if is_rest(note):
return -1 # Special value for rests
# Allow both scientific (C4) and Helmholtz (C') notation
match = re.match(r"([A-Ga-g][#b]?)(\'*)(\d?)", note)
if not match:
raise ValueError(f"Invalid note: {note}")
pitch, apostrophes, octave = match.groups()
pitch = pitch.upper().replace('b', 'B')
# Handle Helmholtz notation (C' = C5, C'' = C6, etc)
octave_num = 4
if octave:
octave_num = int(octave)
elif apostrophes:
octave_num = 5 + len(apostrophes)
if pitch not in NOTE_MAP:
raise ValueError(f"Invalid pitch: {pitch}")
return NOTE_MAP[pitch] + (octave_num + 1) * 12
def midi_to_note_name(midi_num: int) -> str:
if midi_num == -1:
return "Rest"
notes = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
octave = (midi_num // 12) - 1
return f"{notes[midi_num % 12]}{octave}"
# -----------------------------------------------------------------------------
# 5. Duration scaling: guarantee the output sums to requested total (using integers)
# -----------------------------------------------------------------------------
def scale_json_durations(json_data, target_units: int) -> list:
"""Scales durations so that their sum is exactly target_units (8th notes)."""
durations = [int(d) for _, d in json_data]
total = sum(durations)
if total == 0:
return json_data
# Calculate proportional scaling with integer arithmetic
scaled = []
remainder = target_units
for i, (note, d) in enumerate(json_data):
if i < len(json_data) - 1:
# Proportional allocation
portion = max(1, round(d * target_units / total))
scaled.append([note, portion])
remainder -= portion
else:
# Last note gets all remaining units
scaled.append([note, max(1, remainder)])
return scaled
# -----------------------------------------------------------------------------
# 6. MIDI from scaled JSON (using integer durations) - UPDATED REST HANDLING
# -----------------------------------------------------------------------------
def json_to_midi(json_data: list, instrument: str, tempo: int, time_signature: str, measures: int, key: str = "C Major") -> MidiFile:
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
track = MidiTrack(); mid.tracks.append(track)
program = INSTRUMENT_PROGRAMS.get(instrument, 56)
numerator, denominator = map(int, time_signature.split('/'))
# Add time signature meta message
track.append(MetaMessage('time_signature', numerator=numerator,
denominator=denominator, time=0))
# Add tempo meta message
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
# Add key signature meta message based on the key
# For MIDI key signatures, the key parameter expects a string like 'C', 'F#m', etc.
key_map = {
"C Major": "C",
"G Major": "G",
"D Major": "D",
"F Major": "F",
"Bb Major": "Bb",
"A Minor": "Am",
"E Minor": "Em",
}
# Use the provided key or default to C major if key not found
midi_key = key_map.get(key, "C")
# The 'key' parameter in MetaMessage expects a string like 'C', 'F#m', etc.
track.append(MetaMessage('key_signature', key=midi_key, time=0))
# Set instrument program
track.append(Message('program_change', program=program, time=0))
# Accumulator for rest durations
accumulated_rest = 0
for note_item in json_data:
try:
# Handle both formats: [note, duration] and {note, duration}
if isinstance(note_item, list) and len(note_item) == 2:
note_name, duration_units = note_item
elif isinstance(note_item, dict):
note_name = note_item["note"]
duration_units = note_item["duration"]
else:
print(f"Unsupported note format: {note_item}")
continue
ticks = int(duration_units * TICKS_PER_8TH)
ticks = max(ticks, 1)
if is_rest(note_name):
# Accumulate rest duration
accumulated_rest += ticks
else:
# Process any accumulated rest first
if accumulated_rest > 0:
# Add rest by creating a silent note (velocity 0) that won't be heard
# Or just skip and use accumulated_rest in timing
# We'll just add the time to the next note
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
track.append(Message('note_off', note=0, velocity=0, time=0))
accumulated_rest = 0
# Process actual note
note_num = note_name_to_midi(note_name)
velocity = random.randint(60, 100)
track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
track.append(Message('note_off', note=note_num, velocity=velocity, time=ticks))
except Exception as e:
print(f"Error parsing note {note_item}: {e}")
# Handle trailing rest
if accumulated_rest > 0:
track.append(Message('note_on', note=0, velocity=0, time=accumulated_rest))
track.append(Message('note_off', note=0, velocity=0, time=0))
return mid
# -----------------------------------------------------------------------------
# 7. MIDI → Audio (MP3) helpers
# -----------------------------------------------------------------------------
def get_soundfont(instrument: str) -> str:
os.makedirs("soundfonts", exist_ok=True)
sf2_path = f"soundfonts/{instrument}.sf2"
if not os.path.exists(sf2_path):
url = SOUNDFONT_URLS.get(instrument, SOUNDFONT_URLS["Trumpet"])
print(f"Downloading SoundFont for {instrument}…")
response = requests.get(url)
with open(sf2_path, "wb") as f:
f.write(response.content)
return sf2_path
def midi_to_mp3(midi_obj: MidiFile, instrument: str = "Trumpet") -> Tuple[str, float]:
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file:
midi_obj.save(mid_file.name)
wav_path = mid_file.name.replace(".mid", ".wav")
mp3_path = mid_file.name.replace(".mid", ".mp3")
sf2_path = get_soundfont(instrument)
try:
sp.run([
'fluidsynth', '-ni', sf2_path, mid_file.name,
'-F', wav_path, '-r', '44100', '-g', '1.0'
], check=True, capture_output=True)
except Exception:
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
fs.midi_to_audio(mid_file.name, wav_path)
try:
sound = AudioSegment.from_wav(wav_path)
if instrument == "Trumpet":
sound = sound.high_pass_filter(200)
elif instrument == "Violin":
sound = sound.low_pass_filter(5000)
sound.export(mp3_path, format="mp3")
static_mp3_path = os.path.join('static', os.path.basename(mp3_path))
shutil.move(mp3_path, static_mp3_path)
return static_mp3_path, sound.duration_seconds
finally:
for f in [mid_file.name, wav_path]:
try:
os.remove(f)
except FileNotFoundError:
pass
# -----------------------------------------------------------------------------
# 8. Prompt engineering for variety (using integer durations) - UPDATED DURATION SYSTEM
# -----------------------------------------------------------------------------
def get_fallback_exercise(instrument: str, level: str, key: str,
time_sig: str, measures: int) -> str:
key_notes = {
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4"],
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4"],
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5"],
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4"],
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4"],
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4"],
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4"],
}
# Get fundamental note from key signature
fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
is_major = "Major" in key
# Get notes for the key
notes = key_notes.get(key, key_notes["C Major"])
# Find fundamental note with octave for ending
fundamental_with_octave = None
for note in notes:
if note.startswith(fundamental_note):
fundamental_with_octave = note
break
# If not found, use the first note (should not happen with our key definitions)
if not fundamental_with_octave:
fundamental_with_octave = notes[0]
numerator, denominator = map(int, time_sig.split('/'))
# Calculate units based on 8th notes
units_per_measure = numerator * (8 // denominator)
target_units = measures * units_per_measure
# Create a rhythm pattern based on time signature
if numerator == 3:
rhythm = [2, 1, 2, 1, 2] # 3/4 pattern
else:
rhythm = [2, 2, 1, 1, 2, 2] # 4/4 pattern
# Build exercise
result = []
cumulative = 0
current_units = 0
# Reserve at least 2 units for the final note
final_note_duration = min(4, max(2, rhythm[0])) # Between 2 and 4 units
available_units = target_units - final_note_duration
# Generate notes until we reach the available units
while current_units < available_units:
# Avoid minor 7th in major keys
if is_major:
# Filter out minor 7th notes (e.g., Bb in C major)
available_notes = [n for n in notes if not (n.startswith("Bb") and key == "C Major") and
not (n.startswith("F") and key == "G Major") and
not (n.startswith("C") and key == "D Major") and
not (n.startswith("Eb") and key == "F Major") and
not (n.startswith("Ab") and key == "Bb Major")]
else:
available_notes = notes
note = random.choice(available_notes)
dur = random.choice(rhythm)
# Don't exceed available units
if current_units + dur > available_units:
dur = available_units - current_units
if dur <= 0:
break
cumulative += dur
current_units += dur
result.append({
"note": note,
"duration": dur,
"cumulative_duration": cumulative
})
# Add the final note (fundamental of the key)
final_duration = target_units - current_units
if final_duration > 0:
cumulative += final_duration
result.append({
"note": fundamental_with_octave,
"duration": final_duration,
"cumulative_duration": cumulative
})
return json.dumps(result)
def get_style_based_on_level(level: str) -> str:
styles = {
"Beginner": ["simple", "legato", "stepwise", "folk-like", "gentle"],
"Intermediate": ["jazzy", "bluesy", "march-like", "syncopated", "dance-like", "lyrical"],
"Advanced": ["technical", "chromatic", "fast arpeggios", "wide intervals", "virtuosic", "complex", "contemporary"],
}
return random.choice(styles.get(level, ["technical"]))
def get_technique_based_on_level(level: str) -> str:
techniques = {
"Beginner": [
"with long tones", "with simple rhythms", "focusing on tone",
"with step-wise motion", "with easy intervals", "focusing on breath control",
"with simple articulation", "with repeated patterns"
],
"Intermediate": [
"with slurs", "with accents", "using triplets", "with moderate syncopation",
"with varied articulation", "with moderate interval jumps", "with dynamic contrast",
"with scale patterns", "with simple ornaments", "with moderate register changes"
],
"Advanced": [
"with double tonguing", "with extreme registers", "with complex rhythms",
"with challenging intervals", "with rapid articulation", "with advanced ornaments",
"with extended techniques", "with complex syncopation", "with virtuosic passages",
"with extreme dynamic contrast", "with challenging arpeggios"
],
}
return random.choice(techniques.get(level, ["with slurs"]))
# -----------------------------------------------------------------------------
# 9. Mistral API: query, fallback on errors - UPDATED DURATION SYSTEM
# -----------------------------------------------------------------------------
def query_mistral(prompt: str, instrument: str, level: str, key: str,
time_sig: str, measures: int, difficulty_modifier: int = 0,
practice_focus: str = "Balanced") -> str:
headers = {
"Authorization": f"Bearer {MISTRAL_API_KEY}",
"Content-Type": "application/json",
}
numerator, denominator = map(int, time_sig.split('/'))
# UPDATED: Calculate total required 8th notes
units_per_measure = numerator * (8 // denominator)
required_total = measures * units_per_measure
# UPDATED: Duration explanation in prompt
duration_constraint = (
f"Sum of all durations MUST BE EXACTLY {required_total} units (8th notes). "
f"Each integer duration represents an 8th note (1=8th, 2=quarter, 4=half, 8=whole). "
f"If it doesn't match, the exercise is invalid."
)
system_prompt = (
f"You are an expert music teacher specializing in {instrument.lower()}. "
"Create customized exercises using INTEGER durations representing 8th notes."
)
if prompt.strip():
user_prompt = (
f"{prompt} {duration_constraint} Output ONLY a JSON array of objects with "
"the following structure: [{{'note': string, 'duration': integer, 'cumulative_duration': integer}}]"
)
else:
# Adjust level based on difficulty modifier
effective_level = level
if difficulty_modifier != 0:
level_map = {"Beginner": 0, "Intermediate": 1, "Advanced": 2}
level_list = ["Beginner", "Intermediate", "Advanced"]
base_level_idx = level_map.get(level, 1)
adjusted_idx = max(0, min(2, base_level_idx + difficulty_modifier))
effective_level = level_list[adjusted_idx]
style = get_style_based_on_level(effective_level)
technique = get_technique_based_on_level(effective_level)
# Extract fundamental note from key signature
fundamental_note = key.split()[0] # Gets 'C' from 'C Major' or 'A' from 'A Minor'
is_major = "Major" in key
# Create additional musical constraints
key_constraints = (
f"The exercise MUST end on the fundamental note of the key ({fundamental_note}). "
f"{'' if not is_major else 'For this major key, avoid using the minor 7th degree.'}"
)
# Add practice focus constraints
focus_constraints = ""
if practice_focus == "Rhythmic Focus":
focus_constraints = "Include varied rhythmic patterns with syncopation and different note durations. "
elif practice_focus == "Melodic Focus":
focus_constraints = "Create a melodically interesting line with good contour and phrasing. "
elif practice_focus == "Technical Focus":
focus_constraints = "Include technical challenges like arpeggios, scales, or interval jumps. "
elif practice_focus == "Expressive Focus":
focus_constraints = "Design a lyrical exercise with opportunities for dynamic contrast and expression. "
# Difficulty modifier description for prompt
difficulty_desc = ""
if difficulty_modifier > 0:
difficulty_desc = f"Make this slightly more challenging than a typical {level.lower()} exercise. "
elif difficulty_modifier < 0:
difficulty_desc = f"Make this slightly easier than a typical {level.lower()} exercise. "
user_prompt = (
f"Create a {style} {instrument.lower()} exercise in {key} with {time_sig} time signature "
f"{technique} for a {level.lower()} player. {difficulty_desc}{focus_constraints}{duration_constraint} {key_constraints} "
"Output ONLY a JSON array of objects with the following structure: "
"[{{'note': string, 'duration': integer, 'cumulative_duration': integer}}] "
"Use standard note names (e.g., \"Bb4\", \"F#5\"). Monophonic only. "
"Durations: 1=8th, 2=quarter, 4=half, 8=whole. "
"Sum must be exactly as specified. ONLY output the JSON array. No prose."
)
payload = {
"model": "mistral-medium",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": 0.7 if level == "Advanced" else 0.5,
"max_tokens": 1000,
"top_p": 0.95,
"frequency_penalty": 0.2,
"presence_penalty": 0.2,
}
try:
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
response.raise_for_status()
content = response.json()["choices"][0]["message"]["content"]
return content.replace("```json","").replace("```","").strip()
except Exception as e:
print(f"Error querying Mistral API: {e}")
return get_fallback_exercise(instrument, level, key, time_sig, measures)
# -----------------------------------------------------------------------------
# 10. Robust JSON parsing for LLM outputs - ENHANCED PARSING
# -----------------------------------------------------------------------------
def safe_parse_json(text: str) -> Optional[list]:
try:
text = text.strip().replace("'", '"')
# Find JSON array in the text
start_idx = text.find('[')
end_idx = text.rfind(']')
if start_idx == -1 or end_idx == -1:
return None
json_str = text[start_idx:end_idx+1]
# Fix common JSON issues
json_str = re.sub(r',\s*([}\]])', r'\1', json_str) # Trailing commas
json_str = re.sub(r'{\s*(\w+)\s*:', r'{"\1":', json_str) # Unquoted keys
json_str = re.sub(r':\s*([a-zA-Z_][a-zA-Z0-9_]*)(\s*[,}])', r':"\1"\2', json_str) # Unquoted strings
parsed = json.loads(json_str)
# Normalize keys to 'note' and 'duration'
normalized = []
for item in parsed:
if isinstance(item, dict):
# Find note value - accept multiple keys
note_val = None
for key in ['note', 'pitch', 'nota', 'ton']:
if key in item:
note_val = str(item[key])
break
# Find duration value
dur_val = None
for key in ['duration', 'dur', 'length', 'value']:
if key in item:
try:
dur_val = int(item[key])
except (TypeError, ValueError):
pass
if note_val is not None and dur_val is not None:
normalized.append({"note": note_val, "duration": dur_val})
return normalized if normalized else None
except Exception as e:
print(f"JSON parsing error: {e}\nRaw text: {text}")
return None
# -----------------------------------------------------------------------------
# 11. Main orchestration: talk to API, *scale durations*, build MIDI, UI values - UPDATED
# -----------------------------------------------------------------------------
def generate_exercise(instrument: str, level: str, key: str, tempo: int, time_signature: str,
measures: int, custom_prompt: str, mode: str, difficulty_modifier: int = 0,
practice_focus: str = "Balanced") -> Tuple[str, Optional[str], str, MidiFile, str, str, int]:
try:
prompt_to_use = custom_prompt if mode == "Exercise Prompt" else ""
output = query_mistral(prompt_to_use, instrument, level, key, time_signature, measures, difficulty_modifier, practice_focus)
parsed = safe_parse_json(output)
if not parsed:
print("Primary parsing failed, using fallback")
fallback_str = get_fallback_exercise(instrument, level, key, time_signature, measures)
parsed = safe_parse_json(fallback_str)
if not parsed:
print("Fallback parsing failed, using ultimate fallback")
# Ultimate fallback: simple scale based on selected key
key_notes = {
"C Major": ["C4", "D4", "E4", "F4", "G4", "A4", "B4", "C5"],
"G Major": ["G3", "A3", "B3", "C4", "D4", "E4", "F#4", "G4"],
"D Major": ["D4", "E4", "F#4", "G4", "A4", "B4", "C#5", "D5"],
"F Major": ["F3", "G3", "A3", "Bb3", "C4", "D4", "E4", "F4"],
"Bb Major": ["Bb3", "C4", "D4", "Eb4", "F4", "G4", "A4", "Bb4"],
"A Minor": ["A3", "B3", "C4", "D4", "E4", "F4", "G4", "A4"],
"E Minor": ["E3", "F#3", "G3", "A3", "B3", "C4", "D4", "E4"],
}
notes = key_notes.get(key, key_notes["C Major"])
numerator, denominator = map(int, time_signature.split('/'))
units_per_measure = numerator * (8 // denominator)
target_units = measures * units_per_measure
note_duration = max(1, target_units // len(notes))
parsed = [{"note": n, "duration": note_duration} for n in notes]
# Adjust last note to match total duration
total = sum(item["duration"] for item in parsed)
if total < target_units:
parsed[-1]["duration"] += target_units - total
elif total > target_units:
parsed[-1]["duration"] -= total - target_units
# Calculate total required 8th notes (UPDATED)
numerator, denominator = map(int, time_signature.split('/'))
units_per_measure = numerator * (8 // denominator)
total_units = measures * units_per_measure
# Convert to old format for scaling
old_format = []
for item in parsed:
if isinstance(item, dict):
old_format.append([item["note"], item["duration"]])
else:
old_format.append(item)
# Strict scaling
parsed_scaled_old = scale_json_durations(old_format, total_units)
# Convert back to new format with cumulative durations
cumulative = 0
parsed_scaled = []
for note, dur in parsed_scaled_old:
cumulative += dur
parsed_scaled.append({
"note": note,
"duration": dur,
"cumulative_duration": cumulative
})
# Calculate total duration units
total_duration = cumulative
# Generate MIDI and audio
midi = json_to_midi(parsed_scaled, instrument, tempo, time_signature, measures, key)
mp3_path, real_duration = midi_to_mp3(midi, instrument)
output_json_str = json.dumps(parsed_scaled, indent=2)
return output_json_str, mp3_path, str(tempo), midi, f"{real_duration:.2f} seconds", time_signature, total_duration
except Exception as e:
return f"Error: {str(e)}", None, str(tempo), None, "0", time_signature, 0
# -----------------------------------------------------------------------------
# 12. Simple AI chat assistant (optional, shares LLM)
# -----------------------------------------------------------------------------
def handle_chat(message: str, history: List, instrument: str, level: str):
if not message.strip():
return "", history
messages = [{"role": "system", "content": f"You are a {instrument} teacher for {level} students."}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
headers = {"Authorization": f"Bearer {MISTRAL_API_KEY}", "Content-Type": "application/json"}
payload = {"model": "mistral-medium", "messages": messages, "temperature": 0.7, "max_tokens": 500}
try:
response = requests.post(MISTRAL_API_URL, headers=headers, json=payload)
response.raise_for_status()
content = response.json()["choices"][0]["message"]["content"]
history.append((message, content))
return "", history
except Exception as e:
history.append((message, f"Error: {str(e)}"))
return "", history
# -----------------------------------------------------------------------------
# 13. New features: Visualization, Metronome, and Exercise Library
# -----------------------------------------------------------------------------
# Visualization function to create a piano roll representation of the exercise
def create_visualization(json_data, time_sig):
try:
if not json_data or "Error" in json_data:
return None
parsed = json.loads(json_data)
if not isinstance(parsed, list) or len(parsed) == 0:
return None
# Extract notes and durations
notes = []
durations = []
for item in parsed:
if isinstance(item, dict) and "note" in item and "duration" in item:
note_name = item["note"]
if not is_rest(note_name):
try:
midi_note = note_name_to_midi(note_name)
notes.append(midi_note)
durations.append(item["duration"])
except ValueError:
notes.append(60) # Default to middle C if parsing fails
durations.append(item["duration"])
else:
notes.append(None) # Represent rest
durations.append(item["duration"])
# Create piano roll visualization
fig, ax = plt.subplots(figsize=(12, 6))
# Calculate time positions
time_positions = [0]
for dur in durations[:-1]:
time_positions.append(time_positions[-1] + dur)
# Plot notes as rectangles
for i, (note, dur, pos) in enumerate(zip(notes, durations, time_positions)):
if note is not None: # Skip rests
rect = plt.Rectangle((pos, note-0.4), dur, 0.8, color='blue', alpha=0.7)
ax.add_patch(rect)
# Add note name
ax.text(pos + dur/2, note+0.5, midi_to_note_name(note),
ha='center', va='bottom', fontsize=8)
# Add measure lines
numerator, denominator = map(int, time_sig.split('/'))
units_per_measure = numerator * (8 // denominator)
max_time = time_positions[-1] + durations[-1]
for measure in range(1, int(max_time / units_per_measure) + 1):
measure_pos = measure * units_per_measure
if measure_pos <= max_time:
ax.axvline(x=measure_pos, color='gray', linestyle='--', alpha=0.5)
# Set axis limits and labels
ax.set_ylim(min(notes) - 5 if None not in notes else 55,
max(notes) + 5 if None not in notes else 75)
ax.set_xlim(0, max_time)
ax.set_ylabel('MIDI Note')
ax.set_xlabel('Time (8th note units)')
ax.set_title('Exercise Visualization')
# Add piano keyboard on y-axis
ax.set_yticks([60, 62, 64, 65, 67, 69, 71, 72]) # C4 to C5
ax.set_yticklabels(['C4', 'D4', 'E4', 'F4', 'G4', 'A4', 'B4', 'C5'])
ax.grid(True, axis='y', alpha=0.3)
# Save figure to temporary file
temp_img_path = os.path.join('static', f'visualization_{uuid.uuid4().hex}.png')
plt.tight_layout()
plt.savefig(temp_img_path)
plt.close()
return temp_img_path
except Exception as e:
print(f"Error creating visualization: {e}")
return None
# VexFlow music notation visualization function
def create_vexflow_notation(json_data, time_sig, key_sig):
# Helper function to convert duration units to VexFlow duration
def durationToVex(units):
if units == 1:
return "8"
elif units == 2:
return "4"
elif units == 3:
return "4d"
elif units == 4:
return "2"
elif units == 6:
return "2d"
elif units == 8:
return "1"
else:
return "8"
if not json_data or "Error" in json_data:
return None
try:
parsed = json.loads(json_data)
if not isinstance(parsed, list) or len(parsed) == 0:
return None
# Create HTML content with VexFlow notation
html_content = f'''
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>Music Notation</title>
<script src="https://cdn.jsdelivr.net/npm/vexflow@4.2.2/build/cjs/vexflow.js"></script>
<style>
#output {{width: 100%; overflow: auto;}}
body {{font-family: Arial, sans-serif;}}
h2 {{color: #333;}}
</style>
</head>
<body>
<h2>Exercise in {key_sig}, {time_sig}</h2>
<div id="output"></div>
<script>
const {{Factory, EasyScore, System}} = Vex.Flow;
// Create VexFlow factory and context
const vf = new Factory({{renderer: {{elementId: 'output', width: 1200, height: 200}}}});
const score = vf.EasyScore();
const system = vf.System();
// Parse notes from JSON
const jsonData = {json.dumps(parsed)};
// Convert to VexFlow notation
let vexNotes = [];
let currentMeasure = [];
let currentDuration = 0;
const timeSignature = "{time_sig}";
const [numerator, denominator] = timeSignature.split('/').map(Number);
const unitsPerMeasure = numerator * (8 / denominator);
// Helper function to convert duration units to VexFlow duration
function durationToVex(units) {{
if (units === 1) return "8";
if (units === 2) return "4";
if (units === 3) return "4d";
if (units === 4) return "2";
if (units === 6) return "2d";
if (units === 8) return "1";
return "8";
}}
// Process notes
jsonData.forEach(item => {{
const noteName = item.note;
const duration = item.duration;
// Skip invalid notes
if (!noteName || duration <= 0) return;
// Handle rests
const isRest = /rest|r|p/i.test(noteName);
let vexNote;
if (isRest) {{
vexNote = `B4/${{durationToVex(duration)}}/r`;
}} else {{
// Convert scientific notation to VexFlow format
// VexFlow uses lowercase for note names
const noteRegex = /([A-Ga-g][#b]?)(\d)/;
const match = noteName.match(noteRegex);
if (match) {{
const [_, pitch, octave] = match;
vexNote = `${{pitch.toLowerCase()}}${{octave}}/${{durationToVex(duration)}}`;
}} else {{
// Default if parsing fails
vexNote = `c4/${{durationToVex(duration)}}`;
}}
}}
currentMeasure.push(vexNote);
currentDuration += duration;
// Check if measure is complete
if (currentDuration >= unitsPerMeasure) {{
vexNotes.push(currentMeasure);
currentMeasure = [];
currentDuration = 0;
}}
}});
// Add any remaining notes
if (currentMeasure.length > 0) {{
vexNotes.push(currentMeasure);
}}
// Create staves and add notes
const staves = [];
const measuresPerLine = 4;
for (let i = 0; i < vexNotes.length; i += measuresPerLine) {{
const lineStaves = [];
const lineNotes = vexNotes.slice(i, i + measuresPerLine);
// Create a new system for each line
const lineSystem = vf.System({{width: 1100}});
// Add staves for each measure in the line
lineNotes.forEach((measure, index) => {{
const stave = lineSystem.addStave({{
voices: [
score.voice(score.notes(measure.join(', ')))
]
}});
// Add time signature and key to first measure of first line
if (i === 0 && index === 0) {{
stave.addTimeSignature(timeSignature);
stave.addKeySignature("{key_sig.split()[0]}");
}}
}});
lineSystem.addConnector("singleRight");
staves.push(lineSystem);
}}
// Format and draw
vf.draw();
</script>
</body>
</html>
'''
# For Hugging Face environment, return HTML content directly
# Also save a copy to file for compatibility with existing code
try:
html_path = os.path.join('static', f'notation_{uuid.uuid4().hex}.html')
with open(html_path, 'w') as f:
f.write(html_content)
except Exception as file_error:
print(f"Warning: Could not save notation file: {file_error}")
# Return HTML content directly
return html_content
except Exception as e:
print(f"Error creating VexFlow notation: {e}")
return "<p>Failed to generate music notation. Error: " + str(e) + "</p>"
# Metronome function
def create_metronome_audio(tempo, time_sig, measures):
try:
numerator, denominator = map(int, time_sig.split('/'))
# Create a MIDI file with metronome clicks
mid = MidiFile(ticks_per_beat=TICKS_PER_BEAT)
track = MidiTrack()
mid.tracks.append(track)
# Add time signature and tempo
track.append(MetaMessage('time_signature', numerator=numerator,
denominator=denominator, time=0))
track.append(MetaMessage('set_tempo', tempo=mido.bpm2tempo(int(tempo)), time=0))
# Calculate total beats
beats_per_measure = numerator
total_beats = beats_per_measure * measures
# Add metronome clicks (strong beat = note 77, weak beat = note 76)
for beat in range(total_beats):
# Strong beat on first beat of measure, weak beat otherwise
note_num = 77 if beat % beats_per_measure == 0 else 76
velocity = 100 if beat % beats_per_measure == 0 else 80
# Add note on (with time=0 for first beat)
if beat == 0:
track.append(Message('note_on', note=note_num, velocity=velocity, time=0))
else:
# Each beat is a quarter note (TICKS_PER_BEAT)
track.append(Message('note_on', note=note_num, velocity=velocity, time=TICKS_PER_BEAT))
# Short duration for click
track.append(Message('note_off', note=note_num, velocity=0, time=10))
# Save and convert to audio
with tempfile.NamedTemporaryFile(delete=False, suffix=".mid") as mid_file:
mid.save(mid_file.name)
wav_path = mid_file.name.replace(".mid", ".wav")
mp3_path = mid_file.name.replace(".mid", ".mp3")
# Use piano soundfont for metronome
sf2_path = get_soundfont("Piano")
try:
sp.run([
'fluidsynth', '-ni', sf2_path, mid_file.name,
'-F', wav_path, '-r', '44100', '-g', '1.0'
], check=True, capture_output=True)
except Exception:
fs = FluidSynth(sf2_path, sample_rate=44100, gain=1.0)
fs.midi_to_audio(mid_file.name, wav_path)
# Convert to MP3
sound = AudioSegment.from_wav(wav_path)
sound.export(mp3_path, format="mp3")
# Move to static directory
static_mp3_path = os.path.join('static', f'metronome_{uuid.uuid4().hex}.mp3')
shutil.move(mp3_path, static_mp3_path)
# Clean up temporary files
for f in [mid_file.name, wav_path]:
try:
os.remove(f)
except FileNotFoundError:
pass
return static_mp3_path
except Exception as e:
print(f"Error creating metronome: {e}")
return None
# Function to calculate difficulty rating
def calculate_difficulty_rating(json_data, level, difficulty_modifier=0, practice_focus="Balanced"):
try:
if not json_data or "Error" in json_data:
return 0
parsed = json.loads(json_data)
if not isinstance(parsed, list) or len(parsed) == 0:
return 0
# Extract notes and durations
notes = []
durations = []
for item in parsed:
if isinstance(item, dict) and "note" in item and "duration" in item:
note_name = item["note"]
if not is_rest(note_name):
try:
midi_note = note_name_to_midi(note_name)
notes.append(midi_note)
durations.append(item["duration"])
except ValueError:
pass
if not notes:
return 0
# Calculate difficulty factors
# 1. Range (wider range = harder)
note_range = max(notes) - min(notes) if notes else 0
range_factor = min(note_range / 12, 1.0) # Normalize to octave
# 2. Rhythmic complexity (more varied durations = harder)
unique_durations = len(set(durations))
rhythm_factor = min(unique_durations / 4, 1.0) # Normalize
# 3. Interval jumps (larger jumps = harder)
jumps = [abs(notes[i] - notes[i-1]) for i in range(1, len(notes))]
avg_jump = sum(jumps) / len(jumps) if jumps else 0
jump_factor = min(avg_jump / 7, 1.0) # Normalize to perfect fifth
# 4. Speed factor (shorter durations = harder)
avg_duration = sum(durations) / len(durations) if durations else 0
speed_factor = min(2.0 / avg_duration if avg_duration > 0 else 1.0, 1.0) # Normalize
# Adjust weights based on practice focus
weights = {"range": 0.25, "rhythm": 0.25, "jump": 0.25, "speed": 0.25}
if practice_focus == "Rhythmic Focus":
weights = {"range": 0.15, "rhythm": 0.55, "jump": 0.15, "speed": 0.15}
elif practice_focus == "Melodic Focus":
weights = {"range": 0.40, "rhythm": 0.15, "jump": 0.30, "speed": 0.15}
elif practice_focus == "Technical Focus":
weights = {"range": 0.25, "rhythm": 0.15, "jump": 0.40, "speed": 0.20}
elif practice_focus == "Expressive Focus":
weights = {"range": 0.35, "rhythm": 0.25, "jump": 0.25, "speed": 0.15}
# Calculate base difficulty with adjusted weights
base_difficulty = (
range_factor * weights["range"] +
rhythm_factor * weights["rhythm"] +
jump_factor * weights["jump"] +
speed_factor * weights["speed"]
)
# Apply level multiplier
level_multiplier = {
"Beginner": 0.7,
"Intermediate": 1.0,
"Advanced": 1.3
}.get(level, 1.0)
# Apply difficulty modifier (each step is about 15% change)
modifier_multiplier = 1.0 + (difficulty_modifier * 0.15)
# Calculate final rating (1-10 scale)
rating = round(base_difficulty * level_multiplier * modifier_multiplier * 10)
return max(1, min(rating, 10)) # Ensure between 1-10
except Exception as e:
print(f"Error calculating difficulty: {e}")
return 0
# -----------------------------------------------------------------------------
# 14. Gradio user interface definition (for humans!) - ENHANCED GUI
# -----------------------------------------------------------------------------
def create_ui() -> gr.Blocks:
with gr.Blocks(title="Adaptive Music Exercise Generator", theme="soft") as demo:
gr.Markdown("# 🎼 Adaptive Music Exercise Generator")
current_midi = gr.State(None)
current_exercise = gr.State("")
current_audio_path = gr.State(None)
mode = gr.Radio(["Exercise Parameters","Exercise Prompt"], value="Exercise Parameters", label="Generation Mode")
with gr.Row():
with gr.Column(scale=1):
with gr.Group(visible=True) as params_group:
gr.Markdown("### Exercise Parameters")
instrument = gr.Dropdown([
"Trumpet", "Piano", "Violin", "Clarinet", "Flute",
], value="Trumpet", label="Instrument")
level = gr.Radio([
"Beginner", "Intermediate", "Advanced",
], value="Intermediate", label="Difficulty Level")
difficulty_modifier = gr.Slider(minimum=-2, maximum=2, value=0, step=1,
label="Difficulty Modifier",
info="Fine-tune the difficulty: -2 (easier) to +2 (harder)")
practice_focus = gr.Dropdown([
"Balanced", "Rhythmic Focus", "Melodic Focus", "Technical Focus", "Expressive Focus"
], value="Balanced", label="Practice Focus")
key = gr.Dropdown([
"C Major", "G Major", "D Major", "F Major", "Bb Major", "A Minor", "E Minor",
], value="C Major", label="Key Signature")
time_signature = gr.Dropdown(["3/4", "4/4"], value="4/4", label="Time Signature")
measures = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
with gr.Group(visible=False) as prompt_group:
gr.Markdown("### Exercise Prompt")
custom_prompt = gr.Textbox("", label="Enter your custom prompt", lines=3)
measures_prompt = gr.Radio([4, 8, 12, 16], value=4, label="Length (measures)")
generate_btn = gr.Button("Generate Exercise", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Exercise Player"):
audio_output = gr.Audio(label="Generated Exercise", autoplay=True, type="filepath")
with gr.Row():
bpm_display = gr.Textbox(label="Tempo (BPM)")
time_sig_display = gr.Textbox(label="Time Signature")
duration_display = gr.Textbox(label="Audio Duration", interactive=False)
with gr.Row():
difficulty_rating = gr.Number(label="Difficulty Rating (1-10)", interactive=False, precision=1)
# Metronome section
gr.Markdown("### Metronome")
with gr.Row():
metronome_tempo = gr.Slider(minimum=40, maximum=200, value=60, step=1, label="Metronome Tempo")
metronome_btn = gr.Button("Generate Metronome", variant="secondary")
metronome_audio = gr.Audio(label="Metronome", type="filepath")
with gr.TabItem("Exercise Data"):
json_output = gr.Code(label="JSON Representation", language="json")
duration_sum = gr.Number(
label="Total Duration Units (8th notes)",
interactive=False,
precision=0
)
with gr.TabItem("Visualization"):
visualization_output = gr.Image(label="Exercise Visualization", type="filepath")
visualize_btn = gr.Button("Generate Visualization", variant="secondary")
with gr.TabItem("Music Notation"):
notation_html = gr.HTML(label="Music Notation")
notation_btn = gr.Button("Generate Music Notation", variant="secondary")
with gr.TabItem("MIDI Export"):
midi_output = gr.File(label="MIDI File")
download_midi = gr.Button("Generate MIDI File")
with gr.TabItem("AI Chat"):
chat_history = gr.Chatbot(label="Practice Assistant", height=400)
chat_message = gr.Textbox(label="Ask the AI anything about your practice")
send_chat_btn = gr.Button("Send")
# Toggle UI groups
mode.change(
fn=lambda m: {
params_group: gr.update(visible=(m == "Exercise Parameters")),
prompt_group: gr.update(visible=(m == "Exercise Prompt")),
},
inputs=[mode], outputs=[params_group, prompt_group]
)
def generate_caller(mode_val, instrument_val, level_val, key_val,
time_sig_val, measures_val, prompt_val, measures_prompt_val,
difficulty_modifier_val, practice_focus_val):
real_measures = measures_prompt_val if mode_val == "Exercise Prompt" else measures_val
fixed_tempo = 60
json_data, mp3_path, tempo, midi, duration, time_sig, total_duration = generate_exercise(
instrument_val, level_val, key_val, fixed_tempo, time_sig_val,
real_measures, prompt_val, mode_val, difficulty_modifier_val, practice_focus_val
)
# Calculate difficulty rating
rating = calculate_difficulty_rating(json_data, level_val, difficulty_modifier_val, practice_focus_val)
# Generate visualization
viz_path = create_visualization(json_data, time_sig_val)
# Generate music notation
html_content = create_vexflow_notation(json_data, time_sig_val, key_val)
if not html_content:
html_content = ""
return json_data, mp3_path, tempo, midi, duration, time_sig, total_duration, rating, viz_path, mp3_path, html_content
generate_btn.click(
fn=generate_caller,
inputs=[mode, instrument, level, key, time_signature, measures, custom_prompt, measures_prompt,
difficulty_modifier, practice_focus],
outputs=[json_output, audio_output, bpm_display, current_midi, duration_display,
time_sig_display, duration_sum, difficulty_rating, visualization_output, current_audio_path, notation_html]
)
# Visualization button
visualize_btn.click(
fn=create_visualization,
inputs=[json_output, time_signature],
outputs=[visualization_output]
)
# Music Notation button
def display_notation(json_data, time_sig, key_val):
html_content = create_vexflow_notation(json_data, time_sig, key_val)
if html_content:
return html_content
return "<p>Failed to generate music notation.</p>"
notation_btn.click(
fn=display_notation,
inputs=[json_output, time_signature, key],
outputs=[notation_html]
)
# Metronome generation
def generate_metronome(tempo, time_sig, measures_val):
return create_metronome_audio(tempo, time_sig, measures_val)
metronome_btn.click(
fn=generate_metronome,
inputs=[metronome_tempo, time_signature, measures],
outputs=[metronome_audio]
)
def save_midi(json_data, instr, time_sig, key_sig="C Major"):
try:
if not json_data or "Error" in json_data:
return None
parsed = json.loads(json_data)
# Validate JSON structure
if not isinstance(parsed, list):
return None
old_format = []
for item in parsed:
if isinstance(item, dict) and "note" in item and "duration" in item:
old_format.append([item["note"], item["duration"]])
if not old_format:
return None
# Calculate total units
total_units = sum(d[1] for d in old_format)
numerator, denominator = map(int, time_sig.split('/'))
units_per_measure = numerator * (8 // denominator)
measures_est = max(1, round(total_units / units_per_measure))
# Generate MIDI
cumulative = 0
scaled_new = []
for note, dur in old_format:
cumulative += dur
scaled_new.append({
"note": note,
"duration": dur,
"cumulative_duration": cumulative
})
midi_obj = json_to_midi(scaled_new, instr, 60, time_sig, measures_est, key=key_sig)
midi_path = os.path.join("static", "exercise.mid")
midi_obj.save(midi_path)
return midi_path
except Exception as e:
print(f"Error saving MIDI: {e}")
return None
download_midi.click(
fn=save_midi,
inputs=[json_output, instrument, time_signature, key],
outputs=[midi_output],
)
send_chat_btn.click(
fn=handle_chat,
inputs=[chat_message, chat_history, instrument, level],
outputs=[chat_message, chat_history],
)
return demo
# -----------------------------------------------------------------------------
# 14. Entry point
# -----------------------------------------------------------------------------
if __name__ == "__main__":
demo = create_ui()
demo.launch() |