Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update app.py
Browse fileschanged strategy descriptor to download
app.py
CHANGED
@@ -214,26 +214,58 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
214 |
|
215 |
print(f"CSV file '{filename}' created successfully.")
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
-
# Filter the data
|
223 |
-
filtered_data = selected_test_info.iloc[matching_indices]
|
224 |
-
# new data contains etalon instead of 0/1 for ER/ME
|
225 |
-
filtered_data = filtered_data[filtered_data[8] == task_type] # Ensure test_info[6] matches
|
226 |
-
|
227 |
-
# Define filename dynamically
|
228 |
-
task_type_map = {0: "ER", 1: "ME"}
|
229 |
-
label_map = {0: "unsuccessful", 1: "successful"}
|
230 |
-
|
231 |
-
filename = f"{task_type_map[task_type]}-{label_map[label]}-strategies.csv"
|
232 |
-
|
233 |
-
|
234 |
-
# Write to CSV
|
235 |
-
process_and_write_csv(filtered_data, filename)
|
236 |
-
|
237 |
with open("fileHandler/roc_data2.pkl", 'rb') as file:
|
238 |
data = pickle.load(file)
|
239 |
t_label=data[0]
|
@@ -539,7 +571,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
539 |
textinfo='percent+label',
|
540 |
textposition='auto',
|
541 |
marker=dict(colors=colors),
|
542 |
-
sort=False
|
|
|
543 |
|
544 |
)])
|
545 |
|
@@ -577,7 +610,8 @@ def process_file(model_name,inc_slider,progress=Progress(track_tqdm=True)):
|
|
577 |
textinfo='percent+label',
|
578 |
textposition='auto',
|
579 |
marker=dict(colors=colors),
|
580 |
-
sort=False
|
|
|
581 |
# pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
|
582 |
|
583 |
)])
|
@@ -1142,31 +1176,82 @@ button, select, .slider-percentage {
|
|
1142 |
margin-bottom: 1rem !important;
|
1143 |
text-align: center !important;
|
1144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1145 |
|
|
|
|
|
|
|
|
|
|
|
1146 |
}
|
1147 |
|
1148 |
|
1149 |
'''
|
|
|
1150 |
# Define the file directory
|
1151 |
FILE_DIR = "fileHandler"
|
1152 |
|
1153 |
# Function to get list of files
|
1154 |
def list_files():
|
1155 |
-
return ['Unsuccessful Strategies (ER)', 'Successful Strategies (ER)', 'Unsuccessful Strategies (ME)', 'Successful Strategies (ME)']
|
1156 |
label_to_filename = {
|
1157 |
-
|
1158 |
-
'Successful Strategies (ER)': 'ER-successful
|
1159 |
-
'Unsuccessful Strategies (
|
1160 |
-
'Successful Strategies (ME)': 'ME-successful
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1161 |
}
|
|
|
1162 |
# Function to provide the selected file path
|
1163 |
-
def
|
1164 |
-
|
1165 |
-
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1170 |
|
1171 |
|
1172 |
with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
@@ -1205,31 +1290,53 @@ with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
|
1205 |
opt2_pie = gr.Plot(label="ME")
|
1206 |
|
1207 |
with gr.Row():
|
1208 |
-
gr.
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
|
1214 |
-
|
1215 |
-
|
1216 |
-
|
1217 |
-
|
1218 |
-
|
1219 |
-
|
1220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1221 |
|
1222 |
|
|
|
|
|
|
|
|
|
|
|
1223 |
|
1224 |
btn.click(
|
1225 |
fn=lambda model, increment: (
|
1226 |
*process_file(model, increment), # Unpack all outputs from process_file
|
1227 |
-
gr.update(value=None),
|
|
|
1228 |
None, # Clear file output
|
1229 |
gr.update(visible=False) # Hide visualize markdown
|
1230 |
),
|
1231 |
inputs=[model_dropdown, increment_slider],
|
1232 |
-
outputs=[output_text, plot_output, opt1_pie, opt2_pie,
|
1233 |
)
|
1234 |
|
1235 |
|
|
|
214 |
|
215 |
print(f"CSV file '{filename}' created successfully.")
|
216 |
|
217 |
+
task_type_map = {0: "ER", 1: "ME"}
|
218 |
+
label_map = {0: "unsuccessful", 1: "successful"}
|
219 |
+
|
220 |
+
# -------------------------------
|
221 |
+
# 1. Where tlb == plb
|
222 |
+
# -------------------------------
|
223 |
+
for label in [0, 1]:
|
224 |
+
# All strategies
|
225 |
+
matching_indices = [i for i in range(len(tlb)) if tlb[i] == plb[i] == label]
|
226 |
+
filtered_data = selected_test_info.iloc[matching_indices]
|
227 |
+
filename = f"allstrategies-match-{label_map[label]}.csv"
|
228 |
+
process_and_write_csv(filtered_data, filename)
|
229 |
+
|
230 |
+
# Per task type
|
231 |
+
for task_type in [0, 1]:
|
232 |
+
task_data = filtered_data[filtered_data[8] == task_type]
|
233 |
+
filename = f"{task_type_map[task_type]}-match-{label_map[label]}.csv"
|
234 |
+
process_and_write_csv(task_data, filename)
|
235 |
+
|
236 |
+
# -------------------------------
|
237 |
+
# 2. Where tlb only
|
238 |
+
# -------------------------------
|
239 |
+
for label in [0, 1]:
|
240 |
+
# All strategies
|
241 |
+
matching_indices = [i for i in range(len(tlb)) if tlb[i] == label]
|
242 |
+
filtered_data = selected_test_info.iloc[matching_indices]
|
243 |
+
filename = f"allstrategies-groundtruth-{label_map[label]}.csv"
|
244 |
+
process_and_write_csv(filtered_data, filename)
|
245 |
+
|
246 |
+
# Per task type
|
247 |
+
for task_type in [0, 1]:
|
248 |
+
task_data = filtered_data[filtered_data[8] == task_type]
|
249 |
+
filename = f"{task_type_map[task_type]}-groundtruth-{label_map[label]}.csv"
|
250 |
+
process_and_write_csv(task_data, filename)
|
251 |
+
|
252 |
+
# -------------------------------
|
253 |
+
# 3. All data by task type (no label filtering)
|
254 |
+
# -------------------------------
|
255 |
+
# ER
|
256 |
+
task_data = selected_test_info[selected_test_info[8] == 0]
|
257 |
+
filename = f"ER-all.csv"
|
258 |
+
process_and_write_csv(task_data, filename)
|
259 |
+
|
260 |
+
# ME
|
261 |
+
task_data = selected_test_info[selected_test_info[8] == 1]
|
262 |
+
filename = f"ME-all.csv"
|
263 |
+
process_and_write_csv(task_data, filename)
|
264 |
+
|
265 |
+
# All strategies
|
266 |
+
filename = "allstrategies-all.csv"
|
267 |
+
process_and_write_csv(selected_test_info, filename)
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
with open("fileHandler/roc_data2.pkl", 'rb') as file:
|
270 |
data = pickle.load(file)
|
271 |
t_label=data[0]
|
|
|
571 |
textinfo='percent+label',
|
572 |
textposition='auto',
|
573 |
marker=dict(colors=colors),
|
574 |
+
sort=False,
|
575 |
+
hole=0.4
|
576 |
|
577 |
)])
|
578 |
|
|
|
610 |
textinfo='percent+label',
|
611 |
textposition='auto',
|
612 |
marker=dict(colors=colors),
|
613 |
+
sort=False,
|
614 |
+
hole=0.4
|
615 |
# pull=[0, 0.2, 0, 0] # for pulling part of pie chart out (depends on position)
|
616 |
|
617 |
)])
|
|
|
1176 |
margin-bottom: 1rem !important;
|
1177 |
text-align: center !important;
|
1178 |
|
1179 |
+
#file-box {
|
1180 |
+
border: 1px solid #ccc;
|
1181 |
+
border-radius: 6px;
|
1182 |
+
padding: 10px;
|
1183 |
+
margin-top: 12px;
|
1184 |
+
background-color: #f9f9f9;
|
1185 |
+
}
|
1186 |
|
1187 |
+
.file-download {
|
1188 |
+
margin-bottom: 5px !important;
|
1189 |
+
padding: 4px !important;
|
1190 |
+
height: 10px;
|
1191 |
+
}
|
1192 |
}
|
1193 |
|
1194 |
|
1195 |
'''
|
1196 |
+
|
1197 |
# Define the file directory
|
1198 |
FILE_DIR = "fileHandler"
|
1199 |
|
1200 |
# Function to get list of files
|
1201 |
def list_files():
|
1202 |
+
return ['Unsuccessful Strategies (ER)', 'Successful Strategies (ER)', 'Unsuccessful Strategies (ME)', 'Successful Strategies (ME)','Ground Truth Unsuccessful Strategies (ER)','Ground Truth Successful Strategies (ER)','Ground Truth Unsuccessful Strategies (ME)','Ground Truth Successful Strategies (ME)']
|
1203 |
label_to_filename = {
|
1204 |
+
# Predicted (tlb == plb)
|
1205 |
+
'Predicted Successful Strategies (ER)': 'ER-match-successful.csv',
|
1206 |
+
'Predicted Unsuccessful Strategies (ER)': 'ER-match-unsuccessful.csv',
|
1207 |
+
'Predicted Successful Strategies (ME)': 'ME-match-successful.csv',
|
1208 |
+
'Predicted Unsuccessful Strategies (ME)': 'ME-match-unsuccessful.csv',
|
1209 |
+
'Predicted Successful Strategies (All)': 'allstrategies-match-successful.csv',
|
1210 |
+
'Predicted Unsuccessful Strategies (All)': 'allstrategies-match-unsuccessful.csv',
|
1211 |
+
|
1212 |
+
# Ground Truth (tlb only)
|
1213 |
+
'Ground Truth Successful Strategies (ER)': 'ER-groundtruth-successful.csv',
|
1214 |
+
'Ground Truth Unsuccessful Strategies (ER)': 'ER-groundtruth-unsuccessful.csv',
|
1215 |
+
'Ground Truth Successful Strategies (ME)': 'ME-groundtruth-successful.csv',
|
1216 |
+
'Ground Truth Unsuccessful Strategies (ME)': 'ME-groundtruth-unsuccessful.csv',
|
1217 |
+
'Ground Truth Successful Strategies (All)': 'allstrategies-groundtruth-successful.csv',
|
1218 |
+
'Ground Truth Unsuccessful Strategies (All)': 'allstrategies-groundtruth-unsuccessful.csv',
|
1219 |
+
|
1220 |
+
# All data
|
1221 |
+
'All Strategies (ER)': 'ER-all.csv',
|
1222 |
+
'All Strategies (ME)': 'ME-all.csv',
|
1223 |
+
'All Strategies (All)': 'allstrategies-all.csv'
|
1224 |
}
|
1225 |
+
|
1226 |
# Function to provide the selected file path
|
1227 |
+
def provide_file_paths(task_type, source):
|
1228 |
+
if not task_type or not source:
|
1229 |
+
return None, None, gr.update(visible=False)
|
1230 |
+
|
1231 |
+
# Handle "All" case for combined strategies
|
1232 |
+
if source == "All":
|
1233 |
+
label_success = f"All Strategies ({task_type})"
|
1234 |
+
label_unsuccess = f"All Strategies ({task_type})"
|
1235 |
+
else:
|
1236 |
+
label_success = f"{source} Successful Strategies ({task_type})"
|
1237 |
+
label_unsuccess = f"{source} Unsuccessful Strategies ({task_type})"
|
1238 |
+
label_all=f"All Strategies ({task_type})"
|
1239 |
+
|
1240 |
+
file_success = label_to_filename.get(label_success)
|
1241 |
+
file_unsuccess = label_to_filename.get(label_unsuccess)
|
1242 |
+
file_all=label_to_filename.get(label_all)
|
1243 |
+
|
1244 |
+
file_success_path = f"{FILE_DIR}/{file_success}" if file_success else None
|
1245 |
+
file_unsuccess_path = f"{FILE_DIR}/{file_unsuccess}" if file_unsuccess else None
|
1246 |
+
file_all_path = f"{FILE_DIR}/{file_all}" if file_all else None
|
1247 |
+
|
1248 |
+
dynamic_text = "🔍 [Visualize the strategies](https://path-analysis.vercel.app/)"
|
1249 |
+
if file_success and file_unsuccess and file_all:
|
1250 |
+
return file_success_path, file_unsuccess_path,file_all_path, gr.update(value=dynamic_text, visible=True)
|
1251 |
+
|
1252 |
+
return None, None,None, gr.update(visible=False)
|
1253 |
+
|
1254 |
+
|
1255 |
|
1256 |
|
1257 |
with gr.Blocks(theme='gstaff/sketch', css=custom_css) as demo:
|
|
|
1290 |
opt2_pie = gr.Plot(label="ME")
|
1291 |
|
1292 |
with gr.Row():
|
1293 |
+
with gr.Column():
|
1294 |
+
# gr.Markdown("Select strategy filters and click Generate")
|
1295 |
+
task_type_radio = gr.Dropdown(
|
1296 |
+
choices=["ER", "ME", "All"],
|
1297 |
+
label="Filter by Problem Type",
|
1298 |
+
interactive=True
|
1299 |
+
)
|
1300 |
+
source_radio = gr.Checkbox(
|
1301 |
+
label="Predicted Labels",
|
1302 |
+
value=True
|
1303 |
+
)
|
1304 |
+
generate_button = gr.Button("Generate Strategies")
|
1305 |
+
|
1306 |
+
# with gr.Row():
|
1307 |
+
with gr.Column():
|
1308 |
+
with gr.Group(visible=False) as file_output_group:
|
1309 |
+
|
1310 |
+
gr.Markdown("**Download strategy descriptor files**")
|
1311 |
+
file_output_success = gr.File(label=" ")
|
1312 |
+
file_output_unsuccess = gr.File(label=" ")
|
1313 |
+
file_output_all = gr.File(label=" ")
|
1314 |
+
visualize_markdown = gr.Markdown(visible=False)
|
1315 |
+
|
1316 |
+
|
1317 |
+
def handle_generate(task_type_dropdown, use_predicted):
|
1318 |
+
label_source = "Predicted" if use_predicted else "Ground Truth"
|
1319 |
+
file_success_path, file_unsuccess_path,file_all_path, viz_link = provide_file_paths(task_type_dropdown, label_source)
|
1320 |
+
|
1321 |
+
return file_success_path, file_unsuccess_path,file_all_path, viz_link,gr.update(visible=True)
|
1322 |
|
1323 |
|
1324 |
+
generate_button.click(
|
1325 |
+
fn=handle_generate,
|
1326 |
+
inputs=[task_type_radio, source_radio],
|
1327 |
+
outputs=[file_output_success, file_output_unsuccess,file_output_all, visualize_markdown,file_output_group]
|
1328 |
+
)
|
1329 |
|
1330 |
btn.click(
|
1331 |
fn=lambda model, increment: (
|
1332 |
*process_file(model, increment), # Unpack all outputs from process_file
|
1333 |
+
gr.update(value=None), # update outcome_radio
|
1334 |
+
gr.update(value=None), # Reset dropdown to first item
|
1335 |
None, # Clear file output
|
1336 |
gr.update(visible=False) # Hide visualize markdown
|
1337 |
),
|
1338 |
inputs=[model_dropdown, increment_slider],
|
1339 |
+
outputs=[output_text, plot_output, opt1_pie, opt2_pie, task_type_radio,source_radio,file_output_success,file_output_unsuccess, visualize_markdown]
|
1340 |
)
|
1341 |
|
1342 |
|