Spaces:
Build error
Build error
import os | |
import cv2 | |
import json | |
import torch | |
import gc # Added for garbage collection | |
from tqdm import tqdm | |
from PIL import Image | |
import numpy as np | |
from ..utils import sr_utils, seg_utils, inpaint_utils | |
class ImageProcessingPipeline: | |
"""Base class for image processing pipelines with common functionality""" | |
def __init__(self, params): | |
"""Initialize pipeline with processing parameters""" | |
self.params = params | |
self.seed = self._init_seed(params['seed']) | |
def _init_seed(self, seed_param): | |
"""Initialize random seed for reproducibility""" | |
if seed_param == -1: | |
import random | |
return random.randint(1, 65535) | |
return seed_param | |
def _prepare_output_dir(self, output_path): | |
"""Create output directory if it doesn't exist""" | |
os.makedirs(output_path, exist_ok=True) | |
def _prepare_image_path(self, img_path, output_path): | |
"""Create basic input image if it doesn't exist""" | |
full_image_path = f"{output_path}/full_image.png" | |
image = Image.open(img_path) | |
image.save(full_image_path) | |
def _get_image_path(self, base_dir, priority_files): | |
"""Get image path based on priority of existing files""" | |
for file in priority_files: | |
path = os.path.join(base_dir, file) | |
if os.path.exists(path): | |
return path | |
return os.path.join(base_dir, "full_image.png") | |
def _process_mask(self, mask_path, base_dir, size, mask_infos_key, edge_padding: int = 20): | |
"""Process mask with dilation and smoothing""" | |
mask_sharp = cv2.imread(os.path.join(base_dir, mask_path), 0) | |
with open(os.path.join(base_dir, f'{mask_infos_key}.json')) as f: | |
mask_infos = json.load(f)["bboxes"] | |
mask_smooth = inpaint_utils.get_adaptive_smooth_mask_ksize_ctrl( | |
mask_sharp, mask_infos, | |
basek=self.params['dilation_size'], | |
threshold=self.params['threshold'], | |
r=self.params['ratio'] | |
) | |
# Apply edge padding | |
mask_smooth[:, 0:edge_padding] = 1 | |
mask_smooth[:, -edge_padding:] = 1 | |
return cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR) | |
def _run_inpainting(self, image, mask, size, prompt_config, image_info, inpaint_model): | |
"""Run inpainting with configured parameters""" | |
labels = image_info["labels"] | |
# process prompt | |
if self._is_indoor(image_info): | |
prompt = prompt_config["indoor"]["positive_prompt"] | |
negative_prompt = prompt_config["indoor"]["negative_prompt"] | |
else: | |
prompt = prompt_config["outdoor"]["positive_prompt"] | |
negative_prompt = prompt_config["outdoor"]["negative_prompt"] | |
if labels: | |
negative_prompt += ", " + ", ".join(labels) | |
result = inpaint_model( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=image, | |
mask_image=mask, | |
height=size[0], | |
width=size[1], | |
strength=self.params['strength'], | |
true_cfg_scale=self.params['cfg_scale'], | |
guidance_scale=30, | |
num_inference_steps=50, | |
max_sequence_length=512, | |
generator=torch.Generator("cpu").manual_seed(self.seed), | |
).images[0] | |
# Clear memory after inpainting | |
torch.cuda.empty_cache() | |
gc.collect() | |
return result | |
def _is_indoor(self, img_info): | |
"""Check if image is classified as indoor""" | |
return img_info["class"] in ["indoor", "[indoor]"] | |
def _run_super_resolution(self, input_path, output_path, sr_model, suffix='sr'): | |
"""Run super-resolution on input image""" | |
if os.path.exists(input_path): | |
sr_utils.sr_inference( | |
input_path, output_path, sr_model, | |
scale=self.params['scale'], ext='auto', suffix=suffix | |
) | |
# Clear memory after super-resolution | |
torch.cuda.empty_cache() | |
gc.collect() | |
class ForegroundPipeline(ImageProcessingPipeline): | |
"""Pipeline for processing foreground layers (fg1 and fg2)""" | |
def __init__(self, params, layer): | |
"""Initialize with parameters and layer type (0 for fg1, 1 for fg2)""" | |
super().__init__(params) | |
self.layer = layer | |
self.layer_name = f"fg{layer+1}" | |
def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model): | |
"""Run full processing pipeline for foreground layer""" | |
print(f"============= Now starting {self.layer_name} processing ===============") | |
# Phase 1: Super Resolution | |
self._process_super_resolution(img_infos, sr_model) | |
# Phase 2: Segmentation | |
self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model) | |
# Phase 3: Inpainting | |
self._process_inpainting(img_infos, inpaint_model) | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _process_super_resolution(self, img_infos, sr_model): | |
"""Process super-resolution phase""" | |
for img_info in tqdm(img_infos): | |
output_path = img_info["output_path"] | |
# prepare input image | |
if self.layer == 0: | |
self._prepare_image_path(img_info["image_path"], output_path) | |
input_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"]) | |
self._prepare_output_dir(output_path) | |
self._run_super_resolution(input_path, output_path, sr_model) | |
def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model): | |
"""Process segmentation phase""" | |
for img_info in tqdm(img_infos): | |
if not img_info.get("labels"): | |
continue | |
output_path = img_info["output_path"] | |
img_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"]) | |
img_sr_path = img_path.replace(".png", "_sr.png") | |
text = ". ".join(img_info["labels"]) + "." if img_info["labels"] else "" | |
if self._is_indoor(img_info): | |
seg_utils.get_fg_pad_indoor( | |
output_path, img_path, img_sr_path, | |
zim_predictor, gd_processor, gd_model, | |
text, layer=self.layer, scale=self.params['scale'] | |
) | |
else: | |
seg_utils.get_fg_pad_outdoor( | |
output_path, img_path, img_sr_path, | |
zim_predictor, gd_processor, gd_model, | |
text, layer=self.layer, scale=self.params['scale'] | |
) | |
# Clear memory after segmentation | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _process_inpainting(self, img_infos, inpaint_model): | |
"""Process inpainting phase""" | |
for img_info in tqdm(img_infos): | |
base_dir = img_info["output_path"] | |
mask_path = f'{self.layer_name}_mask.png' | |
if not os.path.exists(os.path.join(base_dir, mask_path)): | |
continue | |
image = Image.open(self._get_image_path( | |
base_dir, | |
[f"remove_fg{self.layer}_image.png", "full_image.png"] | |
)).convert('RGB') | |
size = image.height, image.width | |
mask_smooth = self._process_mask( | |
mask_path, base_dir, size, self.layer_name | |
) | |
pano_mask_pil = Image.fromarray(mask_smooth*255) | |
result = self._run_inpainting( | |
image, pano_mask_pil, size, | |
self.params['prompt_config'], img_info, inpaint_model | |
) | |
result.save(f'{base_dir}/remove_{self.layer_name}_image.png') | |
# Clear memory after saving result | |
del image, mask_smooth, pano_mask_pil, result | |
torch.cuda.empty_cache() | |
gc.collect() | |
class SkyPipeline(ImageProcessingPipeline): | |
"""Pipeline for processing sky layer""" | |
def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model): | |
"""Run full processing pipeline for sky layer""" | |
print("============= Now starting sky processing ===============") | |
# Phase 1: Super Resolution | |
self._process_super_resolution(img_infos, sr_model) | |
# Phase 2: Segmentation | |
self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model) | |
# Phase 3: Inpainting | |
self._process_inpainting(img_infos, inpaint_model) | |
# Phase 4: Final Super Resolution | |
self._process_final_super_resolution(img_infos, sr_model) | |
# Clear all models from memory after processing | |
self._clear_models([sr_model, zim_predictor, gd_processor, gd_model, inpaint_model]) | |
def _clear_models(self, models): | |
"""Clear model weights from memory""" | |
for model in models: | |
if hasattr(model, 'cpu'): | |
model.cpu() | |
if hasattr(model, 'to'): | |
model.to('cpu') | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _process_super_resolution(self, img_infos, sr_model): | |
"""Process initial super-resolution phase""" | |
for img_info in tqdm(img_infos): | |
output_path = img_info["output_path"] | |
self._prepare_output_dir(output_path) | |
input_path = f"{output_path}/remove_fg2_image.png" | |
self._run_super_resolution(input_path, output_path, sr_model) | |
def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model): | |
"""Process segmentation phase for sky""" | |
for img_info in tqdm(img_infos): | |
if self._is_indoor(img_info): | |
continue | |
output_path = img_info["output_path"] | |
img_path = self._get_image_path( | |
output_path, | |
["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"] | |
) | |
img_sr_path = img_path.replace(".png", "_sr.png") | |
seg_utils.get_sky( | |
output_path, img_path, img_sr_path, | |
zim_predictor, gd_processor, gd_model, "sky." | |
) | |
# Clear memory after segmentation | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _process_inpainting(self, img_infos, inpaint_model): | |
"""Process inpainting phase for sky""" | |
for img_info in tqdm(img_infos): | |
if self._is_indoor(img_info): | |
continue | |
base_dir = img_info["output_path"] | |
if not os.path.exists(os.path.join(base_dir, 'sky_mask.png')): | |
continue | |
image = Image.open(self._get_image_path( | |
base_dir, | |
["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"] | |
)).convert('RGB') | |
size = image.height, image.width | |
mask_sharp = Image.open(os.path.join(base_dir, 'sky_mask.png')).convert('L') | |
mask_smooth = inpaint_utils.get_smooth_mask(np.asarray(mask_sharp)) | |
# Apply edge padding | |
mask_smooth[:, 0:20] = 1 | |
mask_smooth[:, -20:] = 1 | |
mask_smooth = cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR) | |
pano_mask_pil = Image.fromarray(mask_smooth*255) | |
# Sky-specific inpainting parameters | |
prompt = "sky-coverage, whole sky image, ultra-high definition stratosphere" | |
negative_prompt = ("object, text, defocus, pure color, low-res, blur, pixelation, foggy, " | |
"noise, mosaic, artifacts, low-contrast, low-quality, blurry, tree, " | |
"grass, plant, ground, land, mountain, building, lake, river, sea, ocean") | |
result = inpaint_model( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
image=image, | |
mask_image=pano_mask_pil, | |
height=size[0], | |
width=size[1], | |
strength=self.params['strength'], | |
true_cfg_scale=self.params['cfg_scale'], | |
guidance_scale=20, | |
num_inference_steps=50, | |
max_sequence_length=512, | |
generator=torch.Generator("cpu").manual_seed(self.seed), | |
).images[0] | |
result.save(f'{base_dir}/sky_image.png') | |
# Clear memory after saving result | |
del image, mask_sharp, mask_smooth, pano_mask_pil, result | |
torch.cuda.empty_cache() | |
gc.collect() | |
def _process_final_super_resolution(self, img_infos, sr_model): | |
"""Process final super-resolution phase""" | |
for img_info in tqdm(img_infos): | |
output_path = img_info["output_path"] | |
input_path = f"{output_path}/sky_image.png" | |
self._run_super_resolution(input_path, output_path, sr_model) | |
# Original functions refactored to use the new pipeline classes | |
def remove_fg1_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): | |
"""Process the first foreground layer (fg1)""" | |
pipeline = ForegroundPipeline(params, layer=0) | |
pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model) | |
def remove_fg2_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): | |
"""Process the second foreground layer (fg2)""" | |
pipeline = ForegroundPipeline(params, layer=1) | |
pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model) | |
def sky_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): | |
"""Process the sky layer""" | |
pipeline = SkyPipeline(params) | |
pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model) | |