Spaces:
Runtime error
Runtime error
import os | |
import torch | |
from imagebind import data | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType as ImageBindModalityType | |
from pydub import AudioSegment | |
from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from fastapi.concurrency import run_in_threadpool | |
from pydantic import BaseModel, Field # Убрали BaseSettings отсюда | |
from pydantic_settings import BaseSettings # <--- ИЗМЕНЕННЫЙ ИМПОРТ | |
from typing import List, Dict, Optional, Tuple, Any | |
import tempfile | |
import uvicorn | |
import numpy as np | |
import logging | |
from contextlib import asynccontextmanager | |
class Settings(BaseSettings): | |
api_token: str = "your-default-token-here" | |
model_device: Optional[str] = None | |
log_level: str = "INFO" | |
class Config: | |
env_file = ".env" | |
env_file_encoding = 'utf-8' | |
settings = Settings() | |
logging.basicConfig(level=settings.log_level.upper()) | |
logger = logging.getLogger(__name__) | |
class EmbeddingManager: | |
_instance = None | |
def __new__(cls, *args, **kwargs): | |
if not cls._instance: | |
cls._instance = super(EmbeddingManager, cls).__new__(cls, *args, **kwargs) | |
return cls._instance | |
def __init__(self): | |
if not hasattr(self, 'initialized'): | |
self.device = settings.model_device or ("cuda:0" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Initializing EmbeddingManager on device: {self.device}") | |
try: | |
self.model = imagebind_model.imagebind_huge(pretrained=True) | |
self.model.eval() | |
self.model.to(self.device) | |
self.initialized = True | |
logger.info("ImageBind model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load ImageBind model: {e}") | |
raise RuntimeError(f"Failed to load ImageBind model: {e}") | |
async def compute_embeddings(self, | |
image_inputs: Optional[List[Tuple[str, str]]] = None, | |
audio_inputs: Optional[List[Tuple[str, str]]] = None, | |
text_inputs: Optional[List[str]] = None, | |
depth_inputs: Optional[List[Tuple[str, str]]] = None, | |
thermal_inputs: Optional[List[Tuple[str, str]]] = None, | |
imu_inputs: Optional[List[Tuple[str, str]]] = None | |
) -> Dict[str, List[Dict[str, Any]]]: | |
inputs = {} | |
input_ids = {} | |
if text_inputs: | |
inputs[ImageBindModalityType.TEXT] = data.load_and_transform_text(text_inputs, self.device) | |
input_ids[ImageBindModalityType.TEXT] = text_inputs | |
if image_inputs: | |
paths = [item[0] for item in image_inputs] | |
inputs[ImageBindModalityType.VISION] = data.load_and_transform_vision_data(paths, self.device) | |
input_ids[ImageBindModalityType.VISION] = [item[1] for item in image_inputs] | |
if audio_inputs: | |
paths = [item[0] for item in audio_inputs] | |
inputs[ImageBindModalityType.AUDIO] = data.load_and_transform_audio_data(paths, self.device) | |
input_ids[ImageBindModalityType.AUDIO] = [item[1] for item in audio_inputs] | |
if depth_inputs: | |
logger.warning("Depth modality processing is not yet fully implemented.") | |
if thermal_inputs: | |
logger.warning("Thermal modality processing is not yet fully implemented.") | |
if imu_inputs: | |
logger.warning("IMU modality processing is not yet fully implemented.") | |
if not inputs: | |
return {} | |
with torch.no_grad(): | |
raw_embeddings = await run_in_threadpool(self.model, inputs) | |
result_embeddings = {} | |
for modality_type, embeddings_tensor in raw_embeddings.items(): | |
modality_key = modality_type.name.lower() | |
result_embeddings[modality_key] = [] | |
ids_for_modality = input_ids.get(modality_type, []) | |
for i, emb in enumerate(embeddings_tensor.cpu().numpy().tolist()): | |
item_id = ids_for_modality[i] if i < len(ids_for_modality) else f"item_{i}" | |
result_embeddings[modality_key].append({"id": item_id, "embedding": emb}) | |
return result_embeddings | |
embedding_manager: Optional[EmbeddingManager] = None | |
async def lifespan(app: FastAPI): | |
global embedding_manager | |
logger.info("Application startup...") | |
embedding_manager = EmbeddingManager() | |
settings.model_device = embedding_manager.device | |
yield | |
logger.info("Application shutdown...") | |
app = FastAPI(lifespan=lifespan, title="ImageBind API", version="0.2.0") | |
security = HTTPBearer() | |
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): | |
if credentials.scheme != "Bearer" or credentials.credentials != settings.api_token: | |
logger.warning(f"Invalid authentication attempt. Scheme: {credentials.scheme}") | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Invalid authentication token", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
return credentials.credentials | |
async def _save_upload_file_tmp(upload_file: UploadFile) -> Tuple[str, str]: | |
try: | |
suffix = os.path.splitext(upload_file.filename)[1] | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
content = await upload_file.read() | |
tmp.write(content) | |
return tmp.name, upload_file.filename | |
except Exception as e: | |
logger.error(f"Error saving uploaded file {upload_file.filename}: {e}") | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save file: {upload_file.filename}") | |
def convert_audio_to_wav(audio_path: str, original_filename: str) -> str: | |
if audio_path.lower().endswith('.mp3') or not audio_path.lower().endswith('.wav'): | |
wav_path = audio_path.rsplit('.', 1)[0] + '.wav' | |
try: | |
logger.info(f"Converting {original_filename} to WAV format.") | |
audio = AudioSegment.from_file(audio_path) | |
audio.export(wav_path, format='wav') | |
if audio_path != wav_path and os.path.exists(audio_path): | |
try: | |
os.unlink(audio_path) | |
except OSError: | |
pass | |
return wav_path | |
except Exception as e: | |
logger.error(f"Error converting audio file {original_filename} to WAV: {e}") | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not process audio file {original_filename}: {e}") | |
return audio_path | |
class ModalityType(str): | |
VISION = "vision" | |
AUDIO = "audio" | |
TEXT = "text" | |
DEPTH = "depth" | |
THERMAL = "thermal" | |
IMU = "imu" | |
class EmbeddingItem(BaseModel): | |
id: str | |
embedding: List[float] | |
class EmbeddingPayload(BaseModel): | |
vision: Optional[List[EmbeddingItem]] = None | |
audio: Optional[List[EmbeddingItem]] = None | |
text: Optional[List[EmbeddingItem]] = None | |
depth: Optional[List[EmbeddingItem]] = None | |
thermal: Optional[List[EmbeddingItem]] = None | |
imu: Optional[List[EmbeddingItem]] = None | |
class EmbeddingResponse(BaseModel): | |
embeddings: EmbeddingPayload | |
message: str = "Embeddings computed successfully" | |
class SimilarityMatch(BaseModel): | |
item_a_id: str | |
item_b_id: str | |
modality_a: ModalityType | |
modality_b: ModalityType | |
score: float | |
class SimilarityRequest(BaseModel): | |
embeddings_payload: EmbeddingPayload | |
threshold: float = 0.5 | |
top_k: Optional[int] = None | |
normalize_scores: bool = True | |
compare_within_modalities: bool = True | |
compare_across_modalities: bool = True | |
class SimilarityResponse(BaseModel): | |
matches: List[SimilarityMatch] | |
statistics: Dict[str, float] | |
modality_pairs_compared: List[str] | |
async def generate_embeddings_endpoint( | |
texts: Optional[List[str]] = Form(None), | |
images: Optional[List[UploadFile]] = File(default=None), | |
audio_files: Optional[List[UploadFile]] = File(default=None) | |
): | |
if embedding_manager is None: | |
raise HTTPException(status_code=503, detail="Embedding manager not initialized.") | |
temp_files_to_clean = [] | |
try: | |
image_inputs: List[Tuple[str, str]] = [] | |
audio_inputs: List[Tuple[str, str]] = [] | |
if images: | |
for img_file in images: | |
path, name = await _save_upload_file_tmp(img_file) | |
image_inputs.append((path, name)) | |
temp_files_to_clean.append(path) | |
if audio_files: | |
for audio_file_in in audio_files: | |
path, name = await _save_upload_file_tmp(audio_file_in) | |
temp_files_to_clean.append(path) | |
wav_path = convert_audio_to_wav(path, name) | |
audio_inputs.append((wav_path, name)) | |
if wav_path != path: | |
temp_files_to_clean.append(wav_path) | |
text_inputs_processed = [t.strip() for t in texts if t.strip()] if texts else None | |
if not any([image_inputs, audio_inputs, text_inputs_processed]): | |
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid inputs provided for embedding.") | |
computed_data = await embedding_manager.compute_embeddings( | |
image_inputs=image_inputs if image_inputs else None, | |
audio_inputs=audio_inputs if audio_inputs else None, | |
text_inputs=text_inputs_processed if text_inputs_processed else None | |
) | |
payload_data = { | |
ModalityType.VISION: computed_data.get(ModalityType.VISION, []), | |
ModalityType.AUDIO: computed_data.get(ModalityType.AUDIO, []), | |
ModalityType.TEXT: computed_data.get(ModalityType.TEXT, []), | |
} | |
embedding_payload = EmbeddingPayload(**payload_data) | |
return EmbeddingResponse(embeddings=embedding_payload) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in /compute_embeddings: {e}", exc_info=True) | |
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}") | |
finally: | |
for temp_file in temp_files_to_clean: | |
try: | |
if os.path.exists(temp_file): | |
os.unlink(temp_file) | |
except Exception as e_clean: | |
logger.warning(f"Could not clean up temporary file {temp_file}: {e_clean}") | |
def _compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool) -> torch.Tensor: | |
if normalize: | |
tensor1 = torch.nn.functional.normalize(tensor1, p=2, dim=1) | |
tensor2 = torch.nn.functional.normalize(tensor2, p=2, dim=1) | |
return torch.matmul(tensor1, tensor2.T) | |
async def compute_similarities_endpoint(request: SimilarityRequest): | |
all_matches: List[SimilarityMatch] = [] | |
all_scores: List[float] = [] | |
modality_pairs_compared_set = set() | |
embeddings_by_modality: Dict[ModalityType, List[EmbeddingItem]] = {} | |
if request.embeddings_payload.vision: | |
embeddings_by_modality[ModalityType.VISION] = request.embeddings_payload.vision | |
if request.embeddings_payload.audio: | |
embeddings_by_modality[ModalityType.AUDIO] = request.embeddings_payload.audio | |
if request.embeddings_payload.text: | |
embeddings_by_modality[ModalityType.TEXT] = request.embeddings_payload.text | |
modalities_present = list(embeddings_by_modality.keys()) | |
current_device = embedding_manager.device if embedding_manager else "cpu" | |
for i, mod1_type in enumerate(modalities_present): | |
items1 = embeddings_by_modality[mod1_type] | |
if not items1: continue | |
tensor1 = torch.tensor([item.embedding for item in items1], device=current_device) | |
if request.compare_within_modalities: | |
sim_matrix_intra = _compute_similarity_matrix(tensor1, tensor1, request.normalize_scores) | |
modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod1_type.value}") | |
for r_idx in range(len(items1)): | |
for c_idx in range(r_idx + 1, len(items1)): | |
score = float(sim_matrix_intra[r_idx, c_idx].item()) | |
if score >= request.threshold: | |
all_matches.append(SimilarityMatch( | |
item_a_id=items1[r_idx].id, item_b_id=items1[c_idx].id, | |
modality_a=mod1_type, modality_b=mod1_type, score=score | |
)) | |
all_scores.append(score) | |
if request.compare_across_modalities: | |
for j in range(i + 1, len(modalities_present)): | |
mod2_type = modalities_present[j] | |
items2 = embeddings_by_modality[mod2_type] | |
if not items2: continue | |
tensor2 = torch.tensor([item.embedding for item in items2], device=current_device) | |
sim_matrix_inter = _compute_similarity_matrix(tensor1, tensor2, request.normalize_scores) | |
modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod2_type.value}") | |
for r_idx in range(len(items1)): | |
for c_idx in range(len(items2)): | |
score = float(sim_matrix_inter[r_idx, c_idx].item()) | |
if score >= request.threshold: | |
all_matches.append(SimilarityMatch( | |
item_a_id=items1[r_idx].id, item_b_id=items2[c_idx].id, | |
modality_a=mod1_type, modality_b=mod2_type, score=score | |
)) | |
all_scores.append(score) | |
all_matches.sort(key=lambda x: x.score, reverse=True) | |
if request.top_k and len(all_matches) > request.top_k: | |
all_matches = all_matches[:request.top_k] | |
all_scores = [match.score for match in all_matches] | |
stats = { | |
"total_matches_found_above_threshold": len(all_matches), | |
"avg_score": float(np.mean(all_scores)) if all_scores else 0.0, | |
"max_score": float(np.max(all_scores)) if all_scores else 0.0, | |
"min_score": float(np.min(all_scores)) if all_scores else 0.0, | |
} | |
return SimilarityResponse( | |
matches=all_matches, | |
statistics=stats, | |
modality_pairs_compared=sorted(list(modality_pairs_compared_set)) | |
) | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_device": settings.model_device, | |
"torch_version": torch.__version__, | |
"cuda_available": torch.cuda.is_available() | |
} | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level=settings.log_level.lower()) |