Spaces:
Sleeping
Sleeping
File size: 16,986 Bytes
d0571ab 94767f3 34992cb f8b7ca2 d0571ab 94767f3 34992cb f8b7ca2 69e1587 a94ab38 69e1587 d0571ab f20cac8 69e1587 c352f03 a94ab38 31e85ae cbb368a 69e1587 a83fe7f cbb368a a83fe7f cbb368a fc258ac 41af176 fed1d2a 69e1587 cbb368a 69e1587 593b21d 9934200 d0571ab cbb368a d0571ab 69e1587 cbb368a 69e1587 d0571ab 593b21d a83fe7f cbb368a a83fe7f 593b21d d0571ab 69e1587 dea3fed b748e29 cbb368a a94ab38 a83fe7f a94ab38 a83fe7f a94ab38 a83fe7f a94ab38 a83fe7f cbb368a 41af176 cbb368a 41af176 cbb368a 41af176 cbb368a 41af176 cbb368a a83fe7f cbb368a 6a052a7 31e85ae 6a052a7 69e1587 593b21d c352f03 69e1587 c352f03 34992cb cbb368a e70a4b0 94767f3 69e1587 593b21d 69e1587 cbb368a 69e1587 cbb368a b748e29 9934200 d0571ab 9934200 d0571ab 9934200 d0571ab b748e29 9934200 a83fe7f b748e29 d0571ab b748e29 d0571ab b748e29 593b21d 69e1587 cbb368a 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ed2000 6a052a7 7ffaaac fed1d2a 69e1587 593b21d d0571ab 6a052a7 31e85ae 6a052a7 69e1587 593b21d 6a052a7 a94ab38 593b21d 69e1587 cbb368a 69e1587 34992cb e70a4b0 69e1587 dea3fed 69e1587 6ff0102 69e1587 cbb368a 69e1587 6a052a7 dea3fed 69e1587 6a052a7 69e1587 6a052a7 dea3fed 6a052a7 69e1587 6a052a7 7b9d2f4 6a052a7 6ff0102 6a052a7 f20cac8 6ff0102 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 |
import os, io, textwrap
import gradio as gr
import numpy as np
import pandas as pd
import soundfile as sf
from PIL import Image, ImageDraw, ImageFont
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
import torch
torch.set_num_threads(1)
from transformers import ClapProcessor, ClapModel
from datasets import load_dataset, Audio
import requests
from itertools import islice
from scipy.signal import resample_poly
import math
import gc
from diffusers import AutoPipelineForText2Image # NEW
LAST_AUDIO_EMB = None # NEW
_SD_PIPE = None # NEW
# =========================
# CONFIG
# =========================
# Real (WikiArt) embeddings + metadata (you already have these)
REAL_EMB_PATH = "wikiart_embeddings.npy"
REAL_META_PATH = "wikiart_metadata.csv"
# Synthetic (optional; placeholders used if missing)
SYNTH_EMB_PATH = "synthetic_image_embeddings.npy"
SYNTH_META_PATH = "synthetic_image_metadata.csv"
# Data sources (remote)
WIKIART_DATASET = "huggan/wikiart" # images streamed by order
WIKIART_COUNT = 1500 # load only first 500 to match your embeddings
GTZAN_DATASET = "sanchit-gandhi/gtzan" # <-- keep using your dataset
PRESET_ORDINALS = [1, 102, 203, 304, 405] # GTZAN presets by ordinal (1-based)
# =========================
# LOAD EMBEDDINGS/METADATA
# =========================
real_embeddings = np.load(REAL_EMB_PATH)
real_meta = pd.read_csv(REAL_META_PATH)
real_paths = real_meta["path"].astype(str).tolist()
N_REAL = len(real_paths)
# Synthetic (optional): fall back to placeholders if not present
try:
synthetic_embeddings = np.load(SYNTH_EMB_PATH)
synthetic_meta = pd.read_csv(SYNTH_META_PATH)
synthetic_paths = synthetic_meta["path"].astype(str).tolist()
except Exception:
synthetic_embeddings = np.zeros((10, real_embeddings.shape[1]), dtype=np.float32)
synthetic_paths = [f"synthetic_placeholder_{i}.png" for i in range(10)]
# =========================
# HELPERS
# =========================
HF_DATASET_RESOLVE_BASE = "https://huggingface.co/datasets"
def _placeholder_img(text: str, size=(512, 512)) -> Image.Image:
img = Image.new("RGB", size, (240, 240, 240))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.load_default()
except Exception:
font = None
lines = textwrap.wrap(text, width=28)
line_boxes = [draw.textbbox((0, 0), ln, font=font) for ln in lines]
total_h = sum(b[3]-b[1] for b in line_boxes) + (len(lines)-1)*6
y = (size[1]-total_h)//2
for ln, bb in zip(lines, line_boxes):
w = bb[2]-bb[0]; h = bb[3]-bb[1]
x = (size[0]-w)//2
draw.text((x, y), ln, fill=(30,30,30), font=font)
y += h + 6
return img
def slice_audio_to_wav_bytes(raw_bytes: bytes, start_sec: float | None, duration_sec: float = 30.0) -> bytes:
"""Trim to [start_sec, start_sec+duration_sec) and return WAV bytes."""
# decode
data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
if data.ndim > 1:
data = data.mean(axis=1)
total_sec = len(data) / sr
# default start=0 if None/blank/NaN
try:
s = float(start_sec) if start_sec is not None else 0.0
except Exception:
s = 0.0
s = max(0.0, s)
# clamp start so we always get up to 30s if possible
if total_sec > duration_sec and s + duration_sec > total_sec:
s = max(0.0, total_sec - duration_sec)
start_idx = int(round(s * sr))
end_idx = min(len(data), start_idx + int(round(duration_sec * sr)))
seg = data[start_idx:end_idx]
# re‑encode to WAV bytes
buf = io.BytesIO()
sf.write(buf, seg, sr, format="WAV")
return buf.getvalue()
def _open_local_image(path: str):
try:
img = Image.open(path).convert("RGB")
# keep memory down for the gallery
img.thumbnail((512, 512))
return img
except Exception:
return None
# =========================
# WIKIART IMAGES (STREAM FIRST 500 BY ORDER)
# =========================
from functools import lru_cache
@lru_cache(maxsize=64) # small cache to avoid re-fetching the same few images
def _get_wikiart_image_by_index(idx: int):
if idx is None or idx < 0 or idx >= min(WIKIART_COUNT, N_REAL):
return None
# Stream and jump to idx; loads just one image
stream = load_dataset(WIKIART_DATASET, split="train", streaming=True)
ex = next(islice(stream, idx, idx + 1), None)
if ex is None:
return None
img = ex["image"]
if hasattr(img, "convert"):
img = img.convert("RGB")
# Downsize for gallery to save RAM
try:
img.thumbnail((512, 512))
except Exception:
pass
return img
# =========================
# GTZAN AUDIO (NON-STREAMING, INDEXABLE) — keeps dataset = sanchit-gandhi/gtzan
# =========================
_gtzan_ds = None
def _load_gtzan():
"""
Load once (non-streaming) so we can index reliably by ordinal.
We also ensure the 'audio' column is decoded to array+sr.
"""
global _gtzan_ds
if _gtzan_ds is None:
ds = load_dataset(GTZAN_DATASET, split="train") # local cache, no manual URLs
# If 'audio' is already an Audio feature (decoded), great.
# If not, cast to Audio(decode=True) so ds[i]["audio"] -> {'array', 'sampling_rate'}.
if "audio" not in ds.features or not isinstance(ds.features["audio"], Audio):
ds = ds.cast_column("audio", Audio(decode=True))
_gtzan_ds = ds
return _gtzan_ds
def get_gtzan_audio_by_ordinal(n: int):
"""
1-based index → ((sr, waveform), raw_wav_bytes)
"""
if n < 1:
raise gr.Error("Index must be ≥ 1")
ds = _load_gtzan()
idx = n - 1
if idx >= len(ds):
raise gr.Error(f"No record at index {n} (dataset size: {len(ds)}).")
ex = ds[idx]
# 'genre' in some variants, 'label' in others — support both
genre = ex.get("genre", ex.get("label", ""))
audio = ex["audio"] # {'array': np.ndarray, 'sampling_rate': int}
arr = audio["array"]
sr = audio["sampling_rate"]
# Ensure mono for CLAP
if arr.ndim > 1:
arr = arr.mean(axis=1)
# Create WAV bytes so the rest of your pipeline can reuse uniformly
buf = io.BytesIO()
sf.write(buf, arr, sr, format="WAV")
raw = buf.getvalue()
return (sr, arr.astype(np.float32)), raw
def get_preset_audio_bytes(i: int) -> bytes:
"""Preset buttons: first 5 items by order."""
n = PRESET_ORDINALS[i]
(_sr_data, raw) = get_gtzan_audio_by_ordinal(n)
return raw
# =========================
# CLAP MODEL
# =========================
DEVICE = "cpu"
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(DEVICE)
clap_model.eval()
# Candidates the model can choose from (expand anytime)
CANDS = {
"subject": [
"jazz band", "symphony orchestra", "solo piano", "acoustic guitar",
"electric guitar", "violin", "hip-hop rapper", "DJ in a nightclub",
"techno producer", "ambient soundscape", "drum and bass", "punk band",
"choir", "saxophone solo", "lofi beats"
],
"mood": [
"moody", "melancholic", "energetic", "uplifting", "intense",
"dreamy", "peaceful", "dark", "nostalgic", "aggressive", "euphoric"
],
"style": [
"surreal", "impressionist", "cubist", "abstract", "street art",
"photorealistic", "digital painting", "watercolor", "oil painting",
"futuristic", "minimalist", "psychedelic", "retro"
],
"visuals": [
"neon lights", "soft gradients", "high contrast", "film grain",
"bokeh", "long exposure", "smoky bar", "concert stage",
"urban graffiti", "fractals", "misty atmosphere", "warm window light"
]
}
# Build a flat list and index slices for each pool, then encode once with CLAP's text tower.
ALL_PHRASES, POOL_SLICES = [], {}
_start = 0
for _k, _lst in CANDS.items():
ALL_PHRASES.extend(_lst)
POOL_SLICES[_k] = slice(_start, _start + len(_lst))
_start += len(_lst)
with torch.no_grad():
_ti = clap_processor(text=ALL_PHRASES, return_tensors="pt", padding=True)
_ti = {k: v.to(DEVICE) for k, v in _ti.items()}
_ALL_EMB = clap_model.get_text_features(**_ti) # [N, D]
_ALL_EMB = torch.nn.functional.normalize(_ALL_EMB, dim=-1).cpu().numpy()
def _best_from_pool(audio_emb: np.ndarray, pool_key: str, topk: int = 1):
a = audio_emb / (np.linalg.norm(audio_emb) + 1e-12)
sl = POOL_SLICES[pool_key]
sims = (_ALL_EMB[sl] @ a)
idx = np.argsort(-sims)[:topk]
phrases = [CANDS[pool_key][i] for i in idx]
return phrases, [float(sims[i]) for i in idx]
def build_prompt_from_audio(audio_emb: np.ndarray) -> str:
subj, _ = _best_from_pool(audio_emb, "subject", topk=1)
moods, _ = _best_from_pool(audio_emb, "mood", topk=2)
styles, _ = _best_from_pool(audio_emb, "style", topk=1)
vis, _ = _best_from_pool(audio_emb, "visuals", topk=2)
directive = "high quality, detailed, 512x512"
return f"{subj[0]}, {', '.join(moods)}, {styles[0]}, {', '.join(vis)}, {directive}"
def embed_audio_bytes(raw_bytes: bytes):
# Read original audio for playback
orig, orig_sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
if orig.ndim > 1:
orig = orig.mean(axis=1)
# CLAP needs 48 kHz mono
target_sr = 48000
proc = orig
if orig_sr != target_sr:
g = math.gcd(target_sr, orig_sr)
up = target_sr // g
down = orig_sr // g
proc = resample_poly(orig, up, down)
inputs = clap_processor(audios=proc, sampling_rate=target_sr, return_tensors="pt")
with torch.no_grad():
feats = clap_model.get_audio_features(**{k: v.to(DEVICE) for k, v in inputs.items()})
# Return features + the ORIGINAL (sr, waveform) for gradio playback
return feats[0].cpu().numpy(), (orig_sr, orig)
# =========================
# NEIGHBOR INDEX + MATCHING
# =========================
nn_real = NearestNeighbors(n_neighbors=3, metric="cosine")
nn_real.fit(real_embeddings)
def recommend_top3_real_plus_best_synth(audio_emb: np.ndarray):
# Top-3 real
dists, idxs = nn_real.kneighbors(audio_emb.reshape(1, -1), n_neighbors=3)
real_results = [(real_paths[i], 1 - d) for i, d in zip(idxs[0], dists[0])]
# Best synthetic
synth_sims = cosine_similarity(audio_emb.reshape(1, -1), synthetic_embeddings)[0]
j = int(np.argmax(synth_sims))
synth_result = (synthetic_paths[j], float(synth_sims[j]))
return real_results, synth_result
def open_images_with_captions(match_list):
"""
match_list: list of (csv_path_string, score, tag)
Assumes order alignment: CSV row index == WikiArt streamed index.
"""
images, captions = [], []
if not hasattr(open_images_with_captions, "_real_index_map"):
open_images_with_captions._real_index_map = {real_paths[i]: i for i in range(len(real_paths))}
idx_map = open_images_with_captions._real_index_map
for p, score, tag in match_list:
img = None
base = os.path.basename(str(p))
if tag == "Real":
i = idx_map.get(p)
img = _get_wikiart_image_by_index(i) if i is not None else None
elif tag == "Synthetic":
# Try to load from repo file path, e.g. "synthetic_images/synthetic_08.jpg"
if isinstance(p, str) and os.path.exists(p):
img = _open_local_image(p)
if img is None:
# fallback placeholder if anything failed above
img = _placeholder_img(f"{tag}\n{base}")
images.append(img)
captions.append(f"{tag}: {base} (sim {score:.2f})")
return images, "\n".join(captions)
# =========================
# PIPELINE
# =========================
def _get_sd_pipe():
global _SD_PIPE
if _SD_PIPE is not None:
return _SD_PIPE
if torch.cuda.is_available():
# Fast path (recommended)
_SD_PIPE = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
else:
# CPU fallback: smaller + few steps to avoid timeouts/OOM
_SD_PIPE = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sd-turbo", # lighter than SDXL
torch_dtype=torch.float32,
use_safetensors=True,
).to("cpu")
return _SD_PIPE
import contextlib
def generate_image_from_last_audio():
if LAST_AUDIO_EMB is None:
gr.Warning("Load or upload audio first, then try again.")
return None
prompt = build_prompt_from_audio(LAST_AUDIO_EMB)
pipe = _get_sd_pipe()
if pipe is None:
gr.Warning("Image generator not available.")
return None
is_cuda = (pipe.device.type == "cuda")
autocast_ctx = torch.autocast("cuda") if is_cuda else contextlib.nullcontext()
# Small settings on CPU to keep it reasonable
steps = 4 if is_cuda else 2
w = h = 512 if is_cuda else 256
guidance = 0.0 # Turbo models like low guidance
with autocast_ctx, torch.inference_mode():
img = pipe(prompt, num_inference_steps=steps,
guidance_scale=guidance, width=w, height=h).images[0]
if not is_cuda:
gr.Info("Generated on CPU (reduced size/steps). For fast 512×512, enable a GPU runtime.")
return img
def run_example(i: int):
# full clip for examples (no trimming)
raw = get_preset_audio_bytes(i)
playable, images, caps = run_pipeline_from_bytes(raw)
# enable generate button
return playable, images, caps, gr.update(interactive=True)
def run_pipeline_from_bytes(raw_bytes: bytes):
if raw_bytes is None:
return None, [], ""
audio_emb, playable = embed_audio_bytes(raw_bytes)
global LAST_AUDIO_EMB
LAST_AUDIO_EMB = audio_emb.copy()
real3, synth1 = recommend_top3_real_plus_best_synth(audio_emb)
combined = [(p, s, "Real") for (p, s) in real3] + [(synth1[0], synth1[1], "Synthetic")]
images, captions_str = open_images_with_captions(combined)
return playable, images, captions_str
# free large arrays we no longer need
del audio_emb, real3, synth1, combined
gc.collect()
return playable, images, captions_str
# =========================
# UI
# =========================
with gr.Blocks(theme="soft") as demo:
gr.Markdown("# 🎧 EchoArt\nMatch your music to visual art — Top 3 real artworks + 1 AI‑generated image.")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Upload audio (WAV/MP3) or try a preset:**")
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
start_in = gr.Number(label="Start at (seconds, optional)", value=None, precision=2)
with gr.Row():
ex_labels = [f"Example {k+1} (idx {PRESET_ORDINALS[k]})" for k in range(5)]
ex_buttons = [gr.Button(lbl) for lbl in ex_labels]
submit = gr.Button("Find Matching Art", variant="primary")
with gr.Column(scale=2):
audio_out = gr.Audio(label="Your Audio", interactive=False)
gallery = gr.Gallery(label="Top 3 Real + 1 Synthetic", columns=4, height=320)
captions = gr.Textbox(label="Matches", lines=4)
# --- AI image generation row (button is disabled until audio is processed) ---
with gr.Row():
gen_btn = gr.Button("🎨 Generate image from this audio", interactive=False)
gen_img = gr.Image(label="AI Image", interactive=False)
# -----------------------
# Handlers (must be inside Blocks)
# -----------------------
def on_submit(audio_path, start_sec):
if not audio_path:
# keep the gen button disabled
return None, [], "", gr.update(interactive=False)
with open(audio_path, "rb") as f:
raw = f.read()
# trim uploads to 30 seconds (start_sec may be None)
raw = slice_audio_to_wav_bytes(raw, start_sec, duration_sec=30.0)
playable, images, caps = run_pipeline_from_bytes(raw)
# enable generate button now that LAST_AUDIO_EMB is set
return playable, images, caps, gr.update(interactive=True)
submit.click(
on_submit,
inputs=[audio_in, start_in],
outputs=[audio_out, gallery, captions, gen_btn],
)
# Wire example buttons (freeze loop var with k=k)
for k, btn in enumerate(ex_buttons):
btn.click(lambda k=k: run_example(k), inputs=None,
outputs=[audio_out, gallery, captions, gen_btn])
def on_generate():
img = generate_image_from_last_audio()
return img
gen_btn.click(on_generate, inputs=None, outputs=[gen_img])
# Outside the Blocks context:
demo.launch() |