jdavis's picture
Update app.py
0e1643e verified
import os
import sys
# Set critical environment variables first
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
os.environ["WATCHDOG_OPTIONAL"] = "1"
os.environ["PYTORCH_JIT"] = "0"
# Import third party modules
import streamlit as st
import numpy as np
import random
from PIL import Image
import io
import time
# Set up imports for huggingface_hub
# Import what we can, but handle potential import errors
try:
from huggingface_hub import HfApi, HfFolder, login
except ImportError as e:
st.error(f"Error importing from huggingface_hub: {e}")
# Configure Hugging Face cache and environment
os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".cache/huggingface")
# Import PyTorch after environment setup
import torch
from diffusers import FluxFillPipeline
import warnings
warnings.filterwarnings("ignore", message=".*add_prefix_space.*")
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
# Setting page config
st.set_page_config(
page_title="FLUX.1 Fill [dev]",
layout="wide"
)
# Title and description
st.markdown("""
# FLUX.1 Fill [dev]
12B param rectified flow transformer structural conditioning tuned, guidance-distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/)
[[non-commercial license](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md)] [[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-dev)]
""")
# Add simple instructions
st.sidebar.markdown("""
## Important Setup Information
This app uses the FLUX.1-Fill-dev model which requires special access:
1. Sign up/login at [Hugging Face](https://huggingface.co/)
2. Request access to [FLUX.1-Fill-dev](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) by clicking 'Access repository'
3. Wait for approval from model owners
### For Hugging Face Spaces Setup:
1. Go to your Space settings > Secrets
2. Add a new secret with the name `HF_TOKEN`
3. Set its value to your Hugging Face API token (found in your account settings)
""")
# Try to get a Hugging Face token from environment variables
def get_hf_token():
# Check common environment variable names for HF tokens
token_env_vars = [
'HF_TOKEN',
'HUGGINGFACE_TOKEN',
'HUGGING_FACE_HUB_TOKEN',
'HF_API_TOKEN',
'HUGGINGFACE_API_TOKEN',
'HUGGINGFACE_HUB_TOKEN'
]
for env_var in token_env_vars:
if env_var in os.environ and os.environ[env_var].strip():
st.sidebar.success(f"Found token in {env_var}")
return os.environ[env_var].strip()
# If we're here, no token was found
st.sidebar.warning("No Hugging Face token found in environment variables")
return None
@st.cache_resource(show_spinner=False)
def load_model():
"""Load the model with a simplified approach using the required token"""
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
st.info(f"Using device: {device}")
# Get token
token = get_hf_token()
st.info(f"Token available: {'Yes' if token else 'No'}")
try:
# Add use_fast_tokenizer=True to address the tokenizer warning
model = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
token=token,
torch_dtype=torch.bfloat16,
revision="main"
)
st.success("Model loaded successfully!")
return model.to(device)
except Exception as e:
st.error(f"Failed to load model: {e}")
if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
st.error("""
Access Denied: You need to:
1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
2. Set up your Hugging Face token in Spaces:
- Go to your Space settings > Secrets
- Add a new secret with name 'HF_TOKEN'
- Set its value to your Hugging Face API token
3. Wait for approval from model owners
Note: You can find your token at https://huggingface.co/settings/tokens
""")
st.stop()
except Exception as e:
st.error(f"Failed to load model after all attempts: {e}")
if "401" in str(e) or "access" in str(e).lower() or "denied" in str(e).lower():
st.error("""
Access Denied: You need to:
1. Request access to the model at https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev
2. Set up your Hugging Face token in Spaces:
- Go to your Space settings > Secrets
- Add a new secret with name 'HF_TOKEN'
- Set its value to your Hugging Face API token
3. Wait for approval from model owners
Note: You can find your token at https://huggingface.co/settings/tokens
""")
elif "Tried to instantiate class" in str(e):
st.error("""
PyTorch class initialization error. Try restarting the app.
If the error persists, try accessing the app from a different browser.
""")
st.stop()
# Initialize model section
with st.spinner("Loading model..."):
try:
pipe = load_model()
st.success("Model loaded successfully!")
except Exception as e:
st.error(f"Failed to load model: {str(e)}")
st.stop()
def calculate_optimal_dimensions(image: Image.Image):
# Extract the original dimensions
original_width, original_height = image.size
# Set constants
MIN_ASPECT_RATIO = 9 / 16
MAX_ASPECT_RATIO = 16 / 9
FIXED_DIMENSION = 1024
# Calculate the aspect ratio of the original image
original_aspect_ratio = original_width / original_height
# Determine which dimension to fix
if original_aspect_ratio > 1: # Wider than tall
width = FIXED_DIMENSION
height = round(FIXED_DIMENSION / original_aspect_ratio)
else: # Taller than wide
height = FIXED_DIMENSION
width = round(FIXED_DIMENSION * original_aspect_ratio)
# Ensure dimensions are multiples of 8
width = (width // 8) * 8
height = (height // 8) * 8
# Enforce aspect ratio limits
calculated_aspect_ratio = width / height
if calculated_aspect_ratio > MAX_ASPECT_RATIO:
width = (height * MAX_ASPECT_RATIO // 8) * 8
elif calculated_aspect_ratio < MIN_ASPECT_RATIO:
height = (width / MIN_ASPECT_RATIO // 8) * 8
# Ensure width and height remain above the minimum dimensions
width = max(width, 576) if width == FIXED_DIMENSION else width
height = max(height, 576) if height == FIXED_DIMENSION else height
return width, height
# Create two columns for layout
col1, col2 = st.columns([1, 1])
with col1:
# Upload image
uploaded_file = st.file_uploader("Upload an image for inpainting", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Display the uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Simple approach to create a mask - select a square area
st.write("Select an area to inpaint:")
# Get image dimensions
img_width, img_height = image.size
# Scale for display while maintaining aspect ratio
display_height = 600
display_width = int(img_width * (display_height / img_height))
# Create sliders for selecting the area
col_sliders1, col_sliders2 = st.columns(2)
with col_sliders1:
x1 = st.slider("Left edge (X1)", 0, img_width, img_width // 4)
y1 = st.slider("Top edge (Y1)", 0, img_height, img_height // 4)
with col_sliders2:
x2 = st.slider("Right edge (X2)", x1, img_width, min(x1 + img_width // 2, img_width))
y2 = st.slider("Bottom edge (Y2)", y1, img_height, min(y1 + img_height // 2, img_height))
# Create a copy of the image to show the mask
preview_img = image.copy()
preview_mask = Image.new("L", image.size, 0)
# Draw a white rectangle on the mask
from PIL import ImageDraw
draw = ImageDraw.Draw(preview_mask)
draw.rectangle([(x1, y1), (x2, y2)], fill=255)
# Show the mask on the image
masked_preview = image.copy()
# Add semi-transparent white overlay
overlay = Image.new("RGBA", image.size, (255, 255, 255, 128))
masked_preview.paste(overlay, (0, 0), preview_mask)
st.image(masked_preview, caption="Area to inpaint (white overlay)", use_container_width=True)
# Prompt input
prompt = st.text_input("Enter your prompt")
# Example prompts
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
example_prompt = st.selectbox("Or select an example prompt", [""] + examples)
if example_prompt and not prompt:
prompt = example_prompt
# Advanced settings with expander
with st.expander("Advanced Settings"):
randomize_seed = st.checkbox("Randomize seed", value=True)
if not randomize_seed:
seed = st.slider("Seed", 0, MAX_SEED, 0)
else:
seed = random.randint(0, MAX_SEED)
guidance_scale = st.slider("Guidance Scale", 1.0, 30.0, 3.5, 0.5)
num_inference_steps = st.slider("Number of inference steps", 1, 50, 28)
# Run button
run_button = st.button("Generate")
with col2:
if uploaded_file is not None:
st.write("Result will appear here")
if run_button and prompt:
with st.spinner("Generating image..."):
# Create mask from rectangle coordinates
mask = Image.new("L", image.size, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([(x1, y1), (x2, y2)], fill=255)
# Calculate dimensions for generation
width, height = calculate_optimal_dimensions(image)
# Progress bar
progress_bar = st.progress(0)
# Generate the image
try:
# Set up progress bar updates
progress_text = st.empty()
debug_info = st.empty()
# Show parameters for debugging
debug_info.info(f"Model type: {pipe.__class__.__name__}")
# Update progress
progress_bar.progress(0.1)
progress_text.text("Preparing image and mask...")
# Make sure mask is in the right format
# Some models require masks where white (255) is the area to inpaint
mask_img = mask.convert("L")
# Prepare arguments - different models may have different parameter names
model_class_name = pipe.__class__.__name__
# Common parameters for all models
common_params = {
"prompt": prompt,
"image": image,
"mask_image": mask_img,
"num_inference_steps": num_inference_steps,
"generator": torch.Generator("cpu").manual_seed(seed)
}
# Add parameters for Flux model
common_params["guidance_scale"] = guidance_scale
# Try running generation with dimensions
try:
progress_text.text("Running generation...")
progress_bar.progress(0.2)
# First try with dimensions
common_params["height"] = int(height)
common_params["width"] = int(width)
result = pipe(**common_params)
except Exception as e:
debug_info.warning(f"First attempt failed: {str(e)}")
progress_text.text("Retrying with adjusted parameters...")
# Remove dimensions and try again
del common_params["height"]
del common_params["width"]
result = pipe(**common_params)
# Get the result image
result_image = result.images[0]
# Update final progress
progress_bar.progress(1.0)
progress_text.text("Complete!")
debug_info.empty() # Clear debug info
# Display the result
st.image(result_image, caption="Generated Result", use_column_width=True)
# Add download button
buf = io.BytesIO()
result_image.save(buf, format="PNG")
st.download_button(
label="Download result",
data=buf.getvalue(),
file_name="inpaint_result.png",
mime="image/png",
)
# Display used seed
st.write(f"Seed used: {seed}")
except Exception as e:
st.error(f"An error occurred during generation: {str(e)}")
st.error("Try adjusting the parameters or using a different image.")
# If no image is uploaded
if uploaded_file is None:
with col2:
st.write("Please upload an image first")