import matplotlib.pyplot as plt import numpy as np from PIL import Image, ImageFilter import io import time import os import copy import pickle import datetime import urllib.request import gradio as gr import torch # from mim import install # install('mmcv-full') # install('mmengine') # install('mmdet') # from mmocr.apis import MMOCRInferencer # ocr = MMOCRInferencer(det='TextSnake', rec='ABINet_Vision') url = ( "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" ) path_input = "./cat.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" path_input = "./dog.jpg" urllib.request.urlretrieve(url, filename=path_input) # model = keras_model(weights="imagenet") # n_steps = 50 # method = "gausslegendre" # internal_batch_size = 50 # ig = IntegratedGradients( # model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size # ) def do_process(img): return img # instance = image.img_to_array(img) # instance = np.expand_dims(instance, axis=0) # instance = preprocess_input(instance) # preds = model.predict(instance) # lstPreds = decode_predictions(preds, top=3)[0] # dctPreds = { # lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds)) # } # predictions = preds.argmax(axis=1) # if baseline == "white": # baselines = bls = np.ones(instance.shape).astype(instance.dtype) # img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) # elif baseline == "black": # baselines = bls = np.zeros(instance.shape).astype(instance.dtype) # img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) # elif baseline == "blur": # img_flt = img.filter(ImageFilter.GaussianBlur(5)) # baselines = image.img_to_array(img_flt) # baselines = np.expand_dims(baselines, axis=0) # baselines = preprocess_input(baselines) # else: # baselines = np.random.random_sample(instance.shape).astype(instance.dtype) # img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255)) # explanation = ig.explain(instance, baselines=baselines, target=predictions) # attrs = explanation.attributions[0] # fig, ax = visualize_image_attr( # attr=attrs.squeeze(), # original_image=img, # method="blended_heat_map", # sign="all", # show_colorbar=True, # title=baseline, # plt_fig_axis=None, # use_pyplot=False, # ) # fig.tight_layout() # buf = io.BytesIO() # fig.savefig(buf) # buf.seek(0) # img_res = Image.open(buf) # return img_res, img_flt, dctPreds input_im = gr.inputs.Image( shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil" ) # input_drop = gr.inputs.Dropdown( # label="Baseline (default: random)", # choices=["random", "black", "white", "blur"], # default="random", # type="value", # ) output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil") # output_base = gr.outputs.Image(label="Baseline image", type="pil") # output_label = gr.outputs.Label(label="Classification results", num_top_classes=3) title = "XAI - Integrated gradients" description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio." examples = [["./cat.jpg"], ["./dog.jpg"]] article = "

By Dr. Mohamed Elawady

" iface = gr.Interface( fn=do_process, inputs=[input_im], outputs=[output_img], live=False, interpretation=None, title=title, description=description, article=article, examples=examples, ) iface.launch(debug=True)