Spaces:
Sleeping
Sleeping
fix
Browse files- app.py +26 -20
- js/interactive_grid.js +1 -1
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import numpy as np
|
3 |
from PIL import Image, ImageDraw
|
4 |
import torch
|
@@ -26,28 +27,29 @@ def get_image_data(image_path):
|
|
26 |
image_input = image_to_tensor(image_path)
|
27 |
return image_input
|
28 |
|
29 |
-
def
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
|
|
38 |
|
39 |
left_map = np.zeros((1, 14 * 14 + 1))
|
40 |
right_map = np.zeros((1, 14 * 14 + 1))
|
41 |
|
42 |
-
left_map[0, 1:] = np.reshape(
|
43 |
-
right_map[0, 1:] = np.reshape(
|
44 |
|
45 |
|
46 |
-
if
|
47 |
-
left_map[0, 0] =
|
48 |
|
49 |
-
if
|
50 |
-
right_map[0, 0] =
|
51 |
|
52 |
|
53 |
return left_map, right_map
|
@@ -85,7 +87,7 @@ def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map
|
|
85 |
return input_caption_ids[:, 1:].tolist(), left_map, right_map
|
86 |
|
87 |
# Dummy prediction function
|
88 |
-
def predict_image(image_bef, image_aft,
|
89 |
if image_bef is None:
|
90 |
return "No image provided", "", ""
|
91 |
if image_aft is None:
|
@@ -98,6 +100,9 @@ def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
|
|
98 |
|
99 |
tokenizer = ClipTokenizer()
|
100 |
|
|
|
|
|
|
|
101 |
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
102 |
|
103 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
@@ -228,12 +233,12 @@ with gr.Blocks() as demo:
|
|
228 |
3. Click the 'Predict' button to get model predictions
|
229 |
""")
|
230 |
|
231 |
-
selected_cells_bef = gr.State([])
|
232 |
-
selected_cells_aft = gr.State([])
|
233 |
-
|
234 |
height = gr.State(value=320)
|
235 |
width = gr.State(value=480)
|
236 |
|
|
|
|
|
|
|
237 |
|
238 |
with gr.Row():
|
239 |
with gr.Column(scale=1):
|
@@ -282,8 +287,9 @@ with gr.Blocks() as demo:
|
|
282 |
# Connect the predict button to the prediction function
|
283 |
predict_btn.click(
|
284 |
fn=predict_image,
|
285 |
-
inputs=[image_bef, image_aft,
|
286 |
-
outputs=[prediction, selected_info_bef, selected_info_aft]
|
|
|
287 |
)
|
288 |
|
289 |
image_bef.change(
|
|
|
1 |
import gradio as gr
|
2 |
+
import ast
|
3 |
import numpy as np
|
4 |
from PIL import Image, ImageDraw
|
5 |
import torch
|
|
|
27 |
image_input = image_to_tensor(image_path)
|
28 |
return image_input
|
29 |
|
30 |
+
def parse_bool_string(s):
|
31 |
+
try:
|
32 |
+
bool_list = ast.literal_eval(s)
|
33 |
+
if not isinstance(bool_list, list):
|
34 |
+
raise ValueError("The input string must represent a list.")
|
35 |
+
return bool_list
|
36 |
+
except (SyntaxError, ValueError) as e:
|
37 |
+
raise ValueError(f"Invalid input string: {e}")
|
38 |
|
39 |
+
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
40 |
|
41 |
left_map = np.zeros((1, 14 * 14 + 1))
|
42 |
right_map = np.zeros((1, 14 * 14 + 1))
|
43 |
|
44 |
+
left_map[0, 1:] = np.reshape(selected_cells_bef, (1, 14 * 14))
|
45 |
+
right_map[0, 1:] = np.reshape(selected_cells_aft, (1, 14 * 14))
|
46 |
|
47 |
|
48 |
+
if np.count_nonzero(selected_cells_bef) == 0:
|
49 |
+
left_map[0, 0] = 1.0
|
50 |
|
51 |
+
if np.count_nonzero(selected_cells_aft) == 0:
|
52 |
+
right_map[0, 0] = 1.0
|
53 |
|
54 |
|
55 |
return left_map, right_map
|
|
|
87 |
return input_caption_ids[:, 1:].tolist(), left_map, right_map
|
88 |
|
89 |
# Dummy prediction function
|
90 |
+
def predict_image(image_bef, image_aft, json_data_bef, json_data_aft):
|
91 |
if image_bef is None:
|
92 |
return "No image provided", "", ""
|
93 |
if image_aft is None:
|
|
|
100 |
|
101 |
tokenizer = ClipTokenizer()
|
102 |
|
103 |
+
selected_cells_bef = np.asarray(parse_bool_string(json_data_bef), np.int32)
|
104 |
+
selected_cells_aft = np.asarray(parse_bool_string(json_data_aft), np.int32)
|
105 |
+
|
106 |
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
107 |
|
108 |
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
|
|
233 |
3. Click the 'Predict' button to get model predictions
|
234 |
""")
|
235 |
|
|
|
|
|
|
|
236 |
height = gr.State(value=320)
|
237 |
width = gr.State(value=480)
|
238 |
|
239 |
+
sel_attn_bef = gr.Textbox("", visible=False)
|
240 |
+
sel_attn_aft = gr.Textbox("", visible=False)
|
241 |
+
|
242 |
|
243 |
with gr.Row():
|
244 |
with gr.Column(scale=1):
|
|
|
287 |
# Connect the predict button to the prediction function
|
288 |
predict_btn.click(
|
289 |
fn=predict_image,
|
290 |
+
inputs=[image_bef, image_aft, sel_attn_bef, sel_attn_aft],
|
291 |
+
outputs=[prediction, selected_info_bef, selected_info_aft],
|
292 |
+
_js="(sel_attn_bef, sel_attn_aft) => { return [read_js_Data()]; }"
|
293 |
)
|
294 |
|
295 |
image_bef.change(
|
js/interactive_grid.js
CHANGED
@@ -304,7 +304,7 @@ function read_js_Data() {
|
|
304 |
console.log("read_js_Data");
|
305 |
console.log("read_js_Data");
|
306 |
console.log("read_js_Data");
|
307 |
-
return
|
308 |
}
|
309 |
|
310 |
|
|
|
304 |
console.log("read_js_Data");
|
305 |
console.log("read_js_Data");
|
306 |
console.log("read_js_Data");
|
307 |
+
return grid_bef, grid_aft;
|
308 |
}
|
309 |
|
310 |
|