PopovDanil's picture
try 1
48ec4db
import asyncio
import os
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from sentence_transformers import (
CrossEncoder,
) # SentenceTransformer -> model for embeddings, CrossEncoder -> re-ranker
from torch import Tensor
from google import genai
from google.genai import types
from app.core.chunks import Chunk
from app.settings import settings, BASE_DIR, GeminiEmbeddingSettings, logger
import aiofiles.os
load_dotenv()
class Reranker:
def __init__(self, model: str = "cross-encoder/ms-marco-MiniLM-L6-v2"):
self.device: str = settings.device
self.model_name: str = model
self.model: CrossEncoder = CrossEncoder(model, device=self.device)
async def rank(self, query: str, chunks: list[Chunk]) -> list[dict[str, int]]:
return await asyncio.to_thread(self.model.rank, query, [await chunk.get_raw_text() for chunk in chunks])
class GeminiLLM:
def __init__(self, model="gemini-2.0-flash"):
self.client = genai.Client(api_key=settings.api_key)
self.model = model
async def get_response(self, prompt: str, stream: bool = True, use_default_config: bool = False) -> str:
path_to_prompt = os.path.join(BASE_DIR, "models_io", "prompt.txt")
async with aiofiles.open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f:
await f.write(prompt)
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=(
types.GenerateContentConfig(**settings.gemini_generation.model_dump())
if use_default_config
else None
),
)
return response.text
async def get_streaming_response(self, prompt: str, use_default_config: bool = False):
loop = asyncio.get_event_loop()
start = loop.time()
path_to_prompt = os.path.join(BASE_DIR, "models_io", "prompt.txt")
async with aiofiles.open(path_to_prompt, "w", encoding="utf-8", errors="replace") as f:
await f.write(prompt)
await logger.info(f"Time of saving prompt to document - {loop.time() - start}")
response = self.client.models.generate_content_stream(
model=self.model,
contents=prompt,
config=(
types.GenerateContentConfig(**settings.gemini_generation.model_dump())
if use_default_config
else None
),
)
for chunk in response:
yield chunk
class GeminiEmbed:
def __init__(self, model="text-embedding-004"):
self.client = genai.Client(api_key=settings.api_key)
self.model = model
self.settings = GeminiEmbeddingSettings()
self.max_workers = 5
self.embed_executor = ThreadPoolExecutor(max_workers=self.max_workers)
def _embed_batch_sync(self, batch: list[str], idx: int) -> dict:
response = self.client.models.embed_content(
model=self.model,
contents=batch,
config=types.EmbedContentConfig(
**settings.gemini_embedding.model_dump()
),
).embeddings
return {"idx": idx, "embeddings": response}
async def encode(self, text: str | list[str]) -> list[Tensor]:
if isinstance(text, str):
text = [text]
groups: list[dict] = []
max_batch_size = 100 # can not be changed due to google restrictions
batches: list[list[str]] = [
text[i : i + max_batch_size]
for i in range(0, len(text), max_batch_size)
]
loop = asyncio.get_running_loop()
tasks = [
loop.run_in_executor(
self.embed_executor,
self._embed_batch_sync,
batch,
idx
)
for idx, batch in enumerate(batches)
]
groups = await asyncio.gather(*tasks)
groups.sort(key=lambda x: x["idx"])
result: list[float] = []
for group in groups:
for vec in group["embeddings"]:
result.append(vec.values)
return result
async def get_vector_dimensionality(self) -> int | None:
return getattr(self.settings, "output_dimensionality")
class Wrapper:
def __init__(self, model: str = "gemini-2.0-flash"):
self.model = model
self.client = genai.Client(api_key=settings.api_key)
async def wrap(self, prompt: str) -> str:
def wrapper(prompt):
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(**settings.gemini_wrapper.model_dump())
)
return response.text
return await asyncio.to_thread(wrapper, prompt)