Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from imagebind import data | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType | |
from pydub import AudioSegment | |
from fastapi import FastAPI, UploadFile, File, Form | |
from typing import List, Dict | |
import tempfile | |
from pydantic import BaseModel | |
import uvicorn | |
import numpy as np | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi import Depends, HTTPException, status | |
app = FastAPI() | |
# Add these lines after the app initialization | |
security = HTTPBearer() | |
API_TOKEN = os.getenv("API_TOKEN", "your-default-token-here") # Set a default token or use environment variable | |
# Add this function for token verification | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
if credentials.credentials != API_TOKEN: | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return credentials.credentials | |
def convert_audio_to_wav(audio_path: str) -> str: | |
"""Convert MP3 to WAV if necessary.""" | |
if audio_path.lower().endswith('.mp3'): | |
wav_path = audio_path.rsplit('.', 1)[0] + '.wav' | |
if not os.path.exists(wav_path): | |
audio = AudioSegment.from_mp3(audio_path) | |
audio.export(wav_path, format='wav') | |
return wav_path | |
return audio_path | |
class EmbeddingManager: | |
def __init__(self): | |
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
self.model = imagebind_model.imagebind_huge(pretrained=True) | |
self.model.eval() | |
self.model.to(self.device) | |
def compute_embeddings(self, | |
images: List[str] = None, | |
audio_files: List[str] = None, | |
texts: List[str] = None) -> dict: | |
"""Compute embeddings for provided modalities only.""" | |
with torch.no_grad(): | |
inputs = {} | |
if texts: | |
inputs[ModalityType.TEXT] = data.load_and_transform_text(texts, self.device) | |
if images: | |
inputs[ModalityType.VISION] = data.load_and_transform_vision_data(images, self.device) | |
if audio_files: | |
inputs[ModalityType.AUDIO] = data.load_and_transform_audio_data(audio_files, self.device) | |
if not inputs: | |
return {} | |
embeddings = self.model(inputs) | |
result = {} | |
if ModalityType.VISION in inputs: | |
result['vision'] = embeddings[ModalityType.VISION].cpu().numpy().tolist() | |
if ModalityType.AUDIO in inputs: | |
result['audio'] = embeddings[ModalityType.AUDIO].cpu().numpy().tolist() | |
if ModalityType.TEXT in inputs: | |
result['text'] = embeddings[ModalityType.TEXT].cpu().numpy().tolist() | |
return result | |
def compute_similarities(embeddings: Dict[str, List[List[float]]]) -> dict: | |
"""Compute similarities between available embeddings.""" | |
similarities = {} | |
# Convert available embeddings to tensors | |
tensors = { | |
k: torch.tensor(v) for k, v in embeddings.items() | |
if isinstance(v, (list, np.ndarray)) and len(v) > 0 | |
} | |
# Compute cross-modal similarities | |
modality_pairs = [ | |
('vision', 'audio', 'vision_audio'), | |
('vision', 'text', 'vision_text'), | |
('audio', 'text', 'audio_text') | |
] | |
for mod1, mod2, key in modality_pairs: | |
if mod1 in tensors and mod2 in tensors: | |
similarities[key] = torch.softmax( | |
tensors[mod1] @ tensors[mod2].T, | |
dim=-1 | |
).numpy().tolist() | |
# Compute same-modality similarities | |
for modality in ['vision', 'audio', 'text']: | |
if modality in tensors: | |
key = f'{modality}_{modality}' | |
similarities[key] = torch.softmax( | |
tensors[modality] @ tensors[modality].T, | |
dim=-1 | |
).numpy().tolist() | |
return similarities | |
# Initialize the embedding manager | |
embedding_manager = EmbeddingManager() | |
class EmbeddingResponse(BaseModel): | |
embeddings: dict | |
file_names: dict | |
class SimilarityRequest(BaseModel): | |
embeddings: Dict[str, List[List[float]]] | |
threshold: float = 0.5 | |
top_k: int | None = None | |
include_self_similarity: bool = False | |
normalize_scores: bool = True | |
class SimilarityMatch(BaseModel): | |
index_a: int | |
index_b: int | |
score: float | |
modality_a: str | |
modality_b: str | |
item_a: str # Original item identifier (filename or text) | |
item_b: str # Original item identifier (filename or text) | |
class SimilarityResponse(BaseModel): | |
matches: List[SimilarityMatch] | |
statistics: Dict[str, float] # Contains avg_score, max_score, etc. | |
modality_pairs: List[str] # Lists which modality comparisons were performed | |
class ModalityPair: | |
def __init__(self, mod1: str, mod2: str): | |
self.mod1 = min(mod1, mod2) # Ensure consistent ordering | |
self.mod2 = max(mod1, mod2) | |
def __str__(self): | |
return f"{self.mod1}_to_{self.mod2}" | |
def compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool = True) -> torch.Tensor: | |
"""Compute cosine similarity between two sets of embeddings.""" | |
# Normalize embeddings if requested | |
if normalize: | |
tensor1 = torch.nn.functional.normalize(tensor1, dim=1) | |
tensor2 = torch.nn.functional.normalize(tensor2, dim=1) | |
# Compute similarity matrix | |
similarity = torch.matmul(tensor1, tensor2.T) | |
return similarity | |
def get_top_k_matches(similarity_matrix: torch.Tensor, top_k: int | None = None) -> List[tuple]: | |
"""Get top-k matches from a similarity matrix.""" | |
if top_k is None: | |
top_k = similarity_matrix.numel() | |
# Flatten and get top-k indices | |
flat_sim = similarity_matrix.flatten() | |
top_k = min(top_k, flat_sim.numel()) | |
values, indices = torch.topk(flat_sim, k=top_k) | |
# Convert flat indices to 2D indices | |
rows = indices // similarity_matrix.size(1) | |
cols = indices % similarity_matrix.size(1) | |
return [(r.item(), c.item(), v.item()) for r, c, v in zip(rows, cols, values)] | |
async def generate_embeddings( | |
credentials: HTTPAuthorizationCredentials = Depends(verify_token), | |
texts: str | None = Form(None), | |
images: List[UploadFile] | None = File(default=None), | |
audio_files: List[UploadFile] | None = File(default=None) | |
): | |
"""Generate embeddings for any provided files and texts.""" | |
temp_files = [] | |
try: | |
image_paths = [] | |
image_names = [] | |
audio_paths = [] | |
audio_names = [] | |
text_list = [] | |
# Process images if provided | |
if images: | |
for img in images: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(img.filename)[1]) as tmp: | |
content = await img.read() | |
tmp.write(content) | |
image_paths.append(tmp.name) | |
image_names.append(img.filename) | |
temp_files.append(tmp.name) | |
# Process audio files if provided | |
if audio_files: | |
for audio in audio_files: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(audio.filename)[1]) as tmp: | |
content = await audio.read() | |
tmp.write(content) | |
audio_path = convert_audio_to_wav(tmp.name) | |
audio_paths.append(audio_path) | |
audio_names.append(audio.filename) | |
temp_files.append(tmp.name) | |
if audio_path != tmp.name: | |
temp_files.append(audio_path) | |
# Process texts if provided | |
if texts: | |
text_list = [text.strip() for text in texts.split('\n') if text.strip()] | |
# Compute embeddings only if we have any input | |
if not any([image_paths, audio_paths, text_list]): | |
return EmbeddingResponse( | |
embeddings={}, | |
file_names={} | |
) | |
embeddings = embedding_manager.compute_embeddings( | |
image_paths if image_paths else None, | |
audio_paths if audio_paths else None, | |
text_list if text_list else None | |
) | |
file_names = {} | |
if image_names: | |
file_names['images'] = image_names | |
if audio_names: | |
file_names['audio'] = audio_names | |
if text_list: | |
file_names['texts'] = text_list | |
return EmbeddingResponse( | |
embeddings=embeddings, | |
file_names=file_names | |
) | |
finally: | |
# Clean up temporary files | |
for temp_file in temp_files: | |
try: | |
os.unlink(temp_file) | |
except: | |
pass | |
async def compute_similarities( | |
request: SimilarityRequest, | |
file_names: Dict[str, List[str]], # Maps modality to list of file/text names | |
credentials: HTTPAuthorizationCredentials = Depends(verify_token) | |
): | |
""" | |
Compute cross-modal similarities with advanced filtering and matching options. | |
Parameters: | |
- embeddings: Dict mapping modality to embedding tensors | |
- threshold: Minimum similarity score to include in results | |
- top_k: Maximum number of matches to return (per modality pair) | |
- include_self_similarity: Whether to include same-item comparisons | |
- normalize_scores: Whether to normalize embeddings before comparison | |
- file_names: Dict mapping modality to list of original file/text names | |
""" | |
matches = [] | |
statistics = { | |
"avg_score": 0.0, | |
"max_score": 0.0, | |
"min_score": 1.0, | |
"total_comparisons": 0 | |
} | |
# Convert embeddings to tensors | |
tensors = { | |
k: torch.tensor(v) for k, v in request.embeddings.items() | |
if isinstance(v, (list, np.ndarray)) and len(v) > 0 | |
} | |
modality_pairs = [] | |
all_scores = [] | |
# Get all possible modality pairs | |
modalities = list(tensors.keys()) | |
for i, mod1 in enumerate(modalities): | |
for mod2 in modalities[i:]: # Include self-comparisons if requested | |
if mod1 == mod2 and not request.include_self_similarity: | |
continue | |
pair = ModalityPair(mod1, mod2) | |
modality_pairs.append(str(pair)) | |
# Compute similarity matrix | |
sim_matrix = compute_similarity_matrix( | |
tensors[mod1], | |
tensors[mod2], | |
normalize=request.normalize_scores | |
) | |
# Get top matches | |
top_matches = get_top_k_matches(sim_matrix, request.top_k) | |
# Filter by threshold and create match objects | |
for idx_a, idx_b, score in top_matches: | |
if score < request.threshold: | |
continue | |
# Skip self-matches if not requested | |
if mod1 == mod2 and idx_a == idx_b and not request.include_self_similarity: | |
continue | |
matches.append(SimilarityMatch( | |
index_a=idx_a, | |
index_b=idx_b, | |
score=float(score), | |
modality_a=mod1, | |
modality_b=mod2, | |
item_a=file_names[mod1][idx_a], | |
item_b=file_names[mod2][idx_b] | |
)) | |
all_scores.append(score) | |
# Compute statistics | |
if all_scores: | |
statistics.update({ | |
"avg_score": float(np.mean(all_scores)), | |
"max_score": float(np.max(all_scores)), | |
"min_score": float(np.min(all_scores)), | |
"total_comparisons": len(all_scores) | |
}) | |
# Sort matches by score in descending order | |
matches.sort(key=lambda x: x.score, reverse=True) | |
return SimilarityResponse( | |
matches=matches, | |
statistics=statistics, | |
modality_pairs=modality_pairs | |
) | |
async def health_check( | |
credentials: HTTPAuthorizationCredentials = Depends(verify_token) | |
): | |
"""Basic healthcheck endpoint that returns the status of the service.""" | |
return { | |
"status": "healthy", | |
"model_device": embedding_manager.device | |
} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) |