#!/usr/bin/env python3 """ AnalysisGNN Gradio App A Gradio interface for AnalysisGNN music analysis. Users can upload MusicXML scores, run the model, and view results. """ import gradio as gr import pandas as pd import numpy as np import logging import os import shutil import subprocess import tempfile import time import torch import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from pathlib import Path from typing import Tuple, Optional, Dict import traceback import warnings # Suppress warnings for cleaner output warnings.filterwarnings('ignore') # Import partitura and AnalysisGNN import partitura as pt from analysisgnn.models.analysis import ContinualAnalysisGNN from analysisgnn.utils.chord_representations import available_representations, NoteDegree49 # Ensure additional representations are available for decoding if "note_degree" not in available_representations and NoteDegree49 is not None: available_representations["note_degree"] = NoteDegree49 LOG_LEVEL = os.environ.get("ANALYSISGNN_LOG_LEVEL", "INFO").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.INFO), format="[%(asctime)s] %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger("analysisgnn_app") PARALLEL_CONFIG = os.environ.get("ANALYSISGNN_PARALLEL", "auto").strip().lower() CPU_COUNT = os.cpu_count() or 1 MUSESCORE_APPIMAGE_URL = "https://www.modelscope.cn/studio/Genius-Society/piano_trans/resolve/master/MuseScore.AppImage" MUSESCORE_STORAGE_DIR = Path("artifacts") / "musescore" MUSESCORE_ENV_VAR = "MUSESCORE_BIN" MUSESCORE_RENDER_TIMEOUT = int(os.environ.get("MUSESCORE_RENDER_TIMEOUT", "180")) MUSESCORE_EXTRACT_TIMEOUT = int(os.environ.get("MUSESCORE_EXTRACT_TIMEOUT", "240")) _MUSESCORE_BINARY: Optional[str] = None _MUSESCORE_READY: bool = False MUSESCORE_V3_APPIMAGE_URL = "https://github.com/musescore/MuseScore/releases/download/v3.6.2/MuseScore-3.6.2.548021370-x86_64.AppImage" MUSESCORE_V3_STORAGE_DIR = Path("artifacts") / "musescore_v3" MUSESCORE_V3_ENV_VAR = "MUSESCORE_V3_BIN" _MUSESCORE_V3_BINARY: Optional[str] = None RENDER_OUTPUT_DIR = Path("artifacts") / "rendered_scores" XVFB_ENV_VAR = "XVFB_BIN" XVFB_STORAGE_DIR = Path("artifacts") / "xvfb" _XVFB_BINARY: Optional[str] = None # Global model variable MODEL = None DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Using device: %s", DEVICE) if torch.cuda.is_available(): logger.info("CUDA device: %s", torch.cuda.get_device_name(0)) @contextmanager def log_timing(label: str): """Log start/stop (with duration) for expensive operations.""" start = time.perf_counter() logger.info("▶ %s", label) try: yield except Exception: elapsed = time.perf_counter() - start logger.exception("✗ %s failed after %.2fs", label, elapsed) raise else: elapsed = time.perf_counter() - start logger.info("✓ %s in %.2fs", label, elapsed) def should_parallelize() -> bool: """ Decide whether to run analysis/visualization in parallel. Controlled via ANALYSISGNN_PARALLEL env: - "0"/"false": force sequential - "1"/"true": force parallel - "auto" (default): enable if more than one CPU core is available """ if PARALLEL_CONFIG in {"0", "false", "no", "off"}: return False if PARALLEL_CONFIG in {"1", "true", "yes", "on"}: return True return CPU_COUNT > 1 def download_wandb_checkpoint(artifact_path: str = "melkisedeath/AnalysisGNN/model-uvj2ddun:v1") -> str: """Download checkpoint from Weights & Biases, or use cached version if available.""" # Create artifacts directory structure artifacts_dir = "checkpoint" os.makedirs(artifacts_dir, exist_ok=True) # Check if checkpoint already exists directly in artifacts/models checkpoint_path = os.path.join(artifacts_dir, "model.ckpt") if os.path.exists(checkpoint_path): logger.info("Using cached checkpoint: %s", checkpoint_path) return checkpoint_path # Check for any .ckpt file in the artifacts/models directory if os.path.exists(artifacts_dir): for fname in os.listdir(artifacts_dir): if fname.endswith('.ckpt'): checkpoint_path = os.path.join(artifacts_dir, fname) logger.info("Using cached checkpoint: %s", checkpoint_path) return checkpoint_path # Check artifact-specific subdirectory artifact_dir = os.path.join(artifacts_dir, os.path.basename(artifact_path)) checkpoint_path = os.path.join(artifact_dir, "model.ckpt") if os.path.exists(checkpoint_path): logger.info("Using cached checkpoint: %s", checkpoint_path) return checkpoint_path # Only import and use wandb if checkpoint is not cached import wandb logger.info("Downloading checkpoint from W&B: %s", artifact_path) # Initialize wandb in offline mode to avoid creating online runs run = wandb.init(mode="offline") try: artifact = run.use_artifact(artifact_path, type='model') with log_timing("Downloading W&B checkpoint"): artifact_dir = artifact.download(root=artifacts_dir) finally: wandb.finish() # Find the checkpoint file checkpoint_path = os.path.join(artifact_dir, "model.ckpt") if not os.path.exists(checkpoint_path): for fname in os.listdir(artifact_dir): if fname.endswith('.ckpt'): checkpoint_path = os.path.join(artifact_dir, fname) break return checkpoint_path def load_model() -> ContinualAnalysisGNN: """Load the AnalysisGNN model.""" global MODEL if MODEL is None: checkpoint_path = download_wandb_checkpoint() logger.info("Loading model from: %s", checkpoint_path) MODEL = ContinualAnalysisGNN.load_from_checkpoint( checkpoint_path, map_location=DEVICE ) MODEL.eval() MODEL.to(DEVICE) logger.info("Model loaded successfully!") return MODEL def _format_bytes(num_bytes: float) -> str: """Return human readable size string.""" units = ["B", "KB", "MB", "GB", "TB"] size = float(num_bytes) for unit in units: if size < 1024: return f"{size:.1f}{unit}" size /= 1024 return f"{size:.1f}PB" def _download_file(url: str, destination: Path) -> bool: """Download a file from url to destination.""" try: destination.parent.mkdir(parents=True, exist_ok=True) logger.info("Starting download: %s -> %s", url, destination) with urllib.request.urlopen(url) as response, open(destination, "wb") as out_file: total_size = int(response.headers.get("Content-Length", 0)) downloaded = 0 chunk_size = 1024 * 256 last_log = time.perf_counter() while True: chunk = response.read(chunk_size) if not chunk: break out_file.write(chunk) downloaded += len(chunk) now = time.perf_counter() if now - last_log > 5: pct = (downloaded / total_size * 100) if total_size else 0 logger.info( "Downloading... %s / %s (%.1f%%)", _format_bytes(downloaded), _format_bytes(total_size) if total_size else "unknown", pct, ) last_log = now logger.info( "Finished download: %s (%s)", destination, _format_bytes(destination.stat().st_size), ) return True except Exception as exc: logger.exception("Error downloading %s: %s", url, exc) return False def _cleanup_musescore_artifacts(remove_appimage: bool = False) -> None: """Remove partially extracted MuseScore artifacts to allow a clean retry.""" extract_dir = MUSESCORE_STORAGE_DIR / "squashfs-root" if extract_dir.exists(): logger.warning("Removing stale MuseScore extract at %s", extract_dir) shutil.rmtree(extract_dir, ignore_errors=True) if remove_appimage: appimage = MUSESCORE_STORAGE_DIR / "MuseScore.AppImage" if appimage.exists(): try: appimage.unlink() logger.warning("Removed corrupt MuseScore AppImage at %s", appimage) except Exception: logger.warning("Could not remove MuseScore AppImage at %s", appimage) def ensure_musescore_binary() -> Optional[str]: """Ensure a MuseScore binary is available for rendering.""" global _MUSESCORE_BINARY if _MUSESCORE_BINARY and os.path.exists(_MUSESCORE_BINARY): return _MUSESCORE_BINARY env_path = os.environ.get(MUSESCORE_ENV_VAR) if env_path and os.path.exists(env_path): logger.info("Using MuseScore binary from %s", MUSESCORE_ENV_VAR) _MUSESCORE_BINARY = env_path return _MUSESCORE_BINARY for candidate in ("mscore", "mscore3", "musescore3", "musescore", "MuseScore3"): found = shutil.which(candidate) if found: logger.info("Found MuseScore executable on PATH: %s", found) _MUSESCORE_BINARY = found return _MUSESCORE_BINARY MUSESCORE_STORAGE_DIR.mkdir(parents=True, exist_ok=True) appimage_path = (MUSESCORE_STORAGE_DIR / "MuseScore.AppImage").resolve(strict=False) apprun_path = (MUSESCORE_STORAGE_DIR / "squashfs-root" / "AppRun").resolve(strict=False) if apprun_path.exists(): logger.info("Using cached MuseScore AppRun at %s", apprun_path) os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") _MUSESCORE_BINARY = str(apprun_path) return _MUSESCORE_BINARY for attempt in (1, 2): if not appimage_path.exists() or appimage_path.stat().st_size == 0: logger.info("MuseScore AppImage missing. Downloading (attempt %s)...", attempt) if not _download_file(MUSESCORE_APPIMAGE_URL, appimage_path): return None try: os.chmod(appimage_path, 0o755) except Exception as exc: logger.warning("Could not chmod MuseScore AppImage: %s", exc) try: with log_timing("Extracting MuseScore AppImage"): subprocess.run( [str(appimage_path), "--appimage-extract"], cwd=MUSESCORE_STORAGE_DIR, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=MUSESCORE_EXTRACT_TIMEOUT, ) except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) logger.error("MuseScore extraction failed: %s", stderr) _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) continue except subprocess.TimeoutExpired: logger.error("MuseScore extraction timed out after %ss", MUSESCORE_EXTRACT_TIMEOUT) _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) continue if apprun_path.exists(): os.environ.setdefault("QT_QPA_PLATFORM", "offscreen") _MUSESCORE_BINARY = str(apprun_path) try: os.chmod(apprun_path, 0o755) except Exception: logger.debug("Could not chmod MuseScore AppRun; continuing anyway.") logger.info("MuseScore AppRun ready at %s", _MUSESCORE_BINARY) return _MUSESCORE_BINARY logger.error("MuseScore extraction completed but AppRun was not found.") _cleanup_musescore_artifacts(remove_appimage=(attempt == 1)) logger.error("MuseScore binary unavailable after retries.") return None def ensure_musescore_v3_binary() -> Optional[str]: """Ensure a MuseScore 3.x binary is available for rendering.""" global _MUSESCORE_V3_BINARY if _MUSESCORE_V3_BINARY and os.path.exists(_MUSESCORE_V3_BINARY): return _MUSESCORE_V3_BINARY env_path = os.environ.get(MUSESCORE_V3_ENV_VAR) if env_path and os.path.exists(env_path): logger.info("Using MuseScore 3 binary from %s", MUSESCORE_V3_ENV_VAR) _MUSESCORE_V3_BINARY = env_path return _MUSESCORE_V3_BINARY storage = MUSESCORE_V3_STORAGE_DIR storage.mkdir(parents=True, exist_ok=True) appimage_path = (storage / "MuseScore-3.AppImage").resolve(strict=False) apprun_path = (storage / "squashfs-root" / "AppRun").resolve(strict=False) if apprun_path.exists(): logger.info("Using cached MuseScore 3 AppRun at %s", apprun_path) _MUSESCORE_V3_BINARY = str(apprun_path) return _MUSESCORE_V3_BINARY if not appimage_path.exists(): logger.info("MuseScore 3 AppImage missing. Downloading (first run only)...") if not _download_file(MUSESCORE_V3_APPIMAGE_URL, appimage_path): return None try: os.chmod(appimage_path, 0o755) except Exception as exc: logger.warning("Could not chmod MuseScore 3 AppImage: %s", exc) try: with log_timing("Extracting MuseScore 3 AppImage"): subprocess.run( [str(appimage_path), "--appimage-extract"], cwd=storage, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=MUSESCORE_EXTRACT_TIMEOUT, ) except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) logger.error("MuseScore 3 extraction failed: %s", stderr) return None except subprocess.TimeoutExpired: logger.error("MuseScore 3 extraction timed out after %ss", MUSESCORE_EXTRACT_TIMEOUT) return None if apprun_path.exists(): _MUSESCORE_V3_BINARY = str(apprun_path) try: os.chmod(apprun_path, 0o755) except Exception: pass logger.info("MuseScore 3 AppRun ready at %s", _MUSESCORE_V3_BINARY) return _MUSESCORE_V3_BINARY logger.error("MuseScore 3 extraction did not produce the expected AppRun binary.") return None def _download_xvfb_package(dest_dir: Path) -> Optional[Path]: """Download the Xvfb .deb package using apt.""" try: completed = subprocess.run( ["apt", "download", "xvfb"], cwd=str(dest_dir), check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) logger.debug("apt download xvfb stdout: %s", completed.stdout.strip()) if completed.stderr: logger.debug("apt download xvfb stderr: %s", completed.stderr.strip()) except FileNotFoundError: logger.error("'apt' command not available; cannot download Xvfb automatically.") return None except subprocess.CalledProcessError as exc: logger.error( "Failed to download Xvfb package (exit %s): %s", exc.returncode, exc.stderr.strip() if exc.stderr else exc, ) return None deb_candidates = sorted(dest_dir.glob("xvfb_*.deb"), key=lambda p: p.stat().st_mtime, reverse=True) if not deb_candidates: logger.error("apt download xvfb did not produce any .deb files under %s", dest_dir) return None return deb_candidates[0] def _extract_xvfb_binary(deb_path: Path, target_dir: Path) -> Optional[Path]: extract_dir = target_dir / "pkg" if extract_dir.exists(): shutil.rmtree(extract_dir, ignore_errors=True) try: subprocess.run( ["dpkg-deb", "-x", str(deb_path), str(extract_dir)], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) except FileNotFoundError: logger.error("'dpkg-deb' command not available; cannot extract Xvfb package.") return None except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode(errors="ignore") if isinstance(exc.stderr, bytes) else exc.stderr logger.error("Failed to extract Xvfb package: %s", stderr or exc) return None xvfb_path = extract_dir / "usr/bin/Xvfb" if xvfb_path.exists(): logger.info("Xvfb binary extracted to %s", xvfb_path) try: os.chmod(xvfb_path, 0o755) except Exception: pass try: deb_path.unlink() except Exception: pass return xvfb_path logger.error("Extracted Xvfb package but could not find usr/bin/Xvfb inside %s", extract_dir) return None def ensure_xvfb_binary() -> Optional[str]: """Ensure we have an Xvfb binary available for headless rendering.""" global _XVFB_BINARY if _XVFB_BINARY and os.path.exists(_XVFB_BINARY): return _XVFB_BINARY env_path = os.environ.get(XVFB_ENV_VAR) if env_path and os.path.exists(env_path): _XVFB_BINARY = env_path return _XVFB_BINARY which = shutil.which("Xvfb") if which: _XVFB_BINARY = which return _XVFB_BINARY XVFB_STORAGE_DIR.mkdir(parents=True, exist_ok=True) extracted_bin = XVFB_STORAGE_DIR / "pkg" / "usr" / "bin" / "Xvfb" if extracted_bin.exists(): _XVFB_BINARY = str(extracted_bin) return _XVFB_BINARY deb_path = _download_xvfb_package(XVFB_STORAGE_DIR) if not deb_path: return None extracted = _extract_xvfb_binary(deb_path, XVFB_STORAGE_DIR) if extracted: _XVFB_BINARY = str(extracted) return _XVFB_BINARY return None def initialize_musescore_backend() -> bool: """Initialize MuseScore AppRun at startup to avoid on-demand downloads.""" global _MUSESCORE_READY if _MUSESCORE_READY: return True available = [] primary = ensure_musescore_binary() if primary: available.append(primary) logger.info("MuseScore 4 AppRun ready at startup: %s", primary) legacy = ensure_musescore_v3_binary() if legacy: available.append(legacy) logger.info("MuseScore 3 AppRun ready at startup: %s", legacy) if available: _MUSESCORE_READY = True return True logger.warning("No MuseScore AppRun binaries could be initialized; score visualization will fail.") return False def _coalesce_musescore_output(output_path: str) -> Optional[str]: """ Normalize MuseScore CLI output when it renders multiple PNG pages. MuseScore writes `basename-1.png`, `basename-2.png`, ... even if we request a single filename. We promote the first page to the requested output path so downstream code can always load one predictable image. """ target = Path(output_path) if target.exists(): return str(target) suffix = target.suffix pattern = f"{target.stem}-*{suffix}" if suffix else f"{target.name}-*" matches = sorted(target.parent.glob(pattern)) if not matches: return None first_page = matches[0] normalized_path: Optional[Path] = None try: shutil.move(str(first_page), str(target)) normalized_path = target except Exception: try: shutil.copy(str(first_page), str(target)) normalized_path = target except Exception: normalized_path = first_page if normalized_path == target: logger.debug("Normalized MuseScore output %s -> %s", first_page, target) else: logger.debug("Using MuseScore page %s as output", first_page) # Remove leftover pages to avoid clutter, keep best-effort for extra in matches: if extra == first_page: continue try: extra.unlink() except Exception: pass return str(normalized_path) def persist_rendered_image(src_path: str) -> Optional[str]: """Copy rendered PNG to a persistent artifacts directory for UI display.""" if not src_path or not os.path.exists(src_path): return None try: RENDER_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) dest = RENDER_OUTPUT_DIR / f"{int(time.time()*1000)}_{Path(src_path).name}" shutil.copy2(src_path, dest) return str(dest) except Exception as exc: logger.warning("Could not persist rendered image %s: %s", src_path, exc) return src_path @contextmanager def xvfb_session(): """Spin up a temporary Xvfb server if available.""" xvfb_bin = ensure_xvfb_binary() if not xvfb_bin: logger.warning("Xvfb binary unavailable; proceeding without virtual display.") yield None return display = None base_dir = Path("/tmp/.X11-unix") try: base_dir.mkdir(mode=0o1777, exist_ok=True) except Exception: pass used = {p.name for p in base_dir.glob("X*")} for candidate in range(99, 160): name = f"X{candidate}" if name not in used: display = f":{candidate}" break if display is None: logger.warning("No free DISPLAY slots for Xvfb.") yield None return cmd = [ xvfb_bin, display, "-screen", "0", "1920x1080x24", "-nolisten", "tcp", ] logger.debug("Starting Xvfb with command: %s", " ".join(cmd)) proc = subprocess.Popen( cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) time.sleep(0.5) if proc.poll() is not None: logger.error("Xvfb failed to start (exit %s).", proc.returncode) yield None return try: yield display finally: proc.terminate() try: proc.wait(timeout=5) except subprocess.TimeoutExpired: proc.kill() def render_with_musescore(musicxml_path: Optional[str], output_path: str) -> Optional[str]: """Render using MuseScore command-line interface.""" if not musicxml_path or not os.path.exists(musicxml_path): return None candidates = [] legacy = ensure_musescore_v3_binary() if legacy: candidates.append(("MuseScore 3", legacy, True)) primary = ensure_musescore_binary() if primary: candidates.append(("MuseScore 4", primary, True)) if not candidates: logger.warning("No MuseScore binaries available for rendering.") return None last_error = None for label, musescore_bin, requires_display in candidates: env = os.environ.copy() env.setdefault("QTWEBENGINE_DISABLE_SANDBOX", "1") env.setdefault("MUSESCORE_NO_AUDIO", "1") cmd = [musescore_bin, "-o", output_path, musicxml_path] logger.info("Attempting rendering with %s (%s).", label, musescore_bin) try: with xvfb_session() as display: if display: env["DISPLAY"] = display env["QT_QPA_PLATFORM"] = "xcb" logger.debug("%s: using Xvfb display %s", label, display) else: if requires_display: logger.warning("%s requires an X11 display but Xvfb could not be started.", label) continue env["QT_QPA_PLATFORM"] = "offscreen" logger.debug("%s: using Qt offscreen platform.", label) with log_timing(f"{label} rendering"): subprocess.run( cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env, timeout=MUSESCORE_RENDER_TIMEOUT, ) except subprocess.CalledProcessError as exc: stderr = exc.stderr.decode(errors='ignore') if exc.stderr else str(exc) logger.error("%s rendering failed: %s", label, stderr) last_error = stderr continue except subprocess.TimeoutExpired: logger.error("%s rendering timed out after %ss", label, MUSESCORE_RENDER_TIMEOUT) last_error = f"{label} timed out" continue normalized_path = _coalesce_musescore_output(output_path) if normalized_path and os.path.exists(normalized_path): logger.info("%s rendered %s -> %s", label, musicxml_path, normalized_path) return normalized_path logger.error("%s rendered score but the expected output file was not found.", label) last_error = "output missing" logger.error("All MuseScore binaries failed to render the score. Last error: %s", last_error) return None def resolve_musicxml_path(musicxml_file) -> Optional[str]: """Return a filesystem path for the uploaded MusicXML file.""" if musicxml_file is None: return None if isinstance(musicxml_file, (str, os.PathLike)): return str(musicxml_file) if isinstance(musicxml_file, dict) and "name" in musicxml_file: return musicxml_file["name"] file_path = getattr(musicxml_file, "name", None) if file_path: return file_path return None def save_parsed_musicxml(score: pt.score.Score, original_path: Optional[str]) -> Optional[str]: """ Persist the parsed/normalized score to a temporary MusicXML file. Returns the path to the saved file or None if saving fails. """ try: suffix = ".musicxml" if original_path: original_suffix = Path(original_path).suffix.lower() if original_suffix in {".xml", ".musicxml"}: suffix = original_suffix fd, tmp_path = tempfile.mkstemp(suffix=suffix) os.close(fd) with log_timing("Saving parsed MusicXML"): pt.save_musicxml(score, tmp_path) return tmp_path except Exception as exc: logger.warning("Could not save parsed MusicXML: %s", exc) return None def render_score_to_image( score: pt.score.Score, output_path: str, source_musicxml_path: Optional[str] = None ) -> Optional[str]: """ Render score directly with the MuseScore AppRun (no other fallbacks). The `score` argument is unused but kept for backward compatibility with the earlier pipeline that rendered from a score object. """ del score # Render is driven solely by the MusicXML path if not source_musicxml_path or not os.path.exists(source_musicxml_path): logger.error("Cannot render score: MusicXML path '%s' not found.", source_musicxml_path) return None return render_with_musescore(source_musicxml_path, output_path) def predict_analysis( model: ContinualAnalysisGNN, score: pt.score.Score, tasks: list ) -> Dict[str, np.ndarray]: """ Perform music analysis prediction. Parameters ---------- model : ContinualAnalysisGNN The model to use for prediction score : pt.score.Score The score to analyze tasks : list List of analysis tasks to perform Returns ------- dict Dictionary mapping task names to predictions and confidence scores """ with torch.no_grad(): with log_timing("Model prediction"): predictions = model.predict(score) # Decode predictions decoded_predictions = {} for task in tasks: if task in predictions: pred_tensor = predictions[task] if len(pred_tensor.shape) > 1: # Get confidence scores (probabilities) pred_probs = torch.softmax(pred_tensor, dim=-1) pred_onehot = torch.argmax(pred_tensor, dim=-1) # Get confidence for the predicted class confidence = torch.max(pred_probs, dim=-1)[0] # Store confidence scores decoded_predictions[f"{task}_confidence"] = confidence.cpu().numpy() else: pred_onehot = pred_tensor # Decode using available representations if task in available_representations: try: decoded = available_representations[task].decode( pred_onehot.reshape(-1, 1) ) # Convert to numpy array if it's a list if isinstance(decoded, list): decoded_predictions[task] = np.array(decoded).flatten() else: decoded_predictions[task] = decoded.flatten() except (IndexError, ValueError) as e: logger.warning("Error decoding %s predictions: %s", task, e) # Fallback to raw indices decoded_predictions[task] = pred_onehot.cpu().numpy() else: decoded_predictions[task] = pred_onehot.cpu().numpy() # Add timing information try: if "onset" in predictions: decoded_predictions["onset_beat"] = predictions["onset"].cpu().numpy() else: decoded_predictions["onset_beat"] = score.note_array()["onset_beat"] except (AttributeError, KeyError, IndexError) as e: logger.warning("Could not add onset timing: %s", e) try: if "s_measure" in predictions: decoded_predictions["measure"] = predictions["s_measure"].cpu().numpy() else: decoded_predictions["measure"] = score[0].measure_number_map(score.note_array()["onset_div"]) except (AttributeError, KeyError, IndexError) as e: logger.warning("Could not add measure information: %s", e) return decoded_predictions def process_musicxml( musicxml_file, selected_tasks: list ) -> Tuple[Optional[str], Optional[pd.DataFrame], Optional[str], str]: """ Process a MusicXML file and return visualization and analysis results. Parameters ---------- musicxml_file : file Uploaded MusicXML file selected_tasks : list List of selected analysis tasks Returns ------- tuple (image_path, dataframe, parsed_musicxml_path, status_message) """ if musicxml_file is None: return None, None, None, "Please upload a MusicXML file." if not selected_tasks: return None, None, None, "Please select at least one analysis task." try: score_path = resolve_musicxml_path(musicxml_file) if score_path is None or not os.path.exists(score_path): return None, None, None, "Could not locate the uploaded MusicXML file." # Load the model status_msg = "Loading model..." logger.info(status_msg) model = load_model() # Load the score status_msg = "Loading score..." logger.info(status_msg) score = pt.load_musicxml(score_path) parsed_score_path = save_parsed_musicxml(score, score_path) # Render score to image with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_img: img_path = tmp_img.name rendered_path: Optional[str] = None predictions: Dict[str, np.ndarray] = {} source_path = parsed_score_path or score_path parallel_enabled = should_parallelize() logger.info("Rendering score (parallel analysis enabled=%s)...", parallel_enabled) if parallel_enabled: logger.info("Running analysis and visualization in parallel (threads=%s).", 2) render_success = False analysis_success = False with ThreadPoolExecutor(max_workers=2) as executor: future_map = { executor.submit( render_score_to_image, score, img_path, source_musicxml_path=source_path, ): "render", executor.submit( predict_analysis, model, score, selected_tasks, ): "analysis", } for future in as_completed(future_map): task_name = future_map[future] try: result = future.result() except Exception: logger.exception("%s task failed.", task_name.capitalize()) continue if task_name == "render": rendered_path = result render_success = True else: predictions = result or {} analysis_success = True if not render_success: logger.info("Retrying score rendering sequentially after parallel failure.") rendered_path = render_score_to_image( score, img_path, source_musicxml_path=source_path, ) if not analysis_success: logger.info("Retrying analysis sequentially after parallel failure.") predictions = predict_analysis(model, score, selected_tasks) else: logger.info("Running analysis and visualization sequentially (parallel disabled).") rendered_path = render_score_to_image( score, img_path, source_musicxml_path=source_path, ) predictions = predict_analysis(model, score, selected_tasks) persisted_path = persist_rendered_image(rendered_path) if rendered_path else None if rendered_path is None or persisted_path is None: logger.warning("MuseScore AppRun could not render the score or save the PNG; visualization will be unavailable.") # Create DataFrame if predictions: df = pd.DataFrame(predictions) # Add note/event IDs if 'note_id' not in df.columns: df.insert(0, 'note_id', range(len(df))) # Convert tpc_in_label logits into NCT-friendly labels if 'tpc_in_label' in df.columns: df['tpc_in_label'] = np.where( df['tpc_in_label'].astype(int) == 0, "NCT", "Chord Tone" ) # Reorder columns to have timing info first, then predictions, then confidence timing_cols = [col for col in ['note_id', 'onset_beat', 'measure'] if col in df.columns] confidence_cols = [col for col in df.columns if col.endswith('_confidence')] prediction_cols = [col for col in df.columns if col not in timing_cols and col not in confidence_cols] # Interleave predictions with their confidence scores ordered_cols = timing_cols.copy() for pred_col in prediction_cols: ordered_cols.append(pred_col) conf_col = f"{pred_col}_confidence" if conf_col in confidence_cols: ordered_cols.append(conf_col) df = df[ordered_cols] # Apply user-friendly column names rename_map = {} for key, label in DISPLAY_NAME_OVERRIDES.items(): if key in df.columns: rename_map[key] = label conf_key = f"{key}_confidence" if conf_key in df.columns: rename_map[conf_key] = f"{label} Confidence" if rename_map: df = df.rename(columns=rename_map) status_msg = f"✓ Analysis complete! Analyzed {len(df)} notes with {len(selected_tasks)} task(s)." if parsed_score_path: status_msg += " Parsed MusicXML ready for download." else: df = pd.DataFrame() status_msg = "⚠ Analysis returned no predictions." if parsed_score_path: status_msg += " Parsed MusicXML ready for download." return persisted_path, df, parsed_score_path, status_msg except Exception as e: error_msg = f"Error processing file: {str(e)}\n\n{traceback.format_exc()}" logger.error(error_msg) return None, None, None, error_msg # Define available tasks AVAILABLE_TASKS = { "cadence": "Cadence Detection", "localkey": "Local Key", "tonkey": "Tonalized Key", "quality": "Chord Quality", "root": "Chord Root", "bass": "Bass Note", "inversion": "Chord Inversion", "degree1": "Primary Degree", "degree2": "Secondary Degree", "romanNumeral": "Roman Numeral Analysis", "phrase": "Phrase Segmentation", "section": "Section Detection", "hrhythm": "Harmonic Rhythm", "pcset": "Pitch-Class Set", "tpc_in_label": "Non-Chord Tone (NCT)", "note_degree": "Note Degree", } DISPLAY_NAME_OVERRIDES = { "tpc_in_label": "NCT", "note_degree": "Note Degree", } # Ensure MuseScore AppRun is available before the UI is constructed initialize_musescore_backend() # Create Gradio interface with gr.Blocks(title="AnalysisGNN Music Analysis", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎵 AnalysisGNN Music Analysis Upload a MusicXML score to perform automatic music analysis using Graph Neural Networks. **Supported Analysis Tasks:** - Cadence Detection - Key Analysis (Local & Tonalized) - Harmonic Analysis (Chords, Inversions, Roman Numerals) - Phrase & Section Segmentation - Non-Chord Tone Detection (TPC-in-label / NCT) - Note Degree Labeling **Model:** Pre-trained AnalysisGNN from [manoskary/analysisGNN](https://github.com/manoskary/analysisGNN) """) with gr.Row(): with gr.Column(scale=1): # Input section gr.Markdown("### 📁 Input") file_input = gr.File( label="Upload MusicXML Score", file_types=[".musicxml", ".xml", ".mxl"], type="filepath" ) task_selector = gr.CheckboxGroup( choices=list(AVAILABLE_TASKS.values()), value=["Cadence Detection", "Local Key", "Roman Numeral Analysis"], label="Select Analysis Tasks", info="Choose which tasks to perform" ) analyze_btn = gr.Button("🎼 Analyze Score", variant="primary", size="lg") gr.Markdown("---") example_btn = gr.Button("🎵 Try Example (Mozart K.158)", size="sm") with gr.Column(scale=2): # Output section gr.Markdown("### 📊 Results") status_output = gr.Textbox( label="Status", lines=2, interactive=False ) with gr.Row(): with gr.Column(): # Score visualization gr.Markdown("### 🎼 Score Visualization") image_output = gr.Image( label="Rendered Score", type="filepath" ) parsed_score_output = gr.File( label="Parsed MusicXML (Download)", interactive=False ) with gr.Row(): with gr.Column(): # Analysis results table gr.Markdown("### 📈 Analysis Results") table_output = gr.Dataframe( label="Analysis Output", wrap=True, interactive=False ) download_btn = gr.Button("💾 Download Results as CSV") csv_output = gr.File(label="Download CSV") # Example section gr.Markdown(""" ### 💡 Tips & Information **Getting Started:** - Click "Try Example" to load a Mozart quartet, or upload your own MusicXML file - Select the analysis tasks you're interested in - Click "Analyze Score" to run the model **Analysis Output:** The table shows note-level predictions for all selected tasks: - **Onset & Measure**: Timing information - **Keys**: Detected key areas (local and tonalized) - **Chords**: Harmonic analysis with Roman numerals - **Cadences**: Identified cadence points and types **Score Visualization:** Requires MuseScore or LilyPond for rendering. If unavailable, analysis will still work. """) # Event handlers def analyze_wrapper(file, tasks_selected): # Convert task names back to internal names task_mapping = {v: k for k, v in AVAILABLE_TASKS.items()} selected_task_keys = [task_mapping[t] for t in tasks_selected if t in task_mapping] return process_musicxml(file, selected_task_keys) def load_example(): """Load example Mozart score.""" import urllib.request url = "https://raw.githubusercontent.com/manoskary/humdrum-mozart-quartets/refs/heads/master/musicxml/k158-01.musicxml" # Create artifacts directory if it doesn't exist os.makedirs("./artifacts", exist_ok=True) example_path = "./artifacts/k158-01.musicxml" if not os.path.exists(example_path): try: logger.info("Downloading example score from: %s", url) urllib.request.urlretrieve(url, example_path) logger.info("Example score saved to: %s", example_path) except Exception as e: return None, f"Error downloading example: {e}" return example_path, "Example loaded! Click 'Analyze Score' to proceed." analyze_btn.click( fn=analyze_wrapper, inputs=[file_input, task_selector], outputs=[image_output, table_output, parsed_score_output, status_output] ) example_btn.click( fn=load_example, outputs=[file_input, status_output] ) def save_csv(df): if df is None or len(df) == 0: return None with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp: df.to_csv(tmp.name, index=False) return tmp.name download_btn.click( fn=save_csv, inputs=[table_output], outputs=[csv_output] ) # Launch the app if __name__ == "__main__": # Pre-load the model at startup for efficiency logger.info("=" * 50) logger.info("Initializing AnalysisGNN app...") logger.info("=" * 50) logger.info("Pre-loading model at startup...") load_model() logger.info("Model ready. Launching Gradio interface...") logger.info("=" * 50) demo.launch()