samu's picture
1st
7c7ef49
from google import genai
from google.genai import types
from typing import Union, List, Generator, Dict, Optional
from PIL import Image
from io import BytesIO
import base64
import requests
import asyncio
import os
from dotenv import load_dotenv
from .category_instructions import get_instruction_for_category
from .category_config import CATEGORY_CONFIGS
load_dotenv()
client = genai.Client(
api_key=os.getenv("API_KEY")
)
def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str:
encoded = base64.b64encode(data).decode("utf-8")
return f"data:image/png;base64,{encoded}" if with_prefix else encoded
def decode_base64_image(base64_str: str) -> Image.Image:
# Remove the prefix if present (e.g., "data:image/png;base64,")
if base64_str.startswith("data:image"):
base64_str = base64_str.split(",")[1]
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
async def async_generate_text_and_image(prompt, category: Optional[str] = None):
# Get the appropriate instruction and configuration
instruction = get_instruction_for_category(category)
config = CATEGORY_CONFIGS.get(category.lower() if category else "", {})
# Enhance the prompt with category-specific guidance if available
if config:
style_guide = config.get("style_guide", "")
conventions = config.get("conventions", [])
common_elements = config.get("common_elements", [])
enhanced_prompt = (
f"{instruction}\n\n"
f"Style Guide: {style_guide}\n"
f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n"
f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n"
f"User Request: {prompt}"
)
else:
enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}"
response = await client.aio.models.generate_content(
model=os.getenv("MODEL"),
contents=enhanced_prompt,
config=types.GenerateContentConfig(
response_modalities=['TEXT', 'IMAGE']
)
)
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text is not None:
# Try to parse the text into sections
try:
text_sections = {}
current_section = "overview"
lines = part.text.split('\n')
for line in lines:
line = line.strip()
if not line:
continue
# Check for section headers
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
section_name = section_name.replace(' ', '_')
current_section = section_name
text_sections[current_section] = []
else:
if current_section not in text_sections:
text_sections[current_section] = []
text_sections[current_section].append(line)
# Clean up the sections
for section in text_sections:
text_sections[section] = '\n'.join(text_sections[section]).strip()
yield {'type': 'text', 'data': text_sections}
except Exception as e:
# Fallback to raw text if parsing fails
yield {'type': 'text', 'data': {'raw_text': part.text}}
elif hasattr(part, 'inline_data') and part.inline_data is not None:
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}
async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None):
# Validate that the image input is a base64 data URI
if not isinstance(image_path, str) or not image_path.startswith("data:image/"):
raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'")
# Decode the base64 string into a PIL Image
image = decode_base64_image(image_path)
# Get the appropriate instruction for the category
instruction = get_instruction_for_category(category)
contents = []
if text:
# Combine the instruction with the user's text input
combined_text = f"{instruction}\n\nUser Request: {text}"
contents.append(combined_text)
else:
contents.append(instruction)
contents.append(image)
response = await client.aio.models.generate_content(
model=os.getenv("MODEL"),
contents=contents,
config=types.GenerateContentConfig(
response_modalities=['TEXT', 'IMAGE']
)
)
for part in response.candidates[0].content.parts:
if hasattr(part, 'text') and part.text is not None:
# Try to parse the text into sections
try:
text_sections = {}
current_section = "overview"
lines = part.text.split('\n')
for line in lines:
line = line.strip()
if not line:
continue
# Check for section headers
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)):
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower()
section_name = section_name.replace(' ', '_')
current_section = section_name
text_sections[current_section] = []
else:
if current_section not in text_sections:
text_sections[current_section] = []
text_sections[current_section].append(line)
# Clean up the sections
for section in text_sections:
text_sections[section] = '\n'.join(text_sections[section]).strip()
yield {'type': 'text', 'data': text_sections}
except Exception as e:
# Fallback to raw text if parsing fails
yield {'type': 'text', 'data': {'raw_text': part.text}}
elif hasattr(part, 'inline_data') and part.inline_data is not None:
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)}