File size: 1,726 Bytes
6843cf8
780954b
6843cf8
780954b
 
6843cf8
4b87b66
 
 
 
6843cf8
4b87b66
 
 
6843cf8
4b87b66
780954b
4b87b66
6843cf8
780954b
 
 
4b87b66
780954b
 
 
 
 
 
 
7a41e86
780954b
 
 
 
 
 
 
 
7a41e86
 
 
 
 
780954b
7a41e86
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from typing import Optional

from google import genai
from google.genai import types

from .base import AbstractLLMModel
from .registry import register_llm_model


GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")


@register_llm_model("gemini-2.5-flash")
class GeminiLLM(AbstractLLMModel):
    def __init__(
        self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
    ):
        if not GOOGLE_API_KEY:
            raise ValueError(
                "Please set the GOOGLE_API_KEY environment variable to use Gemini."
            )
        super().__init__(model_id=model_id, **kwargs)
        self.client = genai.Client(api_key=GOOGLE_API_KEY)

    def generate(
        self,
        prompt: str,
        system_prompt: Optional[str] = None,
        max_output_tokens: int = 1024,
        max_iterations: int = 3,
        **kwargs,
    ) -> str:
        generation_config_dict = {
            "max_output_tokens": max_output_tokens,
            **kwargs,
        }
        if system_prompt:
            generation_config_dict["system_instruction"] = system_prompt
        for _ in range(max_iterations):
            response = self.client.models.generate_content(
                model=self.model_id,
                contents=prompt,
                config=types.GenerateContentConfig(**generation_config_dict),
            )
            if response.text:
                return response.text
            else:
                print(
                    f"No response from Gemini. May need to increase max_output_tokens. Current {max_output_tokens=}. Try again."
                )
        print(f"Failed to generate response from Gemini after {max_iterations} attempts.")
        return ""