attributionapi / app /services /service_gemini.py
chips
gemini 2.5 flash
4a370c7
import json
import os
import io
import base64
from typing import Any, Dict, List, Type, Union, Optional
import google.generativeai as genai
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings
import weave # Assuming weave is still used
from pydantic import BaseModel, ValidationError # For schema validation
# Assuming these utilities are in the same relative paths or accessible
from app.utils.converter import product_data_to_str
from app.utils.image_processing import (
get_data_format, # Assuming this returns 'jpeg', 'png' etc.
get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str)
get_image_data, # Assuming this reads local path and returns base64_str
)
from app.utils.logger import exception_to_str, setup_logger
# Assuming these are correctly defined and accessible
from ..config import get_settings
from ..core import errors
from ..core.errors import BadRequestError, VendorError # Using your custom errors
from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted
from .base import BaseAttributionService # Assuming this base class exists
# Environment and Weave setup ( 그대로 유지 )
ENV = os.getenv("ENV", "LOCAL")
if ENV == "LOCAL":
weave_project_name = "cfai/attribution-exp"
elif ENV == "DEV":
weave_project_name = "cfai/attribution-dev"
elif ENV == "UAT":
weave_project_name = "cfai/attribution-uat"
elif ENV == "PROD":
pass # No weave for PROD
if ENV != "PROD":
# weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here
print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}")
settings = get_settings()
prompts = get_prompts()
logger = setup_logger(__name__)
# Configure the Gemini client
try:
if settings.GEMINI_API_KEY:
genai.configure(api_key=settings.GEMINI_API_KEY)
else:
logger.error("GEMINI_API_KEY not found in settings.")
# Potentially raise an error or handle this case as per application requirements
except AttributeError:
logger.error("Settings object does not have GEMINI_API_KEY attribute.")
# Handle missing settings attribute
# Define default safety settings for Gemini
# Adjust these as per your application's requirements
DEFAULT_SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
class GeminiService(BaseAttributionService):
def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"):
"""
Initializes the GeminiService.
Args:
model_name (str): The name of the Gemini model to use.
"""
try:
self.model = genai.GenerativeModel(
model_name,
safety_settings=DEFAULT_SAFETY_SETTINGS
# system_instruction can be set here if a global system message is always used
)
logger.info(f"GeminiService initialized with model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}")
# Depending on requirements, you might want to raise an error here
# For now, we'll let it proceed, and calls will fail if model isn't initialized.
self.model = None
def _prepare_image_parts(
self,
img_urls: Optional[List[str]] = None,
img_paths: Optional[List[str]] = None,
pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects
) -> List[Dict[str, Any]]:
"""
Prepares image data in the format expected by Gemini API.
Decodes base64 image data to bytes.
Converts PIL images to bytes.
"""
image_parts = []
# Process image URLs
if img_urls:
for img_url in img_urls:
try:
base64_data, img_type = get_image_base64_and_type(img_url)
if base64_data and img_type:
# Gemini expects raw bytes, so decode base64
image_bytes = base64.b64decode(base64_data)
mime_type = f"image/{img_type.lower()}"
image_parts.append({"mime_type": mime_type, "data": image_bytes})
else:
logger.warning(f"Could not retrieve or identify type for image URL: {img_url}")
except Exception as e:
logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}")
# Process image paths
if img_paths:
for img_path in img_paths:
try:
base64_data = get_image_data(img_path) # Assuming this returns base64 string
img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg'
if base64_data and img_type:
image_bytes = base64.b64decode(base64_data)
mime_type = f"image/{img_type.lower()}"
image_parts.append({"mime_type": mime_type, "data": image_bytes})
else:
logger.warning(f"Could not retrieve or identify type for image path: {img_path}")
except Exception as e:
logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}")
# Process PIL images
if pil_images:
for i, pil_image in enumerate(pil_images):
try:
img_format = pil_image.format or 'PNG' # Default to PNG if format is not available
mime_type = f"image/{img_format.lower()}"
with io.BytesIO() as img_byte_arr:
pil_image.save(img_byte_arr, format=img_format)
image_bytes = img_byte_arr.getvalue()
image_parts.append({"mime_type": mime_type, "data": image_bytes})
except Exception as e:
logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}")
return image_parts
@weave.op() # Assuming weave.op can be used as a decorator directly
async def extract_attributes(
self,
attributes_model: Type[BaseModel],
ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest"
img_urls: Optional[List[str]] = None,
product_taxonomy: str = "",
product_data: Optional[Dict[str, Union[str, List[str]]]] = None,
pil_images: Optional[List[Any]] = None,
img_paths: Optional[List[str]] = None,
) -> Dict[str, Any]:
if not self.model:
raise VendorError("Gemini model not initialized.")
if self.model.model_name != ai_model: # If a different model is requested for this specific call
logger.info(f"Switching to model {ai_model} for this extraction request.")
# Note: This creates a new model object for the call.
# If this happens frequently, consider how model instances are managed.
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
else:
current_model = self.model
# Construct the prompt text
# Combining system and human prompts as Gemini typically takes a list of contents.
# System instructions can also be part of the model's initialization.
system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE
human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
product_taxonomy=product_taxonomy,
product_data=product_data_to_str(product_data if product_data else {}),
)
full_prompt_text = f"{system_message}\n\n{human_message}"
# For logging or debugging the prompt
logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet
content_parts = [full_prompt_text]
# Prepare image parts
try:
image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images)
content_parts.extend(image_parts)
except Exception as e:
logger.error(f"Failed during image preparation: {exception_to_str(e)}")
raise BadRequestError(f"Image processing failed: {e}")
if not image_parts and (img_urls or img_paths or pil_images):
logger.warning("Image sources provided, but no image parts were successfully prepared.")
# Define generation config for JSON output
# Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary.
try:
schema_for_gemini = attributes_model.model_json_schema()
except Exception as e:
logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}")
raise VendorError(f"Could not generate schema for attributes_model: {e}")
generation_config = GenerationConfig(
response_mime_type="application/json",
response_schema=schema_for_gemini, # Gemini expects the schema here
temperature=0.0, # For deterministic output, similar to low top_p
max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI
# top_p, top_k can also be set if needed
)
logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...")
try:
response = await current_model.generate_content_async(
contents=content_parts,
generation_config=generation_config,
# request_options={"timeout": 120} # Example: set timeout in seconds
)
except Exception as e: # Catches google.api_core.exceptions and others
error_message = exception_to_str(e)
logger.error(f"Gemini API call failed: {error_message}")
# More specific error handling for Gemini can be added here
# e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument):
# raise BadRequestError(f"Invalid argument to Gemini: {error_message}")
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
# Process the response
try:
# Check for safety blocks or refusals
if not response.candidates:
# This can happen if all candidates were filtered due to safety or other reasons.
block_reason_detail = "Unknown reason (no candidates)"
if response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
if response.prompt_feedback.block_reason_message:
block_reason_detail += f" - {response.prompt_feedback.block_reason_message}"
logger.error(f"Gemini response was blocked or empty. {block_reason_detail}")
raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}")
# Assuming the first candidate is the one we want
candidate = response.candidates[0]
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
logger.warning(f"Gemini generation finished with reason: {finish_reason_str}")
# Potentially raise error if finish reason is SAFETY, RECITATION, etc.
if finish_reason_str == "SAFETY":
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]")
if not candidate.content.parts or not candidate.content.parts[0].text:
logger.error("Gemini response content is empty or not in the expected text format.")
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)")
response_text = candidate.content.parts[0].text
# Parse and validate the JSON response using the Pydantic model
parsed_data = attributes_model.model_validate_json(response_text)
return parsed_data.model_dump() # Return as dict
except ValidationError as ve:
logger.error(f"Pydantic validation failed for Gemini response: {ve}")
logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}")
except json.JSONDecodeError as je:
logger.error(f"JSON decoding failed for Gemini response: {je}")
logger.debug(f"Non-JSON response received: {response_text[:500]}...")
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}")
except VendorError: # Re-raise VendorErrors
raise
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Error processing Gemini response: {error_message}")
# Log the raw response text if available and an error occurred
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}")
raise VendorError(f"Failed to process Gemini response: {error_message}")
@weave.op()
async def follow_schema(
self,
schema: Dict[str, Any], # This should be an OpenAPI schema dictionary
data: Dict[str, Any],
ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task
) -> Dict[str, Any]:
if not self.model: # Check if the main model was initialized
logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.")
try:
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
except Exception as e:
raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}")
elif self.model.model_name != ai_model:
logger.info(f"Switching to model {ai_model} for this follow_schema request.")
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
else:
current_model = self.model
logger.info(f"Following schema via Gemini model: {current_model.model_name}...")
# Prepare the prompt
# System message can be part of the model or prepended here.
system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE
# The human message needs to contain the data to be transformed.
# Ensure `json_info` placeholder is correctly used by your prompt string.
try:
data_as_json_string = json.dumps(data, indent=2)
except TypeError as te:
logger.error(f"Could not serialize 'data' to JSON for prompt: {te}")
raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}")
human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string)
full_prompt_text = f"{system_message}\n\n{human_message}"
content_parts = [full_prompt_text]
# Define generation config for JSON output using the provided schema
generation_config = GenerationConfig(
response_mime_type="application/json",
response_schema=schema, # The provided schema dictionary
temperature=0.0, # For deterministic output
max_output_tokens=2048, # Adjust as needed
)
try:
response = await current_model.generate_content_async(
contents=content_parts,
generation_config=generation_config,
)
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Gemini API call failed for follow_schema: {error_message}")
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
# Process response
try:
if not response.candidates:
block_reason_detail = "Unknown reason (no candidates)"
if response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}")
# OpenAI version returned {"status": "refused"}, mimicking similar for block
return {"status": "refused", "reason": block_reason_detail}
candidate = response.candidates[0]
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}")
if finish_reason_str == "SAFETY":
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"}
if not candidate.content.parts or not candidate.content.parts[0].text:
logger.error("Gemini response content (follow_schema) is empty.")
# Mimic OpenAI's refusal structure or raise error
return {"status": "refused", "reason": "Empty content from Gemini"}
response_text = candidate.content.parts[0].text
parsed_data = json.loads(response_text) # The schema is enforced by Gemini
return parsed_data
except json.JSONDecodeError as je:
logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}")
logger.debug(f"Non-JSON response received: {response_text[:500]}...")
# The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON)
# Let's use VendorError for consistency if that's preferred, or ValueError
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}")
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Error processing Gemini response (follow_schema): {error_message}")
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}")
raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}")