File size: 19,539 Bytes
ede43a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a370c7
ede43a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
import json
import os
import io
import base64
from typing import Any, Dict, List, Type, Union, Optional

import google.generativeai as genai
from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings
import weave # Assuming weave is still used
from pydantic import BaseModel, ValidationError # For schema validation

# Assuming these utilities are in the same relative paths or accessible
from app.utils.converter import product_data_to_str
from app.utils.image_processing import (
    get_data_format, # Assuming this returns 'jpeg', 'png' etc.
    get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str)
    get_image_data, # Assuming this reads local path and returns base64_str
)
from app.utils.logger import exception_to_str, setup_logger

# Assuming these are correctly defined and accessible
from ..config import get_settings 
from ..core import errors
from ..core.errors import BadRequestError, VendorError # Using your custom errors
from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted
from .base import BaseAttributionService # Assuming this base class exists

# Environment and Weave setup ( 그대로 유지 )
ENV = os.getenv("ENV", "LOCAL")
if ENV == "LOCAL":
    weave_project_name = "cfai/attribution-exp"
elif ENV == "DEV":
    weave_project_name = "cfai/attribution-dev"
elif ENV == "UAT":
    weave_project_name = "cfai/attribution-uat"
elif ENV == "PROD":
    pass # No weave for PROD

if ENV != "PROD":
    # weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here
    print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}")

settings = get_settings()
prompts = get_prompts()
logger = setup_logger(__name__)

# Configure the Gemini client
try:
    if settings.GEMINI_API_KEY:
        genai.configure(api_key=settings.GEMINI_API_KEY)
    else:
        logger.error("GEMINI_API_KEY not found in settings.")
        # Potentially raise an error or handle this case as per application requirements
except AttributeError:
    logger.error("Settings object does not have GEMINI_API_KEY attribute.")
    # Handle missing settings attribute

# Define default safety settings for Gemini
# Adjust these as per your application's requirements
DEFAULT_SAFETY_SETTINGS = {
    HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
    HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
    HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
    HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
}

class GeminiService(BaseAttributionService):
    def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"):
        """
        Initializes the GeminiService.
        Args:
            model_name (str): The name of the Gemini model to use.
        """
        try:
            self.model = genai.GenerativeModel(
                model_name,
                safety_settings=DEFAULT_SAFETY_SETTINGS
                # system_instruction can be set here if a global system message is always used
            )
            logger.info(f"GeminiService initialized with model: {model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}")
            # Depending on requirements, you might want to raise an error here
            # For now, we'll let it proceed, and calls will fail if model isn't initialized.
            self.model = None


    def _prepare_image_parts(
        self,
        img_urls: Optional[List[str]] = None,
        img_paths: Optional[List[str]] = None,
        pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects
    ) -> List[Dict[str, Any]]:
        """
        Prepares image data in the format expected by Gemini API.
        Decodes base64 image data to bytes.
        Converts PIL images to bytes.
        """
        image_parts = []

        # Process image URLs
        if img_urls:
            for img_url in img_urls:
                try:
                    base64_data, img_type = get_image_base64_and_type(img_url)
                    if base64_data and img_type:
                        # Gemini expects raw bytes, so decode base64
                        image_bytes = base64.b64decode(base64_data)
                        mime_type = f"image/{img_type.lower()}"
                        image_parts.append({"mime_type": mime_type, "data": image_bytes})
                    else:
                        logger.warning(f"Could not retrieve or identify type for image URL: {img_url}")
                except Exception as e:
                    logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}")

        # Process image paths
        if img_paths:
            for img_path in img_paths:
                try:
                    base64_data = get_image_data(img_path) # Assuming this returns base64 string
                    img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg'
                    if base64_data and img_type:
                        image_bytes = base64.b64decode(base64_data)
                        mime_type = f"image/{img_type.lower()}"
                        image_parts.append({"mime_type": mime_type, "data": image_bytes})
                    else:
                        logger.warning(f"Could not retrieve or identify type for image path: {img_path}")
                except Exception as e:
                    logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}")
        
        # Process PIL images
        if pil_images:
            for i, pil_image in enumerate(pil_images):
                try:
                    img_format = pil_image.format or 'PNG' # Default to PNG if format is not available
                    mime_type = f"image/{img_format.lower()}"
                    with io.BytesIO() as img_byte_arr:
                        pil_image.save(img_byte_arr, format=img_format)
                        image_bytes = img_byte_arr.getvalue()
                        image_parts.append({"mime_type": mime_type, "data": image_bytes})
                except Exception as e:
                    logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}")
        
        return image_parts

    @weave.op() # Assuming weave.op can be used as a decorator directly
    async def extract_attributes(
        self,
        attributes_model: Type[BaseModel],
        ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest"
        img_urls: Optional[List[str]] = None,
        product_taxonomy: str = "",
        product_data: Optional[Dict[str, Union[str, List[str]]]] = None,
        pil_images: Optional[List[Any]] = None,
        img_paths: Optional[List[str]] = None,
    ) -> Dict[str, Any]:
        if not self.model:
            raise VendorError("Gemini model not initialized.")
        if self.model.model_name != ai_model: # If a different model is requested for this specific call
             logger.info(f"Switching to model {ai_model} for this extraction request.")
             # Note: This creates a new model object for the call.
             # If this happens frequently, consider how model instances are managed.
             current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
        else:
            current_model = self.model

        # Construct the prompt text
        # Combining system and human prompts as Gemini typically takes a list of contents.
        # System instructions can also be part of the model's initialization.
        system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE
        human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
            product_taxonomy=product_taxonomy,
            product_data=product_data_to_str(product_data if product_data else {}),
        )
        full_prompt_text = f"{system_message}\n\n{human_message}"
        
        # For logging or debugging the prompt
        logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet

        content_parts = [full_prompt_text]
        
        # Prepare image parts
        try:
            image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images)
            content_parts.extend(image_parts)
        except Exception as e:
            logger.error(f"Failed during image preparation: {exception_to_str(e)}")
            raise BadRequestError(f"Image processing failed: {e}")

        if not image_parts and (img_urls or img_paths or pil_images):
             logger.warning("Image sources provided, but no image parts were successfully prepared.")
        
        # Define generation config for JSON output
        # Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary.
        try:
            schema_for_gemini = attributes_model.model_json_schema()
        except Exception as e:
            logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}")
            raise VendorError(f"Could not generate schema for attributes_model: {e}")

        generation_config = GenerationConfig(
            response_mime_type="application/json",
            response_schema=schema_for_gemini, # Gemini expects the schema here
            temperature=0.0, # For deterministic output, similar to low top_p
            max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI
            # top_p, top_k can also be set if needed
        )
        
        logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...")
        try:
            response = await current_model.generate_content_async(
                contents=content_parts,
                generation_config=generation_config,
                # request_options={"timeout": 120} # Example: set timeout in seconds
            )
        except Exception as e: # Catches google.api_core.exceptions and others
            error_message = exception_to_str(e)
            logger.error(f"Gemini API call failed: {error_message}")
            # More specific error handling for Gemini can be added here
            # e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument):
            # raise BadRequestError(f"Invalid argument to Gemini: {error_message}")
            raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))

        # Process the response
        try:
            # Check for safety blocks or refusals
            if not response.candidates:
                 # This can happen if all candidates were filtered due to safety or other reasons.
                block_reason_detail = "Unknown reason (no candidates)"
                if response.prompt_feedback and response.prompt_feedback.block_reason:
                    block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
                    if response.prompt_feedback.block_reason_message:
                         block_reason_detail += f" - {response.prompt_feedback.block_reason_message}"
                logger.error(f"Gemini response was blocked or empty. {block_reason_detail}")
                raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}")


            # Assuming the first candidate is the one we want
            candidate = response.candidates[0]
            
            if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
                finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
                logger.warning(f"Gemini generation finished with reason: {finish_reason_str}")
                # Potentially raise error if finish reason is SAFETY, RECITATION, etc.
                if finish_reason_str == "SAFETY":
                    safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
                    raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]")


            if not candidate.content.parts or not candidate.content.parts[0].text:
                logger.error("Gemini response content is empty or not in the expected text format.")
                raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)")

            response_text = candidate.content.parts[0].text
            
            # Parse and validate the JSON response using the Pydantic model
            parsed_data = attributes_model.model_validate_json(response_text)
            return parsed_data.model_dump() # Return as dict

        except ValidationError as ve:
            logger.error(f"Pydantic validation failed for Gemini response: {ve}")
            logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON
            raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}")
        except json.JSONDecodeError as je:
            logger.error(f"JSON decoding failed for Gemini response: {je}")
            logger.debug(f"Non-JSON response received: {response_text[:500]}...")
            raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}")
        except VendorError: # Re-raise VendorErrors
            raise
        except Exception as e:
            error_message = exception_to_str(e)
            logger.error(f"Error processing Gemini response: {error_message}")
            # Log the raw response text if available and an error occurred
            raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
            logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}")
            raise VendorError(f"Failed to process Gemini response: {error_message}")

    @weave.op()
    async def follow_schema(
        self, 
        schema: Dict[str, Any], # This should be an OpenAPI schema dictionary
        data: Dict[str, Any],
        ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task
    ) -> Dict[str, Any]:
        if not self.model: # Check if the main model was initialized
            logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.")
            try:
                current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
            except Exception as e:
                raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}")
        elif self.model.model_name != ai_model:
             logger.info(f"Switching to model {ai_model} for this follow_schema request.")
             current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
        else:
            current_model = self.model

        logger.info(f"Following schema via Gemini model: {current_model.model_name}...")
        
        # Prepare the prompt
        # System message can be part of the model or prepended here.
        system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE
        # The human message needs to contain the data to be transformed.
        # Ensure `json_info` placeholder is correctly used by your prompt string.
        try:
            data_as_json_string = json.dumps(data, indent=2)
        except TypeError as te:
            logger.error(f"Could not serialize 'data' to JSON for prompt: {te}")
            raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}")

        human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string)
        full_prompt_text = f"{system_message}\n\n{human_message}"
        
        content_parts = [full_prompt_text]

        # Define generation config for JSON output using the provided schema
        generation_config = GenerationConfig(
            response_mime_type="application/json",
            response_schema=schema, # The provided schema dictionary
            temperature=0.0, # For deterministic output
            max_output_tokens=2048, # Adjust as needed
        )

        try:
            response = await current_model.generate_content_async(
                contents=content_parts,
                generation_config=generation_config,
            )
        except Exception as e:
            error_message = exception_to_str(e)
            logger.error(f"Gemini API call failed for follow_schema: {error_message}")
            raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))

        # Process response
        try:
            if not response.candidates:
                block_reason_detail = "Unknown reason (no candidates)"
                if response.prompt_feedback and response.prompt_feedback.block_reason:
                    block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
                logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}")
                # OpenAI version returned {"status": "refused"}, mimicking similar for block
                return {"status": "refused", "reason": block_reason_detail}

            candidate = response.candidates[0]

            if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
                finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
                logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}")
                if finish_reason_str == "SAFETY":
                     safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
                     return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"}


            if not candidate.content.parts or not candidate.content.parts[0].text:
                logger.error("Gemini response content (follow_schema) is empty.")
                # Mimic OpenAI's refusal structure or raise error
                return {"status": "refused", "reason": "Empty content from Gemini"}


            response_text = candidate.content.parts[0].text
            parsed_data = json.loads(response_text) # The schema is enforced by Gemini
            return parsed_data

        except json.JSONDecodeError as je:
            logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}")
            logger.debug(f"Non-JSON response received: {response_text[:500]}...")
            # The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON)
            # Let's use VendorError for consistency if that's preferred, or ValueError
            raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}")
        except Exception as e:
            error_message = exception_to_str(e)
            logger.error(f"Error processing Gemini response (follow_schema): {error_message}")
            raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
            logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}")
            raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}")