Spaces:
Sleeping
Sleeping
import json | |
import os | |
import io | |
import base64 | |
from typing import Any, Dict, List, Type, Union, Optional | |
import google.generativeai as genai | |
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings | |
import weave # Assuming weave is still used | |
from pydantic import BaseModel, ValidationError # For schema validation | |
# Assuming these utilities are in the same relative paths or accessible | |
from app.utils.converter import product_data_to_str | |
from app.utils.image_processing import ( | |
get_data_format, # Assuming this returns 'jpeg', 'png' etc. | |
get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str) | |
get_image_data, # Assuming this reads local path and returns base64_str | |
) | |
from app.utils.logger import exception_to_str, setup_logger | |
# Assuming these are correctly defined and accessible | |
from ..config import get_settings | |
from ..core import errors | |
from ..core.errors import BadRequestError, VendorError # Using your custom errors | |
from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted | |
from .base import BaseAttributionService # Assuming this base class exists | |
# Environment and Weave setup ( 그대로 유지 ) | |
ENV = os.getenv("ENV", "LOCAL") | |
if ENV == "LOCAL": | |
weave_project_name = "cfai/attribution-exp" | |
elif ENV == "DEV": | |
weave_project_name = "cfai/attribution-dev" | |
elif ENV == "UAT": | |
weave_project_name = "cfai/attribution-uat" | |
elif ENV == "PROD": | |
pass # No weave for PROD | |
if ENV != "PROD": | |
# weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here | |
print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}") | |
settings = get_settings() | |
prompts = get_prompts() | |
logger = setup_logger(__name__) | |
# Configure the Gemini client | |
try: | |
if settings.GEMINI_API_KEY: | |
genai.configure(api_key=settings.GEMINI_API_KEY) | |
else: | |
logger.error("GEMINI_API_KEY not found in settings.") | |
# Potentially raise an error or handle this case as per application requirements | |
except AttributeError: | |
logger.error("Settings object does not have GEMINI_API_KEY attribute.") | |
# Handle missing settings attribute | |
# Define default safety settings for Gemini | |
# Adjust these as per your application's requirements | |
DEFAULT_SAFETY_SETTINGS = { | |
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
} | |
class GeminiService(BaseAttributionService): | |
def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"): | |
""" | |
Initializes the GeminiService. | |
Args: | |
model_name (str): The name of the Gemini model to use. | |
""" | |
try: | |
self.model = genai.GenerativeModel( | |
model_name, | |
safety_settings=DEFAULT_SAFETY_SETTINGS | |
# system_instruction can be set here if a global system message is always used | |
) | |
logger.info(f"GeminiService initialized with model: {model_name}") | |
except Exception as e: | |
logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}") | |
# Depending on requirements, you might want to raise an error here | |
# For now, we'll let it proceed, and calls will fail if model isn't initialized. | |
self.model = None | |
def _prepare_image_parts( | |
self, | |
img_urls: Optional[List[str]] = None, | |
img_paths: Optional[List[str]] = None, | |
pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects | |
) -> List[Dict[str, Any]]: | |
""" | |
Prepares image data in the format expected by Gemini API. | |
Decodes base64 image data to bytes. | |
Converts PIL images to bytes. | |
""" | |
image_parts = [] | |
# Process image URLs | |
if img_urls: | |
for img_url in img_urls: | |
try: | |
base64_data, img_type = get_image_base64_and_type(img_url) | |
if base64_data and img_type: | |
# Gemini expects raw bytes, so decode base64 | |
image_bytes = base64.b64decode(base64_data) | |
mime_type = f"image/{img_type.lower()}" | |
image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
else: | |
logger.warning(f"Could not retrieve or identify type for image URL: {img_url}") | |
except Exception as e: | |
logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}") | |
# Process image paths | |
if img_paths: | |
for img_path in img_paths: | |
try: | |
base64_data = get_image_data(img_path) # Assuming this returns base64 string | |
img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg' | |
if base64_data and img_type: | |
image_bytes = base64.b64decode(base64_data) | |
mime_type = f"image/{img_type.lower()}" | |
image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
else: | |
logger.warning(f"Could not retrieve or identify type for image path: {img_path}") | |
except Exception as e: | |
logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}") | |
# Process PIL images | |
if pil_images: | |
for i, pil_image in enumerate(pil_images): | |
try: | |
img_format = pil_image.format or 'PNG' # Default to PNG if format is not available | |
mime_type = f"image/{img_format.lower()}" | |
with io.BytesIO() as img_byte_arr: | |
pil_image.save(img_byte_arr, format=img_format) | |
image_bytes = img_byte_arr.getvalue() | |
image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
except Exception as e: | |
logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}") | |
return image_parts | |
# Assuming weave.op can be used as a decorator directly | |
async def extract_attributes( | |
self, | |
attributes_model: Type[BaseModel], | |
ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest" | |
img_urls: Optional[List[str]] = None, | |
product_taxonomy: str = "", | |
product_data: Optional[Dict[str, Union[str, List[str]]]] = None, | |
pil_images: Optional[List[Any]] = None, | |
img_paths: Optional[List[str]] = None, | |
) -> Dict[str, Any]: | |
if not self.model: | |
raise VendorError("Gemini model not initialized.") | |
if self.model.model_name != ai_model: # If a different model is requested for this specific call | |
logger.info(f"Switching to model {ai_model} for this extraction request.") | |
# Note: This creates a new model object for the call. | |
# If this happens frequently, consider how model instances are managed. | |
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
else: | |
current_model = self.model | |
# Construct the prompt text | |
# Combining system and human prompts as Gemini typically takes a list of contents. | |
# System instructions can also be part of the model's initialization. | |
system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE | |
human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( | |
product_taxonomy=product_taxonomy, | |
product_data=product_data_to_str(product_data if product_data else {}), | |
) | |
full_prompt_text = f"{system_message}\n\n{human_message}" | |
# For logging or debugging the prompt | |
logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet | |
content_parts = [full_prompt_text] | |
# Prepare image parts | |
try: | |
image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images) | |
content_parts.extend(image_parts) | |
except Exception as e: | |
logger.error(f"Failed during image preparation: {exception_to_str(e)}") | |
raise BadRequestError(f"Image processing failed: {e}") | |
if not image_parts and (img_urls or img_paths or pil_images): | |
logger.warning("Image sources provided, but no image parts were successfully prepared.") | |
# Define generation config for JSON output | |
# Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary. | |
try: | |
schema_for_gemini = attributes_model.model_json_schema() | |
except Exception as e: | |
logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}") | |
raise VendorError(f"Could not generate schema for attributes_model: {e}") | |
generation_config = GenerationConfig( | |
response_mime_type="application/json", | |
response_schema=schema_for_gemini, # Gemini expects the schema here | |
temperature=0.0, # For deterministic output, similar to low top_p | |
max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI | |
# top_p, top_k can also be set if needed | |
) | |
logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...") | |
try: | |
response = await current_model.generate_content_async( | |
contents=content_parts, | |
generation_config=generation_config, | |
# request_options={"timeout": 120} # Example: set timeout in seconds | |
) | |
except Exception as e: # Catches google.api_core.exceptions and others | |
error_message = exception_to_str(e) | |
logger.error(f"Gemini API call failed: {error_message}") | |
# More specific error handling for Gemini can be added here | |
# e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument): | |
# raise BadRequestError(f"Invalid argument to Gemini: {error_message}") | |
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) | |
# Process the response | |
try: | |
# Check for safety blocks or refusals | |
if not response.candidates: | |
# This can happen if all candidates were filtered due to safety or other reasons. | |
block_reason_detail = "Unknown reason (no candidates)" | |
if response.prompt_feedback and response.prompt_feedback.block_reason: | |
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" | |
if response.prompt_feedback.block_reason_message: | |
block_reason_detail += f" - {response.prompt_feedback.block_reason_message}" | |
logger.error(f"Gemini response was blocked or empty. {block_reason_detail}") | |
raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}") | |
# Assuming the first candidate is the one we want | |
candidate = response.candidates[0] | |
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS | |
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" | |
logger.warning(f"Gemini generation finished with reason: {finish_reason_str}") | |
# Potentially raise error if finish reason is SAFETY, RECITATION, etc. | |
if finish_reason_str == "SAFETY": | |
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) | |
raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]") | |
if not candidate.content.parts or not candidate.content.parts[0].text: | |
logger.error("Gemini response content is empty or not in the expected text format.") | |
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)") | |
response_text = candidate.content.parts[0].text | |
# Parse and validate the JSON response using the Pydantic model | |
parsed_data = attributes_model.model_validate_json(response_text) | |
return parsed_data.model_dump() # Return as dict | |
except ValidationError as ve: | |
logger.error(f"Pydantic validation failed for Gemini response: {ve}") | |
logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON | |
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}") | |
except json.JSONDecodeError as je: | |
logger.error(f"JSON decoding failed for Gemini response: {je}") | |
logger.debug(f"Non-JSON response received: {response_text[:500]}...") | |
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}") | |
except VendorError: # Re-raise VendorErrors | |
raise | |
except Exception as e: | |
error_message = exception_to_str(e) | |
logger.error(f"Error processing Gemini response: {error_message}") | |
# Log the raw response text if available and an error occurred | |
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" | |
logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}") | |
raise VendorError(f"Failed to process Gemini response: {error_message}") | |
async def follow_schema( | |
self, | |
schema: Dict[str, Any], # This should be an OpenAPI schema dictionary | |
data: Dict[str, Any], | |
ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task | |
) -> Dict[str, Any]: | |
if not self.model: # Check if the main model was initialized | |
logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.") | |
try: | |
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
except Exception as e: | |
raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}") | |
elif self.model.model_name != ai_model: | |
logger.info(f"Switching to model {ai_model} for this follow_schema request.") | |
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
else: | |
current_model = self.model | |
logger.info(f"Following schema via Gemini model: {current_model.model_name}...") | |
# Prepare the prompt | |
# System message can be part of the model or prepended here. | |
system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE | |
# The human message needs to contain the data to be transformed. | |
# Ensure `json_info` placeholder is correctly used by your prompt string. | |
try: | |
data_as_json_string = json.dumps(data, indent=2) | |
except TypeError as te: | |
logger.error(f"Could not serialize 'data' to JSON for prompt: {te}") | |
raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}") | |
human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string) | |
full_prompt_text = f"{system_message}\n\n{human_message}" | |
content_parts = [full_prompt_text] | |
# Define generation config for JSON output using the provided schema | |
generation_config = GenerationConfig( | |
response_mime_type="application/json", | |
response_schema=schema, # The provided schema dictionary | |
temperature=0.0, # For deterministic output | |
max_output_tokens=2048, # Adjust as needed | |
) | |
try: | |
response = await current_model.generate_content_async( | |
contents=content_parts, | |
generation_config=generation_config, | |
) | |
except Exception as e: | |
error_message = exception_to_str(e) | |
logger.error(f"Gemini API call failed for follow_schema: {error_message}") | |
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) | |
# Process response | |
try: | |
if not response.candidates: | |
block_reason_detail = "Unknown reason (no candidates)" | |
if response.prompt_feedback and response.prompt_feedback.block_reason: | |
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" | |
logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}") | |
# OpenAI version returned {"status": "refused"}, mimicking similar for block | |
return {"status": "refused", "reason": block_reason_detail} | |
candidate = response.candidates[0] | |
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS | |
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" | |
logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}") | |
if finish_reason_str == "SAFETY": | |
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) | |
return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"} | |
if not candidate.content.parts or not candidate.content.parts[0].text: | |
logger.error("Gemini response content (follow_schema) is empty.") | |
# Mimic OpenAI's refusal structure or raise error | |
return {"status": "refused", "reason": "Empty content from Gemini"} | |
response_text = candidate.content.parts[0].text | |
parsed_data = json.loads(response_text) # The schema is enforced by Gemini | |
return parsed_data | |
except json.JSONDecodeError as je: | |
logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}") | |
logger.debug(f"Non-JSON response received: {response_text[:500]}...") | |
# The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON) | |
# Let's use VendorError for consistency if that's preferred, or ValueError | |
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}") | |
except Exception as e: | |
error_message = exception_to_str(e) | |
logger.error(f"Error processing Gemini response (follow_schema): {error_message}") | |
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" | |
logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}") | |
raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}") | |