File size: 3,679 Bytes
3301b3c
 
 
 
 
04db7e0
 
 
6c61722
c613bb1
3301b3c
 
 
 
 
 
 
 
6c61722
33f4e34
 
 
 
 
 
a1d050d
3301b3c
33f4e34
 
 
 
 
3301b3c
04db7e0
33f4e34
3301b3c
 
 
 
 
 
 
6c61722
 
3301b3c
a1d050d
3301b3c
04db7e0
 
6c61722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04db7e0
 
6c61722
 
04db7e0
 
6c61722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04db7e0
 
6c61722
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Utilities module: LLM client wrapper and shared helpers.
"""
import os
import openai
from typing import List
from openai import AzureOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from src import logger


class LLMClient:
    """
    Simple wrapper around OpenAI (or any other) LLM API.
    Reads API key from environment and exposes `generate(prompt)`.
    """
    @staticmethod
    def generate(prompt: str, model: str = None, max_tokens: int = 512, **kwargs) -> tuple[str, str]:
        azure_api_key = os.getenv('AZURE_API_KEY')
        azure_endpoint = os.getenv('AZURE_ENDPOINT')
        azure_api_version = os.getenv('AZURE_API_VERSION')
        openai_model_name = model or os.getenv('OPENAI_MODEL', 'gpt-4o')

        if not (azure_api_key or azure_endpoint or azure_api_version or openai_model_name):
            logger.error('OPENAI_API_KEY is not set')
            raise EnvironmentError('Missing OPENAI_API_KEY')
        client = AzureOpenAI(
                api_key=azure_api_key,
                azure_endpoint=azure_endpoint,
                api_version=azure_api_version
            )
        try:
            resp = client.chat.completions.create(
                model=openai_model_name,
                messages=[{"role": "system", "content": "You are a helpful assistant."},
                          {"role": "user", "content": prompt}],
                max_tokens=max_tokens,
                temperature=0.0,
                **kwargs
            )
            text = resp.choices[0].message.content.strip()
            finish_reason = resp.choices[0].finish_reason
            return text, finish_reason
        except Exception as e:
            logger.error(f'LLM generation failed: {e}')
            raise


class LocalEmbedder:
    """
    Wrapper for a local SentenceTransformer model.
    """
    def __init__(self, model_name: str):
        self.model = SentenceTransformer(model_name)
        logger.info(f"Initialized local embedder with model: {model_name}")

    def embed(self, texts: List[str]) -> List[List[float]]:
        """Embeds a list of texts using the local SentenceTransformer model."""
        try:
            embeddings = self.model.encode(texts, show_progress_bar=False)
            return embeddings.tolist()
        except Exception as e:
            logger.error(f"Local embedding failed: {e}")
            raise


class OpenAIEmbedder:
    """
    Wrapper around OpenAI and Azure OpenAI Embeddings.
    Automatically uses Azure credentials if available, otherwise falls back to OpenAI.
    """
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.is_azure = os.getenv('AZURE_API_KEY') and os.getenv('AZURE_ENDPOINT')
        
        if self.is_azure:
            logger.info("Using Azure OpenAI for embeddings.")
            self.embedder = AzureOpenAIEmbeddings(
                model=self.model_name,
                azure_deployment=os.getenv("AZURE_EMBEDDING_DEPLOYMENT"), # Assumes a deployment name is set
                api_version=os.getenv("AZURE_API_VERSION")
            )
        else:
            logger.info("Using standard OpenAI for embeddings.")
            # This part would need OPENAI_API_KEY to be set
            from langchain_openai import OpenAIEmbeddings
            self.embedder = OpenAIEmbeddings(model=self.model_name)

    def embed(self, texts: List[str]) -> List[List[float]]:
        """Embeds a list of texts."""
        try:
            return self.embedder.embed_documents(texts)
        except Exception as e:
            logger.error(f"Embedding failed: {e}")
            raise