fluxhdupscaler / app.py
comrender's picture
Update app.py
6490774 verified
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxImg2ImgPipeline
from transformers import AutoProcessor, AutoModelForCausalLM
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
import requests
# For ESRGAN (requires pip install basicsr gfpgan)
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils import img2tensor, tensor2img
USE_ESRGAN = True
except ImportError:
USE_ESRGAN = False
warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
.main-header {
text-align: center;
margin-bottom: 2rem;
}
"""
# Device setup
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
else:
power_device = "CPU"
device = "cpu"
# Get HuggingFace token
huggingface_token = os.getenv("HF_TOKEN")
# Download FLUX model
print("πŸ“₯ Downloading FLUX model...")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*.gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load Florence-2 model for image captioning
print("πŸ“₯ Loading Florence-2 model...")
florence_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=torch.float16,
trust_remote_code=True,
attn_implementation="eager" # Fix for SDPA compatibility issue
).to(device)
florence_processor = AutoProcessor.from_pretrained(
"microsoft/Florence-2-large",
trust_remote_code=True
)
# Load FLUX Img2Img pipeline
print("πŸ“₯ Loading FLUX Img2Img...")
pipe = FluxImg2ImgPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
)
pipe.to(device)
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
print("βœ… All models loaded successfully!")
# Download ESRGAN model if using
if USE_ESRGAN:
esrgan_path = "4x-UltraSharp.pth"
if not os.path.exists(esrgan_path):
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
with open(esrgan_path, "wb") as f:
f.write(requests.get(url).content)
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
state_dict = torch.load(esrgan_path)['params_ema']
esrgan_model.load_state_dict(state_dict)
esrgan_model.eval()
esrgan_model.to(device)
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
def generate_caption(image):
"""Generate detailed caption using Florence-2"""
try:
task_prompt = "<MORE_DETAILED_CAPTION>"
prompt = task_prompt
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = florence_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
num_beams=3,
do_sample=True,
)
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
caption = parsed_answer[task_prompt]
return caption
except Exception as e:
print(f"Caption generation failed: {e}")
return "a high quality detailed image"
def process_input(input_image, upscale_factor):
"""Process input image and handle size constraints"""
w, h = input_image.size
w_original, h_original = w, h
aspect_ratio = w / h
was_resized = False
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
warnings.warn(
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
)
gr.Info(
f"Requested output image is too large. Resizing input to fit within pixel budget."
)
target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
scale = (target_input_pixels / (w * h)) ** 0.5
new_w = int(w * scale) - int(w * scale) % 8
new_h = int(h * scale) - int(h * scale) % 8
input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
was_resized = True
return input_image, w_original, h_original, was_resized
def load_image_from_url(url):
"""Load image from URL"""
try:
response = requests.get(url, stream=True)
response.raise_for_status()
return Image.open(response.raw)
except Exception as e:
raise gr.Error(f"Failed to load image from URL: {e}")
def esrgan_upscale(image, scale=4):
if not USE_ESRGAN:
return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
with torch.no_grad():
output = esrgan_model(img.unsqueeze(0)).squeeze()
output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
return Image.fromarray(output_img)
def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
"""Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
w, h = image.size
output = image.copy() # Start with the control image
for x in range(0, w, tile_size - overlap):
for y in range(0, h, tile_size - overlap):
tile_w = min(tile_size, w - x)
tile_h = min(tile_size, h - y)
tile = image.crop((x, y, x + tile_w, y + tile_h))
# Run Flux on tile
gen_tile = pipe(
prompt=prompt,
image=tile,
strength=strength,
num_inference_steps=steps,
guidance_scale=guidance,
height=tile_h,
width=tile_w,
generator=generator,
).images[0]
# Paste with blending if overlap
if overlap > 0:
paste_box = (x, y, x + tile_w, y + tile_h)
if x > 0 or y > 0:
# Simple linear blend on overlaps
mask = Image.new('L', (tile_w, tile_h), 255)
if x > 0:
for i in range(overlap):
for j in range(tile_h):
mask.putpixel((i, j), int(255 * (i / overlap)))
if y > 0:
for i in range(tile_w):
for j in range(overlap):
mask.putpixel((i, j), int(255 * (j / overlap)))
output.paste(gen_tile, paste_box, mask)
else:
output.paste(gen_tile, paste_box)
else:
output.paste(gen_tile, (x, y))
return output
@spaces.GPU(duration=120)
def enhance_image(
image_input,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
use_generated_caption,
custom_prompt,
progress=gr.Progress(track_tqdm=True),
):
"""Main enhancement function"""
# Handle image input
if image_input is not None:
input_image = image_input
elif image_url:
input_image = load_image_from_url(image_url)
else:
raise gr.Error("Please provide an image (upload or URL)")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
true_input_image = input_image
# Process input image
input_image, w_original, h_original, was_resized = process_input(
input_image, upscale_factor
)
# Generate caption if requested
if use_generated_caption:
gr.Info("πŸ” Generating image caption...")
generated_caption = generate_caption(input_image)
prompt = generated_caption
else:
prompt = custom_prompt if custom_prompt.strip() else ""
generator = torch.Generator().manual_seed(seed)
gr.Info("πŸš€ Upscaling image...")
# Initial upscale
if USE_ESRGAN and upscale_factor == 4:
control_image = esrgan_upscale(input_image, upscale_factor)
else:
w, h = input_image.size
control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
# Tiled Flux Img2Img for refinement
image = tiled_flux_img2img(
pipe,
prompt,
control_image,
denoising_strength,
num_inference_steps,
1.0, # Hardcoded guidance_scale to 1
generator,
tile_size=1024,
overlap=32
)
if was_resized:
gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
# Resize input image to match output size for slider alignment
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
return [resized_input, image]
# Create Gradio interface
with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
gr.HTML("""
<div class="main-header">
<h1>🎨 AI Image Upscaler</h1>
<p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
<p>Currently running on <strong>{}</strong></p>
</div>
""".format(power_device))
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“€ Input</h3>")
with gr.Tabs():
with gr.TabItem("πŸ“ Upload Image"):
input_image = gr.Image(
label="Upload Image",
type="pil",
height=200 # Made smaller
)
with gr.TabItem("πŸ”— Image URL"):
image_url = gr.Textbox(
label="Image URL",
placeholder="https://example.com/image.jpg",
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
)
gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
use_generated_caption = gr.Checkbox(
label="Use AI-generated caption (Florence-2)",
value=True,
info="Generate detailed caption automatically"
)
custom_prompt = gr.Textbox(
label="Custom Prompt (optional)",
placeholder="Enter custom prompt or leave empty for generated caption",
lines=2
)
gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=2,
info="How much to upscale the image"
)
num_inference_steps = gr.Slider(
label="Number of Inference Steps",
minimum=8,
maximum=50,
step=1,
value=25,
info="More steps = better quality but slower"
)
denoising_strength = gr.Slider(
label="Denoising Strength",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.3,
info="Controls how much the image is transformed"
)
with gr.Row():
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
interactive=True
)
enhance_btn = gr.Button(
"πŸš€ Upscale Image",
variant="primary",
size="lg"
)
with gr.Column(scale=2): # Larger scale for results
gr.HTML("<h3>πŸ“Š Results</h3>")
result_slider = ImageSlider(
type="pil",
interactive=False, # Disable interactivity to prevent uploads
height=600, # Made larger
elem_id="result_slider",
label=None # Remove default label
)
# Event handler
enhance_btn.click(
fn=enhance_image,
inputs=[
input_image,
image_url,
seed,
randomize_seed,
num_inference_steps,
upscale_factor,
denoising_strength,
use_generated_caption,
custom_prompt,
],
outputs=[result_slider]
)
gr.HTML("""
<div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
<p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
</div>
""")
# Custom CSS for slider
gr.HTML("""
<style>
#result_slider .slider {
width: 100% !important;
max-width: inherit !important;
}
#result_slider img {
object-fit: contain !important;
width: 100% !important;
height: auto !important;
}
#result_slider .gr-button-tool {
display: none !important;
}
#result_slider .gr-button-undo {
display: none !important;
}
#result_slider .gr-button-clear {
display: none !important;
}
#result_slider .badge-container .badge {
display: none !important;
}
#result_slider .badge-container::before {
content: "Before";
position: absolute;
top: 10px;
left: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .badge-container::after {
content: "After";
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.5);
color: white;
padding: 5px;
border-radius: 5px;
z-index: 10;
}
#result_slider .fullscreen img {
object-fit: contain !important;
width: 100vw !important;
height: 100vh !important;
}
</style>
""")
# JS to set slider default position to middle
gr.HTML("""
<script>
document.addEventListener('DOMContentLoaded', function() {
const sliderInput = document.querySelector('#result_slider input[type="range"]');
if (sliderInput) {
sliderInput.value = 50;
sliderInput.dispatchEvent(new Event('input'));
}
});
</script>
""")
if __name__ == "__main__":
demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)