├── .gitignore ├── README.md ├── config ├── e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml └── e2e_mask_rcnn_X_101_32x8d_FPN_1x_1gpu.yaml ├── create_dataset.py ├── create_submission.py ├── integrate_results.py ├── requirements.txt ├── test.py ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### https://raw.github.com/github/gitignore/f57304e9762876ae4c9b02867ed0cb887316387e/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | /.idea/ 102 | 103 | .DS_Store 104 | 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # open-images-2019-instance-segmentation 2 | 3 | Codes for [Open Images 2019 - Instance Segmentation competition](https://www.kaggle.com/c/open-images-2019-instance-segmentation) 4 | using [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark). 5 | 6 | The result is not outstanding but the solution might be valuable to be shared 7 | because it used the famous maskrcnn-benchmark library 'as it is' and also used its outputs as it is without TTA or any post processing. 8 | 9 | The detailed solution can be found in [kaggle discussion](https://www.kaggle.com/c/open-images-2019-instance-segmentation/discussion/110908#latest-638554). 10 | 11 | ## Preparation 12 | 13 | 1. Install maskrcnn_bencmark according to [official guide](https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/INSTALL.md). 14 | 2. Download the [Open Images dataset](https://storage.googleapis.com/openimages/web/download.html) to the project root directory (or make sim link). 15 | 16 | ``` 17 | PROJECT_ROOT 18 | ├── README.md 19 | ├── config 20 | │   └── e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml 21 | ├── create_dataset.py 22 | ├── create_submission.py 23 | ├── datasets 24 | │   ├── challenge-2019-label300-segmentable-hierarchy.json 25 | │   ├── challenge-2019-train-segmentation-masks.csv 26 | │   ├── test 27 | │   ├── train 28 | │   └── train_masks 29 | ├── test.py 30 | ├── train.py 31 | └── util.py 32 | ``` 33 | 34 | - `test`: test images (99,999 images) 35 | - `train`: train images 36 | - `train_masks`: train mask images 37 | 38 | Trained models are available from [Kaggle Dataset](https://www.kaggle.com/ren4yu/openimages2019instancesegmentationmodels). 39 | If you use the trained models, please skip to the 'Test for Layer 0 Classes' section. 40 | 41 | ## Train on Open Images Dataset 42 | 43 | ### Create Dataset for Layer 0 Classes 44 | 45 | Create COCO format dataset for layer 0 class. 46 | 47 | ```bash 48 | python create_dataset.py -l 0 49 | ``` 50 | 51 | The COCO format dataset is created as: 52 | 53 | ``` 54 | PROJECT_ROOT 55 | ├── datasets 56 | │   └── coco 57 | │   ├── annotations 58 | │ └── train2017 59 | ``` 60 | 61 | This is the COCO-based format, thus it can be used on the other library like [mmdetection](https://github.com/open-mmlab/mmdetection) (but not tested). 62 | 63 | ### Train for Layer 0 Classes 64 | 65 | Train on 8GPUs. This requires only 14 hours using V100 8GPUs 66 | 67 | ```bash 68 | python -m torch.distributed.launch --nproc_per_node=8 train.py OUTPUT_DIR "layer0" 69 | ``` 70 | 71 | Train on a single GPU. This requires about 4 days using a V100 GPU. 72 | 73 | ```bash 74 | python train.py --config-file config/e2e_mask_rcnn_X_101_32x8d_FPN_1x_1gpu.yaml OUTPUT_DIR "layer0" 75 | ``` 76 | 77 | Training steps can be reduced without large degradation of accuracy. 78 | The following should requires only a day for training with a single GPU. 79 | 80 | ```bash 81 | python train.py --config-file config/e2e_mask_rcnn_X_101_32x8d_FPN_1x_1gpu.yaml OUTPUT_DIR "layer0" SOLVER.STEPS "(70000, 100000)" SOLVER.MAX_ITER 120000 82 | ``` 83 | 84 | ### Test for Layer 0 Classes 85 | 86 | ```bash 87 | python test.py -l 0 --weight [TRAINED_WEIGHT_PATH (e.g. layer0/model_0060000.pth)] 88 | ``` 89 | 90 | The resulting files will created as: 91 | 92 | ``` 93 | PROJECT_ROOT 94 | ├── datasets 95 | │   └── test_0_results 96 | ``` 97 | 98 | ### Create Submission File for Layer 0 Classes 99 | 100 | ```bash 101 | python create_submission.py -l 0 102 | ``` 103 | 104 | ### Create Submission File for Layer 1 Classes 105 | 106 | Do the same procedure also for layer 1 classes: 107 | 108 | ```bash 109 | python create_dataset.py -l 1 # this overwrite layer 0 dataset. Please move it if needed later 110 | python -m torch.distributed.launch --nproc_per_node=8 train.py OUTPUT_DIR "layer1" 111 | python test.py -l 1 --weight [TRAINED_WEIGHT_PATH (e.g. layer1/model_0060000.pth)] 112 | python create_submission.py -l 1 113 | ``` 114 | 115 | ### Integrate Two Submission Files 116 | 117 | ```bash 118 | python integrate_results.py --input1 output_0.csv --input2 output_1.csv 119 | ``` 120 | 121 | OK, let's submit the resulting file `integrated_result.csv` ! 122 | 123 | -------------------------------------------------------------------------------- /config/e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d" 4 | BACKBONE: 5 | CONV_BODY: "R-101-FPN" 6 | RESNETS: 7 | BACKBONE_OUT_CHANNELS: 256 8 | STRIDE_IN_1X1: False 9 | NUM_GROUPS: 32 10 | WIDTH_PER_GROUP: 8 11 | RPN: 12 | USE_FPN: True 13 | ANCHOR_STRIDE: (4, 8, 16, 32, 64) 14 | PRE_NMS_TOP_N_TRAIN: 2000 15 | PRE_NMS_TOP_N_TEST: 1000 16 | POST_NMS_TOP_N_TEST: 1000 17 | FPN_POST_NMS_TOP_N_TEST: 1000 18 | ROI_HEADS: 19 | USE_FPN: True 20 | ROI_BOX_HEAD: 21 | POOLER_RESOLUTION: 7 22 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 23 | POOLER_SAMPLING_RATIO: 2 24 | FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" 25 | NUM_CLASSES: 220 26 | PREDICTOR: "FPNPredictor" 27 | ROI_MASK_HEAD: 28 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 29 | FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" 30 | PREDICTOR: "MaskRCNNC4Predictor" 31 | POOLER_RESOLUTION: 14 32 | POOLER_SAMPLING_RATIO: 2 33 | RESOLUTION: 28 34 | SHARE_BOX_FEATURE_EXTRACTOR: False 35 | MASK_ON: True 36 | DATASETS: 37 | TRAIN: ("coco_2017_train",) 38 | DATALOADER: 39 | SIZE_DIVISIBILITY: 32 40 | SOLVER: 41 | BASE_LR: 0.01 42 | WEIGHT_DECAY: 0.0001 43 | STEPS: (35000, 50000) 44 | MAX_ITER: 60000 45 | CHECKPOINT_PERIOD: 200 46 | IMS_PER_BATCH: 16 47 | OUTPUT_DIR: "my" 48 | -------------------------------------------------------------------------------- /config/e2e_mask_rcnn_X_101_32x8d_FPN_1x_1gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | WEIGHT: "catalog://ImageNetPretrained/FAIR/20171220/X-101-32x8d" 4 | BACKBONE: 5 | CONV_BODY: "R-101-FPN" 6 | RESNETS: 7 | BACKBONE_OUT_CHANNELS: 256 8 | STRIDE_IN_1X1: False 9 | NUM_GROUPS: 32 10 | WIDTH_PER_GROUP: 8 11 | RPN: 12 | USE_FPN: True 13 | ANCHOR_STRIDE: (4, 8, 16, 32, 64) 14 | PRE_NMS_TOP_N_TRAIN: 2000 15 | PRE_NMS_TOP_N_TEST: 1000 16 | POST_NMS_TOP_N_TEST: 1000 17 | FPN_POST_NMS_TOP_N_TEST: 1000 18 | ROI_HEADS: 19 | USE_FPN: True 20 | ROI_BOX_HEAD: 21 | POOLER_RESOLUTION: 7 22 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 23 | POOLER_SAMPLING_RATIO: 2 24 | FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" 25 | NUM_CLASSES: 220 26 | PREDICTOR: "FPNPredictor" 27 | ROI_MASK_HEAD: 28 | POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) 29 | FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" 30 | PREDICTOR: "MaskRCNNC4Predictor" 31 | POOLER_RESOLUTION: 14 32 | POOLER_SAMPLING_RATIO: 2 33 | RESOLUTION: 28 34 | SHARE_BOX_FEATURE_EXTRACTOR: False 35 | MASK_ON: True 36 | DATASETS: 37 | TRAIN: ("coco_2017_train",) 38 | DATALOADER: 39 | SIZE_DIVISIBILITY: 32 40 | SOLVER: 41 | BASE_LR: 0.01 42 | WEIGHT_DECAY: 0.0001 43 | STEPS: (280000, 400000) 44 | MAX_ITER: 480000 45 | CHECKPOINT_PERIOD: 200 46 | IMS_PER_BATCH: 2 47 | OUTPUT_DIR: "my" 48 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | from collections import defaultdict 5 | from itertools import chain 6 | import json 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import cv2 10 | 11 | from util import get_hierarchy, find_contour 12 | 13 | 14 | def get_args(): 15 | parser = argparse.ArgumentParser(description="This script creates coco format dataset for maskrcnn-benchmark", 16 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 17 | parser.add_argument("--layer", "-l", type=int, default=0, 18 | help="target layer; 0 or 1") 19 | parser.add_argument("--mode", "-m", type=str, default="train", 20 | help="target dataset; train or validation") 21 | parser.add_argument("--img_num", type=int, default=1500, 22 | help="max image num for each class") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def calc_overlap_rate(rect1, rect2): 28 | x_left = max(rect1.x1, rect2.x1) 29 | y_top = max(rect1.y1, rect2.y1) 30 | x_right = min(rect1.x2, rect2.x2) 31 | y_bottom = min(rect1.y2, rect2.y2) 32 | intersection = max(0, x_right - x_left) * max(0, y_bottom - y_top) 33 | iou = intersection / rect1.area 34 | 35 | return iou 36 | 37 | 38 | class Rect: 39 | def __init__(self, x1, y1, x2, y2): 40 | self.x1 = x1 41 | self.y1 = y1 42 | self.x2 = x2 43 | self.y2 = y2 44 | self.cx = (x1 + x2) / 2 45 | self.cy = (y1 + y2) / 2 46 | self.area = (self.x2 - self.x1) * (self.y2 - self.y1) 47 | 48 | def is_inside(self, x, y): 49 | return self.x1 <= x <= self.x2 and self.y1 <= y <= self.y2 50 | 51 | 52 | def main(): 53 | args = get_args() 54 | layer = args.layer 55 | mode = args.mode 56 | images = [] 57 | annotations = [] 58 | layer0_class_strings, layer1_class_strings, class_string_to_parent = get_hierarchy() 59 | target_class_strings = get_hierarchy()[layer] 60 | target_class_string_to_class_id = {class_string: i + 1 for i, class_string in 61 | enumerate(sorted(target_class_strings))} 62 | parent_class_strings = list(set(class_string_to_parent.values())) 63 | layer0_independent_class_strings = [class_string for class_string in layer0_class_strings if 64 | class_string not in parent_class_strings] 65 | data_dir = Path(__file__).parent.joinpath("datasets") 66 | img_dir = data_dir.joinpath(f"{mode}") 67 | mask_dir = data_dir.joinpath(f"{mode}_masks") 68 | mask_csv_path = data_dir.joinpath(f"challenge-2019-{mode}-segmentation-masks.csv") 69 | df = pd.read_csv(str(mask_csv_path)) 70 | 71 | output_dir = data_dir.joinpath("coco") 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | output_annotation_dir = output_dir.joinpath("annotations") 74 | output_annotation_dir.mkdir(parents=True, exist_ok=True) 75 | output_img_dir = output_dir.joinpath("train2017") 76 | output_img_dir.mkdir(parents=True, exist_ok=True) 77 | 78 | class_string_to_img_ids = defaultdict(list) 79 | img_id_to_meta = defaultdict(list) 80 | 81 | print("=> parsing {}".format(mask_csv_path.name)) 82 | 83 | for i, row in tqdm(df.iterrows(), total=len(df)): 84 | mask_path, img_id, label_name, _, xp1, xp2, yp1, yp2, _, _ = row.values 85 | class_string_to_img_ids[label_name].append(img_id) 86 | img_id_to_meta[img_id].append({"mask_path": mask_path, "class_string": label_name, 87 | "bbox": [xp1, xp2, yp1, yp2]}) 88 | 89 | # use only args.img_num images for each class 90 | target_img_ids = list( 91 | set(chain.from_iterable( 92 | [class_string_to_img_ids[class_string][:args.img_num] for class_string in target_class_strings]))) 93 | print("=> use {} images for training".format(len(target_img_ids))) 94 | bbox_id = 0 95 | 96 | for i, img_id in enumerate(tqdm(target_img_ids)): 97 | added = False 98 | img_path = img_dir.joinpath(img_id + ".jpg") 99 | img = cv2.imread(str(img_path), 1) 100 | h, w, _ = img.shape 101 | target_rects = [] 102 | 103 | # collect target bboxes 104 | for m in img_id_to_meta[img_id]: 105 | class_string = m["class_string"] 106 | 107 | # non target 108 | if class_string not in target_class_strings: 109 | continue 110 | 111 | xp1, xp2, yp1, yp2 = m["bbox"] 112 | target_rects.append(Rect(xp1, yp1, xp2, yp2)) 113 | 114 | for m in img_id_to_meta[img_id]: 115 | class_string = m["class_string"] 116 | xp1, xp2, yp1, yp2 = m["bbox"] 117 | x1, x2, y1, y2 = xp1 * w, xp2 * w, yp1 * h, yp2 * h 118 | 119 | # for layer1: remove layer0 classes with no child class 120 | if layer == 1 and class_string in layer0_independent_class_strings: 121 | continue 122 | 123 | # for both layer0 and layer1: non-target object is removed if it occludes target bbox over 25% 124 | if class_string not in target_class_strings: 125 | curr_rect = Rect(xp1, yp1, xp2, yp2) 126 | overlap_rate = max([calc_overlap_rate(r, curr_rect) for r in target_rects]) 127 | 128 | if overlap_rate > 0.25: 129 | continue 130 | 131 | # layer0: convert layer1 and layer2 classes to their parent layer0 classes 132 | if layer == 0: 133 | if class_string in class_string_to_parent.keys(): 134 | class_string = class_string_to_parent[class_string] 135 | 136 | if class_string in class_string_to_parent.keys(): # needed for layer2 classes 137 | class_string = class_string_to_parent[class_string] 138 | 139 | if class_string in target_class_strings: 140 | mask_path = mask_dir.joinpath(m["mask_path"]) 141 | mask_img = cv2.imread(str(mask_path), 0) 142 | mask_img = cv2.resize(mask_img, (w, h), cv2.INTER_NEAREST) 143 | contour = find_contour(mask_img) 144 | contour = [p for p in contour if len(p) > 4] 145 | 146 | if not contour: 147 | continue 148 | 149 | class_id = target_class_string_to_class_id[class_string] 150 | gt_bbox = [x1, y1, x2 - x1, y2 - y1] 151 | box_w, box_h = x2 - x1, y2 - y1 152 | 153 | if box_w < 10 or box_h < 10: 154 | print(box_w, box_h) 155 | 156 | annotations.append({# "area": gt_h * gt_h, 157 | "segmentation": contour, 158 | "iscrowd": 0, 159 | "image_id": i, 160 | "bbox": gt_bbox, 161 | "category_id": class_id, 162 | "id": bbox_id}) 163 | bbox_id += 1 164 | added = True 165 | 166 | # for layer1: fill non-target bbox with gray; this class has its child class in layer1 167 | else: 168 | x1, x2, y1, y2 = int(x1), int(x2), int(y1), int(y2) 169 | img[y1:y2, x1:x2] = 128 170 | 171 | if not added: 172 | continue 173 | 174 | images.append({"file_name": img_path.name, 175 | "height": h, 176 | "width": w, 177 | "id": i}) 178 | 179 | output_img_path = output_img_dir.joinpath(img_path.name) 180 | cv2.imwrite(str(output_img_path), img) 181 | 182 | categories = [{"supercategory": "object", "id": class_id, "name": class_string} for class_string, class_id in 183 | target_class_string_to_class_id.items()] 184 | 185 | with output_annotation_dir.joinpath("instances_train2017.json").open("w") as f: 186 | json.dump({"images": images, "annotations": annotations, "categories": categories}, f) 187 | 188 | 189 | if __name__ == '__main__': 190 | main() 191 | -------------------------------------------------------------------------------- /create_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | import json 5 | from tqdm import tqdm 6 | import numpy as np 7 | import pandas as pd 8 | import cv2 9 | from joblib import Parallel, delayed 10 | from util import encode_binary_mask, get_hierarchy 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description="This script creates submission file from maskrcnn results.", 15 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 16 | parser.add_argument("--th", type=int, default=128, 17 | help="threshold for mask") 18 | parser.add_argument("--layer", "-l", type=int, default=0, 19 | help="target layer; 0 or 1") 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def create_result_per_img(json_path, results_dir, th, class_strings): 25 | with json_path.open("r") as f: 26 | d = json.load(f) 27 | 28 | new_w, new_h = d["new_img_size"] 29 | boxes = d["boxes"] 30 | scores = d["scores"] 31 | labels = d["labels"] 32 | img_id = json_path.stem 33 | predicted_strings = [] 34 | 35 | for i, (box, score, label) in enumerate(zip(boxes, scores, labels)): 36 | mask_path = results_dir.joinpath("{}_{}.png".format(json_path.stem, i)) 37 | predicted_mask_img = cv2.imread(str(mask_path), 0) 38 | mask_img = (predicted_mask_img > th) 39 | mask_string = encode_binary_mask(mask_img) 40 | class_string = class_strings[label - 1] 41 | predicted_string = "{} {} {}".format(class_string, score, mask_string.decode()) 42 | predicted_strings.append(predicted_string) 43 | 44 | predicted_strings = " ".join(predicted_strings) 45 | 46 | return img_id, new_w, new_h, predicted_strings 47 | 48 | 49 | def main(): 50 | args = get_args() 51 | th = args.th 52 | layer = args.layer 53 | data_dir = Path(__file__).parent.joinpath("datasets") 54 | results_dir = data_dir.joinpath(f"test_{layer}_results") 55 | layer0, layer1, class_string_to_parent = get_hierarchy() 56 | class_strings = layer0 if layer == 0 else layer1 57 | 58 | r = Parallel(n_jobs=-1, verbose=10)( 59 | [delayed(create_result_per_img)(json_path, results_dir, th, class_strings) for json_path in 60 | results_dir.glob("*.json")]) 61 | 62 | df = pd.DataFrame(data=r, columns=["ImageID", "ImageWidth", "ImageHeight", "PredictionString"]) 63 | df.dropna(inplace=True) 64 | df["ImageWidth"] = df["ImageWidth"].astype(np.int) 65 | df["ImageHeight"] = df["ImageHeight"].astype(np.int) 66 | df.to_csv("output_{}.csv".format(layer), index=False) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /integrate_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | import pandas as pd 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument("--input1", type=str, required=True, 9 | help="input layer 0 submissison file") 10 | parser.add_argument("--input2", type=str, required=True, 11 | help="input layer 1 submissison file") 12 | args = parser.parse_args() 13 | return args 14 | 15 | 16 | def main(): 17 | args = get_args() 18 | df1 = pd.read_csv(args.input1, index_col="ImageID") 19 | df1.sort_index(axis=0, inplace=True) 20 | df2 = pd.read_csv(args.input2, index_col="ImageID") 21 | df2.sort_index(axis=0, inplace=True) 22 | results = [] 23 | 24 | for i, ((index1, row1), (index2, row2)) in enumerate(zip(df1.iterrows(), df2.iterrows())): 25 | assert(index1 == index2) 26 | s1 = row1["PredictionString"] 27 | s2 = row2["PredictionString"] 28 | 29 | if isinstance(s1, float) and isinstance(s2, float): 30 | results.append("") 31 | elif isinstance(s1, float): 32 | results.append(s2) 33 | elif isinstance(s2, float): 34 | results.append(s1) 35 | else: 36 | results.append(" ".join([s1, s2])) 37 | 38 | df1.PredictionString = results 39 | df1.to_csv("integrated_result.csv") 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | better_exceptions 2 | tqdm 3 | numpy 4 | pandas 5 | torch 6 | opencv-python 7 | torchvision 8 | yacs 9 | scipy 10 | future 11 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import json 4 | import math 5 | from tqdm import tqdm 6 | import cv2 7 | import torch 8 | from torchvision import transforms as T 9 | from maskrcnn_benchmark.modeling.detector import build_detection_model 10 | from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer 11 | from maskrcnn_benchmark.structures.image_list import to_image_list 12 | from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker 13 | from maskrcnn_benchmark import layers as L 14 | from maskrcnn_benchmark.utils import cv2_util 15 | from maskrcnn_benchmark.config import cfg 16 | 17 | 18 | def get_args(): 19 | parser = argparse.ArgumentParser(description="This script detects objects and store results", 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser.add_argument("--weight", type=str, required=True) 22 | parser.add_argument("--layer", "-l", type=int, default=0, 23 | help="target layer; 0 or 1") 24 | parser.add_argument("--mode", "-m", type=str, default="test", 25 | help="target dataset; train or validation") 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | class Predictor(object): 31 | # COCO categories for pretty print 32 | CATEGORIES = [ 33 | "__background", 34 | ] 35 | 36 | def __init__(self, cfg, confidence_threshold=0.5): 37 | self.cfg = cfg.clone() 38 | self.model = build_detection_model(cfg) 39 | self.model.eval() 40 | self.device = torch.device(cfg.MODEL.DEVICE) 41 | self.model.to(self.device) 42 | checkpointer = DetectronCheckpointer(cfg, self.model) 43 | _ = checkpointer.load(cfg.MODEL.WEIGHT) 44 | self.transforms = self.build_transform() 45 | self.cpu_device = torch.device("cpu") 46 | self.confidence_threshold = confidence_threshold 47 | 48 | show_mask_heatmaps = True 49 | mask_threshold = -1 if show_mask_heatmaps else 0.5 50 | self.masker = Masker(threshold=mask_threshold, padding=1) 51 | self.show_mask_heatmaps = show_mask_heatmaps 52 | 53 | self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1]) 54 | 55 | masks_per_dim = 2 56 | self.cpu_device = torch.device("cpu") 57 | self.confidence_threshold = confidence_threshold 58 | self.show_mask_heatmaps = show_mask_heatmaps 59 | self.masks_per_dim = masks_per_dim 60 | 61 | def build_transform(self): 62 | """ 63 | Creates a basic transformation that was used to train the models 64 | """ 65 | cfg = self.cfg 66 | 67 | # we are loading images with OpenCV, so we don't need to convert them 68 | # to BGR, they are already! So all we need to do is to normalize 69 | # by 255 if we want to convert to BGR255 format, or flip the channels 70 | # if we want it to be in RGB in [0-1] range. 71 | if cfg.INPUT.TO_BGR255: 72 | to_bgr_transform = T.Lambda(lambda x: x * 255) 73 | else: 74 | to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]]) 75 | 76 | normalize_transform = T.Normalize( 77 | mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD 78 | ) 79 | 80 | transform = T.Compose( 81 | [ 82 | T.ToPILImage(), 83 | T.ToTensor(), 84 | to_bgr_transform, 85 | normalize_transform, 86 | ] 87 | ) 88 | return transform 89 | 90 | def run_on_opencv_image(self, image): 91 | """ 92 | Arguments: 93 | image (np.ndarray): an image as returned by OpenCV 94 | 95 | Returns: 96 | prediction (BoxList): the detected objects. Additional information 97 | of the detection properties can be found in the fields of 98 | the BoxList via `prediction.fields()` 99 | """ 100 | predictions = self.compute_prediction(image) 101 | top_predictions = self.select_top_predictions(predictions) 102 | 103 | # return top_predictions 104 | result = image.copy() 105 | if self.show_mask_heatmaps: 106 | return self.create_mask_montage(result, top_predictions) 107 | result = self.overlay_boxes(result, top_predictions) 108 | if self.cfg.MODEL.MASK_ON: 109 | result = self.overlay_mask(result, top_predictions) 110 | if self.cfg.MODEL.KEYPOINT_ON: 111 | result = self.overlay_keypoints(result, top_predictions) 112 | result = self.overlay_class_names(result, top_predictions) 113 | 114 | return result 115 | 116 | def compute_prediction(self, original_image): 117 | """ 118 | Arguments: 119 | original_image (np.ndarray): an image as returned by OpenCV 120 | 121 | Returns: 122 | prediction (BoxList): the detected objects. Additional information 123 | of the detection properties can be found in the fields of 124 | the BoxList via `prediction.fields()` 125 | """ 126 | # apply pre-processing to image 127 | image = self.transforms(original_image) 128 | # convert to an ImageList, padded so that it is divisible by 129 | # cfg.DATALOADER.SIZE_DIVISIBILITY 130 | image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY) 131 | image_list = image_list.to(self.device) 132 | # compute predictions 133 | with torch.no_grad(): 134 | predictions = self.model(image_list) 135 | predictions = [o.to(self.cpu_device) for o in predictions] 136 | 137 | # always single image is passed at a time 138 | prediction = predictions[0] 139 | 140 | # reshape prediction (a BoxList) into the original image size 141 | height, width = original_image.shape[:-1] 142 | prediction = prediction.resize((width, height)) 143 | 144 | if prediction.has_field("mask"): 145 | # if we have masks, paste the masks in the right position 146 | # in the image, as defined by the bounding boxes 147 | masks = prediction.get_field("mask") 148 | # always single image is passed at a time 149 | masks = self.masker([masks], [prediction])[0] 150 | prediction.add_field("mask", masks) 151 | return prediction 152 | 153 | def select_top_predictions(self, predictions): 154 | """ 155 | Select only predictions which have a `score` > self.confidence_threshold, 156 | and returns the predictions in descending order of score 157 | 158 | Arguments: 159 | predictions (BoxList): the result of the computation by the model. 160 | It should contain the field `scores`. 161 | 162 | Returns: 163 | prediction (BoxList): the detected objects. Additional information 164 | of the detection properties can be found in the fields of 165 | the BoxList via `prediction.fields()` 166 | """ 167 | scores = predictions.get_field("scores") 168 | keep = torch.nonzero(scores > self.confidence_threshold).squeeze(1) 169 | predictions = predictions[keep] 170 | scores = predictions.get_field("scores") 171 | _, idx = scores.sort(0, descending=True) 172 | return predictions[idx] 173 | 174 | def compute_colors_for_labels(self, labels): 175 | """ 176 | Simple function that adds fixed colors depending on the class 177 | """ 178 | colors = labels[:, None] * self.palette 179 | colors = (colors % 255).numpy().astype("uint8") 180 | return colors 181 | 182 | def overlay_mask(self, image, predictions): 183 | """ 184 | Adds the instances contours for each predicted object. 185 | Each label has a different color. 186 | 187 | Arguments: 188 | image (np.ndarray): an image as returned by OpenCV 189 | predictions (BoxList): the result of the computation by the model. 190 | It should contain the field `mask` and `labels`. 191 | """ 192 | masks = predictions.get_field("mask").numpy() 193 | labels = predictions.get_field("labels") 194 | 195 | colors = self.compute_colors_for_labels(labels).tolist() 196 | 197 | for mask, color in zip(masks, colors): 198 | thresh = mask[0, :, :, None] 199 | contours, hierarchy = cv2_util.findContours( 200 | thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 201 | ) 202 | image = cv2.drawContours(image, contours, -1, color, 3) 203 | 204 | composite = image 205 | 206 | return composite 207 | 208 | def overlay_boxes(self, image, predictions): 209 | """ 210 | Adds the predicted boxes on top of the image 211 | 212 | Arguments: 213 | image (np.ndarray): an image as returned by OpenCV 214 | predictions (BoxList): the result of the computation by the model. 215 | It should contain the field `labels`. 216 | """ 217 | labels = predictions.get_field("labels") 218 | boxes = predictions.bbox 219 | 220 | colors = self.compute_colors_for_labels(labels).tolist() 221 | 222 | for box, color in zip(boxes, colors): 223 | box = box.to(torch.int64) 224 | top_left, bottom_right = box[:2].tolist(), box[2:].tolist() 225 | image = cv2.rectangle( 226 | image, tuple(top_left), tuple(bottom_right), tuple(color), 1 227 | ) 228 | 229 | return image 230 | 231 | def overlay_class_names(self, image, predictions): 232 | """ 233 | Adds detected class names and scores in the positions defined by the 234 | top-left corner of the predicted bounding box 235 | 236 | Arguments: 237 | image (np.ndarray): an image as returned by OpenCV 238 | predictions (BoxList): the result of the computation by the model. 239 | It should contain the field `scores` and `labels`. 240 | """ 241 | scores = predictions.get_field("scores").tolist() 242 | labels = predictions.get_field("labels").tolist() 243 | labels = [self.CATEGORIES[i] for i in labels] 244 | # labels = [self.CATEGORIES[1] for i in labels] 245 | boxes = predictions.bbox 246 | 247 | template = "{}: {:.2f}" 248 | for box, score, label in zip(boxes, scores, labels): 249 | x, y = box[:2] 250 | s = template.format(label, score) 251 | cv2.putText( 252 | image, s, (x, y), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1 253 | ) 254 | 255 | return image 256 | 257 | 258 | def main(): 259 | args = get_args() 260 | weight_path = args.weight 261 | mode = args.mode 262 | layer = args.layer 263 | 264 | cfg_path = str(Path(__file__).parent.joinpath("config", "e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml")) 265 | cfg.merge_from_file(cfg_path) 266 | device = "cuda" if torch.cuda.is_available() else "cpu" 267 | cfg.merge_from_list(["MODEL.DEVICE", device, "MODEL.WEIGHT", weight_path]) 268 | predictor = Predictor(cfg, confidence_threshold=0.1) 269 | data_dir = Path(__file__).parent.joinpath("datasets") 270 | img_dir = data_dir.joinpath(f"{mode}") 271 | output_dir = data_dir.joinpath(f"{mode}_{layer}_results") 272 | output_dir.mkdir(parents=True, exist_ok=True) 273 | 274 | for img_path in tqdm(list(img_dir.glob("*.jpg"))): 275 | img = cv2.imread(str(img_path), 1) 276 | h, w, _ = img.shape 277 | 278 | if h > w and h > 1024: 279 | new_h = 1024 280 | new_w = int(new_h * w / h) 281 | img = cv2.resize(img, (new_w, new_h)) 282 | 283 | if w > h and w > 1024: 284 | new_w = 1024 285 | new_h = int(new_w * h / w) 286 | img = cv2.resize(img, (new_w, new_h)) 287 | 288 | new_h, new_w, _ = img.shape 289 | predictions = predictor.compute_prediction(img) 290 | predictions = predictor.select_top_predictions(predictions) 291 | boxes = predictions.bbox.tolist() 292 | scores = predictions.get_field("scores").tolist() 293 | labels = predictions.get_field("labels").tolist() 294 | masks = predictions.get_field("mask").numpy() 295 | json_path = output_dir.joinpath(img_path.stem + ".json") 296 | 297 | with json_path.open("w") as f: 298 | json.dump({"boxes": boxes, "scores": scores, "labels": labels, "original_img_size": [w, h], 299 | "new_img_size": [new_w, new_h]}, f) 300 | 301 | for i, mask in enumerate(masks): 302 | output_mask_path = output_dir.joinpath("{}_{}.png".format(json_path.stem, i)) 303 | cv2.imwrite(str(output_mask_path), mask[0]) 304 | 305 | 306 | if __name__ == '__main__': 307 | main() 308 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | r""" 3 | Basic training script for PyTorch 4 | """ 5 | 6 | # Set up custom environment before nearly anything else is imported 7 | # NOTE: this should be the first import (no not reorder) 8 | from maskrcnn_benchmark.utils.env import setup_environment # noqa F401 isort:skip 9 | 10 | import argparse 11 | import os 12 | from pathlib import Path 13 | import torch 14 | from maskrcnn_benchmark.config import cfg 15 | from maskrcnn_benchmark.data import make_data_loader 16 | from maskrcnn_benchmark.solver import make_lr_scheduler 17 | from maskrcnn_benchmark.solver import make_optimizer 18 | from maskrcnn_benchmark.engine.inference import inference 19 | from maskrcnn_benchmark.engine.trainer import do_train 20 | from maskrcnn_benchmark.modeling.detector import build_detection_model 21 | from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer 22 | from maskrcnn_benchmark.utils.collect_env import collect_env_info 23 | from maskrcnn_benchmark.utils.comm import synchronize, get_rank 24 | from maskrcnn_benchmark.utils.imports import import_file 25 | from maskrcnn_benchmark.utils.logger import setup_logger 26 | from maskrcnn_benchmark.utils.miscellaneous import mkdir 27 | 28 | try: 29 | from apex import amp 30 | except ImportError: 31 | raise ImportError('Use APEX for multi-precision via apex.amp') 32 | 33 | 34 | def train(cfg, local_rank, distributed): 35 | model = build_detection_model(cfg) 36 | device = torch.device(cfg.MODEL.DEVICE) 37 | model.to(device) 38 | 39 | optimizer = make_optimizer(cfg, model) 40 | scheduler = make_lr_scheduler(cfg, optimizer) 41 | 42 | # Initialize mixed-precision training 43 | use_mixed_precision = cfg.DTYPE == "float16" 44 | amp_opt_level = 'O1' if use_mixed_precision else 'O0' 45 | model, optimizer = amp.initialize(model, optimizer, opt_level=amp_opt_level) 46 | 47 | if distributed: 48 | model = torch.nn.parallel.DistributedDataParallel( 49 | model, device_ids=[local_rank], output_device=local_rank, 50 | # this should be removed if we update BatchNorm stats 51 | broadcast_buffers=False, 52 | ) 53 | 54 | arguments = {} 55 | arguments["iteration"] = 0 56 | 57 | output_dir = cfg.OUTPUT_DIR 58 | 59 | save_to_disk = get_rank() == 0 60 | checkpointer = DetectronCheckpointer( 61 | cfg, model, optimizer, scheduler, output_dir, save_to_disk 62 | ) 63 | extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) 64 | arguments.update(extra_checkpoint_data) 65 | 66 | data_loader = make_data_loader( 67 | cfg, 68 | is_train=True, 69 | is_distributed=distributed, 70 | start_iter=arguments["iteration"], 71 | ) 72 | 73 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 74 | 75 | do_train( 76 | cfg, 77 | model, 78 | data_loader, 79 | None, 80 | optimizer, 81 | scheduler, 82 | checkpointer, 83 | device, 84 | checkpoint_period, 85 | None, 86 | arguments, 87 | ) 88 | 89 | return model 90 | 91 | 92 | def run_test(cfg, model, distributed): 93 | if distributed: 94 | model = model.module 95 | torch.cuda.empty_cache() # TODO check if it helps 96 | iou_types = ("bbox",) 97 | if cfg.MODEL.MASK_ON: 98 | iou_types = iou_types + ("segm",) 99 | if cfg.MODEL.KEYPOINT_ON: 100 | iou_types = iou_types + ("keypoints",) 101 | output_folders = [None] * len(cfg.DATASETS.TEST) 102 | dataset_names = cfg.DATASETS.TEST 103 | if cfg.OUTPUT_DIR: 104 | for idx, dataset_name in enumerate(dataset_names): 105 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 106 | mkdir(output_folder) 107 | output_folders[idx] = output_folder 108 | data_loaders_val = make_data_loader(cfg, is_train=False, is_distributed=distributed) 109 | for output_folder, dataset_name, data_loader_val in zip(output_folders, dataset_names, data_loaders_val): 110 | inference( 111 | model, 112 | data_loader_val, 113 | dataset_name=dataset_name, 114 | iou_types=iou_types, 115 | box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY, 116 | device=cfg.MODEL.DEVICE, 117 | expected_results=cfg.TEST.EXPECTED_RESULTS, 118 | expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL, 119 | output_folder=output_folder, 120 | ) 121 | synchronize() 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") 126 | parser.add_argument( 127 | "--config-file", 128 | default=None, 129 | metavar="FILE", 130 | help="path to config file", 131 | type=str, 132 | ) 133 | parser.add_argument("--local_rank", type=int, default=0) 134 | parser.add_argument( 135 | "--skip-test", 136 | dest="skip_test", 137 | help="Do not test the final model", 138 | action="store_true", 139 | ) 140 | parser.add_argument( 141 | "opts", 142 | help="Modify config options using the command-line", 143 | default=None, 144 | nargs=argparse.REMAINDER, 145 | ) 146 | 147 | args = parser.parse_args() 148 | 149 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 150 | args.distributed = num_gpus > 1 151 | 152 | if args.distributed: 153 | torch.cuda.set_device(args.local_rank) 154 | torch.distributed.init_process_group( 155 | backend="nccl", init_method="env://" 156 | ) 157 | synchronize() 158 | 159 | if args.config_file is None: 160 | args.config_file = str(Path(__file__).parent.joinpath("config", "e2e_mask_rcnn_X_101_32x8d_FPN_1x.yaml")) 161 | 162 | cfg.merge_from_file(args.config_file) 163 | cfg.merge_from_list(args.opts) 164 | cfg.freeze() 165 | 166 | output_dir = cfg.OUTPUT_DIR 167 | if output_dir: 168 | mkdir(output_dir) 169 | 170 | logger = setup_logger("maskrcnn_benchmark", output_dir, get_rank()) 171 | logger.info("Using {} GPUs".format(num_gpus)) 172 | logger.info(args) 173 | 174 | logger.info("Collecting env info (might take some time)") 175 | logger.info("\n" + collect_env_info()) 176 | 177 | logger.info("Loaded configuration file {}".format(args.config_file)) 178 | with open(args.config_file, "r") as cf: 179 | config_str = "\n" + cf.read() 180 | logger.info(config_str) 181 | logger.info("Running with config:\n{}".format(cfg)) 182 | 183 | model = train(cfg, args.local_rank, args.distributed) 184 | 185 | if not args.skip_test: 186 | run_test(cfg, model, args.distributed) 187 | 188 | 189 | if __name__ == "__main__": 190 | main() 191 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import better_exceptions 3 | from pathlib import Path 4 | import json 5 | from collections import defaultdict 6 | import pandas as pd 7 | import cv2 8 | import base64 9 | import numpy as np 10 | from pycocotools import _mask as coco_mask 11 | import typing as t 12 | import zlib 13 | 14 | 15 | def get_class_mapping(): 16 | csv_path = Path(__file__).parent.joinpath("datasets", "classes-segmentation.txt") 17 | df = pd.read_csv(str(csv_path), header=None, names=["class_string"]) 18 | class_string_to_class_id = dict(zip(df.class_string, df.index)) 19 | class_id_to_class_string = dict(zip(df.index, df.class_string)) 20 | return class_string_to_class_id, class_id_to_class_string 21 | 22 | 23 | def get_string_to_name(): 24 | csv_path = Path(__file__).parent.joinpath("datasets", "challenge-2019-classes-description-segmentable.csv") 25 | df = pd.read_csv(str(csv_path), header=None, names=["class_string", "class_name"]) 26 | class_string_to_class_name = dict(zip(df.class_string, df.class_name)) 27 | return class_string_to_class_name 28 | 29 | 30 | def get_layer_names(): 31 | layer0, layer1, class_string_to_parent = get_hierarchy() 32 | class_string_to_class_name = get_string_to_name() 33 | layer0_names = [class_string_to_class_name[s] for s in sorted(layer0)] 34 | layer1_names = [class_string_to_class_name[s] for s in sorted(layer1)] 35 | return layer0_names, layer1_names 36 | 37 | 38 | def get_hierarchy(): 39 | json_path = Path(__file__).parent.joinpath("datasets", "challenge-2019-label300-segmentable-hierarchy.json") 40 | 41 | with json_path.open("r") as f: 42 | d = json.load(f) 43 | 44 | level_to_class_strings = defaultdict(list) 45 | class_string_to_parent = {} 46 | 47 | def register(c, level): 48 | class_string = c["LabelName"] 49 | level_to_class_strings[level].append(class_string) 50 | 51 | if "Subcategory" in c.keys(): 52 | for sub_c in c["Subcategory"]: 53 | class_string_to_parent[sub_c["LabelName"]] = class_string 54 | register(sub_c, level + 1) 55 | 56 | for c in d["Subcategory"]: 57 | register(c, 0) 58 | 59 | class_string_to_parent["/m/0kmg4"] = "/m/0138tl" # teddy bear is toy not bear... 60 | layer0 = level_to_class_strings[0] 61 | layer1 = list(set(level_to_class_strings[1] + level_to_class_strings[2])) 62 | return sorted(layer0), sorted(layer1), class_string_to_parent 63 | 64 | 65 | def find_contour(mask): 66 | mask = cv2.UMat(mask) 67 | contour, hierarchy = cv2.findContours( 68 | mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1 69 | ) 70 | 71 | reshaped_contour = [] 72 | for entity in contour: 73 | assert len(entity.shape) == 3 74 | assert entity.shape[1] == 1, \ 75 | 'Hierarchical contours are not allowed' 76 | reshaped_contour.append(entity.reshape(-1).tolist()) 77 | return reshaped_contour 78 | 79 | 80 | def encode_binary_mask(mask: np.ndarray) -> t.Text: 81 | """Converts a binary mask into OID challenge encoding ascii text.""" 82 | 83 | # check input mask -- 84 | if mask.dtype != np.bool: 85 | raise ValueError( 86 | "encode_binary_mask expects a binary mask, received dtype == %s" % 87 | mask.dtype) 88 | 89 | mask = np.squeeze(mask) 90 | if len(mask.shape) != 2: 91 | raise ValueError( 92 | "encode_binary_mask expects a 2d mask, received shape == %s" % 93 | mask.shape) 94 | 95 | # convert input mask to expected COCO API input -- 96 | mask_to_encode = mask.reshape(mask.shape[0], mask.shape[1], 1) 97 | mask_to_encode = mask_to_encode.astype(np.uint8) 98 | mask_to_encode = np.asfortranarray(mask_to_encode) 99 | 100 | # RLE encode mask -- 101 | encoded_mask = coco_mask.encode(mask_to_encode)[0]["counts"] 102 | 103 | # compress and base64 encoding -- 104 | binary_str = zlib.compress(encoded_mask, zlib.Z_BEST_COMPRESSION) 105 | base64_str = base64.b64encode(binary_str) 106 | return base64_str 107 | 108 | 109 | def main(): 110 | get_hierarchy() 111 | print(get_layer_names()[1]) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | --------------------------------------------------------------------------------