| import json |
| import torch |
| import numpy as np |
| import cv2 |
| import os |
| from segment_anything import sam_model_registry, SamPredictor |
| from lvis import LVIS |
| import copy |
| from pathlib import Path |
|
|
|
|
| class Objects365SAM(): |
| def __init__(self, index_low, index_high): |
| |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth" |
| model_type = "vit_h" |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
| sam.to(device=self.device) |
| self.predictor = SamPredictor(sam) |
|
|
| self.index_low = index_low |
| self.index_high = index_high |
|
|
| |
| def load_annotations(self, annotation_file): |
| with open(annotation_file, 'r') as f: |
| self.json_data = json.load(f) |
|
|
| def process_annotations_with_sam(self, images_dir, output_dir): |
| image_info_list = self.json_data['images'] |
| counter = 0 |
| for image_info in image_info_list[self.index_low:self.index_high]: |
| |
| image_id = image_info['id'] |
| image_name = image_info['file_name'].split('/')[-1] |
| image_subset = image_info['file_name'].split('/')[-2] |
|
|
| output_json_dir = Path(os.path.join(output_dir, image_subset)) |
| output_json_dir.mkdir(exist_ok=True) |
|
|
| image_path = os.path.join(images_dir, image_subset, image_name) |
|
|
| |
| image = cv2.imread(image_path) |
| if image is None: |
| print(f"Image not found: {image_path}") |
| continue |
| h, w, _ = image.shape |
| self.predictor.set_image(image) |
|
|
| |
| image_annotations = [anno for anno in self.json_data['annotations'] if anno['image_id'] == image_id] |
|
|
| |
| bounding_boxes = [] |
| for anno in image_annotations: |
| xmin, ymin, width, height = anno['bbox'] |
| xmax, ymax = xmin + width, ymin + height |
| bounding_boxes.append([xmin, ymin, xmax, ymax]) |
|
|
| |
| input_boxes = torch.tensor(bounding_boxes, device=self.device).float() |
| transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2]) |
|
|
| |
| with torch.no_grad(): |
| masks, scores, logits = self.predictor.predict_torch( |
| point_coords=None, |
| point_labels=None, |
| boxes=transformed_boxes, |
| multimask_output=False, |
| ) |
|
|
| |
| mask_annotations = [] |
| for mask in masks: |
| binary_mask = mask.squeeze().cpu().numpy().astype(np.uint8) |
| contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if len(contours) == 0: |
| continue |
| largest_contour = max(contours, key=cv2.contourArea) |
| segmentation = largest_contour.flatten().tolist() |
| x, y, w, h = cv2.boundingRect(largest_contour) |
| area = float(cv2.contourArea(largest_contour)) |
| |
| mask_annotations.append({ |
| "segmentation": [segmentation], |
| "bbox": [x, y, w, h], |
| "area": area, |
| "category_id": 1 |
| }) |
|
|
| save_annotations_to_json(image_id, |
| mask_annotations, |
| os.path.join(output_json_dir, image_name[:-4]+'.json') |
| ) |
| torch.cuda.empty_cache() |
| counter += 1 |
| print('Done image idex: ', counter) |
|
|
| def save_annotations_to_json(image_id, mask_annotations, output_file): |
| coco_format_output = { |
| "image_id": image_id, |
| "annotations": mask_annotations |
| } |
|
|
| with open(output_file, 'w') as f: |
| json.dump(coco_format_output, f) |
|
|
|
|
| if __name__ == "__main__": |
| ''' |
| Image number: train/test: 1742292/80000 |
| ''' |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Annotate labels with Segment Anything") |
| parser.add_argument('--is_train', action='store_true', help="Train/Test") |
| parser.add_argument("--index_low", type=int, default=0, help="Lower bound of indexes for processing Objects365 dataset.") |
| parser.add_argument("--index_high", type=int, default=1742292, help="Upper bound of indexes for processing Objects365 dataset.") |
| args = parser.parse_args() |
|
|
| if args.is_train: |
| input_json_dir = '../data/object365/zhiyuan_objv2_train.json' |
| input_image_dir = '../data/object365/images/train/' |
| output_dir = Path('../data/object365/labels/train/') |
| else: |
| input_json_dir = '../data/object365/zhiyuan_objv2_val.json' |
| input_image_dir = '../data/object365/images/val/' |
| output_dir = Path('../data/object365/labels/val/') |
|
|
| output_dir.mkdir(exist_ok=True) |
|
|
| sam_annotator = Objects365SAM(args.index_low, args.index_high) |
| sam_annotator.load_annotations(input_json_dir) |
| sam_annotator.process_annotations_with_sam(input_image_dir, output_dir) |
|
|
|
|
|
|