Spaces:
Running
on
Zero
Running
on
Zero
import os | |
# Force CPU-only mode | |
os.environ["JAX_PLATFORM_NAME"] = "cpu" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppress TensorFlow warnings | |
# Suppress JAX compilation warnings | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="jax") | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import importlib | |
import ml_collections | |
import jax.numpy as jnp | |
import flax | |
import requests | |
import tempfile | |
import spaces | |
# Try to import MAXIM components | |
try: | |
from maxim.run_eval import ( | |
_MODEL_FILENAME, | |
_MODEL_VARIANT_DICT, | |
_MODEL_CONFIGS, | |
get_params, | |
mod_padding_symmetric, | |
make_shape_even, | |
) | |
MAXIM_AVAILABLE = True | |
except ImportError as e: | |
print(f"MAXIM import failed: {e}") | |
MAXIM_AVAILABLE = False | |
# Model configurations with direct download URLs | |
MODELS = { | |
"Image Enhancement (Retouching)": { | |
"task": "Enhancement", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/FiveK/checkpoint.npz", | |
"filename": "enhancement_retouching.npz" | |
}, | |
"Image Enhancement (Low-light)": { | |
"task": "Enhancement", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Enhancement/LOL/checkpoint.npz", | |
"filename": "enhancement_lowlight.npz" | |
}, | |
"Image Denoising": { | |
"task": "Denoising", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Denoising/SIDD/checkpoint.npz", | |
"filename": "denoising.npz" | |
}, | |
"Image Deblurring": { | |
"task": "Deblurring", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deblurring/GoPro/checkpoint.npz", | |
"filename": "deblurring.npz" | |
}, | |
"Image Deraining": { | |
"task": "Deraining", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Deraining/Rain13k/checkpoint.npz", | |
"filename": "deraining.npz" | |
}, | |
"Image Dehazing": { | |
"task": "Dehazing", | |
"url": "https://storage.googleapis.com/gresearch/maxim/ckpt/Dehazing/SOTS-Indoor/checkpoint.npz", | |
"filename": "dehazing_indoor.npz" | |
} | |
} | |
class SimpleMAXIMPredictor: | |
def __init__(self): | |
self.models = {} | |
self.params = {} | |
self.initialized = False | |
def initialize(self): | |
"""Initialize models when first needed""" | |
if self.initialized or not MAXIM_AVAILABLE: | |
return self.initialized | |
try: | |
# Build models for each task | |
model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}') | |
for task in _MODEL_VARIANT_DICT.keys(): | |
model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS) | |
model_configs.variant = _MODEL_VARIANT_DICT[task] | |
self.models[task] = model_mod.Model(**model_configs) | |
self.initialized = True | |
return True | |
except Exception as e: | |
print(f"Initialization failed: {e}") | |
return False | |
def download_model(self, model_name): | |
"""Download model checkpoint if not already downloaded""" | |
if model_name not in MODELS: | |
return False, f"Model {model_name} not found" | |
model_info = MODELS[model_name] | |
filename = model_info["filename"] | |
if not os.path.exists(filename): | |
try: | |
print(f"Downloading {filename}...") | |
response = requests.get(model_info["url"], stream=True) | |
response.raise_for_status() | |
with open(filename, 'wb') as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
print(f"Downloaded {filename}") | |
except Exception as e: | |
return False, f"Failed to download {filename}: {str(e)}" | |
# Load parameters | |
if model_name not in self.params: | |
try: | |
self.params[model_name] = get_params(filename) | |
except Exception as e: | |
return False, f"Failed to load parameters: {str(e)}" | |
return True, "Success" | |
def preprocess_image(self, image): | |
"""Preprocess image for model input""" | |
# Convert to numpy array | |
input_img = np.asarray(image.convert('RGB'), np.float32) / 255.0 | |
# Store original dimensions | |
height, width = input_img.shape[0], input_img.shape[1] | |
# Make shape even | |
input_img = make_shape_even(input_img) | |
height_even, width_even = input_img.shape[0], input_img.shape[1] | |
# Pad to multiples of 64 | |
input_img = mod_padding_symmetric(input_img, factor=64) | |
input_img = np.expand_dims(input_img, axis=0) | |
return input_img, height, width, height_even, width_even | |
def postprocess_image(self, preds, height, width, height_even, width_even): | |
"""Postprocess model output to get final image""" | |
# Handle multi-stage outputs | |
if isinstance(preds, list): | |
preds = preds[-1] | |
if isinstance(preds, list): | |
preds = preds[-1] | |
preds = np.array(preds[0], np.float32) | |
# Unpad to original resolution | |
new_height, new_width = preds.shape[0], preds.shape[1] | |
h_start = new_height // 2 - height_even // 2 | |
h_end = h_start + height | |
w_start = new_width // 2 - width_even // 2 | |
w_end = w_start + width | |
preds = preds[h_start:h_end, w_start:w_end, :] | |
# Convert to PIL Image | |
output_img = np.array((np.clip(preds, 0.0, 1.0) * 255.0).astype(np.uint8)) | |
return Image.fromarray(output_img) | |
def predict(self, image, model_name): | |
"""Main prediction function""" | |
if not self.initialize(): | |
return None, "Error: Could not initialize model" | |
success, message = self.download_model(model_name) | |
if not success: | |
return None, message | |
try: | |
import jax | |
# Force CPU mode - device info | |
device_info = f"Using device: {jax.default_backend()} (CPU-only mode)" | |
print(device_info) | |
# Get model and parameters | |
task = MODELS[model_name]["task"] | |
model = self.models[task] | |
params = self.params[model_name] | |
# Preprocess | |
input_img, height, width, height_even, width_even = self.preprocess_image(image) | |
# Predict | |
preds = model.apply({'params': flax.core.freeze(params)}, input_img) | |
# Postprocess | |
output_image = self.postprocess_image(preds, height, width, height_even, width_even) | |
return output_image, f"Success - {device_info}" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
# Global predictor | |
predictor = None | |
def get_predictor(): | |
"""Lazy initialization of predictor""" | |
global predictor | |
if predictor is None: | |
predictor = SimpleMAXIMPredictor() | |
return predictor | |
def process_image(image, model_name): | |
"""Gradio interface function""" | |
if image is None: | |
return None, "Please upload an image" | |
if not MAXIM_AVAILABLE: | |
return None, "Error: MAXIM library not available" | |
try: | |
pred = get_predictor() | |
result_image, message = pred.predict(image, model_name) | |
return result_image, message | |
except Exception as e: | |
return None, f"Processing error: {str(e)}" | |
# Create Gradio interface | |
with gr.Blocks(title="MAXIM: Multi-Axis MLP for Image Processing") as demo: | |
gr.Markdown(""" | |
# MAXIM: Multi-Axis MLP for Image Processing | |
This Space demonstrates the MAXIM model for various image processing tasks. | |
**Paper**: [MAXIM: Multi-Axis MLP for Image Processing](https://arxiv.org/abs/2201.02973) (CVPR 2022 Oral) | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Image Enhancement (Retouching)"): | |
with gr.Row(): | |
with gr.Column(): | |
input1 = gr.Image(type="pil", label="Input Image") | |
btn1 = gr.Button("Enhance Image", variant="primary") | |
with gr.Column(): | |
output1 = gr.Image(type="pil", label="Enhanced Image") | |
status1 = gr.Textbox(label="Status", interactive=False) | |
btn1.click( | |
fn=lambda img: process_image(img, "Image Enhancement (Retouching)"), | |
inputs=[input1], | |
outputs=[output1, status1] | |
) | |
if os.path.exists("maxim/images/Enhancement/input"): | |
example_files = [os.path.join("maxim/images/Enhancement/input", f) | |
for f in os.listdir("maxim/images/Enhancement/input") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] | |
if example_files: | |
gr.Examples(examples=[[f] for f in example_files], inputs=[input1]) | |
with gr.TabItem("Image Enhancement (Low-light)"): | |
with gr.Row(): | |
with gr.Column(): | |
input2 = gr.Image(type="pil", label="Input Image") | |
btn2 = gr.Button("Enhance Low-light", variant="primary") | |
with gr.Column(): | |
output2 = gr.Image(type="pil", label="Enhanced Image") | |
status2 = gr.Textbox(label="Status", interactive=False) | |
btn2.click( | |
fn=lambda img: process_image(img, "Image Enhancement (Low-light)"), | |
inputs=[input2], | |
outputs=[output2, status2] | |
) | |
with gr.TabItem("Image Denoising"): | |
with gr.Row(): | |
with gr.Column(): | |
input3 = gr.Image(type="pil", label="Input Image") | |
btn3 = gr.Button("Denoise Image", variant="primary") | |
with gr.Column(): | |
output3 = gr.Image(type="pil", label="Denoised Image") | |
status3 = gr.Textbox(label="Status", interactive=False) | |
btn3.click( | |
fn=lambda img: process_image(img, "Image Denoising"), | |
inputs=[input3], | |
outputs=[output3, status3] | |
) | |
if os.path.exists("maxim/images/Denoising/input"): | |
example_files = [os.path.join("maxim/images/Denoising/input", f) | |
for f in os.listdir("maxim/images/Denoising/input") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] | |
if example_files: | |
gr.Examples(examples=[[f] for f in example_files], inputs=[input3]) | |
with gr.TabItem("Image Deblurring"): | |
with gr.Row(): | |
with gr.Column(): | |
input4 = gr.Image(type="pil", label="Input Image") | |
btn4 = gr.Button("Deblur Image", variant="primary") | |
with gr.Column(): | |
output4 = gr.Image(type="pil", label="Deblurred Image") | |
status4 = gr.Textbox(label="Status", interactive=False) | |
btn4.click( | |
fn=lambda img: process_image(img, "Image Deblurring"), | |
inputs=[input4], | |
outputs=[output4, status4] | |
) | |
if os.path.exists("maxim/images/Deblurring/input"): | |
example_files = [os.path.join("maxim/images/Deblurring/input", f) | |
for f in os.listdir("maxim/images/Deblurring/input") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] | |
if example_files: | |
gr.Examples(examples=[[f] for f in example_files], inputs=[input4]) | |
with gr.TabItem("Image Deraining"): | |
with gr.Row(): | |
with gr.Column(): | |
input5 = gr.Image(type="pil", label="Input Image") | |
btn5 = gr.Button("Remove Rain", variant="primary") | |
with gr.Column(): | |
output5 = gr.Image(type="pil", label="Derained Image") | |
status5 = gr.Textbox(label="Status", interactive=False) | |
btn5.click( | |
fn=lambda img: process_image(img, "Image Deraining"), | |
inputs=[input5], | |
outputs=[output5, status5] | |
) | |
if os.path.exists("maxim/images/Deraining/input"): | |
example_files = [os.path.join("maxim/images/Deraining/input", f) | |
for f in os.listdir("maxim/images/Deraining/input") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] | |
if example_files: | |
gr.Examples(examples=[[f] for f in example_files], inputs=[input5]) | |
with gr.TabItem("Image Dehazing"): | |
with gr.Row(): | |
with gr.Column(): | |
input6 = gr.Image(type="pil", label="Input Image") | |
btn6 = gr.Button("Remove Haze", variant="primary") | |
with gr.Column(): | |
output6 = gr.Image(type="pil", label="Dehazed Image") | |
status6 = gr.Textbox(label="Status", interactive=False) | |
btn6.click( | |
fn=lambda img: process_image(img, "Image Dehazing"), | |
inputs=[input6], | |
outputs=[output6, status6] | |
) | |
if os.path.exists("maxim/images/Dehazing/input"): | |
example_files = [os.path.join("maxim/images/Dehazing/input", f) | |
for f in os.listdir("maxim/images/Dehazing/input") | |
if f.lower().endswith(('.png', '.jpg', '.jpeg'))][:3] | |
if example_files: | |
gr.Examples(examples=[[f] for f in example_files], inputs=[input6]) | |
if __name__ == "__main__": | |
demo.launch() |