Spaces:
Sleeping
Sleeping
# --- Combined Imports ------------------------------------ | |
import io | |
import os | |
import re | |
import base64 | |
import glob | |
import logging | |
import random | |
import shutil | |
import time | |
import zipfile | |
import json | |
import asyncio | |
import aiofiles | |
from datetime import datetime | |
from collections import Counter | |
from dataclasses import dataclass | |
from io import BytesIO | |
from typing import Optional | |
import pandas as pd | |
import pytz | |
import streamlit as st | |
from PIL import Image | |
from reportlab.pdfgen import canvas | |
from reportlab.lib.utils import ImageReader | |
import fitz # PyMuPDF | |
# Conditional imports for optional/heavy libraries | |
try: | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel | |
_ai_libs_available = True | |
except ImportError: | |
_ai_libs_available = False | |
st.sidebar.warning("AI/ML libraries (torch, transformers, diffusers) not found. Some AI features disabled.") | |
try: | |
from openai import OpenAI | |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'), organization=os.getenv('OPENAI_ORG_ID')) | |
_openai_available = True | |
if not os.getenv('OPENAI_API_KEY'): | |
st.sidebar.warning("OpenAI API Key/Org ID not found in environment variables. GPT features disabled.") | |
_openai_available = False | |
except ImportError: | |
_openai_available = False | |
st.sidebar.warning("OpenAI library not found. GPT features disabled.") | |
except Exception as e: | |
_openai_available = False | |
st.sidebar.warning(f"OpenAI client error: {e}. GPT features disabled.") | |
import requests # Keep requests import | |
# --- Logging Setup (from App 2) -------------------------- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
log_records = [] | |
class LogCaptureHandler(logging.Handler): | |
def emit(self, record): | |
log_records.append(record) | |
logger.addHandler(LogCaptureHandler()) | |
# --- App Configuration (Choose one, adapted from App 2) --- | |
st.set_page_config( | |
page_title="Vision & Layout Titans ๐๐ผ๏ธ", | |
page_icon="๐ค", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
'Get Help': 'https://huggingface.co/awacke1', | |
'Report a Bug': 'https://huggingface.co/spaces/awacke1', | |
'About': "Combined App: Image->PDF Layout + AI Vision & SFT Titans ๐" | |
} | |
) | |
# --- Session State Initialization (Combined) ------------- | |
# From App 1 | |
st.session_state.setdefault('layout_snapshots', []) # Renamed to avoid potential conflict | |
# From App 2 | |
st.session_state.setdefault('history', []) | |
st.session_state.setdefault('builder', None) | |
st.session_state.setdefault('model_loaded', False) | |
st.session_state.setdefault('processing', {}) | |
st.session_state.setdefault('asset_checkboxes', {}) | |
st.session_state.setdefault('downloaded_pdfs', {}) | |
st.session_state.setdefault('unique_counter', 0) | |
st.session_state.setdefault('selected_model_type', "Causal LM") | |
st.session_state.setdefault('selected_model', "None") | |
st.session_state.setdefault('cam0_file', None) | |
st.session_state.setdefault('cam1_file', None) | |
st.session_state.setdefault('characters', []) | |
st.session_state.setdefault('char_form_reset', False) | |
if 'asset_gallery_container' not in st.session_state: | |
st.session_state['asset_gallery_container'] = st.sidebar.empty() | |
st.session_state.setdefault('gallery_size', 2) # From App 2 gallery settings | |
# --- Dataclasses (from App 2) ---------------------------- | |
class ModelConfig: | |
name: str | |
base_model: str | |
size: str | |
domain: Optional[str] = None | |
model_type: str = "causal_lm" | |
def model_path(self): | |
return f"models/{self.name}" | |
class DiffusionConfig: | |
name: str | |
base_model: str | |
size: str | |
domain: Optional[str] = None | |
def model_path(self): | |
return f"diffusion_models/{self.name}" | |
# --- Class Definitions (from App 2) ----------------------- | |
# Simplified ModelBuilder and DiffusionBuilder if libraries are missing | |
if _ai_libs_available: | |
class ModelBuilder: | |
def __init__(self): | |
self.config = None | |
self.model = None | |
self.tokenizer = None | |
self.jokes = [ | |
"Why did the AI go to therapy? Too many layers to unpack! ๐", | |
"Training complete! Time for a binary coffee break. โ", | |
"I told my neural network a joke; it couldn't stop dropping bits! ๐ค", | |
"I asked the AI for a pun, and it said, 'I'm punning on parallel processing!' ๐", | |
"Debugging my code is like a stand-up routineโalways a series of exceptions! ๐" | |
] | |
def load_model(self, model_path: str, config: Optional[ModelConfig] = None): | |
with st.spinner(f"Loading {model_path}... โณ"): | |
self.model = AutoModelForCausalLM.from_pretrained(model_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
if config: | |
self.config = config | |
self.model.to("cuda" if torch.cuda.is_available() else "cpu") | |
st.success(f"Model loaded! ๐ {random.choice(self.jokes)}") | |
return self | |
def save_model(self, path: str): | |
with st.spinner("Saving model... ๐พ"): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
self.model.save_pretrained(path) | |
self.tokenizer.save_pretrained(path) | |
st.success(f"Model saved at {path}! โ ") | |
class DiffusionBuilder: | |
def __init__(self): | |
self.config = None | |
self.pipeline = None | |
def load_model(self, model_path: str, config: Optional[DiffusionConfig] = None): | |
with st.spinner(f"Loading diffusion model {model_path}... โณ"): | |
# Use float32 for broader compatibility, esp. CPU | |
self.pipeline = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda" if torch.cuda.is_available() else "cpu") | |
if config: | |
self.config = config | |
st.success("Diffusion model loaded! ๐จ") | |
return self | |
def save_model(self, path: str): | |
with st.spinner("Saving diffusion model... ๐พ"): | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
self.pipeline.save_pretrained(path) | |
st.success(f"Diffusion model saved at {path}! โ ") | |
def generate(self, prompt: str): | |
# Adjust steps for CPU if needed | |
steps = 10 if torch.cuda.is_available() else 5 # Fewer steps for CPU demo | |
with st.spinner(f"Generating image with {steps} steps..."): | |
image = self.pipeline(prompt, num_inference_steps=steps).images[0] | |
return image | |
else: # Placeholder classes if AI libs are missing | |
class ModelBuilder: | |
def __init__(self): st.error("AI Libraries not available.") | |
def load_model(self, *args, **kwargs): pass | |
def save_model(self, *args, **kwargs): pass | |
class DiffusionBuilder: | |
def __init__(self): st.error("AI Libraries not available.") | |
def load_model(self, *args, **kwargs): pass | |
def save_model(self, *args, **kwargs): pass | |
def generate(self, *args, **kwargs): return Image.new("RGB", (64,64), "gray") | |
# --- Helper Functions (Combined and refined) ------------- | |
def generate_filename(sequence, ext="png"): | |
# Use App 2's more robust version | |
timestamp = time.strftime('%Y%m%d_%H%M%S') | |
# Sanitize sequence name for filename | |
safe_sequence = re.sub(r'[^\w\-]+', '_', str(sequence)) | |
return f"{safe_sequence}_{timestamp}.{ext}" | |
def pdf_url_to_filename(url): | |
# Use App 2's version | |
# Further sanitize - remove http(s) prefix and limit length | |
name = re.sub(r'^https?://', '', url) | |
name = re.sub(r'[<>:"/\\|?*]', '_', name) | |
return name[:100] + ".pdf" # Limit length | |
def get_download_link(file_path, mime_type="application/octet-stream", label="Download"): | |
# Use App 2's version, ensure file exists | |
if not os.path.exists(file_path): | |
return f"{label} (File not found)" | |
try: | |
with open(file_path, "rb") as f: | |
file_bytes = f.read() | |
b64 = base64.b64encode(file_bytes).decode() | |
return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>' | |
except Exception as e: | |
logger.error(f"Error creating download link for {file_path}: {e}") | |
return f"{label} (Error)" | |
def zip_directory(directory_path, zip_path): | |
# Use App 2's version | |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
for root, _, files in os.walk(directory_path): | |
for file in files: | |
file_path = os.path.join(root, file) | |
zipf.write(file_path, os.path.relpath(file_path, os.path.dirname(directory_path))) | |
def get_model_files(model_type="causal_lm"): | |
# Use App 2's version | |
pattern = "models/*" if model_type == "causal_lm" else "diffusion_models/*" | |
dirs = [d for d in glob.glob(pattern) if os.path.isdir(d)] | |
return dirs if dirs else ["None"] | |
def get_gallery_files(file_types=("png", "pdf", "jpg", "jpeg", "md", "txt")): # Expanded types | |
# Use App 2's version, ensure lowercase extensions | |
all_files = set() | |
for ext in file_types: | |
all_files.update(glob.glob(f"*.{ext.lower()}")) | |
all_files.update(glob.glob(f"*.{ext.upper()}")) # Include uppercase extensions too | |
return sorted(list(all_files)) | |
def get_pdf_files(): | |
# Use App 2's version | |
return sorted(glob.glob("*.pdf") + glob.glob("*.PDF")) | |
def download_pdf(url, output_path): | |
# Use App 2's version | |
try: | |
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'} | |
response = requests.get(url, stream=True, timeout=20, headers=headers) # Added user-agent, longer timeout | |
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) | |
with open(output_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
logger.info(f"Successfully downloaded {url} to {output_path}") | |
return True | |
except requests.exceptions.RequestException as e: | |
logger.error(f"Failed to download {url}: {e}") | |
# Attempt to remove partially downloaded file | |
if os.path.exists(output_path): | |
try: | |
os.remove(output_path) | |
logger.info(f"Removed partially downloaded file: {output_path}") | |
except OSError as remove_error: | |
logger.error(f"Error removing partial file {output_path}: {remove_error}") | |
return False | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during download of {url}: {e}") | |
if os.path.exists(output_path): | |
try: os.remove(output_path) | |
except: pass | |
return False | |
async def process_pdf_snapshot(pdf_path, mode="single", resolution_factor=2.0): | |
# Use App 2's version, added resolution control | |
start_time = time.time() | |
# Use a placeholder within the main app area for status during async operations | |
status_placeholder = st.empty() | |
status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... (0s)") | |
output_files = [] | |
try: | |
doc = fitz.open(pdf_path) | |
matrix = fitz.Matrix(resolution_factor, resolution_factor) | |
num_pages_to_process = 0 | |
if mode == "single": | |
num_pages_to_process = min(1, len(doc)) | |
elif mode == "twopage": | |
num_pages_to_process = min(2, len(doc)) | |
elif mode == "allpages": | |
num_pages_to_process = len(doc) | |
for i in range(num_pages_to_process): | |
page_start_time = time.time() | |
page = doc[i] | |
pix = page.get_pixmap(matrix=matrix) | |
# Use PDF name and page number in filename for clarity | |
base_name = os.path.splitext(os.path.basename(pdf_path))[0] | |
output_file = generate_filename(f"{base_name}_pg{i+1}_{mode}", "png") | |
await asyncio.to_thread(pix.save, output_file) # Run sync save in thread | |
output_files.append(output_file) | |
elapsed_page = int(time.time() - page_start_time) | |
status_placeholder.text(f"Processing PDF Snapshot ({mode}, Res: {resolution_factor}x)... Page {i+1}/{num_pages_to_process} done ({elapsed_page}s)") | |
await asyncio.sleep(0.01) # Yield control briefly | |
doc.close() | |
elapsed = int(time.time() - start_time) | |
status_placeholder.success(f"PDF Snapshot ({mode}, {len(output_files)} files) completed in {elapsed}s!") | |
return output_files | |
except Exception as e: | |
logger.error(f"Failed to process PDF snapshot for {pdf_path}: {e}") | |
status_placeholder.error(f"Failed to process PDF {os.path.basename(pdf_path)}: {e}") | |
# Clean up any files created before the error | |
for f in output_files: | |
if os.path.exists(f): os.remove(f) | |
return [] | |
async def process_gpt4o_ocr(image: Image.Image, output_file: str): | |
# Use App 2's version, check for OpenAI availability | |
if not _openai_available: | |
st.error("OpenAI OCR requires API key and library.") | |
return "" | |
start_time = time.time() | |
status_placeholder = st.empty() | |
status_placeholder.text("Processing GPT-4o OCR... (0s)") | |
buffered = BytesIO() | |
# Ensure image is in a compatible format (e.g., PNG, JPEG) | |
save_format = "PNG" if image.format != "JPEG" else "JPEG" | |
image.save(buffered, format=save_format) | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": "Extract text content from the image. Provide only the extracted text."}, # More specific prompt | |
{"type": "image_url", "image_url": {"url": f"data:image/{save_format.lower()};base64,{img_str}", "detail": "auto"}} | |
] | |
}] | |
try: | |
# Run OpenAI call in a separate thread to avoid blocking Streamlit's event loop | |
response = await asyncio.to_thread( | |
client.chat.completions.create, | |
model="gpt-4o", messages=messages, max_tokens=4000 # Increased tokens | |
) | |
result = response.choices[0].message.content or "" # Handle potential None result | |
elapsed = int(time.time() - start_time) | |
status_placeholder.success(f"GPT-4o OCR completed in {elapsed}s!") | |
async with aiofiles.open(output_file, "w", encoding='utf-8') as f: # Specify encoding | |
await f.write(result) | |
logger.info(f"GPT-4o OCR successful for {output_file}") | |
return result | |
except Exception as e: | |
logger.error(f"Failed to process image with GPT-4o: {e}") | |
status_placeholder.error(f"GPT-4o OCR Failed: {e}") | |
return f"Error during OCR: {str(e)}" | |
async def process_image_gen(prompt: str, output_file: str): | |
# Use App 2's version, check AI lib availability | |
if not _ai_libs_available: | |
st.error("Image Generation requires AI libraries.") | |
img = Image.new("RGB", (256, 256), "lightgray") | |
draw = ImageDraw.Draw(img) | |
draw.text((10, 10), "AI libs missing", fill="black") | |
img.save(output_file) | |
return img | |
start_time = time.time() | |
status_placeholder = st.empty() | |
status_placeholder.text("Processing Image Gen... (0s)") | |
# Ensure a pipeline is loaded, default to small one if necessary | |
pipeline = None | |
if st.session_state.get('builder') and isinstance(st.session_state['builder'], DiffusionBuilder) and st.session_state['builder'].pipeline: | |
pipeline = st.session_state['builder'].pipeline | |
else: | |
try: | |
with st.spinner("Loading default small diffusion model..."): | |
pipeline = StableDiffusionPipeline.from_pretrained("OFA-Sys/small-stable-diffusion-v0", torch_dtype=torch.float32).to("cuda" if torch.cuda.is_available() else "cpu") | |
st.info("Loaded default small diffusion model for image generation.") | |
except Exception as e: | |
logger.error(f"Failed to load default diffusion model: {e}") | |
status_placeholder.error(f"Failed to load default diffusion model: {e}") | |
img = Image.new("RGB", (256, 256), "lightgray") | |
draw = ImageDraw.Draw(img) | |
draw.text((10, 10), "Model load error", fill="black") | |
img.save(output_file) | |
return img | |
try: | |
# Run generation in a thread | |
gen_image = await asyncio.to_thread(pipeline, prompt, num_inference_steps=15) # Slightly more steps | |
gen_image = gen_image.images[0] # Extract image from list | |
elapsed = int(time.time() - start_time) | |
status_placeholder.success(f"Image Gen completed in {elapsed}s!") | |
await asyncio.to_thread(gen_image.save, output_file) # Save in thread | |
logger.info(f"Image generation successful for {output_file}") | |
return gen_image | |
except Exception as e: | |
logger.error(f"Image generation failed: {e}") | |
status_placeholder.error(f"Image generation failed: {e}") | |
# Create placeholder error image | |
img = Image.new("RGB", (256, 256), "lightgray") | |
from PIL import ImageDraw | |
draw = ImageDraw.Draw(img) | |
draw.text((10, 10), f"Generation Error:\n{e}", fill="red") | |
await asyncio.to_thread(img.save, output_file) | |
return img | |
# --- GPT Processing Functions (from App 2, with checks) --- | |
def process_image_with_prompt(image: Image.Image, prompt: str, model="gpt-4o-mini", detail="auto"): | |
if not _openai_available: return "Error: OpenAI features disabled." | |
status_placeholder = st.empty() | |
status_placeholder.info(f"Processing image with GPT ({model})...") | |
buffered = BytesIO() | |
save_format = "PNG" if image.format != "JPEG" else "JPEG" | |
image.save(buffered, format=save_format) | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
messages = [{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt}, | |
{"type": "image_url", "image_url": {"url": f"data:image/{save_format.lower()};base64,{img_str}", "detail": detail}} | |
] | |
}] | |
try: | |
response = client.chat.completions.create(model=model, messages=messages, max_tokens=1000) # Increased tokens | |
result = response.choices[0].message.content or "" | |
status_placeholder.success(f"GPT ({model}) image processing complete.") | |
logger.info(f"GPT ({model}) image processing successful.") | |
return result | |
except Exception as e: | |
logger.error(f"Error processing image with GPT ({model}): {e}") | |
status_placeholder.error(f"Error processing image with GPT ({model}): {e}") | |
return f"Error processing image with GPT: {str(e)}" | |
def process_text_with_prompt(text: str, prompt: str, model="gpt-4o-mini"): | |
if not _openai_available: return "Error: OpenAI features disabled." | |
status_placeholder = st.empty() | |
status_placeholder.info(f"Processing text with GPT ({model})...") | |
messages = [{"role": "user", "content": f"{prompt}\n\n---\n\n{text}"}] # Added separator | |
try: | |
response = client.chat.completions.create(model=model, messages=messages, max_tokens=2000) # Increased tokens | |
result = response.choices[0].message.content or "" | |
status_placeholder.success(f"GPT ({model}) text processing complete.") | |
logger.info(f"GPT ({model}) text processing successful.") | |
return result | |
except Exception as e: | |
logger.error(f"Error processing text with GPT ({model}): {e}") | |
status_placeholder.error(f"Error processing text with GPT ({model}): {e}") | |
return f"Error processing text with GPT: {str(e)}" | |
# --- Character Functions (from App 2) -------------------- | |
def randomize_character_content(): | |
# Use App 2's version | |
intro_templates = [ | |
"{char} is a valiant knight who is silent and reserved, he looks handsome but aloof.", | |
"{char} is a mischievous thief with a heart of gold, always sneaking around but helping those in need.", | |
"{char} is a wise scholar who loves books more than people, often lost in thought.", | |
"{char} is a fiery warrior with a short temper, but fiercely loyal to friends.", | |
"{char} is a gentle healer who speaks softly, always carrying herbs and a warm smile." | |
] | |
greeting_templates = [ | |
"You were startled by the sudden intrusion of a man into your home. 'I am from the knight's guild, and I have been ordered to arrest you.'", | |
"A shadowy figure steps into the light. 'I heard you needed helpโnameโs {char}, best thief in town.'", | |
"A voice calls from behind a stack of books. 'Oh, hello! Iโm {char}, didnโt see you thereโtoo many scrolls!'", | |
"A booming voice echoes, 'Iโm {char}, and Iโm here to fight for justiceโor at least a good brawl!'", | |
"A soft hand touches your shoulder. 'Iโm {char}, here to heal your woundsโdonโt worry, Iโve got you.'" | |
] | |
name = f"Character_{random.randint(1000, 9999)}" | |
gender = random.choice(["Male", "Female"]) | |
intro = random.choice(intro_templates).format(char=name) | |
greeting = random.choice(greeting_templates).format(char=name) | |
return name, gender, intro, greeting | |
def save_character(character_data): | |
# Use App 2's version | |
characters = st.session_state.get('characters', []) | |
# Prevent duplicate names | |
if any(c['name'] == character_data['name'] for c in characters): | |
st.error(f"Character name '{character_data['name']}' already exists.") | |
return False | |
characters.append(character_data) | |
st.session_state['characters'] = characters | |
try: | |
with open("characters.json", "w", encoding='utf-8') as f: | |
json.dump(characters, f, indent=2) # Added indent for readability | |
logger.info(f"Saved character: {character_data['name']}") | |
return True | |
except IOError as e: | |
logger.error(f"Failed to save characters.json: {e}") | |
st.error(f"Failed to save character file: {e}") | |
return False | |
def load_characters(): | |
# Use App 2's version | |
if not os.path.exists("characters.json"): | |
st.session_state['characters'] = [] | |
return | |
try: | |
with open("characters.json", "r", encoding='utf-8') as f: | |
characters = json.load(f) | |
# Basic validation | |
if isinstance(characters, list): | |
st.session_state['characters'] = characters | |
logger.info(f"Loaded {len(characters)} characters.") | |
else: | |
st.session_state['characters'] = [] | |
logger.warning("characters.json is not a list, resetting.") | |
os.remove("characters.json") # Remove invalid file | |
except (json.JSONDecodeError, IOError) as e: | |
logger.error(f"Failed to load or decode characters.json: {e}") | |
st.error(f"Error loading character file: {e}. Starting fresh.") | |
st.session_state['characters'] = [] | |
# Attempt to backup corrupted file | |
try: | |
corrupt_filename = f"characters_corrupt_{int(time.time())}.json" | |
shutil.copy("characters.json", corrupt_filename) | |
logger.info(f"Backed up corrupted character file to {corrupt_filename}") | |
os.remove("characters.json") | |
except Exception as backup_e: | |
logger.error(f"Could not backup corrupted character file: {backup_e}") | |
# --- Utility: Clean stems (from App 1, needed for Image->PDF tab) --- | |
def clean_stem(fn: str) -> str: | |
# Make it slightly more robust | |
name = os.path.splitext(os.path.basename(fn))[0] | |
name = name.replace('-', ' ').replace('_', ' ') | |
# Remove common prefixes/suffixes if desired (optional) | |
# name = re.sub(r'^(scan|img|image)_?', '', name, flags=re.IGNORECASE) | |
# name = re.sub(r'_?\d+$', '', name) # Remove trailing numbers | |
return name.strip().title() # Title case | |
# --- PDF Creation: Image Sized + Captions (from App 1) --- | |
def make_image_sized_pdf(sources): | |
if not sources: | |
st.warning("No image sources provided for PDF generation.") | |
return None | |
buf = io.BytesIO() | |
# Use A4 size initially, will be overridden per page | |
c = canvas.Canvas(buf, pagesize=(595.27, 841.89)) # Default A4 | |
try: | |
for idx, src in enumerate(sources, start=1): | |
status_placeholder = st.empty() | |
status_placeholder.info(f"Adding page {idx}/{len(sources)}: {os.path.basename(str(src))}...") | |
try: | |
# Handle both file paths and uploaded file objects | |
if isinstance(src, str): # path | |
if not os.path.exists(src): | |
logger.warning(f"Image file not found: {src}. Skipping.") | |
status_placeholder.warning(f"Skipping missing file: {os.path.basename(src)}") | |
continue | |
img_obj = Image.open(src) | |
filename = os.path.basename(src) | |
else: # uploaded file object (BytesIO wrapper) | |
src.seek(0) # Ensure reading from start | |
img_obj = Image.open(src) | |
filename = getattr(src, 'name', f'uploaded_image_{idx}') | |
src.seek(0) # Reset again just in case needed later | |
with img_obj: # Use context manager for PIL Image | |
iw, ih = img_obj.size | |
if iw <= 0 or ih <= 0: | |
logger.warning(f"Invalid image dimensions ({iw}x{ih}) for {filename}. Skipping.") | |
status_placeholder.warning(f"Skipping invalid image: {filename}") | |
continue | |
cap_h = 30 # Increased caption height | |
# Set page size based on image + caption height | |
pw, ph = iw, ih + cap_h | |
c.setPageSize((pw, ph)) | |
# Draw image, ensuring it fits within iw, ih space above caption | |
# Use ImageReader for efficiency with ReportLab | |
img_reader = ImageReader(img_obj) | |
c.drawImage(img_reader, 0, cap_h, width=iw, height=ih, preserveAspectRatio=True, anchor='c', mask='auto') | |
# Draw Caption (cleaned filename) | |
caption = clean_stem(filename) | |
c.setFont('Helvetica', 12) | |
c.setFillColorRGB(0, 0, 0) # Black text | |
c.drawCentredString(pw / 2, cap_h / 2 + 3, caption) # Center vertically too | |
# Draw Page Number | |
c.setFont('Helvetica', 8) | |
c.setFillColorRGB(0.5, 0.5, 0.5) # Gray text | |
c.drawRightString(pw - 10, 8, f"Page {idx}") | |
c.showPage() # Finalize the page | |
status_placeholder.success(f"Added page {idx}/{len(sources)}: {filename}") | |
except (IOError, OSError, UnidentifiedImageError) as img_err: | |
logger.error(f"Error processing image {src}: {img_err}") | |
status_placeholder.error(f"Error adding page {idx}: {img_err}") | |
except Exception as e: | |
logger.error(f"Unexpected error adding page {idx} ({src}): {e}") | |
status_placeholder.error(f"Unexpected error on page {idx}: {e}") | |
c.save() | |
buf.seek(0) | |
if buf.getbuffer().nbytes < 100: # Check if PDF is basically empty | |
st.error("PDF generation resulted in an empty file. Check image files.") | |
return None | |
return buf.getvalue() | |
except Exception as e: | |
logger.error(f"Fatal error during PDF generation: {e}") | |
st.error(f"PDF Generation Failed: {e}") | |
return None | |
# --- Sidebar Gallery Update Function (from App 2) -------- | |
def update_gallery(): | |
container = st.session_state['asset_gallery_container'] | |
container.empty() # Clear previous gallery rendering | |
with container.container(): # Use a container to manage layout | |
st.markdown("### Asset Gallery ๐ธ๐") | |
st.session_state['gallery_size'] = st.slider("Max Items Shown", 2, 50, st.session_state.get('gallery_size', 10), key="gallery_size_slider") | |
cols = st.columns(2) # Use 2 columns in the sidebar | |
all_files = get_gallery_files() # Get currently available files | |
if not all_files: | |
st.info("No assets (images, PDFs, text files) found yet.") | |
return | |
files_to_display = all_files[:st.session_state['gallery_size']] | |
for idx, file in enumerate(files_to_display): | |
with cols[idx % 2]: | |
st.session_state['unique_counter'] += 1 | |
unique_id = st.session_state['unique_counter'] | |
basename = os.path.basename(file) | |
st.caption(basename) # Show filename as caption above preview | |
try: | |
file_ext = os.path.splitext(file)[1].lower() | |
if file_ext in ['.png', '.jpg', '.jpeg']: | |
st.image(Image.open(file), use_container_width=True) | |
elif file_ext == '.pdf': | |
doc = fitz.open(file) | |
# Generate preview only if file opens successfully | |
if len(doc) > 0: | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(0.5, 0.5)) # Smaller preview | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
st.image(img, use_container_width=True) | |
else: | |
st.warning("Empty PDF") | |
doc.close() | |
elif file_ext in ['.md', '.txt']: | |
with open(file, 'r', encoding='utf-8', errors='ignore') as f: | |
content_preview = f.read(200) # Show first 200 chars | |
st.code(content_preview + "...", language='markdown' if file_ext == '.md' else 'text') | |
# Actions for the file | |
checkbox_key = f"asset_cb_{file}_{unique_id}" | |
# Use get to safely access potentially missing keys after deletion | |
st.session_state['asset_checkboxes'][file] = st.checkbox( | |
"Select", | |
value=st.session_state['asset_checkboxes'].get(file, False), | |
key=checkbox_key | |
) | |
mime_map = {'.png': 'image/png', '.jpg': 'image/jpeg', '.jpeg': 'image/jpeg', '.pdf': 'application/pdf', '.txt': 'text/plain', '.md': 'text/markdown'} | |
mime_type = mime_map.get(file_ext, "application/octet-stream") | |
st.markdown(get_download_link(file, mime_type, "๐ฅ"), unsafe_allow_html=True) | |
delete_key = f"delete_btn_{file}_{unique_id}" | |
if st.button("๐๏ธ", key=delete_key, help=f"Delete {basename}"): | |
try: | |
os.remove(file) | |
st.session_state['asset_checkboxes'].pop(file, None) # Remove from selection state | |
# Remove from layout_snapshots if present | |
if file in st.session_state.get('layout_snapshots', []): | |
st.session_state['layout_snapshots'].remove(file) | |
logger.info(f"Deleted asset: {file}") | |
st.success(f"Deleted {basename}") | |
st.rerun() # Rerun to refresh the gallery immediately | |
except OSError as e: | |
logger.error(f"Error deleting file {file}: {e}") | |
st.error(f"Could not delete {basename}") | |
except (fitz.fitz.FileNotFoundError, FileNotFoundError): | |
st.error(f"File not found: {basename}") | |
# Clean up state if file is missing | |
st.session_state['asset_checkboxes'].pop(file, None) | |
if file in st.session_state.get('layout_snapshots', []): | |
st.session_state['layout_snapshots'].remove(file) | |
except (fitz.fitz.FileDataError, fitz.fitz.RuntimeException) as pdf_err: | |
st.error(f"Corrupt PDF: {basename}") | |
logger.warning(f"Error opening PDF {file}: {pdf_err}") | |
except UnidentifiedImageError: | |
st.error(f"Invalid Image: {basename}") | |
logger.warning(f"Cannot identify image file {file}") | |
except Exception as e: | |
st.error(f"Error: {basename}") | |
logger.error(f"Error displaying asset {file}: {e}") | |
st.markdown("---") # Separator between items | |
if len(all_files) > st.session_state['gallery_size']: | |
st.caption(f"Showing {st.session_state['gallery_size']} of {len(all_files)} assets.") | |
# --- App Title ------------------------------------------- | |
st.title("Vision & Layout Titans ๐๐ผ๏ธ๐") | |
st.markdown("Combined App: AI Vision/SFT Tools + Image-to-PDF Layout Generator") | |
# --- Main Application Tabs ------------------------------- | |
tab_list = [ | |
"Image->PDF Layout ๐ผ๏ธโก๏ธ๐", # Added from App 1 | |
"Camera Snap ๐ท", | |
"Download PDFs ๐ฅ", | |
"PDF Process ๐", | |
"Image Process ๐ผ๏ธ", | |
"Test OCR ๐", | |
"MD Gallery & Process ๐", | |
"Build Titan ๐ฑ", | |
"Test Image Gen ๐จ", | |
"Character Editor ๐งโ๐จ", | |
"Character Gallery ๐ผ๏ธ" | |
] | |
tabs = st.tabs(tab_list) | |
# --- Tab 1: Image -> PDF Layout (from App 1) ------------- | |
with tabs[0]: | |
st.header("Image to PDF Layout Generator") | |
st.markdown("Upload or scan images, reorder them, and generate a PDF where each page matches the image dimensions and includes a simple caption.") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("A. Scan or Upload Images") | |
# Camera scan specific to this tab | |
layout_cam = st.camera_input("๐ธ Scan Document for Layout PDF", key="layout_cam") | |
if layout_cam: | |
central = pytz.timezone("US/Central") # Consider making timezone configurable | |
now = datetime.now(central) | |
# Use generate_filename helper | |
scan_name = generate_filename(f"layout_scan_{now.strftime('%a').upper()}", "png") | |
try: | |
# Save the uploaded file content | |
with open(scan_name, "wb") as f: | |
f.write(layout_cam.getvalue()) | |
st.image(Image.open(scan_name), caption=f"Scanned: {scan_name}", use_container_width=True) | |
if scan_name not in st.session_state['layout_snapshots']: | |
st.session_state['layout_snapshots'].append(scan_name) | |
st.success(f"Scan saved as {scan_name}") | |
# No rerun needed, handled by Streamlit's camera widget update | |
except Exception as e: | |
st.error(f"Failed to save scan: {e}") | |
logger.error(f"Failed to save camera scan {scan_name}: {e}") | |
# File uploader specific to this tab | |
layout_uploads = st.file_uploader( | |
"๐ Upload PNG/JPG Images for Layout PDF", type=["png","jpg","jpeg"], | |
accept_multiple_files=True, key="layout_uploader" | |
) | |
# Display uploaded images immediately | |
if layout_uploads: | |
st.write(f"Uploaded {len(layout_uploads)} images:") | |
# Keep track of newly uploaded file objects for the DataFrame | |
st.session_state['layout_new_uploads'] = layout_uploads | |
with col2: | |
st.subheader("B. Review and Reorder") | |
# --- Build combined list for this tab's purpose --- | |
layout_records = [] | |
# From layout-specific snapshots | |
processed_snapshots = set() # Keep track to avoid duplicates if script reruns | |
for idx, path in enumerate(st.session_state.get('layout_snapshots', [])): | |
if path not in processed_snapshots and os.path.exists(path): | |
try: | |
with Image.open(path) as im: | |
w, h = im.size | |
ar = round(w / h, 2) if h > 0 else 0 | |
orient = "Square" if 0.9 <= ar <= 1.1 else ("Landscape" if ar > 1.1 else "Portrait") | |
layout_records.append({ | |
"filename": os.path.basename(path), | |
"source": path, # Store path for snapshots | |
"width": w, | |
"height": h, | |
"aspect_ratio": ar, | |
"orientation": orient, | |
"order": idx, # Initial order based on addition | |
"type": "Scan" | |
}) | |
processed_snapshots.add(path) | |
except Exception as e: | |
logger.warning(f"Could not process snapshot {path}: {e}") | |
st.warning(f"Skipping invalid snapshot: {os.path.basename(path)}") | |
# From layout-specific uploads (use the file objects directly) | |
# Access the newly uploaded files from session state if they exist | |
current_uploads = st.session_state.get('layout_new_uploads', []) | |
if current_uploads: | |
start_idx = len(layout_records) | |
for jdx, f_obj in enumerate(current_uploads, start=start_idx): | |
try: | |
f_obj.seek(0) # Reset pointer | |
with Image.open(f_obj) as im: | |
w, h = im.size | |
ar = round(w / h, 2) if h > 0 else 0 | |
orient = "Square" if 0.9 <= ar <= 1.1 else ("Landscape" if ar > 1.1 else "Portrait") | |
layout_records.append({ | |
"filename": f_obj.name, | |
"source": f_obj, # Store file object for uploads | |
"width": w, | |
"height": h, | |
"aspect_ratio": ar, | |
"orientation": orient, | |
"order": jdx, # Initial order | |
"type": "Upload" | |
}) | |
f_obj.seek(0) # Reset pointer again for potential later use | |
except Exception as e: | |
logger.warning(f"Could not process uploaded file {f_obj.name}: {e}") | |
st.warning(f"Skipping invalid upload: {f_obj.name}") | |
if not layout_records: | |
st.info("Scan or upload images using the controls on the left.") | |
else: | |
# Create DataFrame | |
layout_df = pd.DataFrame(layout_records) | |
# Filter Options (moved here for clarity) | |
st.markdown("Filter by Orientation:") | |
dims = st.multiselect( | |
"Include orientations:", options=["Landscape","Portrait","Square"], | |
default=["Landscape","Portrait","Square"], key="layout_dims_filter" | |
) | |
if dims: # Apply filter only if options are selected | |
filtered_df = layout_df[layout_df['orientation'].isin(dims)].copy() # Use copy to avoid SettingWithCopyWarning | |
else: | |
filtered_df = layout_df.copy() # No filter applied | |
# Ensure 'order' column is integer for editing/sorting | |
filtered_df['order'] = filtered_df['order'].astype(int) | |
filtered_df = filtered_df.sort_values('order').reset_index(drop=True) | |
st.markdown("Edit 'Order' column or drag rows to set PDF page sequence:") | |
# Use st.data_editor for reordering | |
edited_df = st.data_editor( | |
filtered_df, | |
column_config={ | |
"filename": st.column_config.TextColumn("Filename", disabled=True), | |
"source": None, # Hide source column | |
"width": st.column_config.NumberColumn("Width", disabled=True), | |
"height": st.column_config.NumberColumn("Height", disabled=True), | |
"aspect_ratio": st.column_config.NumberColumn("Aspect Ratio", format="%.2f", disabled=True), | |
"orientation": st.column_config.TextColumn("Orientation", disabled=True), | |
"type": st.column_config.TextColumn("Source Type", disabled=True), | |
"order": st.column_config.NumberColumn("Order", min_value=0, step=1, required=True), | |
}, | |
hide_index=True, | |
use_container_width=True, | |
num_rows="dynamic", # Allow sorting/reordering by drag-and-drop | |
key="layout_editor" | |
) | |
# Sort by the edited 'order' column to get the final sequence | |
ordered_layout_df = edited_df.sort_values('order').reset_index(drop=True) | |
# Extract the sources in the correct order for PDF generation | |
# Need to handle both file paths (str) and uploaded file objects | |
ordered_sources_for_pdf = ordered_layout_df['source'].tolist() | |
# --- Generate & Download --- | |
st.subheader("C. Generate & Download PDF") | |
if st.button("๐๏ธ Generate Image-Sized PDF", key="generate_layout_pdf"): | |
if not ordered_sources_for_pdf: | |
st.warning("No images selected or available after filtering.") | |
else: | |
with st.spinner("Generating PDF... This might take a while for many images."): | |
pdf_bytes = make_image_sized_pdf(ordered_sources_for_pdf) | |
if pdf_bytes: | |
# Create filename for the PDF | |
central = pytz.timezone("US/Central") # Use same timezone | |
now = datetime.now(central) | |
prefix = now.strftime("%Y%m%d-%H%M%p") | |
# Create a basename from first few image names | |
stems = [] | |
for src in ordered_sources_for_pdf[:4]: # Limit to first 4 | |
if isinstance(src, str): stems.append(clean_stem(src)) | |
else: stems.append(clean_stem(getattr(src, 'name', 'upload'))) | |
basename = " - ".join(stems) | |
if not basename: basename = "Layout" # Fallback name | |
pdf_fname = f"{prefix}_{basename}.pdf" | |
pdf_fname = re.sub(r'[^\w\- \.]', '_', pdf_fname) # Sanitize filename | |
st.success(f"โ PDF ready: **{pdf_fname}**") | |
st.download_button( | |
"โฌ๏ธ Download PDF", | |
data=pdf_bytes, | |
file_name=pdf_fname, | |
mime="application/pdf", | |
key="download_layout_pdf" | |
) | |
# Add PDF Preview | |
st.markdown("#### Preview First Page") | |
try: | |
doc = fitz.open(stream=pdf_bytes, filetype='pdf') | |
if len(doc) > 0: | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(1.0, 1.0)) # Standard resolution preview | |
preview_img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
st.image(preview_img, caption=f"Preview of {pdf_fname} (Page 1)", use_container_width=True) | |
else: | |
st.warning("Generated PDF appears empty.") | |
doc.close() | |
except ImportError: | |
st.info("Install PyMuPDF (`pip install pymupdf`) to enable PDF previews.") | |
except Exception as preview_err: | |
st.warning(f"Could not generate PDF preview: {preview_err}") | |
logger.warning(f"PDF preview error for {pdf_fname}: {preview_err}") | |
else: | |
st.error("PDF generation failed. Check logs or image files.") | |
# --- Remaining Tabs (from App 2, adapted) ---------------- | |
# --- Tab: Camera Snap --- | |
with tabs[1]: | |
st.header("Camera Snap ๐ท") | |
st.subheader("Single Capture (Adds to General Gallery)") | |
cols = st.columns(2) | |
with cols[0]: | |
cam0_img = st.camera_input("Take a picture - Cam 0", key="main_cam0") | |
if cam0_img: | |
# Use generate_filename helper | |
filename = generate_filename("cam0_snap") | |
# Remove previous file for this camera if it exists | |
if st.session_state.get('cam0_file') and os.path.exists(st.session_state['cam0_file']): | |
try: os.remove(st.session_state['cam0_file']) | |
except OSError: pass # Ignore error if file is already gone | |
try: | |
with open(filename, "wb") as f: f.write(cam0_img.getvalue()) | |
st.session_state['cam0_file'] = filename | |
st.session_state['history'].append(f"Snapshot from Cam 0: {filename}") | |
st.image(Image.open(filename), caption="Camera 0 Snap", use_container_width=True) | |
logger.info(f"Saved snapshot from Camera 0: {filename}") | |
st.success(f"Saved {filename}") | |
update_gallery() # Update sidebar gallery | |
st.rerun() # Rerun to reflect change immediately in gallery | |
except Exception as e: | |
st.error(f"Failed to save Cam 0 snap: {e}") | |
logger.error(f"Failed to save Cam 0 snap {filename}: {e}") | |
with cols[1]: | |
cam1_img = st.camera_input("Take a picture - Cam 1", key="main_cam1") | |
if cam1_img: | |
filename = generate_filename("cam1_snap") | |
if st.session_state.get('cam1_file') and os.path.exists(st.session_state['cam1_file']): | |
try: os.remove(st.session_state['cam1_file']) | |
except OSError: pass | |
try: | |
with open(filename, "wb") as f: f.write(cam1_img.getvalue()) | |
st.session_state['cam1_file'] = filename | |
st.session_state['history'].append(f"Snapshot from Cam 1: {filename}") | |
st.image(Image.open(filename), caption="Camera 1 Snap", use_container_width=True) | |
logger.info(f"Saved snapshot from Camera 1: {filename}") | |
st.success(f"Saved {filename}") | |
update_gallery() # Update sidebar gallery | |
st.rerun() | |
except Exception as e: | |
st.error(f"Failed to save Cam 1 snap: {e}") | |
logger.error(f"Failed to save Cam 1 snap {filename}: {e}") | |
# --- Tab: Download PDFs --- | |
with tabs[2]: | |
st.header("Download PDFs ๐ฅ") | |
st.markdown("Download PDFs from URLs and optionally create image snapshots.") | |
if st.button("Load Example arXiv URLs ๐", key="load_examples"): | |
example_urls = [ | |
"https://arxiv.org/pdf/2308.03892", # Example paper 1 | |
"https://arxiv.org/pdf/1706.03762", # Attention is All You Need | |
"https://arxiv.org/pdf/2402.17764", # Example paper 2 | |
# Add more diverse examples if needed | |
"https://www.un.org/esa/sustdev/publications/publications.html" # Example non-PDF page (will fail download) | |
"https://www.clickdimensions.com/links/ACCERL/" # Example direct PDF link | |
] | |
st.session_state['pdf_urls_input'] = "\n".join(example_urls) | |
url_input = st.text_area( | |
"Enter PDF URLs (one per line)", | |
value=st.session_state.get('pdf_urls_input', ""), | |
height=150, | |
key="pdf_urls_textarea" | |
) | |
if st.button("Robo-Download PDFs ๐ค", key="download_pdfs_button"): | |
urls = [url.strip() for url in url_input.strip().split("\n") if url.strip()] | |
if not urls: | |
st.warning("Please enter at least one URL.") | |
else: | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
total_urls = len(urls) | |
download_count = 0 | |
existing_pdfs = get_pdf_files() # Get current list once | |
for idx, url in enumerate(urls): | |
output_path = pdf_url_to_filename(url) | |
status_text.text(f"Processing {idx + 1}/{total_urls}: {os.path.basename(output_path)}...") | |
progress_bar.progress((idx + 1) / total_urls) | |
if output_path in existing_pdfs: | |
st.info(f"Already exists: {os.path.basename(output_path)}") | |
st.session_state['downloaded_pdfs'][url] = output_path # Still track it | |
# Ensure it's selectable in the gallery if it exists | |
if os.path.exists(output_path): | |
st.session_state['asset_checkboxes'][output_path] = st.session_state['asset_checkboxes'].get(output_path, False) | |
else: | |
if download_pdf(url, output_path): | |
st.session_state['downloaded_pdfs'][url] = output_path | |
logger.info(f"Downloaded PDF from {url} to {output_path}") | |
st.session_state['history'].append(f"Downloaded PDF: {output_path}") | |
st.session_state['asset_checkboxes'][output_path] = False # Default to unselected | |
download_count += 1 | |
existing_pdfs.append(output_path) # Add to current list | |
else: | |
st.error(f"Failed to download: {url}") | |
status_text.success(f"Download process complete! Successfully downloaded {download_count} new PDFs.") | |
if download_count > 0: | |
update_gallery() # Update sidebar only if new files were added | |
st.rerun() | |
st.subheader("Create Snapshots from Gallery PDFs") | |
snapshot_mode = st.selectbox( | |
"Snapshot Mode", | |
["First Page (High-Res)", "First Two Pages (High-Res)", "All Pages (High-Res)", "First Page (Low-Res Preview)"], | |
key="pdf_snapshot_mode" | |
) | |
resolution_map = { | |
"First Page (High-Res)": 2.0, | |
"First Two Pages (High-Res)": 2.0, | |
"All Pages (High-Res)": 2.0, | |
"First Page (Low-Res Preview)": 1.0 | |
} | |
mode_key_map = { | |
"First Page (High-Res)": "single", | |
"First Two Pages (High-Res)": "twopage", | |
"All Pages (High-Res)": "allpages", | |
"First Page (Low-Res Preview)": "single" | |
} | |
resolution = resolution_map[snapshot_mode] | |
mode_key = mode_key_map[snapshot_mode] | |
if st.button("Snapshot Selected PDFs ๐ธ", key="snapshot_selected_pdfs"): | |
selected_pdfs = [ | |
path for path in get_gallery_files(['pdf']) # Only get PDFs | |
if st.session_state['asset_checkboxes'].get(path, False) | |
] | |
if not selected_pdfs: | |
st.warning("No PDFs selected in the sidebar gallery! Tick the 'Select' box for PDFs you want to snapshot.") | |
else: | |
st.info(f"Starting snapshot process for {len(selected_pdfs)} selected PDF(s)...") | |
snapshot_count = 0 | |
total_snapshots_generated = 0 | |
for pdf_path in selected_pdfs: | |
if not os.path.exists(pdf_path): | |
st.warning(f"File not found: {pdf_path}. Skipping.") | |
continue | |
# Run the async snapshot function | |
# Need to run asyncio event loop properly in Streamlit | |
new_snapshots = asyncio.run(process_pdf_snapshot(pdf_path, mode_key, resolution)) | |
if new_snapshots: | |
snapshot_count += 1 | |
total_snapshots_generated += len(new_snapshots) | |
# Display the generated snapshots | |
st.write(f"Snapshots for {os.path.basename(pdf_path)}:") | |
cols = st.columns(3) | |
for i, snap_path in enumerate(new_snapshots): | |
with cols[i % 3]: | |
st.image(Image.open(snap_path), caption=os.path.basename(snap_path), use_container_width=True) | |
st.session_state['asset_checkboxes'][snap_path] = False # Add to gallery, unselected | |
if total_snapshots_generated > 0: | |
st.success(f"Generated {total_snapshots_generated} snapshots from {snapshot_count} PDFs.") | |
update_gallery() # Refresh sidebar | |
st.rerun() | |
else: | |
st.warning("No snapshots were generated. Check logs or PDF files.") | |
# --- Tab: PDF Process --- | |
with tabs[3]: | |
st.header("PDF Process with GPT ๐") | |
st.markdown("Upload PDFs, view pages, and extract text using GPT vision models.") | |
if not _openai_available: | |
st.error("OpenAI features are disabled. Cannot process PDFs with GPT.") | |
else: | |
gpt_models = ["gpt-4o", "gpt-4o-mini"] # Add more if needed | |
selected_gpt_model = st.selectbox("Select GPT Model", gpt_models, key="pdf_process_gpt_model") | |
detail_level = st.selectbox("Image Detail Level for GPT", ["auto", "low", "high"], key="pdf_process_detail_level", help="Affects how GPT 'sees' the image. 'high' costs more.") | |
uploaded_pdfs_process = st.file_uploader("Upload PDF files to process", type=["pdf"], accept_multiple_files=True, key="pdf_process_uploader") | |
if uploaded_pdfs_process: | |
process_button = st.button("Process Uploaded PDFs with GPT", key="process_uploaded_pdfs_gpt") | |
if process_button: | |
combined_text_output = f"# GPT ({selected_gpt_model}) PDF Processing Results\n\n" | |
total_pages_processed = 0 | |
output_placeholder = st.container() # Container for dynamic updates | |
for pdf_file in uploaded_pdfs_process: | |
output_placeholder.markdown(f"--- \n### Processing: {pdf_file.name}") | |
pdf_bytes = pdf_file.read() | |
temp_pdf_path = f"temp_process_{pdf_file.name}" | |
# Save temporary file | |
with open(temp_pdf_path, "wb") as f: f.write(pdf_bytes) | |
try: | |
doc = fitz.open(temp_pdf_path) | |
num_pages = len(doc) | |
output_placeholder.info(f"Found {num_pages} pages. Processing with {selected_gpt_model}...") | |
doc_text = f"## File: {pdf_file.name}\n\n" | |
for i, page in enumerate(doc): | |
page_start_time = time.time() | |
page_placeholder = output_placeholder.empty() | |
page_placeholder.info(f"Processing Page {i + 1}/{num_pages}...") | |
# Generate image from page | |
pix = page.get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) # Standard high-res for OCR | |
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
# Display image being processed | |
# cols = output_placeholder.columns(2) | |
# cols[0].image(img, caption=f"Page {i+1}", use_container_width=True) | |
# Process with GPT | |
prompt_pdf = "Extract all text content visible on this page. Maintain formatting like paragraphs and lists if possible." | |
gpt_text = process_image_with_prompt(img, prompt_pdf, model=selected_gpt_model, detail=detail_level) | |
doc_text += f"### Page {i + 1}\n\n{gpt_text}\n\n---\n\n" | |
total_pages_processed += 1 | |
elapsed_page = int(time.time() - page_start_time) | |
page_placeholder.success(f"Page {i + 1}/{num_pages} processed in {elapsed_page}s.") | |
# cols[1].text_area(f"GPT Output (Page {i+1})", gpt_text, height=200, key=f"pdf_gpt_out_{pdf_file.name}_{i}") | |
combined_text_output += doc_text | |
doc.close() | |
except (fitz.fitz.FileDataError, fitz.fitz.RuntimeException) as pdf_err: | |
output_placeholder.error(f"Error opening PDF {pdf_file.name}: {pdf_err}. Skipping.") | |
logger.warning(f"Error opening PDF {pdf_file.name}: {pdf_err}") | |
except Exception as e: | |
output_placeholder.error(f"Error processing {pdf_file.name}: {str(e)}") | |
logger.error(f"Error processing PDF {pdf_file.name}: {e}") | |
finally: | |
# Clean up temporary file | |
if os.path.exists(temp_pdf_path): | |
try: os.remove(temp_pdf_path) | |
except OSError: pass | |
if total_pages_processed > 0: | |
st.markdown("--- \n### Combined Processing Results") | |
st.markdown(f"Processed a total of {total_pages_processed} pages.") | |
st.text_area("Full GPT Output", combined_text_output, height=400, key="combined_pdf_gpt_output") | |
# Save combined output to a file | |
output_filename = generate_filename("gpt_processed_pdfs", "md") | |
try: | |
with open(output_filename, "w", encoding="utf-8") as f: | |
f.write(combined_text_output) | |
st.success(f"Combined output saved to {output_filename}") | |
st.markdown(get_download_link(output_filename, "text/markdown", "Download Combined MD"), unsafe_allow_html=True) | |
# Add to gallery automatically | |
st.session_state['asset_checkboxes'][output_filename] = False | |
update_gallery() | |
except IOError as e: | |
st.error(f"Failed to save combined output file: {e}") | |
logger.error(f"Failed to save {output_filename}: {e}") | |
else: | |
st.warning("No pages were processed.") | |
# --- Tab: Image Process --- | |
with tabs[4]: | |
st.header("Image Process with GPT ๐ผ๏ธ") | |
st.markdown("Upload images and process them using custom prompts with GPT vision models.") | |
if not _openai_available: | |
st.error("OpenAI features are disabled. Cannot process images with GPT.") | |
else: | |
gpt_models_img = ["gpt-4o", "gpt-4o-mini"] | |
selected_gpt_model_img = st.selectbox("Select GPT Model", gpt_models_img, key="img_process_gpt_model") | |
detail_level_img = st.selectbox("Image Detail Level", ["auto", "low", "high"], key="img_process_detail_level") | |
prompt_img_process = st.text_area( | |
"Enter prompt for image processing", | |
"Describe this image in detail. What is happening? What objects are present?", | |
key="img_process_prompt_area" | |
) | |
uploaded_images_process = st.file_uploader( | |
"Upload image files to process", type=["png", "jpg", "jpeg"], | |
accept_multiple_files=True, key="image_process_uploader" | |
) | |
if uploaded_images_process: | |
process_img_button = st.button("Process Uploaded Images with GPT", key="process_uploaded_images_gpt") | |
if process_img_button: | |
combined_img_text_output = f"# GPT ({selected_gpt_model_img}) Image Processing Results\n\n**Prompt:** {prompt_img_process}\n\n---\n\n" | |
images_processed_count = 0 | |
output_img_placeholder = st.container() | |
for img_file in uploaded_images_process: | |
output_img_placeholder.markdown(f"### Processing: {img_file.name}") | |
img_placeholder = output_img_placeholder.empty() | |
try: | |
img = Image.open(img_file) | |
cols_img = output_img_placeholder.columns(2) | |
cols_img[0].image(img, caption=f"Input: {img_file.name}", use_container_width=True) | |
# Process with GPT | |
gpt_img_text = process_image_with_prompt(img, prompt_img_process, model=selected_gpt_model_img, detail=detail_level_img) | |
cols_img[1].text_area(f"GPT Output", gpt_img_text, height=300, key=f"img_gpt_out_{img_file.name}") | |
combined_img_text_output += f"## Image: {img_file.name}\n\n{gpt_img_text}\n\n---\n\n" | |
images_processed_count += 1 | |
output_img_placeholder.success(f"Processed {img_file.name}.") | |
except UnidentifiedImageError: | |
output_img_placeholder.error(f"Cannot identify image file: {img_file.name}. Skipping.") | |
logger.warning(f"Cannot identify image file {img_file.name}") | |
except Exception as e: | |
output_img_placeholder.error(f"Error processing image {img_file.name}: {str(e)}") | |
logger.error(f"Error processing image {img_file.name}: {e}") | |
if images_processed_count > 0: | |
st.markdown("--- \n### Combined Image Processing Results") | |
st.markdown(f"Processed a total of {images_processed_count} images.") | |
st.text_area("Full GPT Output (Images)", combined_img_text_output, height=400, key="combined_img_gpt_output") | |
# Save combined output | |
output_filename_img = generate_filename("gpt_processed_images", "md") | |
try: | |
with open(output_filename_img, "w", encoding="utf-8") as f: | |
f.write(combined_img_text_output) | |
st.success(f"Combined image processing output saved to {output_filename_img}") | |
st.markdown(get_download_link(output_filename_img, "text/markdown", "Download Combined MD"), unsafe_allow_html=True) | |
st.session_state['asset_checkboxes'][output_filename_img] = False | |
update_gallery() | |
except IOError as e: | |
st.error(f"Failed to save combined image output file: {e}") | |
logger.error(f"Failed to save {output_filename_img}: {e}") | |
else: | |
st.warning("No images were processed.") | |
# --- Tab: Test OCR --- | |
with tabs[5]: | |
st.header("Test OCR with GPT-4o ๐") | |
st.markdown("Select an image or PDF from the gallery and run GPT-4o OCR.") | |
if not _openai_available: | |
st.error("OpenAI features are disabled. Cannot perform OCR.") | |
else: | |
gallery_files_ocr = get_gallery_files(['png', 'jpg', 'jpeg', 'pdf']) | |
if not gallery_files_ocr: | |
st.warning("No images or PDFs in the gallery. Use Camera Snap or Download PDFs first.") | |
else: | |
selected_file_ocr = st.selectbox( | |
"Select Image or PDF from Gallery for OCR", | |
options=[""] + gallery_files_ocr, # Add empty option | |
format_func=lambda x: os.path.basename(x) if x else "Select a file...", | |
key="ocr_select_file" | |
) | |
if selected_file_ocr: | |
st.write(f"Selected: {os.path.basename(selected_file_ocr)}") | |
file_ext_ocr = os.path.splitext(selected_file_ocr)[1].lower() | |
image_to_ocr = None | |
page_info = "" | |
try: | |
if file_ext_ocr in ['.png', '.jpg', '.jpeg']: | |
image_to_ocr = Image.open(selected_file_ocr) | |
elif file_ext_ocr == '.pdf': | |
doc = fitz.open(selected_file_ocr) | |
if len(doc) > 0: | |
# Use first page for single OCR test | |
pix = doc[0].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) # High-res for OCR | |
image_to_ocr = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
page_info = " (Page 1)" | |
else: | |
st.warning("Selected PDF is empty.") | |
doc.close() | |
if image_to_ocr: | |
st.image(image_to_ocr, caption=f"Image for OCR{page_info}", use_container_width=True) | |
if st.button("Run GPT-4o OCR on this Image ๐", key="ocr_run_button"): | |
output_ocr_file = generate_filename(f"ocr_{os.path.splitext(os.path.basename(selected_file_ocr))[0]}", "txt") | |
st.session_state['processing']['ocr'] = True # Indicate processing | |
# Run async OCR function | |
ocr_result = asyncio.run(process_gpt4o_ocr(image_to_ocr, output_ocr_file)) | |
st.session_state['processing']['ocr'] = False # Clear processing flag | |
if ocr_result and not ocr_result.startswith("Error"): | |
entry = f"OCR Test: {selected_file_ocr}{page_info} -> {output_ocr_file}" | |
st.session_state['history'].append(entry) | |
st.text_area("OCR Result", ocr_result, height=300, key="ocr_result_display") | |
if len(ocr_result) > 10: # Basic check if result seems valid | |
st.success(f"OCR output saved to {output_ocr_file}") | |
st.markdown(get_download_link(output_ocr_file, "text/plain", "Download OCR Text"), unsafe_allow_html=True) | |
# Add txt file to gallery | |
st.session_state['asset_checkboxes'][output_ocr_file] = False | |
update_gallery() | |
else: | |
st.warning("OCR output seems short or empty; file may not contain useful text.") | |
if os.path.exists(output_ocr_file): os.remove(output_ocr_file) # Clean up empty file | |
else: | |
st.error(f"OCR failed. {ocr_result}") | |
if os.path.exists(output_ocr_file): os.remove(output_ocr_file) # Clean up failed file | |
# Option for multi-page PDF OCR | |
if file_ext_ocr == '.pdf': | |
if st.button("Run OCR on All Pages of PDF ๐๐", key="ocr_all_pages_button"): | |
st.info("Starting full PDF OCR... This may take time.") | |
try: | |
doc = fitz.open(selected_file_ocr) | |
num_pages_pdf = len(doc) | |
if num_pages_pdf == 0: | |
st.warning("PDF is empty.") | |
else: | |
full_text_ocr = f"# Full OCR Results for {os.path.basename(selected_file_ocr)}\n\n" | |
total_pages_ocr_processed = 0 | |
ocr_output_placeholder = st.container() | |
for i in range(num_pages_pdf): | |
page_ocr_start_time = time.time() | |
page_ocr_placeholder = ocr_output_placeholder.empty() | |
page_ocr_placeholder.info(f"OCR - Processing Page {i + 1}/{num_pages_pdf}...") | |
pix_ocr = doc[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0)) | |
image_page_ocr = Image.frombytes("RGB", [pix_ocr.width, pix_ocr.height], pix_ocr.samples) | |
output_page_ocr_file = generate_filename(f"ocr_{os.path.splitext(os.path.basename(selected_file_ocr))[0]}_p{i+1}", "txt") | |
page_ocr_result = asyncio.run(process_gpt4o_ocr(image_page_ocr, output_page_ocr_file)) | |
if page_ocr_result and not page_ocr_result.startswith("Error"): | |
full_text_ocr += f"## Page {i + 1}\n\n{page_ocr_result}\n\n---\n\n" | |
entry_page = f"OCR Multi: {selected_file_ocr} Page {i + 1} -> {output_page_ocr_file}" | |
st.session_state['history'].append(entry_page) | |
# Don't add individual page txt files to gallery to avoid clutter | |
if os.path.exists(output_page_ocr_file): os.remove(output_page_ocr_file) | |
total_pages_ocr_processed += 1 | |
elapsed_ocr_page = int(time.time() - page_ocr_start_time) | |
page_ocr_placeholder.success(f"OCR - Page {i + 1}/{num_pages_pdf} done ({elapsed_ocr_page}s).") | |
else: | |
page_ocr_placeholder.error(f"OCR failed for Page {i+1}. Skipping.") | |
full_text_ocr += f"## Page {i + 1}\n\n[OCR FAILED]\n\n---\n\n" | |
if os.path.exists(output_page_ocr_file): os.remove(output_page_ocr_file) | |
doc.close() | |
if total_pages_ocr_processed > 0: | |
md_output_file_ocr = generate_filename(f"full_ocr_{os.path.splitext(os.path.basename(selected_file_ocr))[0]}", "md") | |
try: | |
with open(md_output_file_ocr, "w", encoding='utf-8') as f: | |
f.write(full_text_ocr) | |
st.success(f"Full PDF OCR complete. Combined output saved to {md_output_file_ocr}") | |
st.markdown(get_download_link(md_output_file_ocr, "text/markdown", "Download Full OCR Markdown"), unsafe_allow_html=True) | |
st.session_state['asset_checkboxes'][md_output_file_ocr] = False | |
update_gallery() | |
except IOError as e: | |
st.error(f"Failed to save combined OCR file: {e}") | |
else: | |
st.warning("No pages were successfully OCR'd from the PDF.") | |
except Exception as e: | |
st.error(f"Error during full PDF OCR: {e}") | |
logger.error(f"Full PDF OCR failed for {selected_file_ocr}: {e}") | |
except (fitz.fitz.FileDataError, fitz.fitz.RuntimeException) as pdf_err: | |
st.error(f"Cannot open PDF {os.path.basename(selected_file_ocr)}: {pdf_err}") | |
except UnidentifiedImageError: | |
st.error(f"Cannot identify image file: {os.path.basename(selected_file_ocr)}") | |
except FileNotFoundError: | |
st.error(f"File not found: {os.path.basename(selected_file_ocr)}. Refresh the gallery.") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |
logger.error(f"Error in OCR tab for {selected_file_ocr}: {e}") | |
# --- Tab: MD Gallery & Process --- | |
with tabs[6]: | |
st.header("MD & Text File Gallery / GPT Processing ๐") | |
st.markdown("View, process, and combine Markdown (.md) and Text (.txt) files from the gallery using GPT.") | |
if not _openai_available: | |
st.error("OpenAI features are disabled. Cannot process text files with GPT.") | |
else: | |
gpt_models_md = ["gpt-4o", "gpt-4o-mini"] | |
selected_gpt_model_md = st.selectbox("Select GPT Model for Text Processing", gpt_models_md, key="md_process_gpt_model") | |
md_txt_files = get_gallery_files(['md', 'txt']) | |
if not md_txt_files: | |
st.warning("No Markdown (.md) or Text (.txt) files found in the gallery.") | |
else: | |
st.subheader("Individual File Processing") | |
selected_file_md = st.selectbox( | |
"Select MD/TXT File to Process", | |
options=[""] + md_txt_files, | |
format_func=lambda x: os.path.basename(x) if x else "Select a file...", | |
key="md_select_individual" | |
) | |
if selected_file_md: | |
st.write(f"Selected: {os.path.basename(selected_file_md)}") | |
try: | |
with open(selected_file_md, "r", encoding="utf-8", errors='ignore') as f: | |
content_md = f.read() | |
st.text_area("File Content Preview", content_md[:1000] + ("..." if len(content_md) > 1000 else ""), height=200, key="md_content_preview") | |
prompt_md_individual = st.text_area( | |
"Enter Prompt for this File", | |
"Summarize the key points of this text into a bulleted list.", | |
key="md_individual_prompt" | |
) | |
if st.button(f"Process {os.path.basename(selected_file_md)} with GPT", key=f"process_md_ind_{selected_file_md}"): | |
with st.spinner("Processing text with GPT..."): | |
result_text_md = process_text_with_prompt(content_md, prompt_md_individual, model=selected_gpt_model_md) | |
st.markdown("### GPT Processing Result") | |
st.markdown(result_text_md) # Display result as Markdown | |
# Save the result | |
output_filename_md = generate_filename(f"gpt_processed_{os.path.splitext(os.path.basename(selected_file_md))[0]}", "md") | |
try: | |
with open(output_filename_md, "w", encoding="utf-8") as f: | |
f.write(result_text_md) | |
st.success(f"Processing result saved to {output_filename_md}") | |
st.markdown(get_download_link(output_filename_md, "text/markdown", "Download Processed MD"), unsafe_allow_html=True) | |
st.session_state['asset_checkboxes'][output_filename_md] = False | |
update_gallery() | |
except IOError as e: | |
st.error(f"Failed to save processed MD file: {e}") | |
except FileNotFoundError: | |
st.error("Selected file not found. It might have been deleted.") | |
except Exception as e: | |
st.error(f"Error reading or processing file: {e}") | |
st.markdown("---") | |
st.subheader("Combine and Process Multiple Files") | |
st.write("Select MD/TXT files from the gallery to combine:") | |
selected_md_combine = {} | |
cols_md = st.columns(3) | |
for idx, md_file in enumerate(md_txt_files): | |
with cols_md[idx % 3]: | |
selected_md_combine[md_file] = st.checkbox( | |
f"{os.path.basename(md_file)}", | |
key=f"checkbox_md_combine_{md_file}" | |
) | |
prompt_md_combine = st.text_area( | |
"Enter Prompt for Combined Content", | |
"Synthesize the following texts into a cohesive summary. Identify the main themes and provide supporting details from the different sources.", | |
key="md_combine_prompt" | |
) | |
if st.button("Process Selected MD/TXT Files with GPT", key="process_combine_md"): | |
files_to_combine = [f for f, selected in selected_md_combine.items() if selected] | |
if not files_to_combine: | |
st.warning("No files selected for combination.") | |
else: | |
st.info(f"Combining {len(files_to_combine)} files...") | |
combined_content = "" | |
for md_file in files_to_combine: | |
try: | |
with open(md_file, "r", encoding="utf-8", errors='ignore') as f: | |
combined_content += f"\n\n## --- Source: {os.path.basename(md_file)} ---\n\n" + f.read() | |
except Exception as e: | |
st.error(f"Error reading {md_file}: {str(e)}. Skipping.") | |
logger.warning(f"Error reading {md_file} for combination: {e}") | |
if combined_content: | |
st.text_area("Preview Combined Content (First 2000 chars)", combined_content[:2000]+"...", height=200) | |
with st.spinner("Processing combined text with GPT..."): | |
result_text_combine = process_text_with_prompt(combined_content, prompt_md_combine, model=selected_gpt_model_md) | |
st.markdown("### Combined Processing Result") | |
st.markdown(result_text_combine) | |
# Save the combined result | |
output_filename_combine = generate_filename("gpt_combined_md_txt", "md") | |
try: | |
with open(output_filename_combine, "w", encoding="utf-8") as f: | |
f.write(f"# Combined Processing Result\n\n**Prompt:** {prompt_md_combine}\n\n**Sources:** {', '.join([os.path.basename(f) for f in files_to_combine])}\n\n---\n\n{result_text_combine}") | |
st.success(f"Combined processing result saved to {output_filename_combine}") | |
st.markdown(get_download_link(output_filename_combine, "text/markdown", "Download Combined MD"), unsafe_allow_html=True) | |
st.session_state['asset_checkboxes'][output_filename_combine] = False | |
update_gallery() | |
except IOError as e: | |
st.error(f"Failed to save combined processed file: {e}") | |
else: | |
st.error("Failed to read content from selected files.") | |
# --- Tab: Build Titan --- | |
with tabs[7]: | |
st.header("Build Titan Model ๐ฑ") | |
st.markdown("Download and save base models for Causal LM or Diffusion tasks.") | |
if not _ai_libs_available: | |
st.error("AI/ML libraries (torch, transformers, diffusers) are required for this feature.") | |
else: | |
build_model_type = st.selectbox("Model Type to Build", ["Causal LM", "Diffusion"], key="build_type_select") | |
if build_model_type == "Causal LM": | |
default_causal = "HuggingFaceTB/SmolLM-135M" #"Qwen/Qwen1.5-0.5B-Chat" is larger | |
causal_models = [default_causal, "gpt2", "distilgpt2"] # Add more small options | |
base_model_select = st.selectbox( | |
"Select Base Causal LM", causal_models, index=causal_models.index(default_causal), | |
key="causal_model_select" | |
) | |
else: # Diffusion | |
default_diffusion = "OFA-Sys/small-stable-diffusion-v0" #"stabilityai/stable-diffusion-2-base" is large | |
diffusion_models = [default_diffusion, "google/ddpm-cat-256", "google/ddpm-celebahq-256"] # Add more small options | |
base_model_select = st.selectbox( | |
"Select Base Diffusion Model", diffusion_models, index=diffusion_models.index(default_diffusion), | |
key="diffusion_model_select" | |
) | |
model_name_build = st.text_input("Local Model Name", f"{build_model_type.lower().replace(' ','')}-titan-{os.path.basename(base_model_select)}-{int(time.time()) % 10000}", key="build_model_name") | |
domain_build = st.text_input("Optional: Target Domain Tag", "general", key="build_domain") | |
if st.button(f"Download & Save {build_model_type} Model โฌ๏ธ", key="download_build_model"): | |
if not model_name_build: | |
st.error("Please provide a local model name.") | |
else: | |
if build_model_type == "Causal LM": | |
config = ModelConfig( | |
name=model_name_build, base_model=base_model_select, size="small", domain=domain_build # Size is illustrative | |
) | |
builder = ModelBuilder() | |
else: | |
config = DiffusionConfig( | |
name=model_name_build, base_model=base_model_select, size="small", domain=domain_build | |
) | |
builder = DiffusionBuilder() | |
try: | |
builder.load_model(base_model_select, config) | |
builder.save_model(config.model_path) # Save to ./models/ or ./diffusion_models/ | |
st.session_state['builder'] = builder # Store the loaded builder instance | |
st.session_state['model_loaded'] = True | |
st.session_state['selected_model_type'] = build_model_type | |
st.session_state['selected_model'] = config.model_path # Store path to local copy | |
st.session_state['history'].append(f"Built {build_model_type} model: {model_name_build} from {base_model_select}") | |
st.success(f"{build_model_type} model downloaded from {base_model_select} and saved locally to {config.model_path}! ๐") | |
# No automatic rerun, let user proceed | |
except Exception as e: | |
st.error(f"Failed to build model: {e}") | |
logger.error(f"Failed to build model {model_name_build} from {base_model_select}: {e}") | |
# --- Tab: Test Image Gen --- | |
with tabs[8]: | |
st.header("Test Image Generation ๐จ") | |
st.markdown("Generate images using a loaded Diffusion model.") | |
if not _ai_libs_available: | |
st.error("AI/ML libraries (torch, transformers, diffusers) are required for image generation.") | |
else: | |
# Check if a diffusion model is loaded in session state or select one | |
available_diffusion_models = get_model_files("diffusion") | |
loaded_diffusion_model_path = None | |
# Check if the currently loaded builder is diffusion | |
current_builder = st.session_state.get('builder') | |
if current_builder and isinstance(current_builder, DiffusionBuilder) and current_builder.pipeline: | |
loaded_diffusion_model_path = current_builder.config.model_path if current_builder.config else "Loaded Model" | |
# Prepare options for selection, prioritizing loaded model | |
model_options = ["Load Default Small Model"] + available_diffusion_models | |
current_selection_index = 0 # Default to loading small model | |
if loaded_diffusion_model_path and loaded_diffusion_model_path != "Loaded Model": | |
if loaded_diffusion_model_path not in model_options: | |
model_options.insert(1, loaded_diffusion_model_path) # Add if not already listed | |
current_selection_index = model_options.index(loaded_diffusion_model_path) | |
elif loaded_diffusion_model_path == "Loaded Model": | |
# A model is loaded, but we don't have its path (e.g., loaded directly) | |
model_options.insert(1, "Currently Loaded Model") | |
current_selection_index = 1 | |
selected_diffusion_model = st.selectbox( | |
"Select Diffusion Model for Generation", | |
options=model_options, | |
index=current_selection_index, | |
key="imggen_model_select", | |
help="Select a locally saved model, or load the default small one." | |
) | |
# Button to explicitly load the selected model if it's not the active one | |
load_needed = False | |
if selected_diffusion_model == "Load Default Small Model": | |
load_needed = not (current_builder and isinstance(current_builder, DiffusionBuilder) and current_builder.config and current_builder.config.base_model == "OFA-Sys/small-stable-diffusion-v0") | |
elif selected_diffusion_model == "Currently Loaded Model": | |
load_needed = False # Already loaded | |
else: # A specific path is selected | |
load_needed = not (current_builder and isinstance(current_builder, DiffusionBuilder) and current_builder.config and current_builder.config.model_path == selected_diffusion_model) | |
if load_needed: | |
if st.button(f"Load '{os.path.basename(selected_diffusion_model)}' Model", key="imggen_load_sel"): | |
try: | |
if selected_diffusion_model == "Load Default Small Model": | |
model_to_load = "OFA-Sys/small-stable-diffusion-v0" | |
config = DiffusionConfig(name="default-small", base_model=model_to_load, size="small") | |
builder = DiffusionBuilder().load_model(model_to_load, config) | |
st.session_state['builder'] = builder | |
st.session_state['model_loaded'] = True | |
st.session_state['selected_model_type'] = "Diffusion" | |
st.session_state['selected_model'] = config.model_path # This isn't saved, just track base | |
st.success("Default small diffusion model loaded.") | |
st.rerun() | |
else: # Load from local path | |
config = DiffusionConfig(name=os.path.basename(selected_diffusion_model), base_model="local", size="unknown", model_path=selected_diffusion_model) | |
builder = DiffusionBuilder().load_model(selected_diffusion_model, config) | |
st.session_state['builder'] = builder | |
st.session_state['model_loaded'] = True | |
st.session_state['selected_model_type'] = "Diffusion" | |
st.session_state['selected_model'] = config.model_path | |
st.success(f"Loaded local model: {config.name}") | |
st.rerun() | |
except Exception as e: | |
st.error(f"Failed to load model {selected_diffusion_model}: {e}") | |
logger.error(f"Failed loading diffusion model {selected_diffusion_model}: {e}") | |
# Image Generation Prompt | |
prompt_imggen = st.text_area("Image Generation Prompt", "A futuristic cityscape at sunset, neon lights, flying cars", key="imggen_prompt") | |
if st.button("Generate Image ๐", key="imggen_run_button"): | |
# Check again if a model is effectively loaded and ready | |
current_builder = st.session_state.get('builder') | |
if not (current_builder and isinstance(current_builder, DiffusionBuilder) and current_builder.pipeline): | |
st.error("No diffusion model is loaded. Please select and load a model first.") | |
elif not prompt_imggen: | |
st.warning("Please enter a prompt.") | |
else: | |
output_imggen_file = generate_filename("image_gen", "png") | |
st.session_state['processing']['gen'] = True | |
# Run async generation | |
generated_image = asyncio.run(process_image_gen(prompt_imggen, output_imggen_file)) | |
st.session_state['processing']['gen'] = False | |
if generated_image and os.path.exists(output_imggen_file): | |
entry = f"Image Gen: '{prompt_imggen[:30]}...' -> {output_imggen_file}" | |
st.session_state['history'].append(entry) | |
st.image(generated_image, caption=f"Generated: {os.path.basename(output_imggen_file)}", use_container_width=True) | |
st.success(f"Image saved to {output_imggen_file}") | |
st.markdown(get_download_link(output_imggen_file, "image/png", "Download Generated Image"), unsafe_allow_html=True) | |
# Add to gallery | |
st.session_state['asset_checkboxes'][output_imggen_file] = False | |
update_gallery() | |
# Consider st.rerun() if immediate gallery update is critical | |
else: | |
st.error("Image generation failed. Check logs.") | |
# --- Tab: Character Editor --- | |
with tabs[9]: | |
st.header("Character Editor ๐งโ๐จ") | |
st.subheader("Create or Modify Your Character") | |
# Load existing characters for potential editing (optional) | |
load_characters() | |
existing_char_names = [c['name'] for c in st.session_state.get('characters', [])] | |
# Use a unique key for the form to allow reset | |
form_key = f"character_form_{st.session_state.get('char_form_reset_key', 0)}" | |
with st.form(key=form_key): | |
st.markdown("**Create New Character**") | |
# Randomize button inside the form | |
if st.form_submit_button("Randomize Content ๐ฒ"): | |
# Increment key to force form reset with new random values | |
st.session_state['char_form_reset_key'] = st.session_state.get('char_form_reset_key', 0) + 1 | |
st.rerun() # Rerun to get new random defaults in the reset form | |
# Get random defaults only once per form rendering cycle unless reset | |
rand_name, rand_gender, rand_intro, rand_greeting = randomize_character_content() | |
name_char = st.text_input( | |
"Name (3-25 chars, letters, numbers, underscore, hyphen, space)", | |
value=rand_name, max_chars=25, key="char_name_input" | |
) | |
gender_char = st.radio( | |
"Gender", ["Male", "Female"], index=["Male", "Female"].index(rand_gender), | |
key="char_gender_radio" | |
) | |
intro_char = st.text_area( | |
"Intro (Public description)", value=rand_intro, max_chars=300, height=100, | |
key="char_intro_area" | |
) | |
greeting_char = st.text_area( | |
"Greeting (First message)", value=rand_greeting, max_chars=300, height=100, | |
key="char_greeting_area" | |
) | |
tags_char = st.text_input("Tags (comma-separated)", "OC, friendly", key="char_tags_input") | |
submitted = st.form_submit_button("Create Character โจ") | |
if submitted: | |
# Validation | |
error = False | |
if not (3 <= len(name_char) <= 25): | |
st.error("Name must be between 3 and 25 characters.") | |
error = True | |
if not re.match(r'^[a-zA-Z0-9 _-]+$', name_char): | |
st.error("Name contains invalid characters.") | |
error = True | |
if name_char in existing_char_names: | |
st.error(f"Character name '{name_char}' already exists!") | |
error = True | |
if not intro_char or not greeting_char: | |
st.error("Intro and Greeting cannot be empty.") | |
error = True | |
if not error: | |
tag_list = [tag.strip() for tag in tags_char.split(',') if tag.strip()] | |
character_data = { | |
"name": name_char, | |
"gender": gender_char, | |
"intro": intro_char, | |
"greeting": greeting_char, | |
"created_at": datetime.now(pytz.timezone("US/Central")).strftime('%Y-%m-%d %H:%M:%S %Z'), # Added timezone | |
"tags": tag_list | |
} | |
if save_character(character_data): | |
st.success(f"Character '{name_char}' created successfully!") | |
# Increment key to reset form for next creation | |
st.session_state['char_form_reset_key'] = st.session_state.get('char_form_reset_key', 0) + 1 | |
st.rerun() # Rerun to clear form and update gallery tab | |
# --- Tab: Character Gallery --- | |
with tabs[10]: | |
st.header("Character Gallery ๐ผ๏ธ") | |
# Load characters every time the tab is viewed | |
load_characters() | |
characters_list = st.session_state.get('characters', []) | |
if not characters_list: | |
st.warning("No characters created yet. Use the Character Editor tab!") | |
else: | |
st.subheader(f"Your Characters ({len(characters_list)})") | |
st.markdown("View and manage your created characters.") | |
# Search/Filter (Optional Enhancement) | |
search_term = st.text_input("Search Characters by Name", key="char_gallery_search") | |
if search_term: | |
characters_list = [c for c in characters_list if search_term.lower() in c['name'].lower()] | |
cols_char_gallery = st.columns(3) # Adjust number of columns as needed | |
chars_to_delete = [] # Store names to delete after iteration | |
for idx, char in enumerate(characters_list): | |
with cols_char_gallery[idx % 3]: | |
with st.container(border=True): # Add border to each character card | |
st.markdown(f"**{char['name']}**") | |
st.caption(f"Gender: {char.get('gender', 'N/A')}") # Use .get for safety | |
st.markdown("**Intro:**") | |
st.markdown(f"> {char.get('intro', '')}") # Blockquote style | |
st.markdown("**Greeting:**") | |
st.markdown(f"> {char.get('greeting', '')}") | |
st.caption(f"Tags: {', '.join(char.get('tags', ['N/A']))}") | |
st.caption(f"Created: {char.get('created_at', 'N/A')}") | |
# Delete Button | |
delete_key_char = f"delete_char_{char['name']}_{idx}" # More unique key | |
if st.button(f"Delete {char['name']}", key=delete_key_char, type="primary"): | |
chars_to_delete.append(char['name']) # Mark for deletion | |
# Process deletions after iterating | |
if chars_to_delete: | |
current_characters = st.session_state.get('characters', []) | |
updated_characters = [c for c in current_characters if c['name'] not in chars_to_delete] | |
st.session_state['characters'] = updated_characters | |
try: | |
with open("characters.json", "w", encoding='utf-8') as f: | |
json.dump(updated_characters, f, indent=2) | |
logger.info(f"Deleted characters: {', '.join(chars_to_delete)}") | |
st.success(f"Deleted characters: {', '.join(chars_to_delete)}") | |
st.rerun() # Rerun to reflect changes | |
except IOError as e: | |
logger.error(f"Failed to save characters.json after deletion: {e}") | |
st.error("Failed to update character file after deletion.") | |
# --- Footer and Persistent Sidebar Elements ------------ | |
# Update Sidebar Gallery (Call this at the end to reflect all changes) | |
update_gallery() | |
# Action Logs in Sidebar | |
st.sidebar.subheader("Action Logs ๐") | |
log_expander = st.sidebar.expander("View Logs", expanded=False) | |
with log_expander: | |
log_text = "\n".join([f"{record.asctime} - {record.levelname} - {record.message}" for record in log_records[-20:]]) # Show last 20 logs | |
st.code(log_text, language='log') | |
# History in Sidebar | |
st.sidebar.subheader("Session History ๐") | |
history_expander = st.sidebar.expander("View History", expanded=False) | |
with history_expander: | |
# Display history in reverse chronological order | |
for entry in reversed(st.session_state.get("history", [])): | |
if entry: history_expander.write(f"- {entry}") | |
st.sidebar.markdown("---") | |
st.sidebar.info("App combines Image Layout PDF generation with AI Vision/SFT tools.") | |
st.sidebar.caption("Combined App by AI Assistant for User") |