├── .gitattributes ├── pretrained ├── cat_examples.jpg ├── cutlery_example.jpg ├── pizza_topping_example.jpg ├── model_mask_rcnn_cat_bird_30.pth ├── model_mask_rcnn_cutlery_30.pth └── model_mask_rcnn_pizza_30.pth ├── .style.yapf ├── enviroment.yml ├── object_detection ├── __init__.py ├── metrics.py ├── dataloaders.py ├── visualisation.py ├── model.py ├── engines.py └── dataset.py ├── LICENSE ├── check_dataset.py ├── evaluate_images.py ├── .gitignore ├── train.py ├── readme.md └── extract_from_coco.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pth filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /pretrained/cat_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WillBrennan/ObjectDetection/HEAD/pretrained/cat_examples.jpg -------------------------------------------------------------------------------- /pretrained/cutlery_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WillBrennan/ObjectDetection/HEAD/pretrained/cutlery_example.jpg -------------------------------------------------------------------------------- /pretrained/pizza_topping_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WillBrennan/ObjectDetection/HEAD/pretrained/pizza_topping_example.jpg -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = facebook 3 | spaces_before_comment = 4 4 | split_before_logical_operator = true 5 | column_limit = 120 -------------------------------------------------------------------------------- /pretrained/model_mask_rcnn_cat_bird_30.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d5f0a621501ddef3dd706e494d0904e50e81ef6a5a038253e2d1ec37048d5f5a 3 | size 176181730 4 | -------------------------------------------------------------------------------- /pretrained/model_mask_rcnn_cutlery_30.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9582e3ee313a3c8ee16a2293454b8097b5054a5aa4e5c16dbb2578ff7a5b7d2c 3 | size 176311852 4 | -------------------------------------------------------------------------------- /pretrained/model_mask_rcnn_pizza_30.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d7b2b21f4c238006010a6554c308dd4873477f6d7253777a26ff4477577c5ee4 3 | size 176355234 4 | -------------------------------------------------------------------------------- /enviroment.yml: -------------------------------------------------------------------------------- 1 | name: object_detection 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python>=3.7 7 | - pip 8 | - pytorch 9 | - torchvision 10 | - cudatoolkit>=11.1 11 | - numpy 12 | - pip: 13 | - pytorch-ignite 14 | - opencv-python 15 | - pycocotools 16 | - tensorboard 17 | - albumentations 18 | - yapf 19 | - pytest 20 | - labelme 21 | -------------------------------------------------------------------------------- /object_detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import LabelMeDataset 2 | from .dataloaders import create_data_loaders 3 | from .engines import create_mask_rcnn_trainer 4 | from .engines import create_mask_rcnn_evaluator 5 | 6 | from .model import MaskRCNN 7 | from .model import filter_by_threshold 8 | 9 | from .engines import attach_lr_scheduler 10 | from .engines import attach_training_logger 11 | from .engines import attach_model_checkpoint 12 | from .engines import attach_metric_logger 13 | 14 | from .metrics import LossAverager 15 | 16 | from .visualisation import draw_results 17 | -------------------------------------------------------------------------------- /object_detection/metrics.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from ignite import metrics 3 | 4 | 5 | class LossAverager(metrics.Metric): 6 | def __init__(self, output_transform=lambda x: x, device=None): 7 | super().__init__(output_transform=output_transform, device=device) 8 | 9 | def reset(self): 10 | self.count = 0 11 | self.summation = defaultdict(int) 12 | 13 | def update(self, output): 14 | losses, batch_size = output 15 | 16 | self.count += batch_size 17 | for key, value in output[0].items(): 18 | self.summation[key] += batch_size * value 19 | 20 | def compute(self): 21 | results = {k: v / self.count for k, v in self.summation.items()} 22 | return results 23 | -------------------------------------------------------------------------------- /object_detection/dataloaders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import functools 3 | 4 | from torch.utils import data 5 | 6 | 7 | def collate_fn(batch): 8 | return tuple(zip(*batch)) 9 | 10 | 11 | def create_data_loaders(train_dataset: data.Dataset, val_dataset: data.Dataset, num_workers: int, batch_size: int): 12 | logging.info(f'creating dataloaders with {num_workers} workers and a batch-size of {batch_size}') 13 | fn_dataloader = functools.partial( 14 | data.DataLoader, 15 | batch_size=batch_size, 16 | num_workers=num_workers, 17 | collate_fn=collate_fn, 18 | pin_memory=True, 19 | ) 20 | 21 | train_loader = fn_dataloader(train_dataset, shuffle=True) 22 | 23 | train_metrics_sampler = data.RandomSampler(train_dataset, replacement=True, num_samples=len(val_dataset)) 24 | train_metrics_loader = fn_dataloader(train_dataset, sampler=train_metrics_sampler) 25 | 26 | val_metrics_loader = fn_dataloader(val_dataset) 27 | 28 | return train_loader, train_metrics_loader, val_metrics_loader 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Will Brennan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /check_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import cv2 5 | 6 | from object_detection import LabelMeDataset 7 | from object_detection import draw_results 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', type=str, required=True) 13 | parser.add_argument('--use-augmentation', action='store_true') 14 | parser.add_argument('--debug', action='store_true') 15 | return parser.parse_args() 16 | 17 | 18 | def log_stats(name, data): 19 | data_type = data.dtype 20 | data = data.float() 21 | logging.info(f'{name} - {data.shape} - {data_type} - min: {data.min()} mean: {data.mean()} max: {data.max()}') 22 | 23 | 24 | if __name__ == '__main__': 25 | args = parse_args() 26 | level = logging.DEBUG if args.debug else logging.INFO 27 | logging.basicConfig(level=level) 28 | 29 | dataset = LabelMeDataset(args.dataset, args.use_augmentation) 30 | 31 | num_samples = len(dataset) 32 | for idx in range(num_samples): 33 | logging.info(f'showing {(idx + 1)} of {num_samples} samples') 34 | image, target = dataset[idx] 35 | 36 | for k, v in target.items(): 37 | log_stats(k, v) 38 | 39 | result_image = draw_results(image, target, categories=dataset.categories) 40 | cv2.imshow('result', result_image) 41 | 42 | if cv2.waitKey(0) == ord('q'): 43 | logging.info('exiting...') 44 | exit() 45 | -------------------------------------------------------------------------------- /object_detection/visualisation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from typing import Dict 3 | import itertools 4 | 5 | import torch 6 | import cv2 7 | import numpy 8 | 9 | 10 | def draw_results(image: torch.Tensor, target: Dict[str, torch.Tensor], categories: List[str]): 11 | image = (255 * image).to(torch.uint8).cpu().numpy() 12 | image = numpy.transpose(image, (1, 2, 0)) 13 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 14 | 15 | colours = ( 16 | (0, 0, 255), (0, 255, 0), (255, 0, 0), (255, 255, 0), (0, 255, 255), (255, 0, 255), (0, 128, 255), 17 | (0, 255, 128), (128, 0, 255) 18 | ) 19 | 20 | for label, mask in zip(target['labels'], target['masks']): 21 | label = label.item() 22 | mask = mask.cpu().bool().numpy() 23 | 24 | category = categories[label] 25 | colour = colours[label % len(colours)] 26 | 27 | image[mask] = 0.5 * image[mask] + 0.5 * numpy.array(colour) 28 | 29 | for label, bbox in zip(target['labels'], target['boxes']): 30 | label = label.item() 31 | 32 | category = categories[label] 33 | colour = colours[label % len(colours)] 34 | 35 | bbox = bbox.cpu().numpy() 36 | bbox = numpy.round(bbox).astype(int).tolist() 37 | bbox_tl = tuple(bbox[:2]) 38 | bbox_br = tuple(bbox[2:]) 39 | cv2.rectangle(image, bbox_tl, bbox_br, colour, 3) 40 | 41 | text_point = (bbox_tl[0], bbox_tl[1] - 10) 42 | cv2.putText(image, category, text_point, cv2.FONT_HERSHEY_PLAIN, 2, colour) 43 | 44 | return image 45 | -------------------------------------------------------------------------------- /object_detection/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch import nn 5 | from torchvision.models import detection 6 | 7 | 8 | class MaskRCNN(nn.Module): 9 | @staticmethod 10 | def load(state_dict): 11 | # todo(will.brennan) - improve this... might want to save a categories file with this instead 12 | category_prefix = '_categories.' 13 | categories = [k for k in state_dict.keys() if k.startswith(category_prefix)] 14 | categories = [k[len(category_prefix):] for k in categories] 15 | 16 | model = MaskRCNN(categories) 17 | model.load_state_dict(state_dict) 18 | return model 19 | 20 | def __init__(self, categories): 21 | super().__init__() 22 | logging.info(f'creating model with categories: {categories}') 23 | 24 | # todo(will.brennan) - find a nicer way of saving the categories in the state dict... 25 | self._categories = nn.ParameterDict({i: nn.Parameter(torch.Tensor(0)) for i in categories}) 26 | num_categories = len(self._categories) 27 | 28 | self.model = detection.maskrcnn_resnet50_fpn(pretrained=True) 29 | 30 | logging.debug('changing num_categories for bbox predictor') 31 | 32 | in_features = self.model.roi_heads.box_predictor.cls_score.in_features 33 | self.model.roi_heads.box_predictor = detection.faster_rcnn.FastRCNNPredictor(in_features, num_categories) 34 | 35 | logging.debug('changing num_categories for mask predictor') 36 | 37 | in_features_mask = self.model.roi_heads.mask_predictor.conv5_mask.in_channels 38 | self.model.roi_heads.mask_predictor = detection.mask_rcnn.MaskRCNNPredictor( 39 | in_features_mask, 256, num_categories 40 | ) 41 | 42 | @property 43 | def categories(self): 44 | return list(self._categories.keys()) 45 | 46 | def forward(self, *args, **kwargs): 47 | return self.model(*args, **kwargs) 48 | 49 | 50 | def filter_by_threshold(result, bbox_thresh: float, mask_thresh: float): 51 | scores_mask = result['scores'] > bbox_thresh 52 | result = {k: v[scores_mask] for k, v in result.items()} 53 | 54 | result['masks'] = result['masks'][:, 0] >= mask_thresh 55 | 56 | return result -------------------------------------------------------------------------------- /evaluate_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pathlib 4 | import functools 5 | 6 | import cv2 7 | import torch 8 | from torchvision.transforms import functional as F 9 | 10 | from object_detection import MaskRCNN 11 | from object_detection import filter_by_threshold 12 | from object_detection import draw_results 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--images', type=str, required=True) 18 | parser.add_argument('--model', type=str, required=True) 19 | 20 | parser.add_argument('--threshold', type=float, default=0.5) 21 | 22 | parser.add_argument('--save', action='store_true') 23 | parser.add_argument('--display', action='store_true') 24 | 25 | return parser.parse_args() 26 | 27 | 28 | def find_files(dir_path: pathlib.Path, file_exts): 29 | assert dir_path.exists() 30 | assert dir_path.is_dir() 31 | 32 | for file_ext in file_exts: 33 | yield from dir_path.rglob(f'*{file_ext}') 34 | 35 | 36 | if __name__ == '__main__': 37 | logging.basicConfig(level=logging.INFO) 38 | args = parse_args() 39 | 40 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 41 | logging.info(f'running inference on {device}') 42 | 43 | assert args.display or args.save 44 | 45 | logging.info(f'loading model from {args.model}') 46 | model = MaskRCNN.load(torch.load(args.model, map_location=device)) 47 | model.to(device).eval() 48 | 49 | image_dir = pathlib.Path(args.images) 50 | 51 | fn_filter = functools.partial(filter_by_threshold, bbox_thresh=args.threshold, mask_thresh=args.threshold) 52 | 53 | for image_file in find_files(image_dir, ['.png', '.jpg', '.jpeg']): 54 | logging.info(f'finding objects in {image_file} with threshold of {args.threshold}') 55 | 56 | image = cv2.imread(str(image_file)) 57 | assert image is not None 58 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 59 | 60 | with torch.no_grad(): 61 | image = F.to_tensor(image) 62 | image = image.to(device).unsqueeze(0) 63 | 64 | results = model(image) 65 | results = [fn_filter(i) for i in results] 66 | 67 | image = draw_results(image[0], results[0], categories=model.categories) 68 | 69 | if args.save: 70 | output_name = f'results_{image_file.name}' 71 | logging.info(f'writing output to {output_name}') 72 | cv2.imwrite(str(output_name), image) 73 | 74 | if args.display: 75 | cv2.imshow('image', image) 76 | 77 | if cv2.waitKey(0) == ord('q'): 78 | logging.info('exiting...') 79 | exit() 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # machine learning 141 | *.png 142 | *.jpg 143 | *.jpeg 144 | *.json 145 | *.pth 146 | events.* -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | from torch import optim 7 | from torch.utils import data 8 | from ignite import engine 9 | from torch.utils import tensorboard 10 | 11 | from object_detection import LabelMeDataset 12 | from object_detection import create_data_loaders 13 | from object_detection import MaskRCNN 14 | from object_detection import create_mask_rcnn_trainer 15 | from object_detection import create_mask_rcnn_evaluator 16 | from object_detection import LossAverager 17 | 18 | from object_detection import attach_lr_scheduler 19 | from object_detection import attach_training_logger 20 | from object_detection import attach_model_checkpoint 21 | from object_detection import attach_metric_logger 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--train', type=str, required=True) 27 | parser.add_argument('--val', type=str, required=True) 28 | parser.add_argument('--model-tag', type=str, required=True) 29 | 30 | parser.add_argument('--debug', action='store_true') 31 | 32 | parser.add_argument('--num-workers', type=int, default=16) 33 | parser.add_argument('--initial-lr', type=float, default=1e-4) 34 | parser.add_argument('--num-epochs', type=int, default=30) 35 | parser.add_argument('--batch-size', type=int, default=2) 36 | 37 | return parser.parse_args() 38 | 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) 43 | 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | logging.info(f'running training on {device}') 46 | 47 | logging.info('creating dataset and data loaders') 48 | 49 | # assert args.train != args.val 50 | train_dataset = LabelMeDataset(args.train, use_augmentation=True) 51 | val_dataset = LabelMeDataset(args.val, use_augmentation=False) 52 | assert train_dataset.categories == val_dataset.categories 53 | 54 | train_loader, train_metrics_loader, val_metrics_loader = create_data_loaders( 55 | train_dataset=train_dataset, 56 | val_dataset=val_dataset, 57 | num_workers=args.num_workers, 58 | batch_size=args.batch_size, 59 | ) 60 | 61 | logging.info(f'creating model and optimizer with initial lr of {args.initial_lr}') 62 | model = MaskRCNN(train_dataset.categories) 63 | model = nn.DataParallel(model).to(device) 64 | optimizer = optim.RMSprop(params=[p for p in model.parameters() if p.requires_grad], lr=args.initial_lr) 65 | 66 | logging.info('creating trainer and evaluator engines') 67 | trainer = create_mask_rcnn_trainer(model=model, optimizer=optimizer, device=device, non_blocking=True) 68 | # note(will.brennan) - our evaluator just reports losses! we only want to see if its overfitting! 69 | evaluator = create_mask_rcnn_evaluator( 70 | model, metrics={ 71 | 'losses': LossAverager(device=device), 72 | }, device=device, non_blocking=True 73 | ) 74 | 75 | logging.info(f'creating summary writer with tag {args.model_tag}') 76 | writer = tensorboard.SummaryWriter(log_dir=f'logs/{args.model_tag}') 77 | 78 | logging.info('attaching lr scheduler') 79 | lr_scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 80 | attach_lr_scheduler(trainer, lr_scheduler, writer) 81 | 82 | logging.info('attaching event driven calls') 83 | attach_model_checkpoint(trainer, {args.model_tag: model.module}) 84 | attach_training_logger(trainer, writer=writer) 85 | 86 | attach_metric_logger(trainer, evaluator, 'train', train_metrics_loader, writer) 87 | attach_metric_logger(trainer, evaluator, 'val', val_metrics_loader, writer) 88 | 89 | logging.info('training...') 90 | trainer.run(train_loader, max_epochs=args.num_epochs) 91 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Object Detection 2 | This project lets you fine-tune Mask-RCNN on masks annotated using labelme, this allows you to train mask-rcnn on any categories you want to annotate! This project comes with several pretrained models trained on either custom datasets or on subsets of COCO. 3 | 4 | ## Getting Started 5 | The pretrained models are stored in the repo with git-lfs, when you clone make sure you've pulled the files by calling, 6 | 7 | ```bash 8 | git lfs pull 9 | ``` 10 | or by downloading them from github directly. This project uses conda to manage its enviroment; once conda is installed we create the enviroment and activate it, 11 | ```bash 12 | conda env create -f enviroment.yml 13 | conda activate object_detection 14 | ``` 15 | . On windows powershell needs to be initialised and the execution policy needs to be modified. 16 | ```bash 17 | conda init powershell 18 | Set-ExecutionPolicy -ExecutionPolicy RemoteSigned -Scope CurrentUser 19 | ``` 20 | 21 | ## Pre-Trained Projects 22 | This project comes bundled with several pretrained models, which can be found in the `pretrained` directory. To infer objects and instance masks on your images run `evaluate_images`. 23 | ```bash 24 | # to display the output 25 | python evaluate_images.py --images ~/Pictures/ --model pretrained/model_mask_rcnn_skin_30.pth --display 26 | # to save the output 27 | python evaluate_images.py --images ~/Pictures/ --model pretrained/model_mask_rcnn_skin_30.pth --save 28 | ``` 29 | 30 | ### Pizza Topping Segmentation 31 | This was trained with a custom dataset of 89 images taken from COCO where pizza topping annotations were added. There's very few images for each type of topping so this model performs very badly and needs quite a few more images to behave well! 32 | 33 | - 'chilli', 'ham', 'jalapenos', 'mozzarella', 'mushrooms', 'olive', 'pepperoni', 'pineapple', 'salad', 'tomato' 34 | 35 | ![Pizza Toppings](https://raw.githubusercontent.com/WillBrennan/ObjectDetection/master/pretrained/pizza_topping_example.jpg) 36 | 37 | ### Cat and Bird Detection 38 | Annotated images of birds and cats were taken from COCO using the `extract_from_coco` script and then trained on. 39 | 40 | - cat, birds 41 | 42 | ![Demo on Cat & Birds](https://raw.githubusercontent.com/WillBrennan/ObjectDetection/master/pretrained/cat_examples.jpg) 43 | 44 | ### Cutlery Detection 45 | Annotated images of knifes, forks, spoons and other cutlery were taken from COCO using the `extract_from_coco` script and then trained on. 46 | 47 | - knife, bowl, cup, bottle, wine glass, fork, spoon, dining table 48 | 49 | ![Demo on Cutlery](https://raw.githubusercontent.com/WillBrennan/ObjectDetection/master/pretrained/cutlery_example.jpg) 50 | 51 | ## Training New Projects 52 | To train a new project you can either create new labelme annotations on your images, to launch labelme run, 53 | 54 | ```bash 55 | labelme 56 | ``` 57 | and start annotating your images! You'll need a couple of hundred. Alternatively if your category is already in COCO you can run the conversion tool to create labelme annotations from them. 58 | 59 | ```bash 60 | python extract_from_coco.py --images ~/datasets/coco/val2017 --annotations ~/datasets/coco/annotations/instances_val2017.json --output ~/datasets/my_cat_images_val --categories cat 61 | ``` 62 | 63 | Once you've got a directory of labelme annotations you can check how the images will be shown to the model during training by running, 64 | 65 | ```bash 66 | python check_dataset.py --dataset ~/datasets/my_cat_images_val 67 | # to show our dataset with training augmentation 68 | python check_dataset.py --dataset ~/datasets/my_cat_images_val --use-augmentation 69 | ``` 70 | . If your happy with the images and how they'll appear in training then train the model using, 71 | 72 | ```bash 73 | python train.py --train ~/datasets/my_cat_images_train --val ~/datasets/my_cat_images_val --model-tag mask_rcnn_cat 74 | ``` 75 | . This may take some time depending on how many images you have. Tensorboard logs are available in the `logs` directory. To run your trained model on a directory of images run 76 | 77 | ```bash 78 | # to display the output 79 | python evaluate_images.py --images ~/Pictures/my_cat_imgs --model models/model_mask_rcnn_cat_30.pth --display 80 | # to save the output 81 | python evaluate_images.py --images ~/Pictures/my_cat_imgs --model models/model_mask_rcnn_cat_30.pth --save 82 | ``` 83 | -------------------------------------------------------------------------------- /extract_from_coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pathlib 4 | import json 5 | import collections 6 | import multiprocessing as mp 7 | import functools 8 | import shutil 9 | import base64 10 | 11 | import cv2 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--images', type=str, required=True) 17 | parser.add_argument('--annotations', type=str, required=True) 18 | parser.add_argument('--output', type=str, required=True) 19 | 20 | parser.add_argument('--categories', type=str, nargs='+', required=True) 21 | 22 | parser.add_argument('--num-workers', type=int, default=None) 23 | 24 | return parser.parse_args() 25 | 26 | 27 | def images_with_categories(annotations, categories): 28 | all_categories = {i['name']: i['id'] for i in annotations['categories']} 29 | category_names = {i['id']: i['name'] for i in annotations['categories']} 30 | categories = [i for i in categories if i in all_categories] 31 | category_ids = [all_categories[i] for i in categories] 32 | 33 | logging.info(f'categories: {categories}') 34 | logging.info(f'all available categories: {all_categories.keys()}') 35 | 36 | anns_by_image = collections.defaultdict(list) 37 | 38 | for ann in annotations['annotations']: 39 | category_id = ann['category_id'] 40 | image_id = ann['image_id'] 41 | 42 | if category_id not in category_ids: 43 | continue 44 | 45 | ann['category'] = category_names[category_id] 46 | anns_by_image[image_id].append(ann) 47 | 48 | for image_info in annotations['images']: 49 | image_id = image_info['id'] 50 | if image_id in anns_by_image: 51 | anns_for_image = anns_by_image[image_id] 52 | yield image_info, anns_for_image 53 | 54 | 55 | def ann_to_shape(ann): 56 | points = ann['segmentation'][0] 57 | points = [[x, y] for x, y in zip(points[0::2], points[1::2])] 58 | 59 | return { 60 | 'label': ann['category'], 61 | 'points': points, 62 | 'group_id': None, 63 | 'shape_type': 'polygon', 64 | 'flags': {}, 65 | } 66 | 67 | 68 | def save_labelme(image_info, anns, images_dir: pathlib.Path, output_dir: pathlib.Path): 69 | image_path = images_dir / image_info['file_name'] 70 | output_image_path = output_dir / image_info['file_name'] 71 | 72 | logging.info(f'reading image from {image_path}') 73 | 74 | assert image_path.exists() 75 | shutil.copyfile(image_path, output_image_path) 76 | 77 | with open(image_path, 'rb') as image_file: 78 | image_data = base64.b64encode(image_file.read()).decode() 79 | 80 | # warning(will.brennan): 81 | # currently we're only handling 'segmentation' being points... 82 | # maybe use pycocotools despite it not working on windows... 83 | anns = [ann for ann in anns if len(ann['segmentation']) >= 1 and isinstance(ann['segmentation'], list)] 84 | shapes = [ann_to_shape(ann) for ann in anns] 85 | 86 | labelme_data = { 87 | 'version': '4.2.10', 88 | 'flags': {}, 89 | 'shapes': shapes, 90 | 'imagePath': image_info['file_name'], 91 | 'imageHeight': image_info['height'], 92 | 'imageWidth': image_info['width'], 93 | 'imageData': image_data, 94 | } 95 | 96 | output_json_path = output_image_path.with_suffix('.json') 97 | with open(output_json_path, 'w') as json_file: 98 | json.dump(labelme_data, json_file) 99 | 100 | 101 | if __name__ == '__main__': 102 | # note(will.brennan) - not using pycocotools because the authors refuse to support windows 103 | args = parse_args() 104 | logging.basicConfig(level=logging.INFO) 105 | 106 | logging.info(f'using annotations from {args.annotations}') 107 | with open(args.annotations, 'r') as annotations_file: 108 | annotations = json.load(annotations_file) 109 | 110 | gn_anns = images_with_categories(annotations, args.categories) 111 | 112 | num_workers = mp.cpu_count() if args.num_workers is None else args.num_workers 113 | images_dir = pathlib.Path(args.images) 114 | output_dir = pathlib.Path(args.output) 115 | output_dir.mkdir(exist_ok=True, parents=True) 116 | logging.info(f'saving labelme data to {output_dir} with {num_workers}') 117 | 118 | fn_save_labelme = functools.partial(save_labelme, images_dir=images_dir, output_dir=output_dir) 119 | 120 | pool = mp.Pool(num_workers) 121 | res = pool.starmap_async(fn_save_labelme, gn_anns) 122 | res.get() 123 | -------------------------------------------------------------------------------- /object_detection/engines.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils import data 8 | from torch.utils import tensorboard 9 | from ignite import engine 10 | from ignite import metrics 11 | from ignite import handlers 12 | from ignite.contrib import handlers as contrib_handlers 13 | 14 | 15 | def create_mask_rcnn_trainer(model: nn.Module, optimizer: optim.Optimizer, device=None, non_blocking: bool = False): 16 | if device: 17 | model.to(device) 18 | 19 | fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking) 20 | 21 | def _update(engine, batch): 22 | model.train() 23 | optimizer.zero_grad() 24 | 25 | image, targets = fn_prepare_batch(batch) 26 | losses = model(image, targets) 27 | 28 | loss = sum(loss for loss in losses.values()) 29 | 30 | loss.backward() 31 | optimizer.step() 32 | 33 | losses = {k: v.item() for k, v in losses.items()} 34 | losses['loss'] = loss.item() 35 | return losses 36 | 37 | return engine.Engine(_update) 38 | 39 | 40 | def create_mask_rcnn_evaluator(model: nn.Module, metrics, device=None, non_blocking: bool = False): 41 | if device: 42 | model.to(device) 43 | 44 | fn_prepare_batch = lambda batch: engine._prepare_batch(batch, device=device, non_blocking=non_blocking) 45 | 46 | def _update(engine, batch): 47 | # warning(will.brennan) - not putting model in eval mode because we want the losses! 48 | with torch.no_grad(): 49 | image, targets = fn_prepare_batch(batch) 50 | losses = model(image, targets) 51 | 52 | losses = {k: v.item() for k, v in losses.items()} 53 | losses['loss'] = sum(losses.values()) 54 | 55 | # note(will.brennan) - an ugly hack for metrics... 56 | return (losses, len(image)) 57 | 58 | evaluator = engine.Engine(_update) 59 | 60 | for name, metric in metrics.items(): 61 | metric.attach(evaluator, name) 62 | 63 | return evaluator 64 | 65 | 66 | def attach_lr_scheduler( 67 | trainer: engine.Engine, 68 | lr_scheduler: optim.lr_scheduler._LRScheduler, 69 | writer: tensorboard.SummaryWriter, 70 | ): 71 | @trainer.on(engine.Events.EPOCH_COMPLETED) 72 | def update_lr(engine: engine.Engine): 73 | current_lr = lr_scheduler.get_last_lr()[0] 74 | logging.info(f'epoch: {engine.state.epoch} - current lr: {current_lr}') 75 | writer.add_scalar('learning_rate', current_lr, engine.state.epoch) 76 | 77 | lr_scheduler.step() 78 | 79 | 80 | def attach_training_logger( 81 | trainer: engine.Engine, 82 | writer: tensorboard.SummaryWriter, 83 | log_interval: int = 10, 84 | ): 85 | @trainer.on(engine.Events.ITERATION_COMPLETED) 86 | def log_training_loss(engine: engine.Engine): 87 | epoch_length = engine.state.epoch_length 88 | epoch = engine.state.epoch 89 | output = engine.state.output 90 | 91 | idx = engine.state.iteration 92 | idx_in_epoch = (engine.state.iteration - 1) % epoch_length + 1 93 | 94 | if idx_in_epoch % 10 != 0: 95 | return 96 | 97 | msg = '' 98 | for name, value in output.items(): 99 | msg += f'{name}: {value:.4f} ' 100 | writer.add_scalar(f'training/{name}', value, idx) 101 | logging.info(f'epoch[{epoch}] - iteration[{idx_in_epoch}/{epoch_length}] ' + msg) 102 | 103 | 104 | def attach_metric_logger( 105 | trainer: engine.Engine, 106 | evaluator: engine.Engine, 107 | data_name: str, 108 | data_loader: data.DataLoader, 109 | writer: tensorboard.SummaryWriter, 110 | ): 111 | @trainer.on(engine.Events.EPOCH_COMPLETED) 112 | def log_metrics(engine): 113 | evaluator.run(data_loader) 114 | 115 | def _to_message(metrics): 116 | message = '' 117 | 118 | for metric_name, metric_value in metrics.items(): 119 | if isinstance(metric_value, dict): 120 | message += _to_message(metric_value) 121 | else: 122 | writer.add_scalar(f'{data_name}/mean_{metric_name}', metric_value, engine.state.epoch) 123 | message += f'{metric_name}: {metric_value:.3f} ' 124 | 125 | return message 126 | 127 | message = _to_message(evaluator.state.metrics) 128 | logging.info(message) 129 | 130 | 131 | def attach_model_checkpoint(trainer: engine.Engine, models: Dict[str, nn.Module]): 132 | def to_epoch(trainer: engine.Engine, event_name: str): 133 | return trainer.state.epoch 134 | 135 | handler = handlers.ModelCheckpoint( 136 | './models', 137 | 'model', 138 | create_dir=True, 139 | require_empty=False, 140 | n_saved=None, 141 | global_step_transform=to_epoch, 142 | ) 143 | trainer.add_event_handler(engine.Events.EPOCH_COMPLETED, handler, models) 144 | -------------------------------------------------------------------------------- /object_detection/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import base64 4 | import pathlib 5 | import logging 6 | import collections 7 | 8 | import cv2 9 | import numpy 10 | import torch 11 | import torch.utils.data as data 12 | import torchvision.transforms as transforms 13 | import albumentations as alb 14 | 15 | 16 | def _load_image(image_data_b64): 17 | # note(will.brennan) - from https://github.com/wkentaro/labelme/blob/f20a9425698f1ac9b48b622e0140016e9b73601a/labelme/utils/image.py#L17 18 | image_data = base64.b64decode(image_data_b64) 19 | image_data = numpy.fromstring(image_data, dtype=numpy.uint8) 20 | image = cv2.imdecode(image_data, cv2.IMREAD_COLOR) 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 22 | return image 23 | 24 | 25 | def _create_masks(shapes, image_width: int, image_height: int): 26 | for shape in shapes: 27 | mask = numpy.zeros((image_height, image_width), dtype=numpy.uint8) 28 | points = numpy.array(shape['points']).reshape((-1, 1, 2)) 29 | points = numpy.round(points).astype(numpy.int32) 30 | 31 | cv2.fillPoly(mask, [points], (1, )) 32 | mask = mask.astype(numpy.uint8) 33 | yield mask 34 | 35 | 36 | def _create_bboxs(shapes): 37 | for shape in shapes: 38 | points = numpy.array(shape['points']) 39 | xmin, ymin = numpy.min(points, axis=0) 40 | xmax, ymax = numpy.max(points, axis=0) 41 | 42 | yield [xmin, ymin, xmax, ymax] 43 | 44 | 45 | class ToTensor(alb.ImageOnlyTransform): 46 | def __init__(self): 47 | super().__init__(always_apply=True) 48 | 49 | def apply(self, image, **params): 50 | return transforms.ToTensor()(image) 51 | 52 | def get_params(self): 53 | return {} 54 | 55 | 56 | class LabelMeDataset(data.Dataset): 57 | def __init__(self, directory: str, use_augmentation: bool): 58 | self.directory = pathlib.Path(directory) 59 | self.use_augmentation = use_augmentation 60 | assert self.directory.exists() 61 | assert self.directory.is_dir() 62 | 63 | self.labelme_paths = [] 64 | self.categories = collections.defaultdict(list) 65 | 66 | for labelme_path in self.directory.rglob('*.json'): 67 | with open(labelme_path, 'r') as labelme_file: 68 | labelme_json = json.load(labelme_file) 69 | 70 | required_keys = ['version', 'flags', 'shapes', 'imagePath', 'imageData', 'imageHeight', 'imageWidth'] 71 | assert all(key in labelme_json for key in required_keys), (required_keys, labelme_json.keys()) 72 | 73 | self.labelme_paths += [labelme_path] 74 | 75 | for shape in labelme_json['shapes']: 76 | label = shape['label'] 77 | self.categories[label] += [labelme_path] 78 | 79 | for category, paths in self.categories.items(): 80 | for path in paths: 81 | logging.debug(f'{category} - {path}') 82 | self.categories = sorted(list(self.categories.keys())) 83 | 84 | logging.info(f'loaded {len(self)} annotations from {self.directory}') 85 | logging.info(f'use augmentation: {self.use_augmentation}') 86 | logging.info(f'categories: {self.categories}') 87 | 88 | aug_transforms = [ToTensor()] 89 | if self.use_augmentation: 90 | aug_transforms = [ 91 | alb.HueSaturationValue(always_apply=True), 92 | alb.RandomBrightnessContrast(always_apply=True), 93 | alb.HorizontalFlip(), 94 | alb.RandomGamma(always_apply=True), 95 | ] + aug_transforms 96 | bbox_params = alb.BboxParams(format='pascal_voc', min_area=0.0, min_visibility=0.0, label_fields=['labels']) 97 | self.transforms = alb.Compose(transforms=aug_transforms, bbox_params=bbox_params) 98 | 99 | def __len__(self): 100 | return len(self.labelme_paths) 101 | 102 | def __getitem__(self, idx: int): 103 | labelme_path = self.labelme_paths[idx] 104 | logging.debug('parsing labelme json') 105 | 106 | with open(labelme_path, 'r') as labelme_file: 107 | labelme_json = json.load(labelme_file) 108 | 109 | image_width = labelme_json['imageWidth'] 110 | image_height = labelme_json['imageHeight'] 111 | 112 | image = _load_image(labelme_json['imageData']) 113 | assert image.shape == (image_height, image_width, 3) 114 | 115 | labelme_shapes = labelme_json['shapes'] 116 | labelme_shapes = [i for i in labelme_json['shapes'] if len(i['points']) > 2] 117 | assert all(i['shape_type'] == 'polygon' for i in labelme_shapes) 118 | 119 | masks = list(_create_masks(labelme_shapes, image_width, image_height)) 120 | 121 | bboxes = list(_create_bboxs(labelme_shapes)) 122 | 123 | labels = [self.categories.index(shape['label']) for shape in labelme_shapes] 124 | 125 | logging.debug('applying transforms to image and targets') 126 | 127 | target = self.transforms(image=image, bboxes=bboxes, labels=labels, masks=masks) 128 | 129 | image = target.pop('image') 130 | 131 | target['masks'] = torch.as_tensor(numpy.stack(target['masks']), dtype=torch.uint8) 132 | target['labels'] = torch.as_tensor(target['labels'], dtype=torch.int64) 133 | target['iscrowd'] = torch.zeros_like(target['labels'], dtype=torch.int64) 134 | target['image_id'] = torch.tensor([idx], dtype=torch.int64) 135 | 136 | bboxes = torch.as_tensor(target.pop('bboxes'), dtype=torch.float32) 137 | 138 | target['area'] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0]) 139 | target['boxes'] = bboxes 140 | 141 | return image, target 142 | --------------------------------------------------------------------------------