├── .dockerignore ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile.sku110kserving ├── Dockerfile.sku110ktraining ├── LICENSE ├── README.md ├── build_and_push.sh ├── container_serving └── predict_sku110k.py ├── container_training └── sku-110k │ ├── __init__.py │ ├── datasets │ └── catalog.py │ ├── engine │ ├── __init__.py │ ├── custom_trainer.py │ └── hooks.py │ ├── evaluation │ ├── __init__.py │ ├── coco.py │ ├── factory.py │ ├── run.py │ └── utils.py │ └── training.py └── d2_custom_sku110k.ipynb /.dockerignore: -------------------------------------------------------------------------------- 1 | cache 2 | .git 3 | .DS_Store 4 | .vscode 5 | .idea 6 | .pylintrc 7 | .env 8 | ipython_config.py 9 | profile_default/ 10 | .ipynb_checkpoints 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *checkpoints* 6 | *.pkl 7 | 8 | # Local config 9 | .DS_Store 10 | .vscode 11 | .idea 12 | .pylintrc 13 | .env 14 | ipython_config.py 15 | profile_default/ 16 | 17 | # Notebook 18 | .ipynb_checkpoints 19 | 20 | # Data 21 | cache -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /Dockerfile.sku110kserving: -------------------------------------------------------------------------------- 1 | # Build an image of Detectron2 with Sagemaker Multi Model Server: https://github.com/awslabs/multi-model-server 2 | 3 | # using Sagemaker PyTorch container as base image 4 | # from https://github.com/aws/sagemaker-pytorch-serving-container/ 5 | 6 | ARG REGION 7 | FROM 763104351884.dkr.ecr.$REGION.amazonaws.com/pytorch-inference:1.5.1-gpu-py36-cu101-ubuntu16.04 8 | LABEL author="pirrera@amazon.com" 9 | 10 | ############# Installing latest builds ############ 11 | RUN pip install --upgrade torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | ENV FORCE_CUDA="1" 14 | # Build D2 only for Turing (G4) and Volta (P3) architectures. Use P3 for batch transforms and G4 for inference on endpoints 15 | ENV TORCH_CUDA_ARCH_LIST="Turing;Volta" 16 | 17 | # Install Detectron2 18 | RUN pip install \ 19 | --no-cache-dir pycocotools~=2.0.0 \ 20 | --no-cache-dir https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.6/detectron2-0.4%2Bcu101-cp36-cp36m-linux_x86_64.whl 21 | 22 | # Set a fixed model cache directory. Detectron2 requirement 23 | ENV FVCORE_CACHE="/tmp" 24 | -------------------------------------------------------------------------------- /Dockerfile.sku110ktraining: -------------------------------------------------------------------------------- 1 | # Build an image of Detectron2 that can do distributing training on Amazon Sagemaker 2 | 3 | # using Sagemaker PyTorch container as base image 4 | # from https://github.com/aws/sagemaker-pytorch-container 5 | 6 | ARG REGION 7 | FROM 763104351884.dkr.ecr.$REGION.amazonaws.com/pytorch-training:1.6.0-gpu-py36-cu101-ubuntu16.04 8 | LABEL author="pirrera@amazon.com" 9 | 10 | ############# Detectron2 pre-built binaries Pytorch default install ############ 11 | RUN pip install --upgrade torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 12 | 13 | ############# Detectron2 section ############## 14 | RUN pip install \ 15 | --no-cache-dir pycocotools~=2.0.0 \ 16 | --no-cache-dir https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.6/detectron2-0.4%2Bcu101-cp36-cp36m-linux_x86_64.whl 17 | 18 | ENV FORCE_CUDA="1" 19 | # Build D2 only for Volta architecture - V100 chips (ml.p3 AWS instances) 20 | ENV TORCH_CUDA_ARCH_LIST="Volta" 21 | 22 | # Set a fixed model cache directory. Detectron2 requirement 23 | ENV FVCORE_CACHE="/tmp" 24 | 25 | ############# SageMaker section ############## 26 | 27 | COPY container_training/sku-110k /opt/ml/code 28 | WORKDIR /opt/ml/code 29 | 30 | ENV SAGEMAKER_SUBMIT_DIRECTORY /opt/ml/code 31 | ENV SAGEMAKER_PROGRAM training.py 32 | 33 | WORKDIR / 34 | 35 | # Starts PyTorch distributed framework 36 | ENTRYPOINT ["bash", "-m", "start_with_right_hostname.sh"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Object Detection with Detectron2 on Amazon SageMaker 2 | 3 | ### Overview 4 | 5 | In this repository, we use [Amazon SageMaker](https://aws.amazon.com/sagemaker/) to build, train and deploy [Faster-RCNN](https://arxiv.org/abs/1506.01497) and [RetinaNet](https://arxiv.org/abs/1708.02002) models using [Detectron2](https://github.com/facebookresearch/detectron2). 6 | Detectron2 is an open-source project released by Facebook AI Research and build on top of PyTorch deep learning framework. Detectron2 makes easy to build, train and deploy state of the art object detection algorithms. Moreover, Detecron2’s design makes easy to implement cutting-edge research projects without having to fork the entire codebase.Detectron2 also provides a [Model Zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md) which is a collection of pre-trained detection models we can use to accelerate our endeavour. 7 | 8 | This repository shows how to do the following: 9 | 10 | * Build Detectron2 Docker images and push them to [Amazon ECR](https://aws.amazon.com/ecr/) to run training and inference jobs on Amazon SageMaker. 11 | * Register a dataset in Detectron2 catalog from annotations in augmented manifest files. Augmented manifest file is the output format of [Amazon SageMaker Ground Truth](https://aws.amazon.com/sagemaker/groundtruth/) annotation jobs. 12 | * Run a SageMaker Training job to finetune pre-trained model weights on a custom dataset. 13 | * Configure SageMaker Hyperparameter Optimization jobs to finetune hyper-parameters. 14 | * Run a SageMaker Batch Transform job to predict bouding boxes in a large chunk of images. 15 | 16 | ### Get Started 17 | 18 | Create a SageMaker notebook instance with an EBS volume equal or bigger than 30 GB and add the following lines to **start notebook** section of your life cycle configuration: 19 | 20 | ``` 21 | service docker stop 22 | sudo mv /var/lib/docker /home/ec2-user/SageMaker/docker 23 | sudo ln -s /home/ec2-user/SageMaker/docker /var/lib/docker 24 | service docker start 25 | ``` 26 | 27 | This ensures that docker builds images to a folder that is mounted on EBS. Once the instance is running, open Jupyter lab, launch a terminal and clone this repository: 28 | 29 | ``` 30 | cd SageMaker 31 | git clone https://github.com/aws-samples/amazon-sagemaker-pytorch-detectron2.git 32 | cd amazon-sagemaker-pytorch-detectron2 33 | ``` 34 | Open the [notebook](d2_custom_sku110k.ipynb). Follow the instruction in the notebook and use `conda_pytorch_p36` as kernel to execute code cells. 35 | 36 | You can also test the content in this repository on an EC2 that is running the [AWS Deep Learning AMI](https://docs.aws.amazon.com/dlami/latest/devguide/what-is-dlami.html). 37 | ### Instructions 38 | 39 | You will use a Detectron2 object detection model to recognize objects in densely packed scenes. You will use the SKU-110k dataset for this task. Be aware that the authors of the dataset provided it solely for academic and non-commercial purposes. Please refer to the following [paper](https://arxiv.org/abs/1904.00853) for further details on the dataset: 40 | 41 | ``` 42 | @inproceedings{goldman2019dense, 43 | author = {Eran Goldman and Roei Herzig and Aviv Eisenschtat and Jacob Goldberger and Tal Hassner}, 44 | title = {Precise Detection in Densely Packed Scenes}, 45 | booktitle = {Proc. Conf. Comput. Vision Pattern Recognition (CVPR)}, 46 | year = {2019} 47 | } 48 | ``` 49 | 50 | If you want details on the code used for [training](container_training/sku-110k) and [prediction](container_serving), please refer to code documentation in the respective source directories. 51 | 52 | ## Security 53 | 54 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 55 | 56 | ## License 57 | 58 | This library is licensed under the MIT-0 License. See the LICENSE file. 59 | 60 | -------------------------------------------------------------------------------- /build_and_push.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script shows how to build the Docker image and push it to ECR to be ready for use 4 | # by SageMaker. 5 | 6 | # There are 3 arguments in this script: 7 | # - image - required, this will be used as the image on the local machine and combined with the account and region to form the repository name for ECR; 8 | # - tag - optional, if provided, it will be used as ":tag" of your image; otherwise, ":latest" will be used; 9 | # - Dockerfile - optional, if provided, then docker will try to build image from provided dockerfile (e.g. "Dockerfile.serving"); otherwise, default "Dcokerfile" will be used. 10 | # Usage examples: 11 | # 1. "./build_and_push.sh d2-sm-coco-serving debug Dockerfile.serving" 12 | # 2. "./build_and_push.sh d2-sm-coco v2" 13 | 14 | image=$1 15 | tag=$2 16 | dockerfile=$3 17 | 18 | if [ "$image" == "" ] 19 | then 20 | echo "Usage: $0 " 21 | exit 1 22 | fi 23 | 24 | # Get the account number associated with the current IAM credentials 25 | account=$(aws sts get-caller-identity --query Account --output text) 26 | 27 | if [ $? -ne 0 ] 28 | then 29 | exit 255 30 | fi 31 | 32 | 33 | # Get the region defined in the current configuration (default to us-east-1 if none defined) 34 | region=$(aws configure get region) 35 | 36 | echo "Working in region $region" 37 | 38 | if [ "$tag" == "" ] 39 | then 40 | fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:latest" 41 | else 42 | fullname="${account}.dkr.ecr.${region}.amazonaws.com/${image}:${tag}" 43 | fi 44 | 45 | # If the repository doesn't exist in ECR, create it. 46 | 47 | aws ecr describe-repositories --repository-names "${image}" > /dev/null 2>&1 48 | 49 | if [ $? -ne 0 ] 50 | then 51 | aws ecr create-repository --repository-name "${image}" > /dev/null 52 | fi 53 | 54 | # Get the login command from ECR and execute it directly 55 | $(aws ecr get-login --region ${region} --no-include-email) 56 | 57 | # Build the docker image locally with the image name and then push it to ECR 58 | # with the full name. 59 | 60 | if [ "$dockerfile" == "" ] 61 | then 62 | docker build -t ${image} . --build-arg REGION=${region} 63 | else 64 | docker build -t ${image} . -f ${dockerfile} --build-arg REGION=${region} 65 | fi 66 | 67 | docker tag ${image} ${fullname} 68 | docker push ${fullname} 69 | -------------------------------------------------------------------------------- /container_serving/predict_sku110k.py: -------------------------------------------------------------------------------- 1 | """Code used for sagemaker batch transform jobs""" 2 | from typing import BinaryIO, Mapping 3 | import json 4 | import logging 5 | import sys 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import cv2 10 | import torch 11 | 12 | from detectron2.engine import DefaultPredictor 13 | from detectron2.config import CfgNode 14 | 15 | ############## 16 | # Macros 17 | ############## 18 | 19 | LOGGER = logging.Logger("InferenceScript", level=logging.INFO) 20 | HANDLER = logging.StreamHandler(sys.stdout) 21 | HANDLER.setFormatter(logging.Formatter("%(levelname)s | %(name)s | %(message)s")) 22 | LOGGER.addHandler(HANDLER) 23 | 24 | ########## 25 | # Deploy 26 | ########## 27 | def _load_from_bytearray(request_body: BinaryIO) -> np.ndarray: 28 | npimg = np.frombuffer(request_body, np.uint8) 29 | return cv2.imdecode(npimg, cv2.IMREAD_COLOR) 30 | 31 | 32 | def model_fn(model_dir: str) -> DefaultPredictor: 33 | r"""Load trained model 34 | 35 | Parameters 36 | ---------- 37 | model_dir : str 38 | S3 location of the model directory 39 | 40 | Returns 41 | ------- 42 | DefaultPredictor 43 | PyTorch model created by using Detectron2 API 44 | """ 45 | path_cfg, path_model = None, None 46 | for p_file in Path(model_dir).iterdir(): 47 | if p_file.suffix == ".json": 48 | path_cfg = p_file 49 | if p_file.suffix == ".pth": 50 | path_model = p_file 51 | 52 | LOGGER.info(f"Using configuration specified in {path_cfg}") 53 | LOGGER.info(f"Using model saved at {path_model}") 54 | 55 | if path_model is None: 56 | err_msg = "Missing model PTH file" 57 | LOGGER.error(err_msg) 58 | raise RuntimeError(err_msg) 59 | if path_cfg is None: 60 | err_msg = "Missing configuration JSON file" 61 | LOGGER.error(err_msg) 62 | raise RuntimeError(err_msg) 63 | 64 | with open(str(path_cfg)) as fid: 65 | cfg = CfgNode(json.load(fid)) 66 | 67 | cfg.MODEL.WEIGHTS = str(path_model) 68 | cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 69 | 70 | return DefaultPredictor(cfg) 71 | 72 | 73 | def input_fn(request_body: BinaryIO, request_content_type: str) -> np.ndarray: 74 | r"""Parse input data 75 | 76 | Parameters 77 | ---------- 78 | request_body : BinaryIO 79 | encoded input image 80 | request_content_type : str 81 | type of content 82 | 83 | Returns 84 | ------- 85 | np.ndarray 86 | input image 87 | 88 | Raises 89 | ------ 90 | ValueError 91 | ValueError if the content type is not `application/x-image` 92 | """ 93 | if request_content_type == "application/x-image": 94 | np_image = _load_from_bytearray(request_body) 95 | else: 96 | err_msg = f"Type [{request_content_type}] not support this type yet" 97 | LOGGER.error(err_msg) 98 | raise ValueError(err_msg) 99 | return np_image 100 | 101 | 102 | def predict_fn(input_object: np.ndarray, predictor: DefaultPredictor) -> Mapping: 103 | r"""Run Detectron2 prediction 104 | 105 | Parameters 106 | ---------- 107 | input_object : np.ndarray 108 | input image 109 | predictor : DefaultPredictor 110 | Detectron2 default predictor (see Detectron2 documentation for details) 111 | 112 | Returns 113 | ------- 114 | Mapping 115 | a dictionary that contains: the image shape (`image_height`, `image_width`), the predicted 116 | bounding boxes in format x1y1x2y2 (`pred_boxes`), the confidence scores (`scores`) and the 117 | labels associated with the bounding boxes (`pred_boxes`) 118 | """ 119 | LOGGER.info(f"Prediction on image of shape {input_object.shape}") 120 | outputs = predictor(input_object) 121 | fmt_out = { 122 | "image_height": input_object.shape[0], 123 | "image_width": input_object.shape[1], 124 | "pred_boxes": outputs["instances"].pred_boxes.tensor.tolist(), 125 | "scores": outputs["instances"].scores.tolist(), 126 | "pred_classes": outputs["instances"].pred_classes.tolist(), 127 | } 128 | LOGGER.info(f"Number of detected boxes: {len(fmt_out['pred_boxes'])}") 129 | return fmt_out 130 | 131 | 132 | # pylint: disable=unused-argument 133 | def output_fn(predictions, response_content_type): 134 | r"""Serialize the prediction result into the desired response content type""" 135 | return json.dumps(predictions) 136 | -------------------------------------------------------------------------------- /container_training/sku-110k/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train FasterRCNN and Retinanet models on SKU-110k dataset: 3 | 4 | * Build deep learning models by using Detectron2 (https://detectron2.readthedocs.io/) 5 | * Pretrained models at https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md 6 | * This module has an entry_point to be used with SageMaker training jobs (see 'training.py') 7 | """ 8 | -------------------------------------------------------------------------------- /container_training/sku-110k/datasets/catalog.py: -------------------------------------------------------------------------------- 1 | """Add to Detectron2 catalog SKU-110k dataset""" 2 | from typing import Sequence, Mapping, Tuple 3 | from pathlib import Path 4 | import json 5 | from functools import partial 6 | from dataclasses import dataclass 7 | import sys 8 | import logging 9 | 10 | from detectron2.structures import BoxMode 11 | from detectron2.data import MetadataCatalog, DatasetCatalog, Metadata 12 | 13 | LOGGER = logging.Logger(name="Catalog", level=logging.INFO) 14 | HANDLER = logging.StreamHandler(sys.stdout) 15 | HANDLER.setFormatter(logging.Formatter("%(levelname)s | %(name)s | %(message)s")) 16 | LOGGER.addHandler(HANDLER) 17 | 18 | 19 | @dataclass 20 | class DataSetMeta: 21 | r"""Dataset metadata 22 | 23 | Attributes 24 | ---------- 25 | name : str 26 | dataset name 27 | classes : Sequence[str] 28 | class of objects to detect 29 | """ 30 | 31 | name: str 32 | classes: Sequence[str] 33 | 34 | def __str__(self): 35 | """Print dataset name and class names""" 36 | return ( 37 | f"The object detection dataset {self.name} " 38 | f"can detect {len(self.classes)} type(s) of objects: " 39 | f"{self.classes}" 40 | ) 41 | 42 | 43 | def remove_dataset(ds_name: str): 44 | r"""Remove a previously registered dataset 45 | 46 | Parameters 47 | ---------- 48 | ds_name : str 49 | the dataset to be removed 50 | """ 51 | for channel in ("training", "validation"): 52 | DatasetCatalog.remove(f"{ds_name}_{channel}") 53 | 54 | 55 | def aws_file_mode( 56 | path_imgs: str, path_annotation: str, label_name: str 57 | ) -> Sequence[Mapping]: 58 | r"""Add dataset to Detectron by using the schema used by AWS for object detection 59 | 60 | Parameters 61 | ---------- 62 | path_imgs : str 63 | path to folder that contains the images 64 | path_annotation : str 65 | path to the augmented manifest file that contains the annotations 66 | label_name : str 67 | label name used for object detection GT job 68 | 69 | Returns 70 | ------- 71 | Sequence[Mapping] 72 | list of annotations 73 | 74 | Raises 75 | ------ 76 | FileNotFoundError 77 | if the images to which the manifest file points to are not in path_imgs 78 | """ 79 | dataset_dicts = [] 80 | 81 | with open(path_annotation, "r") as ann_fid: 82 | for img_id, jsonline in enumerate(ann_fid): 83 | annotations = json.loads(jsonline) 84 | if "source-ref" not in annotations: 85 | err_msg = f"{path_annotation} is not a valid manifest file" 86 | LOGGER.error(err_msg) 87 | raise ValueError(err_msg) 88 | 89 | path_s3_img = Path(annotations["source-ref"]) 90 | if path_s3_img.suffix.lower() not in (".png", ".jpg"): 91 | LOGGER.warning( 92 | f"{path_s3_img.name} is not an image and it will be ignore" 93 | ) 94 | continue 95 | local_path_to_img = Path(path_imgs) / path_s3_img.name 96 | if not local_path_to_img.exists(): 97 | LOGGER.warning( 98 | f"{path_s3_img.name} not found in image channel: annotations are neglected" 99 | ) 100 | continue 101 | 102 | record = { 103 | "file_name": str(local_path_to_img), 104 | "height": annotations[label_name]["image_size"][0]["height"], 105 | "width": annotations[label_name]["image_size"][0]["width"], 106 | "image_id": img_id, 107 | } 108 | 109 | objs = [] 110 | for bbox in annotations[label_name]["annotations"]: 111 | objs.append( 112 | { 113 | "bbox": [ 114 | bbox["left"], 115 | bbox["top"], 116 | bbox["width"], 117 | bbox["height"], 118 | ], 119 | "bbox_mode": BoxMode.XYWH_ABS, 120 | "category_id": bbox["class_id"], 121 | } 122 | ) 123 | record["annotations"] = objs 124 | dataset_dicts.append(record) 125 | 126 | return dataset_dicts 127 | 128 | 129 | # pylint: disable=too-many-arguments 130 | def register_dataset( 131 | metadata: DataSetMeta, 132 | label_name: str, 133 | channel_to_dataset: Mapping[str, Tuple[str, str]], 134 | ) -> Metadata: 135 | r"""Register training and validation datasets to detectron2 136 | 137 | Parameters 138 | ---------- 139 | metadata : DataSetMeta 140 | metadata of the datasets to register 141 | label_name : str 142 | label name used for object detection GT job 143 | channel_to_dataset: Mapping[str, Tuple[str, str]] 144 | map channel name to dataset, each dataset is a 2D-tuple that contains path to images and 145 | path to augmented manifest file 146 | 147 | Returns 148 | ------- 149 | Metadata 150 | Metadata file 151 | """ 152 | 153 | for channel, datasets in channel_to_dataset.items(): 154 | detectron_ds_name = f"{metadata.name}_{channel}" 155 | DatasetCatalog.register( 156 | detectron_ds_name, 157 | partial(aws_file_mode, datasets[0], datasets[1], label_name), 158 | ) 159 | MetadataCatalog.get(detectron_ds_name).set(thing_classes=metadata.classes) 160 | LOGGER.info(f"{detectron_ds_name} dataset added to catalog") 161 | return MetadataCatalog.get(f"{metadata.name}_{list(channel_to_dataset.keys())[0]}") 162 | -------------------------------------------------------------------------------- /container_training/sku-110k/engine/__init__.py: -------------------------------------------------------------------------------- 1 | """Custom Trainer""" 2 | -------------------------------------------------------------------------------- /container_training/sku-110k/engine/custom_trainer.py: -------------------------------------------------------------------------------- 1 | """Implementation of custom trainer""" 2 | from detectron2.engine import DefaultTrainer 3 | from detectron2.data import ( 4 | build_detection_test_loader, 5 | build_detection_train_loader, 6 | DatasetMapper, 7 | ) 8 | from detectron2.config import CfgNode 9 | import detectron2.data.transforms as T 10 | from detectron2.utils import comm 11 | from detectron2.engine import hooks 12 | from detectron2.utils.events import CommonMetricPrinter 13 | 14 | from engine.hooks import ValidationLoss 15 | 16 | ######################################################## 17 | # MACROs that define training and validation transforms 18 | ######################################################## 19 | 20 | TRAIN_TRANSF = [ 21 | T.ResizeShortestEdge( 22 | short_edge_length=(800,), max_size=1333, sample_style="choice", 23 | ), 24 | ] 25 | VAL_TRANSF = [ 26 | T.ResizeShortestEdge(short_edge_length=(800,), max_size=1333, sample_style="choice") 27 | ] 28 | 29 | ######################################################## 30 | ######################################################## 31 | 32 | 33 | class Trainer(DefaultTrainer): 34 | r""" Use a custom trainer 35 | 36 | The main differences compared with DefaultTrainer are: 37 | 38 | * Use custom transforms rather than default ones defined in the config 39 | * Use custom hooks 40 | """ 41 | 42 | @classmethod 43 | def build_test_loader(cls, cfg: CfgNode, dataset_name: str): 44 | return build_detection_test_loader( 45 | cfg, 46 | dataset_name, 47 | # pylint:disable=redundant-keyword-arg,missing-kwoa 48 | mapper=DatasetMapper(cfg, is_train=False, augmentations=VAL_TRANSF), 49 | ) 50 | 51 | @classmethod 52 | def build_train_loader(cls, cfg: CfgNode): 53 | return build_detection_train_loader( 54 | cfg, 55 | # pylint:disable=redundant-keyword-arg,missing-kwoa 56 | mapper=DatasetMapper(cfg, is_train=True, augmentations=TRAIN_TRANSF), 57 | ) 58 | 59 | @classmethod 60 | def build_evaluator(cls, cfg, dataset_name): 61 | r"""Use Validation loss for evaluation+. Coco evaluation is executed outside of training""" 62 | return None 63 | 64 | def build_hooks(self): 65 | r"""Build hooks 66 | 67 | We use: timing, lr scheduling, checkpointing, lr scheduling, ValidationLoss, writing events 68 | """ 69 | cfg = self.cfg.clone() 70 | cfg.defrost() 71 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 72 | 73 | ret = [ 74 | hooks.IterationTimer(), 75 | hooks.LRScheduler(self.optimizer, self.scheduler), 76 | ] 77 | 78 | # Do PreciseBN before checkpointer, because it updates the model and need to 79 | # be saved by checkpointer. 80 | # This is not always the best: if checkpointing has a different frequency, 81 | # some checkpoints may have more precise statistics than others. 82 | if comm.is_main_process(): 83 | ret.append( 84 | hooks.PeriodicCheckpointer( 85 | self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD 86 | ) 87 | ) 88 | 89 | ret.append(ValidationLoss(cfg, VAL_TRANSF, cfg.VAL_LOG_PERIOD)) 90 | 91 | if comm.is_main_process(): 92 | # run writers in the end, so that evaluation metrics are written 93 | ret.append( 94 | hooks.PeriodicWriter(self.build_writers(), period=cfg.VAL_LOG_PERIOD) 95 | ) 96 | return ret 97 | 98 | def build_writers(self): 99 | r"""Metric to print. This is used by `PeriodicWriter` hook""" 100 | return [ 101 | CommonMetricPrinter(self.cfg.SOLVER.MAX_ITER), 102 | ] 103 | -------------------------------------------------------------------------------- /container_training/sku-110k/engine/hooks.py: -------------------------------------------------------------------------------- 1 | """Implementation of Hooks to be used in training loop""" 2 | from typing import Sequence 3 | 4 | import torch 5 | from detectron2.engine.hooks import HookBase 6 | from detectron2.data import ( 7 | build_detection_train_loader, 8 | DatasetMapper, 9 | ) 10 | from detectron2.utils import comm 11 | from detectron2.config import CfgNode 12 | from detectron2.data.transforms import Augmentation 13 | 14 | class ValidationLoss(HookBase): 15 | r"""Hook that computes validation loss during training 16 | 17 | Parameters 18 | ---------- 19 | cfg : CfgNode 20 | Training configuration 21 | val_augmentation : Sequence[Augmentation] 22 | Data augmentation functions applied to validation data 23 | period : int 24 | The validation loss values are updated each `period` iterations 25 | 26 | Attributes 27 | ---------- 28 | cfg : CfgNode 29 | Clone of `cfg` parameters 30 | _loader : detectron2.data.DataLoader 31 | Validation data loader 32 | _period : int 33 | See `period` parameter 34 | num_steps : int 35 | It keeps track of the current iteration id 36 | """ 37 | def __init__(self, cfg: CfgNode, val_augmentation: Sequence[Augmentation], period: int): 38 | super().__init__() 39 | self.cfg = cfg.clone() 40 | self.cfg.DATASETS.TRAIN = cfg.DATASETS.TEST 41 | self._loader = iter( 42 | build_detection_train_loader( 43 | self.cfg, 44 | mapper=DatasetMapper( 45 | self.cfg, is_train=True, augmentations=val_augmentation 46 | ), 47 | ) 48 | ) 49 | self._period = period 50 | self.num_steps = 0 51 | 52 | def after_step(self): 53 | """Run after every iteration, see parent for details""" 54 | self.num_steps += 1 55 | if self.num_steps % self._period == 0: 56 | data = next(self._loader) 57 | 58 | if torch.cuda.is_available(): 59 | torch.cuda.synchronize() 60 | 61 | with torch.no_grad(): 62 | loss_dict = self.trainer.model(data) 63 | 64 | losses = sum(loss_dict.values()) 65 | assert torch.isfinite(losses).all(), loss_dict 66 | 67 | loss_dict_reduced = { 68 | "val_" + k: v.item() for k, v in comm.reduce_dict(loss_dict).items() 69 | } 70 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 71 | if comm.is_main_process(): 72 | self.trainer.storage.put_scalars( 73 | total_val_loss=losses_reduced, **loss_dict_reduced 74 | ) 75 | comm.synchronize() 76 | else: 77 | pass 78 | -------------------------------------------------------------------------------- /container_training/sku-110k/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """Coco Style Evaluation""" 2 | -------------------------------------------------------------------------------- /container_training/sku-110k/evaluation/coco.py: -------------------------------------------------------------------------------- 1 | """Run COCO evaluation on arbitrary number of predicted bounding boxes""" 2 | import itertools 3 | import os 4 | import json 5 | import copy 6 | from typing import Tuple 7 | import time 8 | 9 | import numpy as np 10 | from pycocotools.cocoeval import COCOeval 11 | from detectron2 import _C 12 | from detectron2.evaluation import COCOEvaluator 13 | from detectron2.utils.file_io import PathManager 14 | 15 | 16 | class EvaluateObjectDetection(COCOeval): 17 | r"""Run COCO evaluation on an arbitrary number of bounding boxes""" 18 | 19 | def summarize(self): 20 | """Compute and display summary metrics for evaluation results""" 21 | 22 | def _summarize(use_ap: bool, iou_thr=None, area_rng="all", max_dets=100): 23 | params = self.params 24 | out_str = ( 25 | " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" 26 | ) 27 | title_str = "Average Precision" if use_ap else "Average Recall" 28 | type_str = "(AP)" if use_ap else "(AR)" 29 | iou_str = ( 30 | "{:0.2f}:{:0.2f}".format(params.iouThrs[0], params.iouThrs[-1]) 31 | if iou_thr is None 32 | else "{:0.2f}".format(iou_thr) 33 | ) 34 | 35 | aind = [i for i, aRng in enumerate(params.areaRngLbl) if aRng == area_rng] 36 | mind = [i for i, mDet in enumerate(params.maxDets) if mDet == max_dets] 37 | if use_ap: 38 | # dimension of precision: [TxRxKxAxM] 39 | metric_val = self.eval["precision"] 40 | # IoU 41 | if iou_thr is not None: 42 | metric_id = np.where(iou_thr == params.iouThrs)[0] 43 | metric_val = metric_val[metric_id] 44 | metric_val = metric_val[:, :, :, aind, mind] 45 | else: 46 | # dimension of recall: [TxKxAxM] 47 | metric_val = self.eval["recall"] 48 | if iou_thr is not None: 49 | metric_id = np.where(iou_thr == params.iouThrs)[0] 50 | metric_val = metric_val[metric_id] 51 | metric_val = metric_val[:, :, aind, mind] 52 | if len(metric_val[metric_val > -1]) == 0: 53 | mean_s = -1 54 | else: 55 | mean_s = np.mean(metric_val[metric_val > -1]) 56 | print( 57 | out_str.format(title_str, type_str, iou_str, area_rng, max_dets, mean_s) 58 | ) 59 | return mean_s 60 | 61 | def _summarize_detections(): 62 | stats = np.zeros((12,)) 63 | stats[0] = _summarize(True, max_dets=self.params.maxDets[0]) 64 | stats[1] = _summarize(True, iou_thr=0.5, max_dets=self.params.maxDets[0]) 65 | stats[2] = _summarize(True, iou_thr=0.75, max_dets=self.params.maxDets[0]) 66 | stats[3] = _summarize( 67 | True, area_rng="small", max_dets=self.params.maxDets[0] 68 | ) 69 | stats[4] = _summarize( 70 | True, area_rng="medium", max_dets=self.params.maxDets[0] 71 | ) 72 | stats[5] = _summarize( 73 | True, area_rng="large", max_dets=self.params.maxDets[0] 74 | ) 75 | stats[6] = _summarize(False, max_dets=self.params.maxDets[0]) 76 | stats[9] = _summarize( 77 | False, area_rng="small", max_dets=self.params.maxDets[0] 78 | ) 79 | stats[10] = _summarize( 80 | False, area_rng="medium", max_dets=self.params.maxDets[0] 81 | ) 82 | stats[11] = _summarize( 83 | False, area_rng="large", max_dets=self.params.maxDets[0] 84 | ) 85 | return stats 86 | 87 | if not self.eval: 88 | raise Exception("Please run accumulate() first") 89 | 90 | if self.params.iouType != "bbox": 91 | raise ValueError( 92 | f"{type(self).__name__} supports object detection evaluation only" 93 | ) 94 | self.stats = _summarize_detections() 95 | 96 | 97 | class FastEvaluateObjectDetection(EvaluateObjectDetection): 98 | r"""This class is the same as Detecron2's `COCOeval_opt` 99 | 100 | The only change is that this class inherits from `EvaluateObjectDetection` instead that from 101 | COCOeval 102 | """ 103 | 104 | def evaluate(self): 105 | """ 106 | Run per image evaluation on given images and store results in self.evalImgs_cpp, a 107 | datastructure that isn't readable from Python but is used by a c++ implementation of 108 | accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure 109 | self.evalImgs because this datastructure is a computational bottleneck. 110 | :return: None 111 | """ 112 | tic = time.time() 113 | 114 | print("Running per image evaluation...") 115 | params = self.params 116 | # add backward compatibility if useSegm is specified in params 117 | if params.useSegm is not None: 118 | params.iouType = "segm" if params.useSegm == 1 else "bbox" 119 | print( 120 | "useSegm (deprecated) is not None. Running {} evaluation".format( 121 | params.iouType 122 | ) 123 | ) 124 | print("Evaluate annotation type *{}*".format(params.iouType)) 125 | params.imgIds = list(np.unique(params.imgIds)) 126 | if params.useCats: 127 | params.catIds = list(np.unique(params.catIds)) 128 | params.maxDets = sorted(params.maxDets) 129 | self.params = params 130 | 131 | self._prepare() 132 | 133 | # loop through images, area range, max detection number 134 | cat_ids = params.catIds if params.useCats else [-1] 135 | 136 | if params.iouType == "segm" or params.iouType == "bbox": 137 | compute_IoU = self.computeIoU 138 | elif params.iouType == "keypoints": 139 | compute_IoU = self.computeOks 140 | else: 141 | assert False, f"Add implementation for {params.iouType}" 142 | self.ious = { 143 | (imgId, catId): compute_IoU(imgId, catId) 144 | for imgId in params.imgIds 145 | for catId in cat_ids 146 | } 147 | 148 | maxDet = params.maxDets[-1] 149 | 150 | # <<<< Beginning of code differences with original COCO API 151 | def convert_instances_to_cpp(instances, is_det=False): 152 | # Convert annotations for a list of instances in an image to a format that's fast 153 | # to access in C++ 154 | instances_cpp = [] 155 | for instance in instances: 156 | instance_cpp = _C.InstanceAnnotation( 157 | int(instance["id"]), 158 | instance["score"] if is_det else instance.get("score", 0.0), 159 | instance["area"], 160 | bool(instance.get("iscrowd", 0)), 161 | bool(instance.get("ignore", 0)), 162 | ) 163 | instances_cpp.append(instance_cpp) 164 | return instances_cpp 165 | 166 | # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++ 167 | ground_truth_instances = [ 168 | [ 169 | convert_instances_to_cpp(self._gts[imgId, catId]) 170 | for catId in params.catIds 171 | ] 172 | for imgId in params.imgIds 173 | ] 174 | detected_instances = [ 175 | [ 176 | convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) 177 | for catId in params.catIds 178 | ] 179 | for imgId in params.imgIds 180 | ] 181 | ious = [ 182 | [self.ious[imgId, catId] for catId in cat_ids] for imgId in params.imgIds 183 | ] 184 | 185 | if not params.useCats: 186 | # For each image, flatten per-category lists into a single list 187 | ground_truth_instances = [ 188 | [[o for c in i for o in c]] for i in ground_truth_instances 189 | ] 190 | detected_instances = [ 191 | [[o for c in i for o in c]] for i in detected_instances 192 | ] 193 | 194 | # Call C++ implementation of self.evaluateImgs() 195 | self._evalImgs_cpp = _C.COCOevalEvaluateImages( 196 | params.areaRng, 197 | maxDet, 198 | params.iouThrs, 199 | ious, 200 | ground_truth_instances, 201 | detected_instances, 202 | ) 203 | self._evalImgs = None 204 | 205 | self._paramsEval = copy.deepcopy(self.params) 206 | toc = time.time() 207 | print("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic)) 208 | # >>>> End of code differences with original COCO API 209 | 210 | def accumulate(self): 211 | """ 212 | Accumulate per image evaluation results and store the result in self.eval. Does not 213 | support changing parameter settings from those used by self.evaluate() 214 | """ 215 | print("Accumulating evaluation results...") 216 | tic = time.time() 217 | if not hasattr(self, "_evalImgs_cpp"): 218 | print("Please run evaluate() first") 219 | 220 | self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp) 221 | 222 | # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections 223 | self.eval["recall"] = np.array(self.eval["recall"]).reshape( 224 | self.eval["counts"][:1] + self.eval["counts"][2:] 225 | ) 226 | 227 | # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X 228 | # num_area_ranges X num_max_detections 229 | self.eval["precision"] = np.array(self.eval["precision"]).reshape( 230 | self.eval["counts"] 231 | ) 232 | self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"]) 233 | toc = time.time() 234 | print( 235 | "COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic) 236 | ) 237 | 238 | 239 | class D2CocoEvaluator(COCOEvaluator): 240 | r"""Detectron2 COCO evaluator that allows setting the maximum number of detections""" 241 | 242 | def __init__( 243 | self, 244 | dataset_name: str, 245 | tasks: Tuple[str, ...], 246 | distributed: bool, 247 | output_dir: str, 248 | use_fast_impl: bool, 249 | nb_max_preds: int, 250 | ): 251 | super().__init__( 252 | dataset_name=dataset_name, 253 | tasks=tasks, 254 | distributed=distributed, 255 | output_dir=output_dir, 256 | use_fast_impl=use_fast_impl, 257 | ) 258 | self._nb_max_preds = nb_max_preds 259 | 260 | def _eval_predictions(self, predictions, img_ids=None): 261 | """ 262 | Evaluate predictions on the given tasks. 263 | Fill self._results with the metrics of the tasks. 264 | """ 265 | self._logger.info("Preparing results for COCO format ...") 266 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) 267 | tasks = self._tasks or self._tasks_from_predictions(coco_results) 268 | 269 | # unmap the category ids for COCO 270 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 271 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id 272 | all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) 273 | num_classes = len(all_contiguous_ids) 274 | assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 275 | 276 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} 277 | for result in coco_results: 278 | category_id = result["category_id"] 279 | assert category_id < num_classes, ( 280 | f"A prediction has class={category_id}, " 281 | f"but the dataset only has {num_classes} classes and " 282 | f"predicted class id should be in [0, {num_classes - 1}]." 283 | ) 284 | result["category_id"] = reverse_id_mapping[category_id] 285 | 286 | if self._output_dir: 287 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 288 | self._logger.info("Saving results to {}".format(file_path)) 289 | with PathManager.open(file_path, "w") as f: 290 | f.write(json.dumps(coco_results)) 291 | f.flush() 292 | 293 | if not self._do_evaluation: 294 | self._logger.info("Annotations are not available for evaluation.") 295 | return 296 | 297 | self._logger.info( 298 | "Evaluating predictions with {} COCO API...".format( 299 | "unofficial" if self._use_fast_impl else "official" 300 | ) 301 | ) 302 | for task in sorted(tasks): 303 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" 304 | coco_eval = ( 305 | _evaluate_on_coco_impl( 306 | self._coco_api, 307 | coco_results, 308 | task, 309 | max_nb_preds=self._nb_max_preds, 310 | kpt_oks_sigmas=self._kpt_oks_sigmas, 311 | use_fast_impl=self._use_fast_impl, 312 | img_ids=img_ids, 313 | ) 314 | if len(coco_results) > 0 315 | else None # cocoapi does not handle empty results very well 316 | ) 317 | 318 | res = self._derive_coco_results( 319 | coco_eval, task, class_names=self._metadata.get("thing_classes") 320 | ) 321 | self._results[task] = res 322 | 323 | 324 | def _evaluate_on_coco_impl( 325 | coco_gt, 326 | coco_results, 327 | iou_type, 328 | max_nb_preds, 329 | kpt_oks_sigmas=None, 330 | use_fast_impl=True, 331 | img_ids=None, 332 | ): 333 | """ 334 | Evaluate the coco results using COCOEval API. 335 | """ 336 | assert len(coco_results) > 0 337 | 338 | if iou_type == "segm": 339 | coco_results = copy.deepcopy(coco_results) 340 | # When evaluating mask AP, if the results contain bbox, cocoapi will 341 | # use the box area as the area of the instance, instead of the mask area. 342 | # This leads to a different definition of small/medium/large. 343 | # We remove the bbox field to let mask AP use mask area. 344 | for c in coco_results: 345 | c.pop("bbox", None) 346 | 347 | coco_dt = coco_gt.loadRes(coco_results) 348 | coco_eval = ( 349 | FastEvaluateObjectDetection if use_fast_impl else EvaluateObjectDetection 350 | )(coco_gt, coco_dt, iou_type) 351 | coco_eval.params.maxDets = [max_nb_preds] 352 | if img_ids is not None: 353 | coco_eval.params.imgIds = img_ids 354 | 355 | if iou_type == "keypoints": 356 | # Use the COCO default keypoint OKS sigmas unless overrides are specified 357 | if kpt_oks_sigmas: 358 | assert hasattr( 359 | coco_eval.params, "kpt_oks_sigmas" 360 | ), "pycocotools is too old!" 361 | coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas) 362 | # COCOAPI requires every detection and every gt to have keypoints, so 363 | # we just take the first entry from both 364 | num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3 365 | num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3 366 | num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas) 367 | assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, ( 368 | f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. " 369 | f"Ground truth contains {num_keypoints_gt} keypoints. " 370 | f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. " 371 | "They have to agree with each other. For meaning of OKS, please refer to " 372 | "http://cocodataset.org/#keypoints-eval." 373 | ) 374 | 375 | coco_eval.evaluate() 376 | coco_eval.accumulate() 377 | coco_eval.summarize() 378 | 379 | return coco_eval 380 | -------------------------------------------------------------------------------- /container_training/sku-110k/evaluation/factory.py: -------------------------------------------------------------------------------- 1 | """Create annotation files""" 2 | from typing import Sequence 3 | import json 4 | 5 | from pycocotools.coco import COCO 6 | 7 | from evaluation.utils import annotation_to_coco 8 | 9 | 10 | def generate_ground_truth( 11 | dataset_description: str, categories: Sequence[str], s3_objs, fname: str 12 | ): 13 | """TODO add doc""" 14 | info = { 15 | "year": 2020, 16 | "version": "0.0", 17 | "description": "SKU-110k", 18 | } 19 | cat = [ 20 | {"id": cat_id, "name": cat_name} for cat_id, cat_name in enumerate(categories) 21 | ] 22 | 23 | images, annotations = annotation_to_coco(s3_objs) 24 | 25 | coco_annotation = { 26 | "info": info, 27 | "images": images, 28 | "annotations": annotations, 29 | "categories": cat, 30 | } 31 | 32 | with open(fname, "w") as fid: 33 | json.dump(coco_annotation, fid) 34 | 35 | return COCO(fname), coco_annotation 36 | 37 | 38 | def generate_predictions(s3_pred_objects: Sequence, image_ids: Sequence[int], coco_gt, fname: str): 39 | 40 | assert len(s3_pred_objects) == len(image_ids), f"Mismatch nb of objects ({len(s3_pred_objects)}) vs nb of identifiers ({len(image_ids)})" 41 | 42 | predictions = [] 43 | for pred_obj, elem in zip(s3_pred_objects, image_ids): 44 | preds = json.loads(pred_obj.get()["Body"].read().decode("utf-8")) 45 | predictions += instances_to_coco_json(convert_to_d2_preds(preds), elem["id"]) 46 | 47 | with open(fname, 'w') as fid: 48 | json.dump(predictions, fid) 49 | 50 | return coco_gt.loadRes(fname) -------------------------------------------------------------------------------- /container_training/sku-110k/evaluation/run.py: -------------------------------------------------------------------------------- 1 | """Run COCO evaluation with custom # of max detections""" -------------------------------------------------------------------------------- /container_training/sku-110k/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | """General Utils""" 2 | from pathlib import Path 3 | import json 4 | from typing import Mapping, Tuple 5 | 6 | import torch 7 | from detectron2.evaluation.coco_evaluation import instances_to_coco_json 8 | from detectron2.structures import Instances, Boxes 9 | 10 | 11 | def _annotation_from_single_img( 12 | p_data: Mapping, p_img_id: int, label_name: str 13 | ) -> Tuple[Mapping, Mapping, int]: 14 | """ 15 | Convert annotations on a single image from detectron2 format to COCO one 16 | 17 | Args: 18 | p_data ([type]): raw annotation 19 | p_img_id ([type]): current image id 20 | label_name (str): label name in raw annotations 21 | 22 | Returns: 23 | Tuple[Mapping, Mapping, int]: COCO image metadata, COCO annotation, next image id 24 | """ 25 | out_images = { 26 | "id": p_img_id, 27 | "width": p_data[label_name]["image_size"][0]["width"], 28 | "height": p_data[label_name]["image_size"][0]["height"], 29 | "file_name": p_data["source-ref"], 30 | } 31 | out_annotations = [] 32 | ann_id = p_img_id 33 | for elem in p_data["sku"]["annotations"]: 34 | out_annotations.append( 35 | { 36 | "id": ann_id, 37 | "image_id": p_img_id, 38 | "category_id": elem["class_id"], 39 | "bbox": [elem["left"], elem["top"], elem["width"], elem["height"]], 40 | "iscrowd": 0, 41 | "area": float(elem["width"] * elem["height"]), 42 | } 43 | ) 44 | ann_id += 1 45 | 46 | return out_images, out_annotations, ann_id 47 | 48 | 49 | def annotation_to_coco(s3_obj_iter): 50 | """Convert all the annotations to COCO style""" 51 | out_images = [] 52 | out_annotations = [] 53 | img_id = 0 54 | for ann_obj in s3_obj_iter: 55 | if Path(ann_obj.key).suffix != ".json": 56 | continue 57 | 58 | data = json.loads(ann_obj.get()["Body"].read().decode("utf-8")) 59 | 60 | elem_img, elem_ann, img_id = _annotation_from_single_img(data, img_id) 61 | out_images.append(elem_img) 62 | out_annotations += elem_ann 63 | return out_images, out_annotations 64 | 65 | 66 | def convert_to_d2_preds(json_data) -> Instances: 67 | """convert detectron2 raw predictions to Detectron2 format""" 68 | f_out = Instances((json_data["image_height"], json_data["image_width"])) 69 | f_out.pred_boxes = Boxes(torch.tensor(json_data["pred_boxes"])) 70 | f_out.scores = torch.tensor(json_data["scores"]) 71 | f_out.pred_classes = torch.tensor(json_data["pred_classes"]) 72 | return f_out 73 | -------------------------------------------------------------------------------- /container_training/sku-110k/training.py: -------------------------------------------------------------------------------- 1 | """Entry point of the Detectron2 container that is used to train models on SKU-110k dataset""" 2 | import os 3 | import argparse 4 | import logging 5 | import sys 6 | import ast 7 | import json 8 | from pathlib import Path 9 | 10 | from detectron2.engine import launch 11 | from detectron2.config import get_cfg, CfgNode 12 | from detectron2 import model_zoo 13 | from detectron2.checkpoint import DetectionCheckpointer 14 | 15 | from datasets.catalog import register_dataset, DataSetMeta 16 | from engine.custom_trainer import Trainer 17 | from evaluation.coco import D2CocoEvaluator 18 | 19 | ############## 20 | # Macros 21 | ############## 22 | LOGGER = logging.Logger("TrainingScript", level=logging.INFO) 23 | HANDLER = logging.StreamHandler(sys.stdout) 24 | HANDLER.setFormatter(logging.Formatter("%(levelname)s | %(name)s | %(message)s")) 25 | LOGGER.addHandler(HANDLER) 26 | 27 | ######################## 28 | # Implementation Details 29 | ######################## 30 | 31 | 32 | def _config_training(args: argparse.Namespace) -> CfgNode: 33 | r"""Create a configuration node from the script arguments. 34 | 35 | In this application we consider object detection use case only. We finetune object detection 36 | networks trained on COCO dataset to a custom use case 37 | 38 | Parameters 39 | ---------- 40 | args : argparse.Namespace 41 | training script arguments, see :py:meth:`_parse_args()` 42 | 43 | Returns 44 | ------- 45 | CfgNode 46 | configuration that is used by Detectron2 to train a model 47 | 48 | Raises: 49 | RuntimeError: if the combination of `model_type`, `backbone`, `lr_schedule` is not valid. 50 | Please refer to Detectron2 model zoo for valid options. 51 | """ 52 | cfg = get_cfg() 53 | pretrained_model = ( 54 | f"COCO-Detection/{args.model_type}_{args.backbone}_{args.lr_schedule}x.yaml" 55 | ) 56 | LOGGER.info(f"Loooking for the pretrained model {pretrained_model}...") 57 | try: 58 | cfg.merge_from_file(model_zoo.get_config_file(pretrained_model)) 59 | except RuntimeError as err: 60 | LOGGER.error(f"{err}: check model backbone and lr schedule combination") 61 | raise 62 | cfg.DATASETS.TRAIN = (f"{args.dataset_name}_training",) 63 | cfg.DATASETS.TEST = (f"{args.dataset_name}_validation",) 64 | cfg.DATALOADER.NUM_WORKERS = args.num_workers 65 | # Let training initialize from model zoo 66 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(pretrained_model) 67 | LOGGER.info(f"{pretrained_model} correctly loaded") 68 | 69 | cfg.SOLVER.CHECKPOINT_PERIOD = 20000 70 | cfg.SOLVER.BASE_LR = args.lr 71 | cfg.SOLVER.MAX_ITER = args.num_iter 72 | cfg.SOLVER.IMS_PER_BATCH = args.batch_size 73 | cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = args.num_rpn 74 | if args.model_type == "faster_rcnn": 75 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(args.classes) 76 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.pred_thr 77 | cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST = args.nms_thr 78 | cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE = args.reg_loss_type 79 | cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = args.bbox_reg_loss_weight 80 | cfg.MODEL.RPN.POSITIVE_FRACTION = args.bbox_rpn_pos_fraction 81 | cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION = args.bbox_head_pos_fraction 82 | elif args.model_type == "retinanet": 83 | cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.pred_thr 84 | cfg.MODEL.RETINANET.NMS_THRESH_TEST = args.nms_thr 85 | cfg.MODEL.RETINANET.NUM_CLASSES = len(args.classes) 86 | cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = args.reg_loss_type 87 | cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA = args.focal_loss_gamma 88 | cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA = args.focal_loss_alpha 89 | else: 90 | assert False, f"Add implementation for model {args.model_type}" 91 | cfg.MODEL.DEVICE = "cuda" if args.num_gpus else "cpu" 92 | 93 | cfg.TEST.DETECTIONS_PER_IMAGE = args.det_per_img 94 | 95 | cfg.OUTPUT_DIR = args.model_dir 96 | os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) 97 | return cfg 98 | 99 | 100 | def _train_impl(args) -> None: 101 | r"""Training implementation executes the following steps: 102 | 103 | * Register the dataset to Detectron2 catalog 104 | * Create the configuration node for training 105 | * Launch training 106 | * Serialize the training configuration to a JSON file as it is required for prediction 107 | """ 108 | 109 | dataset = DataSetMeta(name=args.dataset_name, classes=args.classes) 110 | 111 | for ds_type in ( 112 | ("training", "validation", "test") 113 | if args.evaluation_type 114 | else ("training", "validation",) 115 | ): 116 | if not Path(args.annotation_channel) / f"{ds_type}.manifest": 117 | err_msg = f"{ds_type} dataset annotations not found" 118 | LOGGER.error(err_msg) 119 | raise FileNotFoundError(err_msg) 120 | 121 | channel_to_ds = { 122 | "training": ( 123 | args.training_channel, 124 | f"{args.annotation_channel}/training.manifest", 125 | ), 126 | "validation": ( 127 | args.validation_channel, 128 | f"{args.annotation_channel}/validation.manifest", 129 | ), 130 | } 131 | if args.evaluation_type: 132 | channel_to_ds["test"] = ( 133 | args.test_channel, 134 | f"{args.annotation_channel}/test.manifest", 135 | ) 136 | 137 | register_dataset( 138 | metadata=dataset, label_name=args.label_name, channel_to_dataset=channel_to_ds, 139 | ) 140 | 141 | cfg = _config_training(args) 142 | 143 | cfg.setdefault("VAL_LOG_PERIOD", args.log_period) 144 | 145 | trainer = Trainer(cfg) 146 | trainer.resume_or_load(resume=False) 147 | 148 | if cfg.MODEL.DEVICE != "cuda": 149 | err = RuntimeError("A CUDA device is required to launch training") 150 | LOGGER.error(err) 151 | raise err 152 | trainer.train() 153 | 154 | # If in the master process: save config and run COCO evaluation on test set 155 | if args.current_host == args.hosts[0]: 156 | with open(f"{cfg.OUTPUT_DIR}/config.json", "w") as fid: 157 | json.dump(cfg, fid, indent=2) 158 | 159 | if args.evaluation_type: 160 | LOGGER.info(f"Running {args.evaluation_type} evaluation on the test set") 161 | evaluator = D2CocoEvaluator( 162 | dataset_name=f"{dataset.name}_test", 163 | tasks=("bbox",), 164 | distributed=len(args.hosts)==1 and args.num_gpus > 1, 165 | output_dir=f"{cfg.OUTPUT_DIR}/eval", 166 | use_fast_impl=args.evaluation_type == "fast", 167 | nb_max_preds=cfg.TEST.DETECTIONS_PER_IMAGE, 168 | ) 169 | cfg.DATASETS.TEST = (f"{args.dataset_name}_test",) 170 | model = Trainer.build_model(cfg) 171 | DetectionCheckpointer(model).load(f"{cfg.OUTPUT_DIR}/model_final.pth") 172 | Trainer.test(cfg, model, evaluator) 173 | else: 174 | LOGGER.info("Evaluation on the test set skipped") 175 | 176 | 177 | ########## 178 | # Training 179 | ########## 180 | 181 | 182 | def train(args: argparse.Namespace) -> None: 183 | r"""Launch distributed training by using Detecton2's `launch()` function 184 | 185 | Parameters 186 | ---------- 187 | args : argparse.Namespace 188 | training script arguments, see :py:meth:`_parse_args()` 189 | """ 190 | args.classes = ast.literal_eval(args.classes) 191 | 192 | machine_rank = args.hosts.index(args.current_host) 193 | LOGGER.info(f"Machine rank: {machine_rank}") 194 | master_addr = args.hosts[0] 195 | master_port = "55555" 196 | 197 | url = "auto" if len(args.hosts) == 1 else f"tcp://{master_addr}:{master_port}" 198 | LOGGER.info(f"Device URL: {url}") 199 | 200 | launch( 201 | _train_impl, 202 | num_gpus_per_machine=args.num_gpus, 203 | num_machines=len(args.hosts), 204 | dist_url=url, 205 | machine_rank=machine_rank, 206 | args=(args,), 207 | ) 208 | 209 | 210 | ############# 211 | # Script API 212 | ############# 213 | 214 | 215 | def _parse_args() -> argparse.Namespace: 216 | r"""Define training script API according to the argument that are parsed from the CLI 217 | 218 | Returns 219 | ------- 220 | argparse.Namespace 221 | training script arguments, execute $(python $thisfile --help) for detailed documentation 222 | """ 223 | 224 | parser = argparse.ArgumentParser() 225 | 226 | # Pretrained model 227 | parser.add_argument( 228 | "--model-type", 229 | type=str, 230 | default="faster_rcnn", 231 | choices=["faster_rcnn", "retinanet"], 232 | metavar="MT", 233 | help=( 234 | "Type of architecture to be used for object detection; " 235 | "two options are supported: 'faster_rccn' and 'retinanet' " 236 | "(default: faster_rcnn)" 237 | ), 238 | ) 239 | parser.add_argument( 240 | "--backbone", 241 | type=str, 242 | default="R_50_C4", 243 | choices=[ 244 | "R_50_C4", 245 | "R_50_DC5", 246 | "R_50_FPN", 247 | "R_101_C4", 248 | "R_101_DC5", 249 | "R_101_FPN", 250 | "X_101_32x8d_FPN", 251 | ], 252 | metavar="B", 253 | help=( 254 | "Encoder backbone, how to read this field: " 255 | "R50 (RetinaNet-50), R100 (RetinaNet-100), X101 (ResNeXt-101); " 256 | "C4 (Use a ResNet conv4 backbone with conv5 head), " 257 | "DC5 (ResNet conv5 backbone with dilations in conv5) " 258 | "FPN (Use a FPN on top of resnet) ;" 259 | "Attention! Only some combinations are supported, please refer to the original doc " 260 | "(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md) " 261 | "(default: R_50_C4)" 262 | ), 263 | ) 264 | parser.add_argument( 265 | "--lr-schedule", 266 | type=int, 267 | default=1, 268 | choices=[1, 3], 269 | metavar="LRS", 270 | help=( 271 | "Length of the training schedule, two values are supported: 1 or 3. " 272 | "1x = 16 images / it * 90,000 iterations in total with the LR reduced at 60k and 80k." 273 | "3x = 16 images / it * 270,000 iterations in total with the LR reduced at 210k and 250k" 274 | "(default: 1)" 275 | ), 276 | ) 277 | # Hyper-parameters 278 | parser.add_argument( 279 | "--num-workers", 280 | type=int, 281 | default=2, 282 | metavar="NW", 283 | help="Number of workers used to by the data loader (default: 2)", 284 | ) 285 | parser.add_argument( 286 | "--lr", 287 | type=float, 288 | default=0.00025, 289 | metavar="LR", 290 | help="Base learning rate value (default: 0.00025)", 291 | ) 292 | parser.add_argument( 293 | "--num-iter", 294 | type=int, 295 | default=1000, 296 | metavar="I", 297 | help="Maximum number of iterations (default: 1000)", 298 | ) 299 | parser.add_argument( 300 | "--batch-size", 301 | type=int, 302 | default=16, 303 | metavar="B", 304 | help="Number of images per batch across all machines (default: 16)", 305 | ) 306 | parser.add_argument( 307 | "--num-rpn", 308 | type=int, 309 | default=100, 310 | metavar="R", 311 | help="Total number of RPN examples per image (default: 100)", 312 | ) 313 | parser.add_argument( 314 | "--reg-loss-type", 315 | type=str, 316 | default="smooth_l1", 317 | choices=["smooth_l1", "giou"], 318 | metavar="RLT", 319 | help=("Loss type used for regression subnet " "(default: smooth_l1)"), 320 | ) 321 | 322 | # RetinaNet Specific 323 | parser.add_argument( 324 | "--focal-loss-gamma", 325 | type=float, 326 | default=2.0, 327 | metavar="FLG", 328 | help="Focal loss gamma, used in RetinaNet (default: 2.0)", 329 | ) 330 | parser.add_argument( 331 | "--focal-loss-alpha", 332 | type=float, 333 | default=0.25, 334 | metavar="FLA", 335 | help="Focal loss alpha, used in RetinaNet. It must be in [0.1,1] (default: 0.25)", 336 | ) 337 | 338 | # Faster-RCNN Specific 339 | parser.add_argument( 340 | "--bbox-reg-loss-weight", 341 | type=float, 342 | default=1.0, 343 | help="Weight regression loss (default: 0.1)", 344 | ) 345 | parser.add_argument( 346 | "--bbox-rpn-pos-fraction", 347 | type=float, 348 | default=0.5, 349 | help="Target fraction of foreground (positive) examples per RPN minibatch (default: 0.5)", 350 | ) 351 | parser.add_argument( 352 | "--bbox-head-pos-fraction", 353 | type=float, 354 | default=0.25, 355 | help=( 356 | "Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) " 357 | "(default: 0.25)" 358 | ), 359 | ) 360 | parser.add_argument( 361 | "--log-period", 362 | type=int, 363 | default=40, 364 | help="Occurence in number of iterations at which loss values are logged", 365 | ) 366 | 367 | # Inference Parameters 368 | parser.add_argument( 369 | "--det-per-img", 370 | type=int, 371 | default=200, 372 | metavar="R", 373 | help="Maximum number of detections to return per image during inference (default: 200)", 374 | ) 375 | parser.add_argument( 376 | "--nms-thr", 377 | type=float, 378 | default=0.5, 379 | metavar="NMS", 380 | help="If IoU is bigger than this value, only more confident pred is kept " 381 | "(default: 0.5)", 382 | ) 383 | parser.add_argument( 384 | "--pred-thr", 385 | type=float, 386 | default=0.5, 387 | metavar="PT", 388 | help="Minimum confidence score to retain prediction (default: 0.5)", 389 | ) 390 | parser.add_argument( 391 | "--evaluation-type", 392 | choices=["fast", "coco"], 393 | type=str, 394 | default=None, 395 | help=( 396 | "Evaluation to run on the test set after the training loop on the test. " 397 | "Valid options are: `fast` (Detectron2 boosted COCO eval) and " 398 | "`coco` (default COCO evaluation). " 399 | "This value is by default None, which means that no evaluation is executed" 400 | ), 401 | ) 402 | 403 | # Mandatory parameters 404 | parser.add_argument( 405 | "--classes", type=str, metavar="C", help="List of classes of objects" 406 | ) 407 | parser.add_argument( 408 | "--dataset-name", type=str, metavar="DS", help="Name of the dataset" 409 | ) 410 | parser.add_argument( 411 | "--label-name", 412 | type=str, 413 | metavar="DS", 414 | help="Name of category of objects to detect (e.g. 'object')", 415 | ) 416 | 417 | # Container Environment 418 | parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"]) 419 | 420 | parser.add_argument( 421 | "--training-channel", 422 | type=str, 423 | default=os.environ["SM_CHANNEL_TRAINING"], 424 | help="Path folder that contains training images (File mode)", 425 | ) 426 | parser.add_argument( 427 | "--validation-channel", 428 | type=str, 429 | default=os.environ["SM_CHANNEL_VALIDATION"], 430 | help="Path folder that contains validation images (File mode)", 431 | ) 432 | parser.add_argument( 433 | "--test-channel", 434 | type=str, 435 | default=os.environ["SM_CHANNEL_TEST"], 436 | help=( 437 | "Path folder that contains test images, " 438 | "these are used to evaluate the model but not to drive hparam tuning" 439 | ), 440 | ) 441 | parser.add_argument( 442 | "--annotation-channel", 443 | type=str, 444 | default=os.environ["SM_CHANNEL_ANNOTATION"], 445 | help="Path to folder that contains augumented manifest files with annotations", 446 | ) 447 | 448 | parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"]) 449 | parser.add_argument( 450 | "--hosts", type=str, default=ast.literal_eval(os.environ["SM_HOSTS"]) 451 | ) 452 | parser.add_argument( 453 | "--current-host", type=str, default=os.environ["SM_CURRENT_HOST"] 454 | ) 455 | return parser.parse_args() 456 | 457 | 458 | if __name__ == "__main__": 459 | train(_parse_args()) 460 | -------------------------------------------------------------------------------- /d2_custom_sku110k.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Detectron2 on SKU-110K dataset\n", 8 | "\n", 9 | "**Index**\n", 10 | "\n", 11 | "1. [Background](#Background)\n", 12 | "1. [Setup](#Setup)\n", 13 | "1. [Data](#Data)\n", 14 | "1. [Training](#Training)\n", 15 | "1. [Hyperparameter Tuning Jobs](#HPO)\n", 16 | "1. [Deploy: Batch Transform](#Deploy)\n", 17 | "1. [Visualization](#Visualization)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "## Background\n", 25 | "\n", 26 | "Detectron2 is a Computer Vision framework which implements Object Detection algorithms. It is developed by Facebook AI Research team. While its ancestor, Detectron, was completely written in Caffe, Detecton2 was refactored in PyTorch to enable fast experiments and iterations from. Detectron2 has a rich model zoo that contains State-of-the-Art models for object detection, semantic segmentation and pose estimation, to cite a few. A modular design makes Detectron2 easily extensible, and, hence, cutting-edge research projects can be implemented on top of it. \n", 27 | "\n", 28 | "We use Detectron2 to train and evaluate models on the [SKU110k-dataset](https://github.com/eg4000/SKU110K_CVPR19). This open source dataset contains images of retail shelves. Each image contains about 150 objects, which makes it suitable to test dense scene object detection algortihms. Bounding boxes are associated with SKUs without distinguishing between categories of product.\n", 29 | "\n", 30 | "In this noteboook we use Object Detection models from Detectron2's model zoo. We then leverage Amazon SageMaker ML platform to finetune pre-trained models on SKU110k dataset and deploy trained model for inference." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "## Setup\n", 38 | "\n", 39 | "#### Precondition\n", 40 | "If you are executing this notebook using Sagemaker Notebook instance or Sagemaker Studio instance, please make sure that it has IAM role used with `AmazonSageMakerFullAccess` policy.\n", 41 | "\n", 42 | "We start by importing required Python libraries and configuring some common parameters" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import boto3\n", 52 | "import sagemaker\n", 53 | "\n", 54 | "assert (\n", 55 | " sagemaker.__version__.split(\".\")[0] == \"2\"\n", 56 | "), \"Please upgrade SageMaker Python SDK to version 2\"" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "bucket = \"FILL WITH UNIQUE BUCKET NAME\" # TODO: update this value\n", 66 | "prefix_data = \"detectron2/data\"\n", 67 | "prefix_model = \"detectron2/training_artefacts\"\n", 68 | "prefix_code = \"detectron2/model\"\n", 69 | "prefix_predictions = \"detectron2/predictions\"\n", 70 | "local_folder = \"cache\" # cache folder used to store downloaded data - not versioned\n", 71 | "\n", 72 | "\n", 73 | "sm_session = sagemaker.Session(default_bucket=bucket)\n", 74 | "role = sagemaker.get_execution_role()\n", 75 | "region = sm_session.boto_region_name\n", 76 | "account = sm_session.account_id()\n", 77 | "\n", 78 | "# if bucket doesn't exist, create one\n", 79 | "s3_resource = boto3.resource(\"s3\")\n", 80 | "if not s3_resource.Bucket(bucket) in s3_resource.buckets.all():\n", 81 | " s3_resource.create_bucket(\n", 82 | " Bucket=bucket, CreateBucketConfiguration={\"LocationConstraint\": region}\n", 83 | " )" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": {}, 89 | "source": [ 90 | "## Dataset Preparation\n", 91 | "\n", 92 | "To prepare SKU110K for training, we need to do following:\n", 93 | "* download and unzip SKU-110K dataset;\n", 94 | "* split images into three channels (training, validation and test) according to the filename prefix;\n", 95 | "* remove images (and the corresponding annotations) that are corrupted, i.e. cannot be loaded by PIL.Image.load();\n", 96 | "* upload image channels to the S3 bucket;\n", 97 | "* reorganize annotations into augmented manifest files and upload these files to S3." 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "import json\n", 107 | "import os\n", 108 | "import tarfile\n", 109 | "import tempfile\n", 110 | "from datetime import datetime\n", 111 | "from pathlib import Path\n", 112 | "from typing import Mapping, Optional, Sequence\n", 113 | "from urllib import request\n", 114 | "\n", 115 | "import boto3\n", 116 | "import numpy as np\n", 117 | "import pandas as pd\n", 118 | "from tqdm import tqdm" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "### Download SKU-110K dataset\n", 126 | "\n", 127 | "The total size of the unzipped dataset is 12.2 GB. Please make sure to set the volume size of your notebook instance accordingly. We suggest a volume size equal to 30 GB.\n", 128 | "\n", 129 | "⚠️ dataset download and extraction will take ~15-20 minutes" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "! wget -P cache http://trax-geometry.s3.amazonaws.com/cvpr_challenge/SKU110K_fixed.tar.gz" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "sku_dataset_dirname = \"SKU110K_fixed\"\n", 148 | "assert Path(\n", 149 | " local_folder\n", 150 | ").exists(), f\"Set wget directory-prefix to {local_folder} in the previous cell\"\n", 151 | "\n", 152 | "\n", 153 | "def track_progress(members):\n", 154 | " i = 0\n", 155 | " for member in members:\n", 156 | " if i % 100 == 0:\n", 157 | " print(\".\", end=\"\")\n", 158 | " i += 1\n", 159 | " yield member\n", 160 | "\n", 161 | "\n", 162 | "if not (Path(local_folder) / sku_dataset_dirname).exists():\n", 163 | " compressed_file = tarfile.open(\n", 164 | " name=os.path.join(local_folder, sku_dataset_dirname + \".tar.gz\")\n", 165 | " )\n", 166 | " compressed_file.extractall(\n", 167 | " path=local_folder, members=track_progress(compressed_file)\n", 168 | " )\n", 169 | "else:\n", 170 | " print(f\"Using the data in `{local_folder}` folder\")" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "### Reorganize images" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": {}, 184 | "outputs": [], 185 | "source": [ 186 | "path_images = Path(local_folder) / sku_dataset_dirname / \"images\"\n", 187 | "assert path_images.exists(), f\"{path_images} not found\"\n", 188 | "\n", 189 | "prefix_to_channel = {\n", 190 | " \"train\": \"training\",\n", 191 | " \"val\": \"validation\",\n", 192 | " \"test\": \"test\",\n", 193 | "}\n", 194 | "for channel_name in prefix_to_channel.values():\n", 195 | " if not (path_images.parent / channel_name).exists():\n", 196 | " (path_images.parent / channel_name).mkdir()\n", 197 | "\n", 198 | "for path_img in path_images.iterdir():\n", 199 | " for prefix in prefix_to_channel:\n", 200 | " if path_img.name.startswith(prefix):\n", 201 | " path_img.replace(\n", 202 | " path_images.parent / prefix_to_channel[prefix] / path_img.name\n", 203 | " )" 204 | ] 205 | }, 206 | { 207 | "cell_type": "markdown", 208 | "metadata": {}, 209 | "source": [ 210 | "Detectron2 uses Pillow library to read images. We found out that some images in the SKU dataset are corrupted, which causes the dataloader to raise an IOError exception. Therefore, we remove them from the dataset. " 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "CORRUPTED_IMAGES = {\n", 220 | " \"training\": (\"train_4222.jpg\", \"train_5822.jpg\", \"train_882.jpg\", \"train_924.jpg\"),\n", 221 | " \"validation\": tuple(),\n", 222 | " \"test\": (\"test_274.jpg\", \"test_2924.jpg\"),\n", 223 | "}" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "for channel_name in prefix_to_channel.values():\n", 233 | " for img_name in CORRUPTED_IMAGES[channel_name]:\n", 234 | " try:\n", 235 | " (path_images.parent / channel_name / img_name).unlink()\n", 236 | " print(f\"{img_name} removed from channel {channel_name} \")\n", 237 | " except FileNotFoundError:\n", 238 | " print(f\"{img_name} not in channel {channel_name}\")" 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": {}, 245 | "outputs": [], 246 | "source": [ 247 | "for channel_name in prefix_to_channel.values():\n", 248 | " print(\n", 249 | " f\"Number of {channel_name} images = {sum(1 for x in (path_images.parent / channel_name).glob('*.jpg'))}\"\n", 250 | " )" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "Upload dataset to S3. ⚠️ this operation will take some time (~10-15 minutes)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "channel_to_s3_imgs = {}\n", 267 | "\n", 268 | "for channel_name in prefix_to_channel.values():\n", 269 | " inputs = sm_session.upload_data(\n", 270 | " path=str(path_images.parent / channel_name),\n", 271 | " bucket=bucket,\n", 272 | " key_prefix=f\"{prefix_data}/{channel_name}\",\n", 273 | " )\n", 274 | " print(f\"{channel_name} images uploaded to {inputs}\")\n", 275 | " channel_to_s3_imgs[channel_name] = inputs" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": {}, 281 | "source": [ 282 | "### Reorganise annotations\n", 283 | "\n", 284 | "The annotations in SKU-110K dataset are stored in csv files. They are here reorganised into [augmented manifest files](https://docs.aws.amazon.com/sagemaker/latest/dg/augmented-manifest.html). See SageMaker documentation for specification on [bounding box annotations](https://docs.aws.amazon.com/sagemaker/latest/dg/sms-data-output.html#sms-output-box)." 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": null, 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "def create_annotation_channel(\n", 294 | " channel_id: str,\n", 295 | " path_to_annotation: Path,\n", 296 | " bucket_name: str,\n", 297 | " data_prefix: str,\n", 298 | " img_annotation_to_ignore: Optional[Sequence[str]] = None,\n", 299 | ") -> Sequence[Mapping]:\n", 300 | " r\"\"\"Change format from original to augmented manifest files\n", 301 | "\n", 302 | " Parameters\n", 303 | " ----------\n", 304 | " channel_id : str\n", 305 | " name of the channel, i.e. training, validation or test\n", 306 | " path_to_annotation : Path\n", 307 | " path to annotation file\n", 308 | " bucket_name : str\n", 309 | " bucket where the data are uploaded\n", 310 | " data_prefix : str\n", 311 | " bucket prefix\n", 312 | " img_annotation_to_ignore : Optional[Sequence[str]]\n", 313 | " annotation from these images are ignore because the corresponding images are corrupted, default to None\n", 314 | "\n", 315 | " Returns\n", 316 | " -------\n", 317 | " Sequence[Mapping]\n", 318 | " List of json lines, each lines contains the annotations for a single. This recreates the\n", 319 | " format of augmented manifest files that are generated by Amazon SageMaker GroundTruth\n", 320 | " labeling jobs\n", 321 | " \"\"\"\n", 322 | " if channel_id not in (\"training\", \"validation\", \"test\"):\n", 323 | " raise ValueError(\n", 324 | " f\"Channel identifier must be training, validation or test. The passed values is {channel_id}\"\n", 325 | " )\n", 326 | " if not path_to_annotation.exists():\n", 327 | " raise FileNotFoundError(f\"Annotation file {path_to_annotation} not found\")\n", 328 | "\n", 329 | " df_annotation = pd.read_csv(\n", 330 | " path_to_annotation,\n", 331 | " header=0,\n", 332 | " names=(\n", 333 | " \"image_name\",\n", 334 | " \"x1\",\n", 335 | " \"y1\",\n", 336 | " \"x2\",\n", 337 | " \"y2\",\n", 338 | " \"class\",\n", 339 | " \"image_width\",\n", 340 | " \"image_height\",\n", 341 | " ),\n", 342 | " )\n", 343 | "\n", 344 | " df_annotation[\"left\"] = df_annotation[\"x1\"]\n", 345 | " df_annotation[\"top\"] = df_annotation[\"y1\"]\n", 346 | " df_annotation[\"width\"] = df_annotation[\"x2\"] - df_annotation[\"x1\"]\n", 347 | " df_annotation[\"height\"] = df_annotation[\"y2\"] - df_annotation[\"y1\"]\n", 348 | " df_annotation.drop(columns=[\"x1\", \"x2\", \"y1\", \"y2\"], inplace=True)\n", 349 | "\n", 350 | " jsonlines = []\n", 351 | " for img_id in df_annotation[\"image_name\"].unique():\n", 352 | " if img_annotation_to_ignore and img_id in img_annotation_to_ignore:\n", 353 | " print(\n", 354 | " f\"Annotations for image {img_id} are neglected as the image is corrupted\"\n", 355 | " )\n", 356 | " continue\n", 357 | " img_annotations = df_annotation.loc[df_annotation[\"image_name\"] == img_id, :]\n", 358 | " annotations = []\n", 359 | " for (\n", 360 | " _,\n", 361 | " _,\n", 362 | " img_width,\n", 363 | " img_heigh,\n", 364 | " bbox_l,\n", 365 | " bbox_t,\n", 366 | " bbox_w,\n", 367 | " bbox_h,\n", 368 | " ) in img_annotations.itertuples(index=False):\n", 369 | " annotations.append(\n", 370 | " {\n", 371 | " \"class_id\": 0,\n", 372 | " \"width\": bbox_w,\n", 373 | " \"top\": bbox_t,\n", 374 | " \"left\": bbox_l,\n", 375 | " \"height\": bbox_h,\n", 376 | " }\n", 377 | " )\n", 378 | " jsonline = {\n", 379 | " \"sku\": {\n", 380 | " \"annotations\": annotations,\n", 381 | " \"image_size\": [{\"width\": img_width, \"depth\": 3, \"height\": img_heigh,}],\n", 382 | " },\n", 383 | " \"sku-metadata\": {\n", 384 | " \"job_name\": f\"labeling-job/sku-110k-{channel_id}\",\n", 385 | " \"class-map\": {\"0\": \"SKU\"},\n", 386 | " \"human-annotated\": \"yes\",\n", 387 | " \"objects\": len(annotations) * [{\"confidence\": 0.0}],\n", 388 | " \"type\": \"groundtruth/object-detection\",\n", 389 | " \"creation-date\": datetime.now()\n", 390 | " .replace(second=0, microsecond=0)\n", 391 | " .isoformat(),\n", 392 | " },\n", 393 | " \"source-ref\": f\"s3://{bucket_name}/{data_prefix}/{channel_id}/{img_id}\",\n", 394 | " }\n", 395 | " jsonlines.append(jsonline)\n", 396 | " return jsonlines" 397 | ] 398 | }, 399 | { 400 | "cell_type": "code", 401 | "execution_count": null, 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "annotation_folder = Path(local_folder) / sku_dataset_dirname / \"annotations\"\n", 406 | "channel_to_annotation_path = {\n", 407 | " \"training\": annotation_folder / \"annotations_train.csv\",\n", 408 | " \"validation\": annotation_folder / \"annotations_val.csv\",\n", 409 | " \"test\": annotation_folder / \"annotations_test.csv\",\n", 410 | "}\n", 411 | "channel_to_annotation = {}\n", 412 | "\n", 413 | "for channel in channel_to_annotation_path:\n", 414 | " annotations = create_annotation_channel(\n", 415 | " channel,\n", 416 | " channel_to_annotation_path[channel],\n", 417 | " bucket,\n", 418 | " prefix_data,\n", 419 | " CORRUPTED_IMAGES[channel],\n", 420 | " )\n", 421 | " print(f\"Number of {channel} annotations: {len(annotations)}\")\n", 422 | " channel_to_annotation[channel] = annotations" 423 | ] 424 | }, 425 | { 426 | "cell_type": "code", 427 | "execution_count": null, 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "def upload_annotations(p_annotations, p_channel: str):\n", 432 | " rsc_bucket = boto3.resource(\"s3\").Bucket(bucket)\n", 433 | "\n", 434 | " json_lines = [json.dumps(elem) for elem in p_annotations]\n", 435 | " to_write = \"\\n\".join(json_lines)\n", 436 | "\n", 437 | " with tempfile.NamedTemporaryFile(mode=\"w\") as fid:\n", 438 | " fid.write(to_write)\n", 439 | " rsc_bucket.upload_file(\n", 440 | " fid.name, f\"{prefix_data}/annotations/{p_channel}.manifest\"\n", 441 | " )" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [ 450 | "for channel_id, annotations in channel_to_annotation.items():\n", 451 | " upload_annotations(annotations, channel_id)" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "metadata": {}, 457 | "source": [ 458 | "Let's check on expected number of images in training, validation and test sets, so that any failures on upload or preprocessing are caught before user starts training" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": null, 464 | "metadata": {}, 465 | "outputs": [], 466 | "source": [ 467 | "channel_to_expected_size = {\n", 468 | " \"training\": 8215,\n", 469 | " \"validation\": 588,\n", 470 | " \"test\": 2934,\n", 471 | "}\n", 472 | "\n", 473 | "prefix_data = \"detectron2/data\"\n", 474 | "bucket_rsr = boto3.resource(\"s3\").Bucket(bucket)\n", 475 | "for channel_name, exp_nb in channel_to_expected_size.items():\n", 476 | " nb_objs = len(\n", 477 | " list(bucket_rsr.objects.filter(Prefix=f\"{prefix_data}/{channel_name}\"))\n", 478 | " )\n", 479 | " assert (\n", 480 | " nb_objs == exp_nb\n", 481 | " ), f\"The {channel_name} set should have {exp_nb} images but it contains {nb_objs} images\"" 482 | ] 483 | }, 484 | { 485 | "cell_type": "markdown", 486 | "metadata": {}, 487 | "source": [ 488 | "## Training using Amazon SageMaker \n", 489 | "\n", 490 | "To run training job on SageMaker we will:\n", 491 | "* build training container and push it to Amazon Elastic Container Registry (\"ECR\"), container includes all runtime dependencies and training script;\n", 492 | "* define training job configuration which includes training cluster configuration and model hyperparameters;\n", 493 | "* schedule training job, observe its progress.\n", 494 | "\n", 495 | "\n", 496 | "### Building training container\n", 497 | "Before we can build training container, we need to authethicate in shared ECR repo to retrieve Pytorch base image and in private ECR repository. Enter your region and account id below, and then execute the following cell to do it." 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "%%bash\n", 507 | "\n", 508 | "REGION=YOUR_REGION\n", 509 | "ACCOUNT=YOUR_ACCOUNT_ID\n", 510 | "\n", 511 | "aws ecr get-login-password --region $REGION | docker login --username AWS --password-stdin 763104351884.dkr.ecr.$REGION.amazonaws.com\n", 512 | "# loging to your private ECR\n", 513 | "aws ecr get-login-password --region $REGION | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.$REGION.amazonaws.com" 514 | ] 515 | }, 516 | { 517 | "cell_type": "markdown", 518 | "metadata": {}, 519 | "source": [ 520 | "Our build container uses AWS-authored Pytorch container as a base image. We extend base image with Detecton2 dependencies and copy training script. Execute cell below to review Dockerfile content." 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": null, 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "%%bash\n", 530 | "\n", 531 | "# execute this cell to review Docker container\n", 532 | "pygmentize -l docker Dockerfile.sku110ktraining" 533 | ] 534 | }, 535 | { 536 | "cell_type": "markdown", 537 | "metadata": {}, 538 | "source": [ 539 | "Next, we build the Docker container locally and then push it to ECR repository, so SageMaker can deploy this container on compute nodes at training time. Run command bellow to build and push container. The size of the resulting Docker image is approximately 5GB." 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "%%bash\n", 549 | "./build_and_push.sh sagemaker-d2-train-sku110k latest Dockerfile.sku110ktraining" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": {}, 555 | "source": [ 556 | "### Configure SageMaker training job\n", 557 | "\n", 558 | "Configuration includes following components:\n", 559 | "* data configuration defines where train/test/val datasets are stored;\n", 560 | "* container configuration;\n", 561 | "* model hyperparameters;\n", 562 | "* training job parameters such as size of cluster and instance type, metrics to monitor, etc." 563 | ] 564 | }, 565 | { 566 | "cell_type": "code", 567 | "execution_count": null, 568 | "metadata": {}, 569 | "outputs": [], 570 | "source": [ 571 | "import json\n", 572 | "\n", 573 | "import boto3\n", 574 | "from sagemaker.estimator import Estimator" 575 | ] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "execution_count": null, 580 | "metadata": {}, 581 | "outputs": [], 582 | "source": [ 583 | "# Data configuration\n", 584 | "\n", 585 | "training_channel = f\"s3://{bucket}/{prefix_data}/training/\"\n", 586 | "validation_channel = f\"s3://{bucket}/{prefix_data}/validation/\"\n", 587 | "test_channel = f\"s3://{bucket}/{prefix_data}/test/\"\n", 588 | "\n", 589 | "annotation_channel = f\"s3://{bucket}/{prefix_data}/annotations/\"\n", 590 | "\n", 591 | "classes = [\n", 592 | " \"SKU\",\n", 593 | "]" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "# Container configuration\n", 603 | "\n", 604 | "container_name = \"sagemaker-d2-train-sku110k\"\n", 605 | "container_version = \"latest\"\n", 606 | "training_image_uri = (\n", 607 | " f\"{account}.dkr.ecr.{region}.amazonaws.com/{container_name}:{container_version}\"\n", 608 | ")" 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "# Metrics to monitor during training, each metric is scraped from container Stdout\n", 618 | "\n", 619 | "metrics = [\n", 620 | " {\"Name\": \"training:loss\", \"Regex\": \"total_loss: ([0-9\\\\.]+)\",},\n", 621 | " {\"Name\": \"training:loss_cls\", \"Regex\": \"loss_cls: ([0-9\\\\.]+)\",},\n", 622 | " {\"Name\": \"training:loss_box_reg\", \"Regex\": \"loss_box_reg: ([0-9\\\\.]+)\",},\n", 623 | " {\"Name\": \"training:loss_rpn_cls\", \"Regex\": \"loss_rpn_cls: ([0-9\\\\.]+)\",},\n", 624 | " {\"Name\": \"training:loss_rpn_loc\", \"Regex\": \"loss_rpn_loc: ([0-9\\\\.]+)\",},\n", 625 | " {\"Name\": \"validation:loss\", \"Regex\": \"total_val_loss: ([0-9\\\\.]+)\",},\n", 626 | " {\"Name\": \"validation:loss_cls\", \"Regex\": \"val_loss_cls: ([0-9\\\\.]+)\",},\n", 627 | " {\"Name\": \"validation:loss_box_reg\", \"Regex\": \"val_loss_box_reg: ([0-9\\\\.]+)\",},\n", 628 | " {\"Name\": \"validation:loss_rpn_cls\", \"Regex\": \"val_loss_rpn_cls: ([0-9\\\\.]+)\",},\n", 629 | " {\"Name\": \"validation:loss_rpn_loc\", \"Regex\": \"val_loss_rpn_loc: ([0-9\\\\.]+)\",},\n", 630 | "]" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": null, 636 | "metadata": {}, 637 | "outputs": [], 638 | "source": [ 639 | "# Training instance type\n", 640 | "\n", 641 | "training_instance = \"ml.p3.8xlarge\"\n", 642 | "if training_instance.startswith(\"local\"):\n", 643 | " training_session = sagemaker.LocalSession()\n", 644 | " training_session.config = {\"local\": {\"local_code\": True}}\n", 645 | "else:\n", 646 | " training_session = sm_session" 647 | ] 648 | }, 649 | { 650 | "cell_type": "markdown", 651 | "metadata": {}, 652 | "source": [ 653 | "The following hyper-parameters are used in the training job. Feel free to change them and experiment." 654 | ] 655 | }, 656 | { 657 | "cell_type": "code", 658 | "execution_count": null, 659 | "metadata": {}, 660 | "outputs": [], 661 | "source": [ 662 | "# Model Hyperparameters\n", 663 | "\n", 664 | "od_algorithm = \"faster_rcnn\" # choose one in (\"faster_rcnn\", \"retinanet\")\n", 665 | "training_job_hp = {\n", 666 | " # Dataset\n", 667 | " \"classes\": json.dumps(classes),\n", 668 | " \"dataset-name\": json.dumps(\"sku110k\"),\n", 669 | " \"label-name\": json.dumps(\"sku\"),\n", 670 | " # Algo specs\n", 671 | " \"model-type\": json.dumps(od_algorithm),\n", 672 | " \"backbone\": json.dumps(\"R_101_FPN\"),\n", 673 | " # Data loader\n", 674 | " \"num-iter\": 900,\n", 675 | " \"log-period\": 500,\n", 676 | " \"batch-size\": 16,\n", 677 | " \"num-workers\": 8,\n", 678 | " # Optimization\n", 679 | " \"lr\": 0.005,\n", 680 | " \"lr-schedule\": 3,\n", 681 | " # Faster-RCNN specific\n", 682 | " \"num-rpn\": 517,\n", 683 | " \"bbox-head-pos-fraction\": 0.2,\n", 684 | " \"bbox-rpn-pos-fraction\": 0.4,\n", 685 | " # Prediction specific\n", 686 | " \"nms-thr\": 0.2,\n", 687 | " \"pred-thr\": 0.1,\n", 688 | " \"det-per-img\": 300,\n", 689 | " # Evaluation\n", 690 | " \"evaluation-type\": \"fast\",\n", 691 | "}" 692 | ] 693 | }, 694 | { 695 | "cell_type": "code", 696 | "execution_count": null, 697 | "metadata": {}, 698 | "outputs": [], 699 | "source": [ 700 | "# Compile Sagemaker Training job object and start training\n", 701 | "\n", 702 | "d2_estimator = Estimator(\n", 703 | " image_uri=training_image_uri,\n", 704 | " role=role,\n", 705 | " sagemaker_session=training_session,\n", 706 | " instance_count=2,\n", 707 | " instance_type=training_instance,\n", 708 | " hyperparameters=training_job_hp,\n", 709 | " metric_definitions=metrics,\n", 710 | " output_path=f\"s3://{bucket}/{prefix_model}\",\n", 711 | " base_job_name=f\"detectron2-{od_algorithm.replace('_', '-')}\",\n", 712 | ")\n", 713 | "\n", 714 | "d2_estimator.fit(\n", 715 | " {\n", 716 | " \"training\": training_channel,\n", 717 | " \"validation\": validation_channel,\n", 718 | " \"test\": test_channel,\n", 719 | " \"annotation\": annotation_channel,\n", 720 | " },\n", 721 | " wait=False,\n", 722 | ")" 723 | ] 724 | }, 725 | { 726 | "cell_type": "markdown", 727 | "metadata": {}, 728 | "source": [ 729 | "## HyperParameter Optimization with Amazon SageMaker\n", 730 | "\n", 731 | "SageMaker SDK comes with the `tuner` module that can be used to search for the optimal hyper-parameters (see more details [here](https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning.html)). Let's run several experiment with different model hyperparameters with aim to minize the validation loss. \n", 732 | "\n", 733 | "`hparams_range` dictionary that defines the hyper-parameters to be optimized. Feel free to modify it. ⚠️ Please note, a tuning job runs multiple training job. Therefore, be aware of the amount of computational resources that a tuner job requires." 734 | ] 735 | }, 736 | { 737 | "cell_type": "code", 738 | "execution_count": null, 739 | "metadata": {}, 740 | "outputs": [], 741 | "source": [ 742 | "from sagemaker.tuner import (\n", 743 | " CategoricalParameter,\n", 744 | " ContinuousParameter,\n", 745 | " HyperparameterTuner,\n", 746 | " IntegerParameter,\n", 747 | ")\n", 748 | "\n", 749 | "od_algorithm = \"retinanet\" # choose one in (\"faster_rcnn\", \"retinanet\")" 750 | ] 751 | }, 752 | { 753 | "cell_type": "code", 754 | "execution_count": null, 755 | "metadata": {}, 756 | "outputs": [], 757 | "source": [ 758 | "hparams_range = {\n", 759 | " \"lr\": ContinuousParameter(0.0005, 0.1),\n", 760 | "}\n", 761 | "if od_algorithm == \"faster_rcnn\":\n", 762 | " hparams_range.update(\n", 763 | " {\n", 764 | " \"bbox-rpn-pos-fraction\": ContinuousParameter(0.1, 0.5),\n", 765 | " \"bbox-head-pos-fraction\": ContinuousParameter(0.1, 0.5),\n", 766 | " }\n", 767 | " )\n", 768 | "elif od_algorithm == \"retinanet\":\n", 769 | " hparams_range.update(\n", 770 | " {\n", 771 | " \"focal-loss-gamma\": ContinuousParameter(2.5, 5.0),\n", 772 | " \"focal-loss-alpha\": ContinuousParameter(0.3, 1.0),\n", 773 | " }\n", 774 | " )\n", 775 | "else:\n", 776 | " assert False, f\"{od_algorithm} not supported\"" 777 | ] 778 | }, 779 | { 780 | "cell_type": "code", 781 | "execution_count": null, 782 | "metadata": {}, 783 | "outputs": [], 784 | "source": [ 785 | "obj_metric_name = \"validation:loss\"\n", 786 | "obj_type = \"Minimize\"\n", 787 | "metric_definitions = [\n", 788 | " {\"Name\": \"training:loss\", \"Regex\": \"total_loss: ([0-9\\\\.]+)\",},\n", 789 | " {\"Name\": \"training:loss_cls\", \"Regex\": \"loss_cls: ([0-9\\\\.]+)\",},\n", 790 | " {\"Name\": \"training:loss_box_reg\", \"Regex\": \"loss_box_reg: ([0-9\\\\.]+)\",},\n", 791 | " {\"Name\": obj_metric_name, \"Regex\": \"total_val_loss: ([0-9\\\\.]+)\",},\n", 792 | " {\"Name\": \"validation:loss_cls\", \"Regex\": \"val_loss_cls: ([0-9\\\\.]+)\",},\n", 793 | " {\"Name\": \"validation:loss_box_reg\", \"Regex\": \"val_loss_box_reg: ([0-9\\\\.]+)\",},\n", 794 | "]" 795 | ] 796 | }, 797 | { 798 | "cell_type": "code", 799 | "execution_count": null, 800 | "metadata": {}, 801 | "outputs": [], 802 | "source": [ 803 | "fixed_hparams = {\n", 804 | " # Dataset\n", 805 | " \"classes\": json.dumps(classes),\n", 806 | " \"dataset-name\": json.dumps(\"sku110k\"),\n", 807 | " \"label-name\": json.dumps(\"sku\"),\n", 808 | " # Algo specs\n", 809 | " \"model-type\": json.dumps(od_algorithm),\n", 810 | " \"backbone\": json.dumps(\"R_101_FPN\"),\n", 811 | " # Data loader\n", 812 | " \"num-iter\": 9000,\n", 813 | " \"log-period\": 500,\n", 814 | " \"batch-size\": 16,\n", 815 | " \"num-workers\": 8,\n", 816 | " # Optimization\n", 817 | " \"lr-schedule\": 3,\n", 818 | " # Prediction specific\n", 819 | " \"nms-thr\": 0.2,\n", 820 | " \"pred-thr\": 0.1,\n", 821 | " \"det-per-img\": 300,\n", 822 | " # Evaluation\n", 823 | " \"evaluation-type\": \"fast\",\n", 824 | "}\n", 825 | "\n", 826 | "hpo_estimator = Estimator(\n", 827 | " image_uri=training_image_uri,\n", 828 | " role=role,\n", 829 | " sagemaker_session=sm_session,\n", 830 | " instance_count=1,\n", 831 | " instance_type=\"ml.p3.8xlarge\",\n", 832 | " hyperparameters=fixed_hparams,\n", 833 | " output_path=f\"s3://{bucket}/{prefix_model}\",\n", 834 | " use_spot_instances=True, # Use spot instances to spare a\n", 835 | " max_run=2 * 60 * 60,\n", 836 | " max_wait=3 * 60 * 60,\n", 837 | ")" 838 | ] 839 | }, 840 | { 841 | "cell_type": "code", 842 | "execution_count": null, 843 | "metadata": {}, 844 | "outputs": [], 845 | "source": [ 846 | "tuner = HyperparameterTuner(\n", 847 | " hpo_estimator,\n", 848 | " obj_metric_name,\n", 849 | " hparams_range,\n", 850 | " metric_definitions,\n", 851 | " objective_type=obj_type,\n", 852 | " max_jobs=2,\n", 853 | " max_parallel_jobs=2,\n", 854 | " base_tuning_job_name=f\"hpo-d2-{od_algorithm.replace('_', '-')}\",\n", 855 | ")" 856 | ] 857 | }, 858 | { 859 | "cell_type": "code", 860 | "execution_count": null, 861 | "metadata": {}, 862 | "outputs": [], 863 | "source": [ 864 | "tuner.fit(\n", 865 | " inputs={\n", 866 | " \"training\": training_channel,\n", 867 | " \"validation\": validation_channel,\n", 868 | " \"test\": test_channel,\n", 869 | " \"annotation\": annotation_channel,\n", 870 | " },\n", 871 | " wait=False,\n", 872 | ")" 873 | ] 874 | }, 875 | { 876 | "cell_type": "code", 877 | "execution_count": null, 878 | "metadata": {}, 879 | "outputs": [], 880 | "source": [ 881 | "# Let's review outcomes of HyperParameter search\n", 882 | "\n", 883 | "hpo_tuning_job_name = tuner.latest_tuning_job.name\n", 884 | "bayes_metrics = sagemaker.HyperparameterTuningJobAnalytics(\n", 885 | " hpo_tuning_job_name\n", 886 | ").dataframe()\n", 887 | "bayes_metrics.sort_values([\"FinalObjectiveValue\"], ascending=True)" 888 | ] 889 | }, 890 | { 891 | "cell_type": "markdown", 892 | "metadata": {}, 893 | "source": [ 894 | "## Model Deployment on Amazon SageMaker\n", 895 | "\n", 896 | "Just like with model training, SageMaker is using containers to run inference. Hence, we start by preparing serving container which will be then deployed with on Amazon SageMaker Hosting platform." 897 | ] 898 | }, 899 | { 900 | "cell_type": "code", 901 | "execution_count": null, 902 | "metadata": {}, 903 | "outputs": [], 904 | "source": [ 905 | "%%bash\n", 906 | "\n", 907 | "# execute this cell to review Docker container\n", 908 | "pygmentize -l docker Dockerfile.sku110kserving" 909 | ] 910 | }, 911 | { 912 | "cell_type": "markdown", 913 | "metadata": {}, 914 | "source": [ 915 | "Run cell below to build the Docker container defined in the image `Dockerfile.sku110kserving` and push it to ECR. The size of the resulting Docker image is approximately 5GB." 916 | ] 917 | }, 918 | { 919 | "cell_type": "code", 920 | "execution_count": null, 921 | "metadata": {}, 922 | "outputs": [], 923 | "source": [ 924 | "%%bash\n", 925 | "\n", 926 | "./build_and_push.sh sagemaker-d2-serve latest Dockerfile.sku110kserving" 927 | ] 928 | }, 929 | { 930 | "cell_type": "markdown", 931 | "metadata": {}, 932 | "source": [ 933 | "We will run batch inference, i.e. running inference against large chunk of images. We use [SageMaker Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-batch.html) to do it. " 934 | ] 935 | }, 936 | { 937 | "cell_type": "code", 938 | "execution_count": null, 939 | "metadata": {}, 940 | "outputs": [], 941 | "source": [ 942 | "from sagemaker.pytorch import PyTorchModel" 943 | ] 944 | }, 945 | { 946 | "cell_type": "markdown", 947 | "metadata": {}, 948 | "source": [ 949 | "Here we assume that a HPO job was executed. We attach the tuning job and fetch the best model" 950 | ] 951 | }, 952 | { 953 | "cell_type": "code", 954 | "execution_count": null, 955 | "metadata": {}, 956 | "outputs": [], 957 | "source": [ 958 | "from sagemaker.tuner import HyperparameterTuner\n", 959 | "\n", 960 | "tuning_job_id = \"Insert tuning job id\"\n", 961 | "attached_tuner = HyperparameterTuner.attach(tuning_job_id)\n", 962 | "\n", 963 | "best_estimator = attached_tuner.best_estimator()\n", 964 | "\n", 965 | "best_estimator.latest_training_job.describe()\n", 966 | "training_job_artifact = best_estimator.latest_training_job.describe()[\"ModelArtifacts\"][\n", 967 | " \"S3ModelArtifacts\"\n", 968 | "]" 969 | ] 970 | }, 971 | { 972 | "cell_type": "markdown", 973 | "metadata": {}, 974 | "source": [ 975 | "You can also specify the S3 URI of model artifact. Uncomment the following code if you want to use this option:" 976 | ] 977 | }, 978 | { 979 | "cell_type": "code", 980 | "execution_count": null, 981 | "metadata": {}, 982 | "outputs": [], 983 | "source": [ 984 | "# training_job_artifact = \"Your model artifacts\"" 985 | ] 986 | }, 987 | { 988 | "cell_type": "code", 989 | "execution_count": null, 990 | "metadata": {}, 991 | "outputs": [], 992 | "source": [ 993 | "# Define parameters of inference container\n", 994 | "\n", 995 | "serve_container_name = \"sagemaker-d2-serve\"\n", 996 | "serve_container_version = \"latest\"\n", 997 | "serve_image_uri = f\"{account}.dkr.ecr.{region}.amazonaws.com/{serve_container_name}:{serve_container_version}\"\n", 998 | "\n", 999 | "inference_output = f\"s3://{bucket}/{prefix_predictions}/{serve_container_name}/{Path(test_channel).name}_channel/{training_job_artifact.split('/')[-3]}\"\n", 1000 | "inference_output" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": null, 1006 | "metadata": {}, 1007 | "outputs": [], 1008 | "source": [ 1009 | "# Compile SageMaker model object and configure Batch Transform job\n", 1010 | "\n", 1011 | "model = PyTorchModel(\n", 1012 | " name=\"d2-sku110k-model\",\n", 1013 | " model_data=training_job_artifact,\n", 1014 | " role=role,\n", 1015 | " sagemaker_session=sm_session,\n", 1016 | " entry_point=\"predict_sku110k.py\",\n", 1017 | " source_dir=\"container_serving\",\n", 1018 | " image_uri=serve_image_uri,\n", 1019 | " framework_version=\"1.6.0\",\n", 1020 | " code_location=f\"s3://{bucket}/{prefix_code}\",\n", 1021 | ")\n", 1022 | "\n", 1023 | "transformer = model.transformer(\n", 1024 | " instance_count=1,\n", 1025 | " instance_type=\"ml.p3.2xlarge\", # \"ml.p2.xlarge\", #\n", 1026 | " output_path=inference_output,\n", 1027 | " max_payload=16,\n", 1028 | ")" 1029 | ] 1030 | }, 1031 | { 1032 | "cell_type": "code", 1033 | "execution_count": null, 1034 | "metadata": {}, 1035 | "outputs": [], 1036 | "source": [ 1037 | "# Start Batch Transform job\n", 1038 | "\n", 1039 | "transformer.transform(\n", 1040 | " data=test_channel,\n", 1041 | " data_type=\"S3Prefix\",\n", 1042 | " content_type=\"application/x-image\",\n", 1043 | " wait=False,\n", 1044 | ")" 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "markdown", 1049 | "metadata": {}, 1050 | "source": [ 1051 | "## Visualization\n", 1052 | "\n", 1053 | "Once our batch inference job is completed, let's visualize predictions. We'll use single random image for visualization. Feel free to re-run it many times." 1054 | ] 1055 | }, 1056 | { 1057 | "cell_type": "code", 1058 | "execution_count": null, 1059 | "metadata": {}, 1060 | "outputs": [], 1061 | "source": [ 1062 | "import io\n", 1063 | "\n", 1064 | "import matplotlib\n", 1065 | "import matplotlib.patches as patches\n", 1066 | "import numpy as np\n", 1067 | "from matplotlib import pyplot as plt\n", 1068 | "from PIL import Image" 1069 | ] 1070 | }, 1071 | { 1072 | "cell_type": "code", 1073 | "execution_count": null, 1074 | "metadata": {}, 1075 | "outputs": [], 1076 | "source": [ 1077 | "def key_from_uri(s3_uri: str) -> str:\n", 1078 | " \"\"\"Get S3 object key from its URI\"\"\"\n", 1079 | " return \"/\".join(Path(s3_uri).parts[2:])\n", 1080 | "\n", 1081 | "\n", 1082 | "bucket_rsr = boto3.resource(\"s3\").Bucket(bucket)\n", 1083 | "predict_objs = list(\n", 1084 | " bucket_rsr.objects.filter(Prefix=key_from_uri(inference_output) + \"/\")\n", 1085 | ")\n", 1086 | "img_objs = list(bucket_rsr.objects.filter(Prefix=key_from_uri(test_channel)))" 1087 | ] 1088 | }, 1089 | { 1090 | "cell_type": "code", 1091 | "execution_count": null, 1092 | "metadata": {}, 1093 | "outputs": [], 1094 | "source": [ 1095 | "COLORS = [\n", 1096 | " (0, 200, 0),\n", 1097 | "]\n", 1098 | "\n", 1099 | "\n", 1100 | "def plot_predictions_on_image(\n", 1101 | " p_img: np.ndarray, p_preds: Mapping, score_thr: float = 0.5, show=True\n", 1102 | ") -> plt.Figure:\n", 1103 | " r\"\"\"Plot bounding boxes predicted by an inference job on the corresponding image\n", 1104 | "\n", 1105 | " Parameters\n", 1106 | " ----------\n", 1107 | " p_img : np.ndarray\n", 1108 | " input image used for prediction\n", 1109 | " p_preds : Mapping\n", 1110 | " dictionary with bounding boxes, predicted classes and confidence scores\n", 1111 | " score_thr : float, optional\n", 1112 | " show bounding boxes whose confidence score is bigger than `score_thr`, by default 0.5\n", 1113 | " show : bool, optional\n", 1114 | " show figure if True do not otherwise, by default True\n", 1115 | "\n", 1116 | " Returns\n", 1117 | " -------\n", 1118 | " plt.Figure\n", 1119 | " figure handler\n", 1120 | "\n", 1121 | " Raises\n", 1122 | " ------\n", 1123 | " IOError\n", 1124 | " If the prediction dictionary `p_preds` does not contain one of the required keys:\n", 1125 | " `pred_classes`, `pred_boxes` and `scores`\n", 1126 | " \"\"\"\n", 1127 | " for required_key in (\"pred_classes\", \"pred_boxes\", \"scores\"):\n", 1128 | " if required_key not in p_preds:\n", 1129 | " raise IOError(f\"Missing required key: {required_key}\")\n", 1130 | "\n", 1131 | " fig, fig_axis = plt.subplots(1)\n", 1132 | " fig_axis.imshow(p_img)\n", 1133 | " for class_id, bbox, score in zip(\n", 1134 | " p_preds[\"pred_classes\"], p_preds[\"pred_boxes\"], p_preds[\"scores\"]\n", 1135 | " ):\n", 1136 | " if score < score_thr:\n", 1137 | " break # bounding boxes are sorted by confidence score in descending order\n", 1138 | " rect = patches.Rectangle(\n", 1139 | " (bbox[0], bbox[1]),\n", 1140 | " bbox[2] - bbox[0],\n", 1141 | " bbox[3] - bbox[1],\n", 1142 | " linewidth=1,\n", 1143 | " edgecolor=[float(val) / 255 for val in COLORS[class_id]],\n", 1144 | " facecolor=\"none\",\n", 1145 | " )\n", 1146 | " fig_axis.add_patch(rect)\n", 1147 | " plt.axis(\"off\")\n", 1148 | " if show:\n", 1149 | " plt.show()\n", 1150 | " return fig" 1151 | ] 1152 | }, 1153 | { 1154 | "cell_type": "code", 1155 | "execution_count": null, 1156 | "metadata": {}, 1157 | "outputs": [], 1158 | "source": [ 1159 | "matplotlib.rcParams[\"figure.dpi\"] = 300\n", 1160 | "\n", 1161 | "sample_id = np.random.randint(0, len(img_objs), 1)[0]\n", 1162 | "\n", 1163 | "img_obj = img_objs[sample_id]\n", 1164 | "pred_obj = predict_objs[sample_id]\n", 1165 | "\n", 1166 | "img = np.asarray(Image.open(io.BytesIO(img_obj.get()[\"Body\"].read())))\n", 1167 | "preds = json.loads(pred_obj.get()[\"Body\"].read().decode(\"utf-8\"))\n", 1168 | "\n", 1169 | "sample_fig = plot_predictions_on_image(img, preds, 0.40, True)" 1170 | ] 1171 | } 1172 | ], 1173 | "metadata": { 1174 | "kernelspec": { 1175 | "display_name": "conda_pytorch_p36", 1176 | "language": "python", 1177 | "name": "conda_pytorch_p36" 1178 | }, 1179 | "language_info": { 1180 | "codemirror_mode": { 1181 | "name": "ipython", 1182 | "version": 3 1183 | }, 1184 | "file_extension": ".py", 1185 | "mimetype": "text/x-python", 1186 | "name": "python", 1187 | "nbconvert_exporter": "python", 1188 | "pygments_lexer": "ipython3", 1189 | "version": "3.6.13" 1190 | } 1191 | }, 1192 | "nbformat": 4, 1193 | "nbformat_minor": 4 1194 | } 1195 | --------------------------------------------------------------------------------