pooyanrg commited on
Commit
7af7b4e
·
1 Parent(s): bbb4a99
Files changed (2) hide show
  1. app.py +26 -20
  2. js/interactive_grid.js +1 -1
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image, ImageDraw
4
  import torch
@@ -26,28 +27,29 @@ def get_image_data(image_path):
26
  image_input = image_to_tensor(image_path)
27
  return image_input
28
 
29
- def get_intervention_vector(selected_cells_bef, selected_cells_aft):
30
- left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
31
- right = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
32
-
33
- for (i, j) in selected_cells_bef:
34
- left[i, j] = 1.
35
- for (i, j) in selected_cells_aft:
36
- right[i, j] = 1.
37
 
 
38
 
39
  left_map = np.zeros((1, 14 * 14 + 1))
40
  right_map = np.zeros((1, 14 * 14 + 1))
41
 
42
- left_map[0, 1:] = np.reshape(left, (1, 14 * 14))
43
- right_map[0, 1:] = np.reshape(right, (1, 14 * 14))
44
 
45
 
46
- if len(selected_cells_bef) == 0:
47
- left_map[0, 0] = 0.0
48
 
49
- if len(selected_cells_aft) == 0:
50
- right_map[0, 0] = 0.0
51
 
52
 
53
  return left_map, right_map
@@ -85,7 +87,7 @@ def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map
85
  return input_caption_ids[:, 1:].tolist(), left_map, right_map
86
 
87
  # Dummy prediction function
88
- def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
89
  if image_bef is None:
90
  return "No image provided", "", ""
91
  if image_aft is None:
@@ -98,6 +100,9 @@ def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
98
 
99
  tokenizer = ClipTokenizer()
100
 
 
 
 
101
  left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
102
 
103
  left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
@@ -228,12 +233,12 @@ with gr.Blocks() as demo:
228
  3. Click the 'Predict' button to get model predictions
229
  """)
230
 
231
- selected_cells_bef = gr.State([])
232
- selected_cells_aft = gr.State([])
233
-
234
  height = gr.State(value=320)
235
  width = gr.State(value=480)
236
 
 
 
 
237
 
238
  with gr.Row():
239
  with gr.Column(scale=1):
@@ -282,8 +287,9 @@ with gr.Blocks() as demo:
282
  # Connect the predict button to the prediction function
283
  predict_btn.click(
284
  fn=predict_image,
285
- inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
286
- outputs=[prediction, selected_info_bef, selected_info_aft]
 
287
  )
288
 
289
  image_bef.change(
 
1
  import gradio as gr
2
+ import ast
3
  import numpy as np
4
  from PIL import Image, ImageDraw
5
  import torch
 
27
  image_input = image_to_tensor(image_path)
28
  return image_input
29
 
30
+ def parse_bool_string(s):
31
+ try:
32
+ bool_list = ast.literal_eval(s)
33
+ if not isinstance(bool_list, list):
34
+ raise ValueError("The input string must represent a list.")
35
+ return bool_list
36
+ except (SyntaxError, ValueError) as e:
37
+ raise ValueError(f"Invalid input string: {e}")
38
 
39
+ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
40
 
41
  left_map = np.zeros((1, 14 * 14 + 1))
42
  right_map = np.zeros((1, 14 * 14 + 1))
43
 
44
+ left_map[0, 1:] = np.reshape(selected_cells_bef, (1, 14 * 14))
45
+ right_map[0, 1:] = np.reshape(selected_cells_aft, (1, 14 * 14))
46
 
47
 
48
+ if np.count_nonzero(selected_cells_bef) == 0:
49
+ left_map[0, 0] = 1.0
50
 
51
+ if np.count_nonzero(selected_cells_aft) == 0:
52
+ right_map[0, 0] = 1.0
53
 
54
 
55
  return left_map, right_map
 
87
  return input_caption_ids[:, 1:].tolist(), left_map, right_map
88
 
89
  # Dummy prediction function
90
+ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
91
  if image_bef is None:
92
  return "No image provided", "", ""
93
  if image_aft is None:
 
100
 
101
  tokenizer = ClipTokenizer()
102
 
103
+ selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
104
+ selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
105
+
106
  left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
107
 
108
  left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
 
233
  3. Click the 'Predict' button to get model predictions
234
  """)
235
 
 
 
 
236
  height = gr.State(value=320)
237
  width = gr.State(value=480)
238
 
239
+ sel_attn_bef = gr.Textbox("", visible=False)
240
+ sel_attn_aft = gr.Textbox("", visible=False)
241
+
242
 
243
  with gr.Row():
244
  with gr.Column(scale=1):
 
287
  # Connect the predict button to the prediction function
288
  predict_btn.click(
289
  fn=predict_image,
290
+ inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
291
+ outputs=[prediction, selected_info_bef, selected_info_aft],
292
+ _js="(sel_attn_bef, sel_attn_aft) => { return [read_js_Data()]; }"
293
  )
294
 
295
  image_bef.change(
js/interactive_grid.js CHANGED
@@ -304,7 +304,7 @@ function read_js_Data() {
304
  console.log("read_js_Data");
305
  console.log("read_js_Data");
306
  console.log("read_js_Data");
307
- return grid;
308
  }
309
 
310
 
 
304
  console.log("read_js_Data");
305
  console.log("read_js_Data");
306
  console.log("read_js_Data");
307
+ return grid_bef, grid_aft;
308
  }
309
 
310