import asyncio import os from concurrent.futures import ThreadPoolExecutor from dotenv import load_dotenv from sentence_transformers import ( CrossEncoder, ) # SentenceTransformer -> model for embeddings, CrossEncoder -> re-ranker from torch import Tensor from google import genai from google.genai import types from app.core.chunks import Chunk from app.settings import settings, BASE_DIR, GeminiEmbeddingSettings, logger import aiofiles.os load_dotenv() class Reranker: def __init__(self, model: str = "cross-encoder/ms-marco-MiniLM-L6-v2"): self.device: str = settings.device self.model_name: str = model self.model: CrossEncoder = CrossEncoder(model, device=self.device) async def rank(self, query: str, chunks: list[Chunk]) -> list[dict[str, int]]: return await asyncio.to_thread(self.model.rank, query, [await chunk.get_raw_text() for chunk in chunks]) class GeminiLLM: def __init__(self, model="gemini-2.0-flash"): self.client = genai.Client(api_key=settings.api_key) self.model = model async def get_response(self, prompt: str, stream: bool = True, use_default_config: bool = False) -> str: path_to_prompt = os.path.join(BASE_DIR, "models_io", "prompt.txt") async with aiofiles.open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f: await f.write(prompt) response = self.client.models.generate_content( model=self.model, contents=prompt, config=( types.GenerateContentConfig(**settings.gemini_generation.model_dump()) if use_default_config else None ), ) return response.text async def get_streaming_response(self, prompt: str, use_default_config: bool = False): loop = asyncio.get_event_loop() start = loop.time() path_to_prompt = os.path.join(BASE_DIR, "models_io", "prompt.txt") async with aiofiles.open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f: await f.write(prompt) await logger.info(f"Time of saving prompt to document - {loop.time() - start}") response = self.client.models.generate_content_stream( model=self.model, contents=prompt, config=( types.GenerateContentConfig(**settings.gemini_generation.model_dump()) if use_default_config else None ), ) for chunk in response: yield chunk class GeminiEmbed: def __init__(self, model="text-embedding-004"): self.client = genai.Client(api_key=settings.api_key) self.model = model self.settings = GeminiEmbeddingSettings() self.max_workers = 5 self.embed_executor = ThreadPoolExecutor(max_workers=self.max_workers) def _embed_batch_sync(self, batch: list[str], idx: int) -> dict: response = self.client.models.embed_content( model=self.model, contents=batch, config=types.EmbedContentConfig( **settings.gemini_embedding.model_dump() ), ).embeddings return {"idx": idx, "embeddings": response} async def encode(self, text: str | list[str]) -> list[Tensor]: if isinstance(text, str): text = [text] groups: list[dict] = [] max_batch_size = 100 # can not be changed due to google restrictions batches: list[list[str]] = [ text[i : i + max_batch_size] for i in range(0, len(text), max_batch_size) ] loop = asyncio.get_running_loop() tasks = [ loop.run_in_executor( self.embed_executor, self._embed_batch_sync, batch, idx ) for idx, batch in enumerate(batches) ] groups = await asyncio.gather(*tasks) groups.sort(key=lambda x: x["idx"]) result: list[float] = [] for group in groups: for vec in group["embeddings"]: result.append(vec.values) return result async def get_vector_dimensionality(self) -> int | None: return getattr(self.settings, "output_dimensionality") class Wrapper: def __init__(self, model: str = "gemini-2.0-flash"): self.model = model self.client = genai.Client(api_key=settings.api_key) async def wrap(self, prompt: str) -> str: def wrapper(prompt): response = self.client.models.generate_content( model=self.model, contents=prompt, config=types.GenerateContentConfig(**settings.gemini_wrapper.model_dump()) ) return response.text return await asyncio.to_thread(wrapper, prompt)