import json import os import io import base64 from typing import Any, Dict, List, Type, Union, Optional import google.generativeai as genai from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings import weave # Assuming weave is still used from pydantic import BaseModel, ValidationError # For schema validation # Assuming these utilities are in the same relative paths or accessible from app.utils.converter import product_data_to_str from app.utils.image_processing import ( get_data_format, # Assuming this returns 'jpeg', 'png' etc. get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str) get_image_data, # Assuming this reads local path and returns base64_str ) from app.utils.logger import exception_to_str, setup_logger # Assuming these are correctly defined and accessible from ..config import get_settings from ..core import errors from ..core.errors import BadRequestError, VendorError # Using your custom errors from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted from .base import BaseAttributionService # Assuming this base class exists # Environment and Weave setup ( 그대로 유지 ) ENV = os.getenv("ENV", "LOCAL") if ENV == "LOCAL": weave_project_name = "cfai/attribution-exp" elif ENV == "DEV": weave_project_name = "cfai/attribution-dev" elif ENV == "UAT": weave_project_name = "cfai/attribution-uat" elif ENV == "PROD": pass # No weave for PROD if ENV != "PROD": # weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}") settings = get_settings() prompts = get_prompts() logger = setup_logger(__name__) # Configure the Gemini client try: if settings.GEMINI_API_KEY: genai.configure(api_key=settings.GEMINI_API_KEY) else: logger.error("GEMINI_API_KEY not found in settings.") # Potentially raise an error or handle this case as per application requirements except AttributeError: logger.error("Settings object does not have GEMINI_API_KEY attribute.") # Handle missing settings attribute # Define default safety settings for Gemini # Adjust these as per your application's requirements DEFAULT_SAFETY_SETTINGS = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, } class GeminiService(BaseAttributionService): def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"): """ Initializes the GeminiService. Args: model_name (str): The name of the Gemini model to use. """ try: self.model = genai.GenerativeModel( model_name, safety_settings=DEFAULT_SAFETY_SETTINGS # system_instruction can be set here if a global system message is always used ) logger.info(f"GeminiService initialized with model: {model_name}") except Exception as e: logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}") # Depending on requirements, you might want to raise an error here # For now, we'll let it proceed, and calls will fail if model isn't initialized. self.model = None def _prepare_image_parts( self, img_urls: Optional[List[str]] = None, img_paths: Optional[List[str]] = None, pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects ) -> List[Dict[str, Any]]: """ Prepares image data in the format expected by Gemini API. Decodes base64 image data to bytes. Converts PIL images to bytes. """ image_parts = [] # Process image URLs if img_urls: for img_url in img_urls: try: base64_data, img_type = get_image_base64_and_type(img_url) if base64_data and img_type: # Gemini expects raw bytes, so decode base64 image_bytes = base64.b64decode(base64_data) mime_type = f"image/{img_type.lower()}" image_parts.append({"mime_type": mime_type, "data": image_bytes}) else: logger.warning(f"Could not retrieve or identify type for image URL: {img_url}") except Exception as e: logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}") # Process image paths if img_paths: for img_path in img_paths: try: base64_data = get_image_data(img_path) # Assuming this returns base64 string img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg' if base64_data and img_type: image_bytes = base64.b64decode(base64_data) mime_type = f"image/{img_type.lower()}" image_parts.append({"mime_type": mime_type, "data": image_bytes}) else: logger.warning(f"Could not retrieve or identify type for image path: {img_path}") except Exception as e: logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}") # Process PIL images if pil_images: for i, pil_image in enumerate(pil_images): try: img_format = pil_image.format or 'PNG' # Default to PNG if format is not available mime_type = f"image/{img_format.lower()}" with io.BytesIO() as img_byte_arr: pil_image.save(img_byte_arr, format=img_format) image_bytes = img_byte_arr.getvalue() image_parts.append({"mime_type": mime_type, "data": image_bytes}) except Exception as e: logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}") return image_parts @weave.op() # Assuming weave.op can be used as a decorator directly async def extract_attributes( self, attributes_model: Type[BaseModel], ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest" img_urls: Optional[List[str]] = None, product_taxonomy: str = "", product_data: Optional[Dict[str, Union[str, List[str]]]] = None, pil_images: Optional[List[Any]] = None, img_paths: Optional[List[str]] = None, ) -> Dict[str, Any]: if not self.model: raise VendorError("Gemini model not initialized.") if self.model.model_name != ai_model: # If a different model is requested for this specific call logger.info(f"Switching to model {ai_model} for this extraction request.") # Note: This creates a new model object for the call. # If this happens frequently, consider how model instances are managed. current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) else: current_model = self.model # Construct the prompt text # Combining system and human prompts as Gemini typically takes a list of contents. # System instructions can also be part of the model's initialization. system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( product_taxonomy=product_taxonomy, product_data=product_data_to_str(product_data if product_data else {}), ) full_prompt_text = f"{system_message}\n\n{human_message}" # For logging or debugging the prompt logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet content_parts = [full_prompt_text] # Prepare image parts try: image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images) content_parts.extend(image_parts) except Exception as e: logger.error(f"Failed during image preparation: {exception_to_str(e)}") raise BadRequestError(f"Image processing failed: {e}") if not image_parts and (img_urls or img_paths or pil_images): logger.warning("Image sources provided, but no image parts were successfully prepared.") # Define generation config for JSON output # Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary. try: schema_for_gemini = attributes_model.model_json_schema() except Exception as e: logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}") raise VendorError(f"Could not generate schema for attributes_model: {e}") generation_config = GenerationConfig( response_mime_type="application/json", response_schema=schema_for_gemini, # Gemini expects the schema here temperature=0.0, # For deterministic output, similar to low top_p max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI # top_p, top_k can also be set if needed ) logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...") try: response = await current_model.generate_content_async( contents=content_parts, generation_config=generation_config, # request_options={"timeout": 120} # Example: set timeout in seconds ) except Exception as e: # Catches google.api_core.exceptions and others error_message = exception_to_str(e) logger.error(f"Gemini API call failed: {error_message}") # More specific error handling for Gemini can be added here # e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument): # raise BadRequestError(f"Invalid argument to Gemini: {error_message}") raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) # Process the response try: # Check for safety blocks or refusals if not response.candidates: # This can happen if all candidates were filtered due to safety or other reasons. block_reason_detail = "Unknown reason (no candidates)" if response.prompt_feedback and response.prompt_feedback.block_reason: block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" if response.prompt_feedback.block_reason_message: block_reason_detail += f" - {response.prompt_feedback.block_reason_message}" logger.error(f"Gemini response was blocked or empty. {block_reason_detail}") raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}") # Assuming the first candidate is the one we want candidate = response.candidates[0] if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" logger.warning(f"Gemini generation finished with reason: {finish_reason_str}") # Potentially raise error if finish reason is SAFETY, RECITATION, etc. if finish_reason_str == "SAFETY": safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]") if not candidate.content.parts or not candidate.content.parts[0].text: logger.error("Gemini response content is empty or not in the expected text format.") raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)") response_text = candidate.content.parts[0].text # Parse and validate the JSON response using the Pydantic model parsed_data = attributes_model.model_validate_json(response_text) return parsed_data.model_dump() # Return as dict except ValidationError as ve: logger.error(f"Pydantic validation failed for Gemini response: {ve}") logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}") except json.JSONDecodeError as je: logger.error(f"JSON decoding failed for Gemini response: {je}") logger.debug(f"Non-JSON response received: {response_text[:500]}...") raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}") except VendorError: # Re-raise VendorErrors raise except Exception as e: error_message = exception_to_str(e) logger.error(f"Error processing Gemini response: {error_message}") # Log the raw response text if available and an error occurred raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}") raise VendorError(f"Failed to process Gemini response: {error_message}") @weave.op() async def follow_schema( self, schema: Dict[str, Any], # This should be an OpenAPI schema dictionary data: Dict[str, Any], ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task ) -> Dict[str, Any]: if not self.model: # Check if the main model was initialized logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.") try: current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) except Exception as e: raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}") elif self.model.model_name != ai_model: logger.info(f"Switching to model {ai_model} for this follow_schema request.") current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) else: current_model = self.model logger.info(f"Following schema via Gemini model: {current_model.model_name}...") # Prepare the prompt # System message can be part of the model or prepended here. system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE # The human message needs to contain the data to be transformed. # Ensure `json_info` placeholder is correctly used by your prompt string. try: data_as_json_string = json.dumps(data, indent=2) except TypeError as te: logger.error(f"Could not serialize 'data' to JSON for prompt: {te}") raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}") human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string) full_prompt_text = f"{system_message}\n\n{human_message}" content_parts = [full_prompt_text] # Define generation config for JSON output using the provided schema generation_config = GenerationConfig( response_mime_type="application/json", response_schema=schema, # The provided schema dictionary temperature=0.0, # For deterministic output max_output_tokens=2048, # Adjust as needed ) try: response = await current_model.generate_content_async( contents=content_parts, generation_config=generation_config, ) except Exception as e: error_message = exception_to_str(e) logger.error(f"Gemini API call failed for follow_schema: {error_message}") raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) # Process response try: if not response.candidates: block_reason_detail = "Unknown reason (no candidates)" if response.prompt_feedback and response.prompt_feedback.block_reason: block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}") # OpenAI version returned {"status": "refused"}, mimicking similar for block return {"status": "refused", "reason": block_reason_detail} candidate = response.candidates[0] if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}") if finish_reason_str == "SAFETY": safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"} if not candidate.content.parts or not candidate.content.parts[0].text: logger.error("Gemini response content (follow_schema) is empty.") # Mimic OpenAI's refusal structure or raise error return {"status": "refused", "reason": "Empty content from Gemini"} response_text = candidate.content.parts[0].text parsed_data = json.loads(response_text) # The schema is enforced by Gemini return parsed_data except json.JSONDecodeError as je: logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}") logger.debug(f"Non-JSON response received: {response_text[:500]}...") # The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON) # Let's use VendorError for consistency if that's preferred, or ValueError raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}") except Exception as e: error_message = exception_to_str(e) logger.error(f"Error processing Gemini response (follow_schema): {error_message}") raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}") raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}")