import os # Force CPU-only mode os.environ["JAX_PLATFORM_NAME"] = "cpu" os.environ["CUDA_VISIBLE_DEVICES"] = "" os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TensorFlow warnings # Suppress JAX compilation warnings import warnings warnings.filterwarnings("ignore", category=UserWarning, module="jax") import gradio as gr import numpy as np from PIL import Image import importlib import ml_collections import jax.numpy as jnp import flax import requests import tempfile import spaces # Try to import MAXIM components try: from maxim.run_eval import ( _MODEL_FILENAME, _MODEL_VARIANT_DICT, _MODEL_CONFIGS, get_params, mod_padding_symmetric, make_shape_even, ) MAXIM_AVAILABLE = True except ImportError as e: print(f"MAXIM import failed: {e}") MAXIM_AVAILABLE = False # Model configurations with direct download URLs MODELS = { "Image Enhancement (Retouching)": { "task": "Enhancement", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/FiveK/checkpoint.npz", "filename": "enhancement_retouching.npz" }, "Image Enhancement (Low-light)": { "task": "Enhancement", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/LOL/checkpoint.npz", "filename": "enhancement_lowlight.npz" }, "Image Denoising": { "task": "Denoising", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz", "filename": "denoising.npz" }, "Image Deblurring": { "task": "Deblurring", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deblurring/GoPro/checkpoint.npz", "filename": "deblurring.npz" }, "Image Deraining": { "task": "Deraining", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deraining/Rain13k/checkpoint.npz", "filename": "deraining.npz" }, "Image Dehazing": { "task": "Dehazing", "url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Dehazing/SOTS-Indoor/checkpoint.npz", "filename": "dehazing_indoor.npz" } } class SimpleMAXIMPredictor: def __init__(self): self.models = {} self.params = {} self.initialized = False def initialize(self): """Initialize models when first needed""" if self.initialized or not MAXIM_AVAILABLE: return self.initialized try: # Build models for each task model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}') for task in _MODEL_VARIANT_DICT.keys(): model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS) model_configs.variant = _MODEL_VARIANT_DICT[task] self.models[task] = model_mod.Model(**model_configs) self.initialized = True return True except Exception as e: print(f"Initialization failed: {e}") return False def download_model(self, model_name): """Download model checkpoint if not already downloaded""" if model_name not in MODELS: return False, f"Model {model_name} not found" model_info = MODELS[model_name] filename = model_info["filename"] if not os.path.exists(filename): try: print(f"Downloading {filename}...") response = requests.get(model_info["url"], stream=True) response.raise_for_status() with open(filename, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) print(f"Downloaded {filename}") except Exception as e: return False, f"Failed to download {filename}: {str(e)}" # Load parameters if model_name not in self.params: try: self.params[model_name] = get_params(filename) except Exception as e: return False, f"Failed to load parameters: {str(e)}" return True, "Success" def preprocess_image(self, image): """Preprocess image for model input""" # Convert to numpy array input_img = np.asarray(image.convert('RGB'), np.float32) / 255.0 # Store original dimensions height, width = input_img.shape[0], input_img.shape[1] # Make shape even input_img = make_shape_even(input_img) height_even, width_even = input_img.shape[0], input_img.shape[1] # Pad to multiples of 64 input_img = mod_padding_symmetric(input_img, factor=64) input_img = np.expand_dims(input_img, axis=0) return input_img, height, width, height_even, width_even def postprocess_image(self, preds, height, width, height_even, width_even): """Postprocess model output to get final image""" # Handle multi-stage outputs if isinstance(preds, list): preds = preds[-1] if isinstance(preds, list): preds = preds[-1] preds = np.array(preds[0], np.float32) # Unpad to original resolution new_height, new_width = preds.shape[0], preds.shape[1] h_start = new_height // 2 - height_even // 2 h_end = h_start + height w_start = new_width // 2 - width_even // 2 w_end = w_start + width preds = preds[h_start:h_end, w_start:w_end, :] # Convert to PIL Image output_img = np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)) return Image.fromarray(output_img) @spaces.GPU def predict(self, image, model_name): """Main prediction function""" if not self.initialize(): return None, "Error: Could not initialize model" success, message = self.download_model(model_name) if not success: return None, message try: import jax # Force CPU mode - device info device_info = f"Using device: {jax.default_backend()} (CPU-only mode)" print(device_info) # Get model and parameters task = MODELS[model_name]["task"] model = self.models[task] params = self.params[model_name] # Preprocess input_img, height, width, height_even, width_even = self.preprocess_image(image) # Predict preds = model.apply({'params': flax.core.freeze(params)}, input_img) # Postprocess output_image = self.postprocess_image(preds, height, width, height_even, width_even) return output_image, f"Success - {device_info}" except Exception as e: return None, f"Error: {str(e)}" # Global predictor predictor = None def get_predictor(): """Lazy initialization of predictor""" global predictor if predictor is None: predictor = SimpleMAXIMPredictor() return predictor @spaces.GPU def process_image(image, model_name): """Gradio interface function""" if image is None: return None, "Please upload an image" if not MAXIM_AVAILABLE: return None, "Error: MAXIM library not available" try: pred = get_predictor() result_image, message = pred.predict(image, model_name) return result_image, message except Exception as e: return None, f"Processing error: {str(e)}" # Create Gradio interface with gr.Blocks(title="MAXIM: Multi-Axis MLP for Image Processing") as demo: gr.Markdown(""" # MAXIM: Multi-Axis MLP for Image Processing This Space demonstrates the MAXIM model for various image processing tasks. **Paper**: [MAXIM: Multi-Axis MLP for Image Processing](https://arxiv.org/abs/2201.02973) (CVPR 2022 Oral) """) with gr.Tabs(): with gr.TabItem("Image Enhancement (Retouching)"): with gr.Row(): with gr.Column(): input1 = gr.Image(type="pil", label="Input Image") btn1 = gr.Button("Enhance Image", variant="primary") with gr.Column(): output1 = gr.Image(type="pil", label="Enhanced Image") status1 = gr.Textbox(label="Status", interactive=False) btn1.click( fn=lambda img: process_image(img, "Image Enhancement (Retouching)"), inputs=[input1], outputs=[output1, status1] ) if os.path.exists("maxim/images/Enhancement/input"): example_files = [os.path.join("maxim/images/Enhancement/input", f) for f in os.listdir("maxim/images/Enhancement/input") if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[input1]) with gr.TabItem("Image Enhancement (Low-light)"): with gr.Row(): with gr.Column(): input2 = gr.Image(type="pil", label="Input Image") btn2 = gr.Button("Enhance Low-light", variant="primary") with gr.Column(): output2 = gr.Image(type="pil", label="Enhanced Image") status2 = gr.Textbox(label="Status", interactive=False) btn2.click( fn=lambda img: process_image(img, "Image Enhancement (Low-light)"), inputs=[input2], outputs=[output2, status2] ) with gr.TabItem("Image Denoising"): with gr.Row(): with gr.Column(): input3 = gr.Image(type="pil", label="Input Image") btn3 = gr.Button("Denoise Image", variant="primary") with gr.Column(): output3 = gr.Image(type="pil", label="Denoised Image") status3 = gr.Textbox(label="Status", interactive=False) btn3.click( fn=lambda img: process_image(img, "Image Denoising"), inputs=[input3], outputs=[output3, status3] ) if os.path.exists("maxim/images/Denoising/input"): example_files = [os.path.join("maxim/images/Denoising/input", f) for f in os.listdir("maxim/images/Denoising/input") if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[input3]) with gr.TabItem("Image Deblurring"): with gr.Row(): with gr.Column(): input4 = gr.Image(type="pil", label="Input Image") btn4 = gr.Button("Deblur Image", variant="primary") with gr.Column(): output4 = gr.Image(type="pil", label="Deblurred Image") status4 = gr.Textbox(label="Status", interactive=False) btn4.click( fn=lambda img: process_image(img, "Image Deblurring"), inputs=[input4], outputs=[output4, status4] ) if os.path.exists("maxim/images/Deblurring/input"): example_files = [os.path.join("maxim/images/Deblurring/input", f) for f in os.listdir("maxim/images/Deblurring/input") if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[input4]) with gr.TabItem("Image Deraining"): with gr.Row(): with gr.Column(): input5 = gr.Image(type="pil", label="Input Image") btn5 = gr.Button("Remove Rain", variant="primary") with gr.Column(): output5 = gr.Image(type="pil", label="Derained Image") status5 = gr.Textbox(label="Status", interactive=False) btn5.click( fn=lambda img: process_image(img, "Image Deraining"), inputs=[input5], outputs=[output5, status5] ) if os.path.exists("maxim/images/Deraining/input"): example_files = [os.path.join("maxim/images/Deraining/input", f) for f in os.listdir("maxim/images/Deraining/input") if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[input5]) with gr.TabItem("Image Dehazing"): with gr.Row(): with gr.Column(): input6 = gr.Image(type="pil", label="Input Image") btn6 = gr.Button("Remove Haze", variant="primary") with gr.Column(): output6 = gr.Image(type="pil", label="Dehazed Image") status6 = gr.Textbox(label="Status", interactive=False) btn6.click( fn=lambda img: process_image(img, "Image Dehazing"), inputs=[input6], outputs=[output6, status6] ) if os.path.exists("maxim/images/Dehazing/input"): example_files = [os.path.join("maxim/images/Dehazing/input", f) for f in os.listdir("maxim/images/Dehazing/input") if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] if example_files: gr.Examples(examples=[[f] for f in example_files], inputs=[input6]) if __name__ == "__main__": demo.launch()