File size: 7,803 Bytes
343e5a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461c048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343e5a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d97b2a0
 
 
 
 
 
343e5a8
 
 
 
 
 
 
 
d97b2a0
 
 
 
 
343e5a8
d97b2a0
343e5a8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import random
import numpy as np
from PIL import Image, ImageOps
import base64
from io import BytesIO
import torch
import torchvision.transforms.functional as F
from transformers import BlipProcessor, BlipForConditionalGeneration
from src.pix2pix_turbo import Pix2Pix_Turbo
import nltk
from nltk import pos_tag
from nltk.tokenize import word_tokenize
import re
import os
import json
import logging
import gc
import gradio as gr
from torch.cuda.amp import autocast

# Set environment variable for better memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Function to clear CUDA cache and collect garbage
def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

# Load the configuration from config.json
with open('config.json', 'r') as config_file:
    config = json.load(config_file)

# Setup logging as per config
logging.basicConfig(level=config["logging"]["level"], format=config["logging"]["format"])

# Ensure NLTK resources are downloaded
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')

# File paths for storing sketches and outputs
SKETCH_PATH = config["file_paths"]["sketch_path"]
OUTPUT_PATH = config["file_paths"]["output_path"]

# Global Constants and Configuration
STYLE_LIST = config["style_list"]
STYLES = {style["name"]: style["prompt"] for style in STYLE_LIST}
DEFAULT_STYLE_NAME = config["default_style_name"]
RANDOM_VALUES = config["random_values"]
PIX2PIX_MODEL_NAME = config["model_params"]["pix2pix_model_name"]
DEVICE = config["model_params"]["device"]
DEFAULT_SEED = config["model_params"]["default_seed"]
VAL_R_DEFAULT = config["model_params"]["val_r_default"]
MAX_SEED = config["model_params"]["max_seed"]

# Canvas configuration
CANVAS_WIDTH = config["canvas"]["width"]
CANVAS_HEIGHT = config["canvas"]["height"]

# Preload Models
logging.debug("Loading BLIP and Pix2Pix models...")
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)
pix2pix_model = Pix2Pix_Turbo(PIX2PIX_MODEL_NAME)
logging.debug("Models loaded.")

def pil_image_to_data_uri(img: Image, format="PNG") -> str:
    """Converts a PIL image to a data URI."""
    buffered = BytesIO()
    img.save(buffered, format=format)
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return f"data:image/{format.lower()};base64,{img_str}"

def generate_prompt_from_sketch(image: Image) -> str:
    """Generates a text prompt based on a sketch using the BLIP model."""
    logging.debug("Generating prompt from sketch...")
    
    image = ImageOps.fit(image, (CANVAS_WIDTH, CANVAS_HEIGHT), Image.LANCZOS)
    inputs = processor(image, return_tensors="pt").to(DEVICE)

    out = blip_model.generate(**inputs, max_new_tokens=50)
    text_prompt = processor.decode(out[0], skip_special_tokens=True)
    logging.debug(f"Generated prompt: {text_prompt}")

    recognized_items = [extract_main_words(item) for item in text_prompt.split(', ') if item.strip()]
    random_prefix = random.choice(RANDOM_VALUES)
    
    prompt = f"a photo of a {' and '.join(recognized_items)}, {random_prefix}"
    logging.debug(f"Final prompt: {prompt}")
    return prompt

def extract_main_words(item: str) -> str:
    """Extracts all nouns from a given text fragment and returns them as a space-separated string."""
    words = word_tokenize(item.strip())
    tagged = pos_tag(words)
    nouns = [word.capitalize() for word, tag in tagged if tag in ('NN', 'NNP', 'NNPS', 'NNS')]
    return ' '.join(nouns)

def normalize_image(image, range_from=(-1, 1)):
    """
    Normalize the input image to a specified range.
    
    :param image: The PIL Image to be normalized.
    :param range_from: The target range for normalization, typically (-1, 1) or (0, 1).
    :return: Normalized image tensor.
    """
    # Convert the image to a tensor
    image_t = F.to_tensor(image)
    
    if range_from == (-1, 1):
        # Normalize from [0, 1] to [-1, 1]
        image_t = image_t * 2 - 1
    
    return image_t

def run(image, prompt, prompt_template, style_name, seed, val_r):
    """Runs the main image processing pipeline."""
    logging.debug("Running model inference...")
    if image is None:
        blank_image = Image.new("L", (CANVAS_WIDTH, CANVAS_HEIGHT), 255)
        blank_image.save(SKETCH_PATH)  # Save blank image as sketch
        logging.debug("No image provided. Saving blank image.")
        return "", "", "", ""

    if not prompt.strip():
        prompt = generate_prompt_from_sketch(image)

    # Save the sketch to a file
    image.save(SKETCH_PATH)

    # Show the original prompt before processing
    original_prompt = f"Original Prompt: {prompt}"
    logging.debug(original_prompt)

    prompt = prompt_template.replace("{prompt}", prompt)
    logging.debug(f"Processing with prompt: {prompt}")
    image = image.convert("RGB")
    image_tensor = F.to_tensor(image) * 2 - 1  # Normalize to [-1, 1]
    
    clear_memory()  # Clear memory before running the model

    try:
        with torch.no_grad():
            c_t = image_tensor.unsqueeze(0).to(DEVICE).float()
            torch.manual_seed(seed)
            B, C, H, W = c_t.shape
            
            noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
            logging.debug("Calling Pix2Pix model...")

            # Enable mixed precision
            with autocast():
                output_image = pix2pix_model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)

            logging.debug("Model inference completed.")
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            logging.warning("CUDA out of memory error. Falling back to CPU.")
            with torch.no_grad():
                c_t = c_t.cpu()
                noise = noise.cpu()
                pix2pix_model_cpu = pix2pix_model.cpu()  # Move the model to CPU
                output_image = pix2pix_model_cpu(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
        else:
            raise e

    output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
    output_pil.save(OUTPUT_PATH)
    logging.debug("Output image saved.")

    return output_pil

def gradio_interface(image, prompt, style_name, seed, val_r):
    """Gradio interface function to handle inputs and generate outputs."""
    # Endpoint: `image` - Input image from user (Sketch Image)
    # Endpoint: `prompt` - Text prompt (optional)
    # Endpoint: `style_name` - Selected style from dropdown
    # Endpoint: `seed` - Seed for reproducibility
    # Endpoint: `val_r` - Sketch guidance value

    prompt_template = STYLES.get(style_name, STYLES[DEFAULT_STYLE_NAME])
    result_image = run(image, prompt, prompt_template, style_name, seed, val_r)
    return result_image

# Create the Gradio Interface
interface = gr.Interface(
    fn=gradio_interface,
    inputs=[
        gr.Image(source="upload", type="pil", label="Sketch Image"),  # Endpoint: `image`
        gr.Textbox(lines=2, placeholder="Enter a text prompt (optional)", label="Prompt"),  # Endpoint: `prompt`
        gr.Dropdown(choices=list(STYLES.keys()), value=DEFAULT_STYLE_NAME, label="Style"),  # Endpoint: `style_name`
        gr.Slider(minimum=0, maximum=MAX_SEED, step=1, default=DEFAULT_SEED, label="Seed"),  # Endpoint: `seed`
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=VAL_R_DEFAULT, label="Sketch Guidance")  # Endpoint: `val_r`
    ],
    outputs=gr.Image(label="Generated Image"),  # Output endpoint: `result_image`
    title="Sketch to Image Generation",
    description="Upload a sketch and generate an image based on a prompt and style."
)

if __name__ == "__main__":
    # Launch the Gradio interface
    interface.launch(share=True)