├── .gitignore ├── README.md ├── coco_eval.py ├── coco_utils.py ├── config ├── train_real_data.yaml └── train_syn_data.yaml ├── demo_stream.py ├── engine.py ├── inference.py ├── inference_video.py ├── loader ├── __init__.py ├── real_loader.py ├── synthetic_loader.py ├── transforms.py └── unimib.py ├── models ├── __init__.py └── maskrcnn.py ├── resources ├── demo_unseen_food_segmentation.gif └── example.png ├── test_multi_models.py ├── test_single_model.py ├── train.py ├── utils ├── __init__.py └── logger.py ├── vis_UNIMIB.py ├── vis_by_loader.py └── vis_by_loader_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | *.pth 3 | .vscode/ 4 | results/ 5 | data_sample/ 6 | demo/ 7 | tmp/ 8 | tmp_test.py 9 | *.pptx 10 | script.sh 11 | example_*.yaml 12 | ckp/ 13 | 14 | copy_data_sample.py 15 | count_data.py 16 | inspect_data.py 17 | 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | # For a library or package, you might want to ignore these files since the code is 105 | # intended to run in multiple environments; otherwise, check them in: 106 | # .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Food Instance Segmentation 2 | 3 | This is the implementation of paper submitted as 4 | 5 | > D. Park, J. Lee, J. Lee and K. Lee. **Deep Learning based Food Instance Segmentation using Synthetic Data.** 2021 Ubiquitous Robots (UR) available online at 6 | [arXiv](https://arxiv.org/abs/2107.07191) or [IEEE](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9494704). 7 | 8 | ![Demo](./resources/demo_unseen_food_segmentation.gif) 9 | 10 | 11 | # Usage 12 | - Training 13 | ```Shell 14 | # Learning from scratch 15 | python train.py --config {PATH/TO/CONFIG} 16 | 17 | # Resuming training 18 | python train.py \ 19 | --config {PATH/TO/CONFIG} \ 20 | --resume --resume_ckp PATH/TO/TRAINED/WEIGHT 21 | ``` 22 | 23 | - Evaluation 24 | ```Shell 25 | # Evaluation one weight 26 | python test_single_models.py \ 27 | --config {PATH/TO/CONFIG} \ 28 | --trained_ckp PATH/TO/TRAINED/WEIGHT 29 | 30 | # Evaluation many weight saved in same directory 31 | python test_single_models.py \ 32 | --config {PATH/TO/CONFIG} \ 33 | --ckp_dir PATH/TO/TRAINED/DIRECTORY 34 | ``` 35 | 36 | - Inference demo 37 | ```Shell 38 | # predict and draw food instance from any images 39 | python inference.py \ 40 | --input_dir PATH/TO/IMAGES/FOR/INFERNCE \ 41 | --output_dir PATH/TO/SAVE/VISUALIZED/IMAGES \ 42 | --ckp_path PATH/TO/TRAINED/WEIGHT 43 | ``` 44 | We also provide MASK R-CNN weight pre-trained on our synthetic dataset. [Download](https://drive.google.com/file/d/1JaZKzYiOZ29HZhV6wcbzHCWcJbZK_r35/view?usp=sharing) 45 | 46 | # Dataset 47 | Synthetic dataset generated by Blender. [Download](https://drive.google.com/file/d/1YKHi8RhjgH-LcoVjI5yIwkbfRNtgSVdE/view?usp=sharing) 48 | ![Data](./resources/example.png) 49 | -------------------------------------------------------------------------------- /coco_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | 4 | import numpy as np 5 | import copy 6 | import time 7 | import torch 8 | import torch._six 9 | 10 | from pycocotools.cocoeval import COCOeval 11 | from pycocotools.coco import COCO 12 | import pycocotools.mask as mask_util 13 | 14 | from collections import defaultdict 15 | 16 | import utils 17 | 18 | 19 | class CocoEvaluator(object): 20 | def __init__(self, coco_gt, iou_types): 21 | assert isinstance(iou_types, (list, tuple)) 22 | coco_gt = copy.deepcopy(coco_gt) 23 | self.coco_gt = coco_gt 24 | 25 | self.iou_types = iou_types 26 | self.coco_eval = {} 27 | for iou_type in iou_types: 28 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 29 | 30 | self.img_ids = [] 31 | self.eval_imgs = {k: [] for k in iou_types} 32 | 33 | def update(self, predictions): 34 | img_ids = list(np.unique(list(predictions.keys()))) 35 | self.img_ids.extend(img_ids) 36 | 37 | for iou_type in self.iou_types: 38 | results = self.prepare(predictions, iou_type) 39 | coco_dt = loadRes(self.coco_gt, results) if results else COCO() 40 | coco_eval = self.coco_eval[iou_type] 41 | 42 | coco_eval.cocoDt = coco_dt 43 | coco_eval.params.imgIds = list(img_ids) 44 | img_ids, eval_imgs = evaluate(coco_eval) 45 | 46 | self.eval_imgs[iou_type].append(eval_imgs) 47 | 48 | def synchronize_between_processes(self): 49 | for iou_type in self.iou_types: 50 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 51 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 52 | 53 | def accumulate(self): 54 | for coco_eval in self.coco_eval.values(): 55 | coco_eval.accumulate() 56 | 57 | def summarize(self): 58 | for iou_type, coco_eval in self.coco_eval.items(): 59 | print("IoU metric: {}".format(iou_type)) 60 | coco_eval.summarize() 61 | 62 | def prepare(self, predictions, iou_type): 63 | if iou_type == "bbox": 64 | return self.prepare_for_coco_detection(predictions) 65 | elif iou_type == "segm": 66 | return self.prepare_for_coco_segmentation(predictions) 67 | elif iou_type == "keypoints": 68 | return self.prepare_for_coco_keypoint(predictions) 69 | else: 70 | raise ValueError("Unknown iou type {}".format(iou_type)) 71 | 72 | def prepare_for_coco_detection(self, predictions): 73 | coco_results = [] 74 | for original_id, prediction in predictions.items(): 75 | if len(prediction) == 0: 76 | continue 77 | 78 | boxes = prediction["boxes"] 79 | boxes = convert_to_xywh(boxes).tolist() 80 | scores = prediction["scores"].tolist() 81 | labels = prediction["labels"].tolist() 82 | 83 | coco_results.extend( 84 | [ 85 | { 86 | "image_id": original_id, 87 | "category_id": labels[k], 88 | "bbox": box, 89 | "score": scores[k], 90 | } 91 | for k, box in enumerate(boxes) 92 | ] 93 | ) 94 | return coco_results 95 | 96 | def prepare_for_coco_segmentation(self, predictions): 97 | coco_results = [] 98 | for original_id, prediction in predictions.items(): 99 | if len(prediction) == 0: 100 | continue 101 | 102 | scores = prediction["scores"] 103 | labels = prediction["labels"] 104 | masks = prediction["masks"] 105 | 106 | masks = masks > 0.5 107 | 108 | scores = prediction["scores"].tolist() 109 | labels = prediction["labels"].tolist() 110 | 111 | rles = [ 112 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 113 | for mask in masks 114 | ] 115 | for rle in rles: 116 | rle["counts"] = rle["counts"].decode("utf-8") 117 | 118 | coco_results.extend( 119 | [ 120 | { 121 | "image_id": original_id, 122 | "category_id": labels[k], 123 | "segmentation": rle, 124 | "score": scores[k], 125 | } 126 | for k, rle in enumerate(rles) 127 | ] 128 | ) 129 | return coco_results 130 | 131 | def prepare_for_coco_keypoint(self, predictions): 132 | coco_results = [] 133 | for original_id, prediction in predictions.items(): 134 | if len(prediction) == 0: 135 | continue 136 | 137 | boxes = prediction["boxes"] 138 | boxes = convert_to_xywh(boxes).tolist() 139 | scores = prediction["scores"].tolist() 140 | labels = prediction["labels"].tolist() 141 | keypoints = prediction["keypoints"] 142 | keypoints = keypoints.flatten(start_dim=1).tolist() 143 | 144 | coco_results.extend( 145 | [ 146 | { 147 | "image_id": original_id, 148 | "category_id": labels[k], 149 | 'keypoints': keypoint, 150 | "score": scores[k], 151 | } 152 | for k, keypoint in enumerate(keypoints) 153 | ] 154 | ) 155 | return coco_results 156 | 157 | 158 | def convert_to_xywh(boxes): 159 | xmin, ymin, xmax, ymax = boxes.unbind(1) 160 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 161 | 162 | 163 | def merge(img_ids, eval_imgs): 164 | all_img_ids = utils.all_gather(img_ids) 165 | all_eval_imgs = utils.all_gather(eval_imgs) 166 | 167 | merged_img_ids = [] 168 | for p in all_img_ids: 169 | merged_img_ids.extend(p) 170 | 171 | merged_eval_imgs = [] 172 | for p in all_eval_imgs: 173 | merged_eval_imgs.append(p) 174 | 175 | merged_img_ids = np.array(merged_img_ids) 176 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 177 | 178 | # keep only unique (and in sorted order) images 179 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 180 | merged_eval_imgs = merged_eval_imgs[..., idx] 181 | 182 | return merged_img_ids, merged_eval_imgs 183 | 184 | 185 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 186 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 187 | img_ids = list(img_ids) 188 | eval_imgs = list(eval_imgs.flatten()) 189 | 190 | coco_eval.evalImgs = eval_imgs 191 | coco_eval.params.imgIds = img_ids 192 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 193 | 194 | 195 | ################################################################# 196 | # From pycocotools, just removed the prints and fixed 197 | # a Python3 bug about unicode not defined 198 | ################################################################# 199 | 200 | # Ideally, pycocotools wouldn't have hard-coded prints 201 | # so that we could avoid copy-pasting those two functions 202 | 203 | def createIndex(self): 204 | # create index 205 | # print('creating index...') 206 | anns, cats, imgs = {}, {}, {} 207 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list) 208 | if 'annotations' in self.dataset: 209 | for ann in self.dataset['annotations']: 210 | imgToAnns[ann['image_id']].append(ann) 211 | anns[ann['id']] = ann 212 | 213 | if 'images' in self.dataset: 214 | for img in self.dataset['images']: 215 | imgs[img['id']] = img 216 | 217 | if 'categories' in self.dataset: 218 | for cat in self.dataset['categories']: 219 | cats[cat['id']] = cat 220 | 221 | if 'annotations' in self.dataset and 'categories' in self.dataset: 222 | for ann in self.dataset['annotations']: 223 | catToImgs[ann['category_id']].append(ann['image_id']) 224 | 225 | # print('index created!') 226 | 227 | # create class members 228 | self.anns = anns 229 | self.imgToAnns = imgToAnns 230 | self.catToImgs = catToImgs 231 | self.imgs = imgs 232 | self.cats = cats 233 | 234 | 235 | maskUtils = mask_util 236 | 237 | 238 | def loadRes(self, resFile): 239 | """ 240 | Load result file and return a result api object. 241 | :param resFile (str) : file name of result file 242 | :return: res (obj) : result api object 243 | """ 244 | res = COCO() 245 | res.dataset['images'] = [img for img in self.dataset['images']] 246 | 247 | # print('Loading and preparing results...') 248 | # tic = time.time() 249 | if isinstance(resFile, torch._six.string_classes): 250 | anns = json.load(open(resFile)) 251 | elif type(resFile) == np.ndarray: 252 | anns = self.loadNumpyAnnotations(resFile) 253 | else: 254 | anns = resFile 255 | assert type(anns) == list, 'results in not an array of objects' 256 | annsImgIds = [ann['image_id'] for ann in anns] 257 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 258 | 'Results do not correspond to current coco set' 259 | if 'caption' in anns[0]: 260 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 261 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 262 | for id, ann in enumerate(anns): 263 | ann['id'] = id + 1 264 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 265 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 266 | for id, ann in enumerate(anns): 267 | bb = ann['bbox'] 268 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] 269 | if 'segmentation' not in ann: 270 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 271 | ann['area'] = bb[2] * bb[3] 272 | ann['id'] = id + 1 273 | ann['iscrowd'] = 0 274 | elif 'segmentation' in anns[0]: 275 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 276 | for id, ann in enumerate(anns): 277 | # now only support compressed RLE format as segmentation results 278 | ann['area'] = maskUtils.area(ann['segmentation']) 279 | if 'bbox' not in ann: 280 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 281 | ann['id'] = id + 1 282 | ann['iscrowd'] = 0 283 | elif 'keypoints' in anns[0]: 284 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 285 | for id, ann in enumerate(anns): 286 | s = ann['keypoints'] 287 | x = s[0::3] 288 | y = s[1::3] 289 | x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y) 290 | ann['area'] = (x2 - x1) * (y2 - y1) 291 | ann['id'] = id + 1 292 | ann['bbox'] = [x1, y1, x2 - x1, y2 - y1] 293 | # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 294 | 295 | res.dataset['annotations'] = anns 296 | createIndex(res) 297 | return res 298 | 299 | 300 | def evaluate(self): 301 | ''' 302 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 303 | :return: None 304 | ''' 305 | # tic = time.time() 306 | # print('Running per image evaluation...') 307 | p = self.params 308 | # add backward compatibility if useSegm is specified in params 309 | if p.useSegm is not None: 310 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 311 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 312 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 313 | p.imgIds = list(np.unique(p.imgIds)) 314 | if p.useCats: 315 | p.catIds = list(np.unique(p.catIds)) 316 | p.maxDets = sorted(p.maxDets) 317 | self.params = p 318 | 319 | self._prepare() 320 | # loop through images, area range, max detection number 321 | catIds = p.catIds if p.useCats else [-1] 322 | 323 | if p.iouType == 'segm' or p.iouType == 'bbox': 324 | computeIoU = self.computeIoU 325 | elif p.iouType == 'keypoints': 326 | computeIoU = self.computeOks 327 | self.ious = { 328 | (imgId, catId): computeIoU(imgId, catId) 329 | for imgId in p.imgIds 330 | for catId in catIds} 331 | 332 | evaluateImg = self.evaluateImg 333 | maxDet = p.maxDets[-1] 334 | evalImgs = [ 335 | evaluateImg(imgId, catId, areaRng, maxDet) 336 | for catId in catIds 337 | for areaRng in p.areaRng 338 | for imgId in p.imgIds 339 | ] 340 | # this is NOT in the pycocotools code, but could be done outside 341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 342 | self._paramsEval = copy.deepcopy(self.params) 343 | # toc = time.time() 344 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 345 | return p.imgIds, evalImgs 346 | 347 | ################################################################# 348 | # end of straight copy from pycocotools, just removing the prints 349 | ################################################################# 350 | -------------------------------------------------------------------------------- /coco_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | 12 | from pycocotools import mask as coco_mask 13 | from pycocotools.coco import COCO 14 | 15 | import loader.transforms as T 16 | 17 | class FilterAndRemapCocoCategories(object): 18 | def __init__(self, categories, remap=True): 19 | self.categories = categories 20 | self.remap = remap 21 | 22 | def __call__(self, image, target): 23 | anno = target["annotations"] 24 | anno = [obj for obj in anno if obj["category_id"] in self.categories] 25 | if not self.remap: 26 | target["annotations"] = anno 27 | return image, target 28 | anno = copy.deepcopy(anno) 29 | for obj in anno: 30 | obj["category_id"] = self.categories.index(obj["category_id"]) 31 | target["annotations"] = anno 32 | return image, target 33 | 34 | 35 | def convert_coco_poly_to_mask(segmentations, height, width): 36 | masks = [] 37 | for polygons in segmentations: 38 | rles = coco_mask.frPyObjects(polygons, height, width) 39 | mask = coco_mask.decode(rles) 40 | if len(mask.shape) < 3: 41 | mask = mask[..., None] 42 | mask = torch.as_tensor(mask, dtype=torch.uint8) 43 | mask = mask.any(dim=2) 44 | masks.append(mask) 45 | if masks: 46 | masks = torch.stack(masks, dim=0) 47 | else: 48 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 49 | return masks 50 | 51 | 52 | class ConvertCocoPolysToMask(object): 53 | def __call__(self, image, target): 54 | w, h = image.size 55 | 56 | image_id = target["image_id"] 57 | image_id = torch.tensor([image_id]) 58 | 59 | anno = target["annotations"] 60 | 61 | anno = [obj for obj in anno if obj['iscrowd'] == 0] 62 | 63 | boxes = [obj["bbox"] for obj in anno] 64 | # guard against no boxes via resizing 65 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 66 | boxes[:, 2:] += boxes[:, :2] 67 | boxes[:, 0::2].clamp_(min=0, max=w) 68 | boxes[:, 1::2].clamp_(min=0, max=h) 69 | 70 | classes = [obj["category_id"] for obj in anno] 71 | classes = torch.tensor(classes, dtype=torch.int64) 72 | 73 | segmentations = [obj["segmentation"] for obj in anno] 74 | masks = convert_coco_poly_to_mask(segmentations, h, w) 75 | 76 | keypoints = None 77 | if anno and "keypoints" in anno[0]: 78 | keypoints = [obj["keypoints"] for obj in anno] 79 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 80 | num_keypoints = keypoints.shape[0] 81 | if num_keypoints: 82 | keypoints = keypoints.view(num_keypoints, -1, 3) 83 | 84 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 85 | boxes = boxes[keep] 86 | classes = classes[keep] 87 | masks = masks[keep] 88 | if keypoints is not None: 89 | keypoints = keypoints[keep] 90 | 91 | target = {} 92 | target["boxes"] = boxes 93 | target["labels"] = classes 94 | target["masks"] = masks 95 | target["image_id"] = image_id 96 | if keypoints is not None: 97 | target["keypoints"] = keypoints 98 | 99 | # for conversion to coco api 100 | area = torch.tensor([obj["area"] for obj in anno]) 101 | iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) 102 | target["area"] = area 103 | target["iscrowd"] = iscrowd 104 | 105 | return image, target 106 | 107 | 108 | def _coco_remove_images_without_annotations(dataset, cat_list=None): 109 | def _has_only_empty_bbox(anno): 110 | return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) 111 | 112 | def _count_visible_keypoints(anno): 113 | return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) 114 | 115 | min_keypoints_per_image = 10 116 | 117 | def _has_valid_annotation(anno): 118 | # if it's empty, there is no annotation 119 | if len(anno) == 0: 120 | return False 121 | # if all boxes have close to zero area, there is no annotation 122 | if _has_only_empty_bbox(anno): 123 | return False 124 | # keypoints task have a slight different critera for considering 125 | # if an annotation is valid 126 | if "keypoints" not in anno[0]: 127 | return True 128 | # for keypoint detection tasks, only consider valid images those 129 | # containing at least min_keypoints_per_image 130 | if _count_visible_keypoints(anno) >= min_keypoints_per_image: 131 | return True 132 | return False 133 | 134 | assert isinstance(dataset, torchvision.datasets.CocoDetection) 135 | ids = [] 136 | for ds_idx, img_id in enumerate(dataset.ids): 137 | ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) 138 | anno = dataset.coco.loadAnns(ann_ids) 139 | if cat_list: 140 | anno = [obj for obj in anno if obj["category_id"] in cat_list] 141 | if _has_valid_annotation(anno): 142 | ids.append(ds_idx) 143 | 144 | dataset = torch.utils.data.Subset(dataset, ids) 145 | return dataset 146 | 147 | 148 | def convert_to_coco_api(ds): 149 | coco_ds = COCO() 150 | 151 | # annotation IDs need to start at 1, not 0, see torchvision issue #1530 152 | ann_id = 1 153 | dataset = {'images': [], 'categories': [], 'annotations': []} 154 | categories = set() 155 | for img_idx in tqdm(range(len(ds))): 156 | # find better way to get target 157 | # targets = ds.get_annotations(img_idx) 158 | img, targets = ds[img_idx] 159 | image_id = targets["image_id"].item() 160 | img_dict = {} 161 | img_dict['id'] = image_id 162 | img_dict['height'] = img.shape[-2] 163 | img_dict['width'] = img.shape[-1] 164 | dataset['images'].append(img_dict) 165 | bboxes = targets["boxes"] 166 | bboxes[:, 2:] -= bboxes[:, :2] 167 | bboxes = bboxes.tolist() 168 | labels = targets['labels'].tolist() 169 | areas = targets['area'].tolist() 170 | iscrowd = targets['iscrowd'].tolist() 171 | if 'masks' in targets: 172 | masks = targets['masks'] 173 | # make masks Fortran contiguous for coco_mask 174 | masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) 175 | if 'keypoints' in targets: 176 | keypoints = targets['keypoints'] 177 | keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() 178 | num_objs = len(bboxes) 179 | for i in range(num_objs): 180 | ann = {} 181 | ann['image_id'] = image_id 182 | ann['bbox'] = bboxes[i] 183 | ann['category_id'] = labels[i] 184 | categories.add(labels[i]) 185 | ann['area'] = areas[i] 186 | ann['iscrowd'] = iscrowd[i] 187 | ann['id'] = ann_id 188 | if 'masks' in targets: 189 | ann["segmentation"] = coco_mask.encode(masks[i].numpy()) 190 | if 'keypoints' in targets: 191 | ann['keypoints'] = keypoints[i] 192 | ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) 193 | dataset['annotations'].append(ann) 194 | ann_id += 1 195 | dataset['categories'] = [{'id': i} for i in sorted(categories)] 196 | coco_ds.dataset = dataset 197 | coco_ds.createIndex() 198 | return coco_ds 199 | 200 | 201 | def get_coco_api_from_dataset(dataset): 202 | for _ in range(10): 203 | if isinstance(dataset, torchvision.datasets.CocoDetection): 204 | break 205 | if isinstance(dataset, torch.utils.data.Subset): 206 | dataset = dataset.dataset 207 | if isinstance(dataset, torchvision.datasets.CocoDetection): 208 | return dataset.coco 209 | return convert_to_coco_api(dataset) 210 | 211 | 212 | class CocoDetection(torchvision.datasets.CocoDetection): 213 | def __init__(self, img_folder, ann_file, transforms): 214 | super(CocoDetection, self).__init__(img_folder, ann_file) 215 | self._transforms = transforms 216 | 217 | def __getitem__(self, idx): 218 | img, target = super(CocoDetection, self).__getitem__(idx) 219 | image_id = self.ids[idx] 220 | target = dict(image_id=image_id, annotations=target) 221 | if self._transforms is not None: 222 | img, target = self._transforms(img, target) 223 | return img, target 224 | 225 | 226 | def get_coco(root, image_set, transforms, mode='instances'): 227 | anno_file_template = "{}_{}2017.json" 228 | PATHS = { 229 | "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), 230 | "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), 231 | # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) 232 | } 233 | 234 | t = [ConvertCocoPolysToMask()] 235 | 236 | if transforms is not None: 237 | t.append(transforms) 238 | transforms = T.Compose(t) 239 | 240 | img_folder, ann_file = PATHS[image_set] 241 | img_folder = os.path.join(root, img_folder) 242 | ann_file = os.path.join(root, ann_file) 243 | 244 | dataset = CocoDetection(img_folder, ann_file, transforms=transforms) 245 | 246 | if image_set == "train": 247 | dataset = _coco_remove_images_without_annotations(dataset) 248 | 249 | # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) 250 | return dataset 251 | 252 | def get_coco_kp(root, image_set, transforms): 253 | return get_coco(root, image_set, transforms, mode="person_keypoints") 254 | 255 | 256 | def coco_to_excel(coco_evaluator, epoch, output_folder, eval_data): 257 | excel_file = os.path.join(output_folder, "coco_result_{}.xlsx".format(eval_data)) 258 | if not os.path.isfile(excel_file): 259 | df = _get_init_coco_result() 260 | else: 261 | df = pd.read_excel(excel_file) 262 | results = [] 263 | for iou_type, coco_res in coco_evaluator.coco_eval.items(): 264 | result = ['%.3f' % val for val in coco_res.stats] 265 | results += result 266 | df["Epoch_{}".format(epoch)] = results 267 | for col_name in df: 268 | if col_name[:7] == "Unnamed": 269 | del df[col_name] 270 | df.to_excel(excel_file, encoding='utf-8') 271 | 272 | 273 | def _get_init_coco_result(): 274 | init_data = {} 275 | init_data["IoU metric"] = ["BBOX"] * 12 + ["Segmentation"] * 12 276 | init_data["Type"] = ["Average Precision (AP)"] * 6 + ["Average Recall (AR)"] * 6 \ 277 | + ["Average Precision (AP)"] * 6 + ["Average Recall (AR)"] * 6 278 | init_data["IoU"] = ["IoU=0.50:0.95", "IoU=0.50", "IoU=0.75"] \ 279 | + ["IoU=0.50:0.95"] * 9 \ 280 | + ["IoU=0.50:0.95", "IoU=0.50", "IoU=0.75"] \ 281 | + ["IoU=0.50:0.95"] * 9 282 | init_data["Area"] = ["area=all"] * 3 + ["area=samll", "area=medium", "area=large"] \ 283 | + ["area=all"] * 3 + ["area=samll", "area=medium", "area=large"] \ 284 | + ["area=all"] * 3 + ["area=samll", "area=medium", "area=large"] \ 285 | + ["area=all"] * 3 + ["area=samll", "area=medium", "area=large"] 286 | init_data["Max Det"] = ["maxDets=100"] * 6 + ["maxDets=1", "maxDets=10"] + ["maxDets=100"] * 4 \ 287 | + ["maxDets=100"] * 6 + ["maxDets=1", "maxDets=10"] + ["maxDets=100"] * 4 288 | df = pd.DataFrame(data=init_data) 289 | return df 290 | -------------------------------------------------------------------------------- /config/train_real_data.yaml: -------------------------------------------------------------------------------- 1 | # dataset 2 | dataset: real_tray 3 | label_type: unseen_food 4 | dataset_path: 5 | - PATH/TO/DATASET 6 | 7 | num_test: all 8 | 9 | img_resize_w: 640 10 | img_resize_h: 480 11 | 12 | num_workers: 4 13 | -------------------------------------------------------------------------------- /config/train_syn_data.yaml: -------------------------------------------------------------------------------- 1 | # dataset 2 | dataset: synthetic 3 | label_type: unseen_food 4 | dataset_path: 5 | - PATH/TO/DATASET 6 | 7 | num_test: all 8 | 9 | img_resize_w: 640 10 | img_resize_h: 480 11 | 12 | num_workers: 4 13 | -------------------------------------------------------------------------------- /demo_stream.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | import time 6 | from PIL import Image, ImageDraw, ImageFont 7 | 8 | import torch 9 | import torchvision.transforms as T 10 | 11 | from models import get_instance_segmentation_model 12 | 13 | 14 | import warnings 15 | warnings.filterwarnings(action='ignore') 16 | 17 | def main(args): 18 | # get device (GPU or CPU) 19 | if torch.cuda.is_available(): 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.enabled = True 24 | device = torch.device("cuda") 25 | else: 26 | device = torch.device("cpu") 27 | 28 | # load model (MASKRCNN) 29 | print("... loading model") 30 | model = get_instance_segmentation_model(num_classes=2) 31 | model.to(device) 32 | model.load_state_dict(torch.load(args.ckp_path, map_location=torch.device('cpu'))) 33 | model.eval() 34 | thres = float(args.thres) 35 | 36 | # load transform 37 | transform = T.Compose([T.ToTensor(), 38 | T.Normalize( 39 | mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]), 41 | ]) 42 | print("... loading", end=' ') 43 | 44 | # load camera 45 | cam_id = args.cam_id 46 | cam = cv2.VideoCapture(cam_id) 47 | assert cam.isOpened(), 'Cannot capture source' 48 | 49 | # inference 50 | print("+++ Start inference !") 51 | while cam.isOpened(): 52 | start_time = time.time() 53 | # load and transform image 54 | start_time = time.time() 55 | ret, img_arr = cam.read() 56 | IMG_H, IMG_W, IMG_C = img_arr.shape 57 | img_data = Image.fromarray(img_arr).convert("RGB") 58 | img_tensor = transform(img_data) 59 | img_tensor = img_tensor.unsqueeze(0).to(device) 60 | 61 | # forward and post-process results 62 | pred_result = model(img_tensor, None)[0] 63 | pred_mask = pred_result['masks'].cpu().detach().numpy().transpose(0, 2, 3, 1) 64 | pred_mask[pred_mask >= 0.5] = 1 65 | pred_mask[pred_mask < 0.5] = 0 66 | pred_mask = np.repeat(pred_mask, 3, 3) 67 | pred_scores = pred_result['scores'].cpu().detach().numpy() 68 | pred_boxes = pred_result['boxes'].cpu().detach().numpy() 69 | # pred_labels = pred_result['labels'] 70 | 71 | # draw predictions 72 | ids = np.where(pred_scores > thres)[0] 73 | colors = np.random.randint(0, 255, (len(ids), 3)) 74 | # set colors considering location and size of bbox 75 | colors = [] 76 | for (x1, y1, x2, y2) in pred_boxes: 77 | w = max(x1, x2) - min(x1, x2) 78 | h = max(y1, y2) - min(y1, y2) 79 | x = (x1 + x2) / 2 80 | y = (y1 + y2) / 2 81 | ratio_x, ratio_y = x / IMG_W, y / IMG_H 82 | ratio_s = min(w, h) / max(w, h) 83 | ratio_s = 1 + ratio_s if ratio_s < 0 else ratio_s 84 | ratio_x, ratio_y, ratio_s = int(ratio_x*255), int(ratio_y*255), int(ratio_s*255) 85 | colors.append([ratio_x, ratio_y, ratio_s]) 86 | 87 | for color_i, pred_i in enumerate(ids): 88 | color = tuple(map(int, colors[color_i])) 89 | # draw segmentation 90 | mask = pred_mask[pred_i] 91 | mask = mask * color 92 | img_arr = cv2.addWeighted(img_arr, 1, mask.astype(np.uint8), 0.5, 0) 93 | # draw bbox 94 | x1, y1, x2, y2 = map(int, pred_boxes[pred_i]) 95 | cv2.rectangle(img_arr, (x1, y1), (x2, y2), color, 2) 96 | # put text 97 | vis_text = "FOOD({:.2f})".format(pred_scores[pred_i]) 98 | cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, [255, 255, 255], 2) 99 | cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) 100 | 101 | cv2.imshow('frame', img_arr) 102 | key = cv2.waitKey(1) 103 | if key == 27: 104 | break 105 | 106 | print("FPS is {:.2f} | Image Size:{}\t\t".format( 107 | 1/(time.time()-start_time), img_arr.shape), end='\r') 108 | cam.release() 109 | cv2.destroyAllWindows() 110 | 111 | def get_args_parser(): 112 | parser = argparse.ArgumentParser('Set Visualization of unseen food segmentation', add_help=False) 113 | parser.add_argument("--gpu", type=str, default="0", help="gpu ID number to use.") 114 | parser.add_argument("--cpu", action="store_true") 115 | parser.add_argument("--ckp_path", type=str, default="ckps/MASKRCNN_SIM.tar", help='path/to/trained/weight') 116 | parser.add_argument("--cam_id", type=int, default=0, help="CAMERA ID to use.") 117 | parser.add_argument("--thres", type=float, default=0.5, help='threshold for instance segmentation') 118 | return parser 119 | 120 | if __name__ == '__main__': 121 | parser = argparse.ArgumentParser('Visualizing Food Image Recognition', parents=[get_args_parser()]) 122 | args = parser.parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import math 3 | import time 4 | 5 | import torch 6 | import torchvision 7 | from coco_eval import CocoEvaluator 8 | 9 | 10 | def train_one_epoch(epoch, model, data_loader, optimizer, device, 11 | lr_update=None, lr_scheduler=None, print_freq=100): 12 | model.train() 13 | # logger 14 | metric_logger = utils.MetricLogger(delimiter=" ") 15 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 16 | header = 'Epoch: [{}]'.format(epoch) 17 | print('----' * 20) 18 | print("[Epoch {}] The number of batch: {}".format(epoch, len(data_loader))) 19 | for idx, (images, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 20 | # get data 21 | images = [image.to(device) for image in images] 22 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 23 | # forward 24 | loss_dict = model(images, targets) 25 | losses = loss_dict['loss_classifier'] + loss_dict['loss_box_reg'] \ 26 | + loss_dict['loss_mask'] + loss_dict['loss_objectness'] \ 27 | + loss_dict['loss_rpn_box_reg'] 28 | # backporp 29 | optimizer.zero_grad() 30 | losses.backward() 31 | optimizer.step() 32 | # sum loss 33 | loss_dict_reduced = utils.reduce_dict(loss_dict) 34 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 35 | loss_value = losses_reduced.item() 36 | if not math.isfinite(loss_value): 37 | print("Loss is {}, stopping training".format(loss_value)) 38 | print(loss_dict_reduced) 39 | sys.exit(1) 40 | # logging 41 | metric_logger.update(loss=losses_reduced, **loss_dict_reduced) 42 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 43 | # learning rate scheduler 44 | curr_itr = idx + epoch*len(data_loader) + 1 45 | if lr_scheduler is not None and lr_update is not None: 46 | if curr_itr % lr_update == 0: 47 | print("+++ LR Update !") 48 | lr_scheduler.step() 49 | 50 | def _get_iou_types(model): 51 | model_without_ddp = model 52 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 53 | model_without_ddp = model.module 54 | iou_types = ["bbox"] 55 | if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): 56 | iou_types.append("segm") 57 | if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): 58 | iou_types.append("keypoints") 59 | return iou_types 60 | 61 | @torch.no_grad() 62 | def evaluate(coco, model, data_loader, device): 63 | n_threads = torch.get_num_threads() 64 | # FIXME remove this and make paste_masks_in_image run on the GPU 65 | torch.set_num_threads(1) 66 | cpu_device = torch.device("cpu") 67 | model.eval() 68 | metric_logger = utils.MetricLogger(delimiter=" ") 69 | header = 'Test:' 70 | 71 | iou_types = _get_iou_types(model) 72 | coco_evaluator = CocoEvaluator(coco, iou_types) 73 | 74 | for image, targets in metric_logger.log_every(data_loader, 100, header): 75 | image = list(img.to(device) for img in image) 76 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 77 | 78 | torch.cuda.synchronize() 79 | model_time = time.time() 80 | outputs = model(image) 81 | 82 | outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] 83 | model_time = time.time() - model_time 84 | 85 | res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} 86 | evaluator_time = time.time() 87 | coco_evaluator.update(res) 88 | evaluator_time = time.time() - evaluator_time 89 | metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) 90 | 91 | # gather the stats from all processes 92 | metric_logger.synchronize_between_processes() 93 | print("Averaged stats:", metric_logger) 94 | coco_evaluator.synchronize_between_processes() 95 | 96 | # accumulate predictions from all images 97 | coco_evaluator.accumulate() 98 | coco_evaluator.summarize() 99 | torch.set_num_threads(n_threads) 100 | return coco_evaluator 101 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | from PIL import Image, ImageDraw, ImageFont 6 | 7 | import torch 8 | import torchvision.transforms as T 9 | 10 | from models import get_instance_segmentation_model 11 | 12 | 13 | import warnings 14 | warnings.filterwarnings(action='ignore') 15 | 16 | def main(args): 17 | # get device (GPU or CPU) 18 | if torch.cuda.is_available(): 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 21 | torch.backends.cudnn.benchmark = True 22 | torch.backends.cudnn.enabled = True 23 | device = torch.device("cuda") 24 | else: 25 | device = torch.device("cpu") 26 | 27 | # load model (MASKRCNN) 28 | print("... loading model") 29 | model = get_instance_segmentation_model(num_classes=2) 30 | model.to(device) 31 | model.load_state_dict(torch.load(args.ckp_path)) 32 | model.eval() 33 | 34 | # load images and transform 35 | print("... loading", end=' ') 36 | img_list = sorted(os.listdir(args.input_dir)) 37 | transform = T.Compose([T.ToTensor(), 38 | T.Normalize( 39 | mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]), 41 | ]) 42 | print("{} images".format(len(img_list))) 43 | 44 | # visualization setting 45 | if not os.path.isdir(args.output_dir): 46 | os.makedirs(args.output_dir, exist_ok=True) 47 | thres = float(args.thres) 48 | 49 | # inference 50 | print("+++ Start inference !") 51 | for i, img_name in enumerate(img_list): 52 | print("... inference ({}/{}) _ {}".format(i+1, len(img_list), img_name)) 53 | # load and transform image 54 | img_file = os.path.join(args.input_dir, img_name) 55 | img_data = Image.open(img_file).convert("RGB") 56 | # img_tensor = img_data.resize(img_resize, Img.BICUBIC) 57 | img_tensor = transform(img_data) 58 | img_tensor = img_tensor.unsqueeze(0).to(device) 59 | img_arr = np.array(img_data).astype(np.uint8) 60 | 61 | # forward and post-process results 62 | pred_result = model(img_tensor, None)[0] 63 | pred_mask = pred_result['masks'].cpu().detach().numpy().transpose(0, 2, 3, 1) 64 | pred_mask[pred_mask >= 0.5] = 1 65 | pred_mask[pred_mask < 0.5] = 0 66 | pred_mask = np.repeat(pred_mask, 3, 3) 67 | pred_scores = pred_result['scores'].cpu().detach().numpy() 68 | pred_boxes = pred_result['boxes'].cpu().detach().numpy() 69 | # pred_labels = pred_result['labels'] 70 | 71 | # draw predictions 72 | # print("[{} Scores]:".format(pred_scores.shape[0]), list(pred_scores)) 73 | ids = np.where(pred_scores > thres)[0] 74 | colors = np.random.randint(0, 255, (len(ids), 3)) 75 | for color_i, pred_i in enumerate(ids): 76 | color = tuple(map(int, colors[color_i])) 77 | # draw segmentation 78 | mask = pred_mask[pred_i] 79 | mask = mask * color 80 | img_arr = cv2.addWeighted(img_arr, 1, mask.astype(np.uint8), 0.5, 0) 81 | # draw bbox and text 82 | x1, y1, x2, y2 = map(int, pred_boxes[pred_i]) 83 | cv2.rectangle(img_arr, (x1, y1), (x2, y2), color, 2) 84 | vis_text = "FOOD({:.2f})".format(pred_scores[pred_i]) 85 | cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, [255, 255, 255], 2) 86 | cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) 87 | # # save for debugging 88 | # cv2.imwrite("tmp_{}.png".format(color_i), img_arr) 89 | # save visualized image 90 | img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB) 91 | save_name = os.path.join(args.output_dir, img_name) 92 | cv2.imwrite(save_name, img_arr) 93 | 94 | 95 | def get_args_parser(): 96 | parser = argparse.ArgumentParser('Set Visualization of unseen food segmentation', add_help=False) 97 | parser.add_argument('--input_dir', type=str, help='path/to/images/for/inference') 98 | parser.add_argument('--output_dir', type=str, help='path/to/save/visualized/images') 99 | parser.add_argument("--gpu", type=str, default="0", help="gpu ID number to use.") 100 | parser.add_argument("--ckp_path", type=str, default="ckps/MASKRCNN_SIM.tar", help='path/to/trained/weight') 101 | parser.add_argument("--thres", type=float, default=0.5, help='threshold for instance segmentation') 102 | return parser 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser('Visualizing Food Image Recognition', parents=[get_args_parser()]) 106 | args = parser.parse_args() 107 | main(args) 108 | -------------------------------------------------------------------------------- /inference_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | import numpy as np 5 | from PIL import Image, ImageDraw, ImageFont 6 | 7 | import torch 8 | import torchvision.transforms as T 9 | 10 | from models import get_instance_segmentation_model 11 | 12 | 13 | import warnings 14 | warnings.filterwarnings(action='ignore') 15 | 16 | def main(args): 17 | # get device (GPU or CPU) 18 | if torch.cuda.is_available(): 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 21 | torch.backends.cudnn.benchmark = True 22 | torch.backends.cudnn.enabled = True 23 | device = torch.device("cuda") 24 | else: 25 | device = torch.device("cpu") 26 | 27 | # load model (MASKRCNN) 28 | print("... loading model") 29 | model = get_instance_segmentation_model(num_classes=2) 30 | model.to(device) 31 | model.load_state_dict(torch.load(args.ckp_path)) 32 | model.eval() 33 | 34 | # load transform 35 | transform = T.Compose([T.ToTensor(), 36 | T.Normalize( 37 | mean=[0.485, 0.456, 0.406], 38 | std=[0.229, 0.224, 0.225]), 39 | ]) 40 | print("... loading", end=' ') 41 | cap = cv2.VideoCapture(args.input_dir) 42 | num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 43 | IMG_W, IMG_H = int(cap.get(3)), int(cap.get(4)) 44 | print("{} images".format(num_frames)) 45 | 46 | # visualization setting 47 | if not os.path.isdir(args.output_dir): 48 | os.makedirs(args.output_dir, exist_ok=True) 49 | thres = float(args.thres) 50 | 51 | # save as video 52 | fourcc = cv2.VideoWriter_fourcc(*'DIVX') 53 | out = cv2.VideoWriter(os.path.join(args.output_dir, "demo.mp4"), fourcc, 30.0, (IMG_W, IMG_H)) 54 | 55 | # inference 56 | print("+++ Start inference !") 57 | for fr_idx in range(num_frames): 58 | print("... inference ({}/{})".format(fr_idx+1, num_frames)) 59 | # load and transform image 60 | ret, img_arr = cap.read() 61 | IMG_H, IMG_W, IMG_C = img_arr.shape 62 | img_data = Image.fromarray(img_arr).convert("RGB") 63 | img_tensor = transform(img_data) 64 | img_tensor = img_tensor.unsqueeze(0).to(device) 65 | 66 | # forward and post-process results 67 | pred_result = model(img_tensor, None)[0] 68 | pred_mask = pred_result['masks'].cpu().detach().numpy().transpose(0, 2, 3, 1) 69 | pred_mask[pred_mask >= 0.5] = 1 70 | pred_mask[pred_mask < 0.5] = 0 71 | pred_mask = np.repeat(pred_mask, 3, 3) 72 | pred_scores = pred_result['scores'].cpu().detach().numpy() 73 | pred_boxes = pred_result['boxes'].cpu().detach().numpy() 74 | # pred_labels = pred_result['labels'] 75 | 76 | # draw predictions 77 | ids = np.where(pred_scores > thres)[0] 78 | colors = np.random.randint(0, 255, (len(ids), 3)) 79 | # set colors considering location and size of bbox 80 | colors = [] 81 | for (x1, y1, x2, y2) in pred_boxes: 82 | w = max(x1, x2) - min(x1, x2) 83 | h = max(y1, y2) - min(y1, y2) 84 | x = (x1 + x2) / 2 85 | y = (y1 + y2) / 2 86 | ratio_x, ratio_y = x / IMG_W, y / IMG_H 87 | ratio_s = min(w, h) / max(w, h) 88 | ratio_s = 1 + ratio_s if ratio_s < 0 else ratio_s 89 | ratio_x, ratio_y, ratio_s = int(ratio_x*255), int(ratio_y*255), int(ratio_s*255) 90 | colors.append([ratio_x, ratio_y, ratio_s]) 91 | 92 | for color_i, pred_i in enumerate(ids): 93 | color = tuple(map(int, colors[color_i])) 94 | # draw segmentation 95 | mask = pred_mask[pred_i] 96 | mask = mask * color 97 | img_arr = cv2.addWeighted(img_arr, 1, mask.astype(np.uint8), 0.5, 0) 98 | # draw bbox and text 99 | x1, y1, x2, y2 = map(int, pred_boxes[pred_i]) 100 | cv2.rectangle(img_arr, (x1, y1), (x2, y2), color, 2) 101 | # vis_text = "FOOD({:.2f})".format(pred_scores[pred_i]) 102 | # cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, [255, 255, 255], 2) 103 | # cv2.putText(img_arr, vis_text, (x1+5, y1+15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1) 104 | # # save for debugging 105 | # cv2.imwrite("tmp_{}.png".format(color_i), img_arr) 106 | # save visualized image 107 | # img_arr = cv2.cvtColor(img_arr, cv2.COLOR_BGR2RGB) 108 | out.write(img_arr) 109 | 110 | save_name = os.path.join(args.output_dir, "{}.png".format(fr_idx)) 111 | cv2.imwrite(save_name, img_arr) 112 | 113 | cap.release() 114 | out.release() 115 | cv2.destroyAllWindows() 116 | 117 | def get_args_parser(): 118 | parser = argparse.ArgumentParser('Set Visualization of unseen food segmentation', add_help=False) 119 | parser.add_argument('--input_dir', type=str, help='path/to/images/for/inference') 120 | parser.add_argument('--output_dir', type=str, help='path/to/save/visualized/images') 121 | parser.add_argument("--gpu", type=str, default="0", help="gpu ID number to use.") 122 | parser.add_argument("--ckp_path", type=str, default="ckps/MASKRCNN_SIM.tar", help='path/to/trained/weight') 123 | parser.add_argument("--thres", type=float, default=0.5, help='threshold for instance segmentation') 124 | return parser 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser('Visualizing Food Image Recognition', parents=[get_args_parser()]) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .synthetic_loader import * 2 | from .real_loader import * 3 | from .unimib import * 4 | 5 | def get_dataset(config, mode='train'): 6 | """ mode = 'train' or 'val' 7 | """ 8 | if config["dataset"] == 'synthetic': 9 | mode = 'val' if mode == 'test' else mode 10 | dataset = SyntheticDataset(config=config, mode=mode) 11 | elif config["dataset"] == 'real_tray': 12 | mode = 'val' if mode == 'test' else mode 13 | dataset = RealTrayDataset(config=config, mode=mode) 14 | elif config["dataset"] == 'unimib2016': 15 | mode = 'test' if mode == 'val' else mode 16 | dataset = UNIMIB2016Dataset(config=config, mode=mode) 17 | elif config["dataset"] == 'unimib2016_fake': 18 | dataset = UNIMIB2016DatasetFake(config=config, mode=mode) 19 | else: 20 | raise ValueError("Wrong Dataset Name in CONFIG {}".format(config["dataset"])) 21 | return dataset 22 | -------------------------------------------------------------------------------- /loader/real_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from .transforms import get_transform 10 | 11 | class RealTrayDataset(Dataset): 12 | def __init__(self, config, mode): 13 | assert mode in ['train', 'val'] 14 | self.mode = mode # train mode or validation mode 15 | self.label_type = config["label_type"] 16 | # load rgb and segmentation mask images 17 | self.rgb_list = [] 18 | self.seg_list = [] 19 | self.slot_label_list = [] 20 | data_roots = config["dataset_path"] 21 | if isinstance(data_roots, str): data_roots = [data_roots] 22 | for data_root in data_roots: 23 | self.rgb_path = os.path.join(data_root, "RGBImages") 24 | self.seg_path = os.path.join(data_root, "Annotations", "Annotations_all") 25 | rgb_list = list(sorted(glob.glob(os.path.join(self.rgb_path, '*')))) 26 | if 'Thumbs.db' in rgb_list: rgb_list.remove("Thumbs.db") 27 | seg_list = list(sorted(glob.glob(os.path.join(self.seg_path, '*.mask')))) 28 | slot_label_list = list(sorted(glob.glob(os.path.join(self.seg_path, '*.txt')))) 29 | 30 | if config["num_test"] != "all": 31 | if self.mode == 'train': 32 | rgb_list = rgb_list[config["num_test"]:] 33 | seg_list = seg_list[config["num_test"]:] 34 | slot_label_list = slot_label_list[config["num_test"]:] 35 | else: 36 | rgb_list = rgb_list[:config["num_test"]] 37 | seg_list = seg_list[:config["num_test"]] 38 | slot_label_list = slot_label_list[:config["num_test"]] 39 | self.rgb_list += rgb_list 40 | self.seg_list += seg_list 41 | self.slot_label_list += slot_label_list 42 | assert len(self.rgb_list) > 0 43 | assert len(self.rgb_list) == len(self.seg_list) 44 | assert len(self.rgb_list) == len(self.slot_label_list) 45 | 46 | # load rgb transform 47 | self.transform_ver = config["transform"] if "transform" in config else "torchvision" 48 | self.rgb_transform, self.transform_ver = get_transform(self.transform_ver, self.mode) 49 | self.width = config["img_resize_w"] if "img_resize_w" in config else 640 50 | self.height = config["img_resize_h"] if "img_resize_h" in config else 480 51 | 52 | # labels in mask images 53 | self.slot_cat2label = { 54 | '1_slot': 1, 55 | '2_slot': 2, 56 | '3_slot': 3, 57 | '4_slot': 4, 58 | '5_slot': 5 59 | } 60 | 61 | # # TODO: delete after debugging 62 | # self.rgb_list = self.rgb_list[:10] 63 | # self.seg_list = self.seg_list[:10] 64 | 65 | def __len__(self): 66 | return len(self.rgb_list) 67 | 68 | def __getitem__(self, idx): 69 | # load rgb image 70 | rgb = Image.open(self.rgb_list[idx]).convert("RGB") 71 | rgb = rgb.resize((self.width, self.height), Image.BICUBIC) 72 | # load mask image 73 | make_arr = Image.open(self.seg_list[idx]).convert("L") 74 | make_arr = make_arr.resize((self.width, self.height), Image.NEAREST) 75 | make_arr = np.array(make_arr) 76 | 77 | # extrack masks 78 | slot_labels = open(self.slot_label_list[idx]).readlines() 79 | slot_id2category, slot_valid_ids = {}, [] 80 | for slot_label in slot_labels: 81 | label, category = slot_label.split(' ') 82 | category = category[:-1] if '\n' in category else category 83 | slot_id2category[int(label)] = category 84 | if category in self.slot_cat2label: slot_valid_ids.append(int(label)) 85 | 86 | obj_ids = np.array([mask_id for mask_id in np.unique(make_arr) if mask_id in slot_valid_ids]) 87 | masks = make_arr == obj_ids[:, None, None] 88 | 89 | # get bounding box coordinates for each mask 90 | num_objs = len(obj_ids) 91 | temp_masks = [] 92 | boxes = [] 93 | labels = [] 94 | for i in range(num_objs): 95 | pos = np.where(masks[i]) 96 | xmin = np.min(pos[1]) 97 | xmax = np.max(pos[1]) 98 | ymin = np.min(pos[0]) 99 | ymax = np.max(pos[0]) 100 | if int(xmax-xmin) < 1 or int(ymax-ymin) < 1: 101 | continue 102 | temp_masks.append(masks[i]) 103 | boxes.append([xmin, ymin, xmax, ymax]) 104 | labels.append(1) 105 | masks = np.array(temp_masks) 106 | labels = np.array(labels) 107 | boxes = np.array(boxes) 108 | 109 | # tensor format data 110 | labels = torch.as_tensor(labels, dtype=torch.int64) 111 | masks = torch.as_tensor(masks, dtype=torch.uint8) 112 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 113 | image_id = torch.tensor([idx]) 114 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 115 | iscrowd = torch.zeros((len(labels),), dtype=torch.int64) 116 | 117 | target = {} 118 | target["boxes"] = boxes 119 | target["masks"] = masks 120 | target["image_id"] = image_id 121 | target["area"] = area 122 | target["iscrowd"] = iscrowd 123 | target["labels"] = labels 124 | 125 | # RGB transform 126 | if self.transform_ver == "torchvision": 127 | rgb = self.rgb_transform(rgb) 128 | elif self.transform_ver == "albumentation": 129 | rgb = self.rgb_transform(image=np.array(rgb))['image'] 130 | else: 131 | raise ValueError("Wrong transform version {}".format(self.transform_ver)) 132 | 133 | return rgb, target 134 | -------------------------------------------------------------------------------- /loader/synthetic_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | from .transforms import get_transform 10 | 11 | class SyntheticDataset(Dataset): 12 | def __init__(self, config, mode): 13 | assert mode in ['train', 'val'] 14 | self.mode = mode # train mode or validation mode 15 | self.label_type = config["label_type"] 16 | # load rgb and segmentation mask images 17 | self.rgb_list = [] 18 | self.seg_list = [] 19 | data_roots = config["dataset_path"] 20 | if isinstance(data_roots, str): data_roots = [data_roots] 21 | for data_root in data_roots: 22 | self.rgb_path = os.path.join(data_root, "image") 23 | self.seg_path = os.path.join(data_root, "mask_obj") 24 | rgb_list = list(sorted(glob.glob(os.path.join(self.rgb_path, '*')))) 25 | seg_list = list(sorted(glob.glob(os.path.join(self.seg_path, '*')))) 26 | if 'Thumbs.db' in rgb_list: rgb_list.remove("Thumbs.db") 27 | if 'Thumbs.db' in seg_list: seg_list.remove("Thumbs.db") 28 | 29 | if self.mode == 'train': 30 | rgb_list = rgb_list[config["num_test"]:] 31 | seg_list = seg_list[config["num_test"]:] 32 | else: 33 | rgb_list = rgb_list[:config["num_test"]] 34 | seg_list = seg_list[:config["num_test"]] 35 | self.rgb_list += rgb_list 36 | self.seg_list += seg_list 37 | assert len(self.rgb_list) > 0 38 | assert len(self.rgb_list) == len(self.seg_list) 39 | 40 | # load rgb transform 41 | self.transform_ver = config["transform"] if "transform" in config else "torchvision" 42 | self.rgb_transform, self.transform_ver = get_transform(self.transform_ver, self.mode) 43 | self.width = config["img_resize_w"] if "img_resize_w" in config else 640 44 | self.height = config["img_resize_h"] if "img_resize_h" in config else 480 45 | 46 | # colors of label in mask images 47 | self.id2color = {1: (255, 0, 0), 48 | 2: (0, 255, 0), 49 | 3: (0, 0, 255), 50 | 4: (255, 0, 255), 51 | 5: (0, 255, 255)} 52 | 53 | # # TODO: delete after debugging 54 | # self.rgb_list = self.rgb_list[:10] 55 | # self.seg_list = self.seg_list[:10] 56 | 57 | def __len__(self): 58 | return len(self.rgb_list) 59 | 60 | def __getitem__(self, idx): 61 | # load rgb image 62 | rgb = Image.open(self.rgb_list[idx]).convert("RGB") 63 | rgb = rgb.resize((self.width, self.height), Image.BICUBIC) 64 | # load mask image 65 | make_arr = Image.open(self.seg_list[idx]).convert("RGB") 66 | make_arr = make_arr.resize((self.width, self.height), Image.NEAREST) 67 | make_arr = np.array(make_arr) 68 | # extrack masks 69 | labels, masks, boxes = [], [], [] 70 | for slot_id, (r,g,b) in self.id2color.items(): 71 | cnd_r = make_arr[:,:, 0] == r 72 | cnd_g = make_arr[:,:, 1] == g 73 | cnd_b = make_arr[:,:, 2] == b 74 | mask = cnd_r*cnd_g*cnd_b 75 | if len(np.unique(mask)) == 1: continue 76 | # extract bbox 77 | pos = np.where(mask) 78 | xmin, xmax = np.min(pos[1]), np.max(pos[1]) 79 | ymin, ymax = np.min(pos[0]), np.max(pos[0]) 80 | # skip small boxes 81 | if int(xmax-xmin) < 1 or int(ymax-ymin) < 1: continue 82 | if self.label_type == "slot_food": label = slot_id 83 | elif self.label_type == "unseen_food": label = 1 84 | labels.append(label) 85 | masks.append(mask) 86 | boxes.append([xmin, ymin, xmax, ymax]) 87 | labels = np.array(labels) 88 | masks = np.array(masks) 89 | boxes = np.array(boxes) 90 | 91 | # tensor format data 92 | labels = torch.as_tensor(labels, dtype=torch.int64) 93 | masks = torch.as_tensor(masks, dtype=torch.uint8) 94 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 95 | image_id = torch.tensor([idx]) 96 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 97 | iscrowd = torch.zeros((len(labels),), dtype=torch.int64) 98 | 99 | target = {} 100 | target["boxes"] = boxes 101 | target["masks"] = masks 102 | target["image_id"] = image_id 103 | target["area"] = area 104 | target["iscrowd"] = iscrowd 105 | target["labels"] = labels 106 | 107 | # RGB transform 108 | if self.transform_ver == "torchvision": 109 | rgb = self.rgb_transform(rgb) 110 | elif self.transform_ver == "albumentation": 111 | rgb = self.rgb_transform(image=np.array(rgb))['image'] 112 | else: 113 | raise ValueError("Wrong transform version {}".format(self.transform_ver)) 114 | 115 | return rgb, target 116 | -------------------------------------------------------------------------------- /loader/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms import functional as F 3 | import albumentations 4 | import albumentations.pytorch.transforms as A_torch 5 | 6 | 7 | def get_transform(transform_ver, mode): 8 | if mode in ["val", "test"]: 9 | transform = transforms.Compose([ 10 | transforms.ToTensor(), 11 | transforms.Normalize( 12 | mean=[0.485, 0.456, 0.406], 13 | std=[0.229, 0.224, 0.225]), 14 | ]) 15 | transform_ver = "torchvision" 16 | elif mode == "train": 17 | if transform_ver == "albumentation": 18 | transform = albumentation() 19 | elif transform_ver == "torchvision": 20 | transform = torchvision() 21 | else: 22 | raise ValueError("Wrong transform mode {}".format(mode)) 23 | return transform, transform_ver 24 | 25 | 26 | def torchvision(): 27 | transform = transforms.Compose([ 28 | transforms.ColorJitter(brightness=0.2, 29 | contrast=0.4, 30 | saturation=0.3, 31 | hue=0.25), 32 | transforms.ToTensor(), 33 | transforms.Normalize( 34 | mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]), 36 | T.AddGaussianNoise(mean=0., std=0.05) 37 | ]) 38 | return transform 39 | 40 | def albumentation(): 41 | transform = albumentations.Compose([ 42 | albumentations.OneOf([ 43 | albumentations.GaussNoise(), 44 | albumentations.IAAAdditiveGaussianNoise() 45 | ]), 46 | albumentations.OneOf([ 47 | albumentations.MotionBlur(blur_limit=3, p=0.2), 48 | albumentations.MedianBlur(blur_limit=3, p=0.1), 49 | albumentations.Blur(blur_limit=2, p=0.1) 50 | ]), 51 | albumentations.OneOf([ 52 | albumentations.RandomBrightness(limit=(0.1, 0.4)), 53 | albumentations.HueSaturationValue(hue_shift_limit=(0, 128), sat_shift_limit=(0, 60), val_shift_limit=(0, 20)), 54 | albumentations.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=30) 55 | ]), 56 | albumentations.OneOf([ 57 | albumentations.CLAHE(), 58 | albumentations.ChannelShuffle(), 59 | albumentations.IAASharpen(), 60 | albumentations.IAAEmboss(), 61 | albumentations.RandomBrightnessContrast(), 62 | ]), 63 | albumentations.OneOf([ 64 | albumentations.RandomGamma(gamma_limit=(35,255)), 65 | albumentations.OpticalDistortion(), 66 | albumentations.GridDistortion(), 67 | albumentations.IAAPiecewiseAffine() 68 | ]), 69 | A_torch.ToTensor(normalize={ 70 | "mean": [0.485, 0.456, 0.406], 71 | "std" : [0.229, 0.224, 0.225]}) 72 | ]) 73 | return transform 74 | -------------------------------------------------------------------------------- /loader/unimib.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import glob 5 | from PIL import Image 6 | import json 7 | import skimage 8 | 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | from .transforms import get_transform 14 | 15 | class UNIMIB2016Dataset(Dataset): 16 | def __init__(self, config, mode): 17 | assert mode in ['train', 'test'] 18 | self.mode = mode # train mode or validation mode 19 | self.label_type = config["label_type"] 20 | # load datas from annotation json 21 | self.data_root = config["dataset_path"] 22 | self.ann = os.path.join(self.data_root, "annotations", "{}.json".format(mode)) 23 | self.ann = json.load(open(self.ann)) 24 | self.rgb_list = list(self.ann.keys()) 25 | # load rgb transform 26 | self.transform_ver = config["transform"] if "transform" in config else "torchvision" 27 | self.rgb_transform, self.transform_ver = get_transform(self.transform_ver, self.mode) 28 | self.width = config["img_resize_w"] if "img_resize_w" in config else 640 29 | self.height = config["img_resize_h"] if "img_resize_h" in config else 480 30 | 31 | # # TODO: delete after debugging 32 | # self.rgb_list = self.rgb_list[:10] 33 | # self.seg_list = self.seg_list[:10] 34 | # self.rgb_list = list(reversed(self.rgb_list)) 35 | 36 | def __len__(self): 37 | return len(self.rgb_list) 38 | 39 | def __getitem__(self, idx): 40 | # load rgb image 41 | img_name = self.rgb_list[idx] 42 | img_file = os.path.join(self.data_root, "images", img_name+'.jpg') 43 | rgb = Image.open(img_file).convert("RGB") 44 | img_w, img_h = rgb.size 45 | rgb = rgb.resize((self.width, self.height), Image.BICUBIC) 46 | # extract mask from annotation 47 | ann = self.ann[img_name] 48 | num_instance = len(ann) 49 | masks = [] 50 | boxes = [] 51 | labels = [] 52 | for inst_idx, poly_pts in enumerate(ann): 53 | # # extract mask 54 | # poly_dict = list(ann_inst.values())[0] 55 | # poly_pts = poly_dict["BR"] 56 | rr, cc = skimage.draw.polygon(poly_pts[1::2], poly_pts[::2]) 57 | # rr = (img_h-1) - rr 58 | # cc = (img_w-1) - cc 59 | # fill mask 60 | mask = np.zeros((img_h, img_w), dtype=np.uint8) 61 | mask[rr, cc] = 1 62 | mask = cv2.resize(mask, (self.width, self.height), cv2.INTER_AREA) 63 | # extract bbox 64 | pos = np.where(mask) 65 | xmin = np.min(pos[1]) 66 | xmax = np.max(pos[1]) 67 | ymin = np.min(pos[0]) 68 | ymax = np.max(pos[0]) 69 | # skip small boxes 70 | if int(xmax-xmin) < 1 or int(ymax-ymin) < 1: continue 71 | # same label for unseen food segmentation 72 | label = 1 73 | labels.append(label) 74 | masks.append(mask) 75 | boxes.append([xmin, ymin, xmax, ymax]) 76 | 77 | labels = np.array(labels) 78 | masks = np.array(masks) 79 | boxes = np.array(boxes) 80 | 81 | # tensor format data 82 | labels = torch.as_tensor(labels, dtype=torch.int64) 83 | masks = torch.as_tensor(masks, dtype=torch.uint8) 84 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 85 | image_id = torch.tensor([idx]) 86 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 87 | iscrowd = torch.zeros((len(labels),), dtype=torch.int64) 88 | 89 | target = {} 90 | target["boxes"] = boxes 91 | target["masks"] = masks 92 | target["image_id"] = image_id 93 | target["area"] = area 94 | target["iscrowd"] = iscrowd 95 | target["labels"] = labels 96 | 97 | # RGB transform 98 | if self.transform_ver == "torchvision": 99 | rgb = self.rgb_transform(rgb) 100 | elif self.transform_ver == "albumentation": 101 | rgb = self.rgb_transform(image=np.array(rgb))['image'] 102 | else: 103 | raise ValueError("Wrong transform version {}".format(self.transform_ver)) 104 | 105 | return rgb, target 106 | 107 | 108 | class UNIMIB2016DatasetFake(Dataset): 109 | def __init__(self, config, mode): 110 | assert mode in ['train', 'val', 'test'] 111 | self.mode = mode # train mode or validation mode 112 | self.label_type = config["label_type"] 113 | # load datas from annotation json 114 | self.data_root = config["dataset_path"] 115 | self.ann = os.path.join(self.data_root, "annotations", "{}.json".format(mode)) 116 | self.ann = json.load(open(self.ann)) 117 | self.rgb_list = list(self.ann.keys()) 118 | # load rgb transform 119 | self.transform_ver = config["transform"] if "transform" in config else "torchvision" 120 | self.rgb_transform, self.transform_ver = get_transform(self.transform_ver, self.mode) 121 | self.width = config["img_resize_w"] if "img_resize_w" in config else 640 122 | self.height = config["img_resize_h"] if "img_resize_h" in config else 480 123 | 124 | # # TODO: delete after debugging 125 | # self.rgb_list = self.rgb_list[:10] 126 | # self.seg_list = self.seg_list[:10] 127 | # self.rgb_list = list(reversed(self.rgb_list)) 128 | 129 | def __len__(self): 130 | return len(self.rgb_list) 131 | 132 | def __getitem__(self, idx): 133 | # load rgb image 134 | img_name = self.rgb_list[idx] 135 | img_file = os.path.join(self.data_root, "images", img_name+'.jpg') 136 | rgb = Image.open(img_file).convert("RGB") 137 | img_w, img_h = rgb.size 138 | rgb = rgb.resize((self.width, self.height), Image.BICUBIC) 139 | # extract mask from annotation 140 | ann = self.ann[img_name] 141 | num_instance = len(ann) 142 | masks = [] 143 | boxes = [] 144 | labels = [] 145 | for inst_idx, ann_inst in enumerate(ann): 146 | # extract mask 147 | poly_dict = list(ann_inst.values())[0] 148 | poly_pts = poly_dict["BR"] 149 | rr, cc = skimage.draw.polygon(poly_pts[1::2], poly_pts[::2]) 150 | # rr = (img_h-1) - rr 151 | # cc = (img_w-1) - cc 152 | # fill mask 153 | mask = np.zeros((img_h, img_w), dtype=np.uint8) 154 | mask[rr, cc] = 1 155 | mask = cv2.resize(mask, (self.width, self.height), cv2.INTER_AREA) 156 | # extract bbox 157 | pos = np.where(mask) 158 | xmin = np.min(pos[1]) 159 | xmax = np.max(pos[1]) 160 | ymin = np.min(pos[0]) 161 | ymax = np.max(pos[0]) 162 | # skip small boxes 163 | if int(xmax-xmin) < 1 or int(ymax-ymin) < 1: continue 164 | # same label for unseen food segmentation 165 | label = 1 166 | labels.append(label) 167 | masks.append(mask) 168 | boxes.append([xmin, ymin, xmax, ymax]) 169 | 170 | labels = np.array(labels) 171 | masks = np.array(masks) 172 | boxes = np.array(boxes) 173 | 174 | # tensor format data 175 | labels = torch.as_tensor(labels, dtype=torch.int64) 176 | masks = torch.as_tensor(masks, dtype=torch.uint8) 177 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 178 | image_id = torch.tensor([idx]) 179 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 180 | iscrowd = torch.zeros((len(labels),), dtype=torch.int64) 181 | 182 | target = {} 183 | target["boxes"] = boxes 184 | target["masks"] = masks 185 | target["image_id"] = image_id 186 | target["area"] = area 187 | target["iscrowd"] = iscrowd 188 | target["labels"] = labels 189 | 190 | # RGB transform 191 | if self.transform_ver == "torchvision": 192 | rgb = self.rgb_transform(rgb) 193 | elif self.transform_ver == "albumentation": 194 | rgb = self.rgb_transform(image=np.array(rgb))['image'] 195 | else: 196 | raise ValueError("Wrong transform version {}".format(self.transform_ver)) 197 | 198 | return rgb, target 199 | 200 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .maskrcnn import get_instance_segmentation_model 2 | -------------------------------------------------------------------------------- /models/maskrcnn.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torchvision.transforms as T 3 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 4 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 5 | 6 | 7 | def get_instance_segmentation_model(num_classes): 8 | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=False) 9 | in_features = model.roi_heads.box_predictor.cls_score.in_features 10 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 11 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 12 | hidden_layer = 256 13 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, 14 | hidden_layer, 15 | num_classes) 16 | return model -------------------------------------------------------------------------------- /resources/demo_unseen_food_segmentation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gist-ailab/Food-Instance-Segmentation/fc9a1a7ee4d5800d48bed33d19daedfc97be581b/resources/demo_unseen_food_segmentation.gif -------------------------------------------------------------------------------- /resources/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gist-ailab/Food-Instance-Segmentation/fc9a1a7ee4d5800d48bed33d19daedfc97be581b/resources/example.png -------------------------------------------------------------------------------- /test_multi_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import yaml 5 | 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | 9 | import torch 10 | from models import get_instance_segmentation_model 11 | from loader import get_dataset 12 | from coco_utils import get_coco_api_from_dataset, coco_to_excel 13 | from engine import evaluate 14 | 15 | 16 | def collate_fn(batch): 17 | return tuple(zip(*batch)) 18 | 19 | def main(args): 20 | # get device (GPU or CPU) 21 | if torch.cuda.is_available(): 22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 23 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.enabled = True 26 | device = torch.device("cuda") 27 | else: 28 | device = torch.device("cpu") 29 | 30 | # fix seed for reproducibility 31 | torch.manual_seed(7777) 32 | 33 | # load config 34 | if args.config[-4:] != '.yaml': args.config += '.yaml' 35 | with open(args.config) as cfg_file: 36 | config = yaml.safe_load(cfg_file) 37 | print(config) 38 | 39 | # load dataset 40 | val_dataset = get_dataset(config, mode=args.mode) 41 | val_loader = torch.utils.data.DataLoader( 42 | dataset=val_dataset, num_workers=config["num_workers"], 43 | batch_size=1, shuffle=False, collate_fn=collate_fn) 44 | print("... Get COCO Dataloader for evaluation") 45 | coco = get_coco_api_from_dataset(val_loader.dataset) 46 | 47 | ckp_paths = glob.glob(os.path.join(args.ckp_dir, "*.tar")) 48 | for ckp_idx, ckp_path in enumerate(ckp_paths): 49 | print("[CKP {} / {}]".format(ckp_idx, len(ckp_paths)), "-----" * 10) 50 | # load model 51 | model = get_instance_segmentation_model(num_classes=2) 52 | model.load_state_dict(torch.load(ckp_path)) 53 | model.to(device) 54 | 55 | coco_evaluator = evaluate(coco, model, val_loader, device) 56 | 57 | if args.write_excel: 58 | os.makedirs(args.excel_save_dir, exist_ok=True) 59 | epoch = int(os.path.basename(ckp_path)[6:-4]) 60 | coco_to_excel( 61 | coco_evaluator, epoch, args.excel_save_dir, 62 | "{}_{}".format(config["dataset"], config["label_type"])) 63 | 64 | def get_args_parser(): 65 | parser = argparse.ArgumentParser('Set training of unseen food segmentation', add_help=False) 66 | parser.add_argument("--gpu", type=str, default="0", help="gpu number to use. 0, 1") 67 | parser.add_argument("--mode", type=str, default="test", help="test, val, train") 68 | parser.add_argument("--config", type=str, help="path/to/configfile/.yaml") 69 | parser.add_argument("--ckp_dir", type=str, default=None, help="path/to/trained/weight/directory") 70 | parser.add_argument("--write_excel", action="store_true", help='write COCO results into excel file') 71 | parser.add_argument("--excel_save_dir", type=str, help="path/to/save/result/excel") 72 | return parser 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser('Training Unseen Food Segmentation', parents=[get_args_parser()]) 76 | args = parser.parse_args() 77 | main(args) 78 | -------------------------------------------------------------------------------- /test_single_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') 7 | 8 | import torch 9 | from models import get_instance_segmentation_model 10 | from loader import get_dataset 11 | from coco_utils import get_coco_api_from_dataset, coco_to_excel 12 | from engine import evaluate 13 | 14 | 15 | def collate_fn(batch): 16 | return tuple(zip(*batch)) 17 | 18 | def main(args): 19 | # # get device (GPU or CPU) 20 | # if torch.cuda.is_available(): 21 | # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 22 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 23 | # torch.backends.cudnn.benchmark = True 24 | # torch.backends.cudnn.enabled = True 25 | # device = torch.device("cuda") 26 | # else: 27 | # device = torch.device("cpu") 28 | 29 | # fix seed for reproducibility 30 | torch.manual_seed(7777) 31 | 32 | # load config 33 | if args.config[-4:] != '.yaml': args.config += '.yaml' 34 | with open(args.config) as cfg_file: 35 | config = yaml.safe_load(cfg_file) 36 | print(config) 37 | 38 | # load dataset 39 | val_dataset = get_dataset(config, mode=args.mode) 40 | val_loader = torch.utils.data.DataLoader( 41 | dataset=val_dataset, num_workers=config["num_workers"], 42 | batch_size=1, shuffle=False, collate_fn=collate_fn) 43 | print("... Get COCO Dataloader for evaluation") 44 | coco = get_coco_api_from_dataset(val_loader.dataset) 45 | 46 | # load model 47 | model = get_instance_segmentation_model(num_classes=2) 48 | model.load_state_dict(torch.load(args.trained_ckp)) 49 | model.to(device) 50 | 51 | coco_evaluator = evaluate(coco, model, val_loader, device) 52 | # if args.write_excel: 53 | # coco_to_excel(coco_evaluator, epoch, logging_folder, args.eval_data) 54 | 55 | def get_args_parser(): 56 | parser = argparse.ArgumentParser('Set training of unseen food segmentation', add_help=False) 57 | parser.add_argument("--gpu", type=str, default="0", help="gpu number to use. 0, 1") 58 | parser.add_argument("--mode", type=str, default="test", help="test, val, train") 59 | parser.add_argument("--config", type=str, help="path/to/configfile/.yaml") 60 | parser.add_argument("--trained_ckp", type=str, default=None, help="path/to/trained/weight") 61 | parser.add_argument("--write_excel", action="store_true", help='write COCO results into excel file') 62 | return parser 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser('Training Unseen Food Segmentation', parents=[get_args_parser()]) 66 | args = parser.parse_args() 67 | 68 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 69 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 70 | torch.backends.cudnn.benchmark = True 71 | torch.backends.cudnn.enabled = True 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | 74 | main(args) 75 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import yaml 4 | 5 | import warnings 6 | warnings.filterwarnings('ignore') 7 | 8 | import torch 9 | from models import get_instance_segmentation_model 10 | from loader import get_dataset 11 | from coco_utils import get_coco_api_from_dataset, coco_to_excel 12 | from engine import train_one_epoch, evaluate 13 | 14 | 15 | def collate_fn(batch): 16 | return tuple(zip(*batch)) 17 | 18 | def main(args): 19 | # get device (GPU or CPU) 20 | if torch.cuda.is_available(): 21 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 22 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 23 | torch.backends.cudnn.benchmark = True 24 | torch.backends.cudnn.enabled = True 25 | device = torch.device("cuda") 26 | else: 27 | device = torch.device("cpu") 28 | 29 | # fix seed for reproducibility 30 | torch.manual_seed(7777) 31 | 32 | # load config 33 | if args.config[-4:] != '.yaml': args.config += '.yaml' 34 | with open(args.config) as cfg_file: 35 | config = yaml.safe_load(cfg_file) 36 | print(config) 37 | 38 | # load dataset 39 | train_dataset = get_dataset(config) 40 | train_loader = torch.utils.data.DataLoader( 41 | dataset=train_dataset, num_workers=config["num_workers"], 42 | batch_size=config["batch_size"], shuffle=True, collate_fn=collate_fn) 43 | val_dataset = get_dataset(config, mode="val") 44 | val_loader = torch.utils.data.DataLoader( 45 | dataset=val_dataset, num_workers=config["num_workers"], 46 | batch_size=1, shuffle=False, collate_fn=collate_fn) 47 | print("... Get COCO Dataloader for evaluation") 48 | coco = get_coco_api_from_dataset(val_loader.dataset) 49 | 50 | # load model 51 | model = get_instance_segmentation_model(num_classes=2) 52 | if args.resume: 53 | if args.resume_ckp: 54 | resume_ckp = args.resume_ckp 55 | elif "resume_ckp" in config: 56 | resume_ckp = config["resume_ckp"] 57 | else: 58 | raise ValueError("Wrong resume setting, there's no trainied weight in config and args") 59 | model.load_state_dict(torch.load(resume_ckp)) 60 | model.to(device) 61 | 62 | # construct an optimizer 63 | params = [p for p in model.parameters() if p.requires_grad] 64 | optimizer = torch.optim.Adam(params, lr=config["lr"], weight_decay=config["wd"]) 65 | lr_update = config["save_interval"] if "save_interval" in config else None 66 | 67 | # set training epoch 68 | start_epoch = args.resume_epoch if args.resume_epoch else 0 69 | if args.max_epoch: 70 | max_epoch = args.max_epoch 71 | else: 72 | max_epoch = config['max_epoch'] if "max_epoch" in config else 100 73 | assert start_epoch < max_epoch 74 | save_interval = config["save_interval"] if "save_interval" in config else 1 75 | 76 | # logging 77 | output_folder = config["save_dir"] 78 | os.makedirs(output_folder, exist_ok=True) 79 | 80 | print("+++ Start Training @start:{} @max: {}".format(start_epoch, max_epoch)) 81 | for epoch in range(start_epoch, max_epoch): 82 | # train 83 | train_one_epoch(epoch, model, train_loader, optimizer, device, lr_update) 84 | # validate and write results 85 | coco_evaluator = evaluate(coco, model, val_loader, device) 86 | # save weight 87 | if epoch % save_interval == 0: 88 | torch.save(model.state_dict(), '{}/epoch_{}.tar'.format(output_folder, epoch)) 89 | if args.write_excel: 90 | coco_to_excel( 91 | coco_evaluator, epoch, output_folder, 92 | "{}_{}".format(config["dataset"], config["label_type"])) 93 | 94 | def get_args_parser(): 95 | parser = argparse.ArgumentParser('Set training of unseen food segmentation', add_help=False) 96 | parser.add_argument("--gpu", type=str, default="0", help="gpu number to use. 0, 1") 97 | parser.add_argument("--config", type=str, help="path/to/configfile/.yaml") 98 | parser.add_argument("--max_epoch", type=int, default=None, help="maximun epoch for training") 99 | parser.add_argument("--resume", action="store_true", help='resume training if true') 100 | parser.add_argument("--resume_ckp", type=str, default=None, help="path/to/trained/weight") 101 | parser.add_argument("--resume_epoch", type=int, default=None, help="epoch when resuming") 102 | parser.add_argument("--write_excel", action="store_true", help='write COCO results into excel file') 103 | return parser 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser('Training Unseen Food Segmentation', parents=[get_args_parser()]) 107 | args = parser.parse_args() 108 | main(args) 109 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import * 2 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, deque 2 | import datetime 3 | import time 4 | import pickle 5 | 6 | import torch 7 | import torch.distributed as dist 8 | 9 | import errno 10 | import os 11 | 12 | class SmoothedValue(object): 13 | """Track a series of values and provide access to smoothed values over a 14 | window or the global series average. 15 | """ 16 | def __init__(self, window_size=20, fmt=None): 17 | if fmt is None: 18 | fmt = "{median:.4f} ({global_avg:.4f})" 19 | self.deque = deque(maxlen=window_size) 20 | self.total = 0.0 21 | self.count = 0 22 | self.fmt = fmt 23 | 24 | def update(self, value, n=1): 25 | self.deque.append(value) 26 | self.count += n 27 | self.total += value * n 28 | 29 | def synchronize_between_processes(self): 30 | """ 31 | Warning: does not synchronize the deque! 32 | """ 33 | if not is_dist_avail_and_initialized(): 34 | return 35 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 36 | dist.barrier() 37 | dist.all_reduce(t) 38 | t = t.tolist() 39 | self.count = int(t[0]) 40 | self.total = t[1] 41 | 42 | @property 43 | def median(self): 44 | d = torch.tensor(list(self.deque)) 45 | return d.median().item() 46 | 47 | @property 48 | def avg(self): 49 | d = torch.tensor(list(self.deque), dtype=torch.float32) 50 | return d.mean().item() 51 | 52 | @property 53 | def global_avg(self): 54 | return self.total / self.count 55 | 56 | @property 57 | def max(self): 58 | return max(self.deque) 59 | 60 | @property 61 | def value(self): 62 | return self.deque[-1] 63 | 64 | def __str__(self): 65 | return self.fmt.format( 66 | median=self.median, 67 | avg=self.avg, 68 | global_avg=self.global_avg, 69 | max=self.max, 70 | value=self.value) 71 | 72 | 73 | class MetricLogger(object): 74 | def __init__(self, delimiter="\t"): 75 | self.meters = defaultdict(SmoothedValue) 76 | self.delimiter = delimiter 77 | 78 | def update(self, **kwargs): 79 | for k, v in kwargs.items(): 80 | if isinstance(v, torch.Tensor): 81 | v = v.item() 82 | assert isinstance(v, (float, int)) 83 | self.meters[k].update(v) 84 | 85 | def __getattr__(self, attr): 86 | if attr in self.meters: 87 | return self.meters[attr] 88 | if attr in self.__dict__: 89 | return self.__dict__[attr] 90 | raise AttributeError("'{}' object has no attribute '{}'".format( 91 | type(self).__name__, attr)) 92 | 93 | def __str__(self): 94 | loss_str = [] 95 | for name, meter in self.meters.items(): 96 | loss_str.append( 97 | "{}: {}".format(name, str(meter)) 98 | ) 99 | return self.delimiter.join(loss_str) 100 | 101 | def synchronize_between_processes(self): 102 | for meter in self.meters.values(): 103 | meter.synchronize_between_processes() 104 | 105 | def add_meter(self, name, meter): 106 | self.meters[name] = meter 107 | 108 | def log_every(self, iterable, print_freq, header=None): 109 | i = 0 110 | if not header: 111 | header = '' 112 | start_time = time.time() 113 | end = time.time() 114 | iter_time = SmoothedValue(fmt='{avg:.4f}') 115 | data_time = SmoothedValue(fmt='{avg:.4f}') 116 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 117 | if torch.cuda.is_available(): 118 | log_msg = self.delimiter.join([ 119 | header, 120 | '[{0' + space_fmt + '}/{1}]', 121 | 'eta: {eta}', 122 | '{meters}', 123 | 'time: {time}', 124 | 'data: {data}', 125 | 'max mem: {memory:.0f}' 126 | ]) 127 | else: 128 | log_msg = self.delimiter.join([ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ]) 136 | MB = 1024.0 * 1024.0 137 | for obj in iterable: 138 | data_time.update(time.time() - end) 139 | yield obj 140 | iter_time.update(time.time() - end) 141 | if i % print_freq == 0 or i == len(iterable) - 1: 142 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 144 | if torch.cuda.is_available(): 145 | print(log_msg.format( 146 | i, len(iterable), eta=eta_string, 147 | meters=str(self), 148 | time=str(iter_time), data=str(data_time), 149 | memory=torch.cuda.max_memory_allocated() / MB)) 150 | else: 151 | print(log_msg.format( 152 | i, len(iterable), eta=eta_string, 153 | meters=str(self), 154 | time=str(iter_time), data=str(data_time))) 155 | i += 1 156 | end = time.time() 157 | total_time = time.time() - start_time 158 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 159 | print('{} Total time: {} ({:.4f} s / it)'.format( 160 | header, total_time_str, total_time / len(iterable))) 161 | 162 | 163 | def reduce_dict(input_dict, average=True): 164 | """ 165 | Args: 166 | input_dict (dict): all the values will be reduced 167 | average (bool): whether to do average or sum 168 | Reduce the values in the dictionary from all processes so that all processes 169 | have the averaged results. Returns a dict with the same fields as 170 | input_dict, after reduction. 171 | """ 172 | world_size = get_world_size() 173 | if world_size < 2: 174 | return input_dict 175 | with torch.no_grad(): 176 | names = [] 177 | values = [] 178 | # sort the keys so that they are consistent across processes 179 | for k in sorted(input_dict.keys()): 180 | names.append(k) 181 | values.append(input_dict[k]) 182 | values = torch.stack(values, dim=0) 183 | dist.all_reduce(values) 184 | if average: 185 | values /= world_size 186 | reduced_dict = {k: v for k, v in zip(names, values)} 187 | return reduced_dict 188 | 189 | def get_world_size(): 190 | if not is_dist_avail_and_initialized(): 191 | return 1 192 | return dist.get_world_size() 193 | 194 | def is_dist_avail_and_initialized(): 195 | if not dist.is_available(): 196 | return False 197 | if not dist.is_initialized(): 198 | return False 199 | return True 200 | 201 | 202 | def all_gather(data): 203 | """ 204 | Run all_gather on arbitrary picklable data (not necessarily tensors) 205 | Args: 206 | data: any picklable object 207 | Returns: 208 | list[data]: list of data gathered from each rank 209 | """ 210 | world_size = get_world_size() 211 | if world_size == 1: 212 | return [data] 213 | 214 | # serialized to a Tensor 215 | buffer = pickle.dumps(data) 216 | storage = torch.ByteStorage.from_buffer(buffer) 217 | tensor = torch.ByteTensor(storage).to("cuda") 218 | 219 | # obtain Tensor size of each rank 220 | local_size = torch.tensor([tensor.numel()], device="cuda") 221 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 222 | dist.all_gather(size_list, local_size) 223 | size_list = [int(size.item()) for size in size_list] 224 | max_size = max(size_list) 225 | 226 | # receiving Tensor from all ranks 227 | # we pad the tensor because torch all_gather does not support 228 | # gathering tensors of different shapes 229 | tensor_list = [] 230 | for _ in size_list: 231 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 232 | if local_size != max_size: 233 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 234 | tensor = torch.cat((tensor, padding), dim=0) 235 | dist.all_gather(tensor_list, tensor) 236 | 237 | data_list = [] 238 | for size, tensor in zip(size_list, tensor_list): 239 | buffer = tensor.cpu().numpy().tobytes()[:size] 240 | data_list.append(pickle.loads(buffer)) 241 | 242 | return data_list 243 | -------------------------------------------------------------------------------- /vis_UNIMIB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import numpy as np 5 | 6 | from skimage.draw import polygon2mask 7 | import skimage 8 | 9 | def cvt_str(x): 10 | return str(x[0][0]) 11 | 12 | # LOAD MAT fiel and convert into List[Str] 13 | print("---" * 15) 14 | from scipy.io import loadmat 15 | data_root = "/data/joo/dataset/food/UNIMIB2016/split" 16 | ann_file = os.path.join(data_root, 'TestSet.mat') 17 | anns = loadmat(ann_file)['TestSet'] 18 | test_set = list(map(cvt_str, anns)) 19 | print("TestSet", len(test_set)) 20 | 21 | ann_file = os.path.join(data_root, 'TrainingSet.mat') 22 | anns = loadmat(ann_file)['TrainingSet'] 23 | train_set = list(map(cvt_str, anns)) 24 | print("TrainingSet", len(train_set)) 25 | 26 | exit() 27 | 28 | # read txt file and save as JSON 29 | def remove_bracket(ann): 30 | return ann.replace('[', '') 31 | def cvt_list(ann): 32 | if ann[0] == ',': 33 | ann = ann[1:] 34 | return list(map(int, ann.split(','))) 35 | save_root = "/data/joo/dataset/food/UNIMIB2016/annotations" 36 | txt_root = "/data/joo/dataset/food/UNIMIB2016/annotations_txt" 37 | # load trainset 38 | poly_dict = {} 39 | for img_name in train_set: 40 | print("IMG: ", img_name) 41 | txt_name = os.path.join(txt_root, "{}.txt".format(img_name)) 42 | ann = open(txt_name, 'r').readline() 43 | # print(ann) 44 | ann = ann.split(']')[:-2] 45 | ann = list(map(remove_bracket, ann)) 46 | ann = list(map(cvt_list, ann)) 47 | print("Instance: ", len(ann)) 48 | for inst_ann in ann: 49 | print(len(inst_ann), inst_ann[:10], "...") 50 | print("-----" * 25) 51 | poly_dict[img_name] = ann 52 | # save as json file 53 | save_name = os.path.join(save_root, "train.json") 54 | with open(save_name, "w") as json_file: 55 | json.dump(poly_dict, json_file) 56 | 57 | # load testset 58 | poly_dict = {} 59 | for img_name in test_set: 60 | print("IMG: ", img_name) 61 | txt_name = os.path.join(txt_root, "{}.txt".format(img_name)) 62 | ann = open(txt_name, 'r').readline() 63 | # print(ann) 64 | ann = ann.split(']')[:-2] 65 | ann = list(map(remove_bracket, ann)) 66 | ann = list(map(cvt_list, ann)) 67 | print("Instance: ", len(ann)) 68 | for inst_ann in ann: 69 | print(len(inst_ann), inst_ann[:10], "...") 70 | print("-----" * 25) 71 | poly_dict[img_name] = ann 72 | # save as json file 73 | save_name = os.path.join(save_root, "test.json") 74 | with open(save_name, "w") as json_file: 75 | json.dump(poly_dict, json_file) 76 | exit() 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | # data_root = "/data/joo/dataset/food/UNIMIB2016" 86 | # ann_file = os.path.join(data_root, "annotations/val.json") 87 | 88 | # anns = json.load(open(ann_file)) 89 | # print(type(anns)) 90 | 91 | # # print(anns.keys()) 92 | # # print(len(anns)) 93 | 94 | 95 | # data_root = "/data/joo/dataset/food/UNIMIB2016" 96 | # modes = ['train', 'val', 'test'] 97 | # for mode in modes: 98 | # ann_file = os.path.join(data_root, "annotations/{}.json".format(mode)) 99 | # anns = json.load(open(ann_file)) 100 | # print(mode, len(anns)) 101 | 102 | # for img_name, ann in anns.items(): 103 | # num_instance = len(ann) 104 | # img_file = os.path.join(data_root, "images", "{}.jpg".format(img_name)) 105 | # img_arr = cv2.imread(img_file) 106 | # img_h, img_w, img_c = img_arr.shape 107 | # cv2.imwrite("tmp_IMG.png", img_arr) 108 | # # mask_arr = np.zeros((img_h, img_w, num_instance), dtype=np.uint8) 109 | # mask_arr = np.zeros((img_h, img_w), dtype=np.uint8) 110 | # print("IMG: ", img_name, img_arr.shape) 111 | # for ann_inst in ann: 112 | # poly_dict = list(ann_inst.values())[0] 113 | # poly_pts = poly_dict["BR"] 114 | # rr, cc = skimage.draw.polygon(poly_pts[1::2], poly_pts[::2]) 115 | # rr = (img_h-1) - rr 116 | # cc = (img_w-1) - cc 117 | # mask_arr[rr, cc] = 150 118 | # cv2.imwrite("tmp_{}.png".format(food), mask_arr) 119 | # exit() 120 | -------------------------------------------------------------------------------- /vis_by_loader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | from loader import get_dataset 8 | import torch 9 | 10 | 11 | class UnNormalize(object): 12 | def __init__(self, mean, std): 13 | self.mean = mean 14 | self.std = std 15 | def __call__(self, tensor): 16 | """ 17 | Args: 18 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 19 | Returns: 20 | Tensor: Normalized image. 21 | """ 22 | # The normalize code -> t.sub_(m).div_(s) 23 | for t, m, s in zip(tensor, self.mean, self.std): 24 | t.mul_(s).add_(m) 25 | return tensor 26 | 27 | draw_box = True 28 | 29 | # config for dataset 30 | cfg_name = "./config/eval_real_data_MediHard" 31 | mode = 'test' 32 | save_dir = "./tmp/UNIMIB_vis/OurReal_MediHard" 33 | 34 | cfg_name = "./config/eval_unimib" 35 | mode = 'test' 36 | save_dir = "./tmp/Inference/UNIMIB_Testset_GroundTruth" 37 | os.makedirs(save_dir, exist_ok=True) 38 | 39 | 40 | if cfg_name[-4:] != '.yaml': cfg_name += '.yaml' 41 | with open(cfg_name) as cfg_file: 42 | config = yaml.safe_load(cfg_file) 43 | print(config) 44 | 45 | dataset = get_dataset(config, mode=mode) 46 | dataloader = torch.utils.data.DataLoader( 47 | dataset=dataset, num_workers=config["num_workers"], 48 | batch_size=1, shuffle=False) 49 | unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 50 | 51 | print("+++ Visualize {} data".format(len(dataloader))) 52 | for idx, (image, target) in tqdm(enumerate(dataloader)): 53 | # de-normalize image tensor 54 | image = unorm(image[0]).cpu().detach().numpy().transpose(1,2,0) 55 | image = np.uint8(image * 255) 56 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 57 | # extract mask 58 | masks = target["masks"][0].cpu().detach().numpy() 59 | boxes = target["boxes"][0].cpu().detach().numpy() 60 | colors = np.random.randint(0, 255, (len(masks), 3)) 61 | for i, mask in enumerate(masks): 62 | color = tuple(map(int, colors[i])) 63 | mask = np.expand_dims(mask, -1) 64 | mask = np.repeat(mask, 3, 2) 65 | mask = mask * color 66 | image = cv2.addWeighted(image, 1, mask.astype(np.uint8), 0.5, 0) 67 | 68 | if draw_box: 69 | x1, y1, x2, y2 = map(int, boxes[i]) 70 | cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) 71 | save_name = os.path.join(save_dir, "{}.jpg".format(idx)) 72 | cv2.imwrite(save_name, image) 73 | -------------------------------------------------------------------------------- /vis_by_loader_model.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import yaml 5 | from tqdm import tqdm 6 | 7 | from loader import get_dataset 8 | import torch 9 | from models import get_instance_segmentation_model 10 | 11 | 12 | class UnNormalize(object): 13 | def __init__(self, mean, std): 14 | self.mean = mean 15 | self.std = std 16 | def __call__(self, tensor): 17 | """ 18 | Args: 19 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 20 | Returns: 21 | Tensor: Normalized image. 22 | """ 23 | # The normalize code -> t.sub_(m).div_(s) 24 | for t, m, s in zip(tensor, self.mean, self.std): 25 | t.mul_(s).add_(m) 26 | return tensor 27 | 28 | 29 | thres = 0.5 30 | # device = torch.device("cpu") 31 | device = torch.device("cuda") 32 | 33 | # config for dataset 34 | cfg_name = "./config/eval_unimib" 35 | cfg_name = "./config/eval_real_data_MediHard" 36 | mode = 'test' 37 | 38 | # syn only 39 | ckp_path = "/data/joo/food/maskrcnn/1108_seg_tray_SIM_all/epoch_1.tar" 40 | save_dir = "./tmp/Inference/OurReal_MediHard_SynOnly" 41 | save_dir = "./tmp/Inference/UNIMIB_Testset_SynOnly" 42 | 43 | # fine-tune 44 | ckp_path = "/data/joo/food/maskrcnn/210327_real_easy_finetuning/epoch_18.tar" 45 | save_dir = "./tmp/Inference/OurReal_MediHard_FineTune" 46 | # ckp_path = "/data/joo/food/maskrcnn/210326_UNIMIB_FineTuning/epoch_49.tar" 47 | # save_dir = "./tmp/Inference/UNIMIB_Testset_FineTune" 48 | 49 | # # # from-scratch 50 | # ckp_path = "/data/joo/food/maskrcnn/210327_real_easy/epoch_22.tar" 51 | # save_dir = "./tmp/Inference/OurReal_MediHard_FromScr" 52 | # ckp_path = "/data/joo/food/maskrcnn/210326_UNIMIB_FronScratch/epoch_44.tar" 53 | # save_dir = "./tmp/Inference/UNIMIB_Testset_FromScr" 54 | 55 | os.makedirs(save_dir, exist_ok=True) 56 | 57 | # load trained model 58 | print("... loading model") 59 | model = get_instance_segmentation_model(num_classes=2) 60 | model.to(device) 61 | model.load_state_dict(torch.load(ckp_path)) 62 | model.eval() 63 | 64 | if cfg_name[-4:] != '.yaml': cfg_name += '.yaml' 65 | with open(cfg_name) as cfg_file: 66 | config = yaml.safe_load(cfg_file) 67 | print(config) 68 | 69 | dataset = get_dataset(config, mode=mode) 70 | dataloader = torch.utils.data.DataLoader( 71 | dataset=dataset, num_workers=config["num_workers"], 72 | batch_size=1, shuffle=False) 73 | unorm = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 74 | 75 | print("+++ Visualize {} data".format(len(dataloader))) 76 | for idx, (image, target) in tqdm(enumerate(dataloader)): 77 | # forward and post-process results 78 | pred_result = model(image.to(device), None)[0] 79 | pred_mask = pred_result['masks'].cpu().detach().numpy().transpose(0, 2, 3, 1) 80 | pred_mask[pred_mask >= 0.5] = 1 81 | pred_mask[pred_mask < 0.5] = 0 82 | pred_mask = np.repeat(pred_mask, 3, 3) 83 | pred_scores = pred_result['scores'].cpu().detach().numpy() 84 | pred_boxes = pred_result['boxes'].cpu().detach().numpy() 85 | 86 | # de-normalize image tensor 87 | image = unorm(image[0]).cpu().detach().numpy().transpose(1,2,0) 88 | image = np.uint8(image * 255) 89 | 90 | ids = np.where(pred_scores > thres)[0] 91 | colors = np.random.randint(0, 255, (len(ids), 3)) 92 | for color_i, pred_i in enumerate(ids): 93 | color = tuple(map(int, colors[color_i])) 94 | # draw segmentation 95 | mask = pred_mask[pred_i] 96 | mask = mask * color 97 | image = cv2.addWeighted(image, 1, mask.astype(np.uint8), 0.5, 0) 98 | # draw bbox and text 99 | x1, y1, x2, y2 = map(int, pred_boxes[pred_i]) 100 | cv2.rectangle(image, (x1, y1), (x2, y2), color, 2) 101 | 102 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 103 | save_name = os.path.join(save_dir, "{}.jpg".format(idx)) 104 | cv2.imwrite(save_name, image) 105 | --------------------------------------------------------------------------------