import json import os import sys import cv2 import numpy as np from shapely.geometry import Polygon from tabulate import tabulate def get_image_path(image_dir, file_name_wo_ext): ext_list = ["", ".jpg", ".JPG", ".png", ".PNG", ".jpeg"] image_path = None for ext in ext_list: image_path_tmp = os.path.join(image_dir, file_name_wo_ext + ext) if os.path.exists(image_path_tmp): image_path = image_path_tmp break return image_path def visual_badcase(image_path, pred_list, label_list, output_dir="visual_badcase", info=None, prefix=""): """ """ img = cv2.imread(image_path) if os.path.exists(image_path) is not None else None if img is None: print("--> Warning: skip, given iamge NOT exists: {}".format(image_path)) return None if not os.path.exists(output_dir): os.makedirs(output_dir) font = cv2.FONT_HERSHEY_SIMPLEX for label in label_list: points, class_id = label["poly"], label["category_id"] pts = np.array(points).reshape((1, -1, 2)).astype(np.int32) cv2.polylines(img, pts, isClosed=True, color=(0, 255, 0), thickness=3) cv2.putText(img, "gt:" + str(class_id), tuple(pts[0][0].tolist()), font, 1, (0, 255, 0), 2) for label in pred_list: points, class_id = label["poly"], label["category_id"] pts = np.array(points).reshape((1, -1, 2)).astype(np.int32) cv2.polylines(img, pts, isClosed=True, color=(255, 0, 0), thickness=3) cv2.putText(img, "pred:" + str(class_id), tuple(pts[0][-1].tolist()), font, 1, (255, 0, 0), 2) if info is not None: cv2.putText(img, str(info), (40, 40), font, 1, (0, 0, 255), 2) output_path = os.path.join(output_dir, prefix + os.path.basename(image_path) + "_vis.jpg") cv2.imwrite(output_path, img) return output_path def pub_load_gt_from_json(json_path): """ """ with open(json_path) as f: gt_info = json.load(f) gt_image_list = gt_info["images"] gt_anno_list = gt_info["annotations"] id_to_image_info = {} for image_item in gt_image_list: id_to_image_info[image_item["id"]] = { "file_name": image_item["file_name"], "group_name": image_item.get("group_name", "huntie"), } group_info = {} for annotation_item in gt_anno_list: image_info = id_to_image_info[annotation_item["image_id"]] image_name, group_name = image_info["file_name"], image_info["group_name"] # import ipdb;ipdb.set_trace() if image_name == "15_103.tar_1705.05489.gz_main_12_ori.jpg": print(image_info["file_name"], annotation_item["image_id"]) # import ipdb;ipdb.set_trace() if group_name not in group_info: group_info[group_name] = {} if image_name not in group_info[group_name]: group_info[group_name][image_name] = [] box_xywh = annotation_item["bbox"] box_xyxy = [box_xywh[0], box_xywh[1], box_xywh[0] + box_xywh[2], box_xywh[1] + box_xywh[3]] pts = np.round( [box_xyxy[0], box_xyxy[1], box_xyxy[2], box_xyxy[1], box_xyxy[2], box_xyxy[3], box_xyxy[0], box_xyxy[3]] ) anno_info = { "category_id": annotation_item["category_id"], "poly": pts, "secondary_id": annotation_item.get("secondary_id", -1), "direction_id": annotation_item.get("direction_id", -1), } group_info[group_name][image_name].append(anno_info) group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()]) print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str)) return group_info def load_gt_from_json(json_path): """ """ with open(json_path) as f: gt_info = json.load(f) gt_image_list = gt_info["images"] gt_anno_list = gt_info["annotations"] id_to_image_info = {} for image_item in gt_image_list: id_to_image_info[image_item["id"]] = { "file_name": image_item["file_name"], "group_name": image_item.get("group_name", "huntie"), } group_info = {} for annotation_item in gt_anno_list: image_info = id_to_image_info[annotation_item["image_id"]] image_name, group_name = image_info["file_name"], image_info["group_name"] if group_name not in group_info: group_info[group_name] = {} if image_name not in group_info[group_name]: group_info[group_name][image_name] = [] anno_info = { "category_id": annotation_item["category_id"], "poly": annotation_item["poly"], "secondary_id": annotation_item.get("secondary_id", -1), "direction_id": annotation_item.get("direction_id", -1), } group_info[group_name][image_name].append(anno_info) group_info_str = ", ".join(["{}[{}]".format(k, len(v)) for k, v in group_info.items()]) print("--> load {} groups: {}".format(len(group_info.keys()), group_info_str)) return group_info def calc_iou(label, detect): label_box = [] detect_box = [] d_area = [] for i in range(0, len(detect)): pred_poly = detect[i]["poly"] box_det = [] for k in range(0, 4): box_det.append([pred_poly[2 * k], pred_poly[2 * k + 1]]) detect_box.append(box_det) try: poly = Polygon(box_det) d_area.append(poly.area) except: print("invalid detects", pred_poly) exit(-1) l_area = [] for i in range(0, len(label)): gt_poly = label[i]["poly"] box_gt = [] for k in range(4): box_gt.append([gt_poly[2 * k], gt_poly[2 * k + 1]]) label_box.append(box_gt) try: poly = Polygon(box_gt) l_area.append(poly.area) except: print("invalid detects", gt_poly) exit(-1) ol_areas = [] for i in range(0, len(detect_box)): ol_areas.append([]) poly1 = Polygon(detect_box[i]) for j in range(0, len(label_box)): poly2 = Polygon(label_box[j]) try: ol_area = poly2.intersection(poly1).area except: print("invaild pair", detect_box[i], label_box[j]) ol_areas[i].append(0.0) else: ol_areas[i].append(ol_area) d_ious = [0.0] * len(detect_box) l_ious = [0.0] * len(label_box) for i in range(0, len(detect_box)): for j in range(0, len(label_box)): if int(label[j]["category_id"]) == int(detect[i]["category_id"]): iou = min(ol_areas[i][j] / (d_area[i] + 1e-10), ol_areas[i][j] / (l_area[j] + 1e-10)) else: iou = 0 d_ious[i] = max(d_ious[i], iou) l_ious[j] = max(l_ious[j], iou) return l_ious, d_ious def eval(instance_info): img_name, label_info = instance_info label = label_info["gt"] detect = label_info["det"] l_ious, d_ious = calc_iou(label, detect) return [img_name, d_ious, l_ious, detect, label] def static_with_class(rets, iou_thresh=0.7, is_verbose=True, map_info=None, src_image_dir=None, visualization_dir=None): if is_verbose: table_head = ["Class_id", "Class_name", "Pre_hit", "Pre_num", "GT_hit", "GT_num", "Precision", "Recall", "F-score"] else: table_head = ["Class_id", "Class_name", "Precision", "Recall", "F-score"] table_body = [] class_dict = {} for i in range(len(rets)): img_name, d_ious, l_ious, detects, labels = rets[i] item_lv, item_dv, item_dm, item_lm = 0, 0, 0, 0 for label in labels: item_lv += 1 category_id = label["category_id"] if category_id not in class_dict: class_dict[category_id] = {} class_dict[category_id]["dm"] = 0 class_dict[category_id]["dv"] = 0 class_dict[category_id]["lm"] = 0 class_dict[category_id]["lv"] = 0 class_dict[category_id]["lv"] += 1 for det in detects: item_dv += 1 category_id = det["category_id"] if category_id not in class_dict: print("--> category_id not exists in gt: {}".format(category_id)) continue class_dict[category_id]["dv"] += 1 for idx, iou in enumerate(d_ious): if iou >= iou_thresh: item_dm += 1 class_dict[detects[idx]["category_id"]]["dm"] += 1 for idx, iou in enumerate(l_ious): if iou >= iou_thresh: item_lm += 1 class_dict[labels[idx]["category_id"]]["lm"] += 1 item_p = item_dm / (item_dv + 1e-6) item_r = item_lm / (item_lv + 1e-6) item_f = 2 * item_p * item_r / (item_p + item_r + 1e-6) if item_f < 0.97 and src_image_dir is not None: image_path = get_image_path(src_image_dir, os.path.basename(img_name)) visualization_output = visualization_dir if visualization_dir is not None else "./visualization_badcase" item_info = "IOU{}, {}, {}, {}".format(iou_thresh, item_r, item_p, item_f) vis_path = visual_badcase( image_path, detects, labels, output_dir=visualization_output, info=item_info, prefix="{:02d}_".format(int(item_f * 100)), ) if is_verbose: print("--> info: save visualization at: {}".format(vis_path)) dm, dv, lm, lv = 0, 0, 0, 0 map_info = {} if map_info is None else map_info for key in class_dict.keys(): dm += class_dict[key]["dm"] dv += class_dict[key]["dv"] lm += class_dict[key]["lm"] lv += class_dict[key]["lv"] p = class_dict[key]["dm"] / (class_dict[key]["dv"] + 1e-6) r = class_dict[key]["lm"] / (class_dict[key]["lv"] + 1e-6) fscore = 2 * p * r / (p + r + 1e-6) if is_verbose: table_body.append( ( key, map_info.get("primary_map", {}).get(str(key), str(key)), class_dict[key]["dm"], class_dict[key]["dv"], class_dict[key]["lm"], class_dict[key]["lv"], p, r, fscore, ) ) else: table_body.append((key, map_info.get(str(key), str(key)), p, r, fscore)) p = dm / (dv + 1e-6) r = lm / (lv + 1e-6) f = 2 * p * r / (p + r + 1e-6) table_body_sorted = sorted(table_body, key=lambda x: int((x[0]))) if is_verbose: table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", dm, dv, lm, lv, p, r, f)) else: table_body_sorted.append(("IOU_{}".format(iou_thresh), "average", p, r, f)) print(tabulate(table_body_sorted, headers=table_head, tablefmt="pipe")) return [table_head] + table_body_sorted def multiproc(func, task_list, proc_num=30, retv=True, progress_bar=False): from multiprocessing import Pool pool = Pool(proc_num) rets = [] if progress_bar: import tqdm with tqdm.tqdm(total=len(task_list)) as t: for ret in pool.imap(func, task_list): rets.append(ret) t.update(1) else: for ret in pool.imap(func, task_list): rets.append(ret) pool.close() pool.join() if retv: return rets def eval_and_show( label_dict, detect_dict, output_dir, iou_thresh=0.7, map_info=None, src_image_dir=None, visualization_dir=None ): """ """ evaluation_group_info = {} for group_name, gt_info in label_dict.items(): group_pair_list = [] for file_name, value_list in gt_info.items(): if file_name not in detect_dict: print("--> missing pred:", file_name) continue group_pair_list.append([file_name, {"gt": gt_info[file_name], "det": detect_dict[file_name]}]) evaluation_group_info[group_name] = group_pair_list res_info_all = {} for group_name, group_pair_list in evaluation_group_info.items(): print(" ------- group name: {} -----------".format(group_name)) rets = multiproc(eval, group_pair_list, proc_num=16) group_name_map_info = map_info.get(group_name, None) if map_info is not None else None res_info = static_with_class( rets, iou_thresh=iou_thresh, map_info=group_name_map_info, src_image_dir=src_image_dir, visualization_dir=visualization_dir, ) res_info_all[group_name] = res_info evaluation_res_info_path = os.path.join(output_dir, "results_val.json") with open(evaluation_res_info_path, "w") as f: json.dump(res_info_all, f, ensure_ascii=False, indent=4) print("--> info: evaluation result is saved at {}".format(evaluation_res_info_path)) if __name__ == "__main__": if len(sys.argv) != 5: print("Usage: python {} gt_json_path pred_json_path output_dir iou_thresh".format(__file__)) exit(-1) else: print("--> info: {}".format(sys.argv)) gt_json_path, pred_json_path, output_dir, iou_thresh = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] label_dict = load_gt_from_json(gt_json_path) with open(pred_json_path, "r") as f: detect_dict = json.load(f) src_image_dir = None eval_and_show( label_dict, detect_dict, output_dir, iou_thresh=iou_thresh, map_info=None, src_image_dir=src_image_dir, visualization_dir=None, )