Spaces:
Sleeping
Sleeping
import os, io, textwrap | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import soundfile as sf | |
from PIL import Image, ImageDraw, ImageFont | |
from sklearn.neighbors import NearestNeighbors | |
from sklearn.metrics.pairwise import cosine_similarity | |
import torch | |
torch.set_num_threads(1) | |
from transformers import ClapProcessor, ClapModel | |
from datasets import load_dataset, Audio | |
import requests | |
from itertools import islice | |
from scipy.signal import resample_poly | |
import math | |
import gc | |
from diffusers import AutoPipelineForText2Image # NEW | |
LAST_AUDIO_EMB = None # NEW | |
_SD_PIPE = None # NEW | |
# ========================= | |
# CONFIG | |
# ========================= | |
# Real (WikiArt) embeddings + metadata (you already have these) | |
REAL_EMB_PATH = "wikiart_embeddings.npy" | |
REAL_META_PATH = "wikiart_metadata.csv" | |
# Synthetic (optional; placeholders used if missing) | |
SYNTH_EMB_PATH = "synthetic_image_embeddings.npy" | |
SYNTH_META_PATH = "synthetic_image_metadata.csv" | |
# Data sources (remote) | |
WIKIART_DATASET = "huggan/wikiart" # images streamed by order | |
WIKIART_COUNT = 1500 # load only first 500 to match your embeddings | |
GTZAN_DATASET = "sanchit-gandhi/gtzan" # <-- keep using your dataset | |
PRESET_ORDINALS = [1, 102, 203, 304, 405] # GTZAN presets by ordinal (1-based) | |
# ========================= | |
# LOAD EMBEDDINGS/METADATA | |
# ========================= | |
real_embeddings = np.load(REAL_EMB_PATH) | |
real_meta = pd.read_csv(REAL_META_PATH) | |
real_paths = real_meta["path"].astype(str).tolist() | |
N_REAL = len(real_paths) | |
# Synthetic (optional): fall back to placeholders if not present | |
try: | |
synthetic_embeddings = np.load(SYNTH_EMB_PATH) | |
synthetic_meta = pd.read_csv(SYNTH_META_PATH) | |
synthetic_paths = synthetic_meta["path"].astype(str).tolist() | |
except Exception: | |
synthetic_embeddings = np.zeros((10, real_embeddings.shape[1]), dtype=np.float32) | |
synthetic_paths = [f"synthetic_placeholder_{i}.png" for i in range(10)] | |
# ========================= | |
# HELPERS | |
# ========================= | |
HF_DATASET_RESOLVE_BASE = "https://huggingface.co/datasets" | |
def _placeholder_img(text: str, size=(512, 512)) -> Image.Image: | |
img = Image.new("RGB", size, (240, 240, 240)) | |
draw = ImageDraw.Draw(img) | |
try: | |
font = ImageFont.load_default() | |
except Exception: | |
font = None | |
lines = textwrap.wrap(text, width=28) | |
line_boxes = [draw.textbbox((0, 0), ln, font=font) for ln in lines] | |
total_h = sum(b[3]-b[1] for b in line_boxes) + (len(lines)-1)*6 | |
y = (size[1]-total_h)//2 | |
for ln, bb in zip(lines, line_boxes): | |
w = bb[2]-bb[0]; h = bb[3]-bb[1] | |
x = (size[0]-w)//2 | |
draw.text((x, y), ln, fill=(30,30,30), font=font) | |
y += h + 6 | |
return img | |
def slice_audio_to_wav_bytes(raw_bytes: bytes, start_sec: float | None, duration_sec: float = 30.0) -> bytes: | |
"""Trim to [start_sec, start_sec+duration_sec) and return WAV bytes.""" | |
# decode | |
data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False) | |
if data.ndim > 1: | |
data = data.mean(axis=1) | |
total_sec = len(data) / sr | |
# default start=0 if None/blank/NaN | |
try: | |
s = float(start_sec) if start_sec is not None else 0.0 | |
except Exception: | |
s = 0.0 | |
s = max(0.0, s) | |
# clamp start so we always get up to 30s if possible | |
if total_sec > duration_sec and s + duration_sec > total_sec: | |
s = max(0.0, total_sec - duration_sec) | |
start_idx = int(round(s * sr)) | |
end_idx = min(len(data), start_idx + int(round(duration_sec * sr))) | |
seg = data[start_idx:end_idx] | |
# re‑encode to WAV bytes | |
buf = io.BytesIO() | |
sf.write(buf, seg, sr, format="WAV") | |
return buf.getvalue() | |
def _open_local_image(path: str): | |
try: | |
img = Image.open(path).convert("RGB") | |
# keep memory down for the gallery | |
img.thumbnail((512, 512)) | |
return img | |
except Exception: | |
return None | |
# ========================= | |
# WIKIART IMAGES (STREAM FIRST 500 BY ORDER) | |
# ========================= | |
from functools import lru_cache | |
# small cache to avoid re-fetching the same few images | |
def _get_wikiart_image_by_index(idx: int): | |
if idx is None or idx < 0 or idx >= min(WIKIART_COUNT, N_REAL): | |
return None | |
# Stream and jump to idx; loads just one image | |
stream = load_dataset(WIKIART_DATASET, split="train", streaming=True) | |
ex = next(islice(stream, idx, idx + 1), None) | |
if ex is None: | |
return None | |
img = ex["image"] | |
if hasattr(img, "convert"): | |
img = img.convert("RGB") | |
# Downsize for gallery to save RAM | |
try: | |
img.thumbnail((512, 512)) | |
except Exception: | |
pass | |
return img | |
# ========================= | |
# GTZAN AUDIO (NON-STREAMING, INDEXABLE) — keeps dataset = sanchit-gandhi/gtzan | |
# ========================= | |
_gtzan_ds = None | |
def _load_gtzan(): | |
""" | |
Load once (non-streaming) so we can index reliably by ordinal. | |
We also ensure the 'audio' column is decoded to array+sr. | |
""" | |
global _gtzan_ds | |
if _gtzan_ds is None: | |
ds = load_dataset(GTZAN_DATASET, split="train") # local cache, no manual URLs | |
# If 'audio' is already an Audio feature (decoded), great. | |
# If not, cast to Audio(decode=True) so ds[i]["audio"] -> {'array', 'sampling_rate'}. | |
if "audio" not in ds.features or not isinstance(ds.features["audio"], Audio): | |
ds = ds.cast_column("audio", Audio(decode=True)) | |
_gtzan_ds = ds | |
return _gtzan_ds | |
def get_gtzan_audio_by_ordinal(n: int): | |
""" | |
1-based index → ((sr, waveform), raw_wav_bytes) | |
""" | |
if n < 1: | |
raise gr.Error("Index must be ≥ 1") | |
ds = _load_gtzan() | |
idx = n - 1 | |
if idx >= len(ds): | |
raise gr.Error(f"No record at index {n} (dataset size: {len(ds)}).") | |
ex = ds[idx] | |
# 'genre' in some variants, 'label' in others — support both | |
genre = ex.get("genre", ex.get("label", "")) | |
audio = ex["audio"] # {'array': np.ndarray, 'sampling_rate': int} | |
arr = audio["array"] | |
sr = audio["sampling_rate"] | |
# Ensure mono for CLAP | |
if arr.ndim > 1: | |
arr = arr.mean(axis=1) | |
# Create WAV bytes so the rest of your pipeline can reuse uniformly | |
buf = io.BytesIO() | |
sf.write(buf, arr, sr, format="WAV") | |
raw = buf.getvalue() | |
return (sr, arr.astype(np.float32)), raw | |
def get_preset_audio_bytes(i: int) -> bytes: | |
"""Preset buttons: first 5 items by order.""" | |
n = PRESET_ORDINALS[i] | |
(_sr_data, raw) = get_gtzan_audio_by_ordinal(n) | |
return raw | |
# ========================= | |
# CLAP MODEL | |
# ========================= | |
DEVICE = "cpu" | |
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused") | |
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(DEVICE) | |
clap_model.eval() | |
# Candidates the model can choose from (expand anytime) | |
CANDS = { | |
"subject": [ | |
"jazz band", "symphony orchestra", "solo piano", "acoustic guitar", | |
"electric guitar", "violin", "hip-hop rapper", "DJ in a nightclub", | |
"techno producer", "ambient soundscape", "drum and bass", "punk band", | |
"choir", "saxophone solo", "lofi beats" | |
], | |
"mood": [ | |
"moody", "melancholic", "energetic", "uplifting", "intense", | |
"dreamy", "peaceful", "dark", "nostalgic", "aggressive", "euphoric" | |
], | |
"style": [ | |
"surreal", "impressionist", "cubist", "abstract", "street art", | |
"photorealistic", "digital painting", "watercolor", "oil painting", | |
"futuristic", "minimalist", "psychedelic", "retro" | |
], | |
"visuals": [ | |
"neon lights", "soft gradients", "high contrast", "film grain", | |
"bokeh", "long exposure", "smoky bar", "concert stage", | |
"urban graffiti", "fractals", "misty atmosphere", "warm window light" | |
] | |
} | |
# Build a flat list and index slices for each pool, then encode once with CLAP's text tower. | |
ALL_PHRASES, POOL_SLICES = [], {} | |
_start = 0 | |
for _k, _lst in CANDS.items(): | |
ALL_PHRASES.extend(_lst) | |
POOL_SLICES[_k] = slice(_start, _start + len(_lst)) | |
_start += len(_lst) | |
with torch.no_grad(): | |
_ti = clap_processor(text=ALL_PHRASES, return_tensors="pt", padding=True) | |
_ti = {k: v.to(DEVICE) for k, v in _ti.items()} | |
_ALL_EMB = clap_model.get_text_features(**_ti) # [N, D] | |
_ALL_EMB = torch.nn.functional.normalize(_ALL_EMB, dim=-1).cpu().numpy() | |
def _best_from_pool(audio_emb: np.ndarray, pool_key: str, topk: int = 1): | |
a = audio_emb / (np.linalg.norm(audio_emb) + 1e-12) | |
sl = POOL_SLICES[pool_key] | |
sims = (_ALL_EMB[sl] @ a) | |
idx = np.argsort(-sims)[:topk] | |
phrases = [CANDS[pool_key][i] for i in idx] | |
return phrases, [float(sims[i]) for i in idx] | |
def build_prompt_from_audio(audio_emb: np.ndarray) -> str: | |
subj, _ = _best_from_pool(audio_emb, "subject", topk=1) | |
moods, _ = _best_from_pool(audio_emb, "mood", topk=2) | |
styles, _ = _best_from_pool(audio_emb, "style", topk=1) | |
vis, _ = _best_from_pool(audio_emb, "visuals", topk=2) | |
directive = "high quality, detailed, 512x512" | |
return f"{subj[0]}, {', '.join(moods)}, {styles[0]}, {', '.join(vis)}, {directive}" | |
def embed_audio_bytes(raw_bytes: bytes): | |
# Read original audio for playback | |
orig, orig_sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False) | |
if orig.ndim > 1: | |
orig = orig.mean(axis=1) | |
# CLAP needs 48 kHz mono | |
target_sr = 48000 | |
proc = orig | |
if orig_sr != target_sr: | |
g = math.gcd(target_sr, orig_sr) | |
up = target_sr // g | |
down = orig_sr // g | |
proc = resample_poly(orig, up, down) | |
inputs = clap_processor(audios=proc, sampling_rate=target_sr, return_tensors="pt") | |
with torch.no_grad(): | |
feats = clap_model.get_audio_features(**{k: v.to(DEVICE) for k, v in inputs.items()}) | |
# Return features + the ORIGINAL (sr, waveform) for gradio playback | |
return feats[0].cpu().numpy(), (orig_sr, orig) | |
# ========================= | |
# NEIGHBOR INDEX + MATCHING | |
# ========================= | |
nn_real = NearestNeighbors(n_neighbors=3, metric="cosine") | |
nn_real.fit(real_embeddings) | |
def recommend_top3_real_plus_best_synth(audio_emb: np.ndarray): | |
# Top-3 real | |
dists, idxs = nn_real.kneighbors(audio_emb.reshape(1, -1), n_neighbors=3) | |
real_results = [(real_paths[i], 1 - d) for i, d in zip(idxs[0], dists[0])] | |
# Best synthetic | |
synth_sims = cosine_similarity(audio_emb.reshape(1, -1), synthetic_embeddings)[0] | |
j = int(np.argmax(synth_sims)) | |
synth_result = (synthetic_paths[j], float(synth_sims[j])) | |
return real_results, synth_result | |
def open_images_with_captions(match_list): | |
""" | |
match_list: list of (csv_path_string, score, tag) | |
Assumes order alignment: CSV row index == WikiArt streamed index. | |
""" | |
images, captions = [], [] | |
if not hasattr(open_images_with_captions, "_real_index_map"): | |
open_images_with_captions._real_index_map = {real_paths[i]: i for i in range(len(real_paths))} | |
idx_map = open_images_with_captions._real_index_map | |
for p, score, tag in match_list: | |
img = None | |
base = os.path.basename(str(p)) | |
if tag == "Real": | |
i = idx_map.get(p) | |
img = _get_wikiart_image_by_index(i) if i is not None else None | |
elif tag == "Synthetic": | |
# Try to load from repo file path, e.g. "synthetic_images/synthetic_08.jpg" | |
if isinstance(p, str) and os.path.exists(p): | |
img = _open_local_image(p) | |
if img is None: | |
# fallback placeholder if anything failed above | |
img = _placeholder_img(f"{tag}\n{base}") | |
images.append(img) | |
captions.append(f"{tag}: {base} (sim {score:.2f})") | |
return images, "\n".join(captions) | |
# ========================= | |
# PIPELINE | |
# ========================= | |
def _get_sd_pipe(): | |
global _SD_PIPE | |
if _SD_PIPE is not None: | |
return _SD_PIPE | |
if torch.cuda.is_available(): | |
# Fast path (recommended) | |
_SD_PIPE = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
).to("cuda") | |
else: | |
# CPU fallback: smaller + few steps to avoid timeouts/OOM | |
_SD_PIPE = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sd-turbo", # lighter than SDXL | |
torch_dtype=torch.float32, | |
use_safetensors=True, | |
).to("cpu") | |
return _SD_PIPE | |
import contextlib | |
def generate_image_from_last_audio(): | |
if LAST_AUDIO_EMB is None: | |
gr.Warning("Load or upload audio first, then try again.") | |
return None | |
prompt = build_prompt_from_audio(LAST_AUDIO_EMB) | |
pipe = _get_sd_pipe() | |
if pipe is None: | |
gr.Warning("Image generator not available.") | |
return None | |
is_cuda = (pipe.device.type == "cuda") | |
autocast_ctx = torch.autocast("cuda") if is_cuda else contextlib.nullcontext() | |
# Small settings on CPU to keep it reasonable | |
steps = 4 if is_cuda else 2 | |
w = h = 512 if is_cuda else 256 | |
guidance = 0.0 # Turbo models like low guidance | |
with autocast_ctx, torch.inference_mode(): | |
img = pipe(prompt, num_inference_steps=steps, | |
guidance_scale=guidance, width=w, height=h).images[0] | |
if not is_cuda: | |
gr.Info("Generated on CPU (reduced size/steps). For fast 512×512, enable a GPU runtime.") | |
return img | |
def run_example(i: int): | |
# full clip for examples (no trimming) | |
raw = get_preset_audio_bytes(i) | |
playable, images, caps = run_pipeline_from_bytes(raw) | |
# enable generate button | |
return playable, images, caps, gr.update(interactive=True) | |
def run_pipeline_from_bytes(raw_bytes: bytes): | |
if raw_bytes is None: | |
return None, [], "" | |
audio_emb, playable = embed_audio_bytes(raw_bytes) | |
global LAST_AUDIO_EMB | |
LAST_AUDIO_EMB = audio_emb.copy() | |
real3, synth1 = recommend_top3_real_plus_best_synth(audio_emb) | |
combined = [(p, s, "Real") for (p, s) in real3] + [(synth1[0], synth1[1], "Synthetic")] | |
images, captions_str = open_images_with_captions(combined) | |
return playable, images, captions_str | |
# free large arrays we no longer need | |
del audio_emb, real3, synth1, combined | |
gc.collect() | |
return playable, images, captions_str | |
# ========================= | |
# UI | |
# ========================= | |
with gr.Blocks(theme="soft") as demo: | |
gr.Markdown("# 🎧 EchoArt\nMatch your music to visual art — Top 3 real artworks + 1 AI‑generated image.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("**Upload audio (WAV/MP3) or try a preset:**") | |
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input") | |
start_in = gr.Number(label="Start at (seconds, optional)", value=None, precision=2) | |
with gr.Row(): | |
ex_labels = [f"Example {k+1} (idx {PRESET_ORDINALS[k]})" for k in range(5)] | |
ex_buttons = [gr.Button(lbl) for lbl in ex_labels] | |
submit = gr.Button("Find Matching Art", variant="primary") | |
with gr.Column(scale=2): | |
audio_out = gr.Audio(label="Your Audio", interactive=False) | |
gallery = gr.Gallery(label="Top 3 Real + 1 Synthetic", columns=4, height=320) | |
captions = gr.Textbox(label="Matches", lines=4) | |
# --- AI image generation row (button is disabled until audio is processed) --- | |
with gr.Row(): | |
gen_btn = gr.Button("🎨 Generate image from this audio", interactive=False) | |
gen_img = gr.Image(label="AI Image", interactive=False) | |
# ----------------------- | |
# Handlers (must be inside Blocks) | |
# ----------------------- | |
def on_submit(audio_path, start_sec): | |
if not audio_path: | |
# keep the gen button disabled | |
return None, [], "", gr.update(interactive=False) | |
with open(audio_path, "rb") as f: | |
raw = f.read() | |
# trim uploads to 30 seconds (start_sec may be None) | |
raw = slice_audio_to_wav_bytes(raw, start_sec, duration_sec=30.0) | |
playable, images, caps = run_pipeline_from_bytes(raw) | |
# enable generate button now that LAST_AUDIO_EMB is set | |
return playable, images, caps, gr.update(interactive=True) | |
submit.click( | |
on_submit, | |
inputs=[audio_in, start_in], | |
outputs=[audio_out, gallery, captions, gen_btn], | |
) | |
# Wire example buttons (freeze loop var with k=k) | |
for k, btn in enumerate(ex_buttons): | |
btn.click(lambda k=k: run_example(k), inputs=None, | |
outputs=[audio_out, gallery, captions, gen_btn]) | |
def on_generate(): | |
img = generate_image_from_last_audio() | |
return img | |
gen_btn.click(on_generate, inputs=None, outputs=[gen_img]) | |
# Outside the Blocks context: | |
demo.launch() |