EchoArt / app.py
Tomershor1's picture
Update app.py
7ed2000 verified
import os, io, textwrap
import gradio as gr
import numpy as np
import pandas as pd
import soundfile as sf
from PIL import Image, ImageDraw, ImageFont
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
import torch
torch.set_num_threads(1)
from transformers import ClapProcessor, ClapModel
from datasets import load_dataset, Audio
import requests
from itertools import islice
from scipy.signal import resample_poly
import math
import gc
from diffusers import AutoPipelineForText2Image # NEW
LAST_AUDIO_EMB = None # NEW
_SD_PIPE = None # NEW
# =========================
# CONFIG
# =========================
# Real (WikiArt) embeddings + metadata (you already have these)
REAL_EMB_PATH = "wikiart_embeddings.npy"
REAL_META_PATH = "wikiart_metadata.csv"
# Synthetic (optional; placeholders used if missing)
SYNTH_EMB_PATH = "synthetic_image_embeddings.npy"
SYNTH_META_PATH = "synthetic_image_metadata.csv"
# Data sources (remote)
WIKIART_DATASET = "huggan/wikiart" # images streamed by order
WIKIART_COUNT = 1500 # load only first 500 to match your embeddings
GTZAN_DATASET = "sanchit-gandhi/gtzan" # <-- keep using your dataset
PRESET_ORDINALS = [1, 102, 203, 304, 405] # GTZAN presets by ordinal (1-based)
# =========================
# LOAD EMBEDDINGS/METADATA
# =========================
real_embeddings = np.load(REAL_EMB_PATH)
real_meta = pd.read_csv(REAL_META_PATH)
real_paths = real_meta["path"].astype(str).tolist()
N_REAL = len(real_paths)
# Synthetic (optional): fall back to placeholders if not present
try:
synthetic_embeddings = np.load(SYNTH_EMB_PATH)
synthetic_meta = pd.read_csv(SYNTH_META_PATH)
synthetic_paths = synthetic_meta["path"].astype(str).tolist()
except Exception:
synthetic_embeddings = np.zeros((10, real_embeddings.shape[1]), dtype=np.float32)
synthetic_paths = [f"synthetic_placeholder_{i}.png" for i in range(10)]
# =========================
# HELPERS
# =========================
HF_DATASET_RESOLVE_BASE = "https://huggingface.co/datasets"
def _placeholder_img(text: str, size=(512, 512)) -> Image.Image:
img = Image.new("RGB", size, (240, 240, 240))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.load_default()
except Exception:
font = None
lines = textwrap.wrap(text, width=28)
line_boxes = [draw.textbbox((0, 0), ln, font=font) for ln in lines]
total_h = sum(b[3]-b[1] for b in line_boxes) + (len(lines)-1)*6
y = (size[1]-total_h)//2
for ln, bb in zip(lines, line_boxes):
w = bb[2]-bb[0]; h = bb[3]-bb[1]
x = (size[0]-w)//2
draw.text((x, y), ln, fill=(30,30,30), font=font)
y += h + 6
return img
def slice_audio_to_wav_bytes(raw_bytes: bytes, start_sec: float | None, duration_sec: float = 30.0) -> bytes:
"""Trim to [start_sec, start_sec+duration_sec) and return WAV bytes."""
# decode
data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
if data.ndim > 1:
data = data.mean(axis=1)
total_sec = len(data) / sr
# default start=0 if None/blank/NaN
try:
s = float(start_sec) if start_sec is not None else 0.0
except Exception:
s = 0.0
s = max(0.0, s)
# clamp start so we always get up to 30s if possible
if total_sec > duration_sec and s + duration_sec > total_sec:
s = max(0.0, total_sec - duration_sec)
start_idx = int(round(s * sr))
end_idx = min(len(data), start_idx + int(round(duration_sec * sr)))
seg = data[start_idx:end_idx]
# re‑encode to WAV bytes
buf = io.BytesIO()
sf.write(buf, seg, sr, format="WAV")
return buf.getvalue()
def _open_local_image(path: str):
try:
img = Image.open(path).convert("RGB")
# keep memory down for the gallery
img.thumbnail((512, 512))
return img
except Exception:
return None
# =========================
# WIKIART IMAGES (STREAM FIRST 500 BY ORDER)
# =========================
from functools import lru_cache
@lru_cache(maxsize=64) # small cache to avoid re-fetching the same few images
def _get_wikiart_image_by_index(idx: int):
if idx is None or idx < 0 or idx >= min(WIKIART_COUNT, N_REAL):
return None
# Stream and jump to idx; loads just one image
stream = load_dataset(WIKIART_DATASET, split="train", streaming=True)
ex = next(islice(stream, idx, idx + 1), None)
if ex is None:
return None
img = ex["image"]
if hasattr(img, "convert"):
img = img.convert("RGB")
# Downsize for gallery to save RAM
try:
img.thumbnail((512, 512))
except Exception:
pass
return img
# =========================
# GTZAN AUDIO (NON-STREAMING, INDEXABLE) — keeps dataset = sanchit-gandhi/gtzan
# =========================
_gtzan_ds = None
def _load_gtzan():
"""
Load once (non-streaming) so we can index reliably by ordinal.
We also ensure the 'audio' column is decoded to array+sr.
"""
global _gtzan_ds
if _gtzan_ds is None:
ds = load_dataset(GTZAN_DATASET, split="train") # local cache, no manual URLs
# If 'audio' is already an Audio feature (decoded), great.
# If not, cast to Audio(decode=True) so ds[i]["audio"] -> {'array', 'sampling_rate'}.
if "audio" not in ds.features or not isinstance(ds.features["audio"], Audio):
ds = ds.cast_column("audio", Audio(decode=True))
_gtzan_ds = ds
return _gtzan_ds
def get_gtzan_audio_by_ordinal(n: int):
"""
1-based index → ((sr, waveform), raw_wav_bytes)
"""
if n < 1:
raise gr.Error("Index must be ≥ 1")
ds = _load_gtzan()
idx = n - 1
if idx >= len(ds):
raise gr.Error(f"No record at index {n} (dataset size: {len(ds)}).")
ex = ds[idx]
# 'genre' in some variants, 'label' in others — support both
genre = ex.get("genre", ex.get("label", ""))
audio = ex["audio"] # {'array': np.ndarray, 'sampling_rate': int}
arr = audio["array"]
sr = audio["sampling_rate"]
# Ensure mono for CLAP
if arr.ndim > 1:
arr = arr.mean(axis=1)
# Create WAV bytes so the rest of your pipeline can reuse uniformly
buf = io.BytesIO()
sf.write(buf, arr, sr, format="WAV")
raw = buf.getvalue()
return (sr, arr.astype(np.float32)), raw
def get_preset_audio_bytes(i: int) -> bytes:
"""Preset buttons: first 5 items by order."""
n = PRESET_ORDINALS[i]
(_sr_data, raw) = get_gtzan_audio_by_ordinal(n)
return raw
# =========================
# CLAP MODEL
# =========================
DEVICE = "cpu"
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(DEVICE)
clap_model.eval()
# Candidates the model can choose from (expand anytime)
CANDS = {
"subject": [
"jazz band", "symphony orchestra", "solo piano", "acoustic guitar",
"electric guitar", "violin", "hip-hop rapper", "DJ in a nightclub",
"techno producer", "ambient soundscape", "drum and bass", "punk band",
"choir", "saxophone solo", "lofi beats"
],
"mood": [
"moody", "melancholic", "energetic", "uplifting", "intense",
"dreamy", "peaceful", "dark", "nostalgic", "aggressive", "euphoric"
],
"style": [
"surreal", "impressionist", "cubist", "abstract", "street art",
"photorealistic", "digital painting", "watercolor", "oil painting",
"futuristic", "minimalist", "psychedelic", "retro"
],
"visuals": [
"neon lights", "soft gradients", "high contrast", "film grain",
"bokeh", "long exposure", "smoky bar", "concert stage",
"urban graffiti", "fractals", "misty atmosphere", "warm window light"
]
}
# Build a flat list and index slices for each pool, then encode once with CLAP's text tower.
ALL_PHRASES, POOL_SLICES = [], {}
_start = 0
for _k, _lst in CANDS.items():
ALL_PHRASES.extend(_lst)
POOL_SLICES[_k] = slice(_start, _start + len(_lst))
_start += len(_lst)
with torch.no_grad():
_ti = clap_processor(text=ALL_PHRASES, return_tensors="pt", padding=True)
_ti = {k: v.to(DEVICE) for k, v in _ti.items()}
_ALL_EMB = clap_model.get_text_features(**_ti) # [N, D]
_ALL_EMB = torch.nn.functional.normalize(_ALL_EMB, dim=-1).cpu().numpy()
def _best_from_pool(audio_emb: np.ndarray, pool_key: str, topk: int = 1):
a = audio_emb / (np.linalg.norm(audio_emb) + 1e-12)
sl = POOL_SLICES[pool_key]
sims = (_ALL_EMB[sl] @ a)
idx = np.argsort(-sims)[:topk]
phrases = [CANDS[pool_key][i] for i in idx]
return phrases, [float(sims[i]) for i in idx]
def build_prompt_from_audio(audio_emb: np.ndarray) -> str:
subj, _ = _best_from_pool(audio_emb, "subject", topk=1)
moods, _ = _best_from_pool(audio_emb, "mood", topk=2)
styles, _ = _best_from_pool(audio_emb, "style", topk=1)
vis, _ = _best_from_pool(audio_emb, "visuals", topk=2)
directive = "high quality, detailed, 512x512"
return f"{subj[0]}, {', '.join(moods)}, {styles[0]}, {', '.join(vis)}, {directive}"
def embed_audio_bytes(raw_bytes: bytes):
# Read original audio for playback
orig, orig_sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
if orig.ndim > 1:
orig = orig.mean(axis=1)
# CLAP needs 48 kHz mono
target_sr = 48000
proc = orig
if orig_sr != target_sr:
g = math.gcd(target_sr, orig_sr)
up = target_sr // g
down = orig_sr // g
proc = resample_poly(orig, up, down)
inputs = clap_processor(audios=proc, sampling_rate=target_sr, return_tensors="pt")
with torch.no_grad():
feats = clap_model.get_audio_features(**{k: v.to(DEVICE) for k, v in inputs.items()})
# Return features + the ORIGINAL (sr, waveform) for gradio playback
return feats[0].cpu().numpy(), (orig_sr, orig)
# =========================
# NEIGHBOR INDEX + MATCHING
# =========================
nn_real = NearestNeighbors(n_neighbors=3, metric="cosine")
nn_real.fit(real_embeddings)
def recommend_top3_real_plus_best_synth(audio_emb: np.ndarray):
# Top-3 real
dists, idxs = nn_real.kneighbors(audio_emb.reshape(1, -1), n_neighbors=3)
real_results = [(real_paths[i], 1 - d) for i, d in zip(idxs[0], dists[0])]
# Best synthetic
synth_sims = cosine_similarity(audio_emb.reshape(1, -1), synthetic_embeddings)[0]
j = int(np.argmax(synth_sims))
synth_result = (synthetic_paths[j], float(synth_sims[j]))
return real_results, synth_result
def open_images_with_captions(match_list):
"""
match_list: list of (csv_path_string, score, tag)
Assumes order alignment: CSV row index == WikiArt streamed index.
"""
images, captions = [], []
if not hasattr(open_images_with_captions, "_real_index_map"):
open_images_with_captions._real_index_map = {real_paths[i]: i for i in range(len(real_paths))}
idx_map = open_images_with_captions._real_index_map
for p, score, tag in match_list:
img = None
base = os.path.basename(str(p))
if tag == "Real":
i = idx_map.get(p)
img = _get_wikiart_image_by_index(i) if i is not None else None
elif tag == "Synthetic":
# Try to load from repo file path, e.g. "synthetic_images/synthetic_08.jpg"
if isinstance(p, str) and os.path.exists(p):
img = _open_local_image(p)
if img is None:
# fallback placeholder if anything failed above
img = _placeholder_img(f"{tag}\n{base}")
images.append(img)
captions.append(f"{tag}: {base} (sim {score:.2f})")
return images, "\n".join(captions)
# =========================
# PIPELINE
# =========================
def _get_sd_pipe():
global _SD_PIPE
if _SD_PIPE is not None:
return _SD_PIPE
if torch.cuda.is_available():
# Fast path (recommended)
_SD_PIPE = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
else:
# CPU fallback: smaller + few steps to avoid timeouts/OOM
_SD_PIPE = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sd-turbo", # lighter than SDXL
torch_dtype=torch.float32,
use_safetensors=True,
).to("cpu")
return _SD_PIPE
import contextlib
def generate_image_from_last_audio():
if LAST_AUDIO_EMB is None:
gr.Warning("Load or upload audio first, then try again.")
return None
prompt = build_prompt_from_audio(LAST_AUDIO_EMB)
pipe = _get_sd_pipe()
if pipe is None:
gr.Warning("Image generator not available.")
return None
is_cuda = (pipe.device.type == "cuda")
autocast_ctx = torch.autocast("cuda") if is_cuda else contextlib.nullcontext()
# Small settings on CPU to keep it reasonable
steps = 4 if is_cuda else 2
w = h = 512 if is_cuda else 256
guidance = 0.0 # Turbo models like low guidance
with autocast_ctx, torch.inference_mode():
img = pipe(prompt, num_inference_steps=steps,
guidance_scale=guidance, width=w, height=h).images[0]
if not is_cuda:
gr.Info("Generated on CPU (reduced size/steps). For fast 512×512, enable a GPU runtime.")
return img
def run_example(i: int):
# full clip for examples (no trimming)
raw = get_preset_audio_bytes(i)
playable, images, caps = run_pipeline_from_bytes(raw)
# enable generate button
return playable, images, caps, gr.update(interactive=True)
def run_pipeline_from_bytes(raw_bytes: bytes):
if raw_bytes is None:
return None, [], ""
audio_emb, playable = embed_audio_bytes(raw_bytes)
global LAST_AUDIO_EMB
LAST_AUDIO_EMB = audio_emb.copy()
real3, synth1 = recommend_top3_real_plus_best_synth(audio_emb)
combined = [(p, s, "Real") for (p, s) in real3] + [(synth1[0], synth1[1], "Synthetic")]
images, captions_str = open_images_with_captions(combined)
return playable, images, captions_str
# free large arrays we no longer need
del audio_emb, real3, synth1, combined
gc.collect()
return playable, images, captions_str
# =========================
# UI
# =========================
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🎧 EchoArt\nMatch your music to visual art — Top 3 real artworks + 1 AI‑generated image.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Upload audio (WAV/MP3) or try a preset:**")
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
start_in = gr.Number(label="Start at (seconds, optional)", value=None, precision=2)
with gr.Row():
ex_labels = [f"Example {k+1} (idx {PRESET_ORDINALS[k]})" for k in range(5)]
ex_buttons = [gr.Button(lbl) for lbl in ex_labels]
submit = gr.Button("Find Matching Art", variant="primary")
with gr.Column(scale=2):
audio_out = gr.Audio(label="Your Audio", interactive=False)
gallery = gr.Gallery(label="Top 3 Real + 1 Synthetic", columns=4, height=320)
captions = gr.Textbox(label="Matches", lines=4)
# --- AI image generation row (button is disabled until audio is processed) ---
with gr.Row():
gen_btn = gr.Button("🎨 Generate image from this audio", interactive=False)
gen_img = gr.Image(label="AI Image", interactive=False)
# -----------------------
# Handlers (must be inside Blocks)
# -----------------------
def on_submit(audio_path, start_sec):
if not audio_path:
# keep the gen button disabled
return None, [], "", gr.update(interactive=False)
with open(audio_path, "rb") as f:
raw = f.read()
# trim uploads to 30 seconds (start_sec may be None)
raw = slice_audio_to_wav_bytes(raw, start_sec, duration_sec=30.0)
playable, images, caps = run_pipeline_from_bytes(raw)
# enable generate button now that LAST_AUDIO_EMB is set
return playable, images, caps, gr.update(interactive=True)
submit.click(
on_submit,
inputs=[audio_in, start_in],
outputs=[audio_out, gallery, captions, gen_btn],
)
# Wire example buttons (freeze loop var with k=k)
for k, btn in enumerate(ex_buttons):
btn.click(lambda k=k: run_example(k), inputs=None,
outputs=[audio_out, gallery, captions, gen_btn])
def on_generate():
img = generate_image_from_last_audio()
return img
gen_btn.click(on_generate, inputs=None, outputs=[gen_img])
# Outside the Blocks context:
demo.launch()