import os import torch from imagebind import data from imagebind.models import imagebind_model from imagebind.models.imagebind_model import ModalityType as ImageBindModalityType from pydub import AudioSegment from fastapi import FastAPI, UploadFile, File, Form, Depends, HTTPException, status from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, Field # Убрали BaseSettings отсюда from pydantic_settings import BaseSettings # <--- ИЗМЕНЕННЫЙ ИМПОРТ from typing import List, Dict, Optional, Tuple, Any import tempfile import uvicorn import numpy as np import logging from contextlib import asynccontextmanager class Settings(BaseSettings): api_token: str = "your-default-token-here" model_device: Optional[str] = None log_level: str = "INFO" class Config: env_file = ".env" env_file_encoding = 'utf-8' settings = Settings() logging.basicConfig(level=settings.log_level.upper()) logger = logging.getLogger(__name__) class EmbeddingManager: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(EmbeddingManager, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self): if not hasattr(self, 'initialized'): self.device = settings.model_device or ("cuda:0" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing EmbeddingManager on device: {self.device}") try: self.model = imagebind_model.imagebind_huge(pretrained=True) self.model.eval() self.model.to(self.device) self.initialized = True logger.info("ImageBind model loaded successfully.") except Exception as e: logger.error(f"Failed to load ImageBind model: {e}") raise RuntimeError(f"Failed to load ImageBind model: {e}") async def compute_embeddings(self, image_inputs: Optional[List[Tuple[str, str]]] = None, audio_inputs: Optional[List[Tuple[str, str]]] = None, text_inputs: Optional[List[str]] = None, depth_inputs: Optional[List[Tuple[str, str]]] = None, thermal_inputs: Optional[List[Tuple[str, str]]] = None, imu_inputs: Optional[List[Tuple[str, str]]] = None ) -> Dict[str, List[Dict[str, Any]]]: inputs = {} input_ids = {} if text_inputs: inputs[ImageBindModalityType.TEXT] = data.load_and_transform_text(text_inputs, self.device) input_ids[ImageBindModalityType.TEXT] = text_inputs if image_inputs: paths = [item[0] for item in image_inputs] inputs[ImageBindModalityType.VISION] = data.load_and_transform_vision_data(paths, self.device) input_ids[ImageBindModalityType.VISION] = [item[1] for item in image_inputs] if audio_inputs: paths = [item[0] for item in audio_inputs] inputs[ImageBindModalityType.AUDIO] = data.load_and_transform_audio_data(paths, self.device) input_ids[ImageBindModalityType.AUDIO] = [item[1] for item in audio_inputs] if depth_inputs: logger.warning("Depth modality processing is not yet fully implemented.") if thermal_inputs: logger.warning("Thermal modality processing is not yet fully implemented.") if imu_inputs: logger.warning("IMU modality processing is not yet fully implemented.") if not inputs: return {} with torch.no_grad(): raw_embeddings = await run_in_threadpool(self.model, inputs) result_embeddings = {} for modality_type, embeddings_tensor in raw_embeddings.items(): modality_key = modality_type.name.lower() result_embeddings[modality_key] = [] ids_for_modality = input_ids.get(modality_type, []) for i, emb in enumerate(embeddings_tensor.cpu().numpy().tolist()): item_id = ids_for_modality[i] if i < len(ids_for_modality) else f"item_{i}" result_embeddings[modality_key].append({"id": item_id, "embedding": emb}) return result_embeddings embedding_manager: Optional[EmbeddingManager] = None @asynccontextmanager async def lifespan(app: FastAPI): global embedding_manager logger.info("Application startup...") embedding_manager = EmbeddingManager() settings.model_device = embedding_manager.device yield logger.info("Application shutdown...") app = FastAPI(lifespan=lifespan, title="ImageBind API", version="0.2.0") security = HTTPBearer() async def verify_token(credentials: HTTPAuthorizationCredentials = Depends(security)): if credentials.scheme != "Bearer" or credentials.credentials != settings.api_token: logger.warning(f"Invalid authentication attempt. Scheme: {credentials.scheme}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token", headers={"WWW-Authenticate": "Bearer"}, ) return credentials.credentials async def _save_upload_file_tmp(upload_file: UploadFile) -> Tuple[str, str]: try: suffix = os.path.splitext(upload_file.filename)[1] with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: content = await upload_file.read() tmp.write(content) return tmp.name, upload_file.filename except Exception as e: logger.error(f"Error saving uploaded file {upload_file.filename}: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Could not save file: {upload_file.filename}") def convert_audio_to_wav(audio_path: str, original_filename: str) -> str: if audio_path.lower().endswith('.mp3') or not audio_path.lower().endswith('.wav'): wav_path = audio_path.rsplit('.', 1)[0] + '.wav' try: logger.info(f"Converting {original_filename} to WAV format.") audio = AudioSegment.from_file(audio_path) audio.export(wav_path, format='wav') if audio_path != wav_path and os.path.exists(audio_path): try: os.unlink(audio_path) except OSError: pass return wav_path except Exception as e: logger.error(f"Error converting audio file {original_filename} to WAV: {e}") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Could not process audio file {original_filename}: {e}") return audio_path class ModalityType(str): VISION = "vision" AUDIO = "audio" TEXT = "text" DEPTH = "depth" THERMAL = "thermal" IMU = "imu" class EmbeddingItem(BaseModel): id: str embedding: List[float] class EmbeddingPayload(BaseModel): vision: Optional[List[EmbeddingItem]] = None audio: Optional[List[EmbeddingItem]] = None text: Optional[List[EmbeddingItem]] = None depth: Optional[List[EmbeddingItem]] = None thermal: Optional[List[EmbeddingItem]] = None imu: Optional[List[EmbeddingItem]] = None class EmbeddingResponse(BaseModel): embeddings: EmbeddingPayload message: str = "Embeddings computed successfully" class SimilarityMatch(BaseModel): item_a_id: str item_b_id: str modality_a: ModalityType modality_b: ModalityType score: float class SimilarityRequest(BaseModel): embeddings_payload: EmbeddingPayload threshold: float = 0.5 top_k: Optional[int] = None normalize_scores: bool = True compare_within_modalities: bool = True compare_across_modalities: bool = True class SimilarityResponse(BaseModel): matches: List[SimilarityMatch] statistics: Dict[str, float] modality_pairs_compared: List[str] @app.post("/compute_embeddings", response_model=EmbeddingResponse, dependencies=[Depends(verify_token)]) async def generate_embeddings_endpoint( texts: Optional[List[str]] = Form(None), images: Optional[List[UploadFile]] = File(default=None), audio_files: Optional[List[UploadFile]] = File(default=None) ): if embedding_manager is None: raise HTTPException(status_code=503, detail="Embedding manager not initialized.") temp_files_to_clean = [] try: image_inputs: List[Tuple[str, str]] = [] audio_inputs: List[Tuple[str, str]] = [] if images: for img_file in images: path, name = await _save_upload_file_tmp(img_file) image_inputs.append((path, name)) temp_files_to_clean.append(path) if audio_files: for audio_file_in in audio_files: path, name = await _save_upload_file_tmp(audio_file_in) temp_files_to_clean.append(path) wav_path = convert_audio_to_wav(path, name) audio_inputs.append((wav_path, name)) if wav_path != path: temp_files_to_clean.append(wav_path) text_inputs_processed = [t.strip() for t in texts if t.strip()] if texts else None if not any([image_inputs, audio_inputs, text_inputs_processed]): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No valid inputs provided for embedding.") computed_data = await embedding_manager.compute_embeddings( image_inputs=image_inputs if image_inputs else None, audio_inputs=audio_inputs if audio_inputs else None, text_inputs=text_inputs_processed if text_inputs_processed else None ) payload_data = { ModalityType.VISION: computed_data.get(ModalityType.VISION, []), ModalityType.AUDIO: computed_data.get(ModalityType.AUDIO, []), ModalityType.TEXT: computed_data.get(ModalityType.TEXT, []), } embedding_payload = EmbeddingPayload(**payload_data) return EmbeddingResponse(embeddings=embedding_payload) except HTTPException: raise except Exception as e: logger.error(f"Error in /compute_embeddings: {e}", exc_info=True) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An unexpected error occurred: {str(e)}") finally: for temp_file in temp_files_to_clean: try: if os.path.exists(temp_file): os.unlink(temp_file) except Exception as e_clean: logger.warning(f"Could not clean up temporary file {temp_file}: {e_clean}") def _compute_similarity_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor, normalize: bool) -> torch.Tensor: if normalize: tensor1 = torch.nn.functional.normalize(tensor1, p=2, dim=1) tensor2 = torch.nn.functional.normalize(tensor2, p=2, dim=1) return torch.matmul(tensor1, tensor2.T) @app.post("/compute_similarities", response_model=SimilarityResponse, dependencies=[Depends(verify_token)]) async def compute_similarities_endpoint(request: SimilarityRequest): all_matches: List[SimilarityMatch] = [] all_scores: List[float] = [] modality_pairs_compared_set = set() embeddings_by_modality: Dict[ModalityType, List[EmbeddingItem]] = {} if request.embeddings_payload.vision: embeddings_by_modality[ModalityType.VISION] = request.embeddings_payload.vision if request.embeddings_payload.audio: embeddings_by_modality[ModalityType.AUDIO] = request.embeddings_payload.audio if request.embeddings_payload.text: embeddings_by_modality[ModalityType.TEXT] = request.embeddings_payload.text modalities_present = list(embeddings_by_modality.keys()) current_device = embedding_manager.device if embedding_manager else "cpu" for i, mod1_type in enumerate(modalities_present): items1 = embeddings_by_modality[mod1_type] if not items1: continue tensor1 = torch.tensor([item.embedding for item in items1], device=current_device) if request.compare_within_modalities: sim_matrix_intra = _compute_similarity_matrix(tensor1, tensor1, request.normalize_scores) modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod1_type.value}") for r_idx in range(len(items1)): for c_idx in range(r_idx + 1, len(items1)): score = float(sim_matrix_intra[r_idx, c_idx].item()) if score >= request.threshold: all_matches.append(SimilarityMatch( item_a_id=items1[r_idx].id, item_b_id=items1[c_idx].id, modality_a=mod1_type, modality_b=mod1_type, score=score )) all_scores.append(score) if request.compare_across_modalities: for j in range(i + 1, len(modalities_present)): mod2_type = modalities_present[j] items2 = embeddings_by_modality[mod2_type] if not items2: continue tensor2 = torch.tensor([item.embedding for item in items2], device=current_device) sim_matrix_inter = _compute_similarity_matrix(tensor1, tensor2, request.normalize_scores) modality_pairs_compared_set.add(f"{mod1_type.value}_vs_{mod2_type.value}") for r_idx in range(len(items1)): for c_idx in range(len(items2)): score = float(sim_matrix_inter[r_idx, c_idx].item()) if score >= request.threshold: all_matches.append(SimilarityMatch( item_a_id=items1[r_idx].id, item_b_id=items2[c_idx].id, modality_a=mod1_type, modality_b=mod2_type, score=score )) all_scores.append(score) all_matches.sort(key=lambda x: x.score, reverse=True) if request.top_k and len(all_matches) > request.top_k: all_matches = all_matches[:request.top_k] all_scores = [match.score for match in all_matches] stats = { "total_matches_found_above_threshold": len(all_matches), "avg_score": float(np.mean(all_scores)) if all_scores else 0.0, "max_score": float(np.max(all_scores)) if all_scores else 0.0, "min_score": float(np.min(all_scores)) if all_scores else 0.0, } return SimilarityResponse( matches=all_matches, statistics=stats, modality_pairs_compared=sorted(list(modality_pairs_compared_set)) ) @app.get("/health", status_code=status.HTTP_200_OK, dependencies=[Depends(verify_token)]) async def health_check(): return { "status": "healthy", "model_device": settings.model_device, "torch_version": torch.__version__, "cuda_available": torch.cuda.is_available() } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860, log_level=settings.log_level.lower())