import gradio as gr import numpy as np from PIL import Image, ImageDraw # import torch # from torchvision.transforms import Compose, Resize, ToTensor, Normalize # from utils.model import init_model # from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer from fastapi.staticfiles import StaticFiles from fileservice import app html_text = """
""" # def image_to_tensor(image_path): # image = Image.open(image_path).convert('RGB') # preprocess = Compose([ # Resize([224, 224], interpolation=Image.BICUBIC), # lambda image: image.convert("RGB"), # ToTensor(), # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), # ]) # image_data = preprocess(image) # return {'image': image_data} # def get_image_data(image_path): # image_input = image_to_tensor(image_path) # return image_input def get_intervention_vector(selected_cells_bef, selected_cells_aft): left = np.reshape(np.zeros((1, 14 * 14)), (14, 14)) right = np.reshape(np.zeros((1, 14 * 14)), (14, 14)) for (i, j) in selected_cells_bef: left[i, j] = 1. for (i, j) in selected_cells_aft: right[i, j] = 1. left_map = np.zeros((1, 14 * 14 + 1)) right_map = np.zeros((1, 14 * 14 + 1)) left_map[0, 1:] = np.reshape(left, (1, 14 * 14)) right_map[0, 1:] = np.reshape(right, (1, 14 * 14)) if len(selected_cells_bef) == 0: left_map[0, 0] = 0.0 if len(selected_cells_aft) == 0: right_map[0, 0] = 0.0 return left_map, right_map # def _get_rawimage(image_path): # # Pair x L x T x 3 x H x W # image = np.zeros((1, 3, 224, # 224), dtype=np.float) # for i in range(1): # raw_image_data = get_image_data(image_path) # raw_image_data = raw_image_data['image'] # image[i] = raw_image_data # return image # def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map): # visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask, # gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze()) # video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long() # input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"]) # input_caption_ids = input_caption_ids.long().unsqueeze(1) # decoder_mask = torch.ones_like(input_caption_ids) # for i in range(32): # decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True) # next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1) # input_caption_ids = torch.cat([input_caption_ids, next_words], 1) # next_mask = torch.ones_like(next_words) # decoder_mask = torch.cat([decoder_mask, next_mask], 1) # return input_caption_ids[:, 1:].tolist(), left_map, right_map # Dummy prediction function # def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft): # if image_bef is None: # return "No image provided", "", "" # if image_aft is None: # return "No image provided", "", "" # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # model = init_model('data/pytorch_model.pt', device) # tokenizer = ClipTokenizer() # left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft) # left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0) # bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1) # aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1) # image_pair = torch.cat([bef_image, aft_image], 1) # image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0) # result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map) # decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0]) # if "<|endoftext|>" in decode_text_list: # SEP_index = decode_text_list.index("<|endoftext|>") # decode_text_list = decode_text_list[:SEP_index] # if "!" in decode_text_list: # PAD_index = decode_text_list.index("!") # decode_text_list = decode_text_list[:PAD_index] # decode_text = decode_text_list.strip() # # Generate dummy predictions # pred = f"{decode_text}" # # Include information about selected cells # selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected" # selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected" # return pred, selected_info_bef, selected_info_aft # Add grid to the image def add_grid_to_image(image_path, grid_size=14): if image_path is None: return None image = Image.open(image_path) w, h = image.size image = image.convert('RGBA') draw = ImageDraw.Draw(image) x_positions = np.linspace(0, w, grid_size + 1) y_positions = np.linspace(0, h, grid_size + 1) # Draw the vertical lines for x in x_positions[1:-1]: line = ((x, 0), (x, h)) draw.line(line, fill='white') # Draw the horizontal lines for y in y_positions[1:-1]: line = ((0, y), (w, y)) draw.line(line, fill='white') return image, h, w # Handle cell selection def handle_click(image, evt: gr.SelectData, selected_cells, image_path): if image is None: return None, [] grid_size = 14 image, h, w = add_grid_to_image(image_path, grid_size) x_positions = np.linspace(0, w, grid_size + 1) y_positions = np.linspace(0, h, grid_size + 1) # Calculate which cell was clicked for index, x in enumerate(x_positions[:-1]): if evt.index[0] >= x and evt.index[0] <= x_positions[index+1]: row = index for index, y in enumerate(y_positions[:-1]): if evt.index[1] >= y and evt.index[1] <= y_positions[index+1]: col = index cell_idx = (row, col) # Toggle selection if cell_idx in selected_cells: selected_cells.remove(cell_idx) else: selected_cells.append(cell_idx) # Add semi-transparent overlay for selected cells highlight_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) # Fully transparent layer highlight_draw = ImageDraw.Draw(highlight_layer) # Define a lighter green color with 40% transparency light_green = (144, 238, 144, 102) # RGB = (144, 238, 144), Alpha = 102 (40% of 255) for (row, col) in selected_cells: cell_top_left = (x_positions[row], y_positions[col]) cell_bottom_right = (x_positions[row + 1], y_positions[col + 1]) highlight_draw.rectangle([cell_top_left, cell_bottom_right], fill=light_green, outline='white') result_img = Image.alpha_composite(image.convert('RGBA'), highlight_layer) return result_img, selected_cells # Process example images def process_example(image_path_bef, image_path_aft): # Add grid to the example image image_bef_grid, _, _ = add_grid_to_image(image_path_bef, 14) image_aft_grid, _, _ = add_grid_to_image(image_path_aft, 14) return image_bef_grid, image_aft_grid # Reset selected cells and store original image def display_image(image_path): image_grid, _, _ = add_grid_to_image(image_path, 14) return image_grid, [] with gr.Blocks() as demo: gr.Markdown("# TAB: Transformer Attention Bottleneck") # Instructions gr.Markdown(""" ## Instructions: 1. Upload an image or select one from the examples 2. Click on grid cells to select/deselect them 3. Click the 'Predict' button to get model predictions """) selected_cells_bef = gr.State([]) selected_cells_aft = gr.State([]) with gr.Row(): with gr.Column(scale=1): # Input components with grid overlay image_bef = gr.Image(type="filepath", visible=True) image_aft = gr.Image(type="filepath", visible=True) predict_btn = gr.Button("Predict") with gr.Column(scale=1): image_display_with_grid_bef = gr.Image(type="pil", label="Before Image with Grid") image_display_with_grid_aft = gr.Image(type="pil", label="After Image with Grid") # Add click event to the displayed image image_display_with_grid_bef.select( handle_click, inputs=[image_display_with_grid_bef, selected_cells_bef, image_bef], outputs=[image_display_with_grid_bef, selected_cells_bef] ) image_display_with_grid_aft.select( handle_click, inputs=[image_display_with_grid_aft, selected_cells_aft, image_aft], outputs=[image_display_with_grid_aft, selected_cells_aft] ) with gr.Row(): with gr.Column(scale=1): # Example images examples = gr.Examples( examples=[["data/images/CLEVR_default_000572.png", "data/images/CLEVR_semantic_000572.png"], ["data/images/CLEVR_default_003339.png", "data/images/CLEVR_semantic_003339.png"]], inputs=[image_bef, image_aft], # outputs=[image_display_with_grid_bef, image_display_with_grid_aft], label="Example Images", # fn=process_example, examples_per_page=5 ) # image_bef.change( # fn=display_image, # inputs=[image_bef], # outputs=[image_display_with_grid_bef, selected_cells_bef] # ) # image_aft.change( # fn=display_image, # inputs=[image_aft], # outputs=[image_display_with_grid_aft, selected_cells_aft] # ) with gr.Column(scale=1): # Output components prediction = gr.Textbox(label="Predicted caption") selected_info_bef = gr.Textbox(label="Selected patches on before") selected_info_aft = gr.Textbox(label="Selected patches on after") html = gr.HTML(html_text) # Connect the predict button to the prediction function # predict_btn.click( # fn=predict_image, # inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft], # outputs=[prediction, selected_info_bef, selected_info_aft] # ) # image_bef.change( # fn=None, # inputs=[image_bef], # outputs=[], # js="(image) => { initializeEditor(); importBackground(image); return []; }", # ) # image_aft.change( # fn=None, # inputs=[image_aft], # outputs=[], # js="(image) => { initializeEditor(); importBackground(image); return []; }", # ) app.mount("/js", StaticFiles(directory="js"), name="js") gr.mount_gradio_app(app, demo, path="/")