Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import json | |
import numpy as np | |
import torch | |
from transformers import CLIPProcessor, CLIPModel, Wav2Vec2Processor, Wav2Vec2Model | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
_here = Path(__file__).parent | |
_labels = json.loads((_here / "labels.json").read_text())["labels"] | |
LABELS = [x["name"] for x in _labels] | |
PROMPTS = [x["prompt"] for x in _labels] | |
_clip_model = None | |
_clip_proc = None | |
_wav_model = None | |
_wav_proc = None | |
def _lazy_load_models(): | |
global _clip_model, _clip_proc, _wav_model, _wav_proc | |
if _clip_model is None: | |
_clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE) | |
_clip_model.eval() | |
_clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
if _wav_model is None: | |
_wav_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(DEVICE) | |
_wav_model.eval() | |
_wav_proc = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
# image branch (CLIP) | |
def clip_image_probs(pil_image, prompts=PROMPTS): | |
_lazy_load_models() | |
# text features | |
text_inputs = _clip_proc(text=prompts, return_tensors="pt", padding=True).to(DEVICE) | |
text_feats = _clip_model.get_text_features(**text_inputs) # [K, d] | |
text_feats = torch.nn.functional.normalize(text_feats, dim=-1) | |
# image features | |
img_inputs = _clip_proc(images=pil_image, return_tensors="pt").to(DEVICE) | |
img_feats = _clip_model.get_image_features(**img_inputs) # [1, d] | |
img_feats = torch.nn.functional.normalize(img_feats, dim=-1) | |
# similarity to softmax | |
sims = (img_feats @ text_feats.T).squeeze(0) # [K] | |
probs = torch.softmax(sims, dim=-1) # [K] | |
return probs.detach().cpu().numpy() # np.float32[K] | |
# audio branch (Wav2Vec2 + energy prior) | |
def wav2vec2_embed_energy(wave_16k: np.ndarray): | |
_lazy_load_models() | |
# wave_16k must be float32 mono in [-1, 1] | |
inp = _wav_proc(wave_16k, sampling_rate=16000, return_tensors="pt").to(DEVICE) | |
out = _wav_model(**inp).last_hidden_state # [1, T, 768] | |
emb = out.mean(dim=1).squeeze(0) # [768] | |
emb = torch.nn.functional.normalize(emb, dim=-1) | |
emb_np = emb.detach().cpu().numpy() | |
# simple loudness proxy (RMS) | |
rms = float(np.sqrt(np.mean(np.square(wave_16k)))) # 0..~1 | |
return emb_np, rms | |
def audio_prior_from_rms(rms: float) -> np.ndarray: | |
# clamp | |
r = max(0.0, min(1.0, rms)) | |
# weights via curves | |
calm = max(0.0, 1.0 - 2.0*r) # high when quiet | |
sad = max(0.0, 1.2 - 2.2*r) | |
energetic = r**0.8 # grows with loudness | |
joyful = (r**0.9) * 0.9 + 0.1*(1-r) # energetic but with a small bias | |
suspense = 0.6*(1.0 - abs(r - 0.5)*2) # middle loudness means suspense | |
vec = np.array([calm, energetic, suspense, joyful, sad], dtype=np.float32) | |
vec = np.clip(vec, 1e-4, None) | |
vec = vec / vec.sum() | |
return vec | |
# fusion | |
def fuse_probs(image_probs: np.ndarray, audio_prior: np.ndarray, alpha: float = 0.7) -> np.ndarray: | |
p_img = image_probs / (image_probs.sum() + 1e-8) # alpha closer to 1 favors image, 0 favors audio. | |
p_aud = audio_prior / (audio_prior.sum() + 1e-8) | |
p = alpha * p_img + (1.0 - alpha) * p_aud | |
p = p / (p.sum() + 1e-8) | |
return p | |
def top1_label_from_probs(p: np.ndarray) -> str: | |
return LABELS[int(p.argmax())] | |