imagebind2 / main.py
opex792's picture
Update main.py
b7271ae verified
import os
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType as ImageBindModalityType
from pydub import AudioSegment
from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field # Убрали BaseSettings отсюда
from pydantic_settings import BaseSettings # <--- ИЗМЕНЕННЫЙ ИМПОРТ
from typing import List, Dict, Optional, Tuple, Any
import tempfile
import uvicorn
import numpy as np
import logging
from contextlib import asynccontextmanager
class Settings(BaseSettings):
api_token: str = "your-default-token-here"
model_device: Optional[str] = None
log_level: str = "INFO"
class Config:
env_file = ".env"
env_file_encoding = 'utf-8'
settings = Settings()
logging.basicConfig(level=settings.log_level.upper())
logger = logging.getLogger(__name__)
class EmbeddingManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(EmbeddingManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, 'initialized'):
self.device = settings.model_device or ("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Initializing EmbeddingManager on device: {self.device}")
try:
self.model = imagebind_model.imagebind_huge(pretrained=True)
self.model.eval()
self.model.to(self.device)
self.initialized = True
logger.info("ImageBind model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load ImageBind model: {e}")
raise RuntimeError(f"Failed to load ImageBind model: {e}")
async def compute_embeddings(self,
image_inputs: Optional[List[Tuple[str, str]]] = None,
audio_inputs: Optional[List[Tuple[str, str]]] = None,
text_inputs: Optional[List[str]] = None,
depth_inputs: Optional[List[Tuple[str, str]]] = None,
thermal_inputs: Optional[List[Tuple[str, str]]] = None,
imu_inputs: Optional[List[Tuple[str, str]]] = None
) -> Dict[str, List[Dict[str, Any]]]:
inputs = {}
input_ids = {}
if text_inputs:
inputs[ImageBindModalityType.TEXT] = data.load_and_transform_text(text_inputs, self.device)
input_ids[ImageBindModalityType.TEXT] = text_inputs
if image_inputs:
paths = [item[0] for item in image_inputs]
inputs[ImageBindModalityType.VISION] = data.load_and_transform_vision_data(paths, self.device)
input_ids[ImageBindModalityType.VISION] = [item[1] for item in image_inputs]
if audio_inputs:
paths = [item[0] for item in audio_inputs]
inputs[ImageBindModalityType.AUDIO] = data.load_and_transform_audio_data(paths, self.device)
input_ids[ImageBindModalityType.AUDIO] = [item[1] for item in audio_inputs]
if depth_inputs:
logger.warning("Depth modality processing is not yet fully implemented.")
if thermal_inputs:
logger.warning("Thermal modality processing is not yet fully implemented.")
if imu_inputs:
logger.warning("IMU modality processing is not yet fully implemented.")
if not inputs:
return {}
with torch.no_grad():
raw_embeddings = await run_in_threadpool(self.model, inputs)
result_embeddings = {}
for modality_type, embeddings_tensor in raw_embeddings.items():
modality_key = modality_type.name.lower()
result_embeddings[modality_key] = []
ids_for_modality = input_ids.get(modality_type, [])
for i, emb in enumerate(embeddings_tensor.cpu().numpy().tolist()):
item_id = ids_for_modality[i] if i < len(ids_for_modality) else f"item_{i}"
result_embeddings[modality_key].append({"id": item_id, "embedding": emb})
return result_embeddings
embedding_manager: Optional[EmbeddingManager] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global embedding_manager
logger.info("Application startup...")
embedding_manager = EmbeddingManager()
settings.model_device = embedding_manager.device
yield
logger.info("Application shutdown...")
app = FastAPI(lifespan=lifespan, title="ImageBind API", version="0.2.0")
security = HTTPBearer()
async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.scheme != "Bearer" or credentials.credentials != settings.api_token:
logger.warning(f"Invalid authentication attempt. Scheme: {credentials.scheme}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials
async def _save_upload_file_tmp(upload_file: UploadFile) -> Tuple[str, str]:
try:
suffix = os.path.splitext(upload_file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = await upload_file.read()
tmp.write(content)
return tmp.name, upload_file.filename
except Exception as e:
logger.error(f"Error saving uploaded file {upload_file.filename}: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save file: {upload_file.filename}")
def convert_audio_to_wav(audio_path: str, original_filename: str) -> str:
if audio_path.lower().endswith('.mp3') or not audio_path.lower().endswith('.wav'):
wav_path = audio_path.rsplit('.', 1)[0] + '.wav'
try:
logger.info(f"Converting {original_filename} to WAV format.")
audio = AudioSegment.from_file(audio_path)
audio.export(wav_path, format='wav')
if audio_path != wav_path and os.path.exists(audio_path):
try:
os.unlink(audio_path)
except OSError:
pass
return wav_path
except Exception as e:
logger.error(f"Error converting audio file {original_filename} to WAV: {e}")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not process audio file {original_filename}: {e}")
return audio_path
class ModalityType(str):
VISION = "vision"
AUDIO = "audio"
TEXT = "text"
DEPTH = "depth"
THERMAL = "thermal"
IMU = "imu"
class EmbeddingItem(BaseModel):
id: str
embedding: List[float]
class EmbeddingPayload(BaseModel):
vision: Optional[List[EmbeddingItem]] = None
audio: Optional[List[EmbeddingItem]] = None
text: Optional[List[EmbeddingItem]] = None
depth: Optional[List[EmbeddingItem]] = None
thermal: Optional[List[EmbeddingItem]] = None
imu: Optional[List[EmbeddingItem]] = None
class EmbeddingResponse(BaseModel):
embeddings: EmbeddingPayload
message: str = "Embeddings computed successfully"
class SimilarityMatch(BaseModel):
item_a_id: str
item_b_id: str
modality_a: ModalityType
modality_b: ModalityType
score: float
class SimilarityRequest(BaseModel):
embeddings_payload: EmbeddingPayload
threshold: float = 0.5
top_k: Optional[int] = None
normalize_scores: bool = True
compare_within_modalities: bool = True
compare_across_modalities: bool = True
class SimilarityResponse(BaseModel):
matches: List[SimilarityMatch]
statistics: Dict[str, float]
modality_pairs_compared: List[str]
@app.post("/compute_embeddings", response_model=EmbeddingResponse, dependencies=[Depends(verify_token)])
async def generate_embeddings_endpoint(
texts: Optional[List[str]] = Form(None),
images: Optional[List[UploadFile]] = File(default=None),
audio_files: Optional[List[UploadFile]] = File(default=None)
):
if embedding_manager is None:
raise HTTPException(status_code=503, detail="Embedding manager not initialized.")
temp_files_to_clean = []
try:
image_inputs: List[Tuple[str, str]] = []
audio_inputs: List[Tuple[str, str]] = []
if images:
for img_file in images:
path, name = await _save_upload_file_tmp(img_file)
image_inputs.append((path, name))
temp_files_to_clean.append(path)
if audio_files:
for audio_file_in in audio_files:
path, name = await _save_upload_file_tmp(audio_file_in)
temp_files_to_clean.append(path)
wav_path = convert_audio_to_wav(path, name)
audio_inputs.append((wav_path, name))
if wav_path != path:
temp_files_to_clean.append(wav_path)
text_inputs_processed = [t.strip() for t in texts if t.strip()] if texts else None
if not any([image_inputs, audio_inputs, text_inputs_processed]):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid inputs provided for embedding.")
computed_data = await embedding_manager.compute_embeddings(
image_inputs=image_inputs if image_inputs else None,
audio_inputs=audio_inputs if audio_inputs else None,
text_inputs=text_inputs_processed if text_inputs_processed else None
)
payload_data = {
ModalityType.VISION: computed_data.get(ModalityType.VISION, []),
ModalityType.AUDIO: computed_data.get(ModalityType.AUDIO, []),
ModalityType.TEXT: computed_data.get(ModalityType.TEXT, []),
}
embedding_payload = EmbeddingPayload(**payload_data)
return EmbeddingResponse(embeddings=embedding_payload)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in /compute_embeddings: {e}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}")
finally:
for temp_file in temp_files_to_clean:
try:
if os.path.exists(temp_file):
os.unlink(temp_file)
except Exception as e_clean:
logger.warning(f"Could not clean up temporary file {temp_file}: {e_clean}")
def _compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool) -> torch.Tensor:
if normalize:
tensor1 = torch.nn.functional.normalize(tensor1, p=2, dim=1)
tensor2 = torch.nn.functional.normalize(tensor2, p=2, dim=1)
return torch.matmul(tensor1, tensor2.T)
@app.post("/compute_similarities", response_model=SimilarityResponse, dependencies=[Depends(verify_token)])
async def compute_similarities_endpoint(request: SimilarityRequest):
all_matches: List[SimilarityMatch] = []
all_scores: List[float] = []
modality_pairs_compared_set = set()
embeddings_by_modality: Dict[ModalityType, List[EmbeddingItem]] = {}
if request.embeddings_payload.vision:
embeddings_by_modality[ModalityType.VISION] = request.embeddings_payload.vision
if request.embeddings_payload.audio:
embeddings_by_modality[ModalityType.AUDIO] = request.embeddings_payload.audio
if request.embeddings_payload.text:
embeddings_by_modality[ModalityType.TEXT] = request.embeddings_payload.text
modalities_present = list(embeddings_by_modality.keys())
current_device = embedding_manager.device if embedding_manager else "cpu"
for i, mod1_type in enumerate(modalities_present):
items1 = embeddings_by_modality[mod1_type]
if not items1: continue
tensor1 = torch.tensor([item.embedding for item in items1], device=current_device)
if request.compare_within_modalities:
sim_matrix_intra = _compute_similarity_matrix(tensor1, tensor1, request.normalize_scores)
modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod1_type.value}")
for r_idx in range(len(items1)):
for c_idx in range(r_idx + 1, len(items1)):
score = float(sim_matrix_intra[r_idx, c_idx].item())
if score >= request.threshold:
all_matches.append(SimilarityMatch(
item_a_id=items1[r_idx].id, item_b_id=items1[c_idx].id,
modality_a=mod1_type, modality_b=mod1_type, score=score
))
all_scores.append(score)
if request.compare_across_modalities:
for j in range(i + 1, len(modalities_present)):
mod2_type = modalities_present[j]
items2 = embeddings_by_modality[mod2_type]
if not items2: continue
tensor2 = torch.tensor([item.embedding for item in items2], device=current_device)
sim_matrix_inter = _compute_similarity_matrix(tensor1, tensor2, request.normalize_scores)
modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod2_type.value}")
for r_idx in range(len(items1)):
for c_idx in range(len(items2)):
score = float(sim_matrix_inter[r_idx, c_idx].item())
if score >= request.threshold:
all_matches.append(SimilarityMatch(
item_a_id=items1[r_idx].id, item_b_id=items2[c_idx].id,
modality_a=mod1_type, modality_b=mod2_type, score=score
))
all_scores.append(score)
all_matches.sort(key=lambda x: x.score, reverse=True)
if request.top_k and len(all_matches) > request.top_k:
all_matches = all_matches[:request.top_k]
all_scores = [match.score for match in all_matches]
stats = {
"total_matches_found_above_threshold": len(all_matches),
"avg_score": float(np.mean(all_scores)) if all_scores else 0.0,
"max_score": float(np.max(all_scores)) if all_scores else 0.0,
"min_score": float(np.min(all_scores)) if all_scores else 0.0,
}
return SimilarityResponse(
matches=all_matches,
statistics=stats,
modality_pairs_compared=sorted(list(modality_pairs_compared_set))
)
@app.get("/health", status_code=status.HTTP_200_OK, dependencies=[Depends(verify_token)])
async def health_check():
return {
"status": "healthy",
"model_device": settings.model_device,
"torch_version": torch.__version__,
"cuda_available": torch.cuda.is_available()
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, log_level=settings.log_level.lower())