File size: 16,986 Bytes
d0571ab
94767f3
34992cb
 
f8b7ca2
d0571ab
94767f3
34992cb
f8b7ca2
69e1587
a94ab38
69e1587
d0571ab
 
f20cac8
69e1587
c352f03
 
 
a94ab38
 
31e85ae
 
 
 
cbb368a
 
 
 
69e1587
 
a83fe7f
cbb368a
 
 
a83fe7f
cbb368a
 
fc258ac
41af176
fed1d2a
69e1587
cbb368a
 
 
69e1587
 
593b21d
9934200
d0571ab
cbb368a
d0571ab
 
 
 
 
 
 
69e1587
cbb368a
 
 
 
69e1587
d0571ab
 
 
 
 
 
 
593b21d
a83fe7f
 
 
 
cbb368a
a83fe7f
 
593b21d
d0571ab
69e1587
dea3fed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b748e29
 
 
 
 
 
 
 
 
 
cbb368a
 
 
a94ab38
a83fe7f
a94ab38
a83fe7f
a94ab38
 
 
 
 
 
a83fe7f
a94ab38
 
 
 
 
 
 
 
 
a83fe7f
cbb368a
41af176
cbb368a
41af176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb368a
41af176
 
 
cbb368a
 
41af176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb368a
 
 
 
 
 
a83fe7f
cbb368a
 
 
6a052a7
 
 
 
 
31e85ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a052a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69e1587
593b21d
c352f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69e1587
 
c352f03
 
 
34992cb
cbb368a
 
 
e70a4b0
 
94767f3
69e1587
593b21d
69e1587
 
cbb368a
69e1587
 
 
 
 
 
cbb368a
 
b748e29
9934200
d0571ab
9934200
 
 
 
d0571ab
9934200
d0571ab
b748e29
9934200
 
a83fe7f
b748e29
 
 
 
 
 
 
d0571ab
b748e29
d0571ab
 
b748e29
593b21d
69e1587
cbb368a
 
 
6a052a7
 
7ed2000
 
 
 
 
6a052a7
 
 
7ed2000
6a052a7
7ed2000
 
 
 
 
 
 
6a052a7
 
7ed2000
 
6a052a7
 
7ed2000
6a052a7
7ed2000
6a052a7
 
 
7ed2000
6a052a7
7ed2000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a052a7
 
7ffaaac
 
 
 
 
 
fed1d2a
69e1587
 
593b21d
d0571ab
6a052a7
31e85ae
 
6a052a7
69e1587
 
593b21d
6a052a7
a94ab38
 
 
 
 
593b21d
69e1587
cbb368a
 
 
69e1587
 
34992cb
e70a4b0
69e1587
 
 
dea3fed
69e1587
6ff0102
 
69e1587
 
 
cbb368a
69e1587
 
6a052a7
 
 
 
 
 
 
 
dea3fed
69e1587
6a052a7
 
69e1587
 
6a052a7
dea3fed
6a052a7
 
 
69e1587
6a052a7
 
 
 
 
7b9d2f4
6a052a7
6ff0102
6a052a7
 
 
 
 
 
 
 
 
f20cac8
6ff0102
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
import os, io, textwrap
import gradio as gr
import numpy as np
import pandas as pd
import soundfile as sf
from PIL import Image, ImageDraw, ImageFont
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity

import torch
torch.set_num_threads(1)
from transformers import ClapProcessor, ClapModel
from datasets import load_dataset, Audio
import requests
from itertools import islice

from scipy.signal import resample_poly
import math

import gc

from diffusers import AutoPipelineForText2Image  # NEW
LAST_AUDIO_EMB = None                             # NEW
_SD_PIPE = None                                   # NEW

# =========================
# CONFIG
# =========================
# Real (WikiArt) embeddings + metadata (you already have these)
REAL_EMB_PATH = "wikiart_embeddings.npy"
REAL_META_PATH = "wikiart_metadata.csv"

# Synthetic (optional; placeholders used if missing)
SYNTH_EMB_PATH = "synthetic_image_embeddings.npy"
SYNTH_META_PATH = "synthetic_image_metadata.csv"

# Data sources (remote)
WIKIART_DATASET = "huggan/wikiart"        # images streamed by order
WIKIART_COUNT = 1500                      # load only first 500 to match your embeddings
GTZAN_DATASET   = "sanchit-gandhi/gtzan"  # <-- keep using your dataset
PRESET_ORDINALS = [1, 102, 203, 304, 405]         # GTZAN presets by ordinal (1-based)

# =========================
# LOAD EMBEDDINGS/METADATA
# =========================
real_embeddings = np.load(REAL_EMB_PATH)
real_meta = pd.read_csv(REAL_META_PATH)
real_paths = real_meta["path"].astype(str).tolist()
N_REAL = len(real_paths)

# Synthetic (optional): fall back to placeholders if not present
try:
    synthetic_embeddings = np.load(SYNTH_EMB_PATH)
    synthetic_meta = pd.read_csv(SYNTH_META_PATH)
    synthetic_paths = synthetic_meta["path"].astype(str).tolist()
except Exception:
    synthetic_embeddings = np.zeros((10, real_embeddings.shape[1]), dtype=np.float32)
    synthetic_paths = [f"synthetic_placeholder_{i}.png" for i in range(10)]

# =========================
# HELPERS
# =========================
HF_DATASET_RESOLVE_BASE = "https://huggingface.co/datasets"

def _placeholder_img(text: str, size=(512, 512)) -> Image.Image:
    img = Image.new("RGB", size, (240, 240, 240))
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.load_default()
    except Exception:
        font = None
    lines = textwrap.wrap(text, width=28)
    line_boxes = [draw.textbbox((0, 0), ln, font=font) for ln in lines]
    total_h = sum(b[3]-b[1] for b in line_boxes) + (len(lines)-1)*6
    y = (size[1]-total_h)//2
    for ln, bb in zip(lines, line_boxes):
        w = bb[2]-bb[0]; h = bb[3]-bb[1]
        x = (size[0]-w)//2
        draw.text((x, y), ln, fill=(30,30,30), font=font)
        y += h + 6
    return img

def slice_audio_to_wav_bytes(raw_bytes: bytes, start_sec: float | None, duration_sec: float = 30.0) -> bytes:
    """Trim to [start_sec, start_sec+duration_sec) and return WAV bytes."""
    # decode
    data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
    if data.ndim > 1:
        data = data.mean(axis=1)

    total_sec = len(data) / sr
    # default start=0 if None/blank/NaN
    try:
        s = float(start_sec) if start_sec is not None else 0.0
    except Exception:
        s = 0.0
    s = max(0.0, s)

    # clamp start so we always get up to 30s if possible
    if total_sec > duration_sec and s + duration_sec > total_sec:
        s = max(0.0, total_sec - duration_sec)

    start_idx = int(round(s * sr))
    end_idx = min(len(data), start_idx + int(round(duration_sec * sr)))
    seg = data[start_idx:end_idx]

    # re‑encode to WAV bytes
    buf = io.BytesIO()
    sf.write(buf, seg, sr, format="WAV")
    return buf.getvalue()

def _open_local_image(path: str):
    try:
        img = Image.open(path).convert("RGB")
        # keep memory down for the gallery
        img.thumbnail((512, 512))
        return img
    except Exception:
        return None

# =========================
# WIKIART IMAGES (STREAM FIRST 500 BY ORDER)
# =========================
from functools import lru_cache

@lru_cache(maxsize=64)  # small cache to avoid re-fetching the same few images
def _get_wikiart_image_by_index(idx: int):
    if idx is None or idx < 0 or idx >= min(WIKIART_COUNT, N_REAL):
        return None
    # Stream and jump to idx; loads just one image
    stream = load_dataset(WIKIART_DATASET, split="train", streaming=True)
    ex = next(islice(stream, idx, idx + 1), None)
    if ex is None:
        return None
    img = ex["image"]
    if hasattr(img, "convert"):
        img = img.convert("RGB")
    # Downsize for gallery to save RAM
    try:
        img.thumbnail((512, 512))
    except Exception:
        pass
    return img

# =========================
# GTZAN AUDIO (NON-STREAMING, INDEXABLE) — keeps dataset = sanchit-gandhi/gtzan
# =========================
_gtzan_ds = None

def _load_gtzan():
    """
    Load once (non-streaming) so we can index reliably by ordinal.
    We also ensure the 'audio' column is decoded to array+sr.
    """
    global _gtzan_ds
    if _gtzan_ds is None:
        ds = load_dataset(GTZAN_DATASET, split="train")  # local cache, no manual URLs
        # If 'audio' is already an Audio feature (decoded), great.
        # If not, cast to Audio(decode=True) so ds[i]["audio"] -> {'array', 'sampling_rate'}.
        if "audio" not in ds.features or not isinstance(ds.features["audio"], Audio):
            ds = ds.cast_column("audio", Audio(decode=True))
        _gtzan_ds = ds
    return _gtzan_ds

def get_gtzan_audio_by_ordinal(n: int):
    """
    1-based index → ((sr, waveform), raw_wav_bytes)
    """
    if n < 1:
        raise gr.Error("Index must be ≥ 1")

    ds = _load_gtzan()
    idx = n - 1
    if idx >= len(ds):
        raise gr.Error(f"No record at index {n} (dataset size: {len(ds)}).")

    ex = ds[idx]
    # 'genre' in some variants, 'label' in others — support both
    genre = ex.get("genre", ex.get("label", ""))

    audio = ex["audio"]  # {'array': np.ndarray, 'sampling_rate': int}
    arr = audio["array"]
    sr  = audio["sampling_rate"]

    # Ensure mono for CLAP
    if arr.ndim > 1:
        arr = arr.mean(axis=1)

    # Create WAV bytes so the rest of your pipeline can reuse uniformly
    buf = io.BytesIO()
    sf.write(buf, arr, sr, format="WAV")
    raw = buf.getvalue()

    return (sr, arr.astype(np.float32)), raw

def get_preset_audio_bytes(i: int) -> bytes:
    """Preset buttons: first 5 items by order."""
    n = PRESET_ORDINALS[i]
    (_sr_data, raw) = get_gtzan_audio_by_ordinal(n)
    return raw

# =========================
# CLAP MODEL
# =========================
DEVICE = "cpu"
clap_processor = ClapProcessor.from_pretrained("laion/clap-htsat-unfused")
clap_model = ClapModel.from_pretrained("laion/clap-htsat-unfused").to(DEVICE)
clap_model.eval()

# Candidates the model can choose from (expand anytime)
CANDS = {
    "subject": [
        "jazz band", "symphony orchestra", "solo piano", "acoustic guitar",
        "electric guitar", "violin", "hip-hop rapper", "DJ in a nightclub",
        "techno producer", "ambient soundscape", "drum and bass", "punk band",
        "choir", "saxophone solo", "lofi beats"
    ],
    "mood": [
        "moody", "melancholic", "energetic", "uplifting", "intense",
        "dreamy", "peaceful", "dark", "nostalgic", "aggressive", "euphoric"
    ],
    "style": [
        "surreal", "impressionist", "cubist", "abstract", "street art",
        "photorealistic", "digital painting", "watercolor", "oil painting",
        "futuristic", "minimalist", "psychedelic", "retro"
    ],
    "visuals": [
        "neon lights", "soft gradients", "high contrast", "film grain",
        "bokeh", "long exposure", "smoky bar", "concert stage",
        "urban graffiti", "fractals", "misty atmosphere", "warm window light"
    ]
}

# Build a flat list and index slices for each pool, then encode once with CLAP's text tower.
ALL_PHRASES, POOL_SLICES = [], {}
_start = 0
for _k, _lst in CANDS.items():
    ALL_PHRASES.extend(_lst)
    POOL_SLICES[_k] = slice(_start, _start + len(_lst))
    _start += len(_lst)

with torch.no_grad():
    _ti = clap_processor(text=ALL_PHRASES, return_tensors="pt", padding=True)
    _ti = {k: v.to(DEVICE) for k, v in _ti.items()}
    _ALL_EMB = clap_model.get_text_features(**_ti)                 # [N, D]
    _ALL_EMB = torch.nn.functional.normalize(_ALL_EMB, dim=-1).cpu().numpy()

def _best_from_pool(audio_emb: np.ndarray, pool_key: str, topk: int = 1):
    a = audio_emb / (np.linalg.norm(audio_emb) + 1e-12)
    sl = POOL_SLICES[pool_key]
    sims = (_ALL_EMB[sl] @ a)
    idx = np.argsort(-sims)[:topk]
    phrases = [CANDS[pool_key][i] for i in idx]
    return phrases, [float(sims[i]) for i in idx]

def build_prompt_from_audio(audio_emb: np.ndarray) -> str:
    subj, _ = _best_from_pool(audio_emb, "subject", topk=1)
    moods, _ = _best_from_pool(audio_emb, "mood", topk=2)
    styles, _ = _best_from_pool(audio_emb, "style", topk=1)
    vis, _   = _best_from_pool(audio_emb, "visuals", topk=2)
    directive = "high quality, detailed, 512x512"
    return f"{subj[0]}, {', '.join(moods)}, {styles[0]}, {', '.join(vis)}, {directive}"

def embed_audio_bytes(raw_bytes: bytes):
    # Read original audio for playback
    orig, orig_sr = sf.read(io.BytesIO(raw_bytes), dtype="float32", always_2d=False)
    if orig.ndim > 1:
        orig = orig.mean(axis=1)

    # CLAP needs 48 kHz mono
    target_sr = 48000
    proc = orig
    if orig_sr != target_sr:
        g = math.gcd(target_sr, orig_sr)
        up = target_sr // g
        down = orig_sr // g
        proc = resample_poly(orig, up, down)

    inputs = clap_processor(audios=proc, sampling_rate=target_sr, return_tensors="pt")
    with torch.no_grad():
        feats = clap_model.get_audio_features(**{k: v.to(DEVICE) for k, v in inputs.items()})

    # Return features + the ORIGINAL (sr, waveform) for gradio playback
    return feats[0].cpu().numpy(), (orig_sr, orig)

# =========================
# NEIGHBOR INDEX + MATCHING
# =========================
nn_real = NearestNeighbors(n_neighbors=3, metric="cosine")
nn_real.fit(real_embeddings)

def recommend_top3_real_plus_best_synth(audio_emb: np.ndarray):
    # Top-3 real
    dists, idxs = nn_real.kneighbors(audio_emb.reshape(1, -1), n_neighbors=3)
    real_results = [(real_paths[i], 1 - d) for i, d in zip(idxs[0], dists[0])]
    # Best synthetic
    synth_sims = cosine_similarity(audio_emb.reshape(1, -1), synthetic_embeddings)[0]
    j = int(np.argmax(synth_sims))
    synth_result = (synthetic_paths[j], float(synth_sims[j]))
    return real_results, synth_result

def open_images_with_captions(match_list):
    """
    match_list: list of (csv_path_string, score, tag)
    Assumes order alignment: CSV row index == WikiArt streamed index.
    """
    images, captions = [], []
    if not hasattr(open_images_with_captions, "_real_index_map"):
        open_images_with_captions._real_index_map = {real_paths[i]: i for i in range(len(real_paths))}
    idx_map = open_images_with_captions._real_index_map

    for p, score, tag in match_list:
        img = None
        base = os.path.basename(str(p))

        if tag == "Real":
            i = idx_map.get(p)
            img = _get_wikiart_image_by_index(i) if i is not None else None
        elif tag == "Synthetic":
            # Try to load from repo file path, e.g. "synthetic_images/synthetic_08.jpg"
            if isinstance(p, str) and os.path.exists(p):
                img = _open_local_image(p)

        if img is None:
            # fallback placeholder if anything failed above
            img = _placeholder_img(f"{tag}\n{base}")

        images.append(img)
        captions.append(f"{tag}: {base}  (sim {score:.2f})")

    return images, "\n".join(captions)

# =========================
# PIPELINE
# =========================
def _get_sd_pipe():
    global _SD_PIPE
    if _SD_PIPE is not None:
        return _SD_PIPE

    if torch.cuda.is_available():
        # Fast path (recommended)
        _SD_PIPE = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sdxl-turbo",
            torch_dtype=torch.float16,
            use_safetensors=True,
        ).to("cuda")
    else:
        # CPU fallback: smaller + few steps to avoid timeouts/OOM
        _SD_PIPE = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/sd-turbo",           # lighter than SDXL
            torch_dtype=torch.float32,
            use_safetensors=True,
        ).to("cpu")
    return _SD_PIPE

import contextlib

def generate_image_from_last_audio():
    if LAST_AUDIO_EMB is None:
        gr.Warning("Load or upload audio first, then try again.")
        return None

    prompt = build_prompt_from_audio(LAST_AUDIO_EMB)
    pipe = _get_sd_pipe()
    if pipe is None:
        gr.Warning("Image generator not available.")
        return None

    is_cuda = (pipe.device.type == "cuda")
    autocast_ctx = torch.autocast("cuda") if is_cuda else contextlib.nullcontext()

    # Small settings on CPU to keep it reasonable
    steps = 4 if is_cuda else 2
    w = h = 512 if is_cuda else 256
    guidance = 0.0  # Turbo models like low guidance

    with autocast_ctx, torch.inference_mode():
        img = pipe(prompt, num_inference_steps=steps,
                   guidance_scale=guidance, width=w, height=h).images[0]

    if not is_cuda:
        gr.Info("Generated on CPU (reduced size/steps). For fast 512×512, enable a GPU runtime.")
    return img
    
def run_example(i: int):
    # full clip for examples (no trimming)
    raw = get_preset_audio_bytes(i)
    playable, images, caps = run_pipeline_from_bytes(raw)
    # enable generate button
    return playable, images, caps, gr.update(interactive=True)

def run_pipeline_from_bytes(raw_bytes: bytes):
    if raw_bytes is None:
        return None, [], ""
    audio_emb, playable = embed_audio_bytes(raw_bytes)
    
    global LAST_AUDIO_EMB
    LAST_AUDIO_EMB = audio_emb.copy()
    
    real3, synth1 = recommend_top3_real_plus_best_synth(audio_emb)
    combined = [(p, s, "Real") for (p, s) in real3] + [(synth1[0], synth1[1], "Synthetic")]
    images, captions_str = open_images_with_captions(combined)
    return playable, images, captions_str

    # free large arrays we no longer need
    del audio_emb, real3, synth1, combined
    gc.collect()

    return playable, images, captions_str

# =========================
# UI
# =========================
with gr.Blocks(theme="soft") as demo:
    gr.Markdown("# 🎧 EchoArt\nMatch your music to visual art — Top 3 real artworks + 1 AI‑generated image.")

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("**Upload audio (WAV/MP3) or try a preset:**")
            audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio Input")
            start_in = gr.Number(label="Start at (seconds, optional)", value=None, precision=2)
            with gr.Row():
                ex_labels = [f"Example {k+1} (idx {PRESET_ORDINALS[k]})" for k in range(5)]
                ex_buttons = [gr.Button(lbl) for lbl in ex_labels]
            submit = gr.Button("Find Matching Art", variant="primary")
        with gr.Column(scale=2):
            audio_out = gr.Audio(label="Your Audio", interactive=False)
            gallery = gr.Gallery(label="Top 3 Real + 1 Synthetic", columns=4, height=320)
            captions = gr.Textbox(label="Matches", lines=4)

    # --- AI image generation row (button is disabled until audio is processed) ---
    with gr.Row():
        gen_btn = gr.Button("🎨 Generate image from this audio", interactive=False)
        gen_img = gr.Image(label="AI Image", interactive=False)

    # -----------------------
    # Handlers (must be inside Blocks)
    # -----------------------
    def on_submit(audio_path, start_sec):
        if not audio_path:
            # keep the gen button disabled
            return None, [], "", gr.update(interactive=False)
        with open(audio_path, "rb") as f:
            raw = f.read()
        # trim uploads to 30 seconds (start_sec may be None)
        raw = slice_audio_to_wav_bytes(raw, start_sec, duration_sec=30.0)
        playable, images, caps = run_pipeline_from_bytes(raw)
        # enable generate button now that LAST_AUDIO_EMB is set
        return playable, images, caps, gr.update(interactive=True)

    submit.click(
        on_submit,
        inputs=[audio_in, start_in],
        outputs=[audio_out, gallery, captions, gen_btn],
    )

    # Wire example buttons (freeze loop var with k=k)
    for k, btn in enumerate(ex_buttons):
        btn.click(lambda k=k: run_example(k), inputs=None,
                  outputs=[audio_out, gallery, captions, gen_btn])

    def on_generate():
        img = generate_image_from_last_audio()
        return img

    gen_btn.click(on_generate, inputs=None, outputs=[gen_img])


# Outside the Blocks context:
demo.launch()