Spaces:
Sleeping
Sleeping
fix
Browse files- app.py +12 -5
- js/interactive_grid.js +10 -2
app.py
CHANGED
@@ -38,6 +38,9 @@ def parse_bool_string(s):
|
|
38 |
|
39 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
40 |
|
|
|
|
|
|
|
41 |
left_map = np.zeros((1, 14 * 14 + 1))
|
42 |
right_map = np.zeros((1, 14 * 14 + 1))
|
43 |
|
@@ -47,12 +50,14 @@ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
|
47 |
|
48 |
if np.count_nonzero(selected_cells_bef) == 0:
|
49 |
left_map[0, 0] = 1.0
|
|
|
50 |
|
51 |
if np.count_nonzero(selected_cells_aft) == 0:
|
52 |
right_map[0, 0] = 1.0
|
|
|
53 |
|
54 |
|
55 |
-
return left_map, right_map
|
56 |
|
57 |
def _get_rawimage(image_path):
|
58 |
# Pair x L x T x 3 x H x W
|
@@ -103,7 +108,7 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
|
|
103 |
selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
|
104 |
selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
|
105 |
|
106 |
-
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
107 |
|
108 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
109 |
|
@@ -130,8 +135,10 @@ def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
|
|
130 |
pred = f"{decode_text}"
|
131 |
|
132 |
# Include information about selected cells
|
133 |
-
|
134 |
-
|
|
|
|
|
135 |
|
136 |
return pred, selected_info_bef, selected_info_aft
|
137 |
|
@@ -289,7 +296,7 @@ with gr.Blocks() as demo:
|
|
289 |
fn=predict_image,
|
290 |
inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
|
291 |
outputs=[prediction, selected_info_bef, selected_info_aft],
|
292 |
-
_js="(sel_attn_bef, sel_attn_aft) => { return [
|
293 |
)
|
294 |
|
295 |
image_bef.change(
|
|
|
38 |
|
39 |
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
40 |
|
41 |
+
first_ = True
|
42 |
+
second_ = True
|
43 |
+
|
44 |
left_map = np.zeros((1, 14 * 14 + 1))
|
45 |
right_map = np.zeros((1, 14 * 14 + 1))
|
46 |
|
|
|
50 |
|
51 |
if np.count_nonzero(selected_cells_bef) == 0:
|
52 |
left_map[0, 0] = 1.0
|
53 |
+
first_ = False
|
54 |
|
55 |
if np.count_nonzero(selected_cells_aft) == 0:
|
56 |
right_map[0, 0] = 1.0
|
57 |
+
second_ = False
|
58 |
|
59 |
|
60 |
+
return left_map, right_map, first_, second_
|
61 |
|
62 |
def _get_rawimage(image_path):
|
63 |
# Pair x L x T x 3 x H x W
|
|
|
108 |
selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
|
109 |
selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
|
110 |
|
111 |
+
left_map, right_map, first_, second_ = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
112 |
|
113 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
114 |
|
|
|
135 |
pred = f"{decode_text}"
|
136 |
|
137 |
# Include information about selected cells
|
138 |
+
i, j = np.nonzero(selected_cells_bef)
|
139 |
+
selected_info_bef = f"{list(zip(i, j))}" if first_ else "No image patch was selected"
|
140 |
+
i, j = np.nonzero(selected_cells_aft)
|
141 |
+
selected_info_aft = f"{list(zip(i, j))}" if second_ else "No image patch was selected"
|
142 |
|
143 |
return pred, selected_info_bef, selected_info_aft
|
144 |
|
|
|
296 |
fn=predict_image,
|
297 |
inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
|
298 |
outputs=[prediction, selected_info_bef, selected_info_aft],
|
299 |
+
_js="(image_bef, image_aft, sel_attn_bef, sel_attn_aft) => { return [image_bef, image_aft, read_js_Data_bef(), read_js_Data_aft()]; }"
|
300 |
)
|
301 |
|
302 |
image_bef.change(
|
js/interactive_grid.js
CHANGED
@@ -298,15 +298,23 @@ function importBackgroundAfter(image_after) {
|
|
298 |
}
|
299 |
}
|
300 |
|
301 |
-
function
|
302 |
console.log("read_js_Data");
|
303 |
console.log("read_js_Data");
|
304 |
console.log("read_js_Data");
|
305 |
console.log("read_js_Data");
|
306 |
console.log("read_js_Data");
|
307 |
-
return grid_bef
|
308 |
}
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
|
311 |
function set_grid_from_data(data) {
|
312 |
if (data.length !== gridSize || data[0].length !== gridSize) {
|
|
|
298 |
}
|
299 |
}
|
300 |
|
301 |
+
function read_js_Data_bef() {
|
302 |
console.log("read_js_Data");
|
303 |
console.log("read_js_Data");
|
304 |
console.log("read_js_Data");
|
305 |
console.log("read_js_Data");
|
306 |
console.log("read_js_Data");
|
307 |
+
return grid_bef;
|
308 |
}
|
309 |
|
310 |
+
function read_js_Data_aft() {
|
311 |
+
console.log("read_js_Data");
|
312 |
+
console.log("read_js_Data");
|
313 |
+
console.log("read_js_Data");
|
314 |
+
console.log("read_js_Data");
|
315 |
+
console.log("read_js_Data");
|
316 |
+
return grid_aft;
|
317 |
+
}
|
318 |
|
319 |
function set_grid_from_data(data) {
|
320 |
if (data.length !== gridSize || data[0].length !== gridSize) {
|