# Copyright 2022 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ["XLA_FLAGS"] = "--xla_gpu_force_compilation_parallelism=1" import subprocess subprocess.call(["pip", "install", "."]) import numpy as np from PIL import Image import importlib import ml_collections import tempfile import jax.numpy as jnp import flax from cog import BasePredictor, Path, Input, BaseModel from maxim.run_eval import ( _MODEL_FILENAME, _MODEL_VARIANT_DICT, _MODEL_CONFIGS, get_params, mod_padding_symmetric, make_shape_even, augment_image, ) class Predictor(BasePredictor): def setup(self): self.params = { "Image Denoising": get_params("checkpoints/denoising-SIDD/checkpoint.npz"), "Image Deblurring (GoPro)": get_params( "checkpoints/debluring-GoPro/checkpoint.npz" ), "Image Deblurring (REDS)": get_params( "checkpoints/debluring-REDS/checkpoint.npz" ), "Image Deblurring (RealBlur_R)": get_params( "checkpoints/debluring-Real-Blur-R/checkpoint.npz" ), "Image Deblurring (RealBlur_J)": get_params( "checkpoints/debluring-Real-Blur-J/checkpoint.npz" ), "Image Deraining (Rain streak)": get_params( "checkpoints/deraining-Rain13k/checkpoint.npz" ), "Image Deraining (Rain drop)": get_params( "checkpoints/deraining-Raindrop/checkpoint.npz" ), "Image Dehazing (Indoor)": get_params( "checkpoints/dehazing-RESIDE-Indoor/checkpoint.npz" ), "Image Dehazing (Outdoor)": get_params( "checkpoints/dehazing-RESIDE-Outdoor/checkpoint.npz" ), "Image Enhancement (Low-light)": get_params( "checkpoints/enhancement-LOL/checkpoint.npz" ), "Image Enhancement (Retouching)": get_params( "checkpoints/enhancement-FiveK/checkpoint.npz" ), } model_mod = importlib.import_module(f"maxim.models.{_MODEL_FILENAME}") self.models = {} for task in _MODEL_VARIANT_DICT.keys(): model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS) model_configs.variant = _MODEL_VARIANT_DICT[task] self.models[task] = model_mod.Model(**model_configs) def predict( self, model: str = Input( choices=[ "Image Denoising", "Image Deblurring (GoPro)", "Image Deblurring (REDS)", "Image Deblurring (RealBlur_R)", "Image Deblurring (RealBlur_J)", "Image Deraining (Rain streak)", "Image Deraining (Rain drop)", "Image Dehazing (Indoor)", "Image Dehazing (Outdoor)", "Image Enhancement (Low-light)", "Image Enhancement (Retouching)", ], description="Choose a model.", ), image: Path = Input( description="Input image.", ), ) -> Path: params = self.params[model] task = model.split()[1] model = self.models[task] input_img = ( np.asarray(Image.open(str(image)).convert("RGB"), np.float32) / 255.0 ) # Padding images to have even shapes height, width = input_img.shape[0], input_img.shape[1] input_img = make_shape_even(input_img) height_even, width_even = input_img.shape[0], input_img.shape[1] # padding images to be multiplies of 64 input_img = mod_padding_symmetric(input_img, factor=64) input_img = np.expand_dims(input_img, axis=0) # handle multi-stage outputs, obtain the last scale output of last stage preds = model.apply({"params": flax.core.freeze(params)}, input_img) if isinstance(preds, list): preds = preds[-1] if isinstance(preds, list): preds = preds[-1] preds = np.array(preds[0], np.float32) # unpad images to get the original resolution new_height, new_width = preds.shape[0], preds.shape[1] h_start = new_height // 2 - height_even // 2 h_end = h_start + height w_start = new_width // 2 - width_even // 2 w_end = w_start + width preds = preds[h_start:h_end, w_start:w_end, :] # save files out_path = Path(tempfile.mkdtemp()) / "output.png" Image.fromarray( np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(jnp.uint8)) ).save(str(out_path)) return out_path