├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── stale.yml ├── .gitignore ├── Detectron2_Balloon_Instance_Segmentation.ipynb ├── Microcontroller_Instance_Segmentation.ipynb ├── Microcontroller_Instance_Segmentation_with_COCO_dataformat.ipynb ├── README.md ├── doc ├── detectron_visualize_segmentations.png ├── label_images.PNG ├── labelme_example.jpg └── prediction_example.PNG ├── labelme2coco.py ├── microcontroller_segmentation_data.zip └── resize_labelme_dataset.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: TannerGilbert 4 | patreon: gilberttanner 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Mark stale issues and pull requests 2 | 3 | on: 4 | schedule: 5 | - cron: "0 0 * * *" 6 | 7 | jobs: 8 | stale: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/stale@v1 14 | with: 15 | repo-token: ${{ secrets.GITHUB_TOKEN }} 16 | stale-pr-message: 'Stale pull request message' 17 | stale-issue-label: 'no-issue-activity' 18 | stale-pr-label: 'no-pr-activity' 19 | stale-issue-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days' 20 | days-before-stale: 30 21 | days-before-close: 5 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Detectron2 Train a custom Instance Segmentation Model 2 | 3 | ![](doc/detectron_visualize_segmentations.png) 4 | 5 | ## 1. Installation 6 | 7 | See the official [installation guide](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). 8 | 9 | ## 2. Gathering data 10 | 11 | Gathering image data is simple. You can either take pictures yourself using some kind of camera, or you can download images from the internet. 12 | 13 | To build a robust model, you need pictures with different backgrounds, varying lighting conditions as well as random objects in the background. 14 | 15 | For my microcontroller data-set, I have four different objects ((Arduino Nano, ESP8266, Raspberry Pi 3, Heltect ESP32 Lora). I took about 25 pictures of each microcontroller and 25 containing multiple microcontrollers using my smartphone. After taking the pictures, make sure to transform them to a resolution suitable for training (I used 800x600). 16 | 17 | ## 3. Labeling data 18 | 19 | After you have gathered enough images, it's time to label them, so your model knows what to learn. In order to label the data, you will need to use some kind of labeling software. 20 | 21 | For object detection, we used [LabelImg](https://github.com/tzutalin/labelImg), an excellent image annotation tool supporting both PascalVOC and Yolo format. For Image Segmentation / Instance Segmentation there are multiple great annotations tools available. Including, [VGG Image Annotation Tool](http://www.robots.ox.ac.uk/~vgg/software/via/), [labelme](https://github.com/wkentaro/labelme), and [PixelAnnotationTool](https://github.com/abreheret/PixelAnnotationTool). I chose labelme, because of it's simplicity to both install and use. 22 | 23 | ![](doc/labelme_example.jpg) 24 | 25 | ## 4. Registering the data-set 26 | 27 | Detectron2 gives you multiple options to register your instance segmentation data-set. Which one you use will depend on what data you have. If you labeled your data with labelme or the VGG Image Annotation Tool I recommend you to pass the ```segmentation``` parameter as shown below for the microcontroller data-set: 28 | 29 | ```python 30 | import os 31 | import numpy as np 32 | import json 33 | from detectron2.structures import BoxMode 34 | 35 | def get_microcontroller_dicts(directory): 36 | classes = ['Raspberry_Pi_3', 'Arduino_Nano', 'ESP8266', 'Heltec_ESP32_Lora'] 37 | dataset_dicts = [] 38 | for filename in [file for file in os.listdir(directory) if file.endswith('.json')]: 39 | json_file = os.path.join(directory, filename) 40 | with open(json_file) as f: 41 | img_anns = json.load(f) 42 | 43 | record = {} 44 | 45 | filename = os.path.join(directory, img_anns["imagePath"]) 46 | 47 | record["file_name"] = filename 48 | record["height"] = 600 49 | record["width"] = 800 50 | 51 | annos = img_anns["shapes"] 52 | objs = [] 53 | for anno in annos: 54 | px = [a[0] for a in anno['points']] 55 | py = [a[1] for a in anno['points']] 56 | poly = [(x, y) for x, y in zip(px, py)] 57 | poly = [p for x in poly for p in x] 58 | 59 | obj = { 60 | "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)], 61 | "bbox_mode": BoxMode.XYXY_ABS, 62 | "segmentation": [poly], 63 | "category_id": classes.index(anno['label']), 64 | "iscrowd": 0 65 | } 66 | objs.append(obj) 67 | record["annotations"] = objs 68 | dataset_dicts.append(record) 69 | return dataset_dicts 70 | 71 | from detectron2.data import DatasetCatalog, MetadataCatalog 72 | for d in ["train", "test"]: 73 | DatasetCatalog.register("microcontroller_" + d, lambda d=d: get_microcontroller_dicts('Microcontroller Segmentation/' + d)) 74 | MetadataCatalog.get("microcontroller_" + d).set(thing_classes=['Raspberry_Pi_3', 'Arduino_Nano', 'ESP8266', 'Heltec_ESP32_Lora']) 75 | microcontroller_metadata = MetadataCatalog.get("microcontroller_train") 76 | ``` 77 | 78 | You can also use ```sem_seg_file_name``` or the ```sem_seg``` parameters if it works better for your data-set. 79 | 80 | ## 5. Training the model 81 | 82 | Training the model works just the same as training an object detection model. The only difference is that now you'll need to use an instance segmentation model instead of an object detection model. 83 | 84 | 85 | ```python 86 | from detectron2.engine import DefaultTrainer 87 | from detectron2.config import get_cfg 88 | 89 | cfg = get_cfg() 90 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 91 | cfg.DATASETS.TRAIN = ("microcontroller_train",) 92 | cfg.DATASETS.TEST = () 93 | cfg.DATALOADER.NUM_WORKERS = 2 94 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") 95 | cfg.SOLVER.IMS_PER_BATCH = 2 96 | cfg.SOLVER.BASE_LR = 0.00025 97 | cfg.SOLVER.MAX_ITER = 1000 98 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 4 99 | 100 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 101 | trainer = DefaultTrainer(cfg) 102 | trainer.resume_or_load(resume=False) 103 | trainer.train() 104 | ``` 105 | 106 | ## 6. Using the model for inference 107 | 108 | Now, we can perform inference on our validation set by creating a predictor object. 109 | 110 | ```python 111 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") 112 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 113 | cfg.DATASETS.TEST = ("microcontroller_test", ) 114 | predictor = DefaultPredictor(cfg) 115 | ``` 116 | 117 | ```python 118 | from detectron2.utils.visualizer import ColorMode 119 | dataset_dicts = get_microcontroller_dicts('Microcontroller Segmentation/test') 120 | for d in random.sample(dataset_dicts, 3): 121 | im = cv2.imread(d["file_name"]) 122 | outputs = predictor(im) 123 | v = Visualizer(im[:, :, ::-1], 124 | metadata=microcontroller_metadata, 125 | scale=0.8, 126 | instance_mode=ColorMode.IMAGE_BW # remove the colors of unsegmented pixels 127 | ) 128 | v = v.draw_instance_predictions(outputs["instances"].to("cpu")) 129 | plt.figure(figsize = (14, 10)) 130 | plt.imshow(cv2.cvtColor(v.get_image()[:, :, ::-1], cv2.COLOR_BGR2RGB)) 131 | plt.show() 132 | ``` 133 | 134 | ![](doc/prediction_example.PNG) 135 | 136 | -------------------------------------------------------------------------------- /doc/detectron_visualize_segmentations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TannerGilbert/Detectron2-Train-a-Instance-Segmentation-Model/7d3cc8f11af0fef6d416125c6738e78f01588136/doc/detectron_visualize_segmentations.png -------------------------------------------------------------------------------- /doc/label_images.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TannerGilbert/Detectron2-Train-a-Instance-Segmentation-Model/7d3cc8f11af0fef6d416125c6738e78f01588136/doc/label_images.PNG -------------------------------------------------------------------------------- /doc/labelme_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TannerGilbert/Detectron2-Train-a-Instance-Segmentation-Model/7d3cc8f11af0fef6d416125c6738e78f01588136/doc/labelme_example.jpg -------------------------------------------------------------------------------- /doc/prediction_example.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TannerGilbert/Detectron2-Train-a-Instance-Segmentation-Model/7d3cc8f11af0fef6d416125c6738e78f01588136/doc/prediction_example.PNG -------------------------------------------------------------------------------- /labelme2coco.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/Tony607/labelme2coco/blob/master/labelme2coco.py 2 | 3 | import os 4 | import json 5 | 6 | from labelme import utils 7 | import numpy as np 8 | import glob 9 | import PIL.Image 10 | 11 | 12 | class labelme2coco(object): 13 | def __init__(self, labelme_json=[], save_json_path="./coco.json"): 14 | """ 15 | :param labelme_json: the list of all labelme json file paths 16 | :param save_json_path: the path to save new json 17 | """ 18 | self.labelme_json = labelme_json 19 | self.save_json_path = save_json_path 20 | self.images = [] 21 | self.categories = [] 22 | self.annotations = [] 23 | self.label = [] 24 | self.annID = 1 25 | self.height = 0 26 | self.width = 0 27 | 28 | self.save_json() 29 | 30 | def data_transfer(self): 31 | for num, json_file in enumerate(self.labelme_json): 32 | with open(json_file, "r") as fp: 33 | data = json.load(fp) 34 | self.images.append(self.image(data, num)) 35 | for shapes in data["shapes"]: 36 | label = shapes["label"].split("_") 37 | if label not in self.label: 38 | self.label.append(label) 39 | points = shapes["points"] 40 | self.annotations.append(self.annotation(points, label, num)) 41 | self.annID += 1 42 | 43 | # Sort all text labels so they are in the same order across data splits. 44 | self.label.sort() 45 | for label in self.label: 46 | self.categories.append(self.category(label)) 47 | for annotation in self.annotations: 48 | annotation["category_id"] = self.getcatid(annotation["category_id"]) 49 | 50 | def image(self, data, num): 51 | image = {} 52 | img = utils.img_b64_to_arr(data["imageData"]) 53 | height, width = img.shape[:2] 54 | img = None 55 | image["height"] = height 56 | image["width"] = width 57 | image["id"] = num 58 | image["file_name"] = data["imagePath"].split("/")[-1] 59 | 60 | self.height = height 61 | self.width = width 62 | 63 | return image 64 | 65 | def category(self, label): 66 | category = {} 67 | category["supercategory"] = label[0] 68 | category["id"] = len(self.categories) 69 | category["name"] = label[0] 70 | return category 71 | 72 | def annotation(self, points, label, num): 73 | annotation = {} 74 | contour = np.array(points) 75 | x = contour[:, 0] 76 | y = contour[:, 1] 77 | area = 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) 78 | annotation["segmentation"] = [list(np.asarray(points).flatten())] 79 | annotation["iscrowd"] = 0 80 | annotation["area"] = area 81 | annotation["image_id"] = num 82 | 83 | annotation["bbox"] = list(map(float, self.getbbox(points))) 84 | 85 | annotation["category_id"] = label[0] # self.getcatid(label) 86 | annotation["id"] = self.annID 87 | return annotation 88 | 89 | def getcatid(self, label): 90 | for category in self.categories: 91 | if label == category["name"]: 92 | return category["id"] 93 | print("label: {} not in categories: {}.".format(label, self.categories)) 94 | exit() 95 | return -1 96 | 97 | def getbbox(self, points): 98 | polygons = points 99 | mask = self.polygons_to_mask([self.height, self.width], polygons) 100 | return self.mask2box(mask) 101 | 102 | def mask2box(self, mask): 103 | 104 | index = np.argwhere(mask == 1) 105 | rows = index[:, 0] 106 | clos = index[:, 1] 107 | 108 | left_top_r = np.min(rows) # y 109 | left_top_c = np.min(clos) # x 110 | 111 | right_bottom_r = np.max(rows) 112 | right_bottom_c = np.max(clos) 113 | 114 | return [ 115 | left_top_c, 116 | left_top_r, 117 | right_bottom_c - left_top_c, 118 | right_bottom_r - left_top_r, 119 | ] 120 | 121 | def polygons_to_mask(self, img_shape, polygons): 122 | mask = np.zeros(img_shape, dtype=np.uint8) 123 | mask = PIL.Image.fromarray(mask) 124 | xy = list(map(tuple, polygons)) 125 | PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1) 126 | mask = np.array(mask, dtype=bool) 127 | return mask 128 | 129 | def data2coco(self): 130 | data_coco = {} 131 | data_coco["images"] = self.images 132 | data_coco["categories"] = self.categories 133 | data_coco["annotations"] = self.annotations 134 | return data_coco 135 | 136 | def save_json(self): 137 | print("save coco json") 138 | self.data_transfer() 139 | self.data_coco = self.data2coco() 140 | 141 | print(self.save_json_path) 142 | os.makedirs( 143 | os.path.dirname(os.path.abspath(self.save_json_path)), exist_ok=True 144 | ) 145 | json.dump(self.data_coco, open(self.save_json_path, "w"), indent=4) 146 | 147 | 148 | if __name__ == "__main__": 149 | import argparse 150 | 151 | parser = argparse.ArgumentParser( 152 | description="labelme annotation to coco data json file." 153 | ) 154 | parser.add_argument( 155 | "labelme_images", 156 | help="Directory to labelme images and annotation json files.", 157 | type=str, 158 | ) 159 | parser.add_argument( 160 | "--output", help="Output json file path.", default="trainval.json" 161 | ) 162 | args = parser.parse_args() 163 | labelme_json = glob.glob(os.path.join(args.labelme_images, "*.json")) 164 | labelme2coco(labelme_json, args.output) 165 | -------------------------------------------------------------------------------- /microcontroller_segmentation_data.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TannerGilbert/Detectron2-Train-a-Instance-Segmentation-Model/7d3cc8f11af0fef6d416125c6738e78f01588136/microcontroller_segmentation_data.zip -------------------------------------------------------------------------------- /resize_labelme_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from io import BytesIO 4 | import base64 5 | from PIL import Image 6 | import json 7 | 8 | 9 | def resize_images_and_labels(labelme_json_paths, output_dir, size): 10 | os.makedirs(output_dir, exist_ok=True) 11 | 12 | for json_file in labelme_json_paths: 13 | # Open json file 14 | with open(json_file, 'r') as f: 15 | data = json.load(f) 16 | 17 | # Load base64 image 18 | im = Image.open(BytesIO(base64.b64decode(data['imageData']))) 19 | 20 | # Resize image 21 | im_resized = im.resize(size, Image.ANTIALIAS) 22 | 23 | # Change imageHeight and imageWidth in json 24 | data['imageWidth'] = size[0] 25 | data['imageHeight'] = size[1] 26 | 27 | # Change imageData 28 | buffered = BytesIO() 29 | im_resized.save(buffered, format="JPEG") 30 | data['imageData'] = base64.b64encode(buffered.getvalue()).decode() 31 | 32 | # Change datapoints 33 | width_ratio = im_resized.size[0] / im.size[0] 34 | height_ratio = im_resized.size[1] / im.size[1] 35 | for annotation in data['shapes']: 36 | resized_points = [] 37 | for point in annotation['points']: 38 | resized_points.append([point[0] * width_ratio, point[1] * height_ratio]) 39 | annotation['points'] = resized_points 40 | 41 | # Save image 42 | im_resized.save(os.path.join(output_dir, data['imagePath'])) 43 | 44 | # Save json file 45 | with open(os.path.join(output_dir, os.path.basename(json_file)), 'w') as f: 46 | json.dump(data, f) 47 | 48 | 49 | if __name__ == '__main__': 50 | import argparse 51 | 52 | parser = argparse.ArgumentParser(description='Resize size of images with labelme labels') 53 | parser.add_argument('--input_dir', type=str, required=True, help='Directory to labelme images and annotation json files') 54 | parser.add_argument('--output_dir', type=str, required=True, help='Path to directory where new images and labels will be saved') 55 | parser.add_argument('--size', type=int, nargs=2, required=True, metavar=('width', 'height'), help='Image size') 56 | args = parser.parse_args() 57 | 58 | labelme_json = glob.glob(os.path.join(args.input_dir, "*.json")) 59 | resize_images_and_labels(labelme_json, args.output_dir, args.size) 60 | --------------------------------------------------------------------------------