Spaces:
Sleeping
Sleeping
from google import genai | |
from google.genai import types | |
from typing import Union, List, Generator, Dict, Optional | |
from PIL import Image | |
from io import BytesIO | |
import base64 | |
import requests | |
import asyncio | |
import os | |
from dotenv import load_dotenv | |
from .category_instructions import get_instruction_for_category | |
from .category_config import CATEGORY_CONFIGS | |
load_dotenv() | |
client = genai.Client( | |
api_key=os.getenv("API_KEY") | |
) | |
def bytes_to_base64(data: bytes, with_prefix: bool = True) -> str: | |
encoded = base64.b64encode(data).decode("utf-8") | |
return f"data:image/png;base64,{encoded}" if with_prefix else encoded | |
def decode_base64_image(base64_str: str) -> Image.Image: | |
# Remove the prefix if present (e.g., "data:image/png;base64,") | |
if base64_str.startswith("data:image"): | |
base64_str = base64_str.split(",")[1] | |
image_data = base64.b64decode(base64_str) | |
image = Image.open(BytesIO(image_data)) | |
return image | |
async def async_generate_text_and_image(prompt, category: Optional[str] = None): | |
# Get the appropriate instruction and configuration | |
instruction = get_instruction_for_category(category) | |
config = CATEGORY_CONFIGS.get(category.lower() if category else "", {}) | |
# Enhance the prompt with category-specific guidance if available | |
if config: | |
style_guide = config.get("style_guide", "") | |
conventions = config.get("conventions", []) | |
common_elements = config.get("common_elements", []) | |
enhanced_prompt = ( | |
f"{instruction}\n\n" | |
f"Style Guide: {style_guide}\n" | |
f"Drawing Conventions to Follow:\n- " + "\n- ".join(conventions) + "\n" | |
f"Consider Including These Elements:\n- " + "\n- ".join(common_elements) + "\n\n" | |
f"User Request: {prompt}" | |
) | |
else: | |
enhanced_prompt = f"{instruction}\n\nUser Request: {prompt}" | |
response = await client.aio.models.generate_content( | |
model=os.getenv("MODEL"), | |
contents=enhanced_prompt, | |
config=types.GenerateContentConfig( | |
response_modalities=['TEXT', 'IMAGE'] | |
) | |
) | |
for part in response.candidates[0].content.parts: | |
if hasattr(part, 'text') and part.text is not None: | |
# Try to parse the text into sections | |
try: | |
text_sections = {} | |
current_section = "overview" | |
lines = part.text.split('\n') | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
# Check for section headers | |
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): | |
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() | |
section_name = section_name.replace(' ', '_') | |
current_section = section_name | |
text_sections[current_section] = [] | |
else: | |
if current_section not in text_sections: | |
text_sections[current_section] = [] | |
text_sections[current_section].append(line) | |
# Clean up the sections | |
for section in text_sections: | |
text_sections[section] = '\n'.join(text_sections[section]).strip() | |
yield {'type': 'text', 'data': text_sections} | |
except Exception as e: | |
# Fallback to raw text if parsing fails | |
yield {'type': 'text', 'data': {'raw_text': part.text}} | |
elif hasattr(part, 'inline_data') and part.inline_data is not None: | |
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} | |
async def async_generate_with_image_input(text: Optional[str], image_path: str, category: Optional[str] = None): | |
# Validate that the image input is a base64 data URI | |
if not isinstance(image_path, str) or not image_path.startswith("data:image/"): | |
raise ValueError("Invalid image input: expected a base64 Data URI starting with 'data:image/'") | |
# Decode the base64 string into a PIL Image | |
image = decode_base64_image(image_path) | |
# Get the appropriate instruction for the category | |
instruction = get_instruction_for_category(category) | |
contents = [] | |
if text: | |
# Combine the instruction with the user's text input | |
combined_text = f"{instruction}\n\nUser Request: {text}" | |
contents.append(combined_text) | |
else: | |
contents.append(instruction) | |
contents.append(image) | |
response = await client.aio.models.generate_content( | |
model=os.getenv("MODEL"), | |
contents=contents, | |
config=types.GenerateContentConfig( | |
response_modalities=['TEXT', 'IMAGE'] | |
) | |
) | |
for part in response.candidates[0].content.parts: | |
if hasattr(part, 'text') and part.text is not None: | |
# Try to parse the text into sections | |
try: | |
text_sections = {} | |
current_section = "overview" | |
lines = part.text.split('\n') | |
for line in lines: | |
line = line.strip() | |
if not line: | |
continue | |
# Check for section headers | |
if any(line.lower().startswith(f"{i}.") for i in range(1, 6)): | |
section_name = line.split('.', 1)[1].split(':', 1)[0].strip().lower() | |
section_name = section_name.replace(' ', '_') | |
current_section = section_name | |
text_sections[current_section] = [] | |
else: | |
if current_section not in text_sections: | |
text_sections[current_section] = [] | |
text_sections[current_section].append(line) | |
# Clean up the sections | |
for section in text_sections: | |
text_sections[section] = '\n'.join(text_sections[section]).strip() | |
yield {'type': 'text', 'data': text_sections} | |
except Exception as e: | |
# Fallback to raw text if parsing fails | |
yield {'type': 'text', 'data': {'raw_text': part.text}} | |
elif hasattr(part, 'inline_data') and part.inline_data is not None: | |
yield {'type': 'image', 'data': bytes_to_base64(part.inline_data.data)} |