pooyanrg's picture
fix
ad6a1d7
raw
history blame
11.4 kB
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw
# import torch
# from torchvision.transforms import Compose, Resize, ToTensor, Normalize
# from utils.model import init_model
# from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
from fastapi.staticfiles import StaticFiles
from fileservice import app
html_text = """
<div id="container">
<canvas id="canvas" width="512" height="512"></canvas><img id="canvas-background" style="display:none;"/>
</div>
"""
# def image_to_tensor(image_path):
# image = Image.open(image_path).convert('RGB')
# preprocess = Compose([
# Resize([224, 224], interpolation=Image.BICUBIC),
# lambda image: image.convert("RGB"),
# ToTensor(),
# Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
# ])
# image_data = preprocess(image)
# return {'image': image_data}
# def get_image_data(image_path):
# image_input = image_to_tensor(image_path)
# return image_input
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
right = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
for (i, j) in selected_cells_bef:
left[i, j] = 1.
for (i, j) in selected_cells_aft:
right[i, j] = 1.
left_map = np.zeros((1, 14 * 14 + 1))
right_map = np.zeros((1, 14 * 14 + 1))
left_map[0, 1:] = np.reshape(left, (1, 14 * 14))
right_map[0, 1:] = np.reshape(right, (1, 14 * 14))
if len(selected_cells_bef) == 0:
left_map[0, 0] = 0.0
if len(selected_cells_aft) == 0:
right_map[0, 0] = 0.0
return left_map, right_map
# def _get_rawimage(image_path):
# # Pair x L x T x 3 x H x W
# image = np.zeros((1, 3, 224,
# 224), dtype=np.float)
# for i in range(1):
# raw_image_data = get_image_data(image_path)
# raw_image_data = raw_image_data['image']
# image[i] = raw_image_data
# return image
# def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
# visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
# gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
# video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
# input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
# input_caption_ids = input_caption_ids.long().unsqueeze(1)
# decoder_mask = torch.ones_like(input_caption_ids)
# for i in range(32):
# decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
# next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
# input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
# next_mask = torch.ones_like(next_words)
# decoder_mask = torch.cat([decoder_mask, next_mask], 1)
# return input_caption_ids[:, 1:].tolist(), left_map, right_map
# Dummy prediction function
# def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
# if image_bef is None:
# return "No image provided", "", ""
# if image_aft is None:
# return "No image provided", "", ""
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = init_model('data/pytorch_model.pt', device)
# tokenizer = ClipTokenizer()
# left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
# left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
# bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
# aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
# image_pair = torch.cat([bef_image, aft_image], 1)
# image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
# result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
# decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
# if "<|endoftext|>" in decode_text_list:
# SEP_index = decode_text_list.index("<|endoftext|>")
# decode_text_list = decode_text_list[:SEP_index]
# if "!" in decode_text_list:
# PAD_index = decode_text_list.index("!")
# decode_text_list = decode_text_list[:PAD_index]
# decode_text = decode_text_list.strip()
# # Generate dummy predictions
# pred = f"{decode_text}"
# # Include information about selected cells
# selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
# selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
# return pred, selected_info_bef, selected_info_aft
# Add grid to the image
def add_grid_to_image(image_path, grid_size=14):
if image_path is None:
return None
image = Image.open(image_path)
w, h = image.size
image = image.convert('RGBA')
draw = ImageDraw.Draw(image)
x_positions = np.linspace(0, w, grid_size + 1)
y_positions = np.linspace(0, h, grid_size + 1)
# Draw the vertical lines
for x in x_positions[1:-1]:
line = ((x, 0), (x, h))
draw.line(line, fill='white')
# Draw the horizontal lines
for y in y_positions[1:-1]:
line = ((0, y), (w, y))
draw.line(line, fill='white')
return image, h, w
# Handle cell selection
def handle_click(image, evt: gr.SelectData, selected_cells, image_path):
if image is None:
return None, []
grid_size = 14
image, h, w = add_grid_to_image(image_path, grid_size)
x_positions = np.linspace(0, w, grid_size + 1)
y_positions = np.linspace(0, h, grid_size + 1)
# Calculate which cell was clicked
for index, x in enumerate(x_positions[:-1]):
if evt.index[0] >= x and evt.index[0] <= x_positions[index+1]:
row = index
for index, y in enumerate(y_positions[:-1]):
if evt.index[1] >= y and evt.index[1] <= y_positions[index+1]:
col = index
cell_idx = (row, col)
# Toggle selection
if cell_idx in selected_cells:
selected_cells.remove(cell_idx)
else:
selected_cells.append(cell_idx)
# Add semi-transparent overlay for selected cells
highlight_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) # Fully transparent layer
highlight_draw = ImageDraw.Draw(highlight_layer)
# Define a lighter green color with 40% transparency
light_green = (144, 238, 144, 102) # RGB = (144, 238, 144), Alpha = 102 (40% of 255)
for (row, col) in selected_cells:
cell_top_left = (x_positions[row], y_positions[col])
cell_bottom_right = (x_positions[row + 1], y_positions[col + 1])
highlight_draw.rectangle([cell_top_left, cell_bottom_right], fill=light_green, outline='white')
result_img = Image.alpha_composite(image.convert('RGBA'), highlight_layer)
return result_img, selected_cells
# Process example images
def process_example(image_path_bef, image_path_aft):
# Add grid to the example image
image_bef_grid, _, _ = add_grid_to_image(image_path_bef, 14)
image_aft_grid, _, _ = add_grid_to_image(image_path_aft, 14)
return image_bef_grid, image_aft_grid # Reset selected cells and store original image
def display_image(image_path):
image_grid, _, _ = add_grid_to_image(image_path, 14)
return image_grid, []
with gr.Blocks() as demo:
gr.Markdown("# TAB: Transformer Attention Bottleneck")
# Instructions
gr.Markdown("""
## Instructions:
1. Upload an image or select one from the examples
2. Click on grid cells to select/deselect them
3. Click the 'Predict' button to get model predictions
""")
selected_cells_bef = gr.State([])
selected_cells_aft = gr.State([])
with gr.Row():
with gr.Column(scale=1):
# Input components with grid overlay
image_bef = gr.Image(type="filepath", visible=True)
image_aft = gr.Image(type="filepath", visible=True)
predict_btn = gr.Button("Predict")
with gr.Column(scale=1):
image_display_with_grid_bef = gr.Image(type="pil", label="Before Image with Grid")
image_display_with_grid_aft = gr.Image(type="pil", label="After Image with Grid")
# Add click event to the displayed image
image_display_with_grid_bef.select(
handle_click,
inputs=[image_display_with_grid_bef, selected_cells_bef, image_bef],
outputs=[image_display_with_grid_bef, selected_cells_bef]
)
image_display_with_grid_aft.select(
handle_click,
inputs=[image_display_with_grid_aft, selected_cells_aft, image_aft],
outputs=[image_display_with_grid_aft, selected_cells_aft]
)
with gr.Row():
with gr.Column(scale=1):
# Example images
examples = gr.Examples(
examples=[["data/images/CLEVR_default_000572.png", "data/images/CLEVR_semantic_000572.png"],
["data/images/CLEVR_default_003339.png", "data/images/CLEVR_semantic_003339.png"]],
inputs=[image_bef, image_aft],
# outputs=[image_display_with_grid_bef, image_display_with_grid_aft],
label="Example Images",
# fn=process_example,
examples_per_page=5
)
# image_bef.change(
# fn=display_image,
# inputs=[image_bef],
# outputs=[image_display_with_grid_bef, selected_cells_bef]
# )
# image_aft.change(
# fn=display_image,
# inputs=[image_aft],
# outputs=[image_display_with_grid_aft, selected_cells_aft]
# )
with gr.Column(scale=1):
# Output components
prediction = gr.Textbox(label="Predicted caption")
selected_info_bef = gr.Textbox(label="Selected patches on before")
selected_info_aft = gr.Textbox(label="Selected patches on after")
html = gr.HTML(html_text)
# Connect the predict button to the prediction function
# predict_btn.click(
# fn=predict_image,
# inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
# outputs=[prediction, selected_info_bef, selected_info_aft]
# )
# image_bef.change(
# fn=None,
# inputs=[image_bef],
# outputs=[],
# js="(image) => { initializeEditor(); importBackground(image); return []; }",
# )
# image_aft.change(
# fn=None,
# inputs=[image_aft],
# outputs=[],
# js="(image) => { initializeEditor(); importBackground(image); return []; }",
# )
app.mount("/js", StaticFiles(directory="js"), name="js")
gr.mount_gradio_app(app, demo, path="/")