Spaces:
Sleeping
Sleeping
File size: 1,726 Bytes
6843cf8 780954b 6843cf8 780954b 6843cf8 4b87b66 6843cf8 4b87b66 6843cf8 4b87b66 780954b 4b87b66 6843cf8 780954b 4b87b66 780954b 7a41e86 780954b 7a41e86 780954b 7a41e86 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
from typing import Optional
from google import genai
from google.genai import types
from .base import AbstractLLMModel
from .registry import register_llm_model
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
@register_llm_model("gemini-2.5-flash")
class GeminiLLM(AbstractLLMModel):
def __init__(
self, model_id: str, device: str = "auto", cache_dir: str = "cache", **kwargs
):
if not GOOGLE_API_KEY:
raise ValueError(
"Please set the GOOGLE_API_KEY environment variable to use Gemini."
)
super().__init__(model_id=model_id, **kwargs)
self.client = genai.Client(api_key=GOOGLE_API_KEY)
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_output_tokens: int = 1024,
max_iterations: int = 3,
**kwargs,
) -> str:
generation_config_dict = {
"max_output_tokens": max_output_tokens,
**kwargs,
}
if system_prompt:
generation_config_dict["system_instruction"] = system_prompt
for _ in range(max_iterations):
response = self.client.models.generate_content(
model=self.model_id,
contents=prompt,
config=types.GenerateContentConfig(**generation_config_dict),
)
if response.text:
return response.text
else:
print(
f"No response from Gemini. May need to increase max_output_tokens. Current {max_output_tokens=}. Try again."
)
print(f"Failed to generate response from Gemini after {max_iterations} attempts.")
return ""
|