mgbam commited on
Commit
3539a49
·
verified ·
1 Parent(s): 55e2347

Update core/gemini_handler.py

Browse files
Files changed (1) hide show
  1. core/gemini_handler.py +119 -36
core/gemini_handler.py CHANGED
@@ -2,64 +2,147 @@
2
  import google.generativeai as genai
3
  import json
4
  import re
 
 
 
5
 
6
  class GeminiHandler:
7
  def __init__(self, api_key):
8
  genai.configure(api_key=api_key)
9
- self.model = genai.GenerativeModel('gemini-1.5-flash-latest') # Or your preferred Gemini model
 
 
 
 
10
 
11
  def _clean_json_response(self, text_response):
12
- # Attempt to find JSON within backticks or directly
 
 
 
 
 
 
 
13
  match = re.search(r"```json\s*([\s\S]*?)\s*```", text_response, re.DOTALL)
14
  if match:
15
  json_str = match.group(1).strip()
16
- else:
17
- # Try to find the start of a list or object
18
- json_start_list = text_response.find('[')
19
- json_start_obj = text_response.find('{')
20
-
21
- if json_start_list != -1 and (json_start_obj == -1 or json_start_list < json_start_obj):
22
- json_str = text_response[json_start_list:]
23
- elif json_start_obj != -1:
24
- json_str = text_response[json_start_obj:]
25
- else:
26
- return text_response # Not clearly JSON
27
 
28
- # Remove trailing characters that might break parsing if JSON is incomplete
29
- # This is a bit aggressive, might need refinement
30
- # Find last '}' or ']'
31
- last_bracket = max(json_str.rfind('}'), json_str.rfind(']'))
32
- if last_bracket != -1:
33
- json_str = json_str[:last_bracket+1]
34
- return json_str.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def _execute_gemini_call(self, prompt_text, expect_json=False):
 
 
37
  try:
38
- response = self.model.generate_content(prompt_text)
39
- text_content = response.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  if expect_json:
41
- cleaned_response = self._clean_json_response(text_content)
42
- # print(f"DEBUG: Cleaned JSON for prompt:\n{prompt_text[:200]}...\nResponse:\n{cleaned_response}") # Debug
43
- return json.loads(cleaned_response)
44
- return text_content.strip()
 
 
 
 
 
 
 
 
 
45
  except json.JSONDecodeError as e:
46
- print(f"Error decoding JSON from Gemini response: {e}")
47
- print(f"Problematic Gemini Raw Response:\n{text_content if 'text_content' in locals() else 'No response object'}")
48
- print(f"Cleaned attempt was:\n{cleaned_response if 'cleaned_response' in locals() else 'N/A'}")
49
- raise # Re-raise to be caught by caller
 
 
 
50
  except Exception as e:
51
- print(f"Error in Gemini call: {e}")
52
- print(f"Problematic Gemini Raw Response (if available):\n{response.text if 'response' in locals() else 'No response object'}")
 
 
53
  raise # Re-raise
54
 
55
  def generate_story_breakdown(self, prompt_text):
 
 
 
56
  return self._execute_gemini_call(prompt_text, expect_json=True)
57
 
58
- def generate_image_prompt(self, prompt_text): # This is for generating a new image prompt string
 
 
 
 
59
  return self._execute_gemini_call(prompt_text, expect_json=False)
60
 
61
- def regenerate_scene_script_details(self, prompt_text): # Expects JSON for a single scene
 
 
 
62
  return self._execute_gemini_call(prompt_text, expect_json=True)
63
 
64
- def regenerate_image_prompt_from_feedback(self, prompt_text): # Expects a string (the new prompt)
65
- return self._execute_gemini_call(prompt_text, expect_json=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import google.generativeai as genai
3
  import json
4
  import re
5
+ import logging # Added logging
6
+
7
+ logger = logging.getLogger(__name__) # Added logger
8
 
9
  class GeminiHandler:
10
  def __init__(self, api_key):
11
  genai.configure(api_key=api_key)
12
+ # For complex JSON and instruction following, 'gemini-1.0-pro' or 'gemini-1.5-pro-latest' might be more robust.
13
+ # 'gemini-1.5-flash-latest' is faster and cheaper but might sometimes struggle with very complex formats.
14
+ self.model_name = 'gemini-1.5-flash-latest' # or 'gemini-1.0-pro' or 'gemini-1.5-pro-latest'
15
+ self.model = genai.GenerativeModel(self.model_name)
16
+ logger.info(f"GeminiHandler initialized with model: {self.model_name}")
17
 
18
  def _clean_json_response(self, text_response):
19
+ """
20
+ Attempts to extract a valid JSON string from Gemini's text response.
21
+ Prioritizes content within ```json ... ``` blocks.
22
+ """
23
+ if not text_response:
24
+ return ""
25
+
26
+ # Attempt 1: Find JSON within markdown code blocks
27
  match = re.search(r"```json\s*([\s\S]*?)\s*```", text_response, re.DOTALL)
28
  if match:
29
  json_str = match.group(1).strip()
30
+ logger.debug("Found JSON in markdown code block.")
31
+ return json_str # Assume this is the intended JSON
 
 
 
 
 
 
 
 
 
32
 
33
+ # Attempt 2: If no markdown block, try to find the start of a JSON list or object directly
34
+ # This is more heuristic and might pick up non-JSON if the model doesn't use code blocks.
35
+ json_str = text_response.strip()
36
+ first_char = next((char for char in json_str if char in ['[', '{']), None)
37
+
38
+ if first_char == '[':
39
+ # Find the first '[' and try to match until the last ']'
40
+ start_index = json_str.find('[')
41
+ end_index = json_str.rfind(']')
42
+ if start_index != -1 and end_index != -1 and end_index > start_index:
43
+ json_str = json_str[start_index : end_index+1]
44
+ logger.debug("Extracted potential JSON list directly.")
45
+ return json_str.strip()
46
+ elif first_char == '{':
47
+ # Find the first '{' and try to match until the last '}'
48
+ start_index = json_str.find('{')
49
+ end_index = json_str.rfind('}')
50
+ if start_index != -1 and end_index != -1 and end_index > start_index:
51
+ json_str = json_str[start_index : end_index+1]
52
+ logger.debug("Extracted potential JSON object directly.")
53
+ return json_str.strip()
54
+
55
+ logger.warning("Could not clearly identify JSON structure in the response. Returning raw attempt.")
56
+ return text_response # Return original if no clear JSON structure found by heuristics
57
 
58
  def _execute_gemini_call(self, prompt_text, expect_json=False):
59
+ raw_text_content = "" # Initialize to ensure it's defined for logging
60
+ cleaned_json_attempt = "" # Initialize
61
  try:
62
+ logger.info(f"Executing Gemini call (expect_json={expect_json}). Prompt starts with: {prompt_text[:150]}...")
63
+ # Safety settings can be adjusted if needed, though defaults are usually fine.
64
+ # generation_config = genai.types.GenerationConfig(
65
+ # # temperature=0.7, # Example: Adjust creativity
66
+ # )
67
+ response = self.model.generate_content(
68
+ prompt_text,
69
+ # generation_config=generation_config
70
+ )
71
+
72
+ # Check for safety ratings or blocks first (if applicable to your SDK version and use case)
73
+ # if response.prompt_feedback and response.prompt_feedback.block_reason:
74
+ # logger.error(f"Gemini call blocked. Reason: {response.prompt_feedback.block_reason_message}")
75
+ # raise Exception(f"Gemini call blocked: {response.prompt_feedback.block_reason_message}")
76
+ # if not response.candidates:
77
+ # logger.error("Gemini call returned no candidates. Check prompt or safety settings.")
78
+ # raise Exception("Gemini call returned no candidates.")
79
+
80
+ raw_text_content = response.text # Assuming .text gives the full string
81
+ logger.debug(f"Gemini raw response text (first 300 chars): {raw_text_content[:300]}")
82
+
83
  if expect_json:
84
+ cleaned_json_attempt = self._clean_json_response(raw_text_content)
85
+ if not cleaned_json_attempt: # If cleaning returned empty
86
+ logger.error("JSON cleaning resulted in an empty string.")
87
+ raise json.JSONDecodeError("Cleaned JSON string is empty", "", 0)
88
+
89
+ logger.debug(f"Attempting to parse cleaned JSON (first 300 chars): {cleaned_json_attempt[:300]}")
90
+ parsed_json = json.loads(cleaned_json_attempt)
91
+ logger.info("Gemini call successful, JSON parsed.")
92
+ return parsed_json
93
+ else:
94
+ logger.info("Gemini call successful, returning text.")
95
+ return raw_text_content.strip()
96
+
97
  except json.JSONDecodeError as e:
98
+ logger.error(f"JSONDecodeError: {e}. Failed to parse JSON from Gemini response.")
99
+ logger.error(f"--- Problematic Gemini Raw Response ---\n{raw_text_content}\n--- End Raw Response ---")
100
+ logger.error(f"--- Cleaned JSON Attempt ---\n{cleaned_json_attempt}\n--- End Cleaned Attempt ---")
101
+ raise # Re-raise for the caller to handle (e.g., show error in UI)
102
+ except AttributeError as ae: # Handles cases where `response.text` might not exist if call failed early
103
+ logger.error(f"AttributeError during Gemini call processing: {ae}. Likely an issue with the response object structure.", exc_info=True)
104
+ raise Exception(f"Gemini API response structure error: {ae}")
105
  except Exception as e:
106
+ # This catches other errors from genai.GenerativeModel.generate_content()
107
+ # e.g., google.api_core.exceptions.PermissionDenied, google.api_core.exceptions.ResourceExhausted
108
+ logger.error(f"General error during Gemini API call: {type(e).__name__} - {e}", exc_info=True)
109
+ logger.error(f"--- Problematic Gemini Raw Response (if available) ---\n{raw_text_content}\n--- End Raw Response ---")
110
  raise # Re-raise
111
 
112
  def generate_story_breakdown(self, prompt_text):
113
+ """
114
+ Generates the full cinematic treatment (list of scene JSON objects).
115
+ """
116
  return self._execute_gemini_call(prompt_text, expect_json=True)
117
 
118
+ def generate_image_prompt(self, prompt_text):
119
+ """
120
+ Generates or refines a DALL-E prompt string (expects text, not JSON).
121
+ Used by `create_visual_regeneration_prompt` from prompt_engineering.
122
+ """
123
  return self._execute_gemini_call(prompt_text, expect_json=False)
124
 
125
+ def regenerate_scene_script_details(self, prompt_text):
126
+ """
127
+ Regenerates the JSON object for a single scene based on feedback.
128
+ """
129
  return self._execute_gemini_call(prompt_text, expect_json=True)
130
 
131
+ # Renamed for clarity, as it refines a DALL-E prompt string based on feedback.
132
+ def refine_image_prompt_from_feedback(self, prompt_text):
133
+ """
134
+ Refines an existing DALL-E prompt string based on user feedback.
135
+ Expects Gemini to return a new string (the refined prompt).
136
+ This method is called by app.py, which uses create_visual_regeneration_prompt.
137
+ """
138
+ return self._execute_gemini_call(prompt_text, expect_json=False)
139
+
140
+ # You might add a new method here if you want Gemini to help construct
141
+ # the text-to-video prompt specifically, though your current `construct_text_to_video_prompt`
142
+ # in prompt_engineering.py does this without a Gemini call.
143
+ # If you did want Gemini to craft it:
144
+ # def generate_text_to_video_prompt_string(self, prompt_text_for_gemini):
145
+ # """
146
+ # Asks Gemini to craft a detailed text-to-video prompt string.
147
+ # """
148
+ # return self._execute_gemini_call(prompt_text_for_gemini, expect_json=False)