import os, io, textwrap import gradio as gr import numpy as np import pandas as pd import soundfile as sf from PIL import Image, ImageDraw, ImageFont from sklearn.neighbors import NearestNeighbors from sklearn.metrics.pairwise import cosine_similarity import torch torch.set_num_threads(1) from transformers import ClapProcessor, ClapModel from datasets import load_dataset, Audio import requests from itertools import islice from scipy.signal import resample_poly import math import gc from diffusers import AutoPipelineForText2Image # NEW LAST_AUDIO_EMB = None # NEW _SD_PIPE = None # NEW # ========================= # CONFIG # ========================= # Real (WikiArt) embeddings + metadata (you already have these) REAL_EMB_PATH = "wikiart_embeddings.npy" REAL_META_PATH = "wikiart_metadata.csv" # Synthetic (optional; placeholders used if missing) SYNTH_EMB_PATH = "synthetic_image_embeddings.npy" SYNTH_META_PATH = "synthetic_image_metadata.csv" # Data sources (remote) WIKIART_DATASET = "huggan/wikiart" # images streamed by order WIKIART_COUNT = 1500 # load only first 500 to match your embeddings GTZAN_DATASET = "sanchit-gandhi/gtzan" # <-- keep using your dataset PRESET_ORDINALS = [1, 102, 203, 304, 405] # GTZAN presets by ordinal (1-based) # ========================= # LOAD EMBEDDINGS/METADATA # ========================= real_embeddings = np.load(REAL_EMB_PATH) real_meta = pd.read_csv(REAL_META_PATH) real_paths = real_meta["path"].astype(str).tolist() N_REAL = len(real_paths) # Synthetic (optional): fall back to placeholders if not present try: synthetic_embeddings = np.load(SYNTH_EMB_PATH) synthetic_meta = pd.read_csv(SYNTH_META_PATH) synthetic_paths = synthetic_meta["path"].astype(str).tolist() except Exception: synthetic_embeddings = np.zeros((10, real_embeddings.shape[1]), dtype=np.float32) synthetic_paths = [f"synthetic_placeholder_{i}.png" for i in range(10)] # ========================= # HELPERS # ========================= HF_DATASET_RESOLVE_BASE = "https://huggingface.co/datasets" def _placeholder_img(text: str, size=(512, 512)) -> Image.Image: img = Image.new("RGB", size, (240, 240, 240)) draw = ImageDraw.Draw(img) try: font = ImageFont.load_default() except Exception: font = None lines = textwrap.wrap(text, width=28) line_boxes = [draw.textbbox((0, 0), ln, font=font) for ln in lines] total_h = sum(b[3]-b[1] for b in line_boxes) + (len(lines)-1)*6 y = (size[1]-total_h)//2 for ln, bb in zip(lines, line_boxes): w = bb[2]-bb[0]; h = bb[3]-bb[1] x = (size[0]-w)//2 draw.text((x, y), ln, fill=(30,30,30), font=font) y += h + 6 return img def slice_audio_to_wav_bytes(raw_bytes: bytes, start_sec: float | None, duration_sec: float = 30.0) -> bytes: """Trim to [start_sec, start_sec+duration_sec) and return WAV bytes.""" # decode data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False) if data.ndim > 1: data = data.mean(axis=1) total_sec = len(data) / sr # default start=0 if None/blank/NaN try: s = float(start_sec) if start_sec is not None else 0.0 except Exception: s = 0.0 s = max(0.0, s) # clamp start so we always get up to 30s if possible if total_sec > duration_sec and s + duration_sec > total_sec: s = max(0.0, total_sec - duration_sec) start_idx = int(round(s * sr)) end_idx = min(len(data), start_idx + int(round(duration_sec * sr))) seg = data[start_idx:end_idx] # re‑encode to WAV bytes buf = io.BytesIO() sf.write(buf, seg, sr, format="WAV") return buf.getvalue() def _open_local_image(path: str): try: img = Image.open(path).convert("RGB") # keep memory down for the gallery img.thumbnail((512, 512)) return img except Exception: return None # ========================= # WIKIART IMAGES (STREAM FIRST 500 BY ORDER) # ========================= from functools import lru_cache @lru_cache(maxsize=64) # small cache to avoid re-fetching the same few images def _get_wikiart_image_by_index(idx: int): if idx is None or idx < 0 or idx >= min(WIKIART_COUNT, N_REAL): return None # Stream and jump to idx; loads just one image stream = load_dataset(WIKIART_DATASET, split="train", streaming=True) ex = next(islice(stream, idx, idx + 1), None) if ex is None: return None img = ex["image"] if hasattr(img, "convert"): img = img.convert("RGB") # Downsize for gallery to save RAM try: img.thumbnail((512, 512)) except Exception: pass return img # ========================= # GTZAN AUDIO (NON-STREAMING, INDEXABLE) — keeps dataset = sanchit-gandhi/gtzan # ========================= _gtzan_ds = None def _load_gtzan(): """ Load once (non-streaming) so we can index reliably by ordinal. We also ensure the 'audio' column is decoded to array+sr. """ global _gtzan_ds if _gtzan_ds is None: ds = load_dataset(GTZAN_DATASET, split="train") # local cache, no manual URLs # If 'audio' is already an Audio feature (decoded), great. # If not, cast to Audio(decode=True) so ds[i]["audio"] -> {'array', 'sampling_rate'}. if "audio" not in ds.features or not isinstance(ds.features["audio"], Audio): ds = ds.cast_column("audio", Audio(decode=True)) _gtzan_ds = ds return _gtzan_ds def get_gtzan_audio_by_ordinal(n: int): """ 1-based index → ((sr, waveform), raw_wav_bytes) """ if n < 1: raise gr.Error("Index must be ≥ 1") ds = _load_gtzan() idx = n - 1 if idx >= len(ds): raise gr.Error(f"No record at index {n} (dataset size: {len(ds)}).") ex = ds[idx] # 'genre' in some variants, 'label' in others — support both genre = ex.get("genre", ex.get("label", "")) audio = ex["audio"] # {'array': np.ndarray, 'sampling_rate': int} arr = audio["array"] sr = audio["sampling_rate"] # Ensure mono for CLAP if arr.ndim > 1: arr = arr.mean(axis=1) # Create WAV bytes so the rest of your pipeline can reuse uniformly buf = io.BytesIO() sf.write(buf, arr, sr, format="WAV") raw = buf.getvalue() return (sr, arr.astype(np.float32)), raw def get_preset_audio_bytes(i: int) -> bytes: """Preset buttons: first 5 items by order.""" n = PRESET_ORDINALS[i] (_sr_data, raw) = get_gtzan_audio_by_ordinal(n) return raw # ========================= # CLAP MODEL # ========================= DEVICE = "cpu" clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused") clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(DEVICE) clap_model.eval() # Candidates the model can choose from (expand anytime) CANDS = { "subject": [ "jazz band", "symphony orchestra", "solo piano", "acoustic guitar", "electric guitar", "violin", "hip-hop rapper", "DJ in a nightclub", "techno producer", "ambient soundscape", "drum and bass", "punk band", "choir", "saxophone solo", "lofi beats" ], "mood": [ "moody", "melancholic", "energetic", "uplifting", "intense", "dreamy", "peaceful", "dark", "nostalgic", "aggressive", "euphoric" ], "style": [ "surreal", "impressionist", "cubist", "abstract", "street art", "photorealistic", "digital painting", "watercolor", "oil painting", "futuristic", "minimalist", "psychedelic", "retro" ], "visuals": [ "neon lights", "soft gradients", "high contrast", "film grain", "bokeh", "long exposure", "smoky bar", "concert stage", "urban graffiti", "fractals", "misty atmosphere", "warm window light" ] } # Build a flat list and index slices for each pool, then encode once with CLAP's text tower. ALL_PHRASES, POOL_SLICES = [], {} _start = 0 for _k, _lst in CANDS.items(): ALL_PHRASES.extend(_lst) POOL_SLICES[_k] = slice(_start, _start + len(_lst)) _start += len(_lst) with torch.no_grad(): _ti = clap_processor(text=ALL_PHRASES, return_tensors="pt", padding=True) _ti = {k: v.to(DEVICE) for k, v in _ti.items()} _ALL_EMB = clap_model.get_text_features(**_ti) # [N, D] _ALL_EMB = torch.nn.functional.normalize(_ALL_EMB, dim=-1).cpu().numpy() def _best_from_pool(audio_emb: np.ndarray, pool_key: str, topk: int = 1): a = audio_emb / (np.linalg.norm(audio_emb) + 1e-12) sl = POOL_SLICES[pool_key] sims = (_ALL_EMB[sl] @ a) idx = np.argsort(-sims)[:topk] phrases = [CANDS[pool_key][i] for i in idx] return phrases, [float(sims[i]) for i in idx] def build_prompt_from_audio(audio_emb: np.ndarray) -> str: subj, _ = _best_from_pool(audio_emb, "subject", topk=1) moods, _ = _best_from_pool(audio_emb, "mood", topk=2) styles, _ = _best_from_pool(audio_emb, "style", topk=1) vis, _ = _best_from_pool(audio_emb, "visuals", topk=2) directive = "high quality, detailed, 512x512" return f"{subj[0]}, {', '.join(moods)}, {styles[0]}, {', '.join(vis)}, {directive}" def embed_audio_bytes(raw_bytes: bytes): # Read original audio for playback orig, orig_sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False) if orig.ndim > 1: orig = orig.mean(axis=1) # CLAP needs 48 kHz mono target_sr = 48000 proc = orig if orig_sr != target_sr: g = math.gcd(target_sr, orig_sr) up = target_sr // g down = orig_sr // g proc = resample_poly(orig, up, down) inputs = clap_processor(audios=proc, sampling_rate=target_sr, return_tensors="pt") with torch.no_grad(): feats = clap_model.get_audio_features(**{k: v.to(DEVICE) for k, v in inputs.items()}) # Return features + the ORIGINAL (sr, waveform) for gradio playback return feats[0].cpu().numpy(), (orig_sr, orig) # ========================= # NEIGHBOR INDEX + MATCHING # ========================= nn_real = NearestNeighbors(n_neighbors=3, metric="cosine") nn_real.fit(real_embeddings) def recommend_top3_real_plus_best_synth(audio_emb: np.ndarray): # Top-3 real dists, idxs = nn_real.kneighbors(audio_emb.reshape(1, -1), n_neighbors=3) real_results = [(real_paths[i], 1 - d) for i, d in zip(idxs[0], dists[0])] # Best synthetic synth_sims = cosine_similarity(audio_emb.reshape(1, -1), synthetic_embeddings)[0] j = int(np.argmax(synth_sims)) synth_result = (synthetic_paths[j], float(synth_sims[j])) return real_results, synth_result def open_images_with_captions(match_list): """ match_list: list of (csv_path_string, score, tag) Assumes order alignment: CSV row index == WikiArt streamed index. """ images, captions = [], [] if not hasattr(open_images_with_captions, "_real_index_map"): open_images_with_captions._real_index_map = {real_paths[i]: i for i in range(len(real_paths))} idx_map = open_images_with_captions._real_index_map for p, score, tag in match_list: img = None base = os.path.basename(str(p)) if tag == "Real": i = idx_map.get(p) img = _get_wikiart_image_by_index(i) if i is not None else None elif tag == "Synthetic": # Try to load from repo file path, e.g. "synthetic_images/synthetic_08.jpg" if isinstance(p, str) and os.path.exists(p): img = _open_local_image(p) if img is None: # fallback placeholder if anything failed above img = _placeholder_img(f"{tag}\n{base}") images.append(img) captions.append(f"{tag}: {base} (sim {score:.2f})") return images, "\n".join(captions) # ========================= # PIPELINE # ========================= def _get_sd_pipe(): global _SD_PIPE if _SD_PIPE is not None: return _SD_PIPE if torch.cuda.is_available(): # Fast path (recommended) _SD_PIPE = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch.float16, use_safetensors=True, ).to("cuda") else: # CPU fallback: smaller + few steps to avoid timeouts/OOM _SD_PIPE = AutoPipelineForText2Image.from_pretrained( "stabilityai/sd-turbo", # lighter than SDXL torch_dtype=torch.float32, use_safetensors=True, ).to("cpu") return _SD_PIPE import contextlib def generate_image_from_last_audio(): if LAST_AUDIO_EMB is None: gr.Warning("Load or upload audio first, then try again.") return None prompt = build_prompt_from_audio(LAST_AUDIO_EMB) pipe = _get_sd_pipe() if pipe is None: gr.Warning("Image generator not available.") return None is_cuda = (pipe.device.type == "cuda") autocast_ctx = torch.autocast("cuda") if is_cuda else contextlib.nullcontext() # Small settings on CPU to keep it reasonable steps = 4 if is_cuda else 2 w = h = 512 if is_cuda else 256 guidance = 0.0 # Turbo models like low guidance with autocast_ctx, torch.inference_mode(): img = pipe(prompt, num_inference_steps=steps, guidance_scale=guidance, width=w, height=h).images[0] if not is_cuda: gr.Info("Generated on CPU (reduced size/steps). For fast 512×512, enable a GPU runtime.") return img def run_example(i: int): # full clip for examples (no trimming) raw = get_preset_audio_bytes(i) playable, images, caps = run_pipeline_from_bytes(raw) # enable generate button return playable, images, caps, gr.update(interactive=True) def run_pipeline_from_bytes(raw_bytes: bytes): if raw_bytes is None: return None, [], "" audio_emb, playable = embed_audio_bytes(raw_bytes) global LAST_AUDIO_EMB LAST_AUDIO_EMB = audio_emb.copy() real3, synth1 = recommend_top3_real_plus_best_synth(audio_emb) combined = [(p, s, "Real") for (p, s) in real3] + [(synth1[0], synth1[1], "Synthetic")] images, captions_str = open_images_with_captions(combined) return playable, images, captions_str # free large arrays we no longer need del audio_emb, real3, synth1, combined gc.collect() return playable, images, captions_str # ========================= # UI # ========================= with gr.Blocks(theme="soft") as demo: gr.Markdown("# 🎧 EchoArt\nMatch your music to visual art — Top 3 real artworks + 1 AI‑generated image.") with gr.Row(): with gr.Column(scale=1): gr.Markdown("**Upload audio (WAV/MP3) or try a preset:**") audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") start_in = gr.Number(label="Start at (seconds, optional)", value=None, precision=2) with gr.Row(): ex_labels = [f"Example {k+1} (idx {PRESET_ORDINALS[k]})" for k in range(5)] ex_buttons = [gr.Button(lbl) for lbl in ex_labels] submit = gr.Button("Find Matching Art", variant="primary") with gr.Column(scale=2): audio_out = gr.Audio(label="Your Audio", interactive=False) gallery = gr.Gallery(label="Top 3 Real + 1 Synthetic", columns=4, height=320) captions = gr.Textbox(label="Matches", lines=4) # --- AI image generation row (button is disabled until audio is processed) --- with gr.Row(): gen_btn = gr.Button("🎨 Generate image from this audio", interactive=False) gen_img = gr.Image(label="AI Image", interactive=False) # ----------------------- # Handlers (must be inside Blocks) # ----------------------- def on_submit(audio_path, start_sec): if not audio_path: # keep the gen button disabled return None, [], "", gr.update(interactive=False) with open(audio_path, "rb") as f: raw = f.read() # trim uploads to 30 seconds (start_sec may be None) raw = slice_audio_to_wav_bytes(raw, start_sec, duration_sec=30.0) playable, images, caps = run_pipeline_from_bytes(raw) # enable generate button now that LAST_AUDIO_EMB is set return playable, images, caps, gr.update(interactive=True) submit.click( on_submit, inputs=[audio_in, start_in], outputs=[audio_out, gallery, captions, gen_btn], ) # Wire example buttons (freeze loop var with k=k) for k, btn in enumerate(ex_buttons): btn.click(lambda k=k: run_example(k), inputs=None, outputs=[audio_out, gallery, captions, gen_btn]) def on_generate(): img = generate_image_from_last_audio() return img gen_btn.click(on_generate, inputs=None, outputs=[gen_img]) # Outside the Blocks context: demo.launch()