Spaces:
Sleeping
Sleeping
File size: 19,539 Bytes
ede43a0 4a370c7 ede43a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 |
import json
import os
import io
import base64
from typing import Any, Dict, List, Type, Union, Optional
import google.generativeai as genai
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings
import weave # Assuming weave is still used
from pydantic import BaseModel, ValidationError # For schema validation
# Assuming these utilities are in the same relative paths or accessible
from app.utils.converter import product_data_to_str
from app.utils.image_processing import (
get_data_format, # Assuming this returns 'jpeg', 'png' etc.
get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str)
get_image_data, # Assuming this reads local path and returns base64_str
)
from app.utils.logger import exception_to_str, setup_logger
# Assuming these are correctly defined and accessible
from ..config import get_settings
from ..core import errors
from ..core.errors import BadRequestError, VendorError # Using your custom errors
from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted
from .base import BaseAttributionService # Assuming this base class exists
# Environment and Weave setup ( 그대로 유지 )
ENV = os.getenv("ENV", "LOCAL")
if ENV == "LOCAL":
weave_project_name = "cfai/attribution-exp"
elif ENV == "DEV":
weave_project_name = "cfai/attribution-dev"
elif ENV == "UAT":
weave_project_name = "cfai/attribution-uat"
elif ENV == "PROD":
pass # No weave for PROD
if ENV != "PROD":
# weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here
print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}")
settings = get_settings()
prompts = get_prompts()
logger = setup_logger(__name__)
# Configure the Gemini client
try:
if settings.GEMINI_API_KEY:
genai.configure(api_key=settings.GEMINI_API_KEY)
else:
logger.error("GEMINI_API_KEY not found in settings.")
# Potentially raise an error or handle this case as per application requirements
except AttributeError:
logger.error("Settings object does not have GEMINI_API_KEY attribute.")
# Handle missing settings attribute
# Define default safety settings for Gemini
# Adjust these as per your application's requirements
DEFAULT_SAFETY_SETTINGS = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}
class GeminiService(BaseAttributionService):
def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"):
"""
Initializes the GeminiService.
Args:
model_name (str): The name of the Gemini model to use.
"""
try:
self.model = genai.GenerativeModel(
model_name,
safety_settings=DEFAULT_SAFETY_SETTINGS
# system_instruction can be set here if a global system message is always used
)
logger.info(f"GeminiService initialized with model: {model_name}")
except Exception as e:
logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}")
# Depending on requirements, you might want to raise an error here
# For now, we'll let it proceed, and calls will fail if model isn't initialized.
self.model = None
def _prepare_image_parts(
self,
img_urls: Optional[List[str]] = None,
img_paths: Optional[List[str]] = None,
pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects
) -> List[Dict[str, Any]]:
"""
Prepares image data in the format expected by Gemini API.
Decodes base64 image data to bytes.
Converts PIL images to bytes.
"""
image_parts = []
# Process image URLs
if img_urls:
for img_url in img_urls:
try:
base64_data, img_type = get_image_base64_and_type(img_url)
if base64_data and img_type:
# Gemini expects raw bytes, so decode base64
image_bytes = base64.b64decode(base64_data)
mime_type = f"image/{img_type.lower()}"
image_parts.append({"mime_type": mime_type, "data": image_bytes})
else:
logger.warning(f"Could not retrieve or identify type for image URL: {img_url}")
except Exception as e:
logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}")
# Process image paths
if img_paths:
for img_path in img_paths:
try:
base64_data = get_image_data(img_path) # Assuming this returns base64 string
img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg'
if base64_data and img_type:
image_bytes = base64.b64decode(base64_data)
mime_type = f"image/{img_type.lower()}"
image_parts.append({"mime_type": mime_type, "data": image_bytes})
else:
logger.warning(f"Could not retrieve or identify type for image path: {img_path}")
except Exception as e:
logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}")
# Process PIL images
if pil_images:
for i, pil_image in enumerate(pil_images):
try:
img_format = pil_image.format or 'PNG' # Default to PNG if format is not available
mime_type = f"image/{img_format.lower()}"
with io.BytesIO() as img_byte_arr:
pil_image.save(img_byte_arr, format=img_format)
image_bytes = img_byte_arr.getvalue()
image_parts.append({"mime_type": mime_type, "data": image_bytes})
except Exception as e:
logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}")
return image_parts
@weave.op() # Assuming weave.op can be used as a decorator directly
async def extract_attributes(
self,
attributes_model: Type[BaseModel],
ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest"
img_urls: Optional[List[str]] = None,
product_taxonomy: str = "",
product_data: Optional[Dict[str, Union[str, List[str]]]] = None,
pil_images: Optional[List[Any]] = None,
img_paths: Optional[List[str]] = None,
) -> Dict[str, Any]:
if not self.model:
raise VendorError("Gemini model not initialized.")
if self.model.model_name != ai_model: # If a different model is requested for this specific call
logger.info(f"Switching to model {ai_model} for this extraction request.")
# Note: This creates a new model object for the call.
# If this happens frequently, consider how model instances are managed.
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
else:
current_model = self.model
# Construct the prompt text
# Combining system and human prompts as Gemini typically takes a list of contents.
# System instructions can also be part of the model's initialization.
system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE
human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
product_taxonomy=product_taxonomy,
product_data=product_data_to_str(product_data if product_data else {}),
)
full_prompt_text = f"{system_message}\n\n{human_message}"
# For logging or debugging the prompt
logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet
content_parts = [full_prompt_text]
# Prepare image parts
try:
image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images)
content_parts.extend(image_parts)
except Exception as e:
logger.error(f"Failed during image preparation: {exception_to_str(e)}")
raise BadRequestError(f"Image processing failed: {e}")
if not image_parts and (img_urls or img_paths or pil_images):
logger.warning("Image sources provided, but no image parts were successfully prepared.")
# Define generation config for JSON output
# Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary.
try:
schema_for_gemini = attributes_model.model_json_schema()
except Exception as e:
logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}")
raise VendorError(f"Could not generate schema for attributes_model: {e}")
generation_config = GenerationConfig(
response_mime_type="application/json",
response_schema=schema_for_gemini, # Gemini expects the schema here
temperature=0.0, # For deterministic output, similar to low top_p
max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI
# top_p, top_k can also be set if needed
)
logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...")
try:
response = await current_model.generate_content_async(
contents=content_parts,
generation_config=generation_config,
# request_options={"timeout": 120} # Example: set timeout in seconds
)
except Exception as e: # Catches google.api_core.exceptions and others
error_message = exception_to_str(e)
logger.error(f"Gemini API call failed: {error_message}")
# More specific error handling for Gemini can be added here
# e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument):
# raise BadRequestError(f"Invalid argument to Gemini: {error_message}")
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
# Process the response
try:
# Check for safety blocks or refusals
if not response.candidates:
# This can happen if all candidates were filtered due to safety or other reasons.
block_reason_detail = "Unknown reason (no candidates)"
if response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
if response.prompt_feedback.block_reason_message:
block_reason_detail += f" - {response.prompt_feedback.block_reason_message}"
logger.error(f"Gemini response was blocked or empty. {block_reason_detail}")
raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}")
# Assuming the first candidate is the one we want
candidate = response.candidates[0]
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
logger.warning(f"Gemini generation finished with reason: {finish_reason_str}")
# Potentially raise error if finish reason is SAFETY, RECITATION, etc.
if finish_reason_str == "SAFETY":
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]")
if not candidate.content.parts or not candidate.content.parts[0].text:
logger.error("Gemini response content is empty or not in the expected text format.")
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)")
response_text = candidate.content.parts[0].text
# Parse and validate the JSON response using the Pydantic model
parsed_data = attributes_model.model_validate_json(response_text)
return parsed_data.model_dump() # Return as dict
except ValidationError as ve:
logger.error(f"Pydantic validation failed for Gemini response: {ve}")
logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}")
except json.JSONDecodeError as je:
logger.error(f"JSON decoding failed for Gemini response: {je}")
logger.debug(f"Non-JSON response received: {response_text[:500]}...")
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}")
except VendorError: # Re-raise VendorErrors
raise
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Error processing Gemini response: {error_message}")
# Log the raw response text if available and an error occurred
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}")
raise VendorError(f"Failed to process Gemini response: {error_message}")
@weave.op()
async def follow_schema(
self,
schema: Dict[str, Any], # This should be an OpenAPI schema dictionary
data: Dict[str, Any],
ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task
) -> Dict[str, Any]:
if not self.model: # Check if the main model was initialized
logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.")
try:
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
except Exception as e:
raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}")
elif self.model.model_name != ai_model:
logger.info(f"Switching to model {ai_model} for this follow_schema request.")
current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
else:
current_model = self.model
logger.info(f"Following schema via Gemini model: {current_model.model_name}...")
# Prepare the prompt
# System message can be part of the model or prepended here.
system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE
# The human message needs to contain the data to be transformed.
# Ensure `json_info` placeholder is correctly used by your prompt string.
try:
data_as_json_string = json.dumps(data, indent=2)
except TypeError as te:
logger.error(f"Could not serialize 'data' to JSON for prompt: {te}")
raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}")
human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string)
full_prompt_text = f"{system_message}\n\n{human_message}"
content_parts = [full_prompt_text]
# Define generation config for JSON output using the provided schema
generation_config = GenerationConfig(
response_mime_type="application/json",
response_schema=schema, # The provided schema dictionary
temperature=0.0, # For deterministic output
max_output_tokens=2048, # Adjust as needed
)
try:
response = await current_model.generate_content_async(
contents=content_parts,
generation_config=generation_config,
)
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Gemini API call failed for follow_schema: {error_message}")
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
# Process response
try:
if not response.candidates:
block_reason_detail = "Unknown reason (no candidates)"
if response.prompt_feedback and response.prompt_feedback.block_reason:
block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}")
# OpenAI version returned {"status": "refused"}, mimicking similar for block
return {"status": "refused", "reason": block_reason_detail}
candidate = response.candidates[0]
if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}")
if finish_reason_str == "SAFETY":
safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"}
if not candidate.content.parts or not candidate.content.parts[0].text:
logger.error("Gemini response content (follow_schema) is empty.")
# Mimic OpenAI's refusal structure or raise error
return {"status": "refused", "reason": "Empty content from Gemini"}
response_text = candidate.content.parts[0].text
parsed_data = json.loads(response_text) # The schema is enforced by Gemini
return parsed_data
except json.JSONDecodeError as je:
logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}")
logger.debug(f"Non-JSON response received: {response_text[:500]}...")
# The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON)
# Let's use VendorError for consistency if that's preferred, or ValueError
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}")
except Exception as e:
error_message = exception_to_str(e)
logger.error(f"Error processing Gemini response (follow_schema): {error_message}")
raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}")
raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}")
|