GoogleMaxim / app.py
NightRaven109's picture
Update app.py
2661499 verified
import os
# Force CPU-only mode
os.environ["JAX_PLATFORM_NAME"] = "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TensorFlow warnings
# Suppress JAX compilation warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="jax")
import gradio as gr
import numpy as np
from PIL import Image
import importlib
import ml_collections
import jax.numpy as jnp
import flax
import requests
import tempfile
import spaces
# Try to import MAXIM components
try:
from maxim.run_eval import (
_MODEL_FILENAME,
_MODEL_VARIANT_DICT,
_MODEL_CONFIGS,
get_params,
mod_padding_symmetric,
make_shape_even,
)
MAXIM_AVAILABLE = True
except ImportError as e:
print(f"MAXIM import failed: {e}")
MAXIM_AVAILABLE = False
# Model configurations with direct download URLs
MODELS = {
"Image Enhancement (Retouching)": {
"task": "Enhancement",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/FiveK/checkpoint.npz",
"filename": "enhancement_retouching.npz"
},
"Image Enhancement (Low-light)": {
"task": "Enhancement",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/LOL/checkpoint.npz",
"filename": "enhancement_lowlight.npz"
},
"Image Denoising": {
"task": "Denoising",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz",
"filename": "denoising.npz"
},
"Image Deblurring": {
"task": "Deblurring",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deblurring/GoPro/checkpoint.npz",
"filename": "deblurring.npz"
},
"Image Deraining": {
"task": "Deraining",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deraining/Rain13k/checkpoint.npz",
"filename": "deraining.npz"
},
"Image Dehazing": {
"task": "Dehazing",
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Dehazing/SOTS-Indoor/checkpoint.npz",
"filename": "dehazing_indoor.npz"
}
}
class SimpleMAXIMPredictor:
def __init__(self):
self.models = {}
self.params = {}
self.initialized = False
def initialize(self):
"""Initialize models when first needed"""
if self.initialized or not MAXIM_AVAILABLE:
return self.initialized
try:
# Build models for each task
model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}')
for task in _MODEL_VARIANT_DICT.keys():
model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS)
model_configs.variant = _MODEL_VARIANT_DICT[task]
self.models[task] = model_mod.Model(**model_configs)
self.initialized = True
return True
except Exception as e:
print(f"Initialization failed: {e}")
return False
def download_model(self, model_name):
"""Download model checkpoint if not already downloaded"""
if model_name not in MODELS:
return False, f"Model {model_name} not found"
model_info = MODELS[model_name]
filename = model_info["filename"]
if not os.path.exists(filename):
try:
print(f"Downloading {filename}...")
response = requests.get(model_info["url"], stream=True)
response.raise_for_status()
with open(filename, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Downloaded {filename}")
except Exception as e:
return False, f"Failed to download {filename}: {str(e)}"
# Load parameters
if model_name not in self.params:
try:
self.params[model_name] = get_params(filename)
except Exception as e:
return False, f"Failed to load parameters: {str(e)}"
return True, "Success"
def preprocess_image(self, image):
"""Preprocess image for model input"""
# Convert to numpy array
input_img = np.asarray(image.convert('RGB'), np.float32) / 255.0
# Store original dimensions
height, width = input_img.shape[0], input_img.shape[1]
# Make shape even
input_img = make_shape_even(input_img)
height_even, width_even = input_img.shape[0], input_img.shape[1]
# Pad to multiples of 64
input_img = mod_padding_symmetric(input_img, factor=64)
input_img = np.expand_dims(input_img, axis=0)
return input_img, height, width, height_even, width_even
def postprocess_image(self, preds, height, width, height_even, width_even):
"""Postprocess model output to get final image"""
# Handle multi-stage outputs
if isinstance(preds, list):
preds = preds[-1]
if isinstance(preds, list):
preds = preds[-1]
preds = np.array(preds[0], np.float32)
# Unpad to original resolution
new_height, new_width = preds.shape[0], preds.shape[1]
h_start = new_height // 2 - height_even // 2
h_end = h_start + height
w_start = new_width // 2 - width_even // 2
w_end = w_start + width
preds = preds[h_start:h_end, w_start:w_end, :]
# Convert to PIL Image
output_img = np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8))
return Image.fromarray(output_img)
@spaces.GPU
def predict(self, image, model_name):
"""Main prediction function"""
if not self.initialize():
return None, "Error: Could not initialize model"
success, message = self.download_model(model_name)
if not success:
return None, message
try:
import jax
# Force CPU mode - device info
device_info = f"Using device: {jax.default_backend()} (CPU-only mode)"
print(device_info)
# Get model and parameters
task = MODELS[model_name]["task"]
model = self.models[task]
params = self.params[model_name]
# Preprocess
input_img, height, width, height_even, width_even = self.preprocess_image(image)
# Predict
preds = model.apply({'params': flax.core.freeze(params)}, input_img)
# Postprocess
output_image = self.postprocess_image(preds, height, width, height_even, width_even)
return output_image, f"Success - {device_info}"
except Exception as e:
return None, f"Error: {str(e)}"
# Global predictor
predictor = None
def get_predictor():
"""Lazy initialization of predictor"""
global predictor
if predictor is None:
predictor = SimpleMAXIMPredictor()
return predictor
@spaces.GPU
def process_image(image, model_name):
"""Gradio interface function"""
if image is None:
return None, "Please upload an image"
if not MAXIM_AVAILABLE:
return None, "Error: MAXIM library not available"
try:
pred = get_predictor()
result_image, message = pred.predict(image, model_name)
return result_image, message
except Exception as e:
return None, f"Processing error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="MAXIM: Multi-Axis MLP for Image Processing") as demo:
gr.Markdown("""
# MAXIM: Multi-Axis MLP for Image Processing
This Space demonstrates the MAXIM model for various image processing tasks.
**Paper**: [MAXIM: Multi-Axis MLP for Image Processing](https://arxiv.org/abs/2201.02973) (CVPR 2022 Oral)
""")
with gr.Tabs():
with gr.TabItem("Image Enhancement (Retouching)"):
with gr.Row():
with gr.Column():
input1 = gr.Image(type="pil", label="Input Image")
btn1 = gr.Button("Enhance Image", variant="primary")
with gr.Column():
output1 = gr.Image(type="pil", label="Enhanced Image")
status1 = gr.Textbox(label="Status", interactive=False)
btn1.click(
fn=lambda img: process_image(img, "Image Enhancement (Retouching)"),
inputs=[input1],
outputs=[output1, status1]
)
if os.path.exists("maxim/images/Enhancement/input"):
example_files = [os.path.join("maxim/images/Enhancement/input", f)
for f in os.listdir("maxim/images/Enhancement/input")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
if example_files:
gr.Examples(examples=[[f] for f in example_files], inputs=[input1])
with gr.TabItem("Image Enhancement (Low-light)"):
with gr.Row():
with gr.Column():
input2 = gr.Image(type="pil", label="Input Image")
btn2 = gr.Button("Enhance Low-light", variant="primary")
with gr.Column():
output2 = gr.Image(type="pil", label="Enhanced Image")
status2 = gr.Textbox(label="Status", interactive=False)
btn2.click(
fn=lambda img: process_image(img, "Image Enhancement (Low-light)"),
inputs=[input2],
outputs=[output2, status2]
)
with gr.TabItem("Image Denoising"):
with gr.Row():
with gr.Column():
input3 = gr.Image(type="pil", label="Input Image")
btn3 = gr.Button("Denoise Image", variant="primary")
with gr.Column():
output3 = gr.Image(type="pil", label="Denoised Image")
status3 = gr.Textbox(label="Status", interactive=False)
btn3.click(
fn=lambda img: process_image(img, "Image Denoising"),
inputs=[input3],
outputs=[output3, status3]
)
if os.path.exists("maxim/images/Denoising/input"):
example_files = [os.path.join("maxim/images/Denoising/input", f)
for f in os.listdir("maxim/images/Denoising/input")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
if example_files:
gr.Examples(examples=[[f] for f in example_files], inputs=[input3])
with gr.TabItem("Image Deblurring"):
with gr.Row():
with gr.Column():
input4 = gr.Image(type="pil", label="Input Image")
btn4 = gr.Button("Deblur Image", variant="primary")
with gr.Column():
output4 = gr.Image(type="pil", label="Deblurred Image")
status4 = gr.Textbox(label="Status", interactive=False)
btn4.click(
fn=lambda img: process_image(img, "Image Deblurring"),
inputs=[input4],
outputs=[output4, status4]
)
if os.path.exists("maxim/images/Deblurring/input"):
example_files = [os.path.join("maxim/images/Deblurring/input", f)
for f in os.listdir("maxim/images/Deblurring/input")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
if example_files:
gr.Examples(examples=[[f] for f in example_files], inputs=[input4])
with gr.TabItem("Image Deraining"):
with gr.Row():
with gr.Column():
input5 = gr.Image(type="pil", label="Input Image")
btn5 = gr.Button("Remove Rain", variant="primary")
with gr.Column():
output5 = gr.Image(type="pil", label="Derained Image")
status5 = gr.Textbox(label="Status", interactive=False)
btn5.click(
fn=lambda img: process_image(img, "Image Deraining"),
inputs=[input5],
outputs=[output5, status5]
)
if os.path.exists("maxim/images/Deraining/input"):
example_files = [os.path.join("maxim/images/Deraining/input", f)
for f in os.listdir("maxim/images/Deraining/input")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
if example_files:
gr.Examples(examples=[[f] for f in example_files], inputs=[input5])
with gr.TabItem("Image Dehazing"):
with gr.Row():
with gr.Column():
input6 = gr.Image(type="pil", label="Input Image")
btn6 = gr.Button("Remove Haze", variant="primary")
with gr.Column():
output6 = gr.Image(type="pil", label="Dehazed Image")
status6 = gr.Textbox(label="Status", interactive=False)
btn6.click(
fn=lambda img: process_image(img, "Image Dehazing"),
inputs=[input6],
outputs=[output6, status6]
)
if os.path.exists("maxim/images/Dehazing/input"):
example_files = [os.path.join("maxim/images/Dehazing/input", f)
for f in os.listdir("maxim/images/Dehazing/input")
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3]
if example_files:
gr.Examples(examples=[[f] for f in example_files], inputs=[input6])
if __name__ == "__main__":
demo.launch()