SonicaB's picture
Upload folder using huggingface_hub
3ea794f verified
from pathlib import Path
import json
import numpy as np
import torch
from transformers import CLIPProcessor, CLIPModel, Wav2Vec2Processor, Wav2Vec2Model
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_here = Path(__file__).parent
_labels = json.loads((_here / "labels.json").read_text())["labels"]
LABELS = [x["name"] for x in _labels]
PROMPTS = [x["prompt"] for x in _labels]
_clip_model = None
_clip_proc = None
_wav_model = None
_wav_proc = None
def _lazy_load_models():
global _clip_model, _clip_proc, _wav_model, _wav_proc
if _clip_model is None:
_clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
_clip_model.eval()
_clip_proc = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
if _wav_model is None:
_wav_model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").to(DEVICE)
_wav_model.eval()
_wav_proc = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
# image branch (CLIP)
@torch.no_grad()
def clip_image_probs(pil_image, prompts=PROMPTS):
_lazy_load_models()
# text features
text_inputs = _clip_proc(text=prompts, return_tensors="pt", padding=True).to(DEVICE)
text_feats = _clip_model.get_text_features(**text_inputs) # [K, d]
text_feats = torch.nn.functional.normalize(text_feats, dim=-1)
# image features
img_inputs = _clip_proc(images=pil_image, return_tensors="pt").to(DEVICE)
img_feats = _clip_model.get_image_features(**img_inputs) # [1, d]
img_feats = torch.nn.functional.normalize(img_feats, dim=-1)
# similarity to softmax
sims = (img_feats @ text_feats.T).squeeze(0) # [K]
probs = torch.softmax(sims, dim=-1) # [K]
return probs.detach().cpu().numpy() # np.float32[K]
# audio branch (Wav2Vec2 + energy prior)
@torch.no_grad()
def wav2vec2_embed_energy(wave_16k: np.ndarray):
_lazy_load_models()
# wave_16k must be float32 mono in [-1, 1]
inp = _wav_proc(wave_16k, sampling_rate=16000, return_tensors="pt").to(DEVICE)
out = _wav_model(**inp).last_hidden_state # [1, T, 768]
emb = out.mean(dim=1).squeeze(0) # [768]
emb = torch.nn.functional.normalize(emb, dim=-1)
emb_np = emb.detach().cpu().numpy()
# simple loudness proxy (RMS)
rms = float(np.sqrt(np.mean(np.square(wave_16k)))) # 0..~1
return emb_np, rms
def audio_prior_from_rms(rms: float) -> np.ndarray:
# clamp
r = max(0.0, min(1.0, rms))
# weights via curves
calm = max(0.0, 1.0 - 2.0*r) # high when quiet
sad = max(0.0, 1.2 - 2.2*r)
energetic = r**0.8 # grows with loudness
joyful = (r**0.9) * 0.9 + 0.1*(1-r) # energetic but with a small bias
suspense = 0.6*(1.0 - abs(r - 0.5)*2) # middle loudness means suspense
vec = np.array([calm, energetic, suspense, joyful, sad], dtype=np.float32)
vec = np.clip(vec, 1e-4, None)
vec = vec / vec.sum()
return vec
# fusion
def fuse_probs(image_probs: np.ndarray, audio_prior: np.ndarray, alpha: float = 0.7) -> np.ndarray:
p_img = image_probs / (image_probs.sum() + 1e-8) # alpha closer to 1 favors image, 0 favors audio.
p_aud = audio_prior / (audio_prior.sum() + 1e-8)
p = alpha * p_img + (1.0 - alpha) * p_aud
p = p / (p.sum() + 1e-8)
return p
def top1_label_from_probs(p: np.ndarray) -> str:
return LABELS[int(p.argmax())]