import os import cv2 import json import torch import gc # Added for garbage collection from tqdm import tqdm from PIL import Image import numpy as np from ..utils import sr_utils, seg_utils, inpaint_utils class ImageProcessingPipeline: """Base class for image processing pipelines with common functionality""" def __init__(self, params): """Initialize pipeline with processing parameters""" self.params = params self.seed = self._init_seed(params['seed']) def _init_seed(self, seed_param): """Initialize random seed for reproducibility""" if seed_param == -1: import random return random.randint(1, 65535) return seed_param def _prepare_output_dir(self, output_path): """Create output directory if it doesn't exist""" os.makedirs(output_path, exist_ok=True) def _prepare_image_path(self, img_path, output_path): """Create basic input image if it doesn't exist""" full_image_path = f"{output_path}/full_image.png" image = Image.open(img_path) image.save(full_image_path) def _get_image_path(self, base_dir, priority_files): """Get image path based on priority of existing files""" for file in priority_files: path = os.path.join(base_dir, file) if os.path.exists(path): return path return os.path.join(base_dir, "full_image.png") def _process_mask(self, mask_path, base_dir, size, mask_infos_key, edge_padding: int = 20): """Process mask with dilation and smoothing""" mask_sharp = cv2.imread(os.path.join(base_dir, mask_path), 0) with open(os.path.join(base_dir, f'{mask_infos_key}.json')) as f: mask_infos = json.load(f)["bboxes"] mask_smooth = inpaint_utils.get_adaptive_smooth_mask_ksize_ctrl( mask_sharp, mask_infos, basek=self.params['dilation_size'], threshold=self.params['threshold'], r=self.params['ratio'] ) # Apply edge padding mask_smooth[:, 0:edge_padding] = 1 mask_smooth[:, -edge_padding:] = 1 return cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR) def _run_inpainting(self, image, mask, size, prompt_config, image_info, inpaint_model): """Run inpainting with configured parameters""" labels = image_info["labels"] # process prompt if self._is_indoor(image_info): prompt = prompt_config["indoor"]["positive_prompt"] negative_prompt = prompt_config["indoor"]["negative_prompt"] else: prompt = prompt_config["outdoor"]["positive_prompt"] negative_prompt = prompt_config["outdoor"]["negative_prompt"] if labels: negative_prompt += ", " + ", ".join(labels) result = inpaint_model( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, height=size[0], width=size[1], strength=self.params['strength'], true_cfg_scale=self.params['cfg_scale'], guidance_scale=30, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(self.seed), ).images[0] # Clear memory after inpainting torch.cuda.empty_cache() gc.collect() return result def _is_indoor(self, img_info): """Check if image is classified as indoor""" return img_info["class"] in ["indoor", "[indoor]"] def _run_super_resolution(self, input_path, output_path, sr_model, suffix='sr'): """Run super-resolution on input image""" if os.path.exists(input_path): sr_utils.sr_inference( input_path, output_path, sr_model, scale=self.params['scale'], ext='auto', suffix=suffix ) # Clear memory after super-resolution torch.cuda.empty_cache() gc.collect() class ForegroundPipeline(ImageProcessingPipeline): """Pipeline for processing foreground layers (fg1 and fg2)""" def __init__(self, params, layer): """Initialize with parameters and layer type (0 for fg1, 1 for fg2)""" super().__init__(params) self.layer = layer self.layer_name = f"fg{layer+1}" def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model): """Run full processing pipeline for foreground layer""" print(f"============= Now starting {self.layer_name} processing ===============") # Phase 1: Super Resolution self._process_super_resolution(img_infos, sr_model) # Phase 2: Segmentation self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model) # Phase 3: Inpainting self._process_inpainting(img_infos, inpaint_model) torch.cuda.empty_cache() gc.collect() def _process_super_resolution(self, img_infos, sr_model): """Process super-resolution phase""" for img_info in tqdm(img_infos): output_path = img_info["output_path"] # prepare input image if self.layer == 0: self._prepare_image_path(img_info["image_path"], output_path) input_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"]) self._prepare_output_dir(output_path) self._run_super_resolution(input_path, output_path, sr_model) def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model): """Process segmentation phase""" for img_info in tqdm(img_infos): if not img_info.get("labels"): continue output_path = img_info["output_path"] img_path = self._get_image_path(output_path, [f"remove_fg1_image.png", "full_image.png"]) img_sr_path = img_path.replace(".png", "_sr.png") text = ". ".join(img_info["labels"]) + "." if img_info["labels"] else "" if self._is_indoor(img_info): seg_utils.get_fg_pad_indoor( output_path, img_path, img_sr_path, zim_predictor, gd_processor, gd_model, text, layer=self.layer, scale=self.params['scale'] ) else: seg_utils.get_fg_pad_outdoor( output_path, img_path, img_sr_path, zim_predictor, gd_processor, gd_model, text, layer=self.layer, scale=self.params['scale'] ) # Clear memory after segmentation torch.cuda.empty_cache() gc.collect() def _process_inpainting(self, img_infos, inpaint_model): """Process inpainting phase""" for img_info in tqdm(img_infos): base_dir = img_info["output_path"] mask_path = f'{self.layer_name}_mask.png' if not os.path.exists(os.path.join(base_dir, mask_path)): continue image = Image.open(self._get_image_path( base_dir, [f"remove_fg{self.layer}_image.png", "full_image.png"] )).convert('RGB') size = image.height, image.width mask_smooth = self._process_mask( mask_path, base_dir, size, self.layer_name ) pano_mask_pil = Image.fromarray(mask_smooth*255) result = self._run_inpainting( image, pano_mask_pil, size, self.params['prompt_config'], img_info, inpaint_model ) result.save(f'{base_dir}/remove_{self.layer_name}_image.png') # Clear memory after saving result del image, mask_smooth, pano_mask_pil, result torch.cuda.empty_cache() gc.collect() class SkyPipeline(ImageProcessingPipeline): """Pipeline for processing sky layer""" def process(self, img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model): """Run full processing pipeline for sky layer""" print("============= Now starting sky processing ===============") # Phase 1: Super Resolution self._process_super_resolution(img_infos, sr_model) # Phase 2: Segmentation self._process_segmentation(img_infos, zim_predictor, gd_processor, gd_model) # Phase 3: Inpainting self._process_inpainting(img_infos, inpaint_model) # Phase 4: Final Super Resolution self._process_final_super_resolution(img_infos, sr_model) # Clear all models from memory after processing self._clear_models([sr_model, zim_predictor, gd_processor, gd_model, inpaint_model]) def _clear_models(self, models): """Clear model weights from memory""" for model in models: if hasattr(model, 'cpu'): model.cpu() if hasattr(model, 'to'): model.to('cpu') torch.cuda.empty_cache() gc.collect() def _process_super_resolution(self, img_infos, sr_model): """Process initial super-resolution phase""" for img_info in tqdm(img_infos): output_path = img_info["output_path"] self._prepare_output_dir(output_path) input_path = f"{output_path}/remove_fg2_image.png" self._run_super_resolution(input_path, output_path, sr_model) def _process_segmentation(self, img_infos, zim_predictor, gd_processor, gd_model): """Process segmentation phase for sky""" for img_info in tqdm(img_infos): if self._is_indoor(img_info): continue output_path = img_info["output_path"] img_path = self._get_image_path( output_path, ["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"] ) img_sr_path = img_path.replace(".png", "_sr.png") seg_utils.get_sky( output_path, img_path, img_sr_path, zim_predictor, gd_processor, gd_model, "sky." ) # Clear memory after segmentation torch.cuda.empty_cache() gc.collect() def _process_inpainting(self, img_infos, inpaint_model): """Process inpainting phase for sky""" for img_info in tqdm(img_infos): if self._is_indoor(img_info): continue base_dir = img_info["output_path"] if not os.path.exists(os.path.join(base_dir, 'sky_mask.png')): continue image = Image.open(self._get_image_path( base_dir, ["remove_fg2_image.png", "remove_fg1_image.png", "full_image.png"] )).convert('RGB') size = image.height, image.width mask_sharp = Image.open(os.path.join(base_dir, 'sky_mask.png')).convert('L') mask_smooth = inpaint_utils.get_smooth_mask(np.asarray(mask_sharp)) # Apply edge padding mask_smooth[:, 0:20] = 1 mask_smooth[:, -20:] = 1 mask_smooth = cv2.resize(mask_smooth, (size[1], size[0]), Image.BILINEAR) pano_mask_pil = Image.fromarray(mask_smooth*255) # Sky-specific inpainting parameters prompt = "sky-coverage, whole sky image, ultra-high definition stratosphere" negative_prompt = ("object, text, defocus, pure color, low-res, blur, pixelation, foggy, " "noise, mosaic, artifacts, low-contrast, low-quality, blurry, tree, " "grass, plant, ground, land, mountain, building, lake, river, sea, ocean") result = inpaint_model( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=pano_mask_pil, height=size[0], width=size[1], strength=self.params['strength'], true_cfg_scale=self.params['cfg_scale'], guidance_scale=20, num_inference_steps=50, max_sequence_length=512, generator=torch.Generator("cpu").manual_seed(self.seed), ).images[0] result.save(f'{base_dir}/sky_image.png') # Clear memory after saving result del image, mask_sharp, mask_smooth, pano_mask_pil, result torch.cuda.empty_cache() gc.collect() def _process_final_super_resolution(self, img_infos, sr_model): """Process final super-resolution phase""" for img_info in tqdm(img_infos): output_path = img_info["output_path"] input_path = f"{output_path}/sky_image.png" self._run_super_resolution(input_path, output_path, sr_model) # Original functions refactored to use the new pipeline classes def remove_fg1_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): """Process the first foreground layer (fg1)""" pipeline = ForegroundPipeline(params, layer=0) pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model) def remove_fg2_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): """Process the second foreground layer (fg2)""" pipeline = ForegroundPipeline(params, layer=1) pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model) def sky_pipeline(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model, params): """Process the sky layer""" pipeline = SkyPipeline(params) pipeline.process(img_infos, sr_model, zim_predictor, gd_processor, gd_model, inpaint_model)