├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── questions-or-general-feedbacks.md └── workflows │ ├── ci.yml │ └── codeql.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .viperlightignore ├── .viperlightrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.md ├── THIRD_PARTY ├── cloudformation ├── defect-detection-permissions.yaml ├── defect-detection-sagemaker-notebook-instance.yaml ├── defect-detection.yaml └── solution-assistant │ ├── requirements.in │ ├── requirements.txt │ ├── solution-assistant.yaml │ └── src │ └── lambda_fn.py ├── docs ├── arch.png ├── data.png ├── data_flow.png ├── launch.svg ├── numerical.png ├── sagemaker.png ├── sample1.png ├── sample2.png ├── sample3.png └── train_arch.png ├── manifest.json ├── notebooks ├── 0_demo.ipynb ├── 1_retrain_from_checkpoint.ipynb ├── 2_detection_from_scratch.ipynb ├── 3_classification_from_scratch.ipynb └── 4_finetune.ipynb ├── pyproject.toml ├── requirements.in ├── requirements.txt ├── scripts ├── build.sh ├── find_best_ckpt.py ├── set_kernelspec.py ├── train_classifier.sh └── train_detector.sh ├── setup.cfg ├── setup.py └── src ├── im2rec.py ├── prepare_RecordIO.py ├── prepare_data ├── neu.py └── test_neu.py ├── sagemaker_defect_detection ├── __init__.py ├── classifier.py ├── dataset │ ├── __init__.py │ └── neu.py ├── detector.py ├── models │ ├── __init__.py │ └── ddn.py ├── transforms.py └── utils │ ├── __init__.py │ ├── coco_eval.py │ ├── coco_utils.py │ └── visualize.py ├── utils.py └── xml2json.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.py text diff=python 2 | *.ipynb text 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[Bug]" 5 | labels: bug 6 | assignees: ehsanmok 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior and error details 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Additional context** 20 | Add any other context about the problem here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: enhancement 6 | assignees: ehsanmok 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-or-general-feedbacks.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Questions or general feedbacks 3 | about: Question and feedbacks 4 | title: "[General]" 5 | labels: question 6 | assignees: ehsanmok 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ mainline ] 7 | 8 | jobs: 9 | build: 10 | 11 | strategy: 12 | max-parallel: 4 13 | fail-fast: false 14 | matrix: 15 | python-version: [3.7, 3.8] 16 | platform: [ubuntu-latest] 17 | 18 | runs-on: ${{ matrix.platform }} 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: CloudFormation lint 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install cfn-lint 30 | for y in `find deploy/* -name "*.yaml" -o -name "*.template" -o -name "*.json"`; do 31 | echo "============= $y ================" 32 | cfn-lint --fail-on-warnings $y || ec1=$? 33 | done 34 | if [ "$ec1" -ne "0" ]; then echo 'ERROR-1'; else echo 'SUCCESS-1'; ec1=0; fi 35 | echo "Exit Code 1 `echo $ec1`" 36 | if [ "$ec1" -ne "0" ]; then echo 'ERROR'; ec=1; else echo 'SUCCESS'; ec=0; fi; 37 | echo "Exit Code Final `echo $ec`" 38 | exit $ec 39 | - name: Build the package 40 | run: | 41 | python -m pip install -e '.[dev,test,doc]' 42 | - name: Code style check 43 | run: | 44 | black --check src 45 | - name: Notebook style check 46 | run: | 47 | black-nb notebooks/*.ipynb --check 48 | - name: Type check 49 | run: | 50 | mypy --ignore-missing-imports --allow-redefinition --pretty --show-error-context src/ 51 | - name: Run tests 52 | run: | 53 | pytest --pyargs src/ -s 54 | - name: Update docs 55 | run: | 56 | portray on_github_pages -m "Update gh-pages" -f 57 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yaml: -------------------------------------------------------------------------------- 1 | name: "CodeQL" 2 | 3 | on: 4 | push: 5 | pull_request: 6 | schedule: 7 | - cron: '0 21 * * 1' 8 | 9 | jobs: 10 | CodeQL-Build: 11 | 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | language: ['python'] 17 | 18 | steps: 19 | - name: Checkout repository 20 | uses: actions/checkout@v2 21 | with: 22 | # We must fetch at least the immediate parents so that if this is 23 | # a pull request then we can checkout the head. 24 | fetch-depth: 2 25 | 26 | # If this run was triggered by a pull request event, then checkout 27 | # the head of the pull request instead of the merge commit. 28 | - run: git checkout HEAD^2 29 | if: ${{ github.event_name == 'pull_request' }} 30 | 31 | - name: Initialize CodeQL 32 | uses: github/codeql-action/init@v1 33 | with: 34 | languages: ${{ matrix.language }} 35 | 36 | - name: Perform CodeQL Analysis 37 | uses: github/codeql-action/analyze@v1 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | __pycache__ 3 | .vscode 4 | .mypy* 5 | .ipynb* 6 | .pytest* 7 | *output 8 | *logs* 9 | *checkpoints* 10 | *.ckpt* 11 | *.tar.gz 12 | dist 13 | build 14 | site-packages 15 | stack_outputs.json 16 | docs/*.py 17 | raw_neu_det 18 | neu_det 19 | raw_neu_cls 20 | neu_cls 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.6 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v2.3.0 7 | hooks: 8 | - id: end-of-file-fixer 9 | - id: trailing-whitespace 10 | - repo: https://github.com/psf/black 11 | rev: 20.8b1 12 | hooks: 13 | - id: black 14 | - repo: https://github.com/kynan/nbstripout 15 | rev: 0.3.7 16 | hooks: 17 | - id: nbstripout 18 | name: nbstripout 19 | description: "nbstripout: strip output from Jupyter and IPython notebooks" 20 | entry: nbstripout notebooks 21 | language: python 22 | types: [jupyter] 23 | - repo: https://github.com/tomcatling/black-nb 24 | rev: 0.3.0 25 | hooks: 26 | - id: black-nb 27 | name: black-nb 28 | entry: black-nb 29 | language: python 30 | args: ["--include", '\.ipynb$'] 31 | -------------------------------------------------------------------------------- /.viperlightignore: -------------------------------------------------------------------------------- 1 | ^dist/ 2 | docs/launch.svg:4 3 | CODE_OF_CONDUCT.md:4 4 | CONTRIBUTING.md:50 5 | build/solution-assistant.zip 6 | cloudformation/defect-detection-endpoint.yaml:42 7 | cloudformation/defect-detection-permissions.yaml:114 8 | docs/NEU_surface_defect_database.pdf 9 | -------------------------------------------------------------------------------- /.viperlightrc: -------------------------------------------------------------------------------- 1 | { 2 | "failOn": "medium", 3 | "all": true 4 | } 5 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *mainline* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /THIRD_PARTY: -------------------------------------------------------------------------------- 1 | Modifications Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | SPDX-License-Identifier: Apache-2.0 3 | 4 | Original Copyright 2016 Soumith Chintala. All Rights Reserved. 5 | SPDX-License-Identifier: BSD-3 6 | https://github.com/pytorch/vision/blob/master/LICENSE 7 | -------------------------------------------------------------------------------- /cloudformation/defect-detection-permissions.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection permission stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Type: String 6 | SolutionName: 7 | Type: String 8 | S3Bucket: 9 | Type: String 10 | StackVersion: 11 | Type: String 12 | 13 | Mappings: 14 | S3: 15 | release: 16 | BucketPrefix: "sagemaker-solutions-prod" 17 | development: 18 | BucketPrefix: "sagemaker-solutions-devo" 19 | 20 | Resources: 21 | SageMakerIAMRole: 22 | Type: AWS::IAM::Role 23 | Properties: 24 | RoleName: !Sub "${SolutionPrefix}-${AWS::Region}-nb-role" 25 | AssumeRolePolicyDocument: 26 | Version: "2012-10-17" 27 | Statement: 28 | - Effect: Allow 29 | Principal: 30 | AWS: 31 | - !Sub "arn:aws:iam::${AWS::AccountId}:root" 32 | Service: 33 | - sagemaker.amazonaws.com 34 | - lambda.amazonaws.com # for solution assistant resource cleanup 35 | Action: 36 | - "sts:AssumeRole" 37 | Metadata: 38 | cfn_nag: 39 | rules_to_suppress: 40 | - id: W28 41 | reason: Needs to be explicitly named to tighten launch permissions policy 42 | 43 | SageMakerIAMPolicy: 44 | Type: AWS::IAM::Policy 45 | Properties: 46 | PolicyName: !Sub "${SolutionPrefix}-nb-instance-policy" 47 | Roles: 48 | - !Ref SageMakerIAMRole 49 | PolicyDocument: 50 | Version: "2012-10-17" 51 | Statement: 52 | - Effect: Allow 53 | Action: 54 | - sagemaker:CreateTrainingJob 55 | - sagemaker:DescribeTrainingJob 56 | - sagemaker:CreateProcessingJob 57 | - sagemaker:DescribeProcessingJob 58 | - sagemaker:CreateModel 59 | - sagemaker:DescribeEndpointConfig 60 | - sagemaker:DescribeEndpoint 61 | - sagemaker:CreateEndpointConfig 62 | - sagemaker:CreateEndpoint 63 | - sagemaker:DeleteEndpointConfig 64 | - sagemaker:DeleteEndpoint 65 | - sagemaker:DeleteModel 66 | - sagemaker:InvokeEndpoint 67 | Resource: 68 | - !Sub "arn:aws:sagemaker:${AWS::Region}:${AWS::AccountId}:*" 69 | - Effect: Allow 70 | Action: 71 | - cloudwatch:GetMetricData 72 | - cloudwatch:GetMetricStatistics 73 | - cloudwatch:ListMetrics 74 | - cloudwatch:PutMetricData 75 | Resource: 76 | - !Sub "arn:aws:cloudwatch:${AWS::Region}:${AWS::AccountId}:*" 77 | - Effect: Allow 78 | Action: 79 | - logs:CreateLogGroup 80 | - logs:CreateLogStream 81 | - logs:DescribeLogStreams 82 | - logs:GetLogEvents 83 | - logs:PutLogEvents 84 | Resource: 85 | - !Sub "arn:aws:logs:${AWS::Region}:${AWS::AccountId}:log-group:/aws/sagemaker/*" 86 | - Effect: Allow 87 | Action: 88 | - iam:PassRole 89 | Resource: 90 | - !GetAtt SageMakerIAMRole.Arn 91 | Condition: 92 | StringEquals: 93 | iam:PassedToService: sagemaker.amazonaws.com 94 | - Effect: Allow 95 | Action: 96 | - iam:GetRole 97 | Resource: 98 | - !GetAtt SageMakerIAMRole.Arn 99 | - Effect: Allow 100 | Action: 101 | - ecr:GetAuthorizationToken 102 | - ecr:GetDownloadUrlForLayer 103 | - ecr:BatchGetImage 104 | - ecr:BatchCheckLayerAvailability 105 | - ecr:CreateRepository 106 | - ecr:DescribeRepositories 107 | - ecr:InitiateLayerUpload 108 | - ecr:CompleteLayerUpload 109 | - ecr:UploadLayerPart 110 | - ecr:TagResource 111 | - ecr:PutImage 112 | - ecr:DescribeImages 113 | - ecr:BatchDeleteImage 114 | Resource: 115 | - "*" 116 | - !Sub "arn:aws:ecr:${AWS::Region}:${AWS::AccountId}:repository/*" 117 | - Effect: Allow 118 | Action: 119 | - s3:ListBucket 120 | Resource: !Sub 121 | - "arn:aws:s3:::${SolutionS3BucketName}-${AWS::Region}" 122 | - SolutionS3BucketName: 123 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 124 | - Effect: Allow 125 | Action: 126 | - s3:GetObject 127 | Resource: !Sub 128 | - "arn:aws:s3:::${SolutionS3BucketName}-${AWS::Region}/${SolutionName}/*" 129 | - SolutionS3BucketName: 130 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 131 | SolutionName: !Ref SolutionName 132 | - Effect: Allow 133 | Action: 134 | - s3:ListBucket 135 | - s3:DeleteBucket 136 | - s3:GetBucketLocation 137 | - s3:ListBucketMultipartUploads 138 | Resource: 139 | - !Sub "arn:aws:s3:::${S3Bucket}" 140 | - Effect: Allow 141 | Action: 142 | - s3:AbortMultipartUpload 143 | - s3:ListObject 144 | - s3:GetObject 145 | - s3:PutObject 146 | - s3:DeleteObject 147 | Resource: 148 | - !Sub "arn:aws:s3:::${S3Bucket}" 149 | - !Sub "arn:aws:s3:::${S3Bucket}/*" 150 | - Effect: Allow 151 | Action: 152 | - s3:CreateBucket 153 | - s3:ListBucket 154 | - s3:GetObject 155 | - s3:GetObjectVersion 156 | - s3:PutObject 157 | - s3:DeleteObject 158 | Resource: 159 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}" 160 | - !Sub "arn:aws:s3:::sagemaker-${AWS::Region}-${AWS::AccountId}/*" 161 | 162 | Metadata: 163 | cfn_nag: 164 | rules_to_suppress: 165 | - id: W12 166 | reason: ECR GetAuthorizationToken is non resource-specific action 167 | 168 | Outputs: 169 | SageMakerRoleArn: 170 | Description: "SageMaker Execution Role for the solution" 171 | Value: !GetAtt SageMakerIAMRole.Arn 172 | -------------------------------------------------------------------------------- /cloudformation/defect-detection-sagemaker-notebook-instance.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection notebook stack" 3 | Parameters: 4 | SolutionPrefix: 5 | Type: String 6 | SolutionName: 7 | Type: String 8 | S3Bucket: 9 | Type: String 10 | SageMakerIAMRoleArn: 11 | Type: String 12 | SageMakerNotebookInstanceType: 13 | Type: String 14 | StackVersion: 15 | Type: String 16 | 17 | Mappings: 18 | S3: 19 | release: 20 | BucketPrefix: "sagemaker-solutions-prod" 21 | development: 22 | BucketPrefix: "sagemaker-solutions-devo" 23 | 24 | Resources: 25 | NotebookInstance: 26 | Type: AWS::SageMaker::NotebookInstance 27 | Properties: 28 | DirectInternetAccess: Enabled 29 | InstanceType: !Ref SageMakerNotebookInstanceType 30 | LifecycleConfigName: !GetAtt LifeCycleConfig.NotebookInstanceLifecycleConfigName 31 | NotebookInstanceName: !Sub "${SolutionPrefix}" 32 | RoleArn: !Sub "${SageMakerIAMRoleArn}" 33 | VolumeSizeInGB: 100 34 | Metadata: 35 | cfn_nag: 36 | rules_to_suppress: 37 | - id: W1201 38 | reason: Solution does not have KMS encryption enabled by default 39 | LifeCycleConfig: 40 | Type: AWS::SageMaker::NotebookInstanceLifecycleConfig 41 | Properties: 42 | NotebookInstanceLifecycleConfigName: !Sub "${SolutionPrefix}-nb-lifecycle-config" 43 | OnStart: 44 | - Content: 45 | Fn::Base64: | 46 | set -e 47 | sudo -u ec2-user -i <> stack_outputs.json 69 | echo ' "AccountID": "${AWS::AccountId}",' >> stack_outputs.json 70 | echo ' "AWSRegion": "${AWS::Region}",' >> stack_outputs.json 71 | echo ' "IamRole": "${SageMakerIAMRoleArn}",' >> stack_outputs.json 72 | echo ' "SolutionPrefix": "${SolutionPrefix}",' >> stack_outputs.json 73 | echo ' "SolutionName": "${SolutionName}",' >> stack_outputs.json 74 | echo ' "SolutionS3Bucket": "${SolutionsRefBucketBase}",' >> stack_outputs.json 75 | echo ' "S3Bucket": "${S3Bucket}"' >> stack_outputs.json 76 | echo '}' >> stack_outputs.json 77 | cat stack_outputs.json 78 | sudo chown -R ec2-user:ec2-user * 79 | EOF 80 | - SolutionsRefBucketBase: 81 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 82 | 83 | Outputs: 84 | SourceCode: 85 | Description: "Open Jupyter IDE. This authenticate you against Jupyter." 86 | Value: !Sub "https://${NotebookInstance.NotebookInstanceName}.notebook.${AWS::Region}.sagemaker.aws/" 87 | 88 | SageMakerNotebookInstanceSignOn: 89 | Description: "Link to the SageMaker notebook instance" 90 | Value: !Sub "https://${NotebookInstance.NotebookInstanceName}.notebook.${AWS::Region}.sagemaker.aws/notebooks/notebooks" 91 | -------------------------------------------------------------------------------- /cloudformation/defect-detection.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection: 3 | Solution for training and deploying deep learning models for defect detection in images using Amazon SageMaker. 4 | Version 1" 5 | 6 | Parameters: 7 | SolutionPrefix: 8 | Type: String 9 | Default: "sagemaker-soln-dfd-" 10 | Description: | 11 | Used to name resources created as part of this stack (and inside nested stacks too). 12 | Can be the same as the stack name used by AWS CloudFormation, but this field has extra 13 | constraints because it's used to name resources with restrictions (e.g. Amazon S3 bucket 14 | names cannot contain capital letters). 15 | AllowedPattern: '^sagemaker-soln-dfd[a-z0-9\-]{1,20}$' 16 | ConstraintDescription: | 17 | Only allowed to use lowercase letters, hyphens and/or numbers. 18 | Should also start with 'sagemaker-soln-dfd-' for permission management. 19 | 20 | SolutionName: 21 | Description: | 22 | Prefix for the solution name. Needs to be 'sagemaker-defect-detection' 23 | or begin with 'sagemaker-defect-detection-' followed by a set of letters and hyphens. 24 | Used to specify a particular directory on S3, that can correspond to a development branch. 25 | Type: String 26 | Default: "sagemaker-defect-detection/1.1.0" 27 | AllowedPattern: '^sagemaker-defect-detection(/[0-9]+\.[0-9]+\.[0-9]+-?[a-zA-Z-0-9\.+]*)?$' 28 | 29 | IamRole: 30 | Type: String 31 | Default: "" 32 | Description: | 33 | IAM Role that will be attached to the resources created by this CloudFormation to grant them permissions to 34 | perform their required functions. This role should allow SageMaker and Lambda perform the required actions like 35 | creating training jobs and processing jobs. If left blank, the template will attempt to create a role for you. 36 | This can cause a stack creation error if you don't have privileges to create new roles. 37 | 38 | SageMakerNotebookInstanceType: 39 | Description: SageMaker notebook instance type. 40 | Type: String 41 | Default: "ml.t3.medium" 42 | 43 | CreateSageMakerNotebookInstance: 44 | Description: Whether to launch sagemaker notebook instance 45 | Type: String 46 | AllowedValues: 47 | - "true" 48 | - "false" 49 | Default: "true" 50 | 51 | StackVersion: 52 | Description: | 53 | CloudFormation Stack version. 54 | Use "release" version unless you are customizing the 55 | CloudFormation templates and the solution artifacts in S3 bucket 56 | Type: String 57 | Default: release 58 | AllowedValues: 59 | - release 60 | - development 61 | 62 | Metadata: 63 | AWS::CloudFormation::Interface: 64 | ParameterGroups: 65 | - Label: 66 | default: "Solution Configuration" 67 | Parameters: 68 | - SolutionPrefix 69 | - SolutionName 70 | - IamRole 71 | - Label: 72 | default: "Advanced Configuration" 73 | Parameters: 74 | - SageMakerNotebookInstanceType 75 | - CreateSageMakerNotebookInstance 76 | 77 | ParameterLabels: 78 | SolutionPrefix: 79 | default: "Solution Resources Name Prefix" 80 | SolutionName: 81 | default: "Name of the solution" 82 | IamRole: 83 | default: "Solution IAM Role Arn" 84 | CreateSageMakerNotebookInstance: 85 | default: "Launch SageMaker Notebook Instance" 86 | SageMakerNotebookInstanceType: 87 | default: "SageMaker Notebook Instance Type" 88 | StackVersion: 89 | default: "Solution Stack Version" 90 | 91 | Conditions: 92 | CreateClassicSageMakerResources: 93 | !Equals [!Ref CreateSageMakerNotebookInstance, "true"] 94 | CreateCustomSolutionRole: !Equals [!Ref IamRole, ""] 95 | 96 | Mappings: 97 | S3: 98 | release: 99 | BucketPrefix: "sagemaker-solutions-prod" 100 | development: 101 | BucketPrefix: "sagemaker-solutions-devo" 102 | 103 | Resources: 104 | S3Bucket: 105 | Type: AWS::S3::Bucket 106 | DeletionPolicy: Retain 107 | UpdateReplacePolicy: "Retain" 108 | Properties: 109 | BucketName: !Sub "${SolutionPrefix}-${AWS::AccountId}-${AWS::Region}" 110 | PublicAccessBlockConfiguration: 111 | BlockPublicAcls: true 112 | BlockPublicPolicy: true 113 | IgnorePublicAcls: true 114 | RestrictPublicBuckets: true 115 | BucketEncryption: 116 | ServerSideEncryptionConfiguration: 117 | - ServerSideEncryptionByDefault: 118 | SSEAlgorithm: AES256 119 | 120 | Metadata: 121 | cfn_nag: 122 | rules_to_suppress: 123 | - id: W35 124 | reason: Configuring logging requires supplying an existing customer S3 bucket to store logs. 125 | - id: W51 126 | reason: Default access policy is sufficient. 127 | 128 | SageMakerPermissionsStack: 129 | Type: "AWS::CloudFormation::Stack" 130 | Condition: CreateCustomSolutionRole 131 | Properties: 132 | TemplateURL: !Sub 133 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/defect-detection-permissions.yaml" 134 | - SolutionRefBucketBase: 135 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 136 | Parameters: 137 | SolutionPrefix: !Ref SolutionPrefix 138 | SolutionName: !Ref SolutionName 139 | S3Bucket: !Ref S3Bucket 140 | StackVersion: !Ref StackVersion 141 | 142 | SageMakerStack: 143 | Type: "AWS::CloudFormation::Stack" 144 | Condition: CreateClassicSageMakerResources 145 | Properties: 146 | TemplateURL: !Sub 147 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/defect-detection-sagemaker-notebook-instance.yaml" 148 | - SolutionRefBucketBase: 149 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 150 | Parameters: 151 | SolutionPrefix: !Ref SolutionPrefix 152 | SolutionName: !Ref SolutionName 153 | S3Bucket: !Ref S3Bucket 154 | SageMakerIAMRoleArn: 155 | !If [ 156 | CreateCustomSolutionRole, 157 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 158 | !Ref IamRole, 159 | ] 160 | SageMakerNotebookInstanceType: !Ref SageMakerNotebookInstanceType 161 | StackVersion: !Ref StackVersion 162 | 163 | SolutionAssistantStack: 164 | Type: "AWS::CloudFormation::Stack" 165 | Properties: 166 | TemplateURL: !Sub 167 | - "https://${SolutionRefBucketBase}-${AWS::Region}.s3.${AWS::Region}.amazonaws.com/${SolutionName}/cloudformation/solution-assistant/solution-assistant.yaml" 168 | - SolutionRefBucketBase: 169 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 170 | Parameters: 171 | SolutionPrefix: !Ref SolutionPrefix 172 | SolutionName: !Ref SolutionName 173 | StackName: !Ref AWS::StackName 174 | S3Bucket: !Ref S3Bucket 175 | SolutionS3Bucket: !Sub 176 | - "${SolutionRefBucketBase}-${AWS::Region}" 177 | - SolutionRefBucketBase: 178 | !FindInMap [S3, !Ref StackVersion, BucketPrefix] 179 | RoleArn: 180 | !If [ 181 | CreateCustomSolutionRole, 182 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 183 | !Ref IamRole, 184 | ] 185 | 186 | Outputs: 187 | SolutionName: 188 | Description: "The name for the solution, can be used to deploy different versions of the solution" 189 | Value: !Ref SolutionName 190 | 191 | SourceCode: 192 | Condition: CreateClassicSageMakerResources 193 | Description: "Open Jupyter IDE. This authenticate you against Jupyter" 194 | Value: !GetAtt SageMakerStack.Outputs.SourceCode 195 | 196 | NotebookInstance: 197 | Description: "SageMaker Notebook instance to manually orchestrate data preprocessing, model training and deploying an endpoint" 198 | Value: 199 | !If [ 200 | CreateClassicSageMakerResources, 201 | !GetAtt SageMakerStack.Outputs.SageMakerNotebookInstanceSignOn, 202 | "", 203 | ] 204 | 205 | AccountID: 206 | Description: "AWS Account ID to be passed downstream to the notebook instance" 207 | Value: !Ref AWS::AccountId 208 | 209 | AWSRegion: 210 | Description: "AWS Region to be passed downstream to the notebook instance" 211 | Value: !Ref AWS::Region 212 | 213 | IamRole: 214 | Description: "Arn of SageMaker Execution Role" 215 | Value: 216 | !If [ 217 | CreateCustomSolutionRole, 218 | !GetAtt SageMakerPermissionsStack.Outputs.SageMakerRoleArn, 219 | !Ref IamRole, 220 | ] 221 | 222 | SolutionPrefix: 223 | Description: "Solution Prefix for naming SageMaker transient resources" 224 | Value: !Ref SolutionPrefix 225 | 226 | S3Bucket: 227 | Description: "S3 bucket name used in the solution to store artifacts" 228 | Value: !Ref S3Bucket 229 | 230 | SolutionS3Bucket: 231 | Description: "Solution S3 bucket" 232 | Value: !FindInMap [S3, !Ref StackVersion, BucketPrefix] 233 | 234 | SageMakerMode: 235 | Value: !If [CreateClassicSageMakerResources, "NotebookInstance", "Studio"] 236 | 237 | StackName: 238 | Value: !Ref AWS::StackName 239 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/requirements.in: -------------------------------------------------------------------------------- 1 | crhelper 2 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | crhelper==2.0.6 # via -r requirements.in 8 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/solution-assistant.yaml: -------------------------------------------------------------------------------- 1 | AWSTemplateFormatVersion: "2010-09-09" 2 | Description: "(SA0015) - sagemaker-defect-detection solution assistant stack" 3 | 4 | Parameters: 5 | SolutionPrefix: 6 | Type: String 7 | SolutionName: 8 | Type: String 9 | StackName: 10 | Type: String 11 | S3Bucket: 12 | Type: String 13 | SolutionS3Bucket: 14 | Type: String 15 | RoleArn: 16 | Type: String 17 | 18 | Mappings: 19 | Function: 20 | SolutionAssistant: 21 | S3Key: "build/solution-assistant.zip" 22 | 23 | Resources: 24 | SolutionAssistant: 25 | Type: "Custom::SolutionAssistant" 26 | Properties: 27 | SolutionPrefix: !Ref SolutionPrefix 28 | ServiceToken: !GetAtt SolutionAssistantLambda.Arn 29 | S3Bucket: !Ref S3Bucket 30 | StackName: !Ref StackName 31 | 32 | SolutionAssistantLambda: 33 | Type: AWS::Lambda::Function 34 | Properties: 35 | Handler: "lambda_fn.handler" 36 | FunctionName: !Sub "${SolutionPrefix}-solution-assistant" 37 | Role: !Ref RoleArn 38 | Runtime: "python3.8" 39 | Code: 40 | S3Bucket: !Ref SolutionS3Bucket 41 | S3Key: !Sub 42 | - "${SolutionName}/${LambdaS3Key}" 43 | - LambdaS3Key: !FindInMap [Function, SolutionAssistant, S3Key] 44 | Timeout: 60 45 | Metadata: 46 | cfn_nag: 47 | rules_to_suppress: 48 | - id: W58 49 | reason: >- 50 | The required permissions are provided in the permissions stack. 51 | -------------------------------------------------------------------------------- /cloudformation/solution-assistant/src/lambda_fn.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import sys 3 | 4 | sys.path.append("./site-packages") 5 | from crhelper import CfnResource 6 | 7 | helper = CfnResource() 8 | 9 | 10 | @helper.create 11 | def on_create(_, __): 12 | pass 13 | 14 | 15 | @helper.update 16 | def on_update(_, __): 17 | pass 18 | 19 | 20 | def delete_sagemaker_endpoint(endpoint_name): 21 | sagemaker_client = boto3.client("sagemaker") 22 | try: 23 | sagemaker_client.delete_endpoint(EndpointName=endpoint_name) 24 | print("Successfully deleted endpoint " "called '{}'.".format(endpoint_name)) 25 | except sagemaker_client.exceptions.ClientError as e: 26 | if "Could not find endpoint" in str(e): 27 | print("Could not find endpoint called '{}'. " "Skipping delete.".format(endpoint_name)) 28 | else: 29 | raise e 30 | 31 | 32 | def delete_sagemaker_endpoint_config(endpoint_config_name): 33 | sagemaker_client = boto3.client("sagemaker") 34 | try: 35 | sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name) 36 | print("Successfully deleted endpoint configuration " "called '{}'.".format(endpoint_config_name)) 37 | except sagemaker_client.exceptions.ClientError as e: 38 | if "Could not find endpoint configuration" in str(e): 39 | print( 40 | "Could not find endpoint configuration called '{}'. " "Skipping delete.".format(endpoint_config_name) 41 | ) 42 | else: 43 | raise e 44 | 45 | 46 | def delete_sagemaker_model(model_name): 47 | sagemaker_client = boto3.client("sagemaker") 48 | try: 49 | sagemaker_client.delete_model(ModelName=model_name) 50 | print("Successfully deleted model called '{}'.".format(model_name)) 51 | except sagemaker_client.exceptions.ClientError as e: 52 | if "Could not find model" in str(e): 53 | print("Could not find model called '{}'. " "Skipping delete.".format(model_name)) 54 | else: 55 | raise e 56 | 57 | 58 | def delete_s3_objects(bucket_name): 59 | s3_resource = boto3.resource("s3") 60 | try: 61 | s3_resource.Bucket(bucket_name).objects.all().delete() 62 | print("Successfully deleted objects in bucket " "called '{}'.".format(bucket_name)) 63 | except s3_resource.meta.client.exceptions.NoSuchBucket: 64 | print("Could not find bucket called '{}'. " "Skipping delete.".format(bucket_name)) 65 | 66 | 67 | def delete_s3_bucket(bucket_name): 68 | s3_resource = boto3.resource("s3") 69 | try: 70 | s3_resource.Bucket(bucket_name).delete() 71 | print("Successfully deleted bucket " "called '{}'.".format(bucket_name)) 72 | except s3_resource.meta.client.exceptions.NoSuchBucket: 73 | print("Could not find bucket called '{}'. " "Skipping delete.".format(bucket_name)) 74 | 75 | 76 | @helper.delete 77 | def on_delete(event, __): 78 | # remove sagemaker endpoints 79 | solution_prefix = event["ResourceProperties"]["SolutionPrefix"] 80 | endpoint_names = [ 81 | "{}-demo-endpoint".format(solution_prefix), # make sure it is the same as your endpoint name 82 | "{}-demo-model".format(solution_prefix), 83 | "{}-finetuned-endpoint".format(solution_prefix), 84 | "{}-detector-from-scratch-endpoint".format(solution_prefix), 85 | "{}-classification-endpoint".format(solution_prefix), 86 | ] 87 | for endpoint_name in endpoint_names: 88 | delete_sagemaker_model(endpoint_name) 89 | delete_sagemaker_endpoint_config(endpoint_name) 90 | delete_sagemaker_endpoint(endpoint_name) 91 | 92 | # remove files in s3 93 | output_bucket = event["ResourceProperties"]["S3Bucket"] 94 | delete_s3_objects(output_bucket) 95 | 96 | # delete buckets 97 | delete_s3_bucket(output_bucket) 98 | 99 | 100 | def handler(event, context): 101 | helper(event, context) 102 | -------------------------------------------------------------------------------- /docs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/arch.png -------------------------------------------------------------------------------- /docs/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/data.png -------------------------------------------------------------------------------- /docs/data_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/data_flow.png -------------------------------------------------------------------------------- /docs/launch.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Click to Create Stack 5 | 6 | 7 | -------------------------------------------------------------------------------- /docs/numerical.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/numerical.png -------------------------------------------------------------------------------- /docs/sagemaker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/sagemaker.png -------------------------------------------------------------------------------- /docs/sample1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/sample1.png -------------------------------------------------------------------------------- /docs/sample2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/sample2.png -------------------------------------------------------------------------------- /docs/sample3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/sample3.png -------------------------------------------------------------------------------- /docs/train_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/docs/train_arch.png -------------------------------------------------------------------------------- /manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "1.1.1" 3 | } 4 | -------------------------------------------------------------------------------- /notebooks/0_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## SageMaker Defect Detection Demo\n", 20 | "\n", 21 | "In this notebook, we deploy an endpoint from a provided pretrained detection model that was already trained on **NEU-DET** dataset. Then, we send some image samples with defects for detection and visual the results" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "\n", 34 | "import sagemaker\n", 35 | "from sagemaker.s3 import S3Downloader\n", 36 | "\n", 37 | "sagemaker_session = sagemaker.Session()\n", 38 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 39 | "role = sagemaker_config[\"IamRole\"]\n", 40 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 41 | "region = sagemaker_config[\"AWSRegion\"]\n", 42 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 43 | "bucket = sagemaker_config[\"S3Bucket\"]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 60 | "original_data_prefix = \"data/NEU-DET.zip\"\n", 61 | "original_data = f\"{original_bucket}/{original_data_prefix}\"\n", 62 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"\n", 64 | "print(\"original data: \")\n", 65 | "S3Downloader.list(original_data)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "For easiler data processing, depending on the dataset, we unify the class and label names using the scripts from `prepare_data` which should take less than **5 minutes** to complete. This is done once throughout all our notebooks" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "%%time\n", 82 | "\n", 83 | "RAW_DATA_PATH = !echo $PWD/raw_neu_det\n", 84 | "RAW_DATA_PATH = RAW_DATA_PATH.n\n", 85 | "DATA_PATH = !echo $PWD/neu_det\n", 86 | "DATA_PATH = DATA_PATH.n\n", 87 | "\n", 88 | "!mkdir -p $RAW_DATA_PATH\n", 89 | "!aws s3 cp $original_data $RAW_DATA_PATH\n", 90 | "\n", 91 | "!mkdir -p $DATA_PATH\n", 92 | "!python ../src/prepare_data/neu.py $RAW_DATA_PATH/NEU-DET.zip $DATA_PATH" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "After data preparation, we need upload the prepare data to S3 and setup some paths that will be used throughtout the notebooks" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "%%time\n", 109 | "prefix = \"neu-det\"\n", 110 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 111 | "sources = f\"{neu_det_s3}/code/\"\n", 112 | "train_output = f\"{neu_det_s3}/output/\"\n", 113 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 114 | "!aws s3 sync $DATA_PATH $neu_det_prepared_s3 --quiet # remove the --quiet flag to view the sync logs\n", 115 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 116 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 117 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"\n", 118 | "!aws s3 sync $original_pretained_checkpoint $s3_pretrained\n", 119 | "!aws s3 ls $s3_pretrained\n", 120 | "!aws s3 cp $original_sources $sources" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "## Visualization\n", 128 | "\n", 129 | "Let examine some datasets that we will use later by providing an `ID`" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "import copy\n", 139 | "\n", 140 | "import numpy as np\n", 141 | "import torch\n", 142 | "from PIL import Image\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "\n", 145 | "try:\n", 146 | " import sagemaker_defect_detection\n", 147 | "except ImportError:\n", 148 | " import sys\n", 149 | " from pathlib import Path\n", 150 | "\n", 151 | " ROOT = Path(\"../src\").resolve()\n", 152 | " sys.path.insert(0, str(ROOT))\n", 153 | "\n", 154 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 155 | "\n", 156 | "SPLIT = \"test\"\n", 157 | "ID = 10\n", 158 | "assert 0 <= ID <= 300\n", 159 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 160 | "images, targets, _ = dataset[ID]\n", 161 | "original_image = copy.deepcopy(images)\n", 162 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 163 | "original_labels = targets[\"labels\"].numpy().copy()\n", 164 | "print(f\"first images size: {original_image.shape}\")\n", 165 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 166 | "print(f\"target labels: {original_labels}\")" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "And we can now visualize it using the provided utilities as follows" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 183 | "\n", 184 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 185 | "\n", 186 | "visualize(\n", 187 | " original_image_unnorm,\n", 188 | " [original_boxes],\n", 189 | " [original_labels],\n", 190 | " colors=[(255, 0, 0)],\n", 191 | " titles=[\"original\", \"ground truth\"],\n", 192 | ")" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "For our demo, we deploy an endpoint using a provided pretrained checkpoint. It takes about **10 minutes** to finish" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "%%time\n", 209 | "from os import path as osp\n", 210 | "\n", 211 | "from sagemaker.pytorch import PyTorchModel\n", 212 | "\n", 213 | "demo_model = PyTorchModel(\n", 214 | " osp.join(s3_pretrained, \"model.tar.gz\"),\n", 215 | " role,\n", 216 | " entry_point=\"detector.py\",\n", 217 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 218 | " framework_version=\"1.5\",\n", 219 | " py_version=\"py3\",\n", 220 | " sagemaker_session=sagemaker_session,\n", 221 | " name=sagemaker_config[\"SolutionPrefix\"] + \"-demo-model\",\n", 222 | ")\n", 223 | "\n", 224 | "demo_detector = demo_model.deploy(\n", 225 | " initial_instance_count=1,\n", 226 | " instance_type=\"ml.m5.xlarge\",\n", 227 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\",\n", 228 | ")" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": {}, 234 | "source": [ 235 | "## Inference\n", 236 | "\n", 237 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": {}, 244 | "outputs": [], 245 | "source": [ 246 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Now the input is ready and we can get some results" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "%%time\n", 263 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 264 | "np_load_old = np.load\n", 265 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 266 | "demo_predictions = demo_detector.predict(input)\n", 267 | "np.load = np_load_old" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "And finally, we visualize them as follows " 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": null, 280 | "metadata": { 281 | "tags": [] 282 | }, 283 | "outputs": [], 284 | "source": [ 285 | "visualize(\n", 286 | " original_image_unnorm,\n", 287 | " [original_boxes, demo_predictions[0][\"boxes\"]],\n", 288 | " [original_labels, demo_predictions[0][\"labels\"]],\n", 289 | " colors=[(255, 0, 0), (0, 0, 255)],\n", 290 | " titles=[\"original\", \"ground truth\", \"pretrained demo\"],\n", 291 | " dpi=200,\n", 292 | ")" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "metadata": {}, 298 | "source": [ 299 | "## Optional: Delete the endpoint and model\n", 300 | "\n", 301 | "**Note:** to follow all the notebooks, it is required to keep demo model and the demo endpoint. It will be automatically deleted when you delete the entire resources/stack. However, if you need to, please uncomment and run the next cell\n", 302 | "\n", 303 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": null, 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "# demo_detector.delete_model()\n", 313 | "# demo_detector.delete_endpoint()" 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": {}, 319 | "source": [ 320 | "### [Click here to continue](./1_retrain_from_checkpoint.ipynb)" 321 | ] 322 | } 323 | ], 324 | "metadata": { 325 | "kernelspec": { 326 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 327 | "name": "python361064bitpytorchlatestp36conda2dfac45b320c45f3a4d1e89ca46b60d1" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.6.10-final" 340 | } 341 | }, 342 | "nbformat": 4, 343 | "nbformat_minor": 2 344 | } 345 | -------------------------------------------------------------------------------- /notebooks/1_retrain_from_checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## Resume Training\n", 20 | "\n", 21 | "In this notebook, we retrain our pretrained detector for a few more epochs and compare its results. The same process can be applied when finetuning on another dataset. For the purpose of this notebook, we use the same **NEU-DET** dataset.\n", 22 | "\n", 23 | "## Finetuning\n", 24 | "\n", 25 | "Finetuning is one way to do Transfer Learning. Finetuning a Deep Learning model on one particular task, involves using the learned weights from a particular dataset to enhance the performace of the model on usually another dataset. In a sense, finetuning can be done over the same dataset used in the intial training but perhaps with different hyperparameters.\n" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import json\n", 35 | "\n", 36 | "import sagemaker\n", 37 | "from sagemaker.s3 import S3Downloader\n", 38 | "\n", 39 | "sagemaker_session = sagemaker.Session()\n", 40 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 41 | "role = sagemaker_config[\"IamRole\"]\n", 42 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 43 | "region = sagemaker_config[\"AWSRegion\"]\n", 44 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 45 | "bucket = sagemaker_config[\"S3Bucket\"]" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 62 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "Note that for easiler data processing, we have already executed `prepare_data` once in our `0_demo.ipynb` and have already uploaded the prepared data to our S3 bucket" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "DATA_PATH = !echo $PWD/neu_det\n", 80 | "DATA_PATH = DATA_PATH.n" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": { 94 | "tags": [] 95 | }, 96 | "outputs": [], 97 | "source": [ 98 | "prefix = \"neu-det\"\n", 99 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 100 | "sources = f\"{neu_det_s3}/code/\"\n", 101 | "train_output = f\"{neu_det_s3}/output/\"\n", 102 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 103 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 104 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 105 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## Visualization\n", 113 | "\n", 114 | "Let examine some datasets that we will use later by providing an `ID`" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "import copy\n", 124 | "\n", 125 | "import numpy as np\n", 126 | "import torch\n", 127 | "from PIL import Image\n", 128 | "from torch.utils.data import DataLoader\n", 129 | "\n", 130 | "try:\n", 131 | " import sagemaker_defect_detection\n", 132 | "except ImportError:\n", 133 | " import sys\n", 134 | " from pathlib import Path\n", 135 | "\n", 136 | " ROOT = Path(\"../src\").resolve()\n", 137 | " sys.path.insert(0, str(ROOT))\n", 138 | "\n", 139 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 140 | "\n", 141 | "SPLIT = \"test\"\n", 142 | "ID = 30\n", 143 | "assert 0 <= ID <= 300\n", 144 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 145 | "images, targets, _ = dataset[ID]\n", 146 | "original_image = copy.deepcopy(images)\n", 147 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 148 | "original_labels = targets[\"labels\"].numpy().copy()\n", 149 | "print(f\"first images size: {original_image.shape}\")\n", 150 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 151 | "print(f\"target labels: {original_labels}\")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "And we can now visualize it using the provided utilities as follows" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 168 | "\n", 169 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 170 | "\n", 171 | "visualize(\n", 172 | " original_image_unnorm,\n", 173 | " [original_boxes],\n", 174 | " [original_labels],\n", 175 | " colors=[(255, 0, 0)],\n", 176 | " titles=[\"original\", \"ground truth\"],\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Here we resume from a provided pretrained checkpoint `epoch=294-loss=0.654-main_score=0.349.ckpt` that we have copied into our `s3_pretrained`. This takes about **10 minutes** to complete" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": { 191 | "tags": [ 192 | "outputPrepend" 193 | ] 194 | }, 195 | "outputs": [], 196 | "source": [ 197 | "%%time\n", 198 | "import logging\n", 199 | "from os import path as osp\n", 200 | "\n", 201 | "from sagemaker.pytorch import PyTorch\n", 202 | "\n", 203 | "NUM_CLASSES = 7 # 6 classes + 1 for background\n", 204 | "# Note: resnet34 was used in the pretrained model and it has to match the pretrained model backbone\n", 205 | "# if need resnet50, need to train from scratch\n", 206 | "BACKBONE = \"resnet34\"\n", 207 | "assert BACKBONE in [\n", 208 | " \"resnet34\",\n", 209 | " \"resnet50\",\n", 210 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in detector.py\"\n", 211 | "EPOCHS = 5\n", 212 | "LEARNING_RATE = 1e-4\n", 213 | "SEED = 123\n", 214 | "\n", 215 | "hyperparameters = {\n", 216 | " \"backbone\": BACKBONE, # the backbone resnet model for feature extraction\n", 217 | " \"num-classes\": NUM_CLASSES, # number of classes + background\n", 218 | " \"epochs\": EPOCHS, # number of epochs to finetune\n", 219 | " \"learning-rate\": LEARNING_RATE, # learning rate for optimizer\n", 220 | " \"seed\": SEED, # random number generator seed\n", 221 | "}\n", 222 | "\n", 223 | "assert not isinstance(sagemaker_session, sagemaker.LocalSession), \"local session as share memory cannot be altered\"\n", 224 | "\n", 225 | "finetuned_model = PyTorch(\n", 226 | " entry_point=\"detector.py\",\n", 227 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 228 | " role=role,\n", 229 | " train_instance_count=1,\n", 230 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 231 | " hyperparameters=hyperparameters,\n", 232 | " py_version=\"py3\",\n", 233 | " framework_version=\"1.5\",\n", 234 | " sagemaker_session=sagemaker_session,\n", 235 | " output_path=train_output,\n", 236 | " checkpoint_s3_uri=s3_checkpoint,\n", 237 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 238 | " # container_log_level=logging.DEBUG,\n", 239 | ")\n", 240 | "\n", 241 | "finetuned_model.fit(\n", 242 | " {\n", 243 | " \"training\": neu_det_prepared_s3,\n", 244 | " \"pretrained_checkpoint\": osp.join(s3_pretrained, \"epoch=294-loss=0.654-main_score=0.349.ckpt\"),\n", 245 | " }\n", 246 | ")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "metadata": {}, 252 | "source": [ 253 | "Then, we deploy our new model which takes about **10 minutes** to complete" 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "%%time\n", 263 | "finetuned_detector = finetuned_model.deploy(\n", 264 | " initial_instance_count=1,\n", 265 | " instance_type=\"ml.m5.xlarge\",\n", 266 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-finetuned-endpoint\",\n", 267 | ")" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "## Inference\n", 275 | "\n", 276 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 286 | ] 287 | }, 288 | { 289 | "cell_type": "markdown", 290 | "metadata": {}, 291 | "source": [ 292 | "Now the input is ready and we can get some results" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "%%time\n", 302 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 303 | "np_load_old = np.load\n", 304 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 305 | "finetuned_predictions = finetuned_detector.predict(input)\n", 306 | "np.load = np_load_old" 307 | ] 308 | }, 309 | { 310 | "cell_type": "markdown", 311 | "metadata": {}, 312 | "source": [ 313 | "Here we want to compare the results of the new model and the pretrained model that we already deployed in `0_demo.ipynb` visually by calling our endpoint from SageMaker runtime using `boto3`" 314 | ] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "execution_count": null, 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "import boto3\n", 323 | "import botocore\n", 324 | "\n", 325 | "config = botocore.config.Config(read_timeout=200)\n", 326 | "runtime = boto3.client(\"runtime.sagemaker\", config=config)\n", 327 | "payload = json.dumps(input.tolist() if isinstance(input, np.ndarray) else input)\n", 328 | "response = runtime.invoke_endpoint(\n", 329 | " EndpointName=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\", ContentType=\"application/json\", Body=payload\n", 330 | ")\n", 331 | "demo_predictions = json.loads(response[\"Body\"].read().decode())" 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "metadata": {}, 337 | "source": [ 338 | "Here comes the slight changes in inference" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": null, 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "visualize(\n", 348 | " original_image_unnorm,\n", 349 | " [original_boxes, demo_predictions[0][\"boxes\"], finetuned_predictions[0][\"boxes\"]],\n", 350 | " [original_labels, demo_predictions[0][\"labels\"], finetuned_predictions[0][\"labels\"]],\n", 351 | " colors=[(255, 0, 0), (0, 0, 255), (127, 0, 127)],\n", 352 | " titles=[\"original\", \"ground truth\", \"pretrained\", \"finetuned\"],\n", 353 | " dpi=250,\n", 354 | ")" 355 | ] 356 | }, 357 | { 358 | "cell_type": "markdown", 359 | "metadata": {}, 360 | "source": [ 361 | "## Optional: Delete the endpoint and model\n", 362 | "\n", 363 | "When you are done with the endpoint, you should clean it up.\n", 364 | "\n", 365 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account.\n" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "finetuned_detector.delete_model()\n", 375 | "finetuned_detector.delete_endpoint()" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "### [Click here to continue](./2_detection_from_scratch.ipynb)" 383 | ] 384 | } 385 | ], 386 | "metadata": { 387 | "kernelspec": { 388 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 389 | "name": "python361064bitpytorchlatestp36conda2dfac45b320c45f3a4d1e89ca46b60d1" 390 | }, 391 | "language_info": { 392 | "codemirror_mode": { 393 | "name": "ipython", 394 | "version": 3 395 | }, 396 | "file_extension": ".py", 397 | "mimetype": "text/x-python", 398 | "name": "python", 399 | "nbconvert_exporter": "python", 400 | "pygments_lexer": "ipython3", 401 | "version": "3.6.10-final" 402 | } 403 | }, 404 | "nbformat": 4, 405 | "nbformat_minor": 2 406 | } 407 | -------------------------------------------------------------------------------- /notebooks/2_detection_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "**Note**: To *Run All* successfully, make sure you have executed the entire demo notebook `0_demo.ipynb` first.\n", 18 | "\n", 19 | "## Training our Detector from Scratch\n", 20 | "\n", 21 | "In this notebook, we will see how to train our detector from scratch" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import sagemaker\n", 33 | "from sagemaker.s3 import S3Downloader\n", 34 | "\n", 35 | "sagemaker_session = sagemaker.Session()\n", 36 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 37 | "role = sagemaker_config[\"IamRole\"]\n", 38 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 39 | "region = sagemaker_config[\"AWSRegion\"]\n", 40 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 41 | "bucket = sagemaker_config[\"S3Bucket\"]" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "First, we download our **NEU-DET** dataset from our public S3 bucket" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 58 | "original_pretained_checkpoint = f\"{original_bucket}/pretrained\"\n", 59 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "Note that for easiler data processing, we have already executed `prepare_data` once in our `0_demo.ipynb` and have already uploaded the prepared data to our S3 bucket" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": { 73 | "tags": [] 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "DATA_PATH = !echo $PWD/neu_det\n", 78 | "DATA_PATH = DATA_PATH.n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "tags": [] 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "prefix = \"neu-det\"\n", 97 | "neu_det_s3 = f\"s3://{bucket}/{prefix}\"\n", 98 | "sources = f\"{neu_det_s3}/code/\"\n", 99 | "train_output = f\"{neu_det_s3}/output/\"\n", 100 | "neu_det_prepared_s3 = f\"{neu_det_s3}/data/\"\n", 101 | "s3_checkpoint = f\"{neu_det_s3}/checkpoint/\"\n", 102 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 103 | "s3_pretrained = f\"{neu_det_s3}/pretrained/\"" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Visualization\n", 111 | "\n", 112 | "Let examine some datasets that we will use later by providing an `ID`" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": { 119 | "tags": [] 120 | }, 121 | "outputs": [], 122 | "source": [ 123 | "import copy\n", 124 | "from typing import List\n", 125 | "\n", 126 | "import numpy as np\n", 127 | "import torch\n", 128 | "from PIL import Image\n", 129 | "from torch.utils.data import DataLoader\n", 130 | "\n", 131 | "try:\n", 132 | " import sagemaker_defect_detection\n", 133 | "except ImportError:\n", 134 | " import sys\n", 135 | " from pathlib import Path\n", 136 | "\n", 137 | " ROOT = Path(\"../src\").resolve()\n", 138 | " sys.path.insert(0, str(ROOT))\n", 139 | "\n", 140 | "from sagemaker_defect_detection import NEUDET, get_preprocess\n", 141 | "\n", 142 | "SPLIT = \"test\"\n", 143 | "ID = 50\n", 144 | "dataset = NEUDET(DATA_PATH, split=SPLIT, preprocess=get_preprocess())\n", 145 | "images, targets, _ = dataset[ID]\n", 146 | "original_image = copy.deepcopy(images)\n", 147 | "original_boxes = targets[\"boxes\"].numpy().copy()\n", 148 | "original_labels = targets[\"labels\"].numpy().copy()\n", 149 | "print(f\"first images size: {original_image.shape}\")\n", 150 | "print(f\"target bounding boxes: \\n {original_boxes}\")\n", 151 | "print(f\"target labels: {original_labels}\")" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "And we can now visualize it using the provided utilities as follows" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc, visualize\n", 168 | "\n", 169 | "original_image_unnorm = unnormalize_to_hwc(original_image)\n", 170 | "\n", 171 | "visualize(\n", 172 | " original_image_unnorm,\n", 173 | " [original_boxes],\n", 174 | " [original_labels],\n", 175 | " colors=[(255, 0, 0)],\n", 176 | " titles=[\"original\", \"ground truth\"],\n", 177 | ")" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "In order to get high **Mean Average Percision (mAP)** and **Mean Average Recall (mAR)** for **Intersection Over Union (IOU)** thresholds of 0.5 when training from scratch, it requires more than **300 epochs**. That is why we have provided a pretrained model and recommend finetuning whenever is possible. For demostration, we train the model from scratch for 10 epochs which takes about **16 minutes** and it results in the following mAP, mAR and the accumulated `main_score` of\n", 185 | "\n", 186 | "* `Average Precision (AP) @[ IoU=0.50:0.95 ] ~ 0.048`\n", 187 | "* `Average Recall (AR) @[ IoU=0.50:0.95] ~ 0.153`\n", 188 | "* `main_score=0.0509`\n", 189 | "\n", 190 | "To get higher mAP, mAR and overall `main_score`, you can train for more epochs" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "metadata": { 197 | "tags": [ 198 | "outputPrepend" 199 | ] 200 | }, 201 | "outputs": [], 202 | "source": [ 203 | "%%time\n", 204 | "import logging\n", 205 | "from os import path as osp\n", 206 | "\n", 207 | "from sagemaker.pytorch import PyTorch\n", 208 | "\n", 209 | "NUM_CLASSES = 7 # 6 classes + 1 for background\n", 210 | "BACKBONE = \"resnet34\" # has to match the pretrained model backbone\n", 211 | "assert BACKBONE in [\n", 212 | " \"resnet34\",\n", 213 | " \"resnet50\",\n", 214 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in detector.py\"\n", 215 | "EPOCHS = 10\n", 216 | "LEARNING_RATE = 1e-3\n", 217 | "SEED = 123\n", 218 | "\n", 219 | "hyperparameters = {\n", 220 | " \"backbone\": BACKBONE, # the backbone resnet model for feature extraction\n", 221 | " \"num-classes\": NUM_CLASSES, # number of classes 6 + 1 background\n", 222 | " \"epochs\": EPOCHS, # number of epochs to train\n", 223 | " \"learning-rate\": LEARNING_RATE, # learning rate for optimizer\n", 224 | " \"seed\": SEED, # random number generator seed\n", 225 | "}\n", 226 | "\n", 227 | "model = PyTorch(\n", 228 | " entry_point=\"detector.py\",\n", 229 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 230 | " role=role,\n", 231 | " train_instance_count=1,\n", 232 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 233 | " hyperparameters=hyperparameters,\n", 234 | " py_version=\"py3\",\n", 235 | " framework_version=\"1.5\",\n", 236 | " sagemaker_session=sagemaker_session,\n", 237 | " output_path=train_output,\n", 238 | " checkpoint_s3_uri=s3_checkpoint,\n", 239 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 240 | " # container_log_level=logging.DEBUG,\n", 241 | ")\n", 242 | "\n", 243 | "model.fit(neu_det_prepared_s3)" 244 | ] 245 | }, 246 | { 247 | "cell_type": "markdown", 248 | "metadata": {}, 249 | "source": [ 250 | "Then, we deploy our model which takes about **10 minutes** to complete" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": { 257 | "tags": [] 258 | }, 259 | "outputs": [], 260 | "source": [ 261 | "%%time\n", 262 | "detector = model.deploy(\n", 263 | " initial_instance_count=1,\n", 264 | " instance_type=\"ml.m5.xlarge\",\n", 265 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-detector-from-scratch-endpoint\",\n", 266 | ")" 267 | ] 268 | }, 269 | { 270 | "cell_type": "markdown", 271 | "metadata": {}, 272 | "source": [ 273 | "## Inference\n", 274 | "\n", 275 | "We change the input depending on whether we are providing a list of images or a single image. Also the model requires a four dimensional array / tensor (with the first dimension as batch)" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "input = list(img.numpy() for img in images) if isinstance(images, list) else images.unsqueeze(0).numpy()" 285 | ] 286 | }, 287 | { 288 | "cell_type": "markdown", 289 | "metadata": {}, 290 | "source": [ 291 | "Now the input is ready and we can get some results" 292 | ] 293 | }, 294 | { 295 | "cell_type": "code", 296 | "execution_count": null, 297 | "metadata": { 298 | "tags": [] 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "%%time\n", 303 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 304 | "np_load_old = np.load\n", 305 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 306 | "predictions = detector.predict(input)\n", 307 | "np.load = np_load_old" 308 | ] 309 | }, 310 | { 311 | "cell_type": "markdown", 312 | "metadata": {}, 313 | "source": [ 314 | "We use our `visualize` utility to check the detection results" 315 | ] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "metadata": {}, 321 | "outputs": [], 322 | "source": [ 323 | "visualize(\n", 324 | " original_image_unnorm,\n", 325 | " [original_boxes, predictions[0][\"boxes\"]],\n", 326 | " [original_labels, predictions[0][\"labels\"]],\n", 327 | " colors=[(255, 0, 0), (0, 0, 255)],\n", 328 | " titles=[\"original\", \"ground truth\", \"trained from scratch\"],\n", 329 | " dpi=200,\n", 330 | ")" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": {}, 336 | "source": [ 337 | "Here we want to compare the results of the new model and the pretrained model that we already deployed in `0_demo.ipynb` visually by calling our endpoint from SageMaker runtime using `boto3`" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "metadata": {}, 344 | "outputs": [], 345 | "source": [ 346 | "import boto3\n", 347 | "import botocore\n", 348 | "\n", 349 | "config = botocore.config.Config(read_timeout=200)\n", 350 | "runtime = boto3.client(\"runtime.sagemaker\", config=config)\n", 351 | "payload = json.dumps(input.tolist() if isinstance(input, np.ndarray) else input)\n", 352 | "response = runtime.invoke_endpoint(\n", 353 | " EndpointName=sagemaker_config[\"SolutionPrefix\"] + \"-demo-endpoint\", ContentType=\"application/json\", Body=payload\n", 354 | ")\n", 355 | "demo_predictions = json.loads(response[\"Body\"].read().decode())" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "Finally, we compare the results of the provided pretrained model and our trained from scratch" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": null, 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [ 371 | "visualize(\n", 372 | " original_image_unnorm,\n", 373 | " [original_boxes, demo_predictions[0][\"boxes\"], predictions[0][\"boxes\"]],\n", 374 | " [original_labels, demo_predictions[0][\"labels\"], predictions[0][\"labels\"]],\n", 375 | " colors=[(255, 0, 0), (0, 0, 255), (127, 0, 127)],\n", 376 | " titles=[\"original\", \"ground truth\", \"pretrained\", \"from scratch\"],\n", 377 | " dpi=250,\n", 378 | ")" 379 | ] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": {}, 384 | "source": [ 385 | "## Optional: Delete the endpoint and model\n", 386 | "\n", 387 | "When you are done with the endpoint, you should clean it up.\n", 388 | "\n", 389 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 390 | ] 391 | }, 392 | { 393 | "cell_type": "code", 394 | "execution_count": null, 395 | "metadata": {}, 396 | "outputs": [], 397 | "source": [ 398 | "detector.delete_model()\n", 399 | "detector.delete_endpoint()" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "### [Click here to continue](./3_classification_from_scratch.ipynb)" 407 | ] 408 | } 409 | ], 410 | "metadata": { 411 | "kernelspec": { 412 | "display_name": "Python 3.6.10 64-bit ('pytorch_latest_p36': conda)", 413 | "metadata": { 414 | "interpreter": { 415 | "hash": "4c1e195df8d07db5ee7a78f454b46c3f2e14214bf8c9489d2db5cf8f372ff2ed" 416 | } 417 | }, 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.6.10-final" 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /notebooks/3_classification_from_scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "**Jupyter Kernel**:\n", 8 | "\n", 9 | "* If you are in SageMaker Notebook instance, please make sure you are using **conda_pytorch_latest_p36** kernel\n", 10 | "* If you are on SageMaker Studio, please make sure you are using **SageMaker JumpStart PyTorch 1.0** kernel\n", 11 | "\n", 12 | "**Run All**:\n", 13 | "\n", 14 | "* If you are in SageMaker notebook instance, you can go to *Cell tab -> Run All*\n", 15 | "* If you are in SageMaker Studio, you can go to *Run tab -> Run All Cells*\n", 16 | "\n", 17 | "## Training our Classifier from scratch\n", 18 | "\n", 19 | "Depending on an application, sometimes image classification is enough. In this notebook, we see how to train and deploy an accurate classifier from scratch on **NEU-CLS** dataset" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "tags": [] 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "import json\n", 31 | "\n", 32 | "import numpy as np\n", 33 | "\n", 34 | "import sagemaker\n", 35 | "from sagemaker.s3 import S3Downloader\n", 36 | "\n", 37 | "sagemaker_session = sagemaker.Session()\n", 38 | "sagemaker_config = json.load(open(\"../stack_outputs.json\"))\n", 39 | "role = sagemaker_config[\"IamRole\"]\n", 40 | "solution_bucket = sagemaker_config[\"SolutionS3Bucket\"]\n", 41 | "region = sagemaker_config[\"AWSRegion\"]\n", 42 | "solution_name = sagemaker_config[\"SolutionName\"]\n", 43 | "bucket = sagemaker_config[\"S3Bucket\"]" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "metadata": {}, 49 | "source": [ 50 | "First, we download our **NEU-CLS** dataset from our public S3 bucket" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "from sagemaker.s3 import S3Downloader\n", 60 | "\n", 61 | "original_bucket = f\"s3://{solution_bucket}-{region}/{solution_name}\"\n", 62 | "original_data = f\"{original_bucket}/data/NEU-CLS.zip\"\n", 63 | "original_sources = f\"{original_bucket}/build/lib/source_dir.tar.gz\"\n", 64 | "print(\"original data: \")\n", 65 | "S3Downloader.list(original_data)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "For easiler data processing, depending on the dataset, we unify the class and label names using the scripts from `prepare_data`" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "tags": [] 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "%%time\n", 84 | "RAW_DATA_PATH= !echo $PWD/raw_neu_cls\n", 85 | "RAW_DATA_PATH = RAW_DATA_PATH.n\n", 86 | "DATA_PATH = !echo $PWD/neu_cls\n", 87 | "DATA_PATH = DATA_PATH.n\n", 88 | "\n", 89 | "!mkdir -p $RAW_DATA_PATH\n", 90 | "!aws s3 cp $original_data $RAW_DATA_PATH\n", 91 | "\n", 92 | "!mkdir -p $DATA_PATH\n", 93 | "!python ../src/prepare_data/neu.py $RAW_DATA_PATH/NEU-CLS.zip $DATA_PATH" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "After data preparation, we need to setup some paths that will be used throughtout the notebook" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "prefix = \"neu-cls\"\n", 110 | "neu_cls_s3 = f\"s3://{bucket}/{prefix}\"\n", 111 | "sources = f\"{neu_cls_s3}/code/\"\n", 112 | "train_output = f\"{neu_cls_s3}/output/\"\n", 113 | "neu_cls_prepared_s3 = f\"{neu_cls_s3}/data/\"\n", 114 | "!aws s3 sync $DATA_PATH $neu_cls_prepared_s3 --quiet # remove the --quiet flag to view sync outputs\n", 115 | "s3_checkpoint = f\"{neu_cls_s3}/checkpoint/\"\n", 116 | "sm_local_checkpoint_dir = \"/opt/ml/checkpoints/\"\n", 117 | "!aws s3 cp $original_sources $sources" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Visualization\n", 125 | "\n", 126 | "Let examine some datasets that we will use later by providing an `ID`" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "tags": [] 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "import matplotlib.pyplot as plt\n", 138 | "\n", 139 | "%matplotlib inline\n", 140 | "import numpy as np\n", 141 | "import torch\n", 142 | "from PIL import Image\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "\n", 145 | "try:\n", 146 | " import sagemaker_defect_detection\n", 147 | "except ImportError:\n", 148 | " import sys\n", 149 | " from pathlib import Path\n", 150 | "\n", 151 | " ROOT = Path(\"../src\").resolve()\n", 152 | " sys.path.insert(0, str(ROOT))\n", 153 | "\n", 154 | "from sagemaker_defect_detection import NEUCLS\n", 155 | "\n", 156 | "\n", 157 | "def visualize(image, label, predicted=None):\n", 158 | " if not isinstance(image, Image.Image):\n", 159 | " image = Image.fromarray(image)\n", 160 | "\n", 161 | " plt.figure(dpi=120)\n", 162 | " if predicted is not None:\n", 163 | " plt.title(f\"label: {label}, prediction: {predicted}\")\n", 164 | " else:\n", 165 | " plt.title(f\"label: {label}\")\n", 166 | "\n", 167 | " plt.axis(\"off\")\n", 168 | " plt.imshow(image)\n", 169 | " return\n", 170 | "\n", 171 | "\n", 172 | "dataset = NEUCLS(DATA_PATH, split=\"train\")\n", 173 | "ID = 0\n", 174 | "assert 0 <= ID <= 300\n", 175 | "image, label = dataset[ID]\n", 176 | "visualize(image, label)" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "We train our model with `resnet34` backbone for **50 epochs** and obtains about **99%** test accuracy, f1-score, precision and recall as follows" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "tags": [ 191 | "outputPrepend" 192 | ] 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "%%time\n", 197 | "import logging\n", 198 | "from os import path as osp\n", 199 | "\n", 200 | "from sagemaker.pytorch import PyTorch\n", 201 | "\n", 202 | "NUM_CLASSES = 6\n", 203 | "BACKBONE = \"resnet34\"\n", 204 | "assert BACKBONE in [\n", 205 | " \"resnet34\",\n", 206 | " \"resnet50\",\n", 207 | "], \"either resnet34 or resnet50. Make sure to be consistent with model_fn in classifier.py\"\n", 208 | "EPOCHS = 50\n", 209 | "SEED = 123\n", 210 | "\n", 211 | "hyperparameters = {\n", 212 | " \"backbone\": BACKBONE,\n", 213 | " \"num-classes\": NUM_CLASSES,\n", 214 | " \"epochs\": EPOCHS,\n", 215 | " \"seed\": SEED,\n", 216 | "}\n", 217 | "\n", 218 | "assert not isinstance(sagemaker_session, sagemaker.LocalSession), \"local session as share memory cannot be altered\"\n", 219 | "\n", 220 | "model = PyTorch(\n", 221 | " entry_point=\"classifier.py\",\n", 222 | " source_dir=osp.join(sources, \"source_dir.tar.gz\"),\n", 223 | " role=role,\n", 224 | " train_instance_count=1,\n", 225 | " train_instance_type=\"ml.g4dn.2xlarge\",\n", 226 | " hyperparameters=hyperparameters,\n", 227 | " py_version=\"py3\",\n", 228 | " framework_version=\"1.5\",\n", 229 | " sagemaker_session=sagemaker_session, # Note: Do not use local session as share memory cannot be altered\n", 230 | " output_path=train_output,\n", 231 | " checkpoint_s3_uri=s3_checkpoint,\n", 232 | " checkpoint_local_path=sm_local_checkpoint_dir,\n", 233 | " # container_log_level=logging.DEBUG,\n", 234 | ")\n", 235 | "\n", 236 | "model.fit(neu_cls_prepared_s3)" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "Then, we deploy our model which takes about **8 minutes** to complete" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": null, 249 | "metadata": { 250 | "tags": [] 251 | }, 252 | "outputs": [], 253 | "source": [ 254 | "%%time\n", 255 | "predictor = model.deploy(\n", 256 | " initial_instance_count=1,\n", 257 | " instance_type=\"ml.m5.xlarge\",\n", 258 | " endpoint_name=sagemaker_config[\"SolutionPrefix\"] + \"-classification-endpoint\",\n", 259 | ")" 260 | ] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": {}, 265 | "source": [ 266 | "## Inference\n", 267 | "\n", 268 | "We are ready to test our model by providing some test data and compare the actual labels with the predicted one" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": null, 274 | "metadata": { 275 | "tags": [] 276 | }, 277 | "outputs": [], 278 | "source": [ 279 | "from sagemaker_defect_detection import get_transform\n", 280 | "from sagemaker_defect_detection.utils.visualize import unnormalize_to_hwc\n", 281 | "\n", 282 | "ID = 100\n", 283 | "assert 0 <= ID <= 300\n", 284 | "test_dataset = NEUCLS(DATA_PATH, split=\"test\", transform=get_transform(\"test\"), seed=SEED)\n", 285 | "image, label = test_dataset[ID]\n", 286 | "# SageMaker 1.x doesn't allow_pickle=True by default\n", 287 | "np_load_old = np.load\n", 288 | "np.load = lambda *args, **kwargs: np_load_old(*args, allow_pickle=True, **kwargs)\n", 289 | "outputs = predictor.predict(image.unsqueeze(0).numpy())\n", 290 | "np.load = np_load_old\n", 291 | "\n", 292 | "_, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)\n", 293 | "\n", 294 | "image_unnorm = unnormalize_to_hwc(image)\n", 295 | "visualize(image_unnorm, label, predicted.item())" 296 | ] 297 | }, 298 | { 299 | "cell_type": "markdown", 300 | "metadata": {}, 301 | "source": [ 302 | "## Optional: Delete the endpoint and model\n", 303 | "\n", 304 | "When you are done with the endpoint, you should clean it up.\n", 305 | "\n", 306 | "All of the training jobs, models and endpoints we created can be viewed through the SageMaker console of your AWS account." 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "metadata": {}, 313 | "outputs": [], 314 | "source": [ 315 | "predictor.delete_model()\n", 316 | "predictor.delete_endpoint()" 317 | ] 318 | } 319 | ], 320 | "metadata": { 321 | "kernelspec": { 322 | "display_name": "Python 3.6.10 64-bit", 323 | "metadata": { 324 | "interpreter": { 325 | "hash": "4c1e195df8d07db5ee7a78f454b46c3f2e14214bf8c9489d2db5cf8f372ff2ed" 326 | } 327 | }, 328 | "name": "python3" 329 | }, 330 | "language_info": { 331 | "codemirror_mode": { 332 | "name": "ipython", 333 | "version": 3 334 | }, 335 | "file_extension": ".py", 336 | "mimetype": "text/x-python", 337 | "name": "python", 338 | "nbconvert_exporter": "python", 339 | "pygments_lexer": "ipython3", 340 | "version": "3.6.10-final" 341 | } 342 | }, 343 | "nbformat": 4, 344 | "nbformat_minor": 2 345 | } 346 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.black] 8 | line-length = 119 9 | exclude = ''' 10 | ( 11 | /( 12 | \.eggs 13 | | \.git 14 | | \.mypy_cache 15 | | build 16 | | dist 17 | ) 18 | ) 19 | ''' 20 | 21 | [tool.portray] 22 | modules = ["src/sagemaker_defect_detection", "src/prepare_data/neu.py"] 23 | docs_dir = "docs" 24 | 25 | [tool.portray.mkdocs.theme] 26 | favicon = "docs/sagemaker.png" 27 | logo = "docs/sagemaker.png" 28 | name = "material" 29 | palette = {primary = "dark blue"} 30 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | patool 2 | pyunpack 3 | torch 4 | torchvision 5 | albumentations 6 | pytorch_lightning 7 | pycocotools -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with python 3.8 3 | # To update, run: 4 | # 5 | # pip-compile requirements.in 6 | # 7 | absl-py 8 | # via tensorboard 9 | albumentations 10 | # via -r requirements.in 11 | cachetools 12 | # via google-auth 13 | certifi 14 | # via requests 15 | charset-normalizer 16 | # via requests 17 | cycler 18 | # via matplotlib 19 | cython 20 | # via pycocotools 21 | easyprocess 22 | # via pyunpack 23 | entrypoint2 24 | # via pyunpack 25 | fonttools 26 | # via matplotlib 27 | future 28 | # via 29 | # pytorch-lightning 30 | # torch 31 | google-auth 32 | # via 33 | # google-auth-oauthlib 34 | # tensorboard 35 | google-auth-oauthlib 36 | # via tensorboard 37 | grpcio 38 | # via tensorboard 39 | idna 40 | # via requests 41 | imageio 42 | # via 43 | # imgaug 44 | # scikit-image 45 | imgaug 46 | # via albumentations 47 | importlib-metadata 48 | # via markdown 49 | kiwisolver 50 | # via matplotlib 51 | markdown 52 | # via tensorboard 53 | matplotlib 54 | # via 55 | # imgaug 56 | # pycocotools 57 | networkx 58 | # via scikit-image 59 | numpy 60 | # via 61 | # albumentations 62 | # imageio 63 | # imgaug 64 | # matplotlib 65 | # opencv-python 66 | # opencv-python-headless 67 | # pytorch-lightning 68 | # pywavelets 69 | # scikit-image 70 | # scipy 71 | # tensorboard 72 | # tifffile 73 | # torch 74 | # torchvision 75 | oauthlib 76 | # via requests-oauthlib 77 | opencv-python 78 | # via imgaug 79 | opencv-python-headless 80 | # via albumentations 81 | packaging 82 | # via 83 | # matplotlib 84 | # scikit-image 85 | patool 86 | # via -r requirements.in 87 | pillow 88 | # via 89 | # imageio 90 | # imgaug 91 | # matplotlib 92 | # scikit-image 93 | # torchvision 94 | protobuf 95 | # via tensorboard 96 | pyasn1 97 | # via 98 | # pyasn1-modules 99 | # rsa 100 | pyasn1-modules 101 | # via google-auth 102 | pycocotools 103 | # via -r requirements.in 104 | pyparsing 105 | # via 106 | # matplotlib 107 | # packaging 108 | python-dateutil 109 | # via matplotlib 110 | pytorch-lightning 111 | # via -r requirements.in 112 | pyunpack 113 | # via -r requirements.in 114 | pywavelets 115 | # via scikit-image 116 | pyyaml 117 | # via 118 | # albumentations 119 | # pytorch-lightning 120 | requests 121 | # via 122 | # requests-oauthlib 123 | # tensorboard 124 | requests-oauthlib 125 | # via google-auth-oauthlib 126 | rsa 127 | # via google-auth 128 | scikit-image 129 | # via imgaug 130 | scipy 131 | # via 132 | # albumentations 133 | # imgaug 134 | # scikit-image 135 | shapely 136 | # via imgaug 137 | six 138 | # via 139 | # google-auth 140 | # grpcio 141 | # imgaug 142 | # python-dateutil 143 | tensorboard 144 | # via pytorch-lightning 145 | tensorboard-data-server 146 | # via tensorboard 147 | tensorboard-plugin-wit 148 | # via tensorboard 149 | tifffile 150 | # via scikit-image 151 | torch 152 | # via 153 | # -r requirements.in 154 | # pytorch-lightning 155 | # torchvision 156 | torchvision 157 | # via -r requirements.in 158 | tqdm 159 | # via pytorch-lightning 160 | urllib3 161 | # via requests 162 | werkzeug 163 | # via tensorboard 164 | wheel 165 | # via tensorboard 166 | zipp 167 | # via importlib-metadata 168 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | NOW=$(date +"%x %r %Z") 5 | echo "Time: $NOW" 6 | 7 | if [ $# -lt 3 ]; then 8 | echo "Please provide the solution name as well as the base S3 bucket name and the region to run build script." 9 | echo "For example: bash ./scripts/build.sh trademarked-solution-name sagemaker-solutions-build us-west-2" 10 | exit 1 11 | fi 12 | 13 | REGION=$3 14 | SOURCE_REGION="us-west-2" 15 | echo "Region: $REGION, source region: $SOURCE_REGION" 16 | BASE_DIR="$(dirname "$(dirname "$(readlink -f "$0")")")" 17 | echo "Base dir: $BASE_DIR" 18 | 19 | rm -rf build 20 | 21 | # build python package 22 | echo "Python build and package" 23 | python -m pip install --upgrade pip 24 | python -m pip install --upgrade wheel setuptools 25 | python setup.py build sdist bdist_wheel 26 | 27 | find . | grep -E "(__pycache__|\.pyc|\.pyo$|\.egg*|\lightning_logs)" | xargs rm -rf 28 | 29 | echo "Add requirements for SageMaker" 30 | cp requirements.txt build/lib/ 31 | cd build/lib || exit 32 | mv sagemaker_defect_detection/{classifier.py,detector.py} . 33 | touch source_dir.tar.gz 34 | tar --exclude=source_dir.tar.gz -czvf source_dir.tar.gz . 35 | echo "Only keep source_dir.tar.gz for SageMaker" 36 | find . ! -name "source_dir.tar.gz" -type f -exec rm -r {} + 37 | rm -rf sagemaker_defect_detection 38 | 39 | cd - || exit 40 | 41 | mv dist build 42 | 43 | # add notebooks to build 44 | echo "Prepare notebooks and add to build" 45 | cp -r notebooks build 46 | rm -rf build/notebooks/*neu* # remove local datasets for build 47 | for nb in build/notebooks/*.ipynb; do 48 | python "$BASE_DIR"/scripts/set_kernelspec.py --notebook "$nb" --display-name "Python 3 (PyTorch JumpStart)" --kernel "HUB_1P_IMAGE" 49 | done 50 | 51 | echo "Copy src to build" 52 | cp -r src build 53 | 54 | # add solution assistant 55 | echo "Solution assistant lambda function" 56 | cd cloudformation/solution-assistant/ || exit 57 | python -m pip install -r requirements.txt -t ./src/site-packages 58 | 59 | cd - || exit 60 | 61 | echo "Clean up pyc files, needed to avoid security issues. See: https://blog.jse.li/posts/pyc/" 62 | find cloudformation | grep -E "(__pycache__|\.pyc|\.pyo$)" | xargs rm -rf 63 | cp -r cloudformation/solution-assistant build/ 64 | cd build/solution-assistant/src || exit 65 | zip -q -r9 "$BASE_DIR"/build/solution-assistant.zip -- * 66 | 67 | cd - || exit 68 | rm -rf build/solution-assistant 69 | 70 | if [ -z "$4" ] || [ "$4" == 'mainline' ]; then 71 | s3_prefix="s3://$2-$3/$1" 72 | else 73 | s3_prefix="s3://$2-$3/$1-$4" 74 | fi 75 | 76 | # cleanup and copy the build artifacts 77 | echo "Removing the existing objects under $s3_prefix" 78 | aws s3 rm --recursive "$s3_prefix" --region "$REGION" 79 | echo "Copying new objects to $s3_prefix" 80 | aws s3 sync . "$s3_prefix" --delete --region "$REGION" \ 81 | --exclude ".git/*" \ 82 | --exclude ".vscode/*" \ 83 | --exclude ".mypy_cache/*" \ 84 | --exclude "logs/*" \ 85 | --exclude "stack_outputs.json" \ 86 | --exclude "src/sagemaker_defect_detection/lightning_logs/*" \ 87 | --exclude "notebooks/*neu*/*" 88 | 89 | echo "Copying solution artifacts" 90 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/demo/model.tar.gz" "$s3_prefix"/demo/model.tar.gz --source-region "$SOURCE_REGION" 91 | 92 | mkdir -p build/pretrained/ 93 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/pretrained/model.tar.gz" build/pretrained --source-region "$SOURCE_REGION" && 94 | cd build/pretrained/ && tar -xf model.tar.gz && cd .. && 95 | aws s3 sync pretrained "$s3_prefix"/pretrained/ --delete --region "$REGION" 96 | 97 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/data/NEU-CLS.zip" "$s3_prefix"/data/ --source-region "$SOURCE_REGION" 98 | aws s3 cp "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/data/NEU-DET.zip" "$s3_prefix"/data/ --source-region "$SOURCE_REGION" 99 | 100 | echo "Add docs to build" 101 | aws s3 sync "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/docs" "$s3_prefix"/docs --delete --region "$REGION" 102 | aws s3 sync "s3://sagemaker-solutions-artifacts/sagemaker-defect-detection/docs" "$s3_prefix"/build/docs --delete --region "$REGION" 103 | -------------------------------------------------------------------------------- /scripts/find_best_ckpt.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | 5 | def get_score(s: str) -> float: 6 | """Gets the criterion score from .ckpt formated with ModelCheckpoint 7 | 8 | Parameters 9 | ---------- 10 | s : str 11 | Assumption is the last number is the desired number carved in .ckpt 12 | 13 | Returns 14 | ------- 15 | float 16 | The criterion float 17 | """ 18 | return float(re.findall(r"(\d+.\d+).ckpt", s)[0]) 19 | 20 | 21 | def main(path: str, op: str) -> str: 22 | """Finds the best ckpt path 23 | 24 | Parameters 25 | ---------- 26 | path : str 27 | ckpt path 28 | op : str 29 | "max" (for mAP for example) or "min" (for loss) 30 | 31 | Returns 32 | ------- 33 | str 34 | A ckpt path 35 | """ 36 | ckpts = list(map(str, Path(path).glob("*.ckpt"))) 37 | if not len(ckpts): 38 | return 39 | 40 | ckpt_score_dict = {ckpt: get_score(ckpt) for ckpt in ckpts} 41 | op = max if op == "max" else min 42 | out = op(ckpt_score_dict, key=ckpt_score_dict.get) 43 | print(out) # need to flush for bash 44 | return out 45 | 46 | 47 | if __name__ == "__main__": 48 | import sys 49 | 50 | if len(sys.argv) < 3: 51 | print("provide checkpoint path and op either max or min") 52 | sys.exit(1) 53 | 54 | main(sys.argv[1], sys.argv[2]) 55 | -------------------------------------------------------------------------------- /scripts/set_kernelspec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | def set_kernel_spec(notebook_filepath, display_name, kernel_name): 6 | with open(notebook_filepath, "r") as openfile: 7 | notebook = json.load(openfile) 8 | kernel_spec = {"display_name": display_name, "language": "python", "name": kernel_name} 9 | if "metadata" not in notebook: 10 | notebook["metadata"] = {} 11 | notebook["metadata"]["kernelspec"] = kernel_spec 12 | with open(notebook_filepath, "w") as openfile: 13 | json.dump(notebook, openfile) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--notebook") 19 | parser.add_argument("--display-name") 20 | parser.add_argument("--kernel") 21 | args = parser.parse_args() 22 | set_kernel_spec(args.notebook, args.display_name, args.kernel) 23 | -------------------------------------------------------------------------------- /scripts/train_classifier.sh: -------------------------------------------------------------------------------- 1 | CLASSIFICATION_DATA_PATH=$1 2 | BACKBONE=$2 3 | BASE_DIR="$(dirname $(dirname $(readlink -f $0)))" 4 | CLASSIFIER_SCRIPT="${BASE_DIR}/src/sagemaker_defect_detection/classifier.py" 5 | MFN_LOGS="${BASE_DIR}/logs/" 6 | 7 | python ${CLASSIFIER_SCRIPT} \ 8 | --data-path=${CLASSIFICATION_DATA_PATH} \ 9 | --save-path=${MFN_LOGS} \ 10 | --backbone=${BACKBONE} \ 11 | --gpus=1 \ 12 | --learning-rate=1e-3 \ 13 | --epochs=50 14 | -------------------------------------------------------------------------------- /scripts/train_detector.sh: -------------------------------------------------------------------------------- 1 | NOW=$(date +"%x %r %Z") 2 | echo "Time: ${NOW}" 3 | DETECTION_DATA_PATH=$1 4 | BACKBONE=$2 5 | BASE_DIR="$(dirname $(dirname $(readlink -f $0)))" 6 | DETECTOR_SCRIPT="${BASE_DIR}/src/sagemaker_defect_detection/detector.py" 7 | LOG_DIR="${BASE_DIR}/logs" 8 | MFN_LOGS="${LOG_DIR}/classification_logs" 9 | RPN_LOGS="${LOG_DIR}/rpn_logs" 10 | ROI_LOGS="${LOG_DIR}/roi_logs" 11 | FINETUNED_RPN_LOGS="${LOG_DIR}/finetune_rpn_logs" 12 | FINETUNED_ROI_LOGS="${LOG_DIR}/finetune_roi_logs" 13 | FINETUNED_FINAL_LOGS="${LOG_DIR}/finetune_final_logs" 14 | EXTRA_FINETUNED_RPN_LOGS="${LOG_DIR}/extra_finetune_rpn_logs" 15 | EXTRA_FINETUNED_ROI_LOGS="${LOG_DIR}/extra_finetune_roi_logs" 16 | EXTRA_FINETUNING_STEPS=3 17 | 18 | function find_best_ckpt() { 19 | python ${BASE_DIR}/scripts/find_best_ckpt.py $1 $2 20 | } 21 | 22 | function train_step() { 23 | echo "training step $1" 24 | case $1 in 25 | "1") 26 | echo "skipping step 1 and use 'train_classifier.sh'" 27 | ;; 28 | "2") # train rpn 29 | python ${DETECTOR_SCRIPT} \ 30 | --data-path=${DETECTION_DATA_PATH} \ 31 | --backbone=${BACKBONE} \ 32 | --train-rpn \ 33 | --pretrained-mfn-ckpt=$(find_best_ckpt "${MFN_LOGS}" "max") \ 34 | --save-path=${RPN_LOGS} \ 35 | --gpus=-1 --distributed-backend=ddp \ 36 | --epochs=100 37 | ;; 38 | "3") # train roi 39 | python ${DETECTOR_SCRIPT} \ 40 | --data-path=${DETECTION_DATA_PATH} \ 41 | --backbone=${BACKBONE} \ 42 | --train-roi \ 43 | --pretrained-rpn-ckpt=$(find_best_ckpt "${RPN_LOGS}" "min") \ 44 | --save-path=${ROI_LOGS} \ 45 | --gpus=-1 --distributed-backend=ddp \ 46 | --epochs=100 47 | ;; 48 | "4") # finetune rpn 49 | python ${DETECTOR_SCRIPT} \ 50 | --data-path=${DETECTION_DATA_PATH} \ 51 | --backbone=${BACKBONE} \ 52 | --finetune-rpn \ 53 | --pretrained-rpn-ckpt=$(find_best_ckpt "${RPN_LOGS}" "min") \ 54 | --pretrained-roi-ckpt=$(find_best_ckpt "${ROI_LOGS}" "min") \ 55 | --save-path=${FINETUNED_RPN_LOGS} \ 56 | --gpus=-1 --distributed-backend=ddp \ 57 | --learning-rate=1e-4 \ 58 | --epochs=100 59 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_RPN_LOGS}" "max") 60 | ;; 61 | "5") # finetune roi 62 | python ${DETECTOR_SCRIPT} \ 63 | --data-path=${DETECTION_DATA_PATH} \ 64 | --backbone=${BACKBONE} \ 65 | --finetune-roi \ 66 | --finetuned-rpn-ckpt=$(find_best_ckpt "${FINETUNED_RPN_LOGS}" "max") \ 67 | --pretrained-roi-ckpt=$(find_best_ckpt "${ROI_LOGS}" "min") \ 68 | --save-path=${FINETUNED_ROI_LOGS} \ 69 | --gpus=-1 --distributed-backend=ddp \ 70 | --learning-rate=1e-4 \ 71 | --epochs=100 72 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_ROI_LOGS}" "max") 73 | ;; 74 | "extra_rpn") # initially EXTRA_FINETUNED_*_LOGS is a copy of FINETUNED_*_LOGS 75 | python ${DETECTOR_SCRIPT} \ 76 | --data-path=${DETECTION_DATA_PATH} \ 77 | --backbone=${BACKBONE} \ 78 | --finetune-rpn \ 79 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 80 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 81 | --save-path=${EXTRA_FINETUNED_RPN_LOGS} \ 82 | --gpus=-1 --distributed-backend=ddp \ 83 | --learning-rate=1e-4 \ 84 | --epochs=100 85 | ;; 86 | "extra_roi") # initially EXTRA_FINETUNED_*_LOGS is a copy of FINETUNED_*_LOGS 87 | python ${DETECTOR_SCRIPT} \ 88 | --data-path=${DETECTION_DATA_PATH} \ 89 | --backbone=${BACKBONE} \ 90 | --finetune-roi \ 91 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 92 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 93 | --save-path=${EXTRA_FINETUNED_ROI_LOGS} \ 94 | --gpus=-1 --distributed-backend=ddp \ 95 | --learning-rate=1e-4 \ 96 | --epochs=100 97 | ;; 98 | "joint") # final 99 | python ${DETECTOR_SCRIPT} \ 100 | --data-path=${DETECTION_DATA_PATH} \ 101 | --backbone=${BACKBONE} \ 102 | --finetuned-rpn-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_RPN_LOGS}" "max") \ 103 | --finetuned-roi-ckpt=$(find_best_ckpt "${EXTRA_FINETUNED_ROI_LOGS}" "max") \ 104 | --save-path="${FINETUNED_FINAL_LOGS}" \ 105 | --gpus=-1 --distributed-backend=ddp \ 106 | --learning-rate=1e-3 \ 107 | --epochs=300 108 | # --resume-from-checkpoint=$(find_best_ckpt "${FINETUNED_FINAL_LOGS}" "max") 109 | ;; 110 | 111 | *) ;; 112 | esac 113 | } 114 | 115 | function train_wait_to_finish() { 116 | train_step $1 & 117 | BPID=$! 118 | wait $BPID 119 | } 120 | 121 | function run() { 122 | if [ "$1" != "" ]; then 123 | train_step $1 124 | else 125 | nvidia-smi | grep python | awk '{ print $3 }' | xargs -n1 kill -9 >/dev/null 2>&1 126 | read -p "Training all steps from scratch? (Y/N): " confirm && [[ $confirm == [yY] || $confirm == [yY][eE][sS] ]] 127 | if [ "$confirm" == "Y" ]; then 128 | for i in {1..5}; do 129 | train_wait_to_finish $i 130 | done 131 | echo "finished all the training steps" 132 | mkdir -p "${EXTRA_FINETUNED_RPN_LOGS}" && cp -r "${FINETUNED_RPN_LOGS}/"* "${EXTRA_FINETUNED_RPN_LOGS}" 133 | mkdir -p "${EXTRA_FINETUNED_ROI_LOGS}" && cp -r "${FINETUNED_ROI_LOGS}/"* "${EXTRA_FINETUNED_ROI_LOGS}" 134 | fi 135 | echo "repeating extra finetuning steps ${EXTRA_FINETUNING_STEPS} more times" 136 | for i in {1..${EXTRA_FINETUNING_STEPS}}; do 137 | train_wait_to_finish "extra_rpn" 138 | train_wait_to_finish "extra_roi" 139 | done 140 | echo "final joint training" 141 | train_step "joint" 142 | fi 143 | exit 0 144 | } 145 | 146 | run $1 147 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [tool:pytest] 2 | norecursedirs = 3 | .git 4 | dist 5 | build 6 | python_files = 7 | test_*.py 8 | 9 | [metadata] 10 | # license_files = LICENSE 11 | 12 | [check-manifest] 13 | ignore = 14 | *.yaml 15 | .github 16 | .github/* 17 | build 18 | deploy 19 | notebook 20 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import setup, find_packages 4 | 5 | ROOT = Path(__file__).parent.resolve() 6 | 7 | long_description = (ROOT / "README.md").read_text(encoding="utf-8") 8 | 9 | dev_dependencies = ["pre-commit", "mypy==0.781", "black==20.8b1", "nbstripout==0.3.7", "black-nb==0.3.0"] 10 | test_dependencies = ["pytest>=6.0"] 11 | doc_dependencies = ["portray>=1.4.0"] 12 | 13 | setup( 14 | name="sagemaker_defect_detection", 15 | version="0.1", 16 | description="Detect Defects in Products from their Images using Amazon SageMaker ", 17 | long_description=long_description, 18 | author="Ehsan M. Kermani", 19 | python_requires=">=3.6", 20 | package_dir={"": "src"}, 21 | packages=find_packages("src", exclude=["tests", "tests/*"]), 22 | install_requires=open(str(ROOT / "requirements.txt"), "r").read(), 23 | extras_require={"dev": dev_dependencies, "test": test_dependencies, "doc": doc_dependencies}, 24 | license="Apache License 2.0", 25 | classifiers=[ 26 | "Development Status :: 3 - Alpha", 27 | "Programming Language :: Python :: 3.6", 28 | "Programming Language :: Python :: 3.7", 29 | "Programming Language :: Python :: 3.8", 30 | "Programming Language :: Python :: 3 :: Only", 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /src/im2rec.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Licensed to the Apache Software Foundation (ASF) under one 4 | # or more contributor license agreements. See the NOTICE file 5 | # distributed with this work for additional information 6 | # regarding copyright ownership. The ASF licenses this file 7 | # to you under the Apache License, Version 2.0 (the 8 | # "License"); you may not use this file except in compliance 9 | # with the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, 14 | # software distributed under the License is distributed on an 15 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16 | # KIND, either express or implied. See the License for the 17 | # specific language governing permissions and limitations 18 | # under the License. 19 | 20 | from __future__ import print_function 21 | 22 | import argparse 23 | import os 24 | import random 25 | import sys 26 | import time 27 | import traceback 28 | 29 | import cv2 30 | import mxnet as mx 31 | 32 | 33 | curr_path = os.path.abspath(os.path.dirname(__file__)) 34 | sys.path.append(os.path.join(curr_path, "../python")) 35 | 36 | 37 | 38 | def list_image(root, recursive, exts): 39 | """Traverses the root of directory that contains images and generates image list iterator. 40 | 41 | Args: 42 | root (str): the directory to explore 43 | recursive (bool): whether subdirectories should be explored recursively 44 | exts (collection[str]): the file extensions to look for 45 | 46 | Returns: 47 | image iterator that contains all the image under the specified path 48 | """ 49 | 50 | i = 0 51 | if recursive: 52 | cat = {} 53 | for path, dirs, files in os.walk(root, followlinks=True): 54 | dirs.sort() 55 | files.sort() 56 | for fname in files: 57 | fpath = os.path.join(path, fname) 58 | suffix = os.path.splitext(fname)[1].lower() 59 | if os.path.isfile(fpath) and (suffix in exts): 60 | if path not in cat: 61 | cat[path] = len(cat) 62 | yield (i, os.path.relpath(fpath, root), cat[path]) 63 | i += 1 64 | for k, v in sorted(cat.items(), key=lambda x: x[1]): 65 | print(os.path.relpath(k, root), v) 66 | else: 67 | for fname in sorted(os.listdir(root)): 68 | fpath = os.path.join(root, fname) 69 | suffix = os.path.splitext(fname)[1].lower() 70 | if os.path.isfile(fpath) and (suffix in exts): 71 | yield (i, os.path.relpath(fpath, root), 0) 72 | i += 1 73 | 74 | 75 | def write_list(path_out, image_list): 76 | r"""Helper function to write image list into the file. 77 | 78 | The format is as below, integer_image_index \t float_label_index \t path_to_image 79 | Note that the blank between number and tab is only used for readability. 80 | 81 | Args: 82 | path_out (str): the path of the file to write to 83 | image_list (list): objects in the images 84 | """ 85 | with open(path_out, "w") as fout: 86 | for i, item in enumerate(image_list): 87 | line = "%d\t" % item[0] 88 | for j in item[2:]: 89 | line += "%f\t" % j 90 | line += "%s\n" % item[1] 91 | fout.write(line) 92 | 93 | 94 | def make_list(args): 95 | """Generates .lst file. 96 | 97 | Args: 98 | args: object that contains all the arguments 99 | """ 100 | image_list = list_image(args.root, args.recursive, args.exts) 101 | image_list = list(image_list) 102 | if args.shuffle is True: 103 | random.seed(100) 104 | random.shuffle(image_list) 105 | N = len(image_list) 106 | chunk_size = (N + args.chunks - 1) // args.chunks 107 | for i in range(args.chunks): 108 | chunk = image_list[i * chunk_size : (i + 1) * chunk_size] 109 | if args.chunks > 1: 110 | str_chunk = "_%d" % i 111 | else: 112 | str_chunk = "" 113 | sep = int(chunk_size * args.train_ratio) 114 | sep_test = int(chunk_size * args.test_ratio) 115 | if args.train_ratio == 1.0: 116 | write_list(args.prefix + str_chunk + ".lst", chunk) 117 | else: 118 | if args.test_ratio: 119 | write_list(args.prefix + str_chunk + "_test.lst", chunk[:sep_test]) 120 | if args.train_ratio + args.test_ratio < 1.0: 121 | write_list(args.prefix + str_chunk + "_val.lst", chunk[sep_test + sep :]) 122 | write_list(args.prefix + str_chunk + "_train.lst", chunk[sep_test : sep_test + sep]) 123 | 124 | 125 | def read_list(path_in): 126 | """Reads the .lst file and generates corresponding iterator. 127 | 128 | Args: 129 | path_in (str) 130 | 131 | Returns: 132 | item iterator that contains information in .lst file 133 | """ 134 | with open(path_in) as fin: 135 | while True: 136 | line = fin.readline() 137 | if not line: 138 | break 139 | line = [i.strip() for i in line.strip().split("\t")] 140 | line_len = len(line) 141 | # check the data format of .lst file 142 | if line_len < 3: 143 | print("lst should have at least has three parts, but only has %s parts for %s" % (line_len, line)) 144 | continue 145 | try: 146 | item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]] 147 | except Exception as e: 148 | print("Parsing lst met error for %s, detail: %s" % (line, e)) 149 | continue 150 | yield item 151 | 152 | 153 | def image_encode(args, i, item, q_out): 154 | """Reads, preprocesses, packs the image and put it back in output queue. 155 | 156 | Args" 157 | args (object): image 158 | i (int): image index 159 | item (list): labels 160 | q_out (queue): collection to store the image to 161 | """ 162 | fullpath = os.path.join(args.root, item[1]) 163 | 164 | if len(item) > 3 and args.pack_label: 165 | header = mx.recordio.IRHeader(0, item[2:], item[0], 0) 166 | else: 167 | header = mx.recordio.IRHeader(0, item[2], item[0], 0) 168 | 169 | if args.pass_through: 170 | try: 171 | with open(fullpath, "rb") as fin: 172 | img = fin.read() 173 | s = mx.recordio.pack(header, img) 174 | q_out.put((i, s, item)) 175 | except Exception as e: 176 | traceback.print_exc() 177 | print("pack_img error:", item[1], e) 178 | q_out.put((i, None, item)) 179 | return 180 | 181 | try: 182 | img = cv2.imread(fullpath, args.color) 183 | except Exception: 184 | traceback.print_exc() 185 | print("imread error trying to load file: %s " % fullpath) 186 | q_out.put((i, None, item)) 187 | return 188 | if img is None: 189 | print("imread read blank (None) image for file: %s" % fullpath) 190 | q_out.put((i, None, item)) 191 | return 192 | if args.center_crop: 193 | if img.shape[0] > img.shape[1]: 194 | margin = (img.shape[0] - img.shape[1]) // 2 195 | img = img[margin : margin + img.shape[1], :] 196 | else: 197 | margin = (img.shape[1] - img.shape[0]) // 2 198 | img = img[:, margin : margin + img.shape[0]] 199 | if args.resize: 200 | if img.shape[0] > img.shape[1]: 201 | newsize = (args.resize, img.shape[0] * args.resize // img.shape[1]) 202 | else: 203 | newsize = (img.shape[1] * args.resize // img.shape[0], args.resize) 204 | img = cv2.resize(img, newsize) 205 | 206 | try: 207 | s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding) 208 | q_out.put((i, s, item)) 209 | except Exception as e: 210 | traceback.print_exc() 211 | print("pack_img error on file: %s" % fullpath, e) 212 | q_out.put((i, None, item)) 213 | return 214 | 215 | 216 | def read_worker(args, q_in, q_out): 217 | """Function that will be spawned to fetch the image from the input queue and put it back to output queue. 218 | 219 | Args: 220 | args (object): image 221 | q_in (queue): input queue 222 | q_out (queue): output queue 223 | """ 224 | while True: 225 | deq = q_in.get() 226 | if deq is None: 227 | break 228 | i, item = deq 229 | image_encode(args, i, item, q_out) 230 | 231 | 232 | def write_worker(q_out, fname, working_dir): 233 | """Function that will be spawned to fetch processed image from the output queue and write to the .rec file. 234 | 235 | Args: 236 | q_out (queue): output queue 237 | fname (str): name of the files 238 | working_dir (str): name of directory to write to 239 | """ 240 | pre_time = time.time() 241 | count = 0 242 | fname = os.path.basename(fname) 243 | fname_rec = os.path.splitext(fname)[0] + ".rec" 244 | fname_idx = os.path.splitext(fname)[0] + ".idx" 245 | record = mx.recordio.MXIndexedRecordIO( 246 | os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), "w" 247 | ) 248 | buf = {} 249 | more = True 250 | while more: 251 | deq = q_out.get() 252 | if deq is not None: 253 | i, s, item = deq 254 | buf[i] = (s, item) 255 | else: 256 | more = False 257 | while count in buf: 258 | s, item = buf[count] 259 | del buf[count] 260 | if s is not None: 261 | record.write_idx(item[0], s) 262 | 263 | if count % 1000 == 0: 264 | cur_time = time.time() 265 | print("time:", cur_time - pre_time, " count:", count) 266 | pre_time = cur_time 267 | count += 1 268 | 269 | 270 | def parse_args(): 271 | """Defines all arguments. 272 | 273 | Returns: 274 | args object that contains all the params 275 | """ 276 | parser = argparse.ArgumentParser( 277 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 278 | description="Create an image list or \ 279 | make a record database by reading from an image list", 280 | ) 281 | parser.add_argument("prefix", help="prefix of input/output lst and rec files.") 282 | parser.add_argument("root", help="path to folder containing images.") 283 | 284 | cgroup = parser.add_argument_group("Options for creating image lists") 285 | cgroup.add_argument( 286 | "--list", 287 | action="store_true", 288 | help="If this is set im2rec will create image list(s) by traversing root folder\ 289 | and output to .lst.\ 290 | Otherwise im2rec will read .lst and create a database at .rec", 291 | ) 292 | cgroup.add_argument( 293 | "--exts", nargs="+", default=[".jpeg", ".jpg", ".png"], help="list of acceptable image extensions." 294 | ) 295 | cgroup.add_argument("--chunks", type=int, default=1, help="number of chunks.") 296 | cgroup.add_argument("--train-ratio", type=float, default=1.0, help="Ratio of images to use for training.") 297 | cgroup.add_argument("--test-ratio", type=float, default=0, help="Ratio of images to use for testing.") 298 | cgroup.add_argument( 299 | "--recursive", 300 | action="store_true", 301 | help="If true recursively walk through subdirs and assign an unique label\ 302 | to images in each folder. Otherwise only include images in the root folder\ 303 | and give them label 0.", 304 | ) 305 | cgroup.add_argument( 306 | "--no-shuffle", 307 | dest="shuffle", 308 | action="store_false", 309 | help="If this is passed, \ 310 | im2rec will not randomize the image order in .lst", 311 | ) 312 | rgroup = parser.add_argument_group("Options for creating database") 313 | rgroup.add_argument( 314 | "--pass-through", action="store_true", help="whether to skip transformation and save image as is" 315 | ) 316 | rgroup.add_argument( 317 | "--resize", 318 | type=int, 319 | default=0, 320 | help="resize the shorter edge of image to the newsize, original images will\ 321 | be packed by default.", 322 | ) 323 | rgroup.add_argument( 324 | "--center-crop", action="store_true", help="specify whether to crop the center image to make it rectangular." 325 | ) 326 | rgroup.add_argument( 327 | "--quality", type=int, default=95, help="JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9" 328 | ) 329 | rgroup.add_argument( 330 | "--num-thread", 331 | type=int, 332 | default=1, 333 | help="number of thread to use for encoding. order of images will be different\ 334 | from the input list if >1. the input list will be modified to match the\ 335 | resulting order.", 336 | ) 337 | rgroup.add_argument( 338 | "--color", 339 | type=int, 340 | default=1, 341 | choices=[-1, 0, 1], 342 | help="specify the color mode of the loaded image.\ 343 | 1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\ 344 | 0: Loads image in grayscale mode.\ 345 | -1:Loads image as such including alpha channel.", 346 | ) 347 | rgroup.add_argument( 348 | "--encoding", type=str, default=".jpg", choices=[".jpg", ".png"], help="specify the encoding of the images." 349 | ) 350 | rgroup.add_argument( 351 | "--pack-label", action="store_true", help="Whether to also pack multi dimensional label in the record file" 352 | ) 353 | args = parser.parse_args() 354 | 355 | args.prefix = os.path.abspath(args.prefix) 356 | args.root = os.path.abspath(args.root) 357 | return args 358 | 359 | 360 | if __name__ == "__main__": 361 | args = parse_args() 362 | 363 | # if the '--list' is used, it generates .lst file 364 | if args.list: 365 | make_list(args) 366 | # otherwise read .lst file to generates .rec file 367 | else: 368 | if os.path.isdir(args.prefix): 369 | working_dir = args.prefix 370 | else: 371 | working_dir = os.path.dirname(args.prefix) 372 | files = [ 373 | os.path.join(working_dir, fname) 374 | for fname in os.listdir(working_dir) 375 | if os.path.isfile(os.path.join(working_dir, fname)) 376 | ] 377 | count = 0 378 | for fname in files: 379 | if fname.startswith(args.prefix) and fname.endswith(".lst"): 380 | print("Creating .rec file from", fname, "in", working_dir) 381 | count += 1 382 | image_list = read_list(fname) 383 | # -- write_record -- # 384 | try: 385 | import Queue as queue 386 | except ImportError: 387 | import queue 388 | q_out = queue.Queue() 389 | fname = os.path.basename(fname) 390 | fname_rec = os.path.splitext(fname)[0] + ".rec" 391 | fname_idx = os.path.splitext(fname)[0] + ".idx" 392 | record = mx.recordio.MXIndexedRecordIO( 393 | os.path.join(working_dir, fname_idx), os.path.join(working_dir, fname_rec), "w" 394 | ) 395 | cnt = 0 396 | pre_time = time.time() 397 | for i, item in enumerate(image_list): 398 | image_encode(args, i, item, q_out) 399 | if q_out.empty(): 400 | continue 401 | _, s, _ = q_out.get() 402 | record.write_idx(item[0], s) 403 | if cnt % 1000 == 0: 404 | cur_time = time.time() 405 | print("time:", cur_time - pre_time, " count:", cnt) 406 | pre_time = cur_time 407 | cnt += 1 408 | if not count: 409 | print("Did not find and list file with prefix %s" % args.prefix) 410 | -------------------------------------------------------------------------------- /src/prepare_RecordIO.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import argparse 5 | import numpy as np 6 | from collections import defaultdict 7 | from pathlib import Path 8 | 9 | 10 | def write_line(img_path, width, height, boxes, ids, idx): 11 | """Create a line for each image with annotations, width, height and image name.""" 12 | # for header, we use minimal length 2, plus width and height 13 | # with A: 4, B: 5, C: width, D: height 14 | A = 4 15 | B = 5 16 | C = width 17 | D = height 18 | # concat id and bboxes 19 | labels = np.hstack((ids.reshape(-1, 1), boxes)).astype("float") 20 | # normalized bboxes (recommanded) 21 | labels[:, (1, 3)] /= float(width) 22 | labels[:, (2, 4)] /= float(height) 23 | # flatten 24 | labels = labels.flatten().tolist() 25 | str_idx = [str(idx)] 26 | str_header = [str(x) for x in [A, B, C, D]] 27 | str_labels = [str(x) for x in labels] 28 | str_path = [img_path] 29 | line = "\t".join(str_idx + str_header + str_labels + str_path) + "\n" 30 | return line 31 | 32 | 33 | # adapt from __main__ from im2rec.py 34 | def write_lst(output_file, ids, images_annotations): 35 | 36 | all_labels = set() 37 | image_info = {} 38 | for entry in images_annotations['images']: 39 | if entry["id"] in ids: 40 | image_info[entry["id"]] = entry 41 | annotations_info = {} # one annotation for each id (ie., image) 42 | for entry in images_annotations['annotations']: 43 | image_id = entry['image_id'] 44 | if image_id in ids: 45 | if image_id not in annotations_info: 46 | annotations_info[image_id] = {'boxes': [], 'labels': []} 47 | annotations_info[image_id]['boxes'].append(entry['bbox']) 48 | annotations_info[image_id]['labels'].append(entry['category_id']) 49 | all_labels.add(entry['category_id']) 50 | labels_list = [label for label in all_labels] 51 | class_to_idx_mapping = {label: idx for idx, label in enumerate(labels_list)} 52 | with open(output_file, "w") as fw: 53 | for i, image_id in enumerate(annotations_info): 54 | im_info = image_info[image_id] 55 | image_file = im_info['file_name'] 56 | height = im_info['height'] 57 | width = im_info['width'] 58 | an_info = annotations_info[image_id] 59 | boxes = np.array(an_info['boxes']) 60 | labels = np.array([class_to_idx_mapping[label] for label in an_info['labels']]) 61 | line = write_line(image_file, width, height, boxes, labels, i) 62 | fw.write(line) 63 | 64 | 65 | def create_lst(data_dir, args, rnd_seed=100): 66 | """Generate an lst file based on annotations file which is used to convert the input data to .rec format.""" 67 | with open(os.path.join(data_dir, 'annotations.json')) as f: 68 | images_annotations = json.loads(f.read()) 69 | 70 | # Size of each class 71 | class_ids = defaultdict(list) 72 | for entry in images_annotations['images']: 73 | cls_ = entry['file_name'].split('_')[0] 74 | class_ids[cls_].append(entry['id']) 75 | print('\ncategory\tnum of images') 76 | print('---------------') 77 | for cls_ in class_ids.keys(): 78 | print(f"{cls_}\t{len(class_ids[cls_])}") 79 | 80 | random.seed(rnd_seed) 81 | 82 | # Split train/val/test image ids 83 | if args.test_ratio: 84 | test_ids = [] 85 | if args.train_ratio + args.test_ratio < 1.0: 86 | val_ids = [] 87 | train_ids = [] 88 | for cls_ in class_ids.keys(): 89 | random.shuffle(class_ids[cls_]) 90 | N = len(class_ids[cls_]) 91 | ids = class_ids[cls_] 92 | 93 | sep = int(N * args.train_ratio) 94 | sep_test = int(N * args.test_ratio) 95 | if args.train_ratio == 1.0: 96 | train_ids.extend(ids) 97 | else: 98 | if args.test_ratio: 99 | test_ids.extend(ids[:sep_test]) 100 | if args.train_ratio + args.test_ratio < 1.0: 101 | val_ids.extend(ids[sep_test + sep:]) 102 | train_ids.extend(ids[sep_test: sep_test + sep]) 103 | 104 | write_lst(args.prefix + "_train.lst", train_ids, images_annotations) 105 | lsts = [args.prefix + "_train.lst"] 106 | if args.test_ratio: 107 | write_lst(args.prefix + "_test.lst", test_ids, images_annotations) 108 | lsts.append(args.prefix + "_test.lst") 109 | if args.train_ratio + args.test_ratio < 1.0: 110 | write_lst(args.prefix + "_val.lst", val_ids, images_annotations) 111 | lsts.append(args.prefix + "_val.lst") 112 | 113 | return lsts 114 | 115 | 116 | def parse_args(): 117 | """Defines all arguments. 118 | 119 | Returns: 120 | args object that contains all the params 121 | """ 122 | parser = argparse.ArgumentParser( 123 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 124 | description="Create an image list or \ 125 | make a record database by reading from an image list", 126 | ) 127 | parser.add_argument("prefix", help="prefix of input/output lst and rec files.") 128 | parser.add_argument("root", help="path to folder containing images.") 129 | 130 | cgroup = parser.add_argument_group("Options for creating image lists") 131 | cgroup.add_argument( 132 | "--exts", nargs="+", default=[".jpeg", ".jpg", ".png"], help="list of acceptable image extensions." 133 | ) 134 | cgroup.add_argument("--train-ratio", type=float, default=0.8, help="Ratio of images to use for training.") 135 | cgroup.add_argument("--test-ratio", type=float, default=0, help="Ratio of images to use for testing.") 136 | cgroup.add_argument( 137 | "--recursive", 138 | action="store_true", 139 | help="If true recursively walk through subdirs and assign an unique label\ 140 | to images in each folder. Otherwise only include images in the root folder\ 141 | and give them label 0.", 142 | ) 143 | args = parser.parse_args() 144 | 145 | args.prefix = os.path.abspath(args.prefix) 146 | args.root = os.path.abspath(args.root) 147 | return args 148 | 149 | 150 | if __name__ == '__main__': 151 | 152 | args = parse_args() 153 | data_dir = Path(args.root).parent 154 | 155 | lsts = create_lst(data_dir, args) 156 | print() 157 | 158 | for lst in lsts: 159 | os.system(f"python3 ../src/im2rec.py {lst} {os.path.join(data_dir, 'images')} --pass-through --pack-label") 160 | print() 161 | -------------------------------------------------------------------------------- /src/prepare_data/neu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dependencies: unzip unrar 3 | python -m pip install patool pyunpack 4 | """ 5 | 6 | from pathlib import Path 7 | import shutil 8 | import re 9 | import os 10 | 11 | try: 12 | from pyunpack import Archive 13 | except ModuleNotFoundError: 14 | print("installing the dependencies `patool` and `pyunpack` for unzipping the data") 15 | import subprocess 16 | 17 | subprocess.run("python -m pip install patool==1.12 pyunpack==0.2.1 -q", shell=True) 18 | from pyunpack import Archive 19 | 20 | CLASSES = { 21 | "crazing": "Cr", 22 | "inclusion": "In", 23 | "pitted_surface": "PS", 24 | "patches": "Pa", 25 | "rolled-in_scale": "RS", 26 | "scratches": "Sc", 27 | } 28 | 29 | 30 | def unpack(path: str) -> None: 31 | path = Path(path) 32 | Archive(str(path)).extractall(str(path.parent)) 33 | return 34 | 35 | 36 | def cp_class_images(data_path: Path, class_name: str, class_path_dest: Path) -> None: 37 | lst = list(data_path.rglob(f"{class_name}_*")) 38 | for img_file in lst: 39 | shutil.copy2(str(img_file), str(class_path_dest / img_file.name)) 40 | 41 | assert len(lst) == len(list(class_path_dest.glob("*"))) 42 | return 43 | 44 | 45 | def cp_image_annotation(data_path: Path, class_name: str, image_path_dest: Path, annotation_path_dest: Path) -> None: 46 | img_lst = sorted(list((data_path / "IMAGES").rglob(f"{class_name}_*"))) 47 | ann_lst = sorted(list((data_path / "ANNOTATIONS").rglob(f"{class_name}_*"))) 48 | assert len(img_lst) == len( 49 | ann_lst 50 | ), f"images count {len(img_lst)} does not match with annotations count {len(ann_lst)} for class {class_name}" 51 | for (img_file, ann_file) in zip(img_lst, ann_lst): 52 | shutil.copy2(str(img_file), str(image_path_dest / img_file.name)) 53 | shutil.copy2(str(ann_file), str(annotation_path_dest / ann_file.name)) 54 | 55 | assert len(list(image_path_dest.glob("*"))) == len(list(annotation_path_dest.glob("*"))) 56 | return 57 | 58 | 59 | def main(data_path: str, output_path: str, archived: bool = True) -> None: 60 | """ 61 | Data preparation 62 | 63 | Parameters 64 | ---------- 65 | data_path : str 66 | Raw data path 67 | output_path : str 68 | Output data path 69 | archived: bool 70 | Whether the file is archived or not (for testing) 71 | 72 | Raises 73 | ------ 74 | ValueError 75 | If the packed data file is different from NEU-CLS or NEU-DET 76 | """ 77 | data_path = Path(data_path) 78 | if archived: 79 | unpack(data_path) 80 | 81 | data_path = data_path.parent / re.search(r"^[^.]*", str(data_path.name)).group(0) 82 | try: 83 | os.remove(str(data_path / "Thumbs.db")) 84 | except FileNotFoundError: 85 | print(f"Thumbs.db is not found. Continuing ...") 86 | pass 87 | except Exception as e: 88 | print(f"{e}: Unknown error!") 89 | raise e 90 | 91 | output_path = Path(output_path) 92 | if data_path.name == "NEU-CLS": 93 | for cls_ in CLASSES.values(): 94 | cls_path = output_path / cls_ 95 | cls_path.mkdir(exist_ok=True) 96 | cp_class_images(data_path, cls_, cls_path) 97 | elif data_path.name == "NEU-DET": 98 | for cls_ in CLASSES: 99 | cls_path = output_path / CLASSES[cls_] 100 | image_path = cls_path / "images" 101 | image_path.mkdir(parents=True, exist_ok=True) 102 | annotation_path = cls_path / "annotations" 103 | annotation_path.mkdir(exist_ok=True) 104 | cp_image_annotation(data_path, cls_, image_path, annotation_path) 105 | else: 106 | raise ValueError(f"Unknown data. Choose between `NEU-CLS` and `NEU-DET`. Given {data_path.name}") 107 | 108 | return 109 | 110 | 111 | if __name__ == "__main__": 112 | import sys 113 | 114 | if len(sys.argv) < 3: 115 | print("Provide `data_path` and `output_path`") 116 | sys.exit(1) 117 | 118 | main(sys.argv[1], sys.argv[2]) 119 | print("Done") 120 | -------------------------------------------------------------------------------- /src/prepare_data/test_neu.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | import neu 6 | 7 | 8 | def tmp_fill(tmpdir_, classes, ext): 9 | for cls_ in classes.values(): 10 | tmpdir_.join(cls_ + "_0" + ext) 11 | # add one more image 12 | tmpdir_.join(classes["crazing"] + "_1" + ext) 13 | return 14 | 15 | 16 | @pytest.fixture() 17 | def tmp_neu(): 18 | def _create(tmpdir, filename): 19 | tmpneu = tmpdir.mkdir(filename) 20 | if filename == "NEU-CLS": 21 | tmp_fill(tmpneu, neu.CLASSES, ".png") 22 | elif filename == "NEU-DET": 23 | imgs = tmpneu.mkdir("IMAGES") 24 | tmp_fill(imgs, neu.CLASSES, ".png") 25 | anns = tmpneu.mkdir("ANNOTATIONS") 26 | tmp_fill(anns, neu.CLASSES, ".xml") 27 | else: 28 | raise ValueError("Not supported") 29 | return tmpneu 30 | 31 | return _create 32 | 33 | 34 | @pytest.mark.parametrize("filename", ["NEU-CLS", "NEU-DET"]) 35 | def test_main(tmpdir, tmp_neu, filename) -> None: 36 | data_path = tmp_neu(tmpdir, filename) 37 | output_path = tmpdir.mkdir("output_path") 38 | neu.main(data_path, output_path, archived=False) 39 | assert len(os.listdir(output_path)) == len(neu.CLASSES), "failed to match number of classes in output" 40 | for p in output_path.visit(): 41 | assert p.check(), "correct path was not created" 42 | return 43 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pytorch_lightning 3 | except ModuleNotFoundError: 4 | print("installing the dependencies for sagemaker_defect_detection package ...") 5 | import subprocess 6 | 7 | subprocess.run( 8 | "python -m pip install -q albumentations==0.4.6 pytorch_lightning==0.8.5 pycocotools==2.0.1", shell=True 9 | ) 10 | 11 | from sagemaker_defect_detection.models.ddn import Classification, Detection, RoI, RPN 12 | from sagemaker_defect_detection.dataset.neu import NEUCLS, NEUDET 13 | from sagemaker_defect_detection.transforms import get_transform, get_augmentation, get_preprocess 14 | 15 | __all__ = [ 16 | "Classification", 17 | "Detection", 18 | "RoI", 19 | "RPN", 20 | "NEUCLS", 21 | "NEUDET", 22 | "get_transform", 23 | "get_augmentation", 24 | "get_preprocess", 25 | ] 26 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/classifier.py: -------------------------------------------------------------------------------- 1 | # mypy: ignore-errors 2 | from typing import Dict 3 | import os 4 | from collections import OrderedDict 5 | from argparse import ArgumentParser, Namespace 6 | from multiprocessing import cpu_count 7 | 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | 13 | import pytorch_lightning as pl 14 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 15 | import pytorch_lightning.metrics.functional as plm 16 | 17 | from sagemaker_defect_detection import Classification, NEUCLS, get_transform 18 | from sagemaker_defect_detection.utils import load_checkpoint, freeze 19 | 20 | 21 | def metrics(name: str, out: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]: 22 | pred = torch.argmax(out, 1).detach() 23 | target = target.detach() 24 | metrics = {} 25 | metrics[name + "_acc"] = plm.accuracy(pred, target) 26 | metrics[name + "_prec"] = plm.precision(pred, target) 27 | metrics[name + "_recall"] = plm.recall(pred, target) 28 | metrics[name + "_f1_score"] = plm.recall(pred, target) 29 | return metrics 30 | 31 | 32 | class DDNClassification(pl.LightningModule): 33 | def __init__( 34 | self, 35 | data_path: str, 36 | backbone: str, 37 | freeze_backbone: bool, 38 | num_classes: int, 39 | learning_rate: float, 40 | batch_size: int, 41 | momentum: float, 42 | weight_decay: float, 43 | seed: int, 44 | **kwargs 45 | ) -> None: 46 | super().__init__() 47 | self.data_path = data_path 48 | self.backbone = backbone 49 | self.freeze_backbone = freeze_backbone 50 | self.num_classes = num_classes 51 | self.learning_rate = learning_rate 52 | self.batch_size = batch_size 53 | self.momentum = momentum 54 | self.weight_decay = weight_decay 55 | self.seed = seed 56 | 57 | self.train_dataset = NEUCLS(self.data_path, split="train", transform=get_transform("train"), seed=self.seed) 58 | self.val_dataset = NEUCLS(self.data_path, split="val", transform=get_transform("val"), seed=self.seed) 59 | self.test_dataset = NEUCLS(self.data_path, split="test", transform=get_transform("test"), seed=self.seed) 60 | 61 | self.model = Classification(self.backbone, self.num_classes) 62 | if self.freeze_backbone: 63 | for param in self.model.mfn.backbone.parameters(): 64 | param.requires_grad = False 65 | 66 | def forward(self, x): # ignore 67 | return self.model(x) 68 | 69 | def training_step(self, batch, batch_idx): 70 | images, target = batch 71 | output = self(images) 72 | loss_val = F.cross_entropy(output, target) 73 | metrics_dict = metrics("train", output, target) 74 | tqdm_dict = {"train_loss": loss_val, **metrics_dict} 75 | output = OrderedDict({"loss": loss_val, "progress_bar": tqdm_dict, "log": tqdm_dict}) 76 | return output 77 | 78 | def validation_step(self, batch, batch_idx): 79 | images, target = batch 80 | output = self(images) 81 | loss_val = F.cross_entropy(output, target) 82 | metrics_dict = metrics("val", output, target) 83 | output = OrderedDict({"val_loss": loss_val, **metrics_dict}) 84 | return output 85 | 86 | def validation_epoch_end(self, outputs): 87 | log_dict = {} 88 | for metric_name in outputs[0]: 89 | log_dict[metric_name] = torch.stack([x[metric_name] for x in outputs]).mean() 90 | 91 | return {"log": log_dict, "progress_bar": log_dict, **log_dict} 92 | 93 | def test_step(self, batch, batch_idx): 94 | images, target = batch 95 | output = self(images) 96 | loss_val = F.cross_entropy(output, target) 97 | metrics_dict = metrics("test", output, target) 98 | output = OrderedDict({"test_loss": loss_val, **metrics_dict}) 99 | return output 100 | 101 | def test_epoch_end(self, outputs): 102 | log_dict = {} 103 | for metric_name in outputs[0]: 104 | log_dict[metric_name] = torch.stack([x[metric_name] for x in outputs]).mean() 105 | 106 | return {"log": log_dict, "progress_bar": log_dict, **log_dict} 107 | 108 | def configure_optimizers(self): 109 | optimizer = optim.SGD( 110 | self.parameters(), lr=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay 111 | ) 112 | return optimizer 113 | 114 | def train_dataloader(self): 115 | train_loader = DataLoader( 116 | dataset=self.train_dataset, 117 | batch_size=self.batch_size, 118 | shuffle=True, 119 | num_workers=cpu_count(), 120 | ) 121 | return train_loader 122 | 123 | def val_dataloader(self): 124 | val_loader = DataLoader( 125 | self.val_dataset, 126 | batch_size=self.batch_size, 127 | shuffle=False, 128 | num_workers=cpu_count() // 2, 129 | ) 130 | return val_loader 131 | 132 | def test_dataloader(self): 133 | test_loader = DataLoader( 134 | self.test_dataset, 135 | batch_size=self.batch_size, 136 | shuffle=False, 137 | num_workers=cpu_count(), 138 | ) 139 | return test_loader 140 | 141 | @staticmethod 142 | def add_model_specific_args(parent_parser): # pragma: no-cover 143 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 144 | aa = parser.add_argument 145 | aa( 146 | "--data-path", 147 | metavar="DIR", 148 | type=str, 149 | default=os.getenv("SM_CHANNEL_TRAINING", ""), 150 | ) 151 | aa( 152 | "--backbone", 153 | default="resnet34", 154 | ) 155 | aa( 156 | "--freeze-backbone", 157 | action="store_true", 158 | ) 159 | aa( 160 | "--num-classes", 161 | default=6, 162 | type=int, 163 | metavar="N", 164 | ) 165 | aa( 166 | "-b", 167 | "--batch-size", 168 | default=64, 169 | type=int, 170 | metavar="N", 171 | ) 172 | aa( 173 | "--lr", 174 | "--learning-rate", 175 | default=1e-3, 176 | type=float, 177 | metavar="LR", 178 | dest="learning_rate", 179 | ) 180 | aa("--momentum", default=0.9, type=float, metavar="M", help="momentum") 181 | aa( 182 | "--wd", 183 | "--weight-decay", 184 | default=1e-4, 185 | type=float, 186 | metavar="W", 187 | dest="weight_decay", 188 | ) 189 | aa( 190 | "--seed", 191 | type=int, 192 | default=42, 193 | ) 194 | return parser 195 | 196 | 197 | def get_args() -> Namespace: 198 | parent_parser = ArgumentParser(add_help=False) 199 | aa = parent_parser.add_argument 200 | aa("--epochs", type=int, default=100, help="number of training epochs") 201 | aa("--save-path", metavar="DIR", default=os.getenv("SM_MODEL_DIR", ""), type=str, help="path to save output") 202 | aa("--gpus", type=int, default=os.getenv("SM_NUM_GPUS", 1), help="how many gpus") 203 | aa( 204 | "--distributed-backend", 205 | type=str, 206 | default="", 207 | choices=("dp", "ddp", "ddp2"), 208 | help="supports three options dp, ddp, ddp2", 209 | ) 210 | aa("--use-16bit", dest="use_16bit", action="store_true", help="if true uses 16 bit precision") 211 | 212 | parser = DDNClassification.add_model_specific_args(parent_parser) 213 | return parser.parse_args() 214 | 215 | 216 | def model_fn(model_dir): 217 | # TODO: `model_fn` doesn't get more args 218 | # see: https://github.com/aws/sagemaker-inference-toolkit/issues/65 219 | backbone = "resnet34" 220 | num_classes = 6 221 | 222 | model = load_checkpoint(Classification(backbone, num_classes), model_dir, prefix="model") 223 | model = model.eval() 224 | freeze(model) 225 | return model 226 | 227 | 228 | def main(args: Namespace) -> None: 229 | model = DDNClassification(**vars(args)) 230 | 231 | if args.seed is not None: 232 | pl.seed_everything(args.seed) 233 | if torch.cuda.device_count() > 1: 234 | torch.cuda.manual_seed_all(args.seed) 235 | 236 | # TODO: add deterministic training 237 | # torch.backends.cudnn.deterministic = True 238 | 239 | checkpoint_callback = ModelCheckpoint( 240 | filepath=os.path.join(args.save_path, "{epoch}-{val_loss:.3f}-{val_acc:.3f}"), 241 | save_top_k=1, 242 | verbose=True, 243 | monitor="val_acc", 244 | mode="max", 245 | ) 246 | early_stop_callback = EarlyStopping("val_loss", patience=10) 247 | trainer = pl.Trainer( 248 | default_root_dir=args.save_path, 249 | gpus=args.gpus, 250 | max_epochs=args.epochs, 251 | early_stop_callback=early_stop_callback, 252 | checkpoint_callback=checkpoint_callback, 253 | gradient_clip_val=10, 254 | num_sanity_val_steps=0, 255 | distributed_backend=args.distributed_backend or None, 256 | # precision=16 if args.use_16bit else 32, # TODO: amp apex support 257 | ) 258 | 259 | trainer.fit(model) 260 | trainer.test() 261 | return 262 | 263 | 264 | if __name__ == "__main__": 265 | main(get_args()) 266 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/src/sagemaker_defect_detection/dataset/__init__.py -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/dataset/neu.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Callable 2 | import os 3 | from pathlib import Path 4 | from collections import namedtuple 5 | 6 | from xml.etree.ElementTree import ElementTree 7 | 8 | import numpy as np 9 | 10 | import cv2 11 | 12 | import torch 13 | from torch.utils.data.dataset import Dataset 14 | 15 | from torchvision.datasets import ImageFolder 16 | 17 | 18 | class NEUCLS(ImageFolder): 19 | """ 20 | NEU-CLS dataset processing and loading 21 | """ 22 | 23 | def __init__( 24 | self, 25 | root: str, 26 | split: str, 27 | augmentation: Optional[Callable] = None, 28 | preprocessing: Optional[Callable] = None, 29 | seed: int = 123, 30 | **kwargs, 31 | ) -> None: 32 | """ 33 | NEU-CLS dataset 34 | 35 | Parameters 36 | ---------- 37 | root : str 38 | Dataset root path 39 | split : str 40 | Data split from train, val and test 41 | augmentation : Optional[Callable], optional 42 | Image augmentation function, by default None 43 | preprocess : Optional[Callable], optional 44 | Image preprocessing function, by default None 45 | seed : int, optional 46 | Random number generator seed, by default 123 47 | 48 | Raises 49 | ------ 50 | ValueError 51 | If unsupported split is used 52 | """ 53 | super().__init__(root, **kwargs) 54 | self.samples: List[Tuple[str, int]] 55 | self.split = split 56 | self.augmentation = augmentation 57 | self.preprocessing = preprocessing 58 | n_items = len(self.samples) 59 | np.random.seed(seed) 60 | perm = np.random.permutation(list(range(n_items))) 61 | # TODO: add split ratios as parameters 62 | train_end = int(0.6 * n_items) 63 | val_end = int(0.2 * n_items) + train_end 64 | if split == "train": 65 | self.samples = [self.samples[i] for i in perm[:train_end]] 66 | elif split == "val": 67 | self.samples = [self.samples[i] for i in perm[train_end:val_end]] 68 | elif split == "test": 69 | self.samples = [self.samples[i] for i in perm[val_end:]] 70 | else: 71 | raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}") 72 | 73 | 74 | DetectionSample = namedtuple("DetectionSample", ["image_path", "class_idx", "annotations"]) 75 | 76 | 77 | class NEUDET(Dataset): 78 | """ 79 | NEU-DET dataset processing and loading 80 | """ 81 | 82 | def __init__( 83 | self, 84 | root: str, 85 | split: str, 86 | augmentation: Optional[Callable] = None, 87 | preprocess: Optional[Callable] = None, 88 | seed: int = 123, 89 | ): 90 | """ 91 | NEU-DET dataset 92 | 93 | Parameters 94 | ---------- 95 | root : str 96 | Dataset root path 97 | split : str 98 | Data split from train, val and test 99 | augmentation : Optional[Callable], optional 100 | Image augmentation function, by default None 101 | preprocess : Optional[Callable], optional 102 | Image preprocessing function, by default None 103 | seed : int, optional 104 | Random number generator seed, by default 123 105 | 106 | Raises 107 | ------ 108 | ValueError 109 | If unsupported split is used 110 | """ 111 | super().__init__() 112 | self.root = Path(root) 113 | self.split = split 114 | self.classes, self.class_to_idx = self._find_classes() 115 | self.samples: List[DetectionSample] = self._make_dataset() 116 | self.augmentation = augmentation 117 | self.preprocess = preprocess 118 | n_items = len(self.samples) 119 | np.random.seed(seed) 120 | perm = np.random.permutation(list(range(n_items))) 121 | train_end = int(0.6 * n_items) 122 | val_end = int(0.2 * n_items) + train_end 123 | if split == "train": 124 | self.samples = [self.samples[i] for i in perm[:train_end]] 125 | elif split == "val": 126 | self.samples = [self.samples[i] for i in perm[train_end:val_end]] 127 | elif split == "test": 128 | self.samples = [self.samples[i] for i in perm[val_end:]] 129 | else: 130 | raise ValueError(f"Unknown split mode. Choose from `train`, `val` or `test`. Given {split}") 131 | 132 | def _make_dataset(self) -> List[DetectionSample]: 133 | instances = [] 134 | base_dir = self.root.expanduser() 135 | for target_cls in sorted(self.class_to_idx.keys()): 136 | cls_idx = self.class_to_idx[target_cls] 137 | target_dir = base_dir / target_cls 138 | if not target_dir.is_dir(): 139 | continue 140 | 141 | images = sorted(list((target_dir / "images").glob("*.jpg"))) 142 | annotations = sorted(list((target_dir / "annotations").glob("*.xml"))) 143 | assert len(images) == len(annotations), f"something is wrong. Mismatched number of images and annotations" 144 | for path, ann in zip(images, annotations): 145 | instances.append(DetectionSample(str(path), int(cls_idx), str(ann))) 146 | 147 | return instances 148 | 149 | def _find_classes(self): 150 | classes = sorted([d.name for d in os.scandir(str(self.root)) if d.is_dir()]) 151 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes, 1)} # no bg label in NEU 152 | return classes, class_to_idx 153 | 154 | @staticmethod 155 | def _get_bboxes(ann: str) -> List[List[int]]: 156 | tree = ElementTree().parse(ann) 157 | bboxes = [] 158 | for bndbox in tree.iterfind("object/bndbox"): 159 | # should subtract 1 like coco? 160 | bbox = [int(bndbox.findtext(t)) - 1 for t in ("xmin", "ymin", "xmax", "ymax")] # type: ignore 161 | assert bbox[2] > bbox[0] and bbox[3] > bbox[1], f"box size error, given {bbox}" 162 | bboxes.append(bbox) 163 | 164 | return bboxes 165 | 166 | def __len__(self): 167 | return len(self.samples) 168 | 169 | def __getitem__(self, idx: int): 170 | # Note: images are grayscaled BUT resnet needs 3 channels 171 | image = cv2.imread(self.samples[idx].image_path) # BGR channel last 172 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 173 | boxes = self._get_bboxes(self.samples[idx].annotations) 174 | num_objs = len(boxes) 175 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 176 | labels = torch.tensor([self.samples[idx].class_idx] * num_objs, dtype=torch.int64) 177 | image_id = torch.tensor([idx], dtype=torch.int64) 178 | iscrowd = torch.zeros((len(boxes),), dtype=torch.int64) 179 | 180 | target = {} 181 | target["boxes"] = boxes 182 | target["labels"] = labels 183 | target["image_id"] = image_id 184 | target["iscrowd"] = iscrowd 185 | 186 | if self.augmentation is not None: 187 | sample = self.augmentation(**{"image": image, "bboxes": boxes, "labels": labels}) 188 | image = sample["image"] 189 | target["boxes"] = torch.as_tensor(sample["bboxes"], dtype=torch.float32) 190 | # guards against crops that don't pass the min_visibility augmentation threshold 191 | if not target["boxes"].numel(): 192 | return None 193 | 194 | target["labels"] = torch.as_tensor(sample["labels"], dtype=torch.int64) 195 | 196 | if self.preprocess is not None: 197 | image = self.preprocess(image=image)["image"] 198 | 199 | boxes = target["boxes"] 200 | target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 201 | return image, target, image_id 202 | 203 | def collate_fn(self, batch): 204 | batch = filter(lambda x: x is not None, batch) 205 | return tuple(zip(*batch)) 206 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanzhongqiao/sagemaker-defect-detection/f75c9ac0aaf2cb08a6e57f7011c0c43f5d5050bf/src/sagemaker_defect_detection/models/__init__.py -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/models/ddn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 7 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 8 | from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN 9 | from torchvision.models.detection.roi_heads import RoIHeads 10 | from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork 11 | from torchvision.ops import MultiScaleRoIAlign 12 | 13 | 14 | def get_backbone(name: str) -> nn.Module: 15 | """ 16 | Get official pretrained ResNet34 and ResNet50 as backbones 17 | 18 | Parameters 19 | ---------- 20 | name : str 21 | Either `resnet34` or `resnet50` 22 | 23 | Returns 24 | ------- 25 | nn.Module 26 | resnet34 or resnet50 pytorch modules 27 | 28 | Raises 29 | ------ 30 | ValueError 31 | If unsupported name is used 32 | """ 33 | if name == "resnet34": 34 | return torchvision.models.resnet34(pretrained=True) 35 | elif name == "resnet50": 36 | return torchvision.models.resnet50(pretrained=True) 37 | else: 38 | raise ValueError("Unsupported backbone") 39 | 40 | 41 | def init_weights(m) -> None: 42 | """ 43 | Weight initialization 44 | 45 | Parameters 46 | ---------- 47 | m : [type] 48 | Module used in recursive call 49 | """ 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.xavier_normal_(m.weight) 52 | 53 | elif isinstance(m, nn.Linear): 54 | m.weight.data.normal_(0.0, 0.02) 55 | m.bias.data.fill_(0.0) 56 | 57 | return 58 | 59 | 60 | class MFN(nn.Module): 61 | def __init__(self, backbone: str): 62 | """ 63 | Implementation of MFN model as described in 64 | 65 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 66 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 67 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 68 | 69 | Parameters 70 | ---------- 71 | backbone : str 72 | Either `resnet34` or `resnet50` 73 | """ 74 | super().__init__() 75 | self.backbone = get_backbone(backbone) 76 | # input 224x224 -> conv1 output size 112x112 77 | self.start_layer = nn.Sequential( 78 | self.backbone.conv1, # type: ignore 79 | self.backbone.bn1, # type: ignore 80 | self.backbone.relu, # type: ignore 81 | self.backbone.maxpool, # type: ignore 82 | ) 83 | self.r2 = self.backbone.layer1 # 64/256x56x56 <- (resnet34/resnet50) 84 | self.r3 = self.backbone.layer2 # 128/512x28x28 85 | self.r4 = self.backbone.layer3 # 256/1024x14x14 86 | self.r5 = self.backbone.layer4 # 512/2048x7x7 87 | in_channel = 64 if backbone == "resnet34" else 256 88 | self.b2 = nn.Sequential( 89 | nn.Conv2d( 90 | in_channel, in_channel, kernel_size=3, padding=1, stride=2 91 | ), # 56 -> 28 without Relu or batchnorm not in the paper ??? 92 | nn.BatchNorm2d(in_channel), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=1, stride=2), # 28 -> 14 95 | nn.BatchNorm2d(in_channel), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(in_channel, in_channel * 2, kernel_size=1, padding=0), 98 | nn.BatchNorm2d(in_channel * 2), 99 | nn.ReLU(inplace=True), 100 | ).apply( 101 | init_weights 102 | ) # after r2: 128/512x14x14 <- 103 | self.b3 = nn.MaxPool2d(2) # after r3: 128/512x14x14 <- 104 | in_channel *= 2 # 128/512 105 | self.b4 = nn.Sequential( 106 | nn.Conv2d(in_channel * 2, in_channel, kernel_size=1, padding=0), 107 | nn.BatchNorm2d(in_channel), 108 | nn.ReLU(inplace=True), 109 | ).apply( 110 | init_weights 111 | ) # after r4: 128/512x14x14 112 | in_channel *= 4 # 512 / 2048 113 | self.b5 = nn.Sequential( 114 | nn.ConvTranspose2d( 115 | in_channel, in_channel, kernel_size=3, stride=2, padding=1, output_padding=1 116 | ), # <- after r5 which is 512x7x7 -> 512x14x14 117 | nn.BatchNorm2d(in_channel), 118 | nn.ReLU(inplace=True), 119 | nn.Conv2d(in_channel, in_channel // 4, kernel_size=1, padding=0), 120 | nn.BatchNorm2d(in_channel // 4), 121 | nn.ReLU(inplace=True), 122 | ).apply(init_weights) 123 | 124 | self.out_channels = 512 if backbone == "resnet34" else 2048 # required for FasterRCNN 125 | 126 | def forward(self, x): 127 | x = self.start_layer(x) 128 | x = self.r2(x) 129 | b2_out = self.b2(x) 130 | x = self.r3(x) 131 | b3_out = self.b3(x) 132 | x = self.r4(x) 133 | b4_out = self.b4(x) 134 | x = self.r5(x) 135 | b5_out = self.b5(x) 136 | # BatchNorm works better than L2 normalize 137 | # out = torch.cat([F.normalize(o, p=2, dim=1) for o in (b2_out, b3_out, b4_out, b5_out)], dim=1) 138 | out = torch.cat((b2_out, b3_out, b4_out, b5_out), dim=1) 139 | return out 140 | 141 | 142 | class Classification(nn.Module): 143 | """ 144 | Classification network 145 | 146 | Parameters 147 | ---------- 148 | backbone : str 149 | Either `resnet34` or `resnet50` 150 | 151 | num_classes : int 152 | Number of classes 153 | """ 154 | 155 | def __init__(self, backbone: str, num_classes: int) -> None: 156 | super().__init__() 157 | self.mfn = MFN(backbone) 158 | self.flatten = nn.Flatten() 159 | self.fc = nn.Linear(self.mfn.out_channels * 14 ** 2, num_classes) 160 | 161 | def forward(self, x): 162 | return self.fc(self.flatten(self.mfn(x))) 163 | 164 | 165 | class RPN(nn.Module): 166 | """ 167 | RPN Module as described in 168 | 169 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 170 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 171 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 172 | """ 173 | 174 | def __init__( 175 | self, 176 | out_channels: int = 512, 177 | rpn_pre_nms_top_n_train: int = 1000, # torchvision default 2000, 178 | rpn_pre_nms_top_n_test: int = 500, # torchvision default 1000, 179 | rpn_post_nms_top_n_train: int = 1000, # torchvision default 2000, 180 | rpn_post_nms_top_n_test: int = 500, # torchvision default 1000, 181 | rpn_nms_thresh: float = 0.7, 182 | rpn_fg_iou_thresh: float = 0.7, 183 | rpn_bg_iou_thresh: float = 0.3, 184 | rpn_batch_size_per_image: int = 256, 185 | rpn_positive_fraction: float = 0.5, 186 | ) -> None: 187 | super().__init__() 188 | rpn_anchor_generator = AnchorGenerator(sizes=((64, 128, 256, 512),), aspect_ratios=((0.5, 1.0, 2.0),)) 189 | rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0]) 190 | rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) 191 | rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) 192 | self.rpn = RegionProposalNetwork( 193 | rpn_anchor_generator, 194 | rpn_head, 195 | rpn_fg_iou_thresh, 196 | rpn_bg_iou_thresh, 197 | rpn_batch_size_per_image, 198 | rpn_positive_fraction, 199 | rpn_pre_nms_top_n, 200 | rpn_post_nms_top_n, 201 | rpn_nms_thresh, 202 | ) 203 | 204 | def forward(self, *args, **kwargs): 205 | return self.rpn(*args, **kwargs) 206 | 207 | 208 | class CustomTwoMLPHead(nn.Module): 209 | def __init__(self, in_channels: int, representation_size: int): 210 | super().__init__() 211 | self.avgpool = nn.AdaptiveAvgPool2d(7) 212 | self.mlp = nn.Sequential( 213 | nn.Linear(in_channels, representation_size), 214 | nn.ReLU(inplace=True), 215 | nn.Linear(representation_size, representation_size), 216 | nn.ReLU(inplace=True), 217 | ) 218 | 219 | def forward(self, x): 220 | x = self.avgpool(x) 221 | x = x.flatten(start_dim=1) 222 | x = self.mlp(x) 223 | return x 224 | 225 | 226 | class RoI(nn.Module): 227 | """ 228 | ROI Module as described in 229 | 230 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 231 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 232 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 233 | """ 234 | 235 | def __init__( 236 | self, 237 | num_classes: int, 238 | box_fg_iou_thresh=0.5, 239 | box_bg_iou_thresh=0.5, 240 | box_batch_size_per_image=512, 241 | box_positive_fraction=0.25, 242 | bbox_reg_weights=None, 243 | box_score_thresh=0.05, 244 | box_nms_thresh=0.5, 245 | box_detections_per_img=100, 246 | ) -> None: 247 | super().__init__() 248 | roi_pooler = MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2) 249 | box_head = CustomTwoMLPHead(512 * 7 ** 2, 1024) 250 | box_predictor = FastRCNNPredictor(1024, num_classes=num_classes) 251 | self.roi_head = RoIHeads( 252 | roi_pooler, 253 | box_head, 254 | box_predictor, 255 | box_fg_iou_thresh, 256 | box_bg_iou_thresh, 257 | box_batch_size_per_image, 258 | box_positive_fraction, 259 | bbox_reg_weights, 260 | box_score_thresh, 261 | box_nms_thresh, 262 | box_detections_per_img, 263 | ) 264 | 265 | def forward(self, *args, **kwargs): 266 | return self.roi_head(*args, **kwargs) 267 | 268 | 269 | class Detection(GeneralizedRCNN): 270 | """ 271 | Detection network as described in 272 | 273 | Yu He, Kechen Song, Qinggang Meng, Yunhui Yan, 274 | “An End-to-end Steel Surface Defect Detection Approach via Fusing Multiple Hierarchical Features,” 275 | IEEE Transactions on Instrumentation and Measuremente, 2020,69(4),1493-1504. 276 | """ 277 | 278 | def __init__(self, mfn, rpn, roi): 279 | dummy_transform = GeneralizedRCNNTransform(800, 1333, [00.0, 0.0, 0.0], [1.0, 1.0, 1.0]) 280 | super().__init__(mfn, rpn, roi, dummy_transform) 281 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torchvision.transforms as transforms 3 | 4 | import albumentations as albu 5 | import albumentations.pytorch.transforms as albu_transforms 6 | 7 | 8 | PROBABILITY = 0.5 9 | ROTATION_ANGLE = 90 10 | NUM_CHANNELS = 3 11 | # required for resnet 12 | IMAGE_RESIZE_HEIGHT = 256 13 | IMAGE_RESIZE_WIDTH = 256 14 | IMAGE_HEIGHT = 224 15 | IMAGE_WIDTH = 224 16 | # standard imagenet1k mean and standard deviation of RGB channels 17 | MEAN_RED = 0.485 18 | MEAN_GREEN = 0.456 19 | MEAN_BLUE = 0.406 20 | STD_RED = 0.229 21 | STD_GREEN = 0.224 22 | STD_BLUE = 0.225 23 | 24 | 25 | def get_transform(split: str) -> Callable: 26 | """ 27 | Image data transformations such as normalization for train split for classification task 28 | 29 | Parameters 30 | ---------- 31 | split : str 32 | train or else 33 | 34 | Returns 35 | ------- 36 | Callable 37 | Image transformation function 38 | """ 39 | normalize = transforms.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE]) 40 | if split == "train": 41 | return transforms.Compose( 42 | [ 43 | transforms.RandomResizedCrop(IMAGE_HEIGHT), 44 | transforms.RandomRotation(ROTATION_ANGLE), 45 | transforms.RandomHorizontalFlip(), 46 | transforms.ToTensor(), 47 | normalize, 48 | ] 49 | ) 50 | 51 | else: 52 | return transforms.Compose( 53 | [ 54 | transforms.Resize(IMAGE_RESIZE_HEIGHT), 55 | transforms.CenterCrop(IMAGE_HEIGHT), 56 | transforms.ToTensor(), 57 | normalize, 58 | ] 59 | ) 60 | 61 | 62 | def get_augmentation(split: str) -> Callable: 63 | """ 64 | Obtains proper image augmentation in train split for detection task. 65 | We have splitted transformations done for detection task into augmentation and preprocessing 66 | for clarity 67 | 68 | Parameters 69 | ---------- 70 | split : str 71 | train or else 72 | 73 | Returns 74 | ------- 75 | Callable 76 | Image augmentation function 77 | """ 78 | if split == "train": 79 | return albu.Compose( 80 | [ 81 | albu.Resize(IMAGE_RESIZE_HEIGHT, IMAGE_RESIZE_WIDTH, always_apply=True), 82 | albu.RandomCrop(IMAGE_HEIGHT, IMAGE_WIDTH, always_apply=True), 83 | albu.RandomRotate90(p=PROBABILITY), 84 | albu.HorizontalFlip(p=PROBABILITY), 85 | albu.RandomBrightness(p=PROBABILITY), 86 | ], 87 | bbox_params=albu.BboxParams( 88 | format="pascal_voc", 89 | label_fields=["labels"], 90 | min_visibility=0.2, 91 | ), 92 | ) 93 | else: 94 | return albu.Compose( 95 | [albu.Resize(IMAGE_HEIGHT, IMAGE_WIDTH)], 96 | bbox_params=albu.BboxParams(format="pascal_voc", label_fields=["labels"]), 97 | ) 98 | 99 | 100 | def get_preprocess() -> Callable: 101 | """ 102 | Image normalization using albumentation for detection task that aligns well with image augmentation 103 | 104 | Returns 105 | ------- 106 | Callable 107 | Image normalization function 108 | """ 109 | return albu.Compose( 110 | [ 111 | albu.Normalize(mean=[MEAN_RED, MEAN_GREEN, MEAN_BLUE], std=[STD_RED, STD_GREEN, STD_BLUE]), 112 | albu_transforms.ToTensorV2(), 113 | ] 114 | ) 115 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from pathlib import Path 3 | import tarfile 4 | import logging 5 | from logging.config import fileConfig 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | logger.setLevel(logging.INFO) 13 | 14 | 15 | def get_logger(config_path: str) -> logging.Logger: 16 | fileConfig(config_path, disable_existing_loggers=False) 17 | logger = logging.getLogger() 18 | return logger 19 | 20 | 21 | def str2bool(flag: Union[str, bool]) -> bool: 22 | if not isinstance(flag, bool): 23 | if flag.lower() == "false": 24 | flag = False 25 | elif flag.lower() == "true": 26 | flag = True 27 | else: 28 | raise ValueError("Wrong boolean argument!") 29 | return flag 30 | 31 | 32 | def freeze(m: nn.Module) -> None: 33 | assert isinstance(m, nn.Module), "freeze only is applied to modules" 34 | for param in m.parameters(): 35 | param.requires_grad = False 36 | 37 | return 38 | 39 | 40 | def load_checkpoint(model: nn.Module, path: str, prefix: Optional[str]) -> nn.Module: 41 | path = Path(path) 42 | logger.info(f"path: {path}") 43 | if path.is_dir(): 44 | path_str = str(list(path.rglob("*.ckpt"))[0]) 45 | else: 46 | path_str = str(path) 47 | 48 | device = "cuda" if torch.cuda.is_available() else "cpu" 49 | state_dict = torch.load(path_str, map_location=torch.device(device))["state_dict"] 50 | if prefix is not None: 51 | if prefix[-1] != ".": 52 | prefix += "." 53 | 54 | state_dict = {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)} 55 | 56 | model.load_state_dict(state_dict, strict=True) 57 | return model 58 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) Soumith Chintala 2016, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | 33 | This module is a modified version of https://github.com/pytorch/vision/tree/03b1d38ba3c67703e648fb067570eeb1a1e61265/references/detection 34 | """ 35 | 36 | import json 37 | 38 | import numpy as np 39 | import copy 40 | import torch 41 | import pickle 42 | 43 | import torch.distributed as dist 44 | 45 | from pycocotools.cocoeval import COCOeval 46 | from pycocotools.coco import COCO 47 | import pycocotools.mask as mask_util 48 | 49 | from collections import defaultdict 50 | 51 | 52 | def is_dist_avail_and_initialized(): 53 | if not dist.is_available(): 54 | return False 55 | if not dist.is_initialized(): 56 | return False 57 | return True 58 | 59 | 60 | def get_world_size(): 61 | if not is_dist_avail_and_initialized(): 62 | return 1 63 | return dist.get_world_size() 64 | 65 | 66 | def all_gather(data): 67 | """ 68 | Run all_gather on arbitrary picklable data (not necessarily tensors) 69 | Args: 70 | data: any picklable object 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | world_size = get_world_size() 75 | if world_size == 1: 76 | return [data] 77 | 78 | # serialized to a Tensor 79 | buffer = pickle.dumps(data) 80 | storage = torch.ByteStorage.from_buffer(buffer) 81 | tensor = torch.ByteTensor(storage).to("cuda") 82 | 83 | # obtain Tensor size of each rank 84 | local_size = torch.tensor([tensor.numel()], device="cuda") 85 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 86 | dist.all_gather(size_list, local_size) 87 | size_list = [int(size.item()) for size in size_list] 88 | max_size = max(size_list) 89 | 90 | # receiving Tensor from all ranks 91 | # we pad the tensor because torch all_gather does not support 92 | # gathering tensors of different shapes 93 | tensor_list = [] 94 | for _ in size_list: 95 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 96 | if local_size != max_size: 97 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 98 | tensor = torch.cat((tensor, padding), dim=0) 99 | dist.all_gather(tensor_list, tensor) 100 | 101 | data_list = [] 102 | for size, tensor in zip(size_list, tensor_list): 103 | buffer = tensor.cpu().numpy().tobytes()[:size] 104 | data_list.append(pickle.loads(buffer)) 105 | 106 | return data_list 107 | 108 | 109 | class CocoEvaluator(object): 110 | def __init__(self, coco_gt, iou_types): 111 | assert isinstance(iou_types, (list, tuple)) 112 | coco_gt = copy.deepcopy(coco_gt) 113 | self.coco_gt = coco_gt 114 | 115 | self.iou_types = iou_types 116 | self.coco_eval = {} 117 | for iou_type in iou_types: 118 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 119 | 120 | self.img_ids = [] 121 | self.eval_imgs = {k: [] for k in iou_types} 122 | 123 | def update(self, predictions): 124 | img_ids = list(np.unique(list(predictions.keys()))) 125 | self.img_ids.extend(img_ids) 126 | 127 | for iou_type in self.iou_types: 128 | results = self.prepare(predictions, iou_type) 129 | coco_dt = loadRes(self.coco_gt, results) if results else COCO() 130 | coco_eval = self.coco_eval[iou_type] 131 | 132 | coco_eval.cocoDt = coco_dt 133 | coco_eval.params.imgIds = list(img_ids) 134 | img_ids, eval_imgs = evaluate(coco_eval) 135 | if isinstance(self.eval_imgs[iou_type], np.ndarray): 136 | self.eval_imgs[iou_type] = self.eval_imgs[iou_type].tolist() 137 | 138 | self.eval_imgs[iou_type].append(eval_imgs) 139 | 140 | def synchronize_between_processes(self): 141 | for iou_type in self.iou_types: 142 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 143 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 144 | 145 | def accumulate(self): 146 | for coco_eval in self.coco_eval.values(): 147 | coco_eval.accumulate() 148 | 149 | def summarize(self): 150 | for iou_type, coco_eval in self.coco_eval.items(): 151 | print("IoU metric: {}".format(iou_type)) 152 | coco_eval.summarize() 153 | 154 | def prepare(self, predictions, iou_type): 155 | return self.prepare_for_coco_detection(predictions) 156 | 157 | def prepare_for_coco_detection(self, predictions): 158 | coco_results = [] 159 | for original_id, prediction in predictions.items(): 160 | if len(prediction) == 0: 161 | continue 162 | 163 | boxes = prediction["boxes"] 164 | boxes = convert_to_xywh(boxes).tolist() 165 | scores = prediction["scores"].tolist() 166 | labels = prediction["labels"].tolist() 167 | 168 | coco_results.extend( 169 | [ 170 | { 171 | "image_id": original_id, 172 | "category_id": labels[k], 173 | "bbox": box, 174 | "score": scores[k], 175 | } 176 | for k, box in enumerate(boxes) 177 | ] 178 | ) 179 | return coco_results 180 | 181 | 182 | def convert_to_xywh(boxes): 183 | xmin, ymin, xmax, ymax = boxes.unbind(1) 184 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 185 | 186 | 187 | def merge(img_ids, eval_imgs): 188 | all_img_ids = all_gather(img_ids) 189 | all_eval_imgs = all_gather(eval_imgs) 190 | 191 | merged_img_ids = [] 192 | for p in all_img_ids: 193 | merged_img_ids.extend(p) 194 | 195 | merged_eval_imgs = [] 196 | for p in all_eval_imgs: 197 | merged_eval_imgs.append(p) 198 | 199 | merged_img_ids = np.array(merged_img_ids) 200 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 201 | 202 | # keep only unique (and in sorted order) images 203 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 204 | merged_eval_imgs = merged_eval_imgs[..., idx] 205 | 206 | return merged_img_ids, merged_eval_imgs 207 | 208 | 209 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 210 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 211 | img_ids = list(img_ids) 212 | eval_imgs = list(eval_imgs.flatten()) 213 | 214 | coco_eval.evalImgs = eval_imgs 215 | coco_eval.params.imgIds = img_ids 216 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 217 | 218 | 219 | ################################################################# 220 | # From pycocotools, just removed the prints and fixed 221 | # a Python3 bug about unicode not defined 222 | ################################################################# 223 | 224 | # Ideally, pycocotools wouldn't have hard-coded prints 225 | # so that we could avoid copy-pasting those two functions 226 | 227 | 228 | def createIndex(self): 229 | # create index 230 | # print('creating index...') 231 | anns, cats, imgs = {}, {}, {} 232 | imgToAnns, catToImgs = defaultdict(list), defaultdict(list) 233 | if "annotations" in self.dataset: 234 | for ann in self.dataset["annotations"]: 235 | imgToAnns[ann["image_id"]].append(ann) 236 | anns[ann["id"]] = ann 237 | 238 | if "images" in self.dataset: 239 | for img in self.dataset["images"]: 240 | imgs[img["id"]] = img 241 | 242 | if "categories" in self.dataset: 243 | for cat in self.dataset["categories"]: 244 | cats[cat["id"]] = cat 245 | 246 | if "annotations" in self.dataset and "categories" in self.dataset: 247 | for ann in self.dataset["annotations"]: 248 | catToImgs[ann["category_id"]].append(ann["image_id"]) 249 | 250 | # print('index created!') 251 | 252 | # create class members 253 | self.anns = anns 254 | self.imgToAnns = imgToAnns 255 | self.catToImgs = catToImgs 256 | self.imgs = imgs 257 | self.cats = cats 258 | 259 | 260 | maskUtils = mask_util 261 | 262 | 263 | def loadRes(self, resFile): 264 | """ 265 | Load result file and return a result api object. 266 | :param resFile (str) : file name of result file 267 | :return: res (obj) : result api object 268 | """ 269 | res = COCO() 270 | res.dataset["images"] = [img for img in self.dataset["images"]] 271 | 272 | # print('Loading and preparing results...') 273 | # tic = time.time() 274 | if isinstance(resFile, torch._six.string_classes): 275 | anns = json.load(open(resFile)) 276 | elif type(resFile) == np.ndarray: 277 | anns = self.loadNumpyAnnotations(resFile) 278 | else: 279 | anns = resFile 280 | assert type(anns) == list, "results in not an array of objects" 281 | annsImgIds = [ann["image_id"] for ann in anns] 282 | assert set(annsImgIds) == ( 283 | set(annsImgIds) & set(self.getImgIds()) 284 | ), "Results do not correspond to current coco set" 285 | if "caption" in anns[0]: 286 | imgIds = set([img["id"] for img in res.dataset["images"]]) & set([ann["image_id"] for ann in anns]) 287 | res.dataset["images"] = [img for img in res.dataset["images"] if img["id"] in imgIds] 288 | for id, ann in enumerate(anns): 289 | ann["id"] = id + 1 290 | elif "bbox" in anns[0] and not anns[0]["bbox"] == []: 291 | res.dataset["categories"] = copy.deepcopy(self.dataset["categories"]) 292 | for id, ann in enumerate(anns): 293 | bb = ann["bbox"] 294 | x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] 295 | if "segmentation" not in ann: 296 | ann["segmentation"] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 297 | ann["area"] = bb[2] * bb[3] 298 | ann["id"] = id + 1 299 | ann["iscrowd"] = 0 300 | 301 | res.dataset["annotations"] = anns 302 | createIndex(res) 303 | return res 304 | 305 | 306 | def evaluate(self): 307 | """ 308 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 309 | :return: None 310 | """ 311 | # tic = time.time() 312 | # print('Running per image evaluation...') 313 | p = self.params 314 | # add backward compatibility if useSegm is specified in params 315 | if p.useSegm is not None: 316 | p.iouType = "segm" if p.useSegm == 1 else "bbox" 317 | print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) 318 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 319 | p.imgIds = list(np.unique(p.imgIds)) 320 | if p.useCats: 321 | p.catIds = list(np.unique(p.catIds)) 322 | p.maxDets = sorted(p.maxDets) 323 | self.params = p 324 | 325 | self._prepare() 326 | # loop through images, area range, max detection number 327 | catIds = p.catIds if p.useCats else [-1] 328 | 329 | if p.iouType == "segm" or p.iouType == "bbox": 330 | computeIoU = self.computeIoU 331 | elif p.iouType == "keypoints": 332 | computeIoU = self.computeOks 333 | self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds} 334 | 335 | evaluateImg = self.evaluateImg 336 | maxDet = p.maxDets[-1] 337 | evalImgs = [ 338 | evaluateImg(imgId, catId, areaRng, maxDet) for catId in catIds for areaRng in p.areaRng for imgId in p.imgIds 339 | ] 340 | # this is NOT in the pycocotools code, but could be done outside 341 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 342 | self._paramsEval = copy.deepcopy(self.params) 343 | # toc = time.time() 344 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 345 | return p.imgIds, evalImgs 346 | 347 | 348 | ################################################################# 349 | # end of straight copy from pycocotools, just removing the prints 350 | ################################################################# 351 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/coco_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | BSD 3-Clause License 3 | 4 | Copyright (c) Soumith Chintala 2016, 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | 33 | This module is a modified version of https://github.com/pytorch/vision/tree/03b1d38ba3c67703e648fb067570eeb1a1e61265/references/detection 34 | """ 35 | 36 | from pycocotools.coco import COCO 37 | 38 | 39 | def convert_to_coco_api(ds): 40 | coco_ds = COCO() 41 | # annotation IDs need to start at 1, not 0, see torchvision issue #1530 42 | ann_id = 1 43 | dataset = {"images": [], "categories": [], "annotations": []} 44 | categories = set() 45 | for img_idx in range(len(ds)): 46 | # find better way to get target 47 | # targets = ds.get_annotations(img_idx) 48 | img, targets, _ = ds[img_idx] 49 | image_id = targets["image_id"].item() 50 | img_dict = {} 51 | img_dict["id"] = image_id 52 | img_dict["height"] = img.shape[-2] 53 | img_dict["width"] = img.shape[-1] 54 | dataset["images"].append(img_dict) 55 | bboxes = targets["boxes"] 56 | bboxes[:, 2:] -= bboxes[:, :2] 57 | bboxes = bboxes.tolist() 58 | labels = targets["labels"].tolist() 59 | areas = targets["area"].tolist() 60 | iscrowd = targets["iscrowd"].tolist() 61 | num_objs = len(bboxes) 62 | for i in range(num_objs): 63 | ann = {} 64 | ann["image_id"] = image_id 65 | ann["bbox"] = bboxes[i] 66 | ann["category_id"] = labels[i] 67 | categories.add(labels[i]) 68 | ann["area"] = areas[i] 69 | ann["iscrowd"] = iscrowd[i] 70 | ann["id"] = ann_id 71 | dataset["annotations"].append(ann) 72 | ann_id += 1 73 | dataset["categories"] = [{"id": i} for i in sorted(categories)] 74 | coco_ds.dataset = dataset 75 | coco_ds.createIndex() 76 | return coco_ds 77 | -------------------------------------------------------------------------------- /src/sagemaker_defect_detection/utils/visualize.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Union, Tuple 2 | 3 | import numpy as np 4 | 5 | import torch 6 | 7 | from matplotlib import pyplot as plt 8 | 9 | import cv2 10 | 11 | import torch 12 | 13 | 14 | TEXT_COLOR = (255, 255, 255) # White 15 | CLASSES = { 16 | "crazing": "Cr", 17 | "inclusion": "In", 18 | "pitted_surface": "PS", 19 | "patches": "Pa", 20 | "rolled-in_scale": "RS", 21 | "scratches": "Sc", 22 | } 23 | CATEGORY_ID_TO_NAME = {i: name for i, name in enumerate(CLASSES.keys(), start=1)} 24 | 25 | 26 | def unnormalize_to_hwc( 27 | image: torch.Tensor, mean: List[float] = [0.485, 0.456, 0.406], std: List[float] = [0.229, 0.224, 0.225] 28 | ) -> np.ndarray: 29 | """ 30 | Unnormalizes and a normlized image tensor [0, 1] CHW -> HWC [0, 255] 31 | 32 | Parameters 33 | ---------- 34 | image : torch.Tensor 35 | Normalized image 36 | mean : List[float], optional 37 | RGB averages used in normalization, by default [0.485, 0.456, 0.406] from imagenet1k 38 | std : List[float], optional 39 | RGB standard deviations used in normalization, by default [0.229, 0.224, 0.225] from imagenet1k 40 | 41 | Returns 42 | ------- 43 | np.ndarray 44 | Unnormalized image as numpy array 45 | """ 46 | image = image.numpy().transpose(1, 2, 0) # HWC 47 | image = (image * std + mean).clip(0, 1) 48 | image = (image * 255).astype(np.uint8) 49 | return image 50 | 51 | 52 | def visualize_bbox(img: np.ndarray, bbox: np.ndarray, class_name: str, color, thickness: int = 2) -> np.ndarray: 53 | """ 54 | Uses cv2 to draw colored bounding boxes and class names in an image 55 | 56 | Parameters 57 | ---------- 58 | img : np.ndarray 59 | [description] 60 | bbox : np.ndarray 61 | [description] 62 | class_name : str 63 | Class name 64 | color : tuple 65 | BGR tuple 66 | thickness : int, optional 67 | Bouding box thickness, by default 2 68 | """ 69 | x_min, y_min, x_max, y_max = tuple(map(int, bbox)) 70 | 71 | cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) 72 | 73 | ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) 74 | cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), color, -1) 75 | cv2.putText( 76 | img, 77 | text=class_name, 78 | org=(x_min, y_min - int(0.3 * text_height)), 79 | fontFace=cv2.FONT_HERSHEY_SIMPLEX, 80 | fontScale=0.35, 81 | color=TEXT_COLOR, 82 | lineType=cv2.LINE_AA, 83 | ) 84 | return img 85 | 86 | 87 | def visualize( 88 | image: np.ndarray, 89 | bboxes: Iterable[Union[torch.Tensor, np.ndarray]] = [], 90 | category_ids: Iterable[Union[torch.Tensor, np.ndarray]] = [], 91 | colors: Iterable[Tuple[int, int, int]] = [], 92 | titles: Iterable[str] = [], 93 | category_id_to_name=CATEGORY_ID_TO_NAME, 94 | dpi=150, 95 | ) -> None: 96 | """ 97 | Applies the bounding boxes and category ids to an image 98 | 99 | Parameters 100 | ---------- 101 | image : np.ndarray 102 | Image as numpy array 103 | bboxes : Iterable[Union[torch.Tensor, np.ndarray]], optional 104 | Bouding boxes, by default [] 105 | category_ids : Iterable[Union[torch.Tensor, np.ndarray]], optional 106 | Category ids, by default [] 107 | colors : Iterable[Tuple[int, int, int]], optional 108 | Colors for each bounding box, by default [()] 109 | titles : Iterable[str], optional 110 | Titles for each image, by default [] 111 | category_id_to_name : Dict[str, str], optional 112 | Dictionary of category ids to names, by default CATEGORY_ID_TO_NAME 113 | dpi : int, optional 114 | DPI for clarity, by default 150 115 | """ 116 | bboxes, category_ids, colors, titles = list(map(list, [bboxes, category_ids, colors, titles])) # type: ignore 117 | n = len(bboxes) 118 | assert ( 119 | n == len(category_ids) == len(colors) == len(titles) - 1 120 | ), f"number of bboxes, category ids, colors and titles (minus one) do not match" 121 | 122 | plt.figure(dpi=dpi) 123 | ncols = n + 1 124 | plt.subplot(1, ncols, 1) 125 | img = image.copy() 126 | plt.axis("off") 127 | plt.title(titles[0]) 128 | plt.imshow(image) 129 | if not len(bboxes): 130 | return 131 | 132 | titles = titles[1:] 133 | for i in range(2, ncols + 1): 134 | img = image.copy() 135 | plt.subplot(1, ncols, i) 136 | plt.axis("off") 137 | j = i - 2 138 | plt.title(titles[j]) 139 | for bbox, category_id in zip(bboxes[j], category_ids[j]): # type: ignore 140 | if isinstance(bbox, torch.Tensor): 141 | bbox = bbox.numpy() 142 | 143 | if isinstance(category_id, torch.Tensor): 144 | category_id = category_id.numpy() 145 | 146 | if isinstance(category_id, np.ndarray): 147 | category_id = category_id.item() 148 | 149 | class_name = category_id_to_name[category_id] 150 | img = visualize_bbox(img, bbox, class_name, color=colors[j]) 151 | 152 | plt.imshow(img) 153 | return 154 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import boto3 4 | import copy 5 | import matplotlib.patches as patches 6 | from matplotlib import pyplot as plt 7 | from PIL import Image, ImageColor 8 | 9 | 10 | def query_Type2(image_file_name, endpoint_name, num_predictions=4): 11 | 12 | with open(image_file_name, "rb") as file: 13 | input_img_rb = file.read() 14 | 15 | client = boto3.client("runtime.sagemaker") 16 | query_response = client.invoke_endpoint( 17 | EndpointName=endpoint_name, 18 | ContentType="application/x-image", 19 | Body=input_img_rb, 20 | Accept=f'application/json;verbose;n_predictions={num_predictions}' 21 | ) 22 | # If we remove ';n_predictions={}' from Accept, we get all the predicted boxes. 23 | query_response = query_response['Body'].read() 24 | 25 | model_predictions = json.loads(query_response) 26 | normalized_boxes, classes, scores, labels = ( 27 | model_predictions["normalized_boxes"], 28 | model_predictions["classes"], 29 | model_predictions["scores"], 30 | model_predictions["labels"], 31 | ) 32 | # Substitute the classes index with the classes name 33 | class_names = [labels[int(idx)] for idx in classes] 34 | return normalized_boxes, class_names, scores 35 | 36 | 37 | # Copied from albumentations/augmentations/functional.py 38 | # Follow albumentations.Normalize, which is used in sagemaker_defect_detection/detector.py 39 | def normalize(img, mean, std, max_pixel_value=255.0): 40 | mean = np.array(mean, dtype=np.float32) 41 | mean *= max_pixel_value 42 | 43 | std = np.array(std, dtype=np.float32) 44 | std *= max_pixel_value 45 | 46 | denominator = np.reciprocal(std, dtype=np.float32) 47 | 48 | img = img.astype(np.float32) 49 | img -= mean 50 | img *= denominator 51 | return img 52 | 53 | 54 | def query_DDN(image_file_name, endpoint_name, num_predictions=4): 55 | 56 | with Image.open(image_file_name) as im: 57 | image_np = np.array(im) 58 | 59 | # Follow albumentations.Normalize, which is used in sagemaker_defect_detection/detector.py 60 | mean = (0.485, 0.456, 0.406) 61 | std = (0.229, 0.224, 0.225) 62 | max_pixel_value = 255.0 63 | 64 | image_np = normalize(image_np, mean, std, max_pixel_value) 65 | image_np = image_np.transpose(2, 0, 1) 66 | image_np = np.expand_dims(image_np, 0) # CHW 67 | 68 | client = boto3.client("runtime.sagemaker") 69 | query_response = client.invoke_endpoint( 70 | EndpointName=endpoint_name, 71 | ContentType="application/json", 72 | Body=json.dumps(image_np.tolist()) 73 | ) 74 | query_response = query_response["Body"].read() 75 | 76 | model_predictions = json.loads(query_response.decode())[0] 77 | unnormalized_boxes = model_predictions['boxes'][:num_predictions] 78 | class_names = model_predictions['labels'][:num_predictions] 79 | scores = model_predictions['scores'][:num_predictions] 80 | return unnormalized_boxes, class_names, scores 81 | 82 | 83 | 84 | def query_Type1(image_file_name, endpoint_name, num_predictions=4): 85 | 86 | with open(image_file_name, "rb") as file: 87 | input_img_rb = file.read() 88 | 89 | client = boto3.client(service_name="runtime.sagemaker") 90 | query_response = client.invoke_endpoint( 91 | EndpointName=endpoint_name, 92 | ContentType="application/x-image", 93 | Body=input_img_rb 94 | ) 95 | query_response = query_response["Body"].read() 96 | 97 | model_predictions = json.loads(query_response)['prediction'][:num_predictions] 98 | class_names = [int(pred[0])+1 for pred in model_predictions] # +1 for index starts from 1 99 | scores = [pred[1] for pred in model_predictions] 100 | normalized_boxes = [pred[2:] for pred in model_predictions] 101 | return normalized_boxes, class_names, scores 102 | 103 | 104 | def plot_results(image, bboxes, categories, d): 105 | # d - dictionary of endpoint responses 106 | 107 | colors = list(ImageColor.colormap.values()) 108 | with Image.open(image) as im: 109 | image_np = np.array(im) 110 | fig = plt.figure(figsize=(20, 14)) 111 | 112 | n = len(d) 113 | 114 | # Ground truth 115 | ax1 = fig.add_subplot(2, 3, 1) 116 | plt.axis('off') 117 | plt.title('Ground Truth') 118 | 119 | for bbox in bboxes: 120 | left, bot, right, top = bbox['bbox'] 121 | x, y, w, h = left, bot, right - left, top - bot 122 | 123 | color = colors[hash(bbox['category_id']) % len(colors)] 124 | rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor=color, facecolor="none") 125 | ax1.add_patch(rect) 126 | ax1.text(x, y, "{}".format(categories[bbox['category_id']]), 127 | bbox=dict(facecolor="white", alpha=0.5)) 128 | 129 | ax1.imshow(image_np) 130 | 131 | # Predictions 132 | counter = 2 133 | for k, v in d.items(): 134 | axi = fig.add_subplot(2, 3, counter) 135 | counter += 1 136 | 137 | if "Type2-HPO" in k: 138 | k = "Type2-HPO" 139 | elif "Type2" in k: 140 | k = "Type2" 141 | elif "Type1-HPO" in k: 142 | k = "Type1-HPO" 143 | elif "Type1" in k: 144 | k = "Type1" 145 | elif k.endswith("finetuned-endpoint"): 146 | k = "DDN" 147 | 148 | plt.title(f'Prediction: {k}') 149 | plt.axis('off') 150 | 151 | for idx in range(len(v['normalized_boxes'])): 152 | left, bot, right, top = v['normalized_boxes'][idx] 153 | if k == 'DDN': 154 | x, w = left, right - left 155 | y, h = bot, top - bot 156 | else: 157 | x, w = [val * image_np.shape[1] for val in [left, right - left]] 158 | y, h = [val * image_np.shape[0] for val in [bot, top - bot]] 159 | color = colors[hash(v['classes_names'][idx]) % len(colors)] 160 | rect = patches.Rectangle((x, y), w, h, linewidth=3, edgecolor=color, facecolor="none") 161 | axi.add_patch(rect) 162 | axi.text(x, y, 163 | "{} {:.0f}%".format(categories[v['classes_names'][idx]], v['confidences'][idx] * 100), 164 | bbox=dict(facecolor="white", alpha=0.5), 165 | ) 166 | 167 | axi.imshow(image_np) 168 | 169 | plt.tight_layout() 170 | plt.savefig("results/"+ image.split('/')[-1]) 171 | 172 | plt.show() 173 | -------------------------------------------------------------------------------- /src/xml2json.py: -------------------------------------------------------------------------------- 1 | # Use this script to convert annotation xmls to a single annotations.json file that will be taken by Jumpstart OD model 2 | # Reference: XML2JSON.py https://linuxtut.com/en/e391e5e6924945b8a852/ 3 | 4 | import random 5 | import xmltodict 6 | import copy 7 | import json 8 | import glob 9 | import os 10 | from collections import defaultdict 11 | 12 | 13 | categories = [ 14 | {"id": 1, "name": "crazing"}, 15 | {"id": 2, "name": "inclusion"}, 16 | {"id": 3, "name": "pitted_surface"}, 17 | {"id": 4, "name": "patches"}, 18 | {"id": 5, "name": "rolled-in_scale"}, 19 | {"id": 6, "name": "scratches"}, 20 | ] 21 | 22 | 23 | def XML2JSON(xmlFiles, test_ratio=None, rnd_seed=100): 24 | """ Convert all xmls to annotations.json 25 | 26 | If the test_ratio is not None, convert to two annotations.json files, 27 | one for train+val, another one for test. 28 | """ 29 | 30 | images = list() 31 | annotations = list() 32 | image_id = 1 33 | annotation_id = 1 34 | for file in xmlFiles: 35 | annotation_path = file 36 | image = dict() 37 | with open(annotation_path) as fd: 38 | doc = xmltodict.parse(fd.read(), force_list=('object')) 39 | filename = str(doc['annotation']['filename']) 40 | image['file_name'] = filename if filename.endswith('.jpg') else filename + '.jpg' 41 | image['height'] = int(doc['annotation']['size']['height']) 42 | image['width'] = int(doc['annotation']['size']['width']) 43 | image['id'] = image_id 44 | # print("File Name: {} and image_id {}".format(file, image_id)) 45 | images.append(image) 46 | if 'object' in doc['annotation']: 47 | for obj in doc['annotation']['object']: 48 | for value in categories: 49 | annotation = dict() 50 | if str(obj['name']) == value["name"]: 51 | annotation["image_id"] = image_id 52 | xmin = int(obj["bndbox"]["xmin"]) 53 | ymin = int(obj["bndbox"]["ymin"]) 54 | xmax = int(obj["bndbox"]["xmax"]) 55 | ymax = int(obj["bndbox"]["ymax"]) 56 | annotation["bbox"] = [xmin, ymin, xmax, ymax] 57 | annotation["category_id"] = value["id"] 58 | annotation["id"] = annotation_id 59 | annotation_id += 1 60 | annotations.append(annotation) 61 | 62 | else: 63 | print("File: {} doesn't have any object".format(file)) 64 | 65 | image_id += 1 66 | 67 | if test_ratio is None: 68 | attrDict = dict() 69 | attrDict["images"] = images 70 | attrDict["annotations"] = annotations 71 | 72 | jsonString = json.dumps(attrDict) 73 | with open("annotations.json", "w") as f: 74 | f.write(jsonString) 75 | else: 76 | assert test_ratio < 1.0 77 | 78 | # Size of each class 79 | category_ids = defaultdict(list) 80 | for img in images: 81 | category = img['file_name'].split('_')[0] 82 | category_ids[category].append(img['id']) 83 | print('\ncategory\tnum of images') 84 | print('-' * 20) 85 | 86 | random.seed(rnd_seed) 87 | 88 | train_val_images = [] 89 | test_images = [] 90 | train_val_annotations = [] 91 | test_annotations = [] 92 | 93 | for category in category_ids.keys(): 94 | print(f"{category}:\t{len(category_ids[category])}") 95 | 96 | random.shuffle(category_ids[category]) 97 | N = len(category_ids[category]) 98 | ids = category_ids[category] 99 | 100 | sep = int(N * test_ratio) 101 | 102 | category_images = [img for img in images if img['id'] in ids[:sep]] 103 | test_images.extend(category_images) 104 | category_images = [img for img in images if img['id'] in ids[sep:]] 105 | train_val_images.extend(category_images) 106 | 107 | category_annotations = [ann for ann in annotations if ann['image_id'] in ids[:sep]] 108 | test_annotations.extend(category_annotations) 109 | category_annotations = [ann for ann in annotations if ann['image_id'] in ids[sep:]] 110 | train_val_annotations.extend(category_annotations) 111 | 112 | print('-' * 20) 113 | 114 | train_val_attrDict = dict() 115 | train_val_attrDict["images"] = train_val_images 116 | train_val_attrDict["annotations"] = train_val_annotations 117 | print(f"\ntrain_val:\t{len(train_val_images)}") 118 | 119 | train_val_jsonString = json.dumps(train_val_attrDict) 120 | with open("annotations.json", "w") as f: 121 | f.write(train_val_jsonString) 122 | 123 | test_attDict = dict() 124 | test_attDict["images"] = test_images 125 | test_attDict["annotations"] = test_annotations 126 | print(f"test:\t{len(test_images)}") 127 | 128 | test_jsonString = json.dumps(test_attDict) 129 | with open("test_annotations.json", "w") as f: 130 | f.write(test_jsonString) 131 | 132 | 133 | 134 | def convert_to_pycocotools_ground_truth(annotations_file): 135 | """ 136 | Given the annotation json file for the test data generated during 137 | initial data preparatoin, convert it to the input format pycocotools 138 | can consume. 139 | """ 140 | 141 | with open(annotations_file) as f: 142 | images_annotations = json.loads(f.read()) 143 | 144 | attrDict = dict() 145 | attrDict["images"] = images_annotations["images"] 146 | attrDict["categories"] = categories 147 | 148 | annotations = [] 149 | for entry in images_annotations['annotations']: 150 | ann = copy.deepcopy(entry) 151 | xmin, ymin, xmax, ymax = ann["bbox"] 152 | ann["bbox"] = [xmin, ymin, xmax-xmin, ymax-ymin] # convert to [x, y, W, H] 153 | ann["area"] = (xmax - xmin) * (ymax - ymin) 154 | ann["iscrowd"] = 0 155 | annotations.append(ann) 156 | 157 | attrDict["annotations"] = annotations 158 | 159 | jsonString = json.dumps(attrDict) 160 | ground_truth_annotations = "results/ground_truth_annotations.json" 161 | 162 | with open(ground_truth_annotations, "w") as f: 163 | f.write(jsonString) 164 | 165 | return ground_truth_annotations 166 | 167 | 168 | if __name__ == "__main__": 169 | data_path = '../../NEU-DET/ANNOTATIONS' 170 | xmlfiles = glob.glob(os.path.join(data_path, '*.xml')) 171 | xmlfiles.sort() 172 | 173 | XML2JSON(xmlfiles, test_ratio=0.2) 174 | 175 | 176 | --------------------------------------------------------------------------------