Update core/gemini_handler.py
Browse files- core/gemini_handler.py +119 -36
core/gemini_handler.py
CHANGED
@@ -2,64 +2,147 @@
|
|
2 |
import google.generativeai as genai
|
3 |
import json
|
4 |
import re
|
|
|
|
|
|
|
5 |
|
6 |
class GeminiHandler:
|
7 |
def __init__(self, api_key):
|
8 |
genai.configure(api_key=api_key)
|
9 |
-
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def _clean_json_response(self, text_response):
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
match = re.search(r"```json\s*([\s\S]*?)\s*```", text_response, re.DOTALL)
|
14 |
if match:
|
15 |
json_str = match.group(1).strip()
|
16 |
-
|
17 |
-
#
|
18 |
-
json_start_list = text_response.find('[')
|
19 |
-
json_start_obj = text_response.find('{')
|
20 |
-
|
21 |
-
if json_start_list != -1 and (json_start_obj == -1 or json_start_list < json_start_obj):
|
22 |
-
json_str = text_response[json_start_list:]
|
23 |
-
elif json_start_obj != -1:
|
24 |
-
json_str = text_response[json_start_obj:]
|
25 |
-
else:
|
26 |
-
return text_response # Not clearly JSON
|
27 |
|
28 |
-
#
|
29 |
-
# This is
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def _execute_gemini_call(self, prompt_text, expect_json=False):
|
|
|
|
|
37 |
try:
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
if expect_json:
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
except json.JSONDecodeError as e:
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
raise # Re-raise to
|
|
|
|
|
|
|
50 |
except Exception as e:
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
raise # Re-raise
|
54 |
|
55 |
def generate_story_breakdown(self, prompt_text):
|
|
|
|
|
|
|
56 |
return self._execute_gemini_call(prompt_text, expect_json=True)
|
57 |
|
58 |
-
def generate_image_prompt(self, prompt_text):
|
|
|
|
|
|
|
|
|
59 |
return self._execute_gemini_call(prompt_text, expect_json=False)
|
60 |
|
61 |
-
def regenerate_scene_script_details(self, prompt_text):
|
|
|
|
|
|
|
62 |
return self._execute_gemini_call(prompt_text, expect_json=True)
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import google.generativeai as genai
|
3 |
import json
|
4 |
import re
|
5 |
+
import logging # Added logging
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__) # Added logger
|
8 |
|
9 |
class GeminiHandler:
|
10 |
def __init__(self, api_key):
|
11 |
genai.configure(api_key=api_key)
|
12 |
+
# For complex JSON and instruction following, 'gemini-1.0-pro' or 'gemini-1.5-pro-latest' might be more robust.
|
13 |
+
# 'gemini-1.5-flash-latest' is faster and cheaper but might sometimes struggle with very complex formats.
|
14 |
+
self.model_name = 'gemini-1.5-flash-latest' # or 'gemini-1.0-pro' or 'gemini-1.5-pro-latest'
|
15 |
+
self.model = genai.GenerativeModel(self.model_name)
|
16 |
+
logger.info(f"GeminiHandler initialized with model: {self.model_name}")
|
17 |
|
18 |
def _clean_json_response(self, text_response):
|
19 |
+
"""
|
20 |
+
Attempts to extract a valid JSON string from Gemini's text response.
|
21 |
+
Prioritizes content within ```json ... ``` blocks.
|
22 |
+
"""
|
23 |
+
if not text_response:
|
24 |
+
return ""
|
25 |
+
|
26 |
+
# Attempt 1: Find JSON within markdown code blocks
|
27 |
match = re.search(r"```json\s*([\s\S]*?)\s*```", text_response, re.DOTALL)
|
28 |
if match:
|
29 |
json_str = match.group(1).strip()
|
30 |
+
logger.debug("Found JSON in markdown code block.")
|
31 |
+
return json_str # Assume this is the intended JSON
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
+
# Attempt 2: If no markdown block, try to find the start of a JSON list or object directly
|
34 |
+
# This is more heuristic and might pick up non-JSON if the model doesn't use code blocks.
|
35 |
+
json_str = text_response.strip()
|
36 |
+
first_char = next((char for char in json_str if char in ['[', '{']), None)
|
37 |
+
|
38 |
+
if first_char == '[':
|
39 |
+
# Find the first '[' and try to match until the last ']'
|
40 |
+
start_index = json_str.find('[')
|
41 |
+
end_index = json_str.rfind(']')
|
42 |
+
if start_index != -1 and end_index != -1 and end_index > start_index:
|
43 |
+
json_str = json_str[start_index : end_index+1]
|
44 |
+
logger.debug("Extracted potential JSON list directly.")
|
45 |
+
return json_str.strip()
|
46 |
+
elif first_char == '{':
|
47 |
+
# Find the first '{' and try to match until the last '}'
|
48 |
+
start_index = json_str.find('{')
|
49 |
+
end_index = json_str.rfind('}')
|
50 |
+
if start_index != -1 and end_index != -1 and end_index > start_index:
|
51 |
+
json_str = json_str[start_index : end_index+1]
|
52 |
+
logger.debug("Extracted potential JSON object directly.")
|
53 |
+
return json_str.strip()
|
54 |
+
|
55 |
+
logger.warning("Could not clearly identify JSON structure in the response. Returning raw attempt.")
|
56 |
+
return text_response # Return original if no clear JSON structure found by heuristics
|
57 |
|
58 |
def _execute_gemini_call(self, prompt_text, expect_json=False):
|
59 |
+
raw_text_content = "" # Initialize to ensure it's defined for logging
|
60 |
+
cleaned_json_attempt = "" # Initialize
|
61 |
try:
|
62 |
+
logger.info(f"Executing Gemini call (expect_json={expect_json}). Prompt starts with: {prompt_text[:150]}...")
|
63 |
+
# Safety settings can be adjusted if needed, though defaults are usually fine.
|
64 |
+
# generation_config = genai.types.GenerationConfig(
|
65 |
+
# # temperature=0.7, # Example: Adjust creativity
|
66 |
+
# )
|
67 |
+
response = self.model.generate_content(
|
68 |
+
prompt_text,
|
69 |
+
# generation_config=generation_config
|
70 |
+
)
|
71 |
+
|
72 |
+
# Check for safety ratings or blocks first (if applicable to your SDK version and use case)
|
73 |
+
# if response.prompt_feedback and response.prompt_feedback.block_reason:
|
74 |
+
# logger.error(f"Gemini call blocked. Reason: {response.prompt_feedback.block_reason_message}")
|
75 |
+
# raise Exception(f"Gemini call blocked: {response.prompt_feedback.block_reason_message}")
|
76 |
+
# if not response.candidates:
|
77 |
+
# logger.error("Gemini call returned no candidates. Check prompt or safety settings.")
|
78 |
+
# raise Exception("Gemini call returned no candidates.")
|
79 |
+
|
80 |
+
raw_text_content = response.text # Assuming .text gives the full string
|
81 |
+
logger.debug(f"Gemini raw response text (first 300 chars): {raw_text_content[:300]}")
|
82 |
+
|
83 |
if expect_json:
|
84 |
+
cleaned_json_attempt = self._clean_json_response(raw_text_content)
|
85 |
+
if not cleaned_json_attempt: # If cleaning returned empty
|
86 |
+
logger.error("JSON cleaning resulted in an empty string.")
|
87 |
+
raise json.JSONDecodeError("Cleaned JSON string is empty", "", 0)
|
88 |
+
|
89 |
+
logger.debug(f"Attempting to parse cleaned JSON (first 300 chars): {cleaned_json_attempt[:300]}")
|
90 |
+
parsed_json = json.loads(cleaned_json_attempt)
|
91 |
+
logger.info("Gemini call successful, JSON parsed.")
|
92 |
+
return parsed_json
|
93 |
+
else:
|
94 |
+
logger.info("Gemini call successful, returning text.")
|
95 |
+
return raw_text_content.strip()
|
96 |
+
|
97 |
except json.JSONDecodeError as e:
|
98 |
+
logger.error(f"JSONDecodeError: {e}. Failed to parse JSON from Gemini response.")
|
99 |
+
logger.error(f"--- Problematic Gemini Raw Response ---\n{raw_text_content}\n--- End Raw Response ---")
|
100 |
+
logger.error(f"--- Cleaned JSON Attempt ---\n{cleaned_json_attempt}\n--- End Cleaned Attempt ---")
|
101 |
+
raise # Re-raise for the caller to handle (e.g., show error in UI)
|
102 |
+
except AttributeError as ae: # Handles cases where `response.text` might not exist if call failed early
|
103 |
+
logger.error(f"AttributeError during Gemini call processing: {ae}. Likely an issue with the response object structure.", exc_info=True)
|
104 |
+
raise Exception(f"Gemini API response structure error: {ae}")
|
105 |
except Exception as e:
|
106 |
+
# This catches other errors from genai.GenerativeModel.generate_content()
|
107 |
+
# e.g., google.api_core.exceptions.PermissionDenied, google.api_core.exceptions.ResourceExhausted
|
108 |
+
logger.error(f"General error during Gemini API call: {type(e).__name__} - {e}", exc_info=True)
|
109 |
+
logger.error(f"--- Problematic Gemini Raw Response (if available) ---\n{raw_text_content}\n--- End Raw Response ---")
|
110 |
raise # Re-raise
|
111 |
|
112 |
def generate_story_breakdown(self, prompt_text):
|
113 |
+
"""
|
114 |
+
Generates the full cinematic treatment (list of scene JSON objects).
|
115 |
+
"""
|
116 |
return self._execute_gemini_call(prompt_text, expect_json=True)
|
117 |
|
118 |
+
def generate_image_prompt(self, prompt_text):
|
119 |
+
"""
|
120 |
+
Generates or refines a DALL-E prompt string (expects text, not JSON).
|
121 |
+
Used by `create_visual_regeneration_prompt` from prompt_engineering.
|
122 |
+
"""
|
123 |
return self._execute_gemini_call(prompt_text, expect_json=False)
|
124 |
|
125 |
+
def regenerate_scene_script_details(self, prompt_text):
|
126 |
+
"""
|
127 |
+
Regenerates the JSON object for a single scene based on feedback.
|
128 |
+
"""
|
129 |
return self._execute_gemini_call(prompt_text, expect_json=True)
|
130 |
|
131 |
+
# Renamed for clarity, as it refines a DALL-E prompt string based on feedback.
|
132 |
+
def refine_image_prompt_from_feedback(self, prompt_text):
|
133 |
+
"""
|
134 |
+
Refines an existing DALL-E prompt string based on user feedback.
|
135 |
+
Expects Gemini to return a new string (the refined prompt).
|
136 |
+
This method is called by app.py, which uses create_visual_regeneration_prompt.
|
137 |
+
"""
|
138 |
+
return self._execute_gemini_call(prompt_text, expect_json=False)
|
139 |
+
|
140 |
+
# You might add a new method here if you want Gemini to help construct
|
141 |
+
# the text-to-video prompt specifically, though your current `construct_text_to_video_prompt`
|
142 |
+
# in prompt_engineering.py does this without a Gemini call.
|
143 |
+
# If you did want Gemini to craft it:
|
144 |
+
# def generate_text_to_video_prompt_string(self, prompt_text_for_gemini):
|
145 |
+
# """
|
146 |
+
# Asks Gemini to craft a detailed text-to-video prompt string.
|
147 |
+
# """
|
148 |
+
# return self._execute_gemini_call(prompt_text_for_gemini, expect_json=False)
|