File size: 3,679 Bytes
3301b3c 04db7e0 6c61722 c613bb1 3301b3c 6c61722 33f4e34 a1d050d 3301b3c 33f4e34 3301b3c 04db7e0 33f4e34 3301b3c 6c61722 3301b3c a1d050d 3301b3c 04db7e0 6c61722 04db7e0 6c61722 04db7e0 6c61722 04db7e0 6c61722 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
"""
Utilities module: LLM client wrapper and shared helpers.
"""
import os
import openai
from typing import List
from openai import AzureOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from src import logger
class LLMClient:
"""
Simple wrapper around OpenAI (or any other) LLM API.
Reads API key from environment and exposes `generate(prompt)`.
"""
@staticmethod
def generate(prompt: str, model: str = None, max_tokens: int = 512, **kwargs) -> tuple[str, str]:
azure_api_key = os.getenv('AZURE_API_KEY')
azure_endpoint = os.getenv('AZURE_ENDPOINT')
azure_api_version = os.getenv('AZURE_API_VERSION')
openai_model_name = model or os.getenv('OPENAI_MODEL', 'gpt-4o')
if not (azure_api_key or azure_endpoint or azure_api_version or openai_model_name):
logger.error('OPENAI_API_KEY is not set')
raise EnvironmentError('Missing OPENAI_API_KEY')
client = AzureOpenAI(
api_key=azure_api_key,
azure_endpoint=azure_endpoint,
api_version=azure_api_version
)
try:
resp = client.chat.completions.create(
model=openai_model_name,
messages=[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=0.0,
**kwargs
)
text = resp.choices[0].message.content.strip()
finish_reason = resp.choices[0].finish_reason
return text, finish_reason
except Exception as e:
logger.error(f'LLM generation failed: {e}')
raise
class LocalEmbedder:
"""
Wrapper for a local SentenceTransformer model.
"""
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
logger.info(f"Initialized local embedder with model: {model_name}")
def embed(self, texts: List[str]) -> List[List[float]]:
"""Embeds a list of texts using the local SentenceTransformer model."""
try:
embeddings = self.model.encode(texts, show_progress_bar=False)
return embeddings.tolist()
except Exception as e:
logger.error(f"Local embedding failed: {e}")
raise
class OpenAIEmbedder:
"""
Wrapper around OpenAI and Azure OpenAI Embeddings.
Automatically uses Azure credentials if available, otherwise falls back to OpenAI.
"""
def __init__(self, model_name: str):
self.model_name = model_name
self.is_azure = os.getenv('AZURE_API_KEY') and os.getenv('AZURE_ENDPOINT')
if self.is_azure:
logger.info("Using Azure OpenAI for embeddings.")
self.embedder = AzureOpenAIEmbeddings(
model=self.model_name,
azure_deployment=os.getenv("AZURE_EMBEDDING_DEPLOYMENT"), # Assumes a deployment name is set
api_version=os.getenv("AZURE_API_VERSION")
)
else:
logger.info("Using standard OpenAI for embeddings.")
# This part would need OPENAI_API_KEY to be set
from langchain_openai import OpenAIEmbeddings
self.embedder = OpenAIEmbeddings(model=self.model_name)
def embed(self, texts: List[str]) -> List[List[float]]:
"""Embeds a list of texts."""
try:
return self.embedder.embed_documents(texts)
except Exception as e:
logger.error(f"Embedding failed: {e}")
raise |