├── .github └── PULL_REQUEST_TEMPLATE.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Image_Classification ├── Image_Classification_Tutorial.ipynb ├── cnn.png ├── lenet-mxnet.png ├── mlp.png ├── mxnet_cnn.py ├── signnames.csv └── traffic.png ├── LICENSE ├── NOTICE ├── README.md └── Segmentation_SageMaker ├── Segmentation_Module.ipynb └── segmentation.py /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | *Issue #, if available:* 2 | 3 | *Description of changes:* 4 | 5 | 6 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. 7 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check [existing open](https://github.com/aws-samples/aws-ml-vision-end2end/issues), or [recently closed](https://github.com/aws-samples/aws-ml-vision-end2end/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels ((enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any ['help wanted'](https://github.com/aws-samples/aws-ml-vision-end2end/labels/help%20wanted) issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](https://github.com/aws-samples/aws-ml-vision-end2end/blob/master/LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | 61 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 62 | -------------------------------------------------------------------------------- /Image_Classification/cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-ml-vision-end2end/954f5c6a4449f73ac9c702028427e1e1e60b002e/Image_Classification/cnn.png -------------------------------------------------------------------------------- /Image_Classification/lenet-mxnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-ml-vision-end2end/954f5c6a4449f73ac9c702028427e1e1e60b002e/Image_Classification/lenet-mxnet.png -------------------------------------------------------------------------------- /Image_Classification/mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-ml-vision-end2end/954f5c6a4449f73ac9c702028427e1e1e60b002e/Image_Classification/mlp.png -------------------------------------------------------------------------------- /Image_Classification/mxnet_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mxnet as mx 3 | import os 4 | import logging 5 | 6 | ############################### 7 | ### Model Building ### 8 | ############################### 9 | 10 | def conv_layer(x, nf, k): 11 | #convolution and activation 12 | x = mx.sym.Convolution(x, num_filter=nf, kernel=k) 13 | x = mx.sym.Activation(x, act_type='relu') 14 | #max pooling reduces spatial dimension by half, increasing receptive field 15 | x = mx.sym.Pooling(x, kernel=(2,2), stride=(2,2), pool_type='max') 16 | return x 17 | 18 | def build_cnn(conv_params, num_hidden, num_classes): 19 | 20 | #Input and label placeholders 21 | data = mx.sym.Variable('data') 22 | label = mx.sym.Variable('softmax_label') 23 | 24 | #build conv layers 25 | x = data 26 | for i, conv_param in enumerate(conv_params): 27 | x = conv_layer(x, conv_param[0], conv_param[1]) 28 | 29 | #flatten to fully dense layer 30 | x = mx.sym.Flatten(x, name='flat_1') 31 | #hidden layer 32 | x = mx.sym.FullyConnected(x, num_hidden=num_hidden, name='fc_1') 33 | x = mx.sym.Activation(x, act_type='relu', name='relu_3') 34 | #dense to vector of class length 35 | output = mx.sym.FullyConnected(x, num_hidden=num_classes, name='fc_2') 36 | #categorical cross-entropy 37 | loss = mx.sym.SoftmaxOutput(output, label, name='softmax') 38 | return loss 39 | 40 | ############################### 41 | ### Data Loading ### 42 | ############################### 43 | 44 | def get_data(f_path): 45 | train_X = np.load(os.path.join(f_path,'train_X.npy')) 46 | train_Y = np.load(os.path.join(f_path,'train_Y.npy')) 47 | validation_X = np.load(os.path.join(f_path,'validation_X.npy')) 48 | validation_Y = np.load(os.path.join(f_path,'validation_Y.npy')) 49 | return train_X, train_Y, validation_X, validation_Y 50 | 51 | ############################### 52 | ### Training Loop ### 53 | ############################### 54 | 55 | def train(channel_input_dirs, hyperparameters, hosts): 56 | conv_params = hyperparameters.get('conv_params', [[20, (5,5)], [50, (5,5)]]) 57 | num_fc = hyperparameters.get('num_fc', 128) 58 | num_classes = hyperparameters.get('num_classes', 43) 59 | batch_size = hyperparameters.get('batch_size', 64) 60 | epochs = hyperparameters.get('epochs', 10) 61 | learning_rate = hyperparameters.get('learning_rate', 1E-3) 62 | num_gpus = hyperparameters.get('num_gpus', 0) 63 | # set logging 64 | logging.getLogger().setLevel(logging.DEBUG) 65 | 66 | if len(hosts) == 1: 67 | kvstore = 'device' if num_gpus > 0 else 'local' 68 | else: 69 | kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync' 70 | ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 71 | 72 | f_path = channel_input_dirs['training'] 73 | train_X, train_Y, validation_X, validation_Y = get_data(f_path) 74 | 75 | train_iter = train_iter = mx.io.NDArrayIter(data = train_X, label=train_Y, batch_size=batch_size, shuffle=True) 76 | validation_iter = mx.io.NDArrayIter(data = validation_X, label=validation_Y, batch_size=batch_size, shuffle=False) 77 | sym = build_cnn(conv_params, num_fc, num_classes) 78 | net = mx.mod.Module(sym, context=ctx) 79 | net.fit(train_iter, 80 | eval_data=validation_iter, 81 | initializer = mx.initializer.Xavier(magnitude=2.24), 82 | optimizer='adam', 83 | optimizer_params={'learning_rate':learning_rate}, 84 | eval_metric='acc', 85 | num_epoch=epochs) 86 | 87 | return net 88 | 89 | -------------------------------------------------------------------------------- /Image_Classification/signnames.csv: -------------------------------------------------------------------------------- 1 | a,b 2 | 0,Speed limit (20km/h) 3 | 1,Speed limit (30km/h) 4 | 2,Speed limit (50km/h) 5 | 3,Speed limit (60km/h) 6 | 4,Speed limit (70km/h) 7 | 5,Speed limit (80km/h) 8 | 6,End of speed limit (80km/h) 9 | 7,Speed limit (100km/h) 10 | 8,Speed limit (120km/h) 11 | 9,No passing 12 | 10,No passing for vehicles over 3.5 metric tons 13 | 11,Right-of-way at the next intersection 14 | 12,Priority road 15 | 13,Yield 16 | 14,Stop 17 | 15,No vehicles 18 | 16,Vehicles over 3.5 metric tons prohibited 19 | 17,No entry 20 | 18,General caution 21 | 19,Dangerous curve to the left 22 | 20,Dangerous curve to the right 23 | 21,Double curve 24 | 22,Bumpy road 25 | 23,Slippery road 26 | 24,Road narrows on the right 27 | 25,Road work 28 | 26,Traffic signals 29 | 27,Pedestrians 30 | 28,Children crossing 31 | 29,Bicycles crossing 32 | 30,Beware of ice/snow 33 | 31,Wild animals crossing 34 | 32,End of all speed and passing limits 35 | 33,Turn right ahead 36 | 34,Turn left ahead 37 | 35,Ahead only 38 | 36,Go straight or right 39 | 37,Go straight or left 40 | 38,Keep right 41 | 39,Keep left 42 | 40,Roundabout mandatory 43 | 41,End of no passing 44 | 42,End of no passing by vehicles over 3.5 metric tons 45 | -------------------------------------------------------------------------------- /Image_Classification/traffic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-samples/aws-ml-vision-end2end/954f5c6a4449f73ac9c702028427e1e1e60b002e/Image_Classification/traffic.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | AWS Ml Vision End2end 2 | Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AWS Ml Vision End2end 2 | 3 | This repository contains Jupyter Notebook tutorials for computer vision use-cases. The tutorials take you end-to-end through the process of developing a deep-learning model for computer vision: 4 | * Load, Explore, and Understand the data relevant to your computer vision task 5 | * Prototype deep learning models in the MXNet Framework, using both the MXNet Symbolic API and the imperative Gluon interface. 6 | * Port prototype code to run scalable training jobs using the Amazon SageMaker platform 7 | * Deploy trained models to inference endpoints using Amazon SageMaker 8 | * Deploy trained models to the Edge using AWS DeepLens. 9 | 10 | ## Tutorials Available 11 | * [Image Classification](https://nbviewer.jupyter.org/github/aws-samples/aws-ml-vision-end2end/blob/master/Image_Classification/Image_Classification_Tutorial.ipynb): A level-101 Intro to Computer Vision with Deep Learning, this tutorial covers a very simple image classification task for traffic sign classification. 12 | 13 | ## Coming Soon 14 | * Transfer Learning 15 | * Object-Detection 16 | * Semantic Segmentation 17 | 18 | ## Data Sets Used 19 | * [Traffic Sign Dataset](http://benchmark.ini.rub.de/?section=gtsrb&subsection=dataset#Imageformat) 20 | 21 | ## License 22 | 23 | This library is licensed under the Apache 2.0 License. 24 | -------------------------------------------------------------------------------- /Segmentation_SageMaker/segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import mxnet as mx 3 | from mxnet import ndarray as F 4 | from mxnet.io import DataBatch, DataDesc 5 | import os 6 | import numpy as np 7 | import logging 8 | import urllib 9 | import zipfile 10 | import tarfile 11 | import shutil 12 | import gzip 13 | from glob import glob 14 | import random 15 | import json 16 | 17 | ############################### 18 | ### Loss Functions ### 19 | ############################### 20 | 21 | def dice_coef(y_true, y_pred): 22 | intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=(1, 2, 3)) 23 | return mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.sum(y_true, axis=(1, 2, 3)) + mx.sym.sum(y_pred, axis=(1, 2, 3)) + 1.)) 24 | 25 | def dice_coef_loss(y_true, y_pred): 26 | intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, ) 27 | return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.)) 28 | 29 | ############################### 30 | ### UNet Architecture ### 31 | ############################### 32 | 33 | def conv_block(inp, num_filter, kernel, pad, block, conv_block): 34 | conv = mx.sym.Convolution(inp, num_filter=num_filter, kernel=kernel, pad=pad, name='conv%i_%i' % (block, conv_block)) 35 | conv = mx.sym.BatchNorm(conv, fix_gamma=True, name='bn%i_%i' % (block, conv_block)) 36 | conv = mx.sym.Activation(conv, act_type='relu', name='relu%i_%i' % (block, conv_block)) 37 | return conv 38 | 39 | def down_block(inp, num_filter, kernel, pad, block, pool=True): 40 | conv = conv_block(inp, num_filter, kernel, pad, block, 1) 41 | conv = conv_block(conv, num_filter, kernel, pad, block, 2) 42 | if pool: 43 | pool = mx.sym.Pooling(conv, kernel=(2,2), stride=(2,2), pool_type='max', name='pool_%i' % block) 44 | return pool, conv 45 | return conv 46 | 47 | def down_branch(inp): 48 | pool1, conv1 = down_block(inp, num_filter=32, kernel=(3,3), pad=(1,1), block=1) 49 | pool2, conv2 = down_block(pool1, num_filter=64, kernel=(3,3), pad=(1,1), block=2) 50 | pool3, conv3 = down_block(pool2, num_filter=128, kernel=(3,3), pad=(1,1), block=3) 51 | pool4, conv4 = down_block(pool3, num_filter=256, kernel=(3,3), pad=(1,1), block=4) 52 | conv5 = down_block(pool4, num_filter=512, kernel=(3,3), pad=(1,1), block=5, pool=False) 53 | return [conv5, conv4, conv3, conv2, conv1] 54 | 55 | def up_block(inp, down_feature, num_filter, kernel, pad, block): 56 | trans_conv = mx.sym.Deconvolution(inp, num_filter=num_filter, kernel=(2,2), stride=(2,2), no_bias=True, 57 | name='trans_conv_%i' % block) 58 | up = mx.sym.concat(*[trans_conv, down_feature], dim=1, name='concat_%i' % block) 59 | conv = conv_block(up, num_filter, kernel, pad, block, 1) 60 | conv = conv_block(conv, num_filter, kernel, pad, block, 2) 61 | return conv 62 | 63 | def up_branch(down_features): 64 | conv6 = up_block(down_features[0], down_features[1], num_filter=256, kernel=(3,3), pad=(1,1), block=6) 65 | conv7 = up_block(conv6, down_features[2], num_filter=128, kernel=(3,3), pad=(1,1), block=7) 66 | conv8 = up_block(conv7, down_features[3], num_filter=64, kernel=(3,3), pad=(1,1), block=8) 67 | conv9 = up_block(conv8, down_features[4], num_filter=64, kernel=(3,3), pad=(1,1), block=9) 68 | conv10 = mx.sym.Convolution(conv9, num_filter=1, kernel=(1,1), name='conv10_1') 69 | return conv10 70 | 71 | def dice_coef_loss(y_true, y_pred): 72 | intersection = mx.sym.sum(mx.sym.broadcast_mul(y_true, y_pred), axis=1, ) 73 | return -mx.sym.broadcast_div((2. * intersection + 1.),(mx.sym.broadcast_add(mx.sym.sum(y_true, axis=1), mx.sym.sum(y_pred, axis=1)) + 1.)) 74 | 75 | def build_unet(inference=False): 76 | data = mx.sym.Variable(name='data') 77 | down_features = down_branch(data) 78 | decoded = up_branch(down_features) 79 | decoded = mx.sym.sigmoid(decoded, name='softmax') 80 | if inference: 81 | return decoded 82 | else: 83 | 84 | net = mx.sym.Flatten(decoded) 85 | label = mx.sym.Variable(name='label') 86 | label = mx.sym.Flatten(label, name='label_flatten') 87 | loss = mx.sym.MakeLoss(dice_coef_loss(label, net), normalization='batch') 88 | mask_output = mx.sym.BlockGrad(decoded, 'mask') 89 | out = mx.sym.Group([loss, mask_output]) 90 | return out 91 | 92 | ############################### 93 | ### Training Script ### 94 | ############################### 95 | 96 | def get_data(f_path): 97 | train_X = np.load(os.path.join(f_path,'train_X_crops.npy')) 98 | train_Y = np.load(os.path.join(f_path,'train_Y_crops.npy')) 99 | validation_X = np.load(os.path.join(f_path,'validation_X_crops.npy')) 100 | validation_Y = np.load(os.path.join(f_path,'validation_Y_crops.npy')) 101 | return train_X, train_Y, validation_X, validation_Y 102 | 103 | def train(channel_input_dirs, hyperparameters, hosts, **kwargs): 104 | # retrieve the hyperparameters we set in notebook (with some defaults) 105 | batch_size = hyperparameters.get('batch_size', 128) 106 | epochs = hyperparameters.get('epochs', 100) 107 | learning_rate = hyperparameters.get('learning_rate', 0.1) 108 | beta1 = hyperparameters.get('beta1', 0.9) 109 | beta2 = hyperparameters.get('beta2', 0.99) 110 | num_gpus = hyperparameters.get('num_gpus', 0) 111 | burn_in = hyperparameters.get('burn_in', 5) 112 | # set logging 113 | logging.getLogger().setLevel(logging.DEBUG) 114 | 115 | if len(hosts) == 1: 116 | kvstore = 'device' if num_gpus > 0 else 'local' 117 | else: 118 | kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync' 119 | 120 | ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 121 | print (ctx) 122 | f_path = channel_input_dirs['training'] 123 | train_X, train_Y, validation_X, validation_Y = get_data(f_path) 124 | 125 | print ('loaded data') 126 | 127 | train_iter = mx.io.NDArrayIter(data = train_X, label=train_Y, batch_size=batch_size, shuffle=True) 128 | validation_iter = mx.io.NDArrayIter(data = validation_X, label=validation_Y, batch_size=batch_size, shuffle=False) 129 | data_shape = (batch_size,) + train_X.shape[1:] 130 | label_shape = (batch_size,) + train_Y.shape[1:] 131 | 132 | print ('created iters') 133 | 134 | sym = build_unet() 135 | net = mx.mod.Module(sym, context=ctx, data_names=('data',), label_names=('label',)) 136 | net.bind(data_shapes=[['data', data_shape]], label_shapes=[['label', label_shape]]) 137 | net.init_params(mx.initializer.Xavier(magnitude=6)) 138 | net.init_optimizer(optimizer = 'adam', 139 | optimizer_params=( 140 | ('learning_rate', learning_rate), 141 | ('beta1', beta1), 142 | ('beta2', beta2) 143 | )) 144 | print ('start training') 145 | smoothing_constant = .01 146 | curr_losses = [] 147 | moving_losses = [] 148 | i = 0 149 | best_val_loss = np.inf 150 | for e in range(epochs): 151 | while True: 152 | try: 153 | batch = next(train_iter) 154 | except StopIteration: 155 | train_iter.reset() 156 | break 157 | net.forward_backward(batch) 158 | loss = net.get_outputs()[0] 159 | net.update() 160 | curr_loss = F.mean(loss).asscalar() 161 | curr_losses.append(curr_loss) 162 | moving_loss = (curr_loss if ((i == 0) and (e == 0)) 163 | else (1 - smoothing_constant) * moving_loss + (smoothing_constant) * curr_loss) 164 | moving_losses.append(moving_loss) 165 | i += 1 166 | val_losses = [] 167 | for batch in validation_iter: 168 | net.forward(batch) 169 | loss = net.get_outputs()[0] 170 | val_losses.append(F.mean(loss).asscalar()) 171 | validation_iter.reset() 172 | # early stopping 173 | val_loss = np.mean(val_losses) 174 | if e > burn_in and val_loss < best_val_loss: 175 | best_val_loss = val_loss 176 | net.save_checkpoint('best_net', 0) 177 | print("Best model at Epoch %i" %(e+1)) 178 | print("Epoch %i: Moving Training Loss %0.5f, Validation Loss %0.5f" % (e+1, moving_loss, val_loss)) 179 | inference_sym = build_unet(inference=True) 180 | net = mx.mod.Module(inference_sym, context=ctx, data_names=('data',)) 181 | net.bind(data_shapes=[['data', data_shape]]) 182 | net.load_params('best_net-0000.params') 183 | return net --------------------------------------------------------------------------------