|
import os |
|
import torch |
|
from imagebind import data |
|
from imagebind.models import imagebind_model |
|
from imagebind.models.imagebind_model import ModalityType |
|
from pydub import AudioSegment |
|
from fastapi import FastAPI, UploadFile, File, Form |
|
from typing import List, Dict |
|
import tempfile |
|
from pydantic import BaseModel |
|
import uvicorn |
|
import numpy as np |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
from fastapi import Depends, HTTPException, status |
|
|
|
app = FastAPI() |
|
|
|
|
|
security = HTTPBearer() |
|
API_TOKEN = os.getenv("API_TOKEN", "your-default-token-here") |
|
|
|
|
|
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): |
|
if credentials.credentials != API_TOKEN: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid authentication token", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
return credentials.credentials |
|
|
|
def convert_audio_to_wav(audio_path: str) -> str: |
|
"""Convert MP3 to WAV if necessary.""" |
|
if audio_path.lower().endswith('.mp3'): |
|
wav_path = audio_path.rsplit('.', 1)[0] + '.wav' |
|
if not os.path.exists(wav_path): |
|
audio = AudioSegment.from_mp3(audio_path) |
|
audio.export(wav_path, format='wav') |
|
return wav_path |
|
return audio_path |
|
|
|
class EmbeddingManager: |
|
def __init__(self): |
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
self.model = imagebind_model.imagebind_huge(pretrained=True) |
|
self.model.eval() |
|
self.model.to(self.device) |
|
|
|
def compute_embeddings(self, |
|
images: List[str] = None, |
|
audio_files: List[str] = None, |
|
texts: List[str] = None) -> dict: |
|
"""Compute embeddings for provided modalities only.""" |
|
with torch.no_grad(): |
|
inputs = {} |
|
|
|
if texts: |
|
inputs[ModalityType.TEXT] = data.load_and_transform_text(texts, self.device) |
|
if images: |
|
inputs[ModalityType.VISION] = data.load_and_transform_vision_data(images, self.device) |
|
if audio_files: |
|
inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data(audio_files, self.device) |
|
|
|
if not inputs: |
|
return {} |
|
|
|
embeddings = self.model(inputs) |
|
|
|
result = {} |
|
if ModalityType.VISION in inputs: |
|
result['vision'] = embeddings[ModalityType.VISION].cpu().numpy().tolist() |
|
if ModalityType.AUDIO in inputs: |
|
result['audio'] = embeddings[ModalityType.AUDIO].cpu().numpy().tolist() |
|
if ModalityType.TEXT in inputs: |
|
result['text'] = embeddings[ModalityType.TEXT].cpu().numpy().tolist() |
|
|
|
return result |
|
|
|
@staticmethod |
|
def compute_similarities(embeddings: Dict[str, List[List[float]]]) -> dict: |
|
"""Compute similarities between available embeddings.""" |
|
similarities = {} |
|
|
|
|
|
tensors = { |
|
k: torch.tensor(v) for k, v in embeddings.items() |
|
if isinstance(v, (list, np.ndarray)) and len(v) > 0 |
|
} |
|
|
|
|
|
modality_pairs = [ |
|
('vision', 'audio', 'vision_audio'), |
|
('vision', 'text', 'vision_text'), |
|
('audio', 'text', 'audio_text') |
|
] |
|
|
|
for mod1, mod2, key in modality_pairs: |
|
if mod1 in tensors and mod2 in tensors: |
|
similarities[key] = torch.softmax( |
|
tensors[mod1] @ tensors[mod2].T, |
|
dim=-1 |
|
).numpy().tolist() |
|
|
|
|
|
for modality in ['vision', 'audio', 'text']: |
|
if modality in tensors: |
|
key = f'{modality}_{modality}' |
|
similarities[key] = torch.softmax( |
|
tensors[modality] @ tensors[modality].T, |
|
dim=-1 |
|
).numpy().tolist() |
|
|
|
return similarities |
|
|
|
|
|
embedding_manager = EmbeddingManager() |
|
|
|
class EmbeddingResponse(BaseModel): |
|
embeddings: dict |
|
file_names: dict |
|
|
|
class SimilarityRequest(BaseModel): |
|
embeddings: Dict[str, List[List[float]]] |
|
threshold: float = 0.5 |
|
top_k: int | None = None |
|
include_self_similarity: bool = False |
|
normalize_scores: bool = True |
|
|
|
class SimilarityMatch(BaseModel): |
|
index_a: int |
|
index_b: int |
|
score: float |
|
modality_a: str |
|
modality_b: str |
|
item_a: str |
|
item_b: str |
|
|
|
class SimilarityResponse(BaseModel): |
|
matches: List[SimilarityMatch] |
|
statistics: Dict[str, float] |
|
modality_pairs: List[str] |
|
|
|
class ModalityPair: |
|
def __init__(self, mod1: str, mod2: str): |
|
self.mod1 = min(mod1, mod2) |
|
self.mod2 = max(mod1, mod2) |
|
|
|
def __str__(self): |
|
return f"{self.mod1}_to_{self.mod2}" |
|
|
|
def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor: |
|
"""Compute cosine similarity between two sets of embeddings.""" |
|
|
|
if normalize: |
|
tensor1 = torch.nn.functional.normalize(tensor1, dim=1) |
|
tensor2 = torch.nn.functional.normalize(tensor2, dim=1) |
|
|
|
|
|
similarity = torch.matmul(tensor1, tensor2.T) |
|
|
|
return similarity |
|
|
|
def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]: |
|
"""Get top-k matches from a similarity matrix.""" |
|
if top_k is None: |
|
top_k = similarity_matrix.numel() |
|
|
|
|
|
flat_sim = similarity_matrix.flatten() |
|
top_k = min(top_k, flat_sim.numel()) |
|
values, indices = torch.topk(flat_sim, k=top_k) |
|
|
|
|
|
rows = indices // similarity_matrix.size(1) |
|
cols = indices % similarity_matrix.size(1) |
|
|
|
return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)] |
|
|
|
@app.post("/compute_embeddings", response_model=EmbeddingResponse) |
|
async def generate_embeddings( |
|
credentials: HTTPAuthorizationCredentials = Depends(verify_token), |
|
texts: str | None = Form(None), |
|
images: List[UploadFile] | None = File(default=None), |
|
audio_files: List[UploadFile] | None = File(default=None) |
|
): |
|
"""Generate embeddings for any provided files and texts.""" |
|
temp_files = [] |
|
|
|
try: |
|
image_paths = [] |
|
image_names = [] |
|
audio_paths = [] |
|
audio_names = [] |
|
text_list = [] |
|
|
|
|
|
if images: |
|
for img in images: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(img.filename)[1]) as tmp: |
|
content = await img.read() |
|
tmp.write(content) |
|
image_paths.append(tmp.name) |
|
image_names.append(img.filename) |
|
temp_files.append(tmp.name) |
|
|
|
|
|
if audio_files: |
|
for audio in audio_files: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp: |
|
content = await audio.read() |
|
tmp.write(content) |
|
audio_path = convert_audio_to_wav(tmp.name) |
|
audio_paths.append(audio_path) |
|
audio_names.append(audio.filename) |
|
temp_files.append(tmp.name) |
|
if audio_path != tmp.name: |
|
temp_files.append(audio_path) |
|
|
|
|
|
if texts: |
|
text_list = [text.strip() for text in texts.split('\n') if text.strip()] |
|
|
|
|
|
if not any([image_paths, audio_paths, text_list]): |
|
return EmbeddingResponse( |
|
embeddings={}, |
|
file_names={} |
|
) |
|
|
|
embeddings = embedding_manager.compute_embeddings( |
|
image_paths if image_paths else None, |
|
audio_paths if audio_paths else None, |
|
text_list if text_list else None |
|
) |
|
|
|
file_names = {} |
|
if image_names: |
|
file_names['images'] = image_names |
|
if audio_names: |
|
file_names['audio'] = audio_names |
|
if text_list: |
|
file_names['texts'] = text_list |
|
|
|
return EmbeddingResponse( |
|
embeddings=embeddings, |
|
file_names=file_names |
|
) |
|
|
|
finally: |
|
|
|
for temp_file in temp_files: |
|
try: |
|
os.unlink(temp_file) |
|
except: |
|
pass |
|
|
|
@app.post("/compute_similarities", response_model=SimilarityResponse) |
|
async def compute_similarities( |
|
request: SimilarityRequest, |
|
file_names: Dict[str, List[str]], |
|
credentials: HTTPAuthorizationCredentials = Depends(verify_token) |
|
): |
|
""" |
|
Compute cross-modal similarities with advanced filtering and matching options. |
|
|
|
Parameters: |
|
- embeddings: Dict mapping modality to embedding tensors |
|
- threshold: Minimum similarity score to include in results |
|
- top_k: Maximum number of matches to return (per modality pair) |
|
- include_self_similarity: Whether to include same-item comparisons |
|
- normalize_scores: Whether to normalize embeddings before comparison |
|
- file_names: Dict mapping modality to list of original file/text names |
|
""" |
|
|
|
matches = [] |
|
statistics = { |
|
"avg_score": 0.0, |
|
"max_score": 0.0, |
|
"min_score": 1.0, |
|
"total_comparisons": 0 |
|
} |
|
|
|
|
|
tensors = { |
|
k: torch.tensor(v) for k, v in request.embeddings.items() |
|
if isinstance(v, (list, np.ndarray)) and len(v) > 0 |
|
} |
|
|
|
modality_pairs = [] |
|
all_scores = [] |
|
|
|
|
|
modalities = list(tensors.keys()) |
|
for i, mod1 in enumerate(modalities): |
|
for mod2 in modalities[i:]: |
|
if mod1 == mod2 and not request.include_self_similarity: |
|
continue |
|
|
|
pair = ModalityPair(mod1, mod2) |
|
modality_pairs.append(str(pair)) |
|
|
|
|
|
sim_matrix = compute_similarity_matrix( |
|
tensors[mod1], |
|
tensors[mod2], |
|
normalize=request.normalize_scores |
|
) |
|
|
|
|
|
top_matches = get_top_k_matches(sim_matrix, request.top_k) |
|
|
|
|
|
for idx_a, idx_b, score in top_matches: |
|
if score < request.threshold: |
|
continue |
|
|
|
|
|
if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity: |
|
continue |
|
|
|
matches.append(SimilarityMatch( |
|
index_a=idx_a, |
|
index_b=idx_b, |
|
score=float(score), |
|
modality_a=mod1, |
|
modality_b=mod2, |
|
item_a=file_names[mod1][idx_a], |
|
item_b=file_names[mod2][idx_b] |
|
)) |
|
all_scores.append(score) |
|
|
|
|
|
if all_scores: |
|
statistics.update({ |
|
"avg_score": float(np.mean(all_scores)), |
|
"max_score": float(np.max(all_scores)), |
|
"min_score": float(np.min(all_scores)), |
|
"total_comparisons": len(all_scores) |
|
}) |
|
|
|
|
|
matches.sort(key=lambda x: x.score, reverse=True) |
|
|
|
return SimilarityResponse( |
|
matches=matches, |
|
statistics=statistics, |
|
modality_pairs=modality_pairs |
|
) |
|
|
|
@app.get("/health") |
|
async def health_check( |
|
credentials: HTTPAuthorizationCredentials = Depends(verify_token) |
|
): |
|
"""Basic healthcheck endpoint that returns the status of the service.""" |
|
return { |
|
"status": "healthy", |
|
"model_device": embedding_manager.device |
|
} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |