Spaces:
Runtime error
Runtime error
| import matplotlib as mpl | |
| mpl.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import matplotlib.patches as patches | |
| import numpy as np | |
| from PIL import Image | |
| from zipfile import ZipFile | |
| import gradio as gr | |
| class SampleClass: | |
| def __init__(self): | |
| self.test_df = pd.read_json("data/full_pred_test_w_plurals_w_iou.json") | |
| self.val_df = pd.read_json("data/full_pred_val_w_plurals_w_iou.json") | |
| self.zip_file = ZipFile("data/saiapr_tc-12.zip", 'r') | |
| self.filtered_df = None | |
| def __get(self, img_path): | |
| img_obj = self.zip_file.open(img_path) | |
| img = Image.open(img_obj) | |
| # img = np.array(img) | |
| return img | |
| def __loadPredictions(self, split, model): | |
| assert(split in ['test','val']) | |
| assert(model in ['baseline','extended']) | |
| if split == "test": | |
| df = self.test_df | |
| elif split == "val": | |
| df = self.val_df | |
| else: | |
| raise ValueError("File not available yet") | |
| if model == 'baseline': | |
| df = df.rename(columns={'baseline_hit':'hit', 'baseline_pred':'predictions', | |
| 'extended_hit':'hit_other', 'extended_pred':'predictions_other', | |
| 'baseline_iou':'iou', | |
| 'extended_iou':'iou_other'} | |
| ) | |
| elif model == 'extended': | |
| df = df.rename(columns={'extended_hit':'hit', 'extended_pred':'predictions', | |
| 'baseline_hit':'hit_other', 'baseline_pred':'predictions_other', | |
| 'extended_iou':'iou', | |
| 'baseline_iou':'iou_other'} | |
| ) | |
| return df | |
| def __getSample(self, id): | |
| sample = self.filtered_df[self.filtered_df.sample_idx == id] | |
| sent = sample['sent'].values[0] | |
| pos_tags = sample['pos_tags'].values[0] | |
| plural_tks = sample['plural_tks'].values[0] | |
| cat_intrinsic = sample['intrinsic'].values[0] | |
| cat_spatial = sample['spatial'].values[0] | |
| cat_ordinal = sample['ordinal'].values[0] | |
| cat_relational = sample['relational'].values[0] | |
| cat_plural = sample['plural'].values[0] | |
| categories = [('instrinsic',cat_intrinsic), | |
| ('spatial',cat_spatial), | |
| ('ordinal',cat_ordinal), | |
| ('relational',cat_relational), | |
| ('plural',cat_plural)] | |
| hit = sample['hit'].values[0] | |
| hit_o = sample['hit_other'].values[0] | |
| iou = sample['iou'].values[0] | |
| iou_o = sample['iou_other'].values[0] | |
| prediction = {0:' FAIL ',1:' CORRECT '} | |
| bbox_gt = sample['bbox'].values[0] | |
| x1_gt,y1_gt,x2_gt,y2_gt = bbox_gt | |
| # x1_gt,y1_gt,x2_gt,y2_gt = tuple(map(float,bbox_gt[1:-1].split(","))) | |
| bp_bbox = sample['predictions'].values[0] | |
| x1_pred,y1_pred,x2_pred,y2_pred = bp_bbox | |
| # x1_pred,y1_pred,x2_pred,y2_pred = tuple(map(float,bp_bbox[1:-1].split(","))) | |
| bp_o_bbox = sample['predictions_other'].values[0] | |
| x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = bp_o_bbox | |
| # x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = tuple(map(float,bp_o_bbox[1:-1].split(","))) | |
| # Create Fig with predictions | |
| img_path = "saiapr_tc-12"+sample['file_path'].values[0].split("saiapr_tc-12")[1] | |
| img_seg_path = img_path.replace("images","segmented_images") | |
| fig, ax = plt.subplots(1) | |
| ax.imshow(self.__get(img_path), interpolation='bilinear') | |
| # Create bbox's | |
| rect_gt = patches.Rectangle((x1_gt,y1_gt), (x2_gt-x1_gt),(y2_gt-y1_gt), | |
| linewidth=2, edgecolor='blue', facecolor='None') #fill=True, alpha=.3 | |
| rect_pred = patches.Rectangle((x1_pred,y1_pred), (x2_pred-x1_pred),(y2_pred-y1_pred), | |
| linewidth=2, edgecolor='lightgreen', facecolor='none') | |
| rect_pred_o = patches.Rectangle((x1_pred_o,y1_pred_o), (x2_pred_o-x1_pred_o),(y2_pred_o-y1_pred_o), | |
| linewidth=2, edgecolor='red', facecolor='none') | |
| ax.add_patch(rect_gt) | |
| ax.add_patch(rect_pred) | |
| ax.add_patch(rect_pred_o) | |
| ax.axis('off') | |
| info = {'Expresion':sent, | |
| 'Idx Sample':str(id), | |
| 'IoU': str(round(iou,2)) + "("+prediction[hit]+")", | |
| 'IoU other': str(round(iou_o,2)) + "("+prediction[hit_o]+")", | |
| 'Pos Tags':str(pos_tags), | |
| 'PluralTks ':plural_tks, | |
| 'Categories':",".join([c for c,b in categories if b]) | |
| } | |
| plt.title(info['Expresion'], fontsize=12) | |
| plt.tight_layout() | |
| plt.close(fig) | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| w, h = fig.canvas.get_width_height() | |
| img = data.reshape((int(h), int(w), -1)) | |
| return info, img, self.__get(img_seg_path) | |
| def explorateSamples(self, | |
| username, | |
| predictions, | |
| category, | |
| model, | |
| split, | |
| next_idx_sample): | |
| next_idx_sample = int(next_idx_sample) | |
| hit = {'fail':0,'correct':1} | |
| df = self.__loadPredictions(split, model) | |
| self.filtered_df = df[(df[category] == 1) & (df.hit == hit[predictions])] | |
| all_idx_samples = self.filtered_df.sample_idx.to_list() | |
| parts = np.array_split(list(all_idx_samples), 4) | |
| user_ids = { | |
| 'luciana':list(parts[0]), | |
| 'mauri':list(parts[1]), | |
| 'jorge':list(parts[2]), | |
| 'nano':list(parts[3]) | |
| } | |
| try: | |
| id_ = user_ids[username].index(next_idx_sample) | |
| except: | |
| id_ = 0 | |
| next_idx_sample = user_ids[username][ min(id_+1, len(user_ids[username])-1) ] | |
| progress = {f"{id_}/{len(user_ids[username])-1}":id_/(len(user_ids[username])-1)} | |
| info, img, img_seg = self.__getSample(user_ids[username][id_]) | |
| info = "".join([str(k)+":\t"+str(v)+"\n" for k,v in list(info.items())[1:]]).strip() | |
| return (gr.Number.update(value=next_idx_sample),progress,img,info,img_seg) | |