Spaces:
Paused
Paused
File size: 7,803 Bytes
343e5a8 461c048 343e5a8 d97b2a0 343e5a8 d97b2a0 343e5a8 d97b2a0 343e5a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import random
import numpy as np
from PIL import Image, ImageOps
import base64
from io import BytesIO
import torch
import torchvision.transforms.functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.pix2pix_turbo import Pix2Pix_Turbo
import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize
import re
import os
import json
import logging
import gc
import gradio as gr
from torch.cuda.amp import autocast
# Set environment variable for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# Function to clear CUDA cache and collect garbage
def clear_memory():
torch.cuda.empty_cache()
gc.collect()
# Load the configuration from config.json
with open('config.json', 'r') as config_file:
config = json.load(config_file)
# Setup logging as per config
logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"])
# Ensure NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')
# File paths for storing sketches and outputs
SKETCH_PATH = config["file_paths"]["sketch_path"]
OUTPUT_PATH = config["file_paths"]["output_path"]
# Global Constants and Configuration
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEFAULT_STYLE_NAME = config["default_style_name"]
RANDOM_VALUES = config["random_values"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
MAX_SEED = config["model_params"]["max_seed"]
# Canvas configuration
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]
# Preload Models
logging.debug("Loading BLIP and Pix2Pix models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
logging.debug("Models loaded.")
def pil_image_to_data_uri(img: Image, format="PNG") -> str:
"""Converts a PIL image to a data URI."""
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
def generate_prompt_from_sketch(image: Image) -> str:
"""Generates a text prompt based on a sketch using the BLIP model."""
logging.debug("Generating prompt from sketch...")
image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
inputs = processor(image, return_tensors="pt").to(DEVICE)
out = blip_model.generate(**inputs, max_new_tokens=50)
text_prompt = processor.decode(out[0], skip_special_tokens=True)
logging.debug(f"Generated prompt: {text_prompt}")
recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()]
random_prefix = random.choice(RANDOM_VALUES)
prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
logging.debug(f"Final prompt: {prompt}")
return prompt
def extract_main_words(item: str) -> str:
"""Extracts all nouns from a given text fragment and returns them as a space-separated string."""
words = word_tokenize(item.strip())
tagged = pos_tag(words)
nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')]
return ' '.join(nouns)
def normalize_image(image, range_from=(-1, 1)):
"""
Normalize the input image to a specified range.
:param image: The PIL Image to be normalized.
:param range_from: The target range for normalization, typically (-1, 1) or (0, 1).
:return: Normalized image tensor.
"""
# Convert the image to a tensor
image_t = F.to_tensor(image)
if range_from == (-1, 1):
# Normalize from [0, 1] to [-1, 1]
image_t = image_t * 2 - 1
return image_t
def run(image, prompt, prompt_template, style_name, seed, val_r):
"""Runs the main image processing pipeline."""
logging.debug("Running model inference...")
if image is None:
blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255)
blank_image.save(SKETCH_PATH) # Save blank image as sketch
logging.debug("No image provided. Saving blank image.")
return "", "", "", ""
if not prompt.strip():
prompt = generate_prompt_from_sketch(image)
# Save the sketch to a file
image.save(SKETCH_PATH)
# Show the original prompt before processing
original_prompt = f"Original Prompt: {prompt}"
logging.debug(original_prompt)
prompt = prompt_template.replace("{prompt}", prompt)
logging.debug(f"Processing with prompt: {prompt}")
image = image.convert("RGB")
image_tensor = F.to_tensor(image) * 2 - 1 # Normalize to [-1, 1]
clear_memory() # Clear memory before running the model
try:
with torch.no_grad():
c_t = image_tensor.unsqueeze(0).to(DEVICE).float()
torch.manual_seed(seed)
B, C, H, W = c_t.shape
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
logging.debug("Calling Pix2Pix model...")
# Enable mixed precision
with autocast():
output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
logging.debug("Model inference completed.")
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.warning("CUDA out of memory error. Falling back to CPU.")
with torch.no_grad():
c_t = c_t.cpu()
noise = noise.cpu()
pix2pix_model_cpu = pix2pix_model.cpu() # Move the model to CPU
output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
else:
raise e
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
output_pil.save(OUTPUT_PATH)
logging.debug("Output image saved.")
return output_pil
def gradio_interface(image, prompt, style_name, seed, val_r):
"""Gradio interface function to handle inputs and generate outputs."""
# Endpoint: `image` - Input image from user (Sketch Image)
# Endpoint: `prompt` - Text prompt (optional)
# Endpoint: `style_name` - Selected style from dropdown
# Endpoint: `seed` - Seed for reproducibility
# Endpoint: `val_r` - Sketch guidance value
prompt_template = STYLES.get(style_name, STYLES[DEFAULT_STYLE_NAME])
result_image = run(image, prompt, prompt_template, style_name, seed, val_r)
return result_image
# Create the Gradio Interface
interface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(source="upload", type="pil", label="Sketch Image"), # Endpoint: `image`
gr.Textbox(lines=2, placeholder="Enter a text prompt (optional)", label="Prompt"), # Endpoint: `prompt`
gr.Dropdown(choices=list(STYLES.keys()), value=DEFAULT_STYLE_NAME, label="Style"), # Endpoint: `style_name`
gr.Slider(minimum=0, maximum=MAX_SEED, step=1, default=DEFAULT_SEED, label="Seed"), # Endpoint: `seed`
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=VAL_R_DEFAULT, label="Sketch Guidance") # Endpoint: `val_r`
],
outputs=gr.Image(label="Generated Image"), # Output endpoint: `result_image`
title="Sketch to Image Generation",
description="Upload a sketch and generate an image based on a prompt and style."
)
if __name__ == "__main__":
# Launch the Gradio interface
interface.launch(share=True)
|