pooyanrg commited on
Commit
ae5e903
·
1 Parent(s): 7af7b4e
Files changed (2) hide show
  1. app.py +12 -5
  2. js/interactive_grid.js +10 -2
app.py CHANGED
@@ -38,6 +38,9 @@ def parse_bool_string(s):
38
 
39
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
40
 
 
 
 
41
  left_map = np.zeros((1, 14 * 14 + 1))
42
  right_map = np.zeros((1, 14 * 14 + 1))
43
 
@@ -47,12 +50,14 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
47
 
48
  if np.count_nonzero(selected_cells_bef) == 0:
49
  left_map[0, 0] = 1.0
 
50
 
51
  if np.count_nonzero(selected_cells_aft) == 0:
52
  right_map[0, 0] = 1.0
 
53
 
54
 
55
- return left_map, right_map
56
 
57
  def _get_rawimage(image_path):
58
  # Pair x L x T x 3 x H x W
@@ -103,7 +108,7 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
103
  selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
104
  selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
105
 
106
- left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
107
 
108
  left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
109
 
@@ -130,8 +135,10 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
130
  pred = f"{decode_text}"
131
 
132
  # Include information about selected cells
133
- selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
134
- selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
 
 
135
 
136
  return pred, selected_info_bef, selected_info_aft
137
 
@@ -289,7 +296,7 @@ with gr.Blocks() as demo:
289
  fn=predict_image,
290
  inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
291
  outputs=[prediction, selected_info_bef, selected_info_aft],
292
- _js="(sel_attn_bef, sel_attn_aft) => { return [read_js_Data()]; }"
293
  )
294
 
295
  image_bef.change(
 
38
 
39
  def get_intervention_vector(selected_cells_bef, selected_cells_aft):
40
 
41
+ first_ = True
42
+ second_ = True
43
+
44
  left_map = np.zeros((1, 14 * 14 + 1))
45
  right_map = np.zeros((1, 14 * 14 + 1))
46
 
 
50
 
51
  if np.count_nonzero(selected_cells_bef) == 0:
52
  left_map[0, 0] = 1.0
53
+ first_ = False
54
 
55
  if np.count_nonzero(selected_cells_aft) == 0:
56
  right_map[0, 0] = 1.0
57
+ second_ = False
58
 
59
 
60
+ return left_map, right_map, first_, second_
61
 
62
  def _get_rawimage(image_path):
63
  # Pair x L x T x 3 x H x W
 
108
  selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
109
  selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
110
 
111
+ left_map, right_map, first_, second_ = get_intervention_vector(selected_cells_bef, selected_cells_aft)
112
 
113
  left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
114
 
 
135
  pred = f"{decode_text}"
136
 
137
  # Include information about selected cells
138
+ i, j = np.nonzero(selected_cells_bef)
139
+ selected_info_bef = f"{list(zip(i, j))}" if first_ else "No image patch was selected"
140
+ i, j = np.nonzero(selected_cells_aft)
141
+ selected_info_aft = f"{list(zip(i, j))}" if second_ else "No image patch was selected"
142
 
143
  return pred, selected_info_bef, selected_info_aft
144
 
 
296
  fn=predict_image,
297
  inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
298
  outputs=[prediction, selected_info_bef, selected_info_aft],
299
+ _js="(image_bef, image_aft, sel_attn_bef, sel_attn_aft) => { return [image_bef, image_aft, read_js_Data_bef(), read_js_Data_aft()]; }"
300
  )
301
 
302
  image_bef.change(
js/interactive_grid.js CHANGED
@@ -298,15 +298,23 @@ function importBackgroundAfter(image_after) {
298
  }
299
  }
300
 
301
- function read_js_Data() {
302
  console.log("read_js_Data");
303
  console.log("read_js_Data");
304
  console.log("read_js_Data");
305
  console.log("read_js_Data");
306
  console.log("read_js_Data");
307
- return grid_bef, grid_aft;
308
  }
309
 
 
 
 
 
 
 
 
 
310
 
311
  function set_grid_from_data(data) {
312
  if (data.length !== gridSize || data[0].length !== gridSize) {
 
298
  }
299
  }
300
 
301
+ function read_js_Data_bef() {
302
  console.log("read_js_Data");
303
  console.log("read_js_Data");
304
  console.log("read_js_Data");
305
  console.log("read_js_Data");
306
  console.log("read_js_Data");
307
+ return grid_bef;
308
  }
309
 
310
+ function read_js_Data_aft() {
311
+ console.log("read_js_Data");
312
+ console.log("read_js_Data");
313
+ console.log("read_js_Data");
314
+ console.log("read_js_Data");
315
+ console.log("read_js_Data");
316
+ return grid_aft;
317
+ }
318
 
319
  function set_grid_from_data(data) {
320
  if (data.length !== gridSize || data[0].length !== gridSize) {