├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── questions-or-general-feedbacks.md └── workflows │ ├── ci.yml │ └── codeql.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .viperlightignore ├── .viperlightrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── THIRD_PARTY ├── cloudformation ├── defect-detection-permissions.yaml ├── defect-detection-sagemaker-notebook-instance.yaml ├── defect-detection.yaml └── solution-assistant │ ├── requirements.in │ ├── requirements.txt │ ├── solution-assistant.yaml │ └── src │ └── lambda_fn.py ├── docs ├── arch.png ├── data.png ├── data_flow.png ├── launch.svg ├── numerical.png ├── sagemaker.png ├── sample1.png ├── sample2.png ├── sample3.png └── train_arch.png ├── manifest.json ├── notebooks ├── 0_demo.ipynb ├── 1_retrain_from_checkpoint.ipynb ├── 2_detection_from_scratch.ipynb ├── 3_classification_from_scratch.ipynb └── 4_finetune.ipynb ├── pyproject.toml ├── requirements.in ├── requirements.txt ├── scripts ├── build.sh ├── find_best_ckpt.py ├── set_kernelspec.py ├── train_classifier.sh └── train_detector.sh ├── setup.cfg ├── setup.py └── src ├── im2rec.py ├── prepare_RecordIO.py ├── prepare_data ├── neu.py └── test_neu.py ├── sagemaker_defect_detection ├── __init__.py ├── classifier.py ├── dataset │ ├── __init__.py │ └── neu.py ├── detector.py ├── models │ ├── __init__.py │ └── ddn.py ├── transforms.py └── utils │ ├── __init__.py │ ├── coco_eval.py │ ├── coco_utils.py │ └── visualize.py ├── utils.py └── xml2json.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py text diff=python 2 | *.ipynb text 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: bug 6 | assignees: ehsanmok 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior and error details 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Additional context** 20 | Add any other context about the problem here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: enhancement 6 | assignees: ehsanmok 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-or-general-feedbacks.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Questions or general feedbacks 3 | about: Question and feedbacks 4 | title: "[General]" 5 | labels: question 6 | assignees: ehsanmok 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ mainline ] 7 | 8 | jobs: 9 | build: 10 | 11 | strategy: 12 | max-parallel: 4 13 | fail-fast: false 14 | matrix: 15 | python-version: [3.7, 3.8] 16 | platform: [ubuntu-latest] 17 | 18 | runs-on: ${{ matrix.platform }} 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: CloudFormation lint 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install cfn-lint 30 | for y in `find deploy/* -name "*.yaml" -o -name "*.template" -o -name "*.json"`; do 31 | echo "============= $y ================" 32 | cfn-lint --fail-on-warnings $y || ec1=$? 33 | done 34 | if [ "$ec1" -ne "0" ]; then echo 'ERROR-1'; else echo 'SUCCESS-1'; ec1=0; fi 35 | echo "Exit Code 1 `echo $ec1`" 36 | if [ "$ec1" -ne "0" ]; then echo 'ERROR'; ec=1; else echo 'SUCCESS'; ec=0; fi; 37 | echo "Exit Code Final `echo $ec`" 38 | exit $ec 39 | - name: Build the package 40 | run: | 41 | python -m pip install -e '.[dev,test,doc]' 42 | - name: Code style check 43 | run: | 44 | black --check src 45 | - name: Notebook style check 46 | run: | 47 | black-nb notebooks/*.ipynb --check 48 | - name: Type check 49 | run: | 50 | mypy --ignore-missing-imports --allow-redefinition --pretty --show-error-context src/ 51 | - name: Run tests 52 | run: | 53 | pytest --pyargs src/ -s 54 | - name: Update docs 55 | run: | 56 | portray on_github_pages -m "Update gh-pages" -f 57 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yaml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: '0 21 * * 1' 8 | 9 | jobs: 10 | CodeQL-Build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | language: ['python'] 17 | 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v2 21 | with: 22 | # We must fetch at least the immediate parents so that if this is 23 | # a pull request then we can checkout the head. 24 | fetch-depth: 2 25 | 26 | # If this run was triggered by a pull request event, then checkout 27 | # the head of the pull request instead of the merge commit. 28 | - run: git checkout HEAD^2 29 | if: ${{ github.event_name == 'pull_request' }} 30 | 31 | - name: Initialize CodeQL 32 | uses: github/codeql-action/init@v1 33 | with: 34 | languages: ${{ matrix.language }} 35 | 36 | - name: Perform CodeQL Analysis 37 | uses: github/codeql-action/analyze@v1 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__ 3 | .vscode 4 | .mypy* 5 | .ipynb* 6 | .pytest* 7 | *output 8 | *logs* 9 | *checkpoints* 10 | *.ckpt* 11 | *.tar.gz 12 | dist 13 | build 14 | site-packages 15 | stack_outputs.json 16 | docs/*.py 17 | raw_neu_det 18 | neu_det 19 | raw_neu_cls 20 | neu_cls 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.6 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v2.3.0 7 | hooks: 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/psf/black 11 | rev: 20.8b1 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/kynan/nbstripout 15 | rev: 0.3.7 16 | hooks: 17 | - id: nbstripout 18 | name: nbstripout 19 | description: "nbstripout: strip output from Jupyter and IPython notebooks" 20 | entry: nbstripout notebooks 21 | language: python 22 | types: [jupyter] 23 | - repo: https://github.com/tomcatling/black-nb 24 | rev: 0.3.0 25 | hooks: 26 | - id: black-nb 27 | name: black-nb 28 | entry: black-nb 29 | language: python 30 | args: ["--include", '\.ipynb$'] 31 | -------------------------------------------------------------------------------- /.viperlightignore: -------------------------------------------------------------------------------- 1 | ^dist/ 2 | docs/launch.svg:4 3 | CODE_OF_CONDUCT.md:4 4 | CONTRIBUTING.md:50 5 | build/solution-assistant.zip 6 | cloudformation/defect-detection-endpoint.yaml:42 7 | cloudformation/defect-detection-permissions.yaml:114 8 | docs/NEU_surface_defect_database.pdf 9 | -------------------------------------------------------------------------------- /.viperlightrc: -------------------------------------------------------------------------------- 1 | { 2 | "failOn": "medium", 3 | "all": true 4 | } 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *mainline* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Inspection Automation with Amazon SageMaker 2 | 3 |

4 | Actions Status 5 | CodeQL Status 6 | Code style: black 7 |
8 | License: Apache-2.0 9 | Maintenance 10 | AMA 11 |

12 | 13 | This solution detects product defects with an end-to-end Deep Learning workflow for quality control in manufacturing process. The solution takes input of product images and identifies defect regions with bounding boxes. In particular, this solution takes two distinct approaches: 14 | 1. Use an implementation of the *Defect Detection Network (DDN)* algorithm following [An End-to-End Steel Surface Defect Detection](https://ieeexplore.ieee.org/document/8709818) on [NEU surface defect database](http://faculty.neu.edu.cn/yunhyan/NEU_surface_defect_database.html) (see [resources](#resources)) in PyTorch using [PyTorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning). 15 | 2. Use a pre-trained Sagemaker object detection model and fine-tune on the target dataset. 16 | 17 | This solution will demonstrate the immense advantage of fine-tuning a high-quality pre-trained model on the target dataset, both visually and numerically. 18 | 19 | ### Contents 20 | 1. [Overview](#overview) 21 | 1. [What Does the Input Data Look Like?](#input) 22 | 2. [How to Prepare Your Data to Feed into the Model?](#preparedata) 23 | 3. [What are the Outputs?](#output) 24 | 4. [What is the Estimated Cost?](#cost) 25 | 5. [What Algorithms & Models are Used?](#algorithms) 26 | 6. [What Does the Data Flow Look Like?](#dataflow) 27 | 2. [Solution Details](#solution) 28 | 1. [Background](#background) 29 | 2. [What is Visual Inspection?](#inspection) 30 | 3. [What are the Problems?](#problems) 31 | 4. [What Does this Solution Offer?](#offer) 32 | 3. [Architecture Overview](#architecture) 33 | 4. [Cleaning up](#cleaning-up) 34 | 5. [Customization](#customization) 35 | 36 | 37 | ## 1. Overview 38 | 39 | ### 1.1. What Does the Input Data Look Like? 40 | 41 | Input is an image of a defective / non-defective product. The training data should have relatively balanced classes, with annotations for ground truth defects (locations and defect types) per image. Here are examples of annotations used in the demo, they show some "inclusion" defects on the surface: 42 | 43 | !["sample2"](https://sagemaker-solutions-prod-us-east-2.s3.us-east-2.amazonaws.com/sagemaker-defect-detection/docs/sample2.png) 44 | 45 | The NEU surface defect database (see [references](#references)) is a *balanced* dataset which contains 46 | 47 | > Six kinds of typical surface defects of the hot-rolled steel strip are collected, i.e., rolled-in scale (RS), patches (Pa), crazing (Cr), pitted surface (PS), inclusion (In) and scratches (Sc). The database includes 1,800 grayscale images: 300 samples each of six different kinds of typical surface defects 48 | 49 | Here is a sample image of the six classes 50 | 51 | !["data sample"](https://sagemaker-solutions-prod-us-east-2.s3.us-east-2.amazonaws.com/sagemaker-defect-detection/docs/data.png) 52 | 53 | ### 1.2. How to Prepare Your Data to Feed into the Model? 54 | 55 | There are data preparation and preprocessing steps and should be followed in the notebooks. It's critical to prepare your image annotations beforehand. 56 | * For training the DDN model, please prepare one xml file for each image with defect annotations. Check notebook 0,1,2,3 for details. 57 | * For finetuning pretrained Sagemaker models, you need to prepare either a single `annotation.json` for all data, or a `RecordIO` file for both all images and all annotations. Check notebook 4 for details. 58 | 59 | ### 1.3. What are the Outputs? 60 | 61 | * For each image, the trained model will produce bounding boxes of detected visual defects (if any), the predicted defect type, and prediction confidence score (0~1). 62 | * If you have a labeled test dataset, you could obtain the mean Average Precision (mAP) score for each model and compare among all the models. 63 | * For example, the mAP scores on a test set of the NEU dataset 64 | 65 | | | DDN | Type1 | Type1+HPO | Type2 | Type2+HPO| 66 | | --- | --- | --- | --- | --- | ---| 67 | | mAP | 0.08 | 0.067 | 0.226 | 0.371 | 0.375| 68 | 69 | 70 | ### 1.4. What is the Estimated Cost? 71 | 72 | * Running solution notebook 0~3 end-to-end costs around $8 USD and less than an hour, assuming using p3.2xlarge EC2 instance, and $3.06 on-demand hourly rate in US East. These notebooks only train DDN models for a few iterations for demonstration purpose, which is **far from convergence**. It would take around 8 hours to train from scratch till convergence and cost $25+. 73 | * Running solution notebook 4 costs around $130~140 USD. This notebook provides advanced materials, including finetuning two types of pretrained Sagemaker models **till convergence**, with and without hyperparameter optimization (HPO), and result in four models for inference. You could choose to train either one model, or all four models according to your budget and requirements. The cost and runtime for training each model are: 74 | 75 | | Model | Cost (USD) | Runtime (Hours) | Billable time (Hours)| 76 | |:----------:|:---------------:|:----:|:-----:| 77 | |Type 1| 1.5 | 0.5 | 0.5| 78 | |Type 1 with HPO (20 jobs)| 30.6 | 1* | 10| 79 | |Type 2| 4.6 | 1.5 | 1.5| 80 | |Type 2 with HPO (20 jobs)| 92 | 3* | 30| 81 | (*) HPO tasks in this solution consider 20 jobs in total and 10 jobs in parallel. So 1 actual runtime hour amounts to 10 billable cost hours. 82 | * Please make sure you have read the cleaning up part in [Section 4](#cleaning-up) after training to avoid incurred cost from deployed models. 83 | 84 | 85 | 86 | ### 1.5. What Algorithms & Models are Used? 87 | 88 | * The DDN model is based on [Faster RCNN](https://arxiv.org/abs/1506.01497). For more details, please read the paper [An End-to-End Steel Surface Defect Detection](https://ieeexplore.ieee.org/document/8709818). 89 | * The pretrained Sagemaker models include SSD models and FasterRCNN model, using either VGG, ResNet, or MobileNet as backbone, pretrained on either ImageNet, COCO, VOC, or FPN dataset. 90 | 91 | ### 1.6. What Does the Data Flow Look Like? 92 | 93 | ![Data flow](https://sagemaker-solutions-prod-us-east-2.s3.us-east-2.amazonaws.com/sagemaker-defect-detection/docs/data_flow.png) 94 | 95 | ## 2. Solution Details 96 | 97 | ### 2.1. Background 98 | 99 | According to the [Gartner study on the top 10 strategic tech trends for 2020](https://www.gartner.com/smarterwithgartner/gartner-top-10-strategic-technology-trends-for-2020/), hyper-automation is the number one trend in 2020 and will continue advancing in future. When it comes to manufacturing, one of the main barriers to hyper-automation is in areas where Human involvements is still struggling to be reduced and intelligent systems have hard times to become on-par with Human visual recognition abilities and become mainstream, despite great advancement of Deep Learning in Computer Vision. This is mainly due to lack of enough annotated data (or when data is sparse) in areas such as _Quality Control_ sections where trained Human eyes still dominates. 100 | 101 | 102 | ### 2.2. What is Visual Inspection? 103 | 104 | The **analysis of products on the production line for the purpose of Quality Control**. According to [Everything you need to know about Visual Inspection with AI](https://nanonets.com/blog/ai-visual-inspection/), visual inspection can also be used for internal and external assessment of the various equipment in a production facility such as storage tanks, pressure vessels, piping, and other equipment which expands to many industries from Electronics, Medical, Food and Raw Materials. 105 | 106 | ### 2.3. What are the Problems? 107 | 108 | * *Human visual inspection error* is a major factor in this area. According to the report [The Role of Visual Inspection in the 21st Century](https://www.osti.gov/servlets/purl/1476816) 109 | 110 | > Most inspection tasks are much more complex and typically exhibit error rates of 20% to 30% (Drury & Fox, 1975) 111 | 112 | which directly translates to *cost*. 113 | * Cost: according to [glassdoor estimate](https://www.glassdoor.co.in/Salaries/us-quality-control-inspector-salary-SRCH_IL.0,2_IN1_KO3,28.htm), a trained quality inspector salary varies between 29K (US) - 64K per year. 114 | 115 | ### 2.4. What Does this Solution Offer? 116 | 117 | This solution offers 118 | 1. an implementation of the state-of-the-art Deep Learning approach for automatic *Steel Surface Defect Detection* using **Amazon SageMaker**. The [model](https://ieeexplore.ieee.org/document/8709818) enhances [Faster RCNN](https://arxiv.org/abs/1506.01497) and outputs possible defects in an steel surface image. 119 | This solution trains a classifier on **NEU-CLS** dataset as well as a detector on **NEU-DET** dataset. 120 | 2. a complete solution using high-quality pretrained Sagemaker models to finetune on the target dataset with and without hyperparameter optimization (HPO). 121 | 122 | The **most important** information this solution delivers, is that training a deep learning model from scratch on a small dataset can be both time-consuming and less effective, whereas finetuning a high-quality pretrained model, which was trained on large-scale dataset, could be both cost- and runtime-efficient and highly performant. Here are the sample detection results 123 | 124 | drawing 125 | 126 | ## 3. Architecture Overview 127 | 128 | The following illustration is the architecture for the end-to-end training and deployment process 129 | 130 | !["Solution Architecture"](https://sagemaker-solutions-prod-us-east-2.s3.us-east-2.amazonaws.com/sagemaker-defect-detection/docs/train_arch.png) 131 | 132 | 1. The input data located in an [Amazon S3](https://aws.amazon.com/s3/) bucket 133 | 2. The provided [SageMaker notebook](source/deep_demand_forecast.ipynb) that gets the input data and launches the later stages below 134 | 3. **Training Classifier and Detector models** and evaluating its results using Amazon SageMaker. If desired, one can deploy the trained models and create SageMaker endpoints 135 | 4. **SageMaker endpoint** created from the previous step is an [HTTPS endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-hosting.html) and is capable of producing predictions 136 | 5. Monitoring the training and deployed model via [Amazon CloudWatch](https://aws.amazon.com/cloudwatch/) 137 | 138 | ## 4. Cleaning up 139 | 140 | When you've finished with this solution, make sure that you delete all unwanted AWS resources. AWS CloudFormation can be used to automatically delete all standard resources that have been created by the solution and notebook. Go to the AWS CloudFormation Console, and delete the parent stack. Choosing to delete the parent stack will automatically delete the nested stacks. 141 | 142 | **Caution:** You need to manually delete any extra resources that you may have created in this notebook. Some examples include, extra Amazon S3 buckets (to the solution's default bucket), extra Amazon SageMaker endpoints (using a custom name). 143 | 144 | ## 5. Customization 145 | 146 | For using your own data, make sure it is labeled and is a *relatively* balanced dataset. Also make sure the image annotations follow the required format. 147 | 148 | 149 | 150 | ### Useful Links 151 | 152 | * [Amazon SageMaker Getting Started](https://aws.amazon.com/sagemaker/getting-started/) 153 | * [Amazon SageMaker Developer Guide](https://docs.aws.amazon.com/sagemaker/latest/dg/whatis.html) 154 | * [Amazon SageMaker Python SDK Documentation](https://sagemaker.readthedocs.io/en/stable/) 155 | * [AWS CloudFormation User Guide](https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/Welcome.html) 156 | 157 | ### References 158 | 159 | * K. Song and Y. Yan, “A noise robust method based on completed local binary patterns for hot-rolled steel strip surface defects,” Applied Surface Science, vol. 285, pp. 858-864, Nov. 2013. 160 | 161 | * Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 162 | 163 | * Hongwen Dong, Kechen Song, Yu He, Jing Xu, Yunhui Yan, Qinggang Meng, “PGA-Net: Pyramid Feature Fusion and Global Context Attention Network for Automated Surface Defect Detection,” IEEE Transactions on Industrial Informatics, 2020. 164 | 165 | ### Security 166 | 167 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 168 | 169 | ### License 170 | 171 | This project is licensed under the Apache-2.0 License. 172 | -------------------------------------------------------------------------------- /THIRD_PARTY: -------------------------------------------------------------------------------- 1 | Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | SPDX-License-Identifier: Apache-2.0 3 | 4 | Original Copyright 2016 Soumith Chintala. All Rights Reserved. 5 | SPDX-License-Identifier: BSD-3 6 | https://github.com/pytorch/vision/blob/master/LICENSE 7 | -------------------------------------------------------------------------------- /cloudformation/defect-detection-permissions.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection permission stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Type: String 6 | SolutionName: 7 | Type: String 8 | S3Bucket: 9 | Type: String 10 | StackVersion: 11 | Type: String 12 | 13 | Mappings: 14 | S3: 15 | release: 16 | BucketPrefix: "sagemaker-solutions-prod" 17 | development: 18 | BucketPrefix: "sagemaker-solutions-devo" 19 | 20 | Resources: 21 | SageMakerIAMRole: 22 | Type: AWS::IAM::Role 23 | Properties: 24 | RoleName: !Sub "${SolutionPrefix}-${AWS::Region}-nb-role" 25 | AssumeRolePolicyDocument: 26 | Version: "2012-10-17" 27 | Statement: 28 | - Effect: Allow 29 | Principal: 30 | AWS: 31 | - !Sub "arn:aws:iam::${AWS::AccountId}:root" 32 | Service: 33 | - sagemaker.amazonaws.com 34 | - lambda.amazonaws.com # for solution assistant resource cleanup 35 | Action: 36 | - "sts:AssumeRole" 37 | Metadata: 38 | cfn_nag: 39 | rules_to_suppress: 40 | - id: W28 41 | reason: Needs to be explicitly named to tighten launch permissions policy 42 | 43 | SageMakerIAMPolicy: 44 | Type: AWS::IAM::Policy 45 | Properties: 46 | PolicyName: !Sub "${SolutionPrefix}-nb-instance-policy" 47 | Roles: 48 | - !Ref SageMakerIAMRole 49 | PolicyDocument: 50 | Version: "2012-10-17" 51 | Statement: 52 | - Effect: Allow 53 | Action: 54 | - sagemaker:CreateTrainingJob 55 | - sagemaker:DescribeTrainingJob 56 | - sagemaker:CreateProcessingJob 57 | - sagemaker:DescribeProcessingJob 58 | - sagemaker:CreateModel 59 | - sagemaker:DescribeEndpointConfig 60 | - sagemaker:DescribeEndpoint 61 | - sagemaker:CreateEndpointConfig 62 | - sagemaker:CreateEndpoint 63 | - sagemaker:DeleteEndpointConfig 64 | - sagemaker:DeleteEndpoint 65 | - sagemaker:DeleteModel 66 | - sagemaker:InvokeEndpoint 67 | Resource: 68 | - !Sub "arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:*" 69 | - Effect: Allow 70 | Action: 71 | - cloudwatch:GetMetricData 72 | - cloudwatch:GetMetricStatistics 73 | - cloudwatch:ListMetrics 74 | - cloudwatch:PutMetricData 75 | Resource: 76 | - !Sub "arn:aws:cloudwatch:${AWS::Region}:${AWS::AccountId}:*" 77 | - Effect: Allow 78 | Action: 79 | - logs:CreateLogGroup 80 | - logs:CreateLogStream 81 | - logs:DescribeLogStreams 82 | - logs:GetLogEvents 83 | - logs:PutLogEvents 84 | Resource: 85 | - !Sub "arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/sagemaker/*" 86 | - Effect: Allow 87 | Action: 88 | - iam:PassRole 89 | Resource: 90 | - !GetAtt SageMakerIAMRole.Arn 91 | Condition: 92 | StringEquals: 93 | iam:PassedToService: sagemaker.amazonaws.com 94 | - Effect: Allow 95 | Action: 96 | - iam:GetRole 97 | Resource: 98 | - !GetAtt SageMakerIAMRole.Arn 99 | - Effect: Allow 100 | Action: 101 | - ecr:GetAuthorizationToken 102 | - ecr:GetDownloadUrlForLayer 103 | - ecr:BatchGetImage 104 | - ecr:BatchCheckLayerAvailability 105 | - ecr:CreateRepository 106 | - ecr:DescribeRepositories 107 | - ecr:InitiateLayerUpload 108 | - ecr:CompleteLayerUpload 109 | - ecr:UploadLayerPart 110 | - ecr:TagResource 111 | - ecr:PutImage 112 | - ecr:DescribeImages 113 | - ecr:BatchDeleteImage 114 | Resource: 115 | - "*" 116 | - !Sub "arn:aws:ecr:${AWS::Region}:${AWS::AccountId}:repository/*" 117 | - Effect: Allow 118 | Action: 119 | - s3:ListBucket 120 | Resource: !Sub 121 | - "arn:aws:s3:::${SolutionS3BucketName}-${AWS::Region}" 122 | - SolutionS3BucketName: 123 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 124 | - Effect: Allow 125 | Action: 126 | - s3:GetObject 127 | Resource: !Sub 128 | - "arn:aws:s3:::${SolutionS3BucketName}-${AWS::Region}/${SolutionName}/*" 129 | - SolutionS3BucketName: 130 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 131 | SolutionName: !Ref SolutionName 132 | - Effect: Allow 133 | Action: 134 | - s3:ListBucket 135 | - s3:DeleteBucket 136 | - s3:GetBucketLocation 137 | - s3:ListBucketMultipartUploads 138 | Resource: 139 | - !Sub "arn:aws:s3:::${S3Bucket}" 140 | - Effect: Allow 141 | Action: 142 | - s3:AbortMultipartUpload 143 | - s3:ListObject 144 | - s3:GetObject 145 | - s3:PutObject 146 | - s3:DeleteObject 147 | Resource: 148 | - !Sub "arn:aws:s3:::${S3Bucket}" 149 | - !Sub "arn:aws:s3:::${S3Bucket}/*" 150 | - Effect: Allow 151 | Action: 152 | - s3:CreateBucket 153 | - s3:ListBucket 154 | - s3:GetObject 155 | - s3:GetObjectVersion 156 | - s3:PutObject 157 | - s3:DeleteObject 158 | Resource: 159 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}" 160 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}/*" 161 | 162 | Metadata: 163 | cfn_nag: 164 | rules_to_suppress: 165 | - id: W12 166 | reason: ECR GetAuthorizationToken is non resource-specific action 167 | 168 | Outputs: 169 | SageMakerRoleArn: 170 | Description: "SageMaker Execution Role for the solution" 171 | Value: !GetAtt SageMakerIAMRole.Arn 172 | -------------------------------------------------------------------------------- /cloudformation/defect-detection-sagemaker-notebook-instance.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection notebook stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Type: String 6 | SolutionName: 7 | Type: String 8 | S3Bucket: 9 | Type: String 10 | SageMakerIAMRoleArn: 11 | Type: String 12 | SageMakerNotebookInstanceType: 13 | Type: String 14 | StackVersion: 15 | Type: String 16 | 17 | Mappings: 18 | S3: 19 | release: 20 | BucketPrefix: "sagemaker-solutions-prod" 21 | development: 22 | BucketPrefix: "sagemaker-solutions-devo" 23 | 24 | Resources: 25 | NotebookInstance: 26 | Type: AWS::SageMaker::NotebookInstance 27 | Properties: 28 | DirectInternetAccess: Enabled 29 | InstanceType: !Ref SageMakerNotebookInstanceType 30 | LifecycleConfigName: !GetAtt LifeCycleConfig.NotebookInstanceLifecycleConfigName 31 | NotebookInstanceName: !Sub "${SolutionPrefix}" 32 | RoleArn: !Sub "${SageMakerIAMRoleArn}" 33 | VolumeSizeInGB: 100 34 | Metadata: 35 | cfn_nag: 36 | rules_to_suppress: 37 | - id: W1201 38 | reason: Solution does not have KMS encryption enabled by default 39 | LifeCycleConfig: 40 | Type: AWS::SageMaker::NotebookInstanceLifecycleConfig 41 | Properties: 42 | NotebookInstanceLifecycleConfigName: !Sub "${SolutionPrefix}-nb-lifecycle-config" 43 | OnStart: 44 | - Content: 45 | Fn::Base64: | 46 | set -e 47 | sudo -u ec2-user -i <> stack_outputs.json 69 | echo ' "AccountID": "${AWS::AccountId}",' >> stack_outputs.json 70 | echo ' "AWSRegion": "${AWS::Region}",' >> stack_outputs.json 71 | echo ' "IamRole": "${SageMakerIAMRoleArn}",' >> stack_outputs.json 72 | echo ' "SolutionPrefix": "${SolutionPrefix}",' >> stack_outputs.json 73 | echo ' "SolutionName": "${SolutionName}",' >> stack_outputs.json 74 | echo ' "SolutionS3Bucket": "${SolutionsRefBucketBase}",' >> stack_outputs.json 75 | echo ' "S3Bucket": "${S3Bucket}"' >> stack_outputs.json 76 | echo '}' >> stack_outputs.json 77 | cat stack_outputs.json 78 | sudo chown -R ec2-user:ec2-user * 79 | EOF 80 | - SolutionsRefBucketBase: 81 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 82 | 83 | Outputs: 84 | SourceCode: 85 | Description: "Open Jupyter IDE. This authenticate you against Jupyter." 86 | Value: !Sub "https://${NotebookInstance.NotebookInstanceName}.notebook.${AWS::Region}.sagemaker.aws/" 87 | 88 | SageMakerNotebookInstanceSignOn: 89 | Description: "Link to the SageMaker notebook instance" 90 | Value: !Sub "https://${NotebookInstance.NotebookInstanceName}.notebook.${AWS::Region}.sagemaker.aws/notebooks/notebooks" 91 | -------------------------------------------------------------------------------- /cloudformation/defect-detection.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection: 3 | Solution for training and deploying deep learning models for defect detection in images using Amazon SageMaker. 4 | Version 1" 5 | 6 | Parameters: 7 | SolutionPrefix: 8 | Type: String 9 | Default: "sagemaker-soln-dfd-" 10 | Description: | 11 | Used to name resources created as part of this stack (and inside nested stacks too). 12 | Can be the same as the stack name used by AWS CloudFormation, but this field has extra 13 | constraints because it's used to name resources with restrictions (e.g. Amazon S3 bucket 14 | names cannot contain capital letters). 15 | AllowedPattern: '^sagemaker-soln-dfd[a-z0-9\-]{1,20}$' 16 | ConstraintDescription: | 17 | Only allowed to use lowercase letters, hyphens and/or numbers. 18 | Should also start with 'sagemaker-soln-dfd-' for permission management. 19 | 20 | SolutionName: 21 | Description: | 22 | Prefix for the solution name. Needs to be 'sagemaker-defect-detection' 23 | or begin with 'sagemaker-defect-detection-' followed by a set of letters and hyphens. 24 | Used to specify a particular directory on S3, that can correspond to a development branch. 25 | Type: String 26 | Default: "sagemaker-defect-detection/1.1.0" 27 | AllowedPattern: '^sagemaker-defect-detection(/[0-9]+\.[0-9]+\.[0-9]+-?[a-zA-Z-0-9\.+]*)?$' 28 | 29 | IamRole: 30 | Type: String 31 | Default: "" 32 | Description: | 33 | IAM Role that will be attached to the resources created by this CloudFormation to grant them permissions to 34 | perform their required functions. This role should allow SageMaker and Lambda perform the required actions like 35 | creating training jobs and processing jobs. If left blank, the template will attempt to create a role for you. 36 | This can cause a stack creation error if you don't have privileges to create new roles. 37 | 38 | SageMakerNotebookInstanceType: 39 | Description: SageMaker notebook instance type. 40 | Type: String 41 | Default: "ml.t3.medium" 42 | 43 | CreateSageMakerNotebookInstance: 44 | Description: Whether to launch sagemaker notebook instance 45 | Type: String 46 | AllowedValues: 47 | - "true" 48 | - "false" 49 | Default: "true" 50 | 51 | StackVersion: 52 | Description: | 53 | CloudFormation Stack version. 54 | Use "release" version unless you are customizing the 55 | CloudFormation templates and the solution artifacts in S3 bucket 56 | Type: String 57 | Default: release 58 | AllowedValues: 59 | - release 60 | - development 61 | 62 | Metadata: 63 | AWS::CloudFormation::Interface: 64 | ParameterGroups: 65 | - Label: 66 | default: "Solution Configuration" 67 | Parameters: 68 | - SolutionPrefix 69 | - SolutionName 70 | - IamRole 71 | - Label: 72 | default: "Advanced Configuration" 73 | Parameters: 74 | - SageMakerNotebookInstanceType 75 | - CreateSageMakerNotebookInstance 76 | 77 | ParameterLabels: 78 | SolutionPrefix: 79 | default: "Solution Resources Name Prefix" 80 | SolutionName: 81 | default: "Name of the solution" 82 | IamRole: 83 | default: "Solution IAM Role Arn" 84 | CreateSageMakerNotebookInstance: 85 | default: "Launch SageMaker Notebook Instance" 86 | SageMakerNotebookInstanceType: 87 | default: "SageMaker Notebook Instance Type" 88 | StackVersion: 89 | default: "Solution Stack Version" 90 | 91 | Conditions: 92 | CreateClassicSageMakerResources: 93 | !Equals [!Ref CreateSageMakerNotebookInstance, "true"] 94 | CreateCustomSolutionRole: !Equals [!Ref IamRole, ""] 95 | 96 | Mappings: 97 | S3: 98 | release: 99 | BucketPrefix: "sagemaker-solutions-prod" 100 | development: 101 | BucketPrefix: "sagemaker-solutions-devo" 102 | 103 | Resources: 104 | S3Bucket: 105 | Type: AWS::S3::Bucket 106 | DeletionPolicy: Retain 107 | UpdateReplacePolicy: "Retain" 108 | Properties: 109 | BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 110 | PublicAccessBlockConfiguration: 111 | BlockPublicAcls: true 112 | BlockPublicPolicy: true 113 | IgnorePublicAcls: true 114 | RestrictPublicBuckets: true 115 | BucketEncryption: 116 | ServerSideEncryptionConfiguration: 117 | - ServerSideEncryptionByDefault: 118 | SSEAlgorithm: AES256 119 | 120 | Metadata: 121 | cfn_nag: 122 | rules_to_suppress: 123 | - id: W35 124 | reason: Configuring logging requires supplying an existing customer S3 bucket to store logs. 125 | - id: W51 126 | reason: Default access policy is sufficient. 127 | 128 | SageMakerPermissionsStack: 129 | Type: "AWS::CloudFormation::Stack" 130 | Condition: CreateCustomSolutionRole 131 | Properties: 132 | TemplateURL: !Sub 133 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/defect-detection-permissions.yaml" 134 | - SolutionRefBucketBase: 135 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 136 | Parameters: 137 | SolutionPrefix: !Ref SolutionPrefix 138 | SolutionName: !Ref SolutionName 139 | S3Bucket: !Ref S3Bucket 140 | StackVersion: !Ref StackVersion 141 | 142 | SageMakerStack: 143 | Type: "AWS::CloudFormation::Stack" 144 | Condition: CreateClassicSageMakerResources 145 | Properties: 146 | TemplateURL: !Sub 147 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/defect-detection-sagemaker-notebook-instance.yaml" 148 | - SolutionRefBucketBase: 149 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 150 | Parameters: 151 | SolutionPrefix: !Ref SolutionPrefix 152 | SolutionName: !Ref SolutionName 153 | S3Bucket: !Ref S3Bucket 154 | SageMakerIAMRoleArn: 155 | !If [ 156 | CreateCustomSolutionRole, 157 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 158 | !Ref IamRole, 159 | ] 160 | SageMakerNotebookInstanceType: !Ref SageMakerNotebookInstanceType 161 | StackVersion: !Ref StackVersion 162 | 163 | SolutionAssistantStack: 164 | Type: "AWS::CloudFormation::Stack" 165 | Properties: 166 | TemplateURL: !Sub 167 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/solution-assistant/solution-assistant.yaml" 168 | - SolutionRefBucketBase: 169 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 170 | Parameters: 171 | SolutionPrefix: !Ref SolutionPrefix 172 | SolutionName: !Ref SolutionName 173 | StackName: !Ref AWS::StackName 174 | S3Bucket: !Ref S3Bucket 175 | SolutionS3Bucket: !Sub 176 | - "${SolutionRefBucketBase}-${AWS::Region}" 177 | - SolutionRefBucketBase: 178 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 179 | RoleArn: 180 | !If [ 181 | CreateCustomSolutionRole, 182 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 183 | !Ref IamRole, 184 | ] 185 | 186 | Outputs: 187 | SolutionName: 188 | Description: "The name for the solution, can be used to deploy different versions of the solution" 189 | Value: !Ref SolutionName 190 | 191 | SourceCode: 192 | Condition: CreateClassicSageMakerResources 193 | Description: "Open Jupyter IDE. This authenticate you against Jupyter" 194 | Value: !GetAtt SageMakerStack.Outputs.SourceCode 195 | 196 | NotebookInstance: 197 | Description: "SageMaker Notebook instance to manually orchestrate data preprocessing, model training and deploying an endpoint" 198 | Value: 199 | !If [ 200 | CreateClassicSageMakerResources, 201 | !GetAtt SageMakerStack.Outputs.SageMakerNotebookInstanceSignOn, 202 | "", 203 | ] 204 | 205 | AccountID: 206 | Description: "AWS Account ID to be passed downstream to the notebook instance" 207 | Value: !Ref AWS::AccountId 208 | 209 | AWSRegion: 210 | Description: "AWS Region to be passed downstream to the notebook instance" 211 | Value: !Ref AWS::Region 212 | 213 | IamRole: 214 | Description: "Arn of SageMaker Execution Role" 215 | Value: 216 | !If [ 217 | CreateCustomSolutionRole, 218 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 219 | !Ref IamRole, 220 | ] 221 | 222 | SolutionPrefix: 223 | Description: "Solution Prefix for naming SageMaker transient resources" 224 | Value: !Ref SolutionPrefix 225 | 226 | S3Bucket: 227 | Description: "S3 bucket name used in the solution to store artifacts" 228 | Value: !Ref S3Bucket 229 | 230 | SolutionS3Bucket: 231 | Description: "Solution S3 bucket" 232 | Value: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 233 | 234 | SageMakerMode: 235 | Value: !If [CreateClassicSageMakerResources, "NotebookInstance", "Studio"] 236 | 237 | StackName: 238 | Value: !Ref AWS::StackName 239 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/requirements.in: -------------------------------------------------------------------------------- 1 | crhelper 2 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | crhelper==2.0.6 # via -r requirements.in 8 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/solution-assistant.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection solution assistant stack" 3 | 4 | Parameters: 5 | SolutionPrefix: 6 | Type: String 7 | SolutionName: 8 | Type: String 9 | StackName: 10 | Type: String 11 | S3Bucket: 12 | Type: String 13 | SolutionS3Bucket: 14 | Type: String 15 | RoleArn: 16 | Type: String 17 | 18 | Mappings: 19 | Function: 20 | SolutionAssistant: 21 | S3Key: "build/solution-assistant.zip" 22 | 23 | Resources: 24 | SolutionAssistant: 25 | Type: "Custom::SolutionAssistant" 26 | Properties: 27 | SolutionPrefix: !Ref SolutionPrefix 28 | ServiceToken: !GetAtt SolutionAssistantLambda.Arn 29 | S3Bucket: !Ref S3Bucket 30 | StackName: !Ref StackName 31 | 32 | SolutionAssistantLambda: 33 | Type: AWS::Lambda::Function 34 | Properties: 35 | Handler: "lambda_fn.handler" 36 | FunctionName: !Sub "${SolutionPrefix}-solution-assistant" 37 | Role: !Ref RoleArn 38 | Runtime: "python3.8" 39 | Code: 40 | S3Bucket: !Ref SolutionS3Bucket 41 | S3Key: !Sub 42 | - "${SolutionName}/${LambdaS3Key}" 43 | - LambdaS3Key: !FindInMap [Function, SolutionAssistant, S3Key] 44 | Timeout: 60 45 | Metadata: 46 | cfn_nag: 47 | rules_to_suppress: 48 | - id: W58 49 | reason: >- 50 | The required permissions are provided in the permissions stack. 51 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/src/lambda_fn.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import sys 3 | 4 | sys.path.append("./site-packages") 5 | from crhelper import CfnResource 6 | 7 | helper = CfnResource() 8 | 9 | 10 | @helper.create 11 | def on_create(_, __): 12 | pass 13 | 14 | 15 | @helper.update 16 | def on_update(_, __): 17 | pass 18 | 19 | 20 | def delete_sagemaker_endpoint(endpoint_name): 21 | sagemaker_client = boto3.client("sagemaker") 22 | try: 23 | sagemaker_client.delete_endpoint(EndpointName=endpoint_name) 24 | print("Successfully deleted endpoint " "called '{}'.".format(endpoint_name)) 25 | except sagemaker_client.exceptions.ClientError as e: 26 | if "Could not find endpoint" in str(e): 27 | print("Could not find endpoint called '{}'. " "Skipping delete.".format(endpoint_name)) 28 | else: 29 | raise e 30 | 31 | 32 | def delete_sagemaker_endpoint_config(endpoint_config_name): 33 | sagemaker_client = boto3.client("sagemaker") 34 | try: 35 | sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) 36 | print("Successfully deleted endpoint configuration " "called '{}'.".format(endpoint_config_name)) 37 | except sagemaker_client.exceptions.ClientError as e: 38 | if "Could not find endpoint configuration" in str(e): 39 | print( 40 | "Could not find endpoint configuration called '{}'. " "Skipping delete.".format(endpoint_config_name) 41 | ) 42 | else: 43 | raise e 44 | 45 | 46 | def delete_sagemaker_model(model_name): 47 | sagemaker_client = boto3.client("sagemaker") 48 | try: 49 | sagemaker_client.delete_model(ModelName=model_name) 50 | print("Successfully deleted model called '{}'.".format(model_name)) 51 | except sagemaker_client.exceptions.ClientError as e: 52 | if "Could not find model" in str(e): 53 | print("Could not find model called '{}'. " "Skipping delete.".format(model_name)) 54 | else: 55 | raise e 56 | 57 | 58 | def delete_s3_objects(bucket_name): 59 | s3_resource = boto3.resource("s3") 60 | try: 61 | s3_resource.Bucket(bucket_name).objects.all().delete() 62 | print("Successfully deleted objects in bucket " "called '{}'.".format(bucket_name)) 63 | except s3_resource.meta.client.exceptions.NoSuchBucket: 64 | print("Could not find bucket called '{}'. " "Skipping delete.".format(bucket_name)) 65 | 66 | 67 | def delete_s3_bucket(bucket_name): 68 | s3_resource = boto3.resource("s3") 69 | try: 70 | s3_resource.Bucket(bucket_name).delete() 71 | print("Successfully deleted bucket " "called '{}'.".format(bucket_name)) 72 | except s3_resource.meta.client.exceptions.NoSuchBucket: 73 | print("Could not find bucket called '{}'. " "Skipping delete.".format(bucket_name)) 74 | 75 | 76 | @helper.delete 77 | def on_delete(event, __): 78 | # remove sagemaker endpoints 79 | solution_prefix = event["ResourceProperties"]["SolutionPrefix"] 80 | endpoint_names = [ 81 | "{}-demo-endpoint".format(solution_prefix), # make sure it is the same as your endpoint name 82 | "{}-demo-model".format(solution_prefix), 83 | "{}-finetuned-endpoint".format(solution_prefix), 84 | "{}-detector-from-scratch-endpoint".format(solution_prefix), 85 | "{}-classification-endpoint".format(solution_prefix), 86 | ] 87 | for endpoint_name in endpoint_names: 88 | delete_sagemaker_model(endpoint_name) 89 | delete_sagemaker_endpoint_config(endpoint_name) 90 | delete_sagemaker_endpoint(endpoint_name) 91 | 92 | # remove files in s3 93 | output_bucket = event["ResourceProperties"]["S3Bucket"] 94 | delete_s3_objects(output_bucket) 95 | 96 | # delete buckets 97 | delete_s3_bucket(output_bucket) 98 | 99 | 100 | def handler(event, context): 101 | helper(event, context) 102 | -------------------------------------------------------------------------------- /docs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/arch.png -------------------------------------------------------------------------------- /docs/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/data.png -------------------------------------------------------------------------------- /docs/data_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/data_flow.png -------------------------------------------------------------------------------- /docs/launch.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Click to Create Stack 5 | 6 | 7 | -------------------------------------------------------------------------------- /docs/numerical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/numerical.png -------------------------------------------------------------------------------- /docs/sagemaker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/sagemaker.png -------------------------------------------------------------------------------- /docs/sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/sample1.png -------------------------------------------------------------------------------- /docs/sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/sample2.png -------------------------------------------------------------------------------- /docs/sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/sample3.png -------------------------------------------------------------------------------- /docs/train_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/docs/train_arch.png -------------------------------------------------------------------------------- /manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.1.1" 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/0_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## SageMaker Defect Detection Demo\n", 20 | "\n", 21 | "In this notebook, we deploy an endpoint from a provided pretrained detection model that was already trained on **NEU-DET** dataset. Then, we send some image samples with defects for detection and visual the results" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "\n", 34 | "import sagemaker\n", 35 | "from sagemaker.s3 import S3Downloader\n", 36 | "\n", 37 | "sagemaker_session = sagemaker.Session()\n", 38 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 39 | "role = sagemaker_config[\"IamRole\"]\n", 40 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 41 | "region = sagemaker_config[\"AWSRegion\"]\n", 42 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 43 | "bucket = sagemaker_config[\"S3Bucket\"]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 60 | "original_data_prefix = \"data/NEU-DET.zip\"\n", 61 | "original_data = f\"{original_bucket}/{original_data_prefix}\"\n", 62 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"\n", 64 | "print(\"original data: \")\n", 65 | "S3Downloader.list(original_data)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "For easiler data processing, depending on the dataset, we unify the class and label names using the scripts from `prepare_data` which should take less than **5 minutes** to complete. This is done once throughout all our notebooks" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "%%time\n", 82 | "\n", 83 | "RAW_DATA_PATH = !echo $PWD/raw_neu_det\n", 84 | "RAW_DATA_PATH = RAW_DATA_PATH.n\n", 85 | "DATA_PATH = !echo $PWD/neu_det\n", 86 | "DATA_PATH = DATA_PATH.n\n", 87 | "\n", 88 | "!mkdir -p $RAW_DATA_PATH\n", 89 | "!aws s3 cp $original_data $RAW_DATA_PATH\n", 90 | "\n", 91 | "!mkdir -p $DATA_PATH\n", 92 | "!python ../src/prepare_data/neu.py $RAW_DATA_PATH/NEU-DET.zip $DATA_PATH" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "After data preparation, we need upload the prepare data to S3 and setup some paths that will be used throughtout the notebooks" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "%%time\n", 109 | "prefix = \"neu-det\"\n", 110 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 111 | "sources = f\"{neu_det_s3}/code/\"\n", 112 | "train_output = f\"{neu_det_s3}/output/\"\n", 113 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 114 | "!aws s3 sync $DATA_PATH $neu_det_prepared_s3 --quiet # remove the --quiet flag to view the sync logs\n", 115 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 116 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 117 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"\n", 118 | "!aws s3 sync $original_pretained_checkpoint $s3_pretrained\n", 119 | "!aws s3 ls $s3_pretrained\n", 120 | "!aws s3 cp $original_sources $sources" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Visualization\n", 128 | "\n", 129 | "Let examine some datasets that we will use later by providing an `ID`" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "import copy\n", 139 | "\n", 140 | "import numpy as np\n", 141 | "import torch\n", 142 | "from PIL import Image\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "\n", 145 | "try:\n", 146 | " import sagemaker_defect_detection\n", 147 | "except ImportError:\n", 148 | " import sys\n", 149 | " from pathlib import Path\n", 150 | "\n", 151 | " ROOT = Path(\"../src\").resolve()\n", 152 | " sys.path.insert(0, str(ROOT))\n", 153 | "\n", 154 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 155 | "\n", 156 | "SPLIT = \"test\"\n", 157 | "ID = 10\n", 158 | "assert 0 <= ID <= 300\n", 159 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 160 | "images, targets, _ = dataset[ID]\n", 161 | "original_image = copy.deepcopy(images)\n", 162 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 163 | "original_labels = targets[\"labels\"].numpy().copy()\n", 164 | "print(f\"first images size: {original_image.shape}\")\n", 165 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 166 | "print(f\"target labels: {original_labels}\")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "And we can now visualize it using the provided utilities as follows" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 183 | "\n", 184 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 185 | "\n", 186 | "visualize(\n", 187 | " original_image_unnorm,\n", 188 | " [original_boxes],\n", 189 | " [original_labels],\n", 190 | " colors=[(255, 0, 0)],\n", 191 | " titles=[\"original\", \"ground truth\"],\n", 192 | ")" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "For our demo, we deploy an endpoint using a provided pretrained checkpoint. It takes about **10 minutes** to finish" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "%%time\n", 209 | "from os import path as osp\n", 210 | "\n", 211 | "from sagemaker.pytorch import PyTorchModel\n", 212 | "\n", 213 | "demo_model = PyTorchModel(\n", 214 | " osp.join(s3_pretrained, \"model.tar.gz\"),\n", 215 | " role,\n", 216 | " entry_point=\"detector.py\",\n", 217 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 218 | " framework_version=\"1.5\",\n", 219 | " py_version=\"py3\",\n", 220 | " sagemaker_session=sagemaker_session,\n", 221 | " name=sagemaker_config[\"SolutionPrefix\"] + \"-demo-model\",\n", 222 | ")\n", 223 | "\n", 224 | "demo_detector = demo_model.deploy(\n", 225 | " initial_instance_count=1,\n", 226 | " instance_type=\"ml.m5.xlarge\",\n", 227 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\",\n", 228 | ")" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## Inference\n", 236 | "\n", 237 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Now the input is ready and we can get some results" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "%%time\n", 263 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 264 | "np_load_old = np.load\n", 265 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 266 | "demo_predictions = demo_detector.predict(input)\n", 267 | "np.load = np_load_old" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "And finally, we visualize them as follows " 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "tags": [] 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "visualize(\n", 286 | " original_image_unnorm,\n", 287 | " [original_boxes, demo_predictions[0][\"boxes\"]],\n", 288 | " [original_labels, demo_predictions[0][\"labels\"]],\n", 289 | " colors=[(255, 0, 0), (0, 0, 255)],\n", 290 | " titles=[\"original\", \"ground truth\", \"pretrained demo\"],\n", 291 | " dpi=200,\n", 292 | ")" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "## Optional: Delete the endpoint and model\n", 300 | "\n", 301 | "**Note:** to follow all the notebooks, it is required to keep demo model and the demo endpoint. It will be automatically deleted when you delete the entire resources/stack. However, if you need to, please uncomment and run the next cell\n", 302 | "\n", 303 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "# demo_detector.delete_model()\n", 313 | "# demo_detector.delete_endpoint()" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### [Click here to continue](./1_retrain_from_checkpoint.ipynb)" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "kernelspec": { 326 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 327 | "name": "python361064bitpytorchlatestp36conda2dfac45b320c45f3a4d1e89ca46b60d1" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.6.10-final" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 2 344 | } 345 | -------------------------------------------------------------------------------- /notebooks/1_retrain_from_checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## Resume Training\n", 20 | "\n", 21 | "In this notebook, we retrain our pretrained detector for a few more epochs and compare its results. The same process can be applied when finetuning on another dataset. For the purpose of this notebook, we use the same **NEU-DET** dataset.\n", 22 | "\n", 23 | "## Finetuning\n", 24 | "\n", 25 | "Finetuning is one way to do Transfer Learning. Finetuning a Deep Learning model on one particular task, involves using the learned weights from a particular dataset to enhance the performace of the model on usually another dataset. In a sense, finetuning can be done over the same dataset used in the intial training but perhaps with different hyperparameters.\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import json\n", 35 | "\n", 36 | "import sagemaker\n", 37 | "from sagemaker.s3 import S3Downloader\n", 38 | "\n", 39 | "sagemaker_session = sagemaker.Session()\n", 40 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 41 | "role = sagemaker_config[\"IamRole\"]\n", 42 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 43 | "region = sagemaker_config[\"AWSRegion\"]\n", 44 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 45 | "bucket = sagemaker_config[\"S3Bucket\"]" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 62 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Note that for easiler data processing, we have already executed `prepare_data` once in our `0_demo.ipynb` and have already uploaded the prepared data to our S3 bucket" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "DATA_PATH = !echo $PWD/neu_det\n", 80 | "DATA_PATH = DATA_PATH.n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "tags": [] 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "prefix = \"neu-det\"\n", 99 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 100 | "sources = f\"{neu_det_s3}/code/\"\n", 101 | "train_output = f\"{neu_det_s3}/output/\"\n", 102 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 103 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 104 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 105 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## Visualization\n", 113 | "\n", 114 | "Let examine some datasets that we will use later by providing an `ID`" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import copy\n", 124 | "\n", 125 | "import numpy as np\n", 126 | "import torch\n", 127 | "from PIL import Image\n", 128 | "from torch.utils.data import DataLoader\n", 129 | "\n", 130 | "try:\n", 131 | " import sagemaker_defect_detection\n", 132 | "except ImportError:\n", 133 | " import sys\n", 134 | " from pathlib import Path\n", 135 | "\n", 136 | " ROOT = Path(\"../src\").resolve()\n", 137 | " sys.path.insert(0, str(ROOT))\n", 138 | "\n", 139 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 140 | "\n", 141 | "SPLIT = \"test\"\n", 142 | "ID = 30\n", 143 | "assert 0 <= ID <= 300\n", 144 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 145 | "images, targets, _ = dataset[ID]\n", 146 | "original_image = copy.deepcopy(images)\n", 147 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 148 | "original_labels = targets[\"labels\"].numpy().copy()\n", 149 | "print(f\"first images size: {original_image.shape}\")\n", 150 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 151 | "print(f\"target labels: {original_labels}\")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "And we can now visualize it using the provided utilities as follows" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 168 | "\n", 169 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 170 | "\n", 171 | "visualize(\n", 172 | " original_image_unnorm,\n", 173 | " [original_boxes],\n", 174 | " [original_labels],\n", 175 | " colors=[(255, 0, 0)],\n", 176 | " titles=[\"original\", \"ground truth\"],\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Here we resume from a provided pretrained checkpoint `epoch=294-loss=0.654-main_score=0.349.ckpt` that we have copied into our `s3_pretrained`. This takes about **10 minutes** to complete" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "tags": [ 192 | "outputPrepend" 193 | ] 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "%%time\n", 198 | "import logging\n", 199 | "from os import path as osp\n", 200 | "\n", 201 | "from sagemaker.pytorch import PyTorch\n", 202 | "\n", 203 | "NUM_CLASSES = 7 # 6 classes + 1 for background\n", 204 | "# Note: resnet34 was used in the pretrained model and it has to match the pretrained model backbone\n", 205 | "# if need resnet50, need to train from scratch\n", 206 | "BACKBONE = \"resnet34\"\n", 207 | "assert BACKBONE in [\n", 208 | " \"resnet34\",\n", 209 | " \"resnet50\",\n", 210 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in detector.py\"\n", 211 | "EPOCHS = 5\n", 212 | "LEARNING_RATE = 1e-4\n", 213 | "SEED = 123\n", 214 | "\n", 215 | "hyperparameters = {\n", 216 | " \"backbone\": BACKBONE, # the backbone resnet model for feature extraction\n", 217 | " \"num-classes\": NUM_CLASSES, # number of classes + background\n", 218 | " \"epochs\": EPOCHS, # number of epochs to finetune\n", 219 | " \"learning-rate\": LEARNING_RATE, # learning rate for optimizer\n", 220 | " \"seed\": SEED, # random number generator seed\n", 221 | "}\n", 222 | "\n", 223 | "assert not isinstance(sagemaker_session, sagemaker.LocalSession), \"local session as share memory cannot be altered\"\n", 224 | "\n", 225 | "finetuned_model = PyTorch(\n", 226 | " entry_point=\"detector.py\",\n", 227 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 228 | " role=role,\n", 229 | " train_instance_count=1,\n", 230 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 231 | " hyperparameters=hyperparameters,\n", 232 | " py_version=\"py3\",\n", 233 | " framework_version=\"1.5\",\n", 234 | " sagemaker_session=sagemaker_session,\n", 235 | " output_path=train_output,\n", 236 | " checkpoint_s3_uri=s3_checkpoint,\n", 237 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 238 | " # container_log_level=logging.DEBUG,\n", 239 | ")\n", 240 | "\n", 241 | "finetuned_model.fit(\n", 242 | " {\n", 243 | " \"training\": neu_det_prepared_s3,\n", 244 | " \"pretrained_checkpoint\": osp.join(s3_pretrained, \"epoch=294-loss=0.654-main_score=0.349.ckpt\"),\n", 245 | " }\n", 246 | ")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Then, we deploy our new model which takes about **10 minutes** to complete" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "%%time\n", 263 | "finetuned_detector = finetuned_model.deploy(\n", 264 | " initial_instance_count=1,\n", 265 | " instance_type=\"ml.m5.xlarge\",\n", 266 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-finetuned-endpoint\",\n", 267 | ")" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "## Inference\n", 275 | "\n", 276 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "Now the input is ready and we can get some results" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "%%time\n", 302 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 303 | "np_load_old = np.load\n", 304 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 305 | "finetuned_predictions = finetuned_detector.predict(input)\n", 306 | "np.load = np_load_old" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Here we want to compare the results of the new model and the pretrained model that we already deployed in `0_demo.ipynb` visually by calling our endpoint from SageMaker runtime using `boto3`" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "import boto3\n", 323 | "import botocore\n", 324 | "\n", 325 | "config = botocore.config.Config(read_timeout=200)\n", 326 | "runtime = boto3.client(\"runtime.sagemaker\", config=config)\n", 327 | "payload = json.dumps(input.tolist() if isinstance(input, np.ndarray) else input)\n", 328 | "response = runtime.invoke_endpoint(\n", 329 | " EndpointName=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\", ContentType=\"application/json\", Body=payload\n", 330 | ")\n", 331 | "demo_predictions = json.loads(response[\"Body\"].read().decode())" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "Here comes the slight changes in inference" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "visualize(\n", 348 | " original_image_unnorm,\n", 349 | " [original_boxes, demo_predictions[0][\"boxes\"], finetuned_predictions[0][\"boxes\"]],\n", 350 | " [original_labels, demo_predictions[0][\"labels\"], finetuned_predictions[0][\"labels\"]],\n", 351 | " colors=[(255, 0, 0), (0, 0, 255), (127, 0, 127)],\n", 352 | " titles=[\"original\", \"ground truth\", \"pretrained\", \"finetuned\"],\n", 353 | " dpi=250,\n", 354 | ")" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "## Optional: Delete the endpoint and model\n", 362 | "\n", 363 | "When you are done with the endpoint, you should clean it up.\n", 364 | "\n", 365 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account.\n" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "finetuned_detector.delete_model()\n", 375 | "finetuned_detector.delete_endpoint()" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "### [Click here to continue](./2_detection_from_scratch.ipynb)" 383 | ] 384 | } 385 | ], 386 | "metadata": { 387 | "kernelspec": { 388 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 389 | "name": "python361064bitpytorchlatestp36conda2dfac45b320c45f3a4d1e89ca46b60d1" 390 | }, 391 | "language_info": { 392 | "codemirror_mode": { 393 | "name": "ipython", 394 | "version": 3 395 | }, 396 | "file_extension": ".py", 397 | "mimetype": "text/x-python", 398 | "name": "python", 399 | "nbconvert_exporter": "python", 400 | "pygments_lexer": "ipython3", 401 | "version": "3.6.10-final" 402 | } 403 | }, 404 | "nbformat": 4, 405 | "nbformat_minor": 2 406 | } 407 | -------------------------------------------------------------------------------- /notebooks/2_detection_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## Training our Detector from Scratch\n", 20 | "\n", 21 | "In this notebook, we will see how to train our detector from scratch" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import sagemaker\n", 33 | "from sagemaker.s3 import S3Downloader\n", 34 | "\n", 35 | "sagemaker_session = sagemaker.Session()\n", 36 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 37 | "role = sagemaker_config[\"IamRole\"]\n", 38 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 39 | "region = sagemaker_config[\"AWSRegion\"]\n", 40 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 41 | "bucket = sagemaker_config[\"S3Bucket\"]" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 58 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 59 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "Note that for easiler data processing, we have already executed `prepare_data` once in our `0_demo.ipynb` and have already uploaded the prepared data to our S3 bucket" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "tags": [] 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "DATA_PATH = !echo $PWD/neu_det\n", 78 | "DATA_PATH = DATA_PATH.n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "prefix = \"neu-det\"\n", 97 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 98 | "sources = f\"{neu_det_s3}/code/\"\n", 99 | "train_output = f\"{neu_det_s3}/output/\"\n", 100 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 101 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 102 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 103 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Visualization\n", 111 | "\n", 112 | "Let examine some datasets that we will use later by providing an `ID`" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "import copy\n", 124 | "from typing import List\n", 125 | "\n", 126 | "import numpy as np\n", 127 | "import torch\n", 128 | "from PIL import Image\n", 129 | "from torch.utils.data import DataLoader\n", 130 | "\n", 131 | "try:\n", 132 | " import sagemaker_defect_detection\n", 133 | "except ImportError:\n", 134 | " import sys\n", 135 | " from pathlib import Path\n", 136 | "\n", 137 | " ROOT = Path(\"../src\").resolve()\n", 138 | " sys.path.insert(0, str(ROOT))\n", 139 | "\n", 140 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 141 | "\n", 142 | "SPLIT = \"test\"\n", 143 | "ID = 50\n", 144 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 145 | "images, targets, _ = dataset[ID]\n", 146 | "original_image = copy.deepcopy(images)\n", 147 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 148 | "original_labels = targets[\"labels\"].numpy().copy()\n", 149 | "print(f\"first images size: {original_image.shape}\")\n", 150 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 151 | "print(f\"target labels: {original_labels}\")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "And we can now visualize it using the provided utilities as follows" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 168 | "\n", 169 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 170 | "\n", 171 | "visualize(\n", 172 | " original_image_unnorm,\n", 173 | " [original_boxes],\n", 174 | " [original_labels],\n", 175 | " colors=[(255, 0, 0)],\n", 176 | " titles=[\"original\", \"ground truth\"],\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "In order to get high **Mean Average Percision (mAP)** and **Mean Average Recall (mAR)** for **Intersection Over Union (IOU)** thresholds of 0.5 when training from scratch, it requires more than **300 epochs**. That is why we have provided a pretrained model and recommend finetuning whenever is possible. For demostration, we train the model from scratch for 10 epochs which takes about **16 minutes** and it results in the following mAP, mAR and the accumulated `main_score` of\n", 185 | "\n", 186 | "* `Average Precision (AP) @[ IoU=0.50:0.95 ] ~ 0.048`\n", 187 | "* `Average Recall (AR) @[ IoU=0.50:0.95] ~ 0.153`\n", 188 | "* `main_score=0.0509`\n", 189 | "\n", 190 | "To get higher mAP, mAR and overall `main_score`, you can train for more epochs" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "tags": [ 198 | "outputPrepend" 199 | ] 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "%%time\n", 204 | "import logging\n", 205 | "from os import path as osp\n", 206 | "\n", 207 | "from sagemaker.pytorch import PyTorch\n", 208 | "\n", 209 | "NUM_CLASSES = 7 # 6 classes + 1 for background\n", 210 | "BACKBONE = \"resnet34\" # has to match the pretrained model backbone\n", 211 | "assert BACKBONE in [\n", 212 | " \"resnet34\",\n", 213 | " \"resnet50\",\n", 214 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in detector.py\"\n", 215 | "EPOCHS = 10\n", 216 | "LEARNING_RATE = 1e-3\n", 217 | "SEED = 123\n", 218 | "\n", 219 | "hyperparameters = {\n", 220 | " \"backbone\": BACKBONE, # the backbone resnet model for feature extraction\n", 221 | " \"num-classes\": NUM_CLASSES, # number of classes 6 + 1 background\n", 222 | " \"epochs\": EPOCHS, # number of epochs to train\n", 223 | " \"learning-rate\": LEARNING_RATE, # learning rate for optimizer\n", 224 | " \"seed\": SEED, # random number generator seed\n", 225 | "}\n", 226 | "\n", 227 | "model = PyTorch(\n", 228 | " entry_point=\"detector.py\",\n", 229 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 230 | " role=role,\n", 231 | " train_instance_count=1,\n", 232 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 233 | " hyperparameters=hyperparameters,\n", 234 | " py_version=\"py3\",\n", 235 | " framework_version=\"1.5\",\n", 236 | " sagemaker_session=sagemaker_session,\n", 237 | " output_path=train_output,\n", 238 | " checkpoint_s3_uri=s3_checkpoint,\n", 239 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 240 | " # container_log_level=logging.DEBUG,\n", 241 | ")\n", 242 | "\n", 243 | "model.fit(neu_det_prepared_s3)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "Then, we deploy our model which takes about **10 minutes** to complete" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "tags": [] 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "%%time\n", 262 | "detector = model.deploy(\n", 263 | " initial_instance_count=1,\n", 264 | " instance_type=\"ml.m5.xlarge\",\n", 265 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-detector-from-scratch-endpoint\",\n", 266 | ")" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "## Inference\n", 274 | "\n", 275 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "Now the input is ready and we can get some results" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": { 298 | "tags": [] 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "%%time\n", 303 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 304 | "np_load_old = np.load\n", 305 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 306 | "predictions = detector.predict(input)\n", 307 | "np.load = np_load_old" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "We use our `visualize` utility to check the detection results" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "visualize(\n", 324 | " original_image_unnorm,\n", 325 | " [original_boxes, predictions[0][\"boxes\"]],\n", 326 | " [original_labels, predictions[0][\"labels\"]],\n", 327 | " colors=[(255, 0, 0), (0, 0, 255)],\n", 328 | " titles=[\"original\", \"ground truth\", \"trained from scratch\"],\n", 329 | " dpi=200,\n", 330 | ")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "Here we want to compare the results of the new model and the pretrained model that we already deployed in `0_demo.ipynb` visually by calling our endpoint from SageMaker runtime using `boto3`" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "import boto3\n", 347 | "import botocore\n", 348 | "\n", 349 | "config = botocore.config.Config(read_timeout=200)\n", 350 | "runtime = boto3.client(\"runtime.sagemaker\", config=config)\n", 351 | "payload = json.dumps(input.tolist() if isinstance(input, np.ndarray) else input)\n", 352 | "response = runtime.invoke_endpoint(\n", 353 | " EndpointName=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\", ContentType=\"application/json\", Body=payload\n", 354 | ")\n", 355 | "demo_predictions = json.loads(response[\"Body\"].read().decode())" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "Finally, we compare the results of the provided pretrained model and our trained from scratch" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "visualize(\n", 372 | " original_image_unnorm,\n", 373 | " [original_boxes, demo_predictions[0][\"boxes\"], predictions[0][\"boxes\"]],\n", 374 | " [original_labels, demo_predictions[0][\"labels\"], predictions[0][\"labels\"]],\n", 375 | " colors=[(255, 0, 0), (0, 0, 255), (127, 0, 127)],\n", 376 | " titles=[\"original\", \"ground truth\", \"pretrained\", \"from scratch\"],\n", 377 | " dpi=250,\n", 378 | ")" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "## Optional: Delete the endpoint and model\n", 386 | "\n", 387 | "When you are done with the endpoint, you should clean it up.\n", 388 | "\n", 389 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "detector.delete_model()\n", 399 | "detector.delete_endpoint()" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "### [Click here to continue](./3_classification_from_scratch.ipynb)" 407 | ] 408 | } 409 | ], 410 | "metadata": { 411 | "kernelspec": { 412 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 413 | "metadata": { 414 | "interpreter": { 415 | "hash": "4c1e195df8d07db5ee7a78f454b46c3f2e14214bf8c9489d2db5cf8f372ff2ed" 416 | } 417 | }, 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.6.10-final" 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /notebooks/3_classification_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "## Training our Classifier from scratch\n", 18 | "\n", 19 | "Depending on an application, sometimes image classification is enough. In this notebook, we see how to train and deploy an accurate classifier from scratch on **NEU-CLS** dataset" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "\n", 34 | "import sagemaker\n", 35 | "from sagemaker.s3 import S3Downloader\n", 36 | "\n", 37 | "sagemaker_session = sagemaker.Session()\n", 38 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 39 | "role = sagemaker_config[\"IamRole\"]\n", 40 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 41 | "region = sagemaker_config[\"AWSRegion\"]\n", 42 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 43 | "bucket = sagemaker_config[\"S3Bucket\"]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "First, we download our **NEU-CLS** dataset from our public S3 bucket" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "from sagemaker.s3 import S3Downloader\n", 60 | "\n", 61 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 62 | "original_data = f\"{original_bucket}/data/NEU-CLS.zip\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"\n", 64 | "print(\"original data: \")\n", 65 | "S3Downloader.list(original_data)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "For easiler data processing, depending on the dataset, we unify the class and label names using the scripts from `prepare_data`" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "tags": [] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "%%time\n", 84 | "RAW_DATA_PATH= !echo $PWD/raw_neu_cls\n", 85 | "RAW_DATA_PATH = RAW_DATA_PATH.n\n", 86 | "DATA_PATH = !echo $PWD/neu_cls\n", 87 | "DATA_PATH = DATA_PATH.n\n", 88 | "\n", 89 | "!mkdir -p $RAW_DATA_PATH\n", 90 | "!aws s3 cp $original_data $RAW_DATA_PATH\n", 91 | "\n", 92 | "!mkdir -p $DATA_PATH\n", 93 | "!python ../src/prepare_data/neu.py $RAW_DATA_PATH/NEU-CLS.zip $DATA_PATH" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "prefix = \"neu-cls\"\n", 110 | "neu_cls_s3 = f\"s3://{bucket}/{prefix}\"\n", 111 | "sources = f\"{neu_cls_s3}/code/\"\n", 112 | "train_output = f\"{neu_cls_s3}/output/\"\n", 113 | "neu_cls_prepared_s3 = f\"{neu_cls_s3}/data/\"\n", 114 | "!aws s3 sync $DATA_PATH $neu_cls_prepared_s3 --quiet # remove the --quiet flag to view sync outputs\n", 115 | "s3_checkpoint = f\"{neu_cls_s3}/checkpoint/\"\n", 116 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 117 | "!aws s3 cp $original_sources $sources" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Visualization\n", 125 | "\n", 126 | "Let examine some datasets that we will use later by providing an `ID`" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "import matplotlib.pyplot as plt\n", 138 | "\n", 139 | "%matplotlib inline\n", 140 | "import numpy as np\n", 141 | "import torch\n", 142 | "from PIL import Image\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "\n", 145 | "try:\n", 146 | " import sagemaker_defect_detection\n", 147 | "except ImportError:\n", 148 | " import sys\n", 149 | " from pathlib import Path\n", 150 | "\n", 151 | " ROOT = Path(\"../src\").resolve()\n", 152 | " sys.path.insert(0, str(ROOT))\n", 153 | "\n", 154 | "from sagemaker_defect_detection import NEUCLS\n", 155 | "\n", 156 | "\n", 157 | "def visualize(image, label, predicted=None):\n", 158 | " if not isinstance(image, Image.Image):\n", 159 | " image = Image.fromarray(image)\n", 160 | "\n", 161 | " plt.figure(dpi=120)\n", 162 | " if predicted is not None:\n", 163 | " plt.title(f\"label: {label}, prediction: {predicted}\")\n", 164 | " else:\n", 165 | " plt.title(f\"label: {label}\")\n", 166 | "\n", 167 | " plt.axis(\"off\")\n", 168 | " plt.imshow(image)\n", 169 | " return\n", 170 | "\n", 171 | "\n", 172 | "dataset = NEUCLS(DATA_PATH, split=\"train\")\n", 173 | "ID = 0\n", 174 | "assert 0 <= ID <= 300\n", 175 | "image, label = dataset[ID]\n", 176 | "visualize(image, label)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "We train our model with `resnet34` backbone for **50 epochs** and obtains about **99%** test accuracy, f1-score, precision and recall as follows" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "tags": [ 191 | "outputPrepend" 192 | ] 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "%%time\n", 197 | "import logging\n", 198 | "from os import path as osp\n", 199 | "\n", 200 | "from sagemaker.pytorch import PyTorch\n", 201 | "\n", 202 | "NUM_CLASSES = 6\n", 203 | "BACKBONE = \"resnet34\"\n", 204 | "assert BACKBONE in [\n", 205 | " \"resnet34\",\n", 206 | " \"resnet50\",\n", 207 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in classifier.py\"\n", 208 | "EPOCHS = 50\n", 209 | "SEED = 123\n", 210 | "\n", 211 | "hyperparameters = {\n", 212 | " \"backbone\": BACKBONE,\n", 213 | " \"num-classes\": NUM_CLASSES,\n", 214 | " \"epochs\": EPOCHS,\n", 215 | " \"seed\": SEED,\n", 216 | "}\n", 217 | "\n", 218 | "assert not isinstance(sagemaker_session, sagemaker.LocalSession), \"local session as share memory cannot be altered\"\n", 219 | "\n", 220 | "model = PyTorch(\n", 221 | " entry_point=\"classifier.py\",\n", 222 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 223 | " role=role,\n", 224 | " train_instance_count=1,\n", 225 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 226 | " hyperparameters=hyperparameters,\n", 227 | " py_version=\"py3\",\n", 228 | " framework_version=\"1.5\",\n", 229 | " sagemaker_session=sagemaker_session, # Note: Do not use local session as share memory cannot be altered\n", 230 | " output_path=train_output,\n", 231 | " checkpoint_s3_uri=s3_checkpoint,\n", 232 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 233 | " # container_log_level=logging.DEBUG,\n", 234 | ")\n", 235 | "\n", 236 | "model.fit(neu_cls_prepared_s3)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "Then, we deploy our model which takes about **8 minutes** to complete" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "tags": [] 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "%%time\n", 255 | "predictor = model.deploy(\n", 256 | " initial_instance_count=1,\n", 257 | " instance_type=\"ml.m5.xlarge\",\n", 258 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-classification-endpoint\",\n", 259 | ")" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "## Inference\n", 267 | "\n", 268 | "We are ready to test our model by providing some test data and compare the actual labels with the predicted one" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "tags": [] 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "from sagemaker_defect_detection import get_transform\n", 280 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc\n", 281 | "\n", 282 | "ID = 100\n", 283 | "assert 0 <= ID <= 300\n", 284 | "test_dataset = NEUCLS(DATA_PATH, split=\"test\", transform=get_transform(\"test\"), seed=SEED)\n", 285 | "image, label = test_dataset[ID]\n", 286 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 287 | "np_load_old = np.load\n", 288 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 289 | "outputs = predictor.predict(image.unsqueeze(0).numpy())\n", 290 | "np.load = np_load_old\n", 291 | "\n", 292 | "_, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)\n", 293 | "\n", 294 | "image_unnorm = unnormalize_to_hwc(image)\n", 295 | "visualize(image_unnorm, label, predicted.item())" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Optional: Delete the endpoint and model\n", 303 | "\n", 304 | "When you are done with the endpoint, you should clean it up.\n", 305 | "\n", 306 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "predictor.delete_model()\n", 316 | "predictor.delete_endpoint()" 317 | ] 318 | } 319 | ], 320 | "metadata": { 321 | "kernelspec": { 322 | "display_name": "Python 3.6.10 64-bit", 323 | "metadata": { 324 | "interpreter": { 325 | "hash": "4c1e195df8d07db5ee7a78f454b46c3f2e14214bf8c9489d2db5cf8f372ff2ed" 326 | } 327 | }, 328 | "name": "python3" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.6.10-final" 341 | } 342 | }, 343 | "nbformat": 4, 344 | "nbformat_minor": 2 345 | } 346 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.black] 8 | line-length = 119 9 | exclude = ''' 10 | ( 11 | /( 12 | \.eggs 13 | | \.git 14 | | \.mypy_cache 15 | | build 16 | | dist 17 | ) 18 | ) 19 | ''' 20 | 21 | [tool.portray] 22 | modules = ["src/sagemaker_defect_detection", "src/prepare_data/neu.py"] 23 | docs_dir = "docs" 24 | 25 | [tool.portray.mkdocs.theme] 26 | favicon = "docs/sagemaker.png" 27 | logo = "docs/sagemaker.png" 28 | name = "material" 29 | palette = {primary = "dark blue"} 30 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | patool 2 | pyunpack 3 | torch 4 | torchvision 5 | albumentations 6 | pytorch_lightning 7 | pycocotools -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.8 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | absl-py 8 | # via tensorboard 9 | albumentations 10 | # via -r requirements.in 11 | cachetools 12 | # via google-auth 13 | certifi 14 | # via requests 15 | charset-normalizer 16 | # via requests 17 | cycler 18 | # via matplotlib 19 | cython 20 | # via pycocotools 21 | easyprocess 22 | # via pyunpack 23 | entrypoint2 24 | # via pyunpack 25 | fonttools 26 | # via matplotlib 27 | future 28 | # via 29 | # pytorch-lightning 30 | # torch 31 | google-auth 32 | # via 33 | # google-auth-oauthlib 34 | # tensorboard 35 | google-auth-oauthlib 36 | # via tensorboard 37 | grpcio 38 | # via tensorboard 39 | idna 40 | # via requests 41 | imageio 42 | # via 43 | # imgaug 44 | # scikit-image 45 | imgaug 46 | # via albumentations 47 | importlib-metadata 48 | # via markdown 49 | kiwisolver 50 | # via matplotlib 51 | markdown 52 | # via tensorboard 53 | matplotlib 54 | # via 55 | # imgaug 56 | # pycocotools 57 | networkx 58 | # via scikit-image 59 | numpy 60 | # via 61 | # albumentations 62 | # imageio 63 | # imgaug 64 | # matplotlib 65 | # opencv-python 66 | # opencv-python-headless 67 | # pytorch-lightning 68 | # pywavelets 69 | # scikit-image 70 | # scipy 71 | # tensorboard 72 | # tifffile 73 | # torch 74 | # torchvision 75 | oauthlib 76 | # via requests-oauthlib 77 | opencv-python 78 | # via imgaug 79 | opencv-python-headless 80 | # via albumentations 81 | packaging 82 | # via 83 | # matplotlib 84 | # scikit-image 85 | patool 86 | # via -r requirements.in 87 | pillow 88 | # via 89 | # imageio 90 | # imgaug 91 | # matplotlib 92 | # scikit-image 93 | # torchvision 94 | protobuf 95 | # via tensorboard 96 | pyasn1 97 | # via 98 | # pyasn1-modules 99 | # rsa 100 | pyasn1-modules 101 | # via google-auth 102 | pycocotools 103 | # via -r requirements.in 104 | pyparsing 105 | # via 106 | # matplotlib 107 | # packaging 108 | python-dateutil 109 | # via matplotlib 110 | pytorch-lightning 111 | # via -r requirements.in 112 | pyunpack 113 | # via -r requirements.in 114 | pywavelets 115 | # via scikit-image 116 | pyyaml 117 | # via 118 | # albumentations 119 | # pytorch-lightning 120 | requests 121 | # via 122 | # requests-oauthlib 123 | # tensorboard 124 | requests-oauthlib 125 | # via google-auth-oauthlib 126 | rsa 127 | # via google-auth 128 | scikit-image 129 | # via imgaug 130 | scipy 131 | # via 132 | # albumentations 133 | # imgaug 134 | # scikit-image 135 | shapely 136 | # via imgaug 137 | six 138 | # via 139 | # google-auth 140 | # grpcio 141 | # imgaug 142 | # python-dateutil 143 | tensorboard 144 | # via pytorch-lightning 145 | tensorboard-data-server 146 | # via tensorboard 147 | tensorboard-plugin-wit 148 | # via tensorboard 149 | tifffile 150 | # via scikit-image 151 | torch 152 | # via 153 | # -r requirements.in 154 | # pytorch-lightning 155 | # torchvision 156 | torchvision 157 | # via -r requirements.in 158 | tqdm 159 | # via pytorch-lightning 160 | urllib3 161 | # via requests 162 | werkzeug 163 | # via tensorboard 164 | wheel 165 | # via tensorboard 166 | zipp 167 | # via importlib-metadata 168 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | NOW=$(date +"%x %r %Z") 5 | echo "Time: $NOW" 6 | 7 | if [ $# -lt 3 ]; then 8 | echo "Please provide the solution name as well as the base S3 bucket name and the region to run build script." 9 | echo "For example: bash ./scripts/build.sh trademarked-solution-name sagemaker-solutions-build us-west-2" 10 | exit 1 11 | fi 12 | 13 | REGION=$3 14 | SOURCE_REGION="us-west-2" 15 | echo "Region: $REGION, source region: $SOURCE_REGION" 16 | BASE_DIR="$(dirname "$(dirname "$(readlink -f "$0")")")" 17 | echo "Base dir: $BASE_DIR" 18 | 19 | rm -rf build 20 | 21 | # build python package 22 | echo "Python build and package" 23 | python -m pip install --upgrade pip 24 | python -m pip install --upgrade wheel setuptools 25 | python setup.py build sdist bdist_wheel 26 | 27 | find . | grep -E "(__pycache__|\.pyc|\.pyo$|\.egg*|\lightning_logs)" | xargs rm -rf 28 | 29 | echo "Add requirements for SageMaker" 30 | cp requirements.txt build/lib/ 31 | cd build/lib || exit 32 | mv sagemaker_defect_detection/{classifier.py,detector.py} . 33 | touch source_dir.tar.gz 34 | tar --exclude=source_dir.tar.gz -czvf source_dir.tar.gz . 35 | echo "Only keep source_dir.tar.gz for SageMaker" 36 | find . ! -name "source_dir.tar.gz" -type f -exec rm -r {} + 37 | rm -rf sagemaker_defect_detection 38 | 39 | cd - || exit 40 | 41 | mv dist build 42 | 43 | # add notebooks to build 44 | echo "Prepare notebooks and add to build" 45 | cp -r notebooks build 46 | rm -rf build/notebooks/*neu* # remove local datasets for build 47 | for nb in build/notebooks/*.ipynb; do 48 | python "$BASE_DIR"/scripts/set_kernelspec.py --notebook "$nb" --display-name "Python 3 (PyTorch JumpStart)" --kernel "HUB_1P_IMAGE" 49 | done 50 | 51 | echo "Copy src to build" 52 | cp -r src build 53 | 54 | # add solution assistant 55 | echo "Solution assistant lambda function" 56 | cd cloudformation/solution-assistant/ || exit 57 | python -m pip install -r requirements.txt -t ./src/site-packages 58 | 59 | cd - || exit 60 | 61 | echo "Clean up pyc files, needed to avoid security issues. See: https://blog.jse.li/posts/pyc/" 62 | find cloudformation | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf 63 | cp -r cloudformation/solution-assistant build/ 64 | cd build/solution-assistant/src || exit 65 | zip -q -r9 "$BASE_DIR"/build/solution-assistant.zip -- * 66 | 67 | cd - || exit 68 | rm -rf build/solution-assistant 69 | 70 | if [ -z "$4" ] || [ "$4" == 'mainline' ]; then 71 | s3_prefix="s3://$2-$3/$1" 72 | else 73 | s3_prefix="s3://$2-$3/$1-$4" 74 | fi 75 | 76 | # cleanup and copy the build artifacts 77 | echo "Removing the existing objects under $s3_prefix" 78 | aws s3 rm --recursive "$s3_prefix" --region "$REGION" 79 | echo "Copying new objects to $s3_prefix" 80 | aws s3 sync . "$s3_prefix" --delete --region "$REGION" \ 81 | --exclude ".git/*" \ 82 | --exclude ".vscode/*" \ 83 | --exclude ".mypy_cache/*" \ 84 | --exclude "logs/*" \ 85 | --exclude "stack_outputs.json" \ 86 | --exclude "src/sagemaker_defect_detection/lightning_logs/*" \ 87 | --exclude "notebooks/*neu*/*" 88 | 89 | echo "Copying solution artifacts" 90 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/demo/model.tar.gz" "$s3_prefix"/demo/model.tar.gz --source-region "$SOURCE_REGION" 91 | 92 | mkdir -p build/pretrained/ 93 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/pretrained/model.tar.gz" build/pretrained --source-region "$SOURCE_REGION" && 94 | cd build/pretrained/ && tar -xf model.tar.gz && cd .. && 95 | aws s3 sync pretrained "$s3_prefix"/pretrained/ --delete --region "$REGION" 96 | 97 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/data/NEU-CLS.zip" "$s3_prefix"/data/ --source-region "$SOURCE_REGION" 98 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/data/NEU-DET.zip" "$s3_prefix"/data/ --source-region "$SOURCE_REGION" 99 | 100 | echo "Add docs to build" 101 | aws s3 sync "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/docs" "$s3_prefix"/docs --delete --region "$REGION" 102 | aws s3 sync "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/docs" "$s3_prefix"/build/docs --delete --region "$REGION" 103 | -------------------------------------------------------------------------------- /scripts/find_best_ckpt.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | 5 | def get_score(s: str) -> float: 6 | """Gets the criterion score from .ckpt formated with ModelCheckpoint 7 | 8 | Parameters 9 | ---------- 10 | s : str 11 | Assumption is the last number is the desired number carved in .ckpt 12 | 13 | Returns 14 | ------- 15 | float 16 | The criterion float 17 | """ 18 | return float(re.findall(r"(\d+.\d+).ckpt", s)[0]) 19 | 20 | 21 | def main(path: str, op: str) -> str: 22 | """Finds the best ckpt path 23 | 24 | Parameters 25 | ---------- 26 | path : str 27 | ckpt path 28 | op : str 29 | "max" (for mAP for example) or "min" (for loss) 30 | 31 | Returns 32 | ------- 33 | str 34 | A ckpt path 35 | """ 36 | ckpts = list(map(str, Path(path).glob("*.ckpt"))) 37 | if not len(ckpts): 38 | return 39 | 40 | ckpt_score_dict = {ckpt: get_score(ckpt) for ckpt in ckpts} 41 | op = max if op == "max" else min 42 | out = op(ckpt_score_dict, key=ckpt_score_dict.get) 43 | print(out) # need to flush for bash 44 | return out 45 | 46 | 47 | if __name__ == "__main__": 48 | import sys 49 | 50 | if len(sys.argv) < 3: 51 | print("provide checkpoint path and op either max or min") 52 | sys.exit(1) 53 | 54 | main(sys.argv[1], sys.argv[2]) 55 | -------------------------------------------------------------------------------- /scripts/set_kernelspec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def set_kernel_spec(notebook_filepath, display_name, kernel_name): 6 | with open(notebook_filepath, "r") as openfile: 7 | notebook = json.load(openfile) 8 | kernel_spec = {"display_name": display_name, "language": "python", "name": kernel_name} 9 | if "metadata" not in notebook: 10 | notebook["metadata"] = {} 11 | notebook["metadata"]["kernelspec"] = kernel_spec 12 | with open(notebook_filepath, "w") as openfile: 13 | json.dump(notebook, openfile) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--notebook") 19 | parser.add_argument("--display-name") 20 | parser.add_argument("--kernel") 21 | args = parser.parse_args() 22 | set_kernel_spec(args.notebook, args.display_name, args.kernel) 23 | -------------------------------------------------------------------------------- /scripts/train_classifier.sh: -------------------------------------------------------------------------------- 1 | CLASSIFICATION_DATA_PATH=$1 2 | BACKBONE=$2 3 | BASE_DIR="$(dirname $(dirname $(readlink -f $0)))" 4 | CLASSIFIER_SCRIPT="${BASE_DIR}/src/sagemaker_defect_detection/classifier.py" 5 | MFN_LOGS="${BASE_DIR}/logs/" 6 | 7 | python ${CLASSIFIER_SCRIPT} \ 8 | --data-path=${CLASSIFICATION_DATA_PATH} \ 9 | --save-path=${MFN_LOGS} \ 10 | --backbone=${BACKBONE} \ 11 | --gpus=1 \ 12 | --learning-rate=1e-3 \ 13 | --epochs=50 14 | -------------------------------------------------------------------------------- /scripts/train_detector.sh: -------------------------------------------------------------------------------- 1 | NOW=$(date +"%x %r %Z") 2 | echo "Time: ${NOW}" 3 | DETECTION_DATA_PATH=$1 4 | BACKBONE=$2 5 | BASE_DIR="$(dirname $(dirname $(readlink -f $0)))" 6 | DETECTOR_SCRIPT="${BASE_DIR}/src/sagemaker_defect_detection/detector.py" 7 | LOG_DIR="${BASE_DIR}/logs" 8 | MFN_LOGS="${LOG_DIR}/classification_logs" 9 | RPN_LOGS="${LOG_DIR}/rpn_logs" 10 | ROI_LOGS="${LOG_DIR}/roi_logs" 11 | FINETUNED_RPN_LOGS="${LOG_DIR}/finetune_rpn_logs" 12 | FINETUNED_ROI_LOGS="${LOG_DIR}/finetune_roi_logs" 13 | FINETUNED_FINAL_LOGS="${LOG_DIR}/finetune_final_logs" 14 | EXTRA_FINETUNED_RPN_LOGS="${LOG_DIR}/extra_finetune_rpn_logs" 15 | EXTRA_FINETUNED_ROI_LOGS="${LOG_DIR}/extra_finetune_roi_logs" 16 | EXTRA_FINETUNING_STEPS=3 17 | 18 | function find_best_ckpt() { 19 | python ${BASE_DIR}/scripts/find_best_ckpt.py $1 $2 20 | } 21 | 22 | function train_step() { 23 | echo "training step $1" 24 | case $1 in 25 | "1") 26 | echo "skipping step 1 and use 'train_classifier.sh'" 27 | ;; 28 | "2") # train rpn 29 | python ${DETECTOR_SCRIPT} \ 30 | --data-path=${DETECTION_DATA_PATH} \ 31 | --backbone=${BACKBONE} \ 32 | --train-rpn \ 33 | --pretrained-mfn-ckpt=$(find_best_ckpt "${MFN_LOGS}" "max") \ 34 | --save-path=${RPN_LOGS} \ 35 | --gpus=-1 --distributed-backend=ddp \ 36 | --epochs=100 37 | ;; 38 | "3") # train roi 39 | python ${DETECTOR_SCRIPT} \ 40 | --data-path=${DETECTION_DATA_PATH} \ 41 | --backbone=${BACKBONE} \ 42 | --train-roi \ 43 | --pretrained-rpn-ckpt=$(find_best_ckpt "${RPN_LOGS}" "min") \ 44 | --save-path=${ROI_LOGS} \ 45 | --gpus=-1 --distributed-backend=ddp \ 46 | --epochs=100 47 | ;; 48 | "4") # finetune rpn 49 | python ${DETECTOR_SCRIPT} \ 50 | --data-path=${DETECTION_DATA_PATH} \ 51 | --backbone=${BACKBONE} \ 52 | --finetune-rpn \ 53 | --pretrained-rpn-ckpt=$(find_best_ckpt "${RPN_LOGS}" "min") \ 54 | --pretrained-roi-ckpt=$(find_best_ckpt "${ROI_LOGS}" "min") \ 55 | --save-path=${FINETUNED_RPN_LOGS} \ 56 | --gpus=-1 --distributed-backend=ddp \ 57 | --learning-rate=1e-4 \ 58 | --epochs=100 59 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_RPN_LOGS}" "max") 60 | ;; 61 | "5") # finetune roi 62 | python ${DETECTOR_SCRIPT} \ 63 | --data-path=${DETECTION_DATA_PATH} \ 64 | --backbone=${BACKBONE} \ 65 | --finetune-roi \ 66 | --finetuned-rpn-ckpt=$(find_best_ckpt "${FINETUNED_RPN_LOGS}" "max") \ 67 | --pretrained-roi-ckpt=$(find_best_ckpt "${ROI_LOGS}" "min") \ 68 | --save-path=${FINETUNED_ROI_LOGS} \ 69 | --gpus=-1 --distributed-backend=ddp \ 70 | --learning-rate=1e-4 \ 71 | --epochs=100 72 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_ROI_LOGS}" "max") 73 | ;; 74 | "extra_rpn") # initially EXTRA_FINETUNED_*_LOGS is a copy of FINETUNED_*_LOGS 75 | python ${DETECTOR_SCRIPT} \ 76 | --data-path=${DETECTION_DATA_PATH} \ 77 | --backbone=${BACKBONE} \ 78 | --finetune-rpn \ 79 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 80 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 81 | --save-path=${EXTRA_FINETUNED_RPN_LOGS} \ 82 | --gpus=-1 --distributed-backend=ddp \ 83 | --learning-rate=1e-4 \ 84 | --epochs=100 85 | ;; 86 | "extra_roi") # initially EXTRA_FINETUNED_*_LOGS is a copy of FINETUNED_*_LOGS 87 | python ${DETECTOR_SCRIPT} \ 88 | --data-path=${DETECTION_DATA_PATH} \ 89 | --backbone=${BACKBONE} \ 90 | --finetune-roi \ 91 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 92 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 93 | --save-path=${EXTRA_FINETUNED_ROI_LOGS} \ 94 | --gpus=-1 --distributed-backend=ddp \ 95 | --learning-rate=1e-4 \ 96 | --epochs=100 97 | ;; 98 | "joint") # final 99 | python ${DETECTOR_SCRIPT} \ 100 | --data-path=${DETECTION_DATA_PATH} \ 101 | --backbone=${BACKBONE} \ 102 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 103 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 104 | --save-path="${FINETUNED_FINAL_LOGS}" \ 105 | --gpus=-1 --distributed-backend=ddp \ 106 | --learning-rate=1e-3 \ 107 | --epochs=300 108 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_FINAL_LOGS}" "max") 109 | ;; 110 | 111 | *) ;; 112 | esac 113 | } 114 | 115 | function train_wait_to_finish() { 116 | train_step $1 & 117 | BPID=$! 118 | wait $BPID 119 | } 120 | 121 | function run() { 122 | if [ "$1" != "" ]; then 123 | train_step $1 124 | else 125 | nvidia-smi | grep python | awk '{ print $3 }' | xargs -n1 kill -9 >/dev/null 2>&1 126 | read -p "Training all steps from scratch? (Y/N): " confirm && [[ $confirm == [yY] || $confirm == [yY][eE][sS] ]] 127 | if [ "$confirm" == "Y" ]; then 128 | for i in {1..5}; do 129 | train_wait_to_finish $i 130 | done 131 | echo "finished all the training steps" 132 | mkdir -p "${EXTRA_FINETUNED_RPN_LOGS}" && cp -r "${FINETUNED_RPN_LOGS}/"* "${EXTRA_FINETUNED_RPN_LOGS}" 133 | mkdir -p "${EXTRA_FINETUNED_ROI_LOGS}" && cp -r "${FINETUNED_ROI_LOGS}/"* "${EXTRA_FINETUNED_ROI_LOGS}" 134 | fi 135 | echo "repeating extra finetuning steps ${EXTRA_FINETUNING_STEPS} more times" 136 | for i in {1..${EXTRA_FINETUNING_STEPS}}; do 137 | train_wait_to_finish "extra_rpn" 138 | train_wait_to_finish "extra_roi" 139 | done 140 | echo "final joint training" 141 | train_step "joint" 142 | fi 143 | exit 0 144 | } 145 | 146 | run $1 147 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | python_files = 7 | test_*.py 8 | 9 | [metadata] 10 | # license_files = LICENSE 11 | 12 | [check-manifest] 13 | ignore = 14 | *.yaml 15 | .github 16 | .github/* 17 | build 18 | deploy 19 | notebook 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | ROOT = Path(__file__).parent.resolve() 6 | 7 | long_description = (ROOT / "README.md").read_text(encoding="utf-8") 8 | 9 | dev_dependencies = ["pre-commit", "mypy==0.781", "black==20.8b1", "nbstripout==0.3.7", "black-nb==0.3.0"] 10 | test_dependencies = ["pytest>=6.0"] 11 | doc_dependencies = ["portray>=1.4.0"] 12 | 13 | setup( 14 | name="sagemaker_defect_detection", 15 | version="0.1", 16 | description="Detect Defects in Products from their Images using Amazon SageMaker ", 17 | long_description=long_description, 18 | author="Ehsan M. Kermani", 19 | python_requires=">=3.6", 20 | package_dir={"": "src"}, 21 | packages=find_packages("src", exclude=["tests", "tests/*"]), 22 | install_requires=open(str(ROOT / "requirements.txt"), "r").read(), 23 | extras_require={"dev": dev_dependencies, "test": test_dependencies, "doc": doc_dependencies}, 24 | license="Apache License 2.0", 25 | classifiers=[ 26 | "Development Status :: 3 - Alpha", 27 | "Programming Language :: Python :: 3.6", 28 | "Programming Language :: Python :: 3.7", 29 | "Programming Language :: Python :: 3.8", 30 | "Programming Language :: Python :: 3 :: Only", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /src/prepare_RecordIO.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import argparse 5 | import numpy as np 6 | from collections import defaultdict 7 | from pathlib import Path 8 | 9 | 10 | def write_line(img_path, width, height, boxes, ids, idx): 11 | """Create a line for each image with annotations, width, height and image name.""" 12 | # for header, we use minimal length 2, plus width and height 13 | # with A: 4, B: 5, C: width, D: height 14 | A = 4 15 | B = 5 16 | C = width 17 | D = height 18 | # concat id and bboxes 19 | labels = np.hstack((ids.reshape(-1, 1), boxes)).astype("float") 20 | # normalized bboxes (recommanded) 21 | labels[:, (1, 3)] /= float(width) 22 | labels[:, (2, 4)] /= float(height) 23 | # flatten 24 | labels = labels.flatten().tolist() 25 | str_idx = [str(idx)] 26 | str_header = [str(x) for x in [A, B, C, D]] 27 | str_labels = [str(x) for x in labels] 28 | str_path = [img_path] 29 | line = "\t".join(str_idx + str_header + str_labels + str_path) + "\n" 30 | return line 31 | 32 | 33 | # adapt from __main__ from im2rec.py 34 | def write_lst(output_file, ids, images_annotations): 35 | 36 | all_labels = set() 37 | image_info = {} 38 | for entry in images_annotations['images']: 39 | if entry["id"] in ids: 40 | image_info[entry["id"]] = entry 41 | annotations_info = {} # one annotation for each id (ie., image) 42 | for entry in images_annotations['annotations']: 43 | image_id = entry['image_id'] 44 | if image_id in ids: 45 | if image_id not in annotations_info: 46 | annotations_info[image_id] = {'boxes': [], 'labels': []} 47 | annotations_info[image_id]['boxes'].append(entry['bbox']) 48 | annotations_info[image_id]['labels'].append(entry['category_id']) 49 | all_labels.add(entry['category_id']) 50 | labels_list = [label for label in all_labels] 51 | class_to_idx_mapping = {label: idx for idx, label in enumerate(labels_list)} 52 | with open(output_file, "w") as fw: 53 | for i, image_id in enumerate(annotations_info): 54 | im_info = image_info[image_id] 55 | image_file = im_info['file_name'] 56 | height = im_info['height'] 57 | width = im_info['width'] 58 | an_info = annotations_info[image_id] 59 | boxes = np.array(an_info['boxes']) 60 | labels = np.array([class_to_idx_mapping[label] for label in an_info['labels']]) 61 | line = write_line(image_file, width, height, boxes, labels, i) 62 | fw.write(line) 63 | 64 | 65 | def create_lst(data_dir, args, rnd_seed=100): 66 | """Generate an lst file based on annotations file which is used to convert the input data to .rec format.""" 67 | with open(os.path.join(data_dir, 'annotations.json')) as f: 68 | images_annotations = json.loads(f.read()) 69 | 70 | # Size of each class 71 | class_ids = defaultdict(list) 72 | for entry in images_annotations['images']: 73 | cls_ = entry['file_name'].split('_')[0] 74 | class_ids[cls_].append(entry['id']) 75 | print('\ncategory\tnum of images') 76 | print('---------------') 77 | for cls_ in class_ids.keys(): 78 | print(f"{cls_}\t{len(class_ids[cls_])}") 79 | 80 | random.seed(rnd_seed) 81 | 82 | # Split train/val/test image ids 83 | if args.test_ratio: 84 | test_ids = [] 85 | if args.train_ratio + args.test_ratio < 1.0: 86 | val_ids = [] 87 | train_ids = [] 88 | for cls_ in class_ids.keys(): 89 | random.shuffle(class_ids[cls_]) 90 | N = len(class_ids[cls_]) 91 | ids = class_ids[cls_] 92 | 93 | sep = int(N * args.train_ratio) 94 | sep_test = int(N * args.test_ratio) 95 | if args.train_ratio == 1.0: 96 | train_ids.extend(ids) 97 | else: 98 | if args.test_ratio: 99 | test_ids.extend(ids[:sep_test]) 100 | if args.train_ratio + args.test_ratio < 1.0: 101 | val_ids.extend(ids[sep_test + sep:]) 102 | train_ids.extend(ids[sep_test: sep_test + sep]) 103 | 104 | write_lst(args.prefix + "_train.lst", train_ids, images_annotations) 105 | lsts = [args.prefix + "_train.lst"] 106 | if args.test_ratio: 107 | write_lst(args.prefix + "_test.lst", test_ids, images_annotations) 108 | lsts.append(args.prefix + "_test.lst") 109 | if args.train_ratio + args.test_ratio < 1.0: 110 | write_lst(args.prefix + "_val.lst", val_ids, images_annotations) 111 | lsts.append(args.prefix + "_val.lst") 112 | 113 | return lsts 114 | 115 | 116 | def parse_args(): 117 | """Defines all arguments. 118 | 119 | Returns: 120 | args object that contains all the params 121 | """ 122 | parser = argparse.ArgumentParser( 123 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 124 | description="Create an image list or \ 125 | make a record database by reading from an image list", 126 | ) 127 | parser.add_argument("prefix", help="prefix of input/output lst and rec files.") 128 | parser.add_argument("root", help="path to folder containing images.") 129 | 130 | cgroup = parser.add_argument_group("Options for creating image lists") 131 | cgroup.add_argument( 132 | "--exts", nargs="+", default=[".jpeg", ".jpg", ".png"], help="list of acceptable image extensions." 133 | ) 134 | cgroup.add_argument("--train-ratio", type=float, default=0.8, help="Ratio of images to use for training.") 135 | cgroup.add_argument("--test-ratio", type=float, default=0, help="Ratio of images to use for testing.") 136 | cgroup.add_argument( 137 | "--recursive", 138 | action="store_true", 139 | help="If true recursively walk through subdirs and assign an unique label\ 140 | to images in each folder. Otherwise only include images in the root folder\ 141 | and give them label 0.", 142 | ) 143 | args = parser.parse_args() 144 | 145 | args.prefix = os.path.abspath(args.prefix) 146 | args.root = os.path.abspath(args.root) 147 | return args 148 | 149 | 150 | if __name__ == '__main__': 151 | 152 | args = parse_args() 153 | data_dir = Path(args.root).parent 154 | 155 | lsts = create_lst(data_dir, args) 156 | print() 157 | 158 | for lst in lsts: 159 | os.system(f"python3 ../src/im2rec.py {lst} {os.path.join(data_dir, 'images')} --pass-through --pack-label") 160 | print() 161 | -------------------------------------------------------------------------------- /src/prepare_data/neu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dependencies: unzip unrar 3 | python -m pip install patool pyunpack 4 | """ 5 | 6 | from pathlib import Path 7 | import shutil 8 | import re 9 | import os 10 | 11 | try: 12 | from pyunpack import Archive 13 | except ModuleNotFoundError: 14 | print("installing the dependencies `patool` and `pyunpack` for unzipping the data") 15 | import subprocess 16 | 17 | subprocess.run("python -m pip install patool==1.12 pyunpack==0.2.1 -q", shell=True) 18 | from pyunpack import Archive 19 | 20 | CLASSES = { 21 | "crazing": "Cr", 22 | "inclusion": "In", 23 | "pitted_surface": "PS", 24 | "patches": "Pa", 25 | "rolled-in_scale": "RS", 26 | "scratches": "Sc", 27 | } 28 | 29 | 30 | def unpack(path: str) -> None: 31 | path = Path(path) 32 | Archive(str(path)).extractall(str(path.parent)) 33 | return 34 | 35 | 36 | def cp_class_images(data_path: Path, class_name: str, class_path_dest: Path) -> None: 37 | lst = list(data_path.rglob(f"{class_name}_*")) 38 | for img_file in lst: 39 | shutil.copy2(str(img_file), str(class_path_dest / img_file.name)) 40 | 41 | assert len(lst) == len(list(class_path_dest.glob("*"))) 42 | return 43 | 44 | 45 | def cp_image_annotation(data_path: Path, class_name: str, image_path_dest: Path, annotation_path_dest: Path) -> None: 46 | img_lst = sorted(list((data_path / "IMAGES").rglob(f"{class_name}_*"))) 47 | ann_lst = sorted(list((data_path / "ANNOTATIONS").rglob(f"{class_name}_*"))) 48 | assert len(img_lst) == len( 49 | ann_lst 50 | ), f"images count {len(img_lst)} does not match with annotations count {len(ann_lst)} for class {class_name}" 51 | for (img_file, ann_file) in zip(img_lst, ann_lst): 52 | shutil.copy2(str(img_file), str(image_path_dest / img_file.name)) 53 | shutil.copy2(str(ann_file), str(annotation_path_dest / ann_file.name)) 54 | 55 | assert len(list(image_path_dest.glob("*"))) == len(list(annotation_path_dest.glob("*"))) 56 | return 57 | 58 | 59 | def main(data_path: str, output_path: str, archived: bool = True) -> None: 60 | """ 61 | Data preparation 62 | 63 | Parameters 64 | ---------- 65 | data_path : str 66 | Raw data path 67 | output_path : str 68 | Output data path 69 | archived: bool 70 | Whether the file is archived or not (for testing) 71 | 72 | Raises 73 | ------ 74 | ValueError 75 | If the packed data file is different from NEU-CLS or NEU-DET 76 | """ 77 | data_path = Path(data_path) 78 | if archived: 79 | unpack(data_path) 80 | 81 | data_path = data_path.parent / re.search(r"^[^.]*", str(data_path.name)).group(0) 82 | try: 83 | os.remove(str(data_path / "Thumbs.db")) 84 | except FileNotFoundError: 85 | print(f"Thumbs.db is not found. Continuing ...") 86 | pass 87 | except Exception as e: 88 | print(f"{e}: Unknown error!") 89 | raise e 90 | 91 | output_path = Path(output_path) 92 | if data_path.name == "NEU-CLS": 93 | for cls_ in CLASSES.values(): 94 | cls_path = output_path / cls_ 95 | cls_path.mkdir(exist_ok=True) 96 | cp_class_images(data_path, cls_, cls_path) 97 | elif data_path.name == "NEU-DET": 98 | for cls_ in CLASSES: 99 | cls_path = output_path / CLASSES[cls_] 100 | image_path = cls_path / "images" 101 | image_path.mkdir(parents=True, exist_ok=True) 102 | annotation_path = cls_path / "annotations" 103 | annotation_path.mkdir(exist_ok=True) 104 | cp_image_annotation(data_path, cls_, image_path, annotation_path) 105 | else: 106 | raise ValueError(f"Unknown data. Choose between `NEU-CLS` and `NEU-DET`. Given {data_path.name}") 107 | 108 | return 109 | 110 | 111 | if __name__ == "__main__": 112 | import sys 113 | 114 | if len(sys.argv) < 3: 115 | print("Provide `data_path` and `output_path`") 116 | sys.exit(1) 117 | 118 | main(sys.argv[1], sys.argv[2]) 119 | print("Done") 120 | -------------------------------------------------------------------------------- /src/prepare_data/test_neu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import neu 6 | 7 | 8 | def tmp_fill(tmpdir_, classes, ext): 9 | for cls_ in classes.values(): 10 | tmpdir_.join(cls_ + "_0" + ext) 11 | # add one more image 12 | tmpdir_.join(classes["crazing"] + "_1" + ext) 13 | return 14 | 15 | 16 | @pytest.fixture() 17 | def tmp_neu(): 18 | def _create(tmpdir, filename): 19 | tmpneu = tmpdir.mkdir(filename) 20 | if filename == "NEU-CLS": 21 | tmp_fill(tmpneu, neu.CLASSES, ".png") 22 | elif filename == "NEU-DET": 23 | imgs = tmpneu.mkdir("IMAGES") 24 | tmp_fill(imgs, neu.CLASSES, ".png") 25 | anns = tmpneu.mkdir("ANNOTATIONS") 26 | tmp_fill(anns, neu.CLASSES, ".xml") 27 | else: 28 | raise ValueError("Not supported") 29 | return tmpneu 30 | 31 | return _create 32 | 33 | 34 | @pytest.mark.parametrize("filename", ["NEU-CLS", "NEU-DET"]) 35 | def test_main(tmpdir, tmp_neu, filename) -> None: 36 | data_path = tmp_neu(tmpdir, filename) 37 | output_path = tmpdir.mkdir("output_path") 38 | neu.main(data_path, output_path, archived=False) 39 | assert len(os.listdir(output_path)) == len(neu.CLASSES), "failed to match number of classes in output" 40 | for p in output_path.visit(): 41 | assert p.check(), "correct path was not created" 42 | return 43 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pytorch_lightning 3 | except ModuleNotFoundError: 4 | print("installing the dependencies for sagemaker_defect_detection package ...") 5 | import subprocess 6 | 7 | subprocess.run( 8 | "python -m pip install -q albumentations==0.4.6 pytorch_lightning==0.8.5 pycocotools==2.0.1", shell=True 9 | ) 10 | 11 | from sagemaker_defect_detection.models.ddn import Classification, Detection, RoI, RPN 12 | from sagemaker_defect_detection.dataset.neu import NEUCLS, NEUDET 13 | from sagemaker_defect_detection.transforms import get_transform, get_augmentation, get_preprocess 14 | 15 | __all__ = [ 16 | "Classification", 17 | "Detection", 18 | "RoI", 19 | "RPN", 20 | "NEUCLS", 21 | "NEUDET", 22 | "get_transform", 23 | "get_augmentation", 24 | "get_preprocess", 25 | ] 26 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/classifier.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | from typing import Dict 3 | import os 4 | from collections import OrderedDict 5 | from argparse import ArgumentParser, Namespace 6 | from multiprocessing import cpu_count 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 15 | import pytorch_lightning.metrics.functional as plm 16 | 17 | from sagemaker_defect_detection import Classification, NEUCLS, get_transform 18 | from sagemaker_defect_detection.utils import load_checkpoint, freeze 19 | 20 | 21 | def metrics(name: str, out: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: 22 | pred = torch.argmax(out, 1).detach() 23 | target = target.detach() 24 | metrics = {} 25 | metrics[name + "_acc"] = plm.accuracy(pred, target) 26 | metrics[name + "_prec"] = plm.precision(pred, target) 27 | metrics[name + "_recall"] = plm.recall(pred, target) 28 | metrics[name + "_f1_score"] = plm.recall(pred, target) 29 | return metrics 30 | 31 | 32 | class DDNClassification(pl.LightningModule): 33 | def __init__( 34 | self, 35 | data_path: str, 36 | backbone: str, 37 | freeze_backbone: bool, 38 | num_classes: int, 39 | learning_rate: float, 40 | batch_size: int, 41 | momentum: float, 42 | weight_decay: float, 43 | seed: int, 44 | **kwargs 45 | ) -> None: 46 | super().__init__() 47 | self.data_path = data_path 48 | self.backbone = backbone 49 | self.freeze_backbone = freeze_backbone 50 | self.num_classes = num_classes 51 | self.learning_rate = learning_rate 52 | self.batch_size = batch_size 53 | self.momentum = momentum 54 | self.weight_decay = weight_decay 55 | self.seed = seed 56 | 57 | self.train_dataset = NEUCLS(self.data_path, split="train", transform=get_transform("train"), seed=self.seed) 58 | self.val_dataset = NEUCLS(self.data_path, split="val", transform=get_transform("val"), seed=self.seed) 59 | self.test_dataset = NEUCLS(self.data_path, split="test", transform=get_transform("test"), seed=self.seed) 60 | 61 | self.model = Classification(self.backbone, self.num_classes) 62 | if self.freeze_backbone: 63 | for param in self.model.mfn.backbone.parameters(): 64 | param.requires_grad = False 65 | 66 | def forward(self, x): # ignore 67 | return self.model(x) 68 | 69 | def training_step(self, batch, batch_idx): 70 | images, target = batch 71 | output = self(images) 72 | loss_val = F.cross_entropy(output, target) 73 | metrics_dict = metrics("train", output, target) 74 | tqdm_dict = {"train_loss": loss_val, **metrics_dict} 75 | output = OrderedDict({"loss": loss_val, "progress_bar": tqdm_dict, "log": tqdm_dict}) 76 | return output 77 | 78 | def validation_step(self, batch, batch_idx): 79 | images, target = batch 80 | output = self(images) 81 | loss_val = F.cross_entropy(output, target) 82 | metrics_dict = metrics("val", output, target) 83 | output = OrderedDict({"val_loss": loss_val, **metrics_dict}) 84 | return output 85 | 86 | def validation_epoch_end(self, outputs): 87 | log_dict = {} 88 | for metric_name in outputs[0]: 89 | log_dict[metric_name] = torch.stack([x[metric_name] for x in outputs]).mean() 90 | 91 | return {"log": log_dict, "progress_bar": log_dict, **log_dict} 92 | 93 | def test_step(self, batch, batch_idx): 94 | images, target = batch 95 | output = self(images) 96 | loss_val = F.cross_entropy(output, target) 97 | metrics_dict = metrics("test", output, target) 98 | output = OrderedDict({"test_loss": loss_val, **metrics_dict}) 99 | return output 100 | 101 | def test_epoch_end(self, outputs): 102 | log_dict = {} 103 | for metric_name in outputs[0]: 104 | log_dict[metric_name] = torch.stack([x[metric_name] for x in outputs]).mean() 105 | 106 | return {"log": log_dict, "progress_bar": log_dict, **log_dict} 107 | 108 | def configure_optimizers(self): 109 | optimizer = optim.SGD( 110 | self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay 111 | ) 112 | return optimizer 113 | 114 | def train_dataloader(self): 115 | train_loader = DataLoader( 116 | dataset=self.train_dataset, 117 | batch_size=self.batch_size, 118 | shuffle=True, 119 | num_workers=cpu_count(), 120 | ) 121 | return train_loader 122 | 123 | def val_dataloader(self): 124 | val_loader = DataLoader( 125 | self.val_dataset, 126 | batch_size=self.batch_size, 127 | shuffle=False, 128 | num_workers=cpu_count() // 2, 129 | ) 130 | return val_loader 131 | 132 | def test_dataloader(self): 133 | test_loader = DataLoader( 134 | self.test_dataset, 135 | batch_size=self.batch_size, 136 | shuffle=False, 137 | num_workers=cpu_count(), 138 | ) 139 | return test_loader 140 | 141 | @staticmethod 142 | def add_model_specific_args(parent_parser): # pragma: no-cover 143 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 144 | aa = parser.add_argument 145 | aa( 146 | "--data-path", 147 | metavar="DIR", 148 | type=str, 149 | default=os.getenv("SM_CHANNEL_TRAINING", ""), 150 | ) 151 | aa( 152 | "--backbone", 153 | default="resnet34", 154 | ) 155 | aa( 156 | "--freeze-backbone", 157 | action="store_true", 158 | ) 159 | aa( 160 | "--num-classes", 161 | default=6, 162 | type=int, 163 | metavar="N", 164 | ) 165 | aa( 166 | "-b", 167 | "--batch-size", 168 | default=64, 169 | type=int, 170 | metavar="N", 171 | ) 172 | aa( 173 | "--lr", 174 | "--learning-rate", 175 | default=1e-3, 176 | type=float, 177 | metavar="LR", 178 | dest="learning_rate", 179 | ) 180 | aa("--momentum", default=0.9, type=float, metavar="M", help="momentum") 181 | aa( 182 | "--wd", 183 | "--weight-decay", 184 | default=1e-4, 185 | type=float, 186 | metavar="W", 187 | dest="weight_decay", 188 | ) 189 | aa( 190 | "--seed", 191 | type=int, 192 | default=42, 193 | ) 194 | return parser 195 | 196 | 197 | def get_args() -> Namespace: 198 | parent_parser = ArgumentParser(add_help=False) 199 | aa = parent_parser.add_argument 200 | aa("--epochs", type=int, default=100, help="number of training epochs") 201 | aa("--save-path", metavar="DIR", default=os.getenv("SM_MODEL_DIR", ""), type=str, help="path to save output") 202 | aa("--gpus", type=int, default=os.getenv("SM_NUM_GPUS", 1), help="how many gpus") 203 | aa( 204 | "--distributed-backend", 205 | type=str, 206 | default="", 207 | choices=("dp", "ddp", "ddp2"), 208 | help="supports three options dp, ddp, ddp2", 209 | ) 210 | aa("--use-16bit", dest="use_16bit", action="store_true", help="if true uses 16 bit precision") 211 | 212 | parser = DDNClassification.add_model_specific_args(parent_parser) 213 | return parser.parse_args() 214 | 215 | 216 | def model_fn(model_dir): 217 | # TODO: `model_fn` doesn't get more args 218 | # see: https://github.com/aws/sagemaker-inference-toolkit/issues/65 219 | backbone = "resnet34" 220 | num_classes = 6 221 | 222 | model = load_checkpoint(Classification(backbone, num_classes), model_dir, prefix="model") 223 | model = model.eval() 224 | freeze(model) 225 | return model 226 | 227 | 228 | def main(args: Namespace) -> None: 229 | model = DDNClassification(**vars(args)) 230 | 231 | if args.seed is not None: 232 | pl.seed_everything(args.seed) 233 | if torch.cuda.device_count() > 1: 234 | torch.cuda.manual_seed_all(args.seed) 235 | 236 | # TODO: add deterministic training 237 | # torch.backends.cudnn.deterministic = True 238 | 239 | checkpoint_callback = ModelCheckpoint( 240 | filepath=os.path.join(args.save_path, "{epoch}-{val_loss:.3f}-{val_acc:.3f}"), 241 | save_top_k=1, 242 | verbose=True, 243 | monitor="val_acc", 244 | mode="max", 245 | ) 246 | early_stop_callback = EarlyStopping("val_loss", patience=10) 247 | trainer = pl.Trainer( 248 | default_root_dir=args.save_path, 249 | gpus=args.gpus, 250 | max_epochs=args.epochs, 251 | early_stop_callback=early_stop_callback, 252 | checkpoint_callback=checkpoint_callback, 253 | gradient_clip_val=10, 254 | num_sanity_val_steps=0, 255 | distributed_backend=args.distributed_backend or None, 256 | # precision=16 if args.use_16bit else 32, # TODO: amp apex support 257 | ) 258 | 259 | trainer.fit(model) 260 | trainer.test() 261 | return 262 | 263 | 264 | if __name__ == "__main__": 265 | main(get_args()) 266 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/src/sagemaker_defect_detection/dataset/__init__.py -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/dataset/neu.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Callable 2 | import os 3 | from pathlib import Path 4 | from collections import namedtuple 5 | 6 | from xml.etree.ElementTree import ElementTree 7 | 8 | import numpy as np 9 | 10 | import cv2 11 | 12 | import torch 13 | from torch.utils.data.dataset import Dataset 14 | 15 | from torchvision.datasets import ImageFolder 16 | 17 | 18 | class NEUCLS(ImageFolder): 19 | """ 20 | NEU-CLS dataset processing and loading 21 | """ 22 | 23 | def __init__( 24 | self, 25 | root: str, 26 | split: str, 27 | augmentation: Optional[Callable] = None, 28 | preprocessing: Optional[Callable] = None, 29 | seed: int = 123, 30 | **kwargs, 31 | ) -> None: 32 | """ 33 | NEU-CLS dataset 34 | 35 | Parameters 36 | ---------- 37 | root : str 38 | Dataset root path 39 | split : str 40 | Data split from train, val and test 41 | augmentation : Optional[Callable], optional 42 | Image augmentation function, by default None 43 | preprocess : Optional[Callable], optional 44 | Image preprocessing function, by default None 45 | seed : int, optional 46 | Random number generator seed, by default 123 47 | 48 | Raises 49 | ------ 50 | ValueError 51 | If unsupported split is used 52 | """ 53 | super().__init__(root, **kwargs) 54 | self.samples: List[Tuple[str, int]] 55 | self.split = split 56 | self.augmentation = augmentation 57 | self.preprocessing = preprocessing 58 | n_items = len(self.samples) 59 | np.random.seed(seed) 60 | perm = np.random.permutation(list(range(n_items))) 61 | # TODO: add split ratios as parameters 62 | train_end = int(0.6 * n_items) 63 | val_end = int(0.2 * n_items) + train_end 64 | if split == "train": 65 | self.samples = [self.samples[i] for i in perm[:train_end]] 66 | elif split == "val": 67 | self.samples = [self.samples[i] for i in perm[train_end:val_end]] 68 | elif split == "test": 69 | self.samples = [self.samples[i] for i in perm[val_end:]] 70 | else: 71 | raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}") 72 | 73 | 74 | DetectionSample = namedtuple("DetectionSample", ["image_path", "class_idx", "annotations"]) 75 | 76 | 77 | class NEUDET(Dataset): 78 | """ 79 | NEU-DET dataset processing and loading 80 | """ 81 | 82 | def __init__( 83 | self, 84 | root: str, 85 | split: str, 86 | augmentation: Optional[Callable] = None, 87 | preprocess: Optional[Callable] = None, 88 | seed: int = 123, 89 | ): 90 | """ 91 | NEU-DET dataset 92 | 93 | Parameters 94 | ---------- 95 | root : str 96 | Dataset root path 97 | split : str 98 | Data split from train, val and test 99 | augmentation : Optional[Callable], optional 100 | Image augmentation function, by default None 101 | preprocess : Optional[Callable], optional 102 | Image preprocessing function, by default None 103 | seed : int, optional 104 | Random number generator seed, by default 123 105 | 106 | Raises 107 | ------ 108 | ValueError 109 | If unsupported split is used 110 | """ 111 | super().__init__() 112 | self.root = Path(root) 113 | self.split = split 114 | self.classes, self.class_to_idx = self._find_classes() 115 | self.samples: List[DetectionSample] = self._make_dataset() 116 | self.augmentation = augmentation 117 | self.preprocess = preprocess 118 | n_items = len(self.samples) 119 | np.random.seed(seed) 120 | perm = np.random.permutation(list(range(n_items))) 121 | train_end = int(0.6 * n_items) 122 | val_end = int(0.2 * n_items) + train_end 123 | if split == "train": 124 | self.samples = [self.samples[i] for i in perm[:train_end]] 125 | elif split == "val": 126 | self.samples = [self.samples[i] for i in perm[train_end:val_end]] 127 | elif split == "test": 128 | self.samples = [self.samples[i] for i in perm[val_end:]] 129 | else: 130 | raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}") 131 | 132 | def _make_dataset(self) -> List[DetectionSample]: 133 | instances = [] 134 | base_dir = self.root.expanduser() 135 | for target_cls in sorted(self.class_to_idx.keys()): 136 | cls_idx = self.class_to_idx[target_cls] 137 | target_dir = base_dir / target_cls 138 | if not target_dir.is_dir(): 139 | continue 140 | 141 | images = sorted(list((target_dir / "images").glob("*.jpg"))) 142 | annotations = sorted(list((target_dir / "annotations").glob("*.xml"))) 143 | assert len(images) == len(annotations), f"something is wrong. Mismatched number of images and annotations" 144 | for path, ann in zip(images, annotations): 145 | instances.append(DetectionSample(str(path), int(cls_idx), str(ann))) 146 | 147 | return instances 148 | 149 | def _find_classes(self): 150 | classes = sorted([d.name for d in os.scandir(str(self.root)) if d.is_dir()]) 151 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes, 1)} # no bg label in NEU 152 | return classes, class_to_idx 153 | 154 | @staticmethod 155 | def _get_bboxes(ann: str) -> List[List[int]]: 156 | tree = ElementTree().parse(ann) 157 | bboxes = [] 158 | for bndbox in tree.iterfind("object/bndbox"): 159 | # should subtract 1 like coco? 160 | bbox = [int(bndbox.findtext(t)) - 1 for t in ("xmin", "ymin", "xmax", "ymax")] # type: ignore 161 | assert bbox[2] > bbox[0] and bbox[3] > bbox[1], f"box size error, given {bbox}" 162 | bboxes.append(bbox) 163 | 164 | return bboxes 165 | 166 | def __len__(self): 167 | return len(self.samples) 168 | 169 | def __getitem__(self, idx: int): 170 | # Note: images are grayscaled BUT resnet needs 3 channels 171 | image = cv2.imread(self.samples[idx].image_path) # BGR channel last 172 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 173 | boxes = self._get_bboxes(self.samples[idx].annotations) 174 | num_objs = len(boxes) 175 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 176 | labels = torch.tensor([self.samples[idx].class_idx] * num_objs, dtype=torch.int64) 177 | image_id = torch.tensor([idx], dtype=torch.int64) 178 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64) 179 | 180 | target = {} 181 | target["boxes"] = boxes 182 | target["labels"] = labels 183 | target["image_id"] = image_id 184 | target["iscrowd"] = iscrowd 185 | 186 | if self.augmentation is not None: 187 | sample = self.augmentation(**{"image": image, "bboxes": boxes, "labels": labels}) 188 | image = sample["image"] 189 | target["boxes"] = torch.as_tensor(sample["bboxes"], dtype=torch.float32) 190 | # guards against crops that don't pass the min_visibility augmentation threshold 191 | if not target["boxes"].numel(): 192 | return None 193 | 194 | target["labels"] = torch.as_tensor(sample["labels"], dtype=torch.int64) 195 | 196 | if self.preprocess is not None: 197 | image = self.preprocess(image=image)["image"] 198 | 199 | boxes = target["boxes"] 200 | target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 201 | return image, target, image_id 202 | 203 | def collate_fn(self, batch): 204 | batch = filter(lambda x: x is not None, batch) 205 | return tuple(zip(*batch)) 206 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awslabs/sagemaker-defect-detection/1fe21ddd3b9be5f227728344cd33c6613c3ccccb/src/sagemaker_defect_detection/models/__init__.py -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/models/ddn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 7 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 8 | from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN 9 | from torchvision.models.detection.roi_heads import RoIHeads 10 | from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork 11 | from torchvision.ops import MultiScaleRoIAlign 12 | 13 | 14 | def get_backbone(name: str) -> nn.Module: 15 | """ 16 | Get official pretrained ResNet34 and ResNet50 as backbones 17 | 18 | Parameters 19 | ---------- 20 | name : str 21 | Either `resnet34` or `resnet50` 22 | 23 | Returns 24 | ------- 25 | nn.Module 26 | resnet34 or resnet50 pytorch modules 27 | 28 | Raises 29 | ------ 30 | ValueError 31 | If unsupported name is used 32 | """ 33 | if name == "resnet34": 34 | return torchvision.models.resnet34(pretrained=True) 35 | elif name == "resnet50": 36 | return torchvision.models.resnet50(pretrained=True) 37 | else: 38 | raise ValueError("Unsupported backbone") 39 | 40 | 41 | def init_weights(m) -> None: 42 | """ 43 | Weight initialization 44 | 45 | Parameters 46 | ---------- 47 | m : [type] 48 | Module used in recursive call 49 | """ 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.xavier_normal_(m.weight) 52 | 53 | elif isinstance(m, nn.Linear): 54 | m.weight.data.normal_(0.0, 0.02) 55 | m.bias.data.fill_(0.0) 56 | 57 | return 58 | 59 | 60 | class MFN(nn.Module): 61 | def __init__(self, backbone: str): 62 | """ 63 | Implementation of MFN model as described in 64 | 65 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 66 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 67 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 68 | 69 | Parameters 70 | ---------- 71 | backbone : str 72 | Either `resnet34` or `resnet50` 73 | """ 74 | super().__init__() 75 | self.backbone = get_backbone(backbone) 76 | # input 224x224 -> conv1 output size 112x112 77 | self.start_layer = nn.Sequential( 78 | self.backbone.conv1, # type: ignore 79 | self.backbone.bn1, # type: ignore 80 | self.backbone.relu, # type: ignore 81 | self.backbone.maxpool, # type: ignore 82 | ) 83 | self.r2 = self.backbone.layer1 # 64/256x56x56 <- (resnet34/resnet50) 84 | self.r3 = self.backbone.layer2 # 128/512x28x28 85 | self.r4 = self.backbone.layer3 # 256/1024x14x14 86 | self.r5 = self.backbone.layer4 # 512/2048x7x7 87 | in_channel = 64 if backbone == "resnet34" else 256 88 | self.b2 = nn.Sequential( 89 | nn.Conv2d( 90 | in_channel, in_channel, kernel_size=3, padding=1, stride=2 91 | ), # 56 -> 28 without Relu or batchnorm not in the paper ??? 92 | nn.BatchNorm2d(in_channel), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, stride=2), # 28 -> 14 95 | nn.BatchNorm2d(in_channel), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(in_channel, in_channel * 2, kernel_size=1, padding=0), 98 | nn.BatchNorm2d(in_channel * 2), 99 | nn.ReLU(inplace=True), 100 | ).apply( 101 | init_weights 102 | ) # after r2: 128/512x14x14 <- 103 | self.b3 = nn.MaxPool2d(2) # after r3: 128/512x14x14 <- 104 | in_channel *= 2 # 128/512 105 | self.b4 = nn.Sequential( 106 | nn.Conv2d(in_channel * 2, in_channel, kernel_size=1, padding=0), 107 | nn.BatchNorm2d(in_channel), 108 | nn.ReLU(inplace=True), 109 | ).apply( 110 | init_weights 111 | ) # after r4: 128/512x14x14 112 | in_channel *= 4 # 512 / 2048 113 | self.b5 = nn.Sequential( 114 | nn.ConvTranspose2d( 115 | in_channel, in_channel, kernel_size=3, stride=2, padding=1, output_padding=1 116 | ), # <- after r5 which is 512x7x7 -> 512x14x14 117 | nn.BatchNorm2d(in_channel), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(in_channel, in_channel // 4, kernel_size=1, padding=0), 120 | nn.BatchNorm2d(in_channel // 4), 121 | nn.ReLU(inplace=True), 122 | ).apply(init_weights) 123 | 124 | self.out_channels = 512 if backbone == "resnet34" else 2048 # required for FasterRCNN 125 | 126 | def forward(self, x): 127 | x = self.start_layer(x) 128 | x = self.r2(x) 129 | b2_out = self.b2(x) 130 | x = self.r3(x) 131 | b3_out = self.b3(x) 132 | x = self.r4(x) 133 | b4_out = self.b4(x) 134 | x = self.r5(x) 135 | b5_out = self.b5(x) 136 | # BatchNorm works better than L2 normalize 137 | # out = torch.cat([F.normalize(o, p=2, dim=1) for o in (b2_out, b3_out, b4_out, b5_out)], dim=1) 138 | out = torch.cat((b2_out, b3_out, b4_out, b5_out), dim=1) 139 | return out 140 | 141 | 142 | class Classification(nn.Module): 143 | """ 144 | Classification network 145 | 146 | Parameters 147 | ---------- 148 | backbone : str 149 | Either `resnet34` or `resnet50` 150 | 151 | num_classes : int 152 | Number of classes 153 | """ 154 | 155 | def __init__(self, backbone: str, num_classes: int) -> None: 156 | super().__init__() 157 | self.mfn = MFN(backbone) 158 | self.flatten = nn.Flatten() 159 | self.fc = nn.Linear(self.mfn.out_channels * 14 ** 2, num_classes) 160 | 161 | def forward(self, x): 162 | return self.fc(self.flatten(self.mfn(x))) 163 | 164 | 165 | class RPN(nn.Module): 166 | """ 167 | RPN Module as described in 168 | 169 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 170 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 171 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 172 | """ 173 | 174 | def __init__( 175 | self, 176 | out_channels: int = 512, 177 | rpn_pre_nms_top_n_train: int = 1000, # torchvision default 2000, 178 | rpn_pre_nms_top_n_test: int = 500, # torchvision default 1000, 179 | rpn_post_nms_top_n_train: int = 1000, # torchvision default 2000, 180 | rpn_post_nms_top_n_test: int = 500, # torchvision default 1000, 181 | rpn_nms_thresh: float = 0.7, 182 | rpn_fg_iou_thresh: float = 0.7, 183 | rpn_bg_iou_thresh: float = 0.3, 184 | rpn_batch_size_per_image: int = 256, 185 | rpn_positive_fraction: float = 0.5, 186 | ) -> None: 187 | super().__init__() 188 | rpn_anchor_generator = AnchorGenerator(sizes=((64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) 189 | rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) 190 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 191 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 192 | self.rpn = RegionProposalNetwork( 193 | rpn_anchor_generator, 194 | rpn_head, 195 | rpn_fg_iou_thresh, 196 | rpn_bg_iou_thresh, 197 | rpn_batch_size_per_image, 198 | rpn_positive_fraction, 199 | rpn_pre_nms_top_n, 200 | rpn_post_nms_top_n, 201 | rpn_nms_thresh, 202 | ) 203 | 204 | def forward(self, *args, **kwargs): 205 | return self.rpn(*args, **kwargs) 206 | 207 | 208 | class CustomTwoMLPHead(nn.Module): 209 | def __init__(self, in_channels: int, representation_size: int): 210 | super().__init__() 211 | self.avgpool = nn.AdaptiveAvgPool2d(7) 212 | self.mlp = nn.Sequential( 213 | nn.Linear(in_channels, representation_size), 214 | nn.ReLU(inplace=True), 215 | nn.Linear(representation_size, representation_size), 216 | nn.ReLU(inplace=True), 217 | ) 218 | 219 | def forward(self, x): 220 | x = self.avgpool(x) 221 | x = x.flatten(start_dim=1) 222 | x = self.mlp(x) 223 | return x 224 | 225 | 226 | class RoI(nn.Module): 227 | """ 228 | ROI Module as described in 229 | 230 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 231 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 232 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 233 | """ 234 | 235 | def __init__( 236 | self, 237 | num_classes: int, 238 | box_fg_iou_thresh=0.5, 239 | box_bg_iou_thresh=0.5, 240 | box_batch_size_per_image=512, 241 | box_positive_fraction=0.25, 242 | bbox_reg_weights=None, 243 | box_score_thresh=0.05, 244 | box_nms_thresh=0.5, 245 | box_detections_per_img=100, 246 | ) -> None: 247 | super().__init__() 248 | roi_pooler = MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2) 249 | box_head = CustomTwoMLPHead(512 * 7 ** 2, 1024) 250 | box_predictor = FastRCNNPredictor(1024, num_classes=num_classes) 251 | self.roi_head = RoIHeads( 252 | roi_pooler, 253 | box_head, 254 | box_predictor, 255 | box_fg_iou_thresh, 256 | box_bg_iou_thresh, 257 | box_batch_size_per_image, 258 | box_positive_fraction, 259 | bbox_reg_weights, 260 | box_score_thresh, 261 | box_nms_thresh, 262 | box_detections_per_img, 263 | ) 264 | 265 | def forward(self, *args, **kwargs): 266 | return self.roi_head(*args, **kwargs) 267 | 268 | 269 | class Detection(GeneralizedRCNN): 270 | """ 271 | Detection network as described in 272 | 273 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 274 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 275 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 276 | """ 277 | 278 | def __init__(self, mfn, rpn, roi): 279 | dummy_transform = GeneralizedRCNNTransform(800, 1333, [00.0, 0.0, 0.0], [1.0, 1.0, 1.0]) 280 | super().__init__(mfn, rpn, roi, dummy_transform) 281 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torchvision.transforms as transforms 3 | 4 | import albumentations as albu 5 | import albumentations.pytorch.transforms as albu_transforms 6 | 7 | 8 | PROBABILITY = 0.5 9 | ROTATION_ANGLE = 90 10 | NUM_CHANNELS = 3 11 | # required for resnet 12 | IMAGE_RESIZE_HEIGHT = 256 13 | IMAGE_RESIZE_WIDTH = 256 14 | IMAGE_HEIGHT = 224 15 | IMAGE_WIDTH = 224 16 | # standard imagenet1k mean and standard deviation of RGB channels 17 | MEAN_RED = 0.485 18 | MEAN_GREEN = 0.456 19 | MEAN_BLUE = 0.406 20 | STD_RED = 0.229 21 | STD_GREEN = 0.224 22 | STD_BLUE = 0.225 23 | 24 | 25 | def get_transform(split: str) -> Callable: 26 | """ 27 | Image data transformations such as normalization for train split for classification task 28 | 29 | Parameters 30 | ---------- 31 | split : str 32 | train or else 33 | 34 | Returns 35 | ------- 36 | Callable 37 | Image transformation function 38 | """ 39 | normalize = transforms.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE]) 40 | if split == "train": 41 | return transforms.Compose( 42 | [ 43 | transforms.RandomResizedCrop(IMAGE_HEIGHT), 44 | transforms.RandomRotation(ROTATION_ANGLE), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize, 48 | ] 49 | ) 50 | 51 | else: 52 | return transforms.Compose( 53 | [ 54 | transforms.Resize(IMAGE_RESIZE_HEIGHT), 55 | transforms.CenterCrop(IMAGE_HEIGHT), 56 | transforms.ToTensor(), 57 | normalize, 58 | ] 59 | ) 60 | 61 | 62 | def get_augmentation(split: str) -> Callable: 63 | """ 64 | Obtains proper image augmentation in train split for detection task. 65 | We have splitted transformations done for detection task into augmentation and preprocessing 66 | for clarity 67 | 68 | Parameters 69 | ---------- 70 | split : str 71 | train or else 72 | 73 | Returns 74 | ------- 75 | Callable 76 | Image augmentation function 77 | """ 78 | if split == "train": 79 | return albu.Compose( 80 | [ 81 | albu.Resize(IMAGE_RESIZE_HEIGHT, IMAGE_RESIZE_WIDTH, always_apply=True), 82 | albu.RandomCrop(IMAGE_HEIGHT, IMAGE_WIDTH, always_apply=True), 83 | albu.RandomRotate90(p=PROBABILITY), 84 | albu.HorizontalFlip(p=PROBABILITY), 85 | albu.RandomBrightness(p=PROBABILITY), 86 | ], 87 | bbox_params=albu.BboxParams( 88 | format="pascal_voc", 89 | label_fields=["labels"], 90 | min_visibility=0.2, 91 | ), 92 | ) 93 | else: 94 | return albu.Compose( 95 | [albu.Resize(IMAGE_HEIGHT, IMAGE_WIDTH)], 96 | bbox_params=albu.BboxParams(format="pascal_voc", label_fields=["labels"]), 97 | ) 98 | 99 | 100 | def get_preprocess() -> Callable: 101 | """ 102 | Image normalization using albumentation for detection task that aligns well with image augmentation 103 | 104 | Returns 105 | ------- 106 | Callable 107 | Image normalization function 108 | """ 109 | return albu.Compose( 110 | [ 111 | albu.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE]), 112 | albu_transforms.ToTensorV2(), 113 | ] 114 | ) 115 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pathlib import Path 3 | import tarfile 4 | import logging 5 | from logging.config import fileConfig 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.INFO) 13 | 14 | 15 | def get_logger(config_path: str) -> logging.Logger: 16 | fileConfig(config_path, disable_existing_loggers=False) 17 | logger = logging.getLogger() 18 | return logger 19 | 20 | 21 | def str2bool(flag: Union[str, bool]) -> bool: 22 | if not isinstance(flag, bool): 23 | if flag.lower() == "false": 24 | flag = False 25 | elif flag.lower() == "true": 26 | flag = True 27 | else: 28 | raise ValueError("Wrong boolean argument!") 29 | return flag 30 | 31 | 32 | def freeze(m: nn.Module) -> None: 33 | assert isinstance(m, nn.Module), "freeze only is applied to modules" 34 | for param in m.parameters(): 35 | param.requires_grad = False 36 | 37 | return 38 | 39 | 40 | def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module: 41 | path = Path(path) 42 | logger.info(f"path: {path}") 43 | if path.is_dir(): 44 | path_str = str(list(path.rglob("*.ckpt"))[0]) 45 | else: 46 | path_str = str(path) 47 | 48 | device = "cuda" if torch.cuda.is_available() else "cpu" 49 | state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"] 50 | if prefix is not None: 51 | if prefix[-1] != ".": 52 | prefix += "." 53 | 54 | state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)} 55 | 56 | model.load_state_dict(state_dict, strict=True) 57 | return model 58 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) Soumith Chintala 2016, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | 33 | This module is a modified version of https://github.com/pytorch/vision/tree/03b1d38ba3c67703e648fb067570eeb1a1e61265/references/detection 34 | """ 35 | 36 | import json 37 | 38 | import numpy as np 39 | import copy 40 | import torch 41 | import pickle 42 | 43 | import torch.distributed as dist 44 | 45 | from pycocotools.cocoeval import COCOeval 46 | from pycocotools.coco import COCO 47 | import pycocotools.mask as mask_util 48 | 49 | from collections import defaultdict 50 | 51 | 52 | def is_dist_avail_and_initialized(): 53 | if not dist.is_available(): 54 | return False 55 | if not dist.is_initialized(): 56 | return False 57 | return True 58 | 59 | 60 | def get_world_size(): 61 | if not is_dist_avail_and_initialized(): 62 | return 1 63 | return dist.get_world_size() 64 | 65 | 66 | def all_gather(data): 67 | """ 68 | Run all_gather on arbitrary picklable data (not necessarily tensors) 69 | Args: 70 | data: any picklable object 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | world_size = get_world_size() 75 | if world_size == 1: 76 | return [data] 77 | 78 | # serialized to a Tensor 79 | buffer = pickle.dumps(data) 80 | storage = torch.ByteStorage.from_buffer(buffer) 81 | tensor = torch.ByteTensor(storage).to("cuda") 82 | 83 | # obtain Tensor size of each rank 84 | local_size = torch.tensor([tensor.numel()], device="cuda") 85 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 86 | dist.all_gather(size_list, local_size) 87 | size_list = [int(size.item()) for size in size_list] 88 | max_size = max(size_list) 89 | 90 | # receiving Tensor from all ranks 91 | # we pad the tensor because torch all_gather does not support 92 | # gathering tensors of different shapes 93 | tensor_list = [] 94 | for _ in size_list: 95 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 96 | if local_size != max_size: 97 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 98 | tensor = torch.cat((tensor, padding), dim=0) 99 | dist.all_gather(tensor_list, tensor) 100 | 101 | data_list = [] 102 | for size, tensor in zip(size_list, tensor_list): 103 | buffer = tensor.cpu().numpy().tobytes()[:size] 104 | data_list.append(pickle.loads(buffer)) 105 | 106 | return data_list 107 | 108 | 109 | class CocoEvaluator(object): 110 | def __init__(self, coco_gt, iou_types): 111 | assert isinstance(iou_types, (list, tuple)) 112 | coco_gt = copy.deepcopy(coco_gt) 113 | self.coco_gt = coco_gt 114 | 115 | self.iou_types = iou_types 116 | self.coco_eval = {} 117 | for iou_type in iou_types: 118 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 119 | 120 | self.img_ids = [] 121 | self.eval_imgs = {k: [] for k in iou_types} 122 | 123 | def update(self, predictions): 124 | img_ids = list(np.unique(list(predictions.keys()))) 125 | self.img_ids.extend(img_ids) 126 | 127 | for iou_type in self.iou_types: 128 | results = self.prepare(predictions, iou_type) 129 | coco_dt = loadRes(self.coco_gt, results) if results else COCO() 130 | coco_eval = self.coco_eval[iou_type] 131 | 132 | coco_eval.cocoDt = coco_dt 133 | coco_eval.params.imgIds = list(img_ids) 134 | img_ids, eval_imgs = evaluate(coco_eval) 135 | if isinstance(self.eval_imgs[iou_type], np.ndarray): 136 | self.eval_imgs[iou_type] = self.eval_imgs[iou_type].tolist() 137 | 138 | self.eval_imgs[iou_type].append(eval_imgs) 139 | 140 | def synchronize_between_processes(self): 141 | for iou_type in self.iou_types: 142 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 143 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 144 | 145 | def accumulate(self): 146 | for coco_eval in self.coco_eval.values(): 147 | coco_eval.accumulate() 148 | 149 | def summarize(self): 150 | for iou_type, coco_eval in self.coco_eval.items(): 151 | print("IoU metric: {}".format(iou_type)) 152 | coco_eval.summarize() 153 | 154 | def prepare(self, predictions, iou_type): 155 | return self.prepare_for_coco_detection(predictions) 156 | 157 | def prepare_for_coco_detection(self, predictions): 158 | coco_results = [] 159 | for original_id, prediction in predictions.items(): 160 | if len(prediction) == 0: 161 | continue 162 | 163 | boxes = prediction["boxes"] 164 | boxes = convert_to_xywh(boxes).tolist() 165 | scores = prediction["scores"].tolist() 166 | labels = prediction["labels"].tolist() 167 | 168 | coco_results.extend( 169 | [ 170 | { 171 | "image_id": original_id, 172 | "category_id": labels[k], 173 | "bbox": box, 174 | "score": scores[k], 175 | } 176 | for k, box in enumerate(boxes) 177 | ] 178 | ) 179 | return coco_results 180 | 181 | 182 | def convert_to_xywh(boxes): 183 | xmin, ymin, xmax, ymax = boxes.unbind(1) 184 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 185 | 186 | 187 | def merge(img_ids, eval_imgs): 188 | all_img_ids = all_gather(img_ids) 189 | all_eval_imgs = all_gather(eval_imgs) 190 | 191 | merged_img_ids = [] 192 | for p in all_img_ids: 193 | merged_img_ids.extend(p) 194 | 195 | merged_eval_imgs = [] 196 | for p in all_eval_imgs: 197 | merged_eval_imgs.append(p) 198 | 199 | merged_img_ids = np.array(merged_img_ids) 200 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 201 | 202 | # keep only unique (and in sorted order) images 203 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 204 | merged_eval_imgs = merged_eval_imgs[..., idx] 205 | 206 | return merged_img_ids, merged_eval_imgs 207 | 208 | 209 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 210 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 211 | img_ids = list(img_ids) 212 | eval_imgs = list(eval_imgs.flatten()) 213 | 214 | coco_eval.evalImgs = eval_imgs 215 | coco_eval.params.imgIds = img_ids 216 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 217 | 218 | 219 | ################################################################# 220 | # From pycocotools, just removed the prints and fixed 221 | # a Python3 bug about unicode not defined 222 | ################################################################# 223 | 224 | # Ideally, pycocotools wouldn't have hard-coded prints 225 | # so that we could avoid copy-pasting those two functions 226 | 227 | 228 | def createIndex(self): 229 | # create index 230 | # print('creating index...') 231 | anns, cats, imgs = {}, {}, {} 232 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list) 233 | if "annotations" in self.dataset: 234 | for ann in self.dataset["annotations"]: 235 | imgToAnns[ann["image_id"]].append(ann) 236 | anns[ann["id"]] = ann 237 | 238 | if "images" in self.dataset: 239 | for img in self.dataset["images"]: 240 | imgs[img["id"]] = img 241 | 242 | if "categories" in self.dataset: 243 | for cat in self.dataset["categories"]: 244 | cats[cat["id"]] = cat 245 | 246 | if "annotations" in self.dataset and "categories" in self.dataset: 247 | for ann in self.dataset["annotations"]: 248 | catToImgs[ann["category_id"]].append(ann["image_id"]) 249 | 250 | # print('index created!') 251 | 252 | # create class members 253 | self.anns = anns 254 | self.imgToAnns = imgToAnns 255 | self.catToImgs = catToImgs 256 | self.imgs = imgs 257 | self.cats = cats 258 | 259 | 260 | maskUtils = mask_util 261 | 262 | 263 | def loadRes(self, resFile): 264 | """ 265 | Load result file and return a result api object. 266 | :param resFile (str) : file name of result file 267 | :return: res (obj) : result api object 268 | """ 269 | res = COCO() 270 | res.dataset["images"] = [img for img in self.dataset["images"]] 271 | 272 | # print('Loading and preparing results...') 273 | # tic = time.time() 274 | if isinstance(resFile, torch._six.string_classes): 275 | anns = json.load(open(resFile)) 276 | elif type(resFile) == np.ndarray: 277 | anns = self.loadNumpyAnnotations(resFile) 278 | else: 279 | anns = resFile 280 | assert type(anns) == list, "results in not an array of objects" 281 | annsImgIds = [ann["image_id"] for ann in anns] 282 | assert set(annsImgIds) == ( 283 | set(annsImgIds) & set(self.getImgIds()) 284 | ), "Results do not correspond to current coco set" 285 | if "caption" in anns[0]: 286 | imgIds = set([img["id"] for img in res.dataset["images"]]) & set([ann["image_id"] for ann in anns]) 287 | res.dataset["images"] = [img for img in res.dataset["images"] if img["id"] in imgIds] 288 | for id, ann in enumerate(anns): 289 | ann["id"] = id + 1 290 | elif "bbox" in anns[0] and not anns[0]["bbox"] == []: 291 | res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) 292 | for id, ann in enumerate(anns): 293 | bb = ann["bbox"] 294 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] 295 | if "segmentation" not in ann: 296 | ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 297 | ann["area"] = bb[2] * bb[3] 298 | ann["id"] = id + 1 299 | ann["iscrowd"] = 0 300 | 301 | res.dataset["annotations"] = anns 302 | createIndex(res) 303 | return res 304 | 305 | 306 | def evaluate(self): 307 | """ 308 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 309 | :return: None 310 | """ 311 | # tic = time.time() 312 | # print('Running per image evaluation...') 313 | p = self.params 314 | # add backward compatibility if useSegm is specified in params 315 | if p.useSegm is not None: 316 | p.iouType = "segm" if p.useSegm == 1 else "bbox" 317 | print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) 318 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 319 | p.imgIds = list(np.unique(p.imgIds)) 320 | if p.useCats: 321 | p.catIds = list(np.unique(p.catIds)) 322 | p.maxDets = sorted(p.maxDets) 323 | self.params = p 324 | 325 | self._prepare() 326 | # loop through images, area range, max detection number 327 | catIds = p.catIds if p.useCats else [-1] 328 | 329 | if p.iouType == "segm" or p.iouType == "bbox": 330 | computeIoU = self.computeIoU 331 | elif p.iouType == "keypoints": 332 | computeIoU = self.computeOks 333 | self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds} 334 | 335 | evaluateImg = self.evaluateImg 336 | maxDet = p.maxDets[-1] 337 | evalImgs = [ 338 | evaluateImg(imgId, catId, areaRng, maxDet) for catId in catIds for areaRng in p.areaRng for imgId in p.imgIds 339 | ] 340 | # this is NOT in the pycocotools code, but could be done outside 341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 342 | self._paramsEval = copy.deepcopy(self.params) 343 | # toc = time.time() 344 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 345 | return p.imgIds, evalImgs 346 | 347 | 348 | ################################################################# 349 | # end of straight copy from pycocotools, just removing the prints 350 | ################################################################# 351 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/coco_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) Soumith Chintala 2016, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | 33 | This module is a modified version of https://github.com/pytorch/vision/tree/03b1d38ba3c67703e648fb067570eeb1a1e61265/references/detection 34 | """ 35 | 36 | from pycocotools.coco import COCO 37 | 38 | 39 | def convert_to_coco_api(ds): 40 | coco_ds = COCO() 41 | # annotation IDs need to start at 1, not 0, see torchvision issue #1530 42 | ann_id = 1 43 | dataset = {"images": [], "categories": [], "annotations": []} 44 | categories = set() 45 | for img_idx in range(len(ds)): 46 | # find better way to get target 47 | # targets = ds.get_annotations(img_idx) 48 | img, targets, _ = ds[img_idx] 49 | image_id = targets["image_id"].item() 50 | img_dict = {} 51 | img_dict["id"] = image_id 52 | img_dict["height"] = img.shape[-2] 53 | img_dict["width"] = img.shape[-1] 54 | dataset["images"].append(img_dict) 55 | bboxes = targets["boxes"] 56 | bboxes[:, 2:] -= bboxes[:, :2] 57 | bboxes = bboxes.tolist() 58 | labels = targets["labels"].tolist() 59 | areas = targets["area"].tolist() 60 | iscrowd = targets["iscrowd"].tolist() 61 | num_objs = len(bboxes) 62 | for i in range(num_objs): 63 | ann = {} 64 | ann["image_id"] = image_id 65 | ann["bbox"] = bboxes[i] 66 | ann["category_id"] = labels[i] 67 | categories.add(labels[i]) 68 | ann["area"] = areas[i] 69 | ann["iscrowd"] = iscrowd[i] 70 | ann["id"] = ann_id 71 | dataset["annotations"].append(ann) 72 | ann_id += 1 73 | dataset["categories"] = [{"id": i} for i in sorted(categories)] 74 | coco_ds.dataset = dataset 75 | coco_ds.createIndex() 76 | return coco_ds 77 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/visualize.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Union, Tuple 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from matplotlib import pyplot as plt 8 | 9 | import cv2 10 | 11 | import torch 12 | 13 | 14 | TEXT_COLOR = (255, 255, 255) # White 15 | CLASSES = { 16 | "crazing": "Cr", 17 | "inclusion": "In", 18 | "pitted_surface": "PS", 19 | "patches": "Pa", 20 | "rolled-in_scale": "RS", 21 | "scratches": "Sc", 22 | } 23 | CATEGORY_ID_TO_NAME = {i: name for i, name in enumerate(CLASSES.keys(), start=1)} 24 | 25 | 26 | def unnormalize_to_hwc( 27 | image: torch.Tensor, mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225] 28 | ) -> np.ndarray: 29 | """ 30 | Unnormalizes and a normlized image tensor [0, 1] CHW -> HWC [0, 255] 31 | 32 | Parameters 33 | ---------- 34 | image : torch.Tensor 35 | Normalized image 36 | mean : List[float], optional 37 | RGB averages used in normalization, by default [0.485, 0.456, 0.406] from imagenet1k 38 | std : List[float], optional 39 | RGB standard deviations used in normalization, by default [0.229, 0.224, 0.225] from imagenet1k 40 | 41 | Returns 42 | ------- 43 | np.ndarray 44 | Unnormalized image as numpy array 45 | """ 46 | image = image.numpy().transpose(1, 2, 0) # HWC 47 | image = (image * std + mean).clip(0, 1) 48 | image = (image * 255).astype(np.uint8) 49 | return image 50 | 51 | 52 | def visualize_bbox(img: np.ndarray, bbox: np.ndarray, class_name: str, color, thickness: int = 2) -> np.ndarray: 53 | """ 54 | Uses cv2 to draw colored bounding boxes and class names in an image 55 | 56 | Parameters 57 | ---------- 58 | img : np.ndarray 59 | [description] 60 | bbox : np.ndarray 61 | [description] 62 | class_name : str 63 | Class name 64 | color : tuple 65 | BGR tuple 66 | thickness : int, optional 67 | Bouding box thickness, by default 2 68 | """ 69 | x_min, y_min, x_max, y_max = tuple(map(int, bbox)) 70 | 71 | cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) 72 | 73 | ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) 74 | cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color, -1) 75 | cv2.putText( 76 | img, 77 | text=class_name, 78 | org=(x_min, y_min - int(0.3 * text_height)), 79 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 80 | fontScale=0.35, 81 | color=TEXT_COLOR, 82 | lineType=cv2.LINE_AA, 83 | ) 84 | return img 85 | 86 | 87 | def visualize( 88 | image: np.ndarray, 89 | bboxes: Iterable[Union[torch.Tensor, np.ndarray]] = [], 90 | category_ids: Iterable[Union[torch.Tensor, np.ndarray]] = [], 91 | colors: Iterable[Tuple[int, int, int]] = [], 92 | titles: Iterable[str] = [], 93 | category_id_to_name=CATEGORY_ID_TO_NAME, 94 | dpi=150, 95 | ) -> None: 96 | """ 97 | Applies the bounding boxes and category ids to an image 98 | 99 | Parameters 100 | ---------- 101 | image : np.ndarray 102 | Image as numpy array 103 | bboxes : Iterable[Union[torch.Tensor, np.ndarray]], optional 104 | Bouding boxes, by default [] 105 | category_ids : Iterable[Union[torch.Tensor, np.ndarray]], optional 106 | Category ids, by default [] 107 | colors : Iterable[Tuple[int, int, int]], optional 108 | Colors for each bounding box, by default [()] 109 | titles : Iterable[str], optional 110 | Titles for each image, by default [] 111 | category_id_to_name : Dict[str, str], optional 112 | Dictionary of category ids to names, by default CATEGORY_ID_TO_NAME 113 | dpi : int, optional 114 | DPI for clarity, by default 150 115 | """ 116 | bboxes, category_ids, colors, titles = list(map(list, [bboxes, category_ids, colors, titles])) # type: ignore 117 | n = len(bboxes) 118 | assert ( 119 | n == len(category_ids) == len(colors) == len(titles) - 1 120 | ), f"number of bboxes, category ids, colors and titles (minus one) do not match" 121 | 122 | plt.figure(dpi=dpi) 123 | ncols = n + 1 124 | plt.subplot(1, ncols, 1) 125 | img = image.copy() 126 | plt.axis("off") 127 | plt.title(titles[0]) 128 | plt.imshow(image) 129 | if not len(bboxes): 130 | return 131 | 132 | titles = titles[1:] 133 | for i in range(2, ncols + 1): 134 | img = image.copy() 135 | plt.subplot(1, ncols, i) 136 | plt.axis("off") 137 | j = i - 2 138 | plt.title(titles[j]) 139 | for bbox, category_id in zip(bboxes[j], category_ids[j]): # type: ignore 140 | if isinstance(bbox, torch.Tensor): 141 | bbox = bbox.numpy() 142 | 143 | if isinstance(category_id, torch.Tensor): 144 | category_id = category_id.numpy() 145 | 146 | if isinstance(category_id, np.ndarray): 147 | category_id = category_id.item() 148 | 149 | class_name = category_id_to_name[category_id] 150 | img = visualize_bbox(img, bbox, class_name, color=colors[j]) 151 | 152 | plt.imshow(img) 153 | return 154 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import boto3 4 | import copy 5 | import matplotlib.patches as patches 6 | from matplotlib import pyplot as plt 7 | from PIL import Image, ImageColor 8 | 9 | 10 | def query_Type2(image_file_name, endpoint_name, num_predictions=4): 11 | 12 | with open(image_file_name, "rb") as file: 13 | input_img_rb = file.read() 14 | 15 | client = boto3.client("runtime.sagemaker") 16 | query_response = client.invoke_endpoint( 17 | EndpointName=endpoint_name, 18 | ContentType="application/x-image", 19 | Body=input_img_rb, 20 | Accept=f'application/json;verbose;n_predictions={num_predictions}' 21 | ) 22 | # If we remove ';n_predictions={}' from Accept, we get all the predicted boxes. 23 | query_response = query_response['Body'].read() 24 | 25 | model_predictions = json.loads(query_response) 26 | normalized_boxes, classes, scores, labels = ( 27 | model_predictions["normalized_boxes"], 28 | model_predictions["classes"], 29 | model_predictions["scores"], 30 | model_predictions["labels"], 31 | ) 32 | # Substitute the classes index with the classes name 33 | class_names = [labels[int(idx)] for idx in classes] 34 | return normalized_boxes, class_names, scores 35 | 36 | 37 | # Copied from albumentations/augmentations/functional.py 38 | # Follow albumentations.Normalize, which is used in sagemaker_defect_detection/detector.py 39 | def normalize(img, mean, std, max_pixel_value=255.0): 40 | mean = np.array(mean, dtype=np.float32) 41 | mean *= max_pixel_value 42 | 43 | std = np.array(std, dtype=np.float32) 44 | std *= max_pixel_value 45 | 46 | denominator = np.reciprocal(std, dtype=np.float32) 47 | 48 | img = img.astype(np.float32) 49 | img -= mean 50 | img *= denominator 51 | return img 52 | 53 | 54 | def query_DDN(image_file_name, endpoint_name, num_predictions=4): 55 | 56 | with Image.open(image_file_name) as im: 57 | image_np = np.array(im) 58 | 59 | # Follow albumentations.Normalize, which is used in sagemaker_defect_detection/detector.py 60 | mean = (0.485, 0.456, 0.406) 61 | std = (0.229, 0.224, 0.225) 62 | max_pixel_value = 255.0 63 | 64 | image_np = normalize(image_np, mean, std, max_pixel_value) 65 | image_np = image_np.transpose(2, 0, 1) 66 | image_np = np.expand_dims(image_np, 0) # CHW 67 | 68 | client = boto3.client("runtime.sagemaker") 69 | query_response = client.invoke_endpoint( 70 | EndpointName=endpoint_name, 71 | ContentType="application/json", 72 | Body=json.dumps(image_np.tolist()) 73 | ) 74 | query_response = query_response["Body"].read() 75 | 76 | model_predictions = json.loads(query_response.decode())[0] 77 | unnormalized_boxes = model_predictions['boxes'][:num_predictions] 78 | class_names = model_predictions['labels'][:num_predictions] 79 | scores = model_predictions['scores'][:num_predictions] 80 | return unnormalized_boxes, class_names, scores 81 | 82 | 83 | 84 | def query_Type1(image_file_name, endpoint_name, num_predictions=4): 85 | 86 | with open(image_file_name, "rb") as file: 87 | input_img_rb = file.read() 88 | 89 | client = boto3.client(service_name="runtime.sagemaker") 90 | query_response = client.invoke_endpoint( 91 | EndpointName=endpoint_name, 92 | ContentType="application/x-image", 93 | Body=input_img_rb 94 | ) 95 | query_response = query_response["Body"].read() 96 | 97 | model_predictions = json.loads(query_response)['prediction'][:num_predictions] 98 | class_names = [int(pred[0])+1 for pred in model_predictions] # +1 for index starts from 1 99 | scores = [pred[1] for pred in model_predictions] 100 | normalized_boxes = [pred[2:] for pred in model_predictions] 101 | return normalized_boxes, class_names, scores 102 | 103 | 104 | def plot_results(image, bboxes, categories, d): 105 | # d - dictionary of endpoint responses 106 | 107 | colors = list(ImageColor.colormap.values()) 108 | with Image.open(image) as im: 109 | image_np = np.array(im) 110 | fig = plt.figure(figsize=(20, 14)) 111 | 112 | n = len(d) 113 | 114 | # Ground truth 115 | ax1 = fig.add_subplot(2, 3, 1) 116 | plt.axis('off') 117 | plt.title('Ground Truth') 118 | 119 | for bbox in bboxes: 120 | left, bot, right, top = bbox['bbox'] 121 | x, y, w, h = left, bot, right - left, top - bot 122 | 123 | color = colors[hash(bbox['category_id']) % len(colors)] 124 | rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor=color, facecolor="none") 125 | ax1.add_patch(rect) 126 | ax1.text(x, y, "{}".format(categories[bbox['category_id']]), 127 | bbox=dict(facecolor="white", alpha=0.5)) 128 | 129 | ax1.imshow(image_np) 130 | 131 | # Predictions 132 | counter = 2 133 | for k, v in d.items(): 134 | axi = fig.add_subplot(2, 3, counter) 135 | counter += 1 136 | 137 | if "Type2-HPO" in k: 138 | k = "Type2-HPO" 139 | elif "Type2" in k: 140 | k = "Type2" 141 | elif "Type1-HPO" in k: 142 | k = "Type1-HPO" 143 | elif "Type1" in k: 144 | k = "Type1" 145 | elif k.endswith("finetuned-endpoint"): 146 | k = "DDN" 147 | 148 | plt.title(f'Prediction: {k}') 149 | plt.axis('off') 150 | 151 | for idx in range(len(v['normalized_boxes'])): 152 | left, bot, right, top = v['normalized_boxes'][idx] 153 | if k == 'DDN': 154 | x, w = left, right - left 155 | y, h = bot, top - bot 156 | else: 157 | x, w = [val * image_np.shape[1] for val in [left, right - left]] 158 | y, h = [val * image_np.shape[0] for val in [bot, top - bot]] 159 | color = colors[hash(v['classes_names'][idx]) % len(colors)] 160 | rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor=color, facecolor="none") 161 | axi.add_patch(rect) 162 | axi.text(x, y, 163 | "{} {:.0f}%".format(categories[v['classes_names'][idx]], v['confidences'][idx] * 100), 164 | bbox=dict(facecolor="white", alpha=0.5), 165 | ) 166 | 167 | axi.imshow(image_np) 168 | 169 | plt.tight_layout() 170 | plt.savefig("results/"+ image.split('/')[-1]) 171 | 172 | plt.show() 173 | -------------------------------------------------------------------------------- /src/xml2json.py: -------------------------------------------------------------------------------- 1 | # Use this script to convert annotation xmls to a single annotations.json file that will be taken by Jumpstart OD model 2 | # Reference: XML2JSON.py https://linuxtut.com/en/e391e5e6924945b8a852/ 3 | 4 | import random 5 | import xmltodict 6 | import copy 7 | import json 8 | import glob 9 | import os 10 | from collections import defaultdict 11 | 12 | 13 | categories = [ 14 | {"id": 1, "name": "crazing"}, 15 | {"id": 2, "name": "inclusion"}, 16 | {"id": 3, "name": "pitted_surface"}, 17 | {"id": 4, "name": "patches"}, 18 | {"id": 5, "name": "rolled-in_scale"}, 19 | {"id": 6, "name": "scratches"}, 20 | ] 21 | 22 | 23 | def XML2JSON(xmlFiles, test_ratio=None, rnd_seed=100): 24 | """ Convert all xmls to annotations.json 25 | 26 | If the test_ratio is not None, convert to two annotations.json files, 27 | one for train+val, another one for test. 28 | """ 29 | 30 | images = list() 31 | annotations = list() 32 | image_id = 1 33 | annotation_id = 1 34 | for file in xmlFiles: 35 | annotation_path = file 36 | image = dict() 37 | with open(annotation_path) as fd: 38 | doc = xmltodict.parse(fd.read(), force_list=('object')) 39 | filename = str(doc['annotation']['filename']) 40 | image['file_name'] = filename if filename.endswith('.jpg') else filename + '.jpg' 41 | image['height'] = int(doc['annotation']['size']['height']) 42 | image['width'] = int(doc['annotation']['size']['width']) 43 | image['id'] = image_id 44 | # print("File Name: {} and image_id {}".format(file, image_id)) 45 | images.append(image) 46 | if 'object' in doc['annotation']: 47 | for obj in doc['annotation']['object']: 48 | for value in categories: 49 | annotation = dict() 50 | if str(obj['name']) == value["name"]: 51 | annotation["image_id"] = image_id 52 | xmin = int(obj["bndbox"]["xmin"]) 53 | ymin = int(obj["bndbox"]["ymin"]) 54 | xmax = int(obj["bndbox"]["xmax"]) 55 | ymax = int(obj["bndbox"]["ymax"]) 56 | annotation["bbox"] = [xmin, ymin, xmax, ymax] 57 | annotation["category_id"] = value["id"] 58 | annotation["id"] = annotation_id 59 | annotation_id += 1 60 | annotations.append(annotation) 61 | 62 | else: 63 | print("File: {} doesn't have any object".format(file)) 64 | 65 | image_id += 1 66 | 67 | if test_ratio is None: 68 | attrDict = dict() 69 | attrDict["images"] = images 70 | attrDict["annotations"] = annotations 71 | 72 | jsonString = json.dumps(attrDict) 73 | with open("annotations.json", "w") as f: 74 | f.write(jsonString) 75 | else: 76 | assert test_ratio < 1.0 77 | 78 | # Size of each class 79 | category_ids = defaultdict(list) 80 | for img in images: 81 | category = img['file_name'].split('_')[0] 82 | category_ids[category].append(img['id']) 83 | print('\ncategory\tnum of images') 84 | print('-' * 20) 85 | 86 | random.seed(rnd_seed) 87 | 88 | train_val_images = [] 89 | test_images = [] 90 | train_val_annotations = [] 91 | test_annotations = [] 92 | 93 | for category in category_ids.keys(): 94 | print(f"{category}:\t{len(category_ids[category])}") 95 | 96 | random.shuffle(category_ids[category]) 97 | N = len(category_ids[category]) 98 | ids = category_ids[category] 99 | 100 | sep = int(N * test_ratio) 101 | 102 | category_images = [img for img in images if img['id'] in ids[:sep]] 103 | test_images.extend(category_images) 104 | category_images = [img for img in images if img['id'] in ids[sep:]] 105 | train_val_images.extend(category_images) 106 | 107 | category_annotations = [ann for ann in annotations if ann['image_id'] in ids[:sep]] 108 | test_annotations.extend(category_annotations) 109 | category_annotations = [ann for ann in annotations if ann['image_id'] in ids[sep:]] 110 | train_val_annotations.extend(category_annotations) 111 | 112 | print('-' * 20) 113 | 114 | train_val_attrDict = dict() 115 | train_val_attrDict["images"] = train_val_images 116 | train_val_attrDict["annotations"] = train_val_annotations 117 | print(f"\ntrain_val:\t{len(train_val_images)}") 118 | 119 | train_val_jsonString = json.dumps(train_val_attrDict) 120 | with open("annotations.json", "w") as f: 121 | f.write(train_val_jsonString) 122 | 123 | test_attDict = dict() 124 | test_attDict["images"] = test_images 125 | test_attDict["annotations"] = test_annotations 126 | print(f"test:\t{len(test_images)}") 127 | 128 | test_jsonString = json.dumps(test_attDict) 129 | with open("test_annotations.json", "w") as f: 130 | f.write(test_jsonString) 131 | 132 | 133 | 134 | def convert_to_pycocotools_ground_truth(annotations_file): 135 | """ 136 | Given the annotation json file for the test data generated during 137 | initial data preparatoin, convert it to the input format pycocotools 138 | can consume. 139 | """ 140 | 141 | with open(annotations_file) as f: 142 | images_annotations = json.loads(f.read()) 143 | 144 | attrDict = dict() 145 | attrDict["images"] = images_annotations["images"] 146 | attrDict["categories"] = categories 147 | 148 | annotations = [] 149 | for entry in images_annotations['annotations']: 150 | ann = copy.deepcopy(entry) 151 | xmin, ymin, xmax, ymax = ann["bbox"] 152 | ann["bbox"] = [xmin, ymin, xmax-xmin, ymax-ymin] # convert to [x, y, W, H] 153 | ann["area"] = (xmax - xmin) * (ymax - ymin) 154 | ann["iscrowd"] = 0 155 | annotations.append(ann) 156 | 157 | attrDict["annotations"] = annotations 158 | 159 | jsonString = json.dumps(attrDict) 160 | ground_truth_annotations = "results/ground_truth_annotations.json" 161 | 162 | with open(ground_truth_annotations, "w") as f: 163 | f.write(jsonString) 164 | 165 | return ground_truth_annotations 166 | 167 | 168 | if __name__ == "__main__": 169 | data_path = '../../NEU-DET/ANNOTATIONS' 170 | xmlfiles = glob.glob(os.path.join(data_path, '*.xml')) 171 | xmlfiles.sort() 172 | 173 | XML2JSON(xmlfiles, test_ratio=0.2) 174 | 175 | 176 | --------------------------------------------------------------------------------