├── floyd_requirements.txt ├── requirements.txt ├── images ├── vgg19.png ├── alexnet.png ├── ant_bee.png ├── cnntsne.jpeg ├── densenet.png ├── resnet.png ├── squeezenet.png └── inceptionv3.png ├── test └── images │ ├── test1.jpg │ ├── test2.jpg │ ├── test4.jpg │ ├── test5.jpg │ ├── test6.jpg │ ├── test7.jpg │ └── test3.jpeg ├── .floydignore ├── LICENSE ├── .gitignore ├── app.py ├── imagenet_models.py ├── README.md └── main.py /floyd_requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /images/vgg19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/vgg19.png -------------------------------------------------------------------------------- /images/alexnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/alexnet.png -------------------------------------------------------------------------------- /images/ant_bee.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/ant_bee.png -------------------------------------------------------------------------------- /images/cnntsne.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/cnntsne.jpeg -------------------------------------------------------------------------------- /images/densenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/densenet.png -------------------------------------------------------------------------------- /images/resnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/resnet.png -------------------------------------------------------------------------------- /images/squeezenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/squeezenet.png -------------------------------------------------------------------------------- /test/images/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test1.jpg -------------------------------------------------------------------------------- /test/images/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test2.jpg -------------------------------------------------------------------------------- /test/images/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test4.jpg -------------------------------------------------------------------------------- /test/images/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test5.jpg -------------------------------------------------------------------------------- /test/images/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test6.jpg -------------------------------------------------------------------------------- /test/images/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test7.jpg -------------------------------------------------------------------------------- /images/inceptionv3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/images/inceptionv3.png -------------------------------------------------------------------------------- /test/images/test3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/imagenet/master/test/images/test3.jpeg -------------------------------------------------------------------------------- /.floydignore: -------------------------------------------------------------------------------- 1 | 2 | # Directories and files to ignore when uploading code to floyd 3 | 4 | FLOYD_README.md 5 | .git 6 | .eggs 7 | eggs 8 | lib 9 | lib64 10 | parts 11 | sdist 12 | core 13 | var 14 | *.pyc 15 | *.swp 16 | .DS_Store 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # Floyd things 11 | FLOYD_README.md 12 | .floydexpt 13 | .DS_Store 14 | core 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | 113 | # End of https://www.gitignore.io/api/python 114 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flask Serving 3 | This file is a sample flask app that can be used to test your model with an REST API. 4 | This app does the following: 5 | - Look for a Image and then process it to be ImageNet compliant 6 | - Returns the evaluation 7 | 8 | POST req: 9 | parameter: 10 | - file, required, a image to classify 11 | 12 | """ 13 | import os 14 | import torch 15 | from flask import Flask, send_file, request 16 | from werkzeug.exceptions import BadRequest 17 | from werkzeug.utils import secure_filename 18 | from imagenet_models import ConvNet 19 | 20 | ALLOWED_EXTENSIONS = set(['jpg', 'png', 'jpeg']) 21 | 22 | MODEL_PATH = '/model' 23 | print('Loading model from path: %s' % MODEL_PATH) 24 | 25 | EVAL_PATH = '/eval' 26 | TRAIN_PATH = '/input/train' 27 | MODEL = "resnet18" 28 | 29 | # Is there the EVAL_PATH? 30 | try: 31 | os.makedirs(EVAL_PATH) 32 | except OSError: 33 | pass 34 | 35 | app = Flask('ImageNet-Classifier') 36 | 37 | # Build the model before to improve performance 38 | checkpoint = os.path.join(MODEL_PATH, "model_best.pth.tar") # FIX to 39 | Model = ConvNet(ckp=checkpoint, train_dir=TRAIN_PATH, arch=MODEL) 40 | Model.build_model() 41 | 42 | # Return an Image 43 | @app.route('/', methods=['POST']) 44 | def geneator_handler(path): 45 | """Upload an image file, then 46 | preprocess and classify""" 47 | # check if the post request has the file part 48 | if 'file' not in request.files: 49 | return BadRequest("File not present in request") 50 | file = request.files['file'] 51 | if file.filename == '': 52 | return BadRequest("File name is not present in request") 53 | if not allowed_file(file.filename): 54 | return BadRequest("Invalid file type") 55 | filename = secure_filename(file.filename) 56 | image_folder = os.path.join(EVAL_PATH, "images") 57 | # Create dir /eval/images 58 | try: 59 | os.makedirs(image_folder) 60 | except OSError: 61 | pass 62 | # Save Image to process 63 | input_filepath = os.path.join(image_folder, filename) 64 | file.save(input_filepath) 65 | # Preprocess and Evaluate 66 | Model.image_preprocessing() 67 | pred = Model.classify() 68 | # Return classification and remove uploaded file 69 | output = "Images: {file}, Classified as {pred}\n".format(file=file.filename, 70 | pred=pred) 71 | os.remove(input_filepath) 72 | return output 73 | 74 | 75 | def allowed_file(filename): 76 | return '.' in filename and \ 77 | filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS 78 | 79 | if __name__ == '__main__': 80 | app.run(host='0.0.0.0') 81 | -------------------------------------------------------------------------------- /imagenet_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | import os 4 | import shutil 5 | import time 6 | from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | 19 | 20 | class ConvNet(object): 21 | """ConvNet Model class""" 22 | def __init__(self, 23 | arch="resnet18", 24 | ckp="/model/model_best.pth.tar", 25 | train_dir="/input/train", 26 | evalf="/eval"): 27 | """MNIST ConvNet Builder 28 | Args: 29 | ckp: path to model checkpoint file (to continue training). 30 | evalf: path to evaluate sample. 31 | """ 32 | # Path to model weight 33 | self._ckp = ckp 34 | # Use CUDA? 35 | self._cuda = torch.cuda.is_available() 36 | try: 37 | os.path.isfile(ckp) 38 | self._ckp = ckp 39 | except IOError as e: 40 | # Does not exist OR no read permissions 41 | print ("Unable to open ckp file") 42 | self._evalf = evalf 43 | self._arch = arch 44 | # Size on model 45 | if arch.startswith('inception'): 46 | self._size = (299, 299) 47 | else: 48 | self._size = (224, 256) 49 | # Get labels 50 | self._labels = self._get_label(train_dir) 51 | 52 | 53 | # Build the model loading the weights 54 | def build_model(self): 55 | # Create model from scratch or use a pretrained one 56 | print("=> using model '{}'".format(self._arch)) 57 | self._model = models.__dict__[self._arch](num_classes=len(self._labels)) 58 | print("=> loading checkpoint '{}'".format(self._ckp)) 59 | if self._cuda: 60 | checkpoint = torch.load(self._ckp) 61 | else: 62 | # Load GPU model on CPU 63 | checkpoint = torch.load(self._ckp, map_location=lambda storage, loc: storage) 64 | # Load weights 65 | self._model.load_state_dict(checkpoint['state_dict']) 66 | 67 | if self._cuda: 68 | self._model.cuda() 69 | else: 70 | self._model.cpu() 71 | 72 | 73 | # Preprocess Images to be ImageNet-compliant 74 | def image_preprocessing(self): 75 | """Take images from args.evalf, process to be ImageNet compliant 76 | and classify them with ImageNet ConvNet model chosen""" 77 | # Normalize on RGB Value 78 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 79 | std=[0.229, 0.224, 0.225]) 80 | def pil_loader(path): 81 | """Load images from /eval/ subfolder and resized it as squared""" 82 | with open(path, 'rb') as f: 83 | with Image.open(f) as img: 84 | sqrWidth = numpy.ceil(numpy.sqrt(img.size[0]*img.size[1])).astype(int) 85 | return img.resize((sqrWidth, sqrWidth)) 86 | 87 | self._test_loader = torch.utils.data.DataLoader( 88 | datasets.ImageFolder(self._evalf, transforms.Compose([ 89 | transforms.Scale(self._size[1]), # 256 90 | transforms.CenterCrop(self._size[0]), # 224 , 299 91 | transforms.ToTensor(), 92 | normalize, 93 | ]), loader=pil_loader), 94 | batch_size=1, shuffle=False, 95 | num_workers=1, pin_memory=False) 96 | 97 | 98 | def _get_label(self, train_dir): 99 | # Normalize on RGB Value 100 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 101 | std=[0.229, 0.224, 0.225]) 102 | # Train -> Preprocessing -> Tensor 103 | train_dataset = datasets.ImageFolder( 104 | train_dir, 105 | transforms.Compose([ 106 | transforms.RandomSizedCrop(self._size[0]), #224 , 299 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | normalize, 110 | ])) 111 | 112 | # Get number of labels 113 | return train_dataset.classes 114 | 115 | 116 | def classify(self): 117 | """Classify the current test batch""" 118 | self._model.eval() 119 | for data, _ in self._test_loader: 120 | if self._cuda: 121 | data = data.cuda() 122 | data = torch.autograd.Variable(data, volatile=True) 123 | output = self._model(data) 124 | # Take last layer output 125 | if isinstance(output, tuple): 126 | output = output[len(output)-1] 127 | 128 | lab = self._labels[numpy.asscalar(output.data.max(1, keepdim=True)[1].cpu().numpy())] 129 | print (self._labels, lab) 130 | return lab -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Traning and Transfer Learning ImageNet model in Pytorch 2 | 3 | This project implements: 4 | - [Training](#imagenet-training-in-pytorch) of popular model architectures, such as ResNet, AlexNet, and VGG on the ImageNet dataset; 5 | - [Transfer learning](#transfer-learning) from the most popular model architectures of above, fine tuning only the last fully connected layer. 6 | 7 | *Note*: 8 | 9 | - **ImageNet training** will be documeted in the next release. 10 | - **Transfer-learning** was fully tested on alexnet, densenet121, inception_v3, resnet18 and vgg19. The other models will be tested in the next release. 11 | 12 | ## Usage 13 | 14 | ```bash 15 | usage: main.py [-h] [--data DIR] [--outf OUTF] [--evalf EVALF] [--arch ARCH] 16 | [-j N] [--epochs N] [--start-epoch N] [-b N] [--lr LR] 17 | [--momentum M] [--weight-decay W] [--print-freq N] 18 | [--resume PATH] [-e] [--train] [--test] [-t] [--pretrained] 19 | [--world-size WORLD_SIZE] [--dist-url DIST_URL] 20 | [--dist-backend DIST_BACKEND] 21 | 22 | PyTorch ImageNet Training 23 | 24 | optional arguments: 25 | -h, --help show this help message and exit 26 | --data DIR path to dataset 27 | --outf OUTF folder to output model checkpoints 28 | --evalf EVALF path to evaluate sample 29 | --arch ARCH, -a ARCH model architecture: alexnet | densenet121 | 30 | densenet161 | densenet169 | densenet201 | inception_v3 31 | | resnet101 | resnet152 | resnet18 | resnet34 | 32 | resnet50 | squeezenet1_0 | squeezenet1_1 | vgg11 | 33 | vgg11_bn | vgg13 | vgg13_bn | vgg16 | vgg16_bn | vgg19 34 | | vgg19_bn (default: resnet18) 35 | -j N, --workers N number of data loading workers (default: 4) 36 | --epochs N number of total epochs to run 37 | --start-epoch N manual epoch number (useful on restarts) 38 | -b N, --batch-size N mini-batch size (default: 256) 39 | --lr LR, --learning-rate LR 40 | initial learning rate 41 | --momentum M momentum 42 | --weight-decay W, --wd W 43 | weight decay (default: 1e-4) 44 | --print-freq N, -p N print frequency (default: 10) 45 | --resume PATH path to latest checkpoint (default: none) 46 | -e, --evaluate evaluate model on validation set 47 | --train train the model 48 | --test test a [pre]trained model on new images 49 | -t, --fine-tuning 50 | transfer learning enabled + fine tuning - train only the last FC 51 | layer. 52 | --pretrained use pre-trained model 53 | --world-size WORLD_SIZE 54 | number of distributed processes 55 | --dist-url DIST_URL url used to set up distributed training 56 | --dist-backend DIST_BACKEND 57 | distributed backend 58 | ``` 59 | 60 | 61 | ## ImageNet models Architecture 62 | 63 | #### Alexnet 64 | 65 | ![alexnet architecture](images/alexnet.png) 66 | 67 | Credit: [Imagenet classification with deep convolutional neural networks paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) 68 | 69 | #### Densenet 70 | 71 | ![densenet architecture](images/densenet.png) 72 | 73 | Credit: [Densely Connected Convolutional Networks](https://arxiv.org/pdf/1608.06993.pdf) 74 | 75 | #### Inception_v3 76 | 77 | ![inception_v3 architecture](images/inceptionv3.png) 78 | 79 | Credit: [Rethinking the Inception Architecture for Computer Vision paper](https://arxiv.org/pdf/1512.00567.pdf) (image taken from [google research blog](https://research.googleblog.com/2016/08/improving-inception-and-image.html)) 80 | 81 | #### Resnet34 82 | 83 | ![Resnet architecture](images/resnet.png) 84 | 85 | Credit: [Deep Residual Learning for Image Recognition paper](https://arxiv.org/pdf/1512.03385v1.pdf) 86 | 87 | #### Squeezenet1_0 88 | 89 | ![Squeezenet architecture](images/squeezenet.png) 90 | 91 | Credit: [Squeezenet: Alexnet-level accuracy with 50x fewer parameters and <0.5MB model size](https://arxiv.org/pdf/1602.07360.pdf) 92 | 93 | #### Vgg19net 94 | 95 | ![vgg19 architecture](images/vgg19.png) 96 | 97 | Credit: [Very Deep Convolutional Networks For Large-Scale Image Recognition paper](https://arxiv.org/pdf/1409.1556v6.pdf) (image taken from Resnet paper) 98 | 99 | ## ImageNet training in PyTorch 100 | 101 | ![imagenet dataset tsne visualization](images/cnntsne.jpeg) 102 | 103 | *Credit: [karpathy.github.io](http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/)* 104 | 105 | This project implements the ImageNet classification task on [ImageNet](http://www.image-net.org/) dataset with different famous Convolutional Neural Network(CNN or ConvNet) models. This is a porting of [pytorch/examples/imagenet](https://github.com/pytorch/examples/tree/master/imagenet) making it usables on [FloydHub](https://www.floydhub.com). 106 | 107 | ### Requirement 108 | 109 | Download the ImageNet dataset and move validation images to labeled subfolders. Unfortunately at the moment the imagenet is not fully supported as [torchvision.dataset](http://pytorch.org/docs/master/torchvision/datasets.html#imagenet-12), so we need to use the [ImageFolder API](http://pytorch.org/docs/master/torchvision/datasets.html#imagefolder) which expects to load the dataset from a structure of this type: 110 | 111 | ```bash 112 | ls /dataset 113 | 114 | train 115 | val 116 | test 117 | 118 | # Train 119 | ls /dataset/train 120 | cat 121 | dog 122 | tiger 123 | plane 124 | ... 125 | 126 | ls /dataset/train/cat 127 | cat01.png 128 | cat02.png 129 | ... 130 | 131 | ls /dataset/train/dog 132 | dog01.jpg 133 | dog02.jpg 134 | ... 135 | ...[others classification folders] 136 | 137 | # Val 138 | ls /dataset/val 139 | cat 140 | dog 141 | tiger 142 | plane 143 | ... 144 | 145 | ls /dataset/val/cat 146 | cat01.png 147 | cat02.png 148 | ... 149 | 150 | ls /dataset/val/dog 151 | dog01.jpg 152 | dog02.jpg 153 | ... 154 | 155 | # Test 156 | ls /dataset/test 157 | images 158 | 159 | ls /dataset/test/images 160 | test01.png 161 | test02.png 162 | ... 163 | 164 | ``` 165 | 166 | Once you have build the dataset following the steps above, upload it as FloydHub dataset following this guide: [create and upload FloydHub dataset](https://docs.floydhub.com/guides/create_and_upload_dataset/). 167 | 168 | ### Run on FloydHub 169 | 170 | Here's the commands to train and evaluate your [pretrained] model on FloydHub(these section will be improved with the next release): 171 | 172 | #### Project Setup 173 | 174 | Before you start, log in on FloydHub with the [floyd login](http://docs.floydhub.com/commands/login/) command, then fork and init 175 | the project: 176 | 177 | ```bash 178 | $ git clone https://github.com/floydhub/imagenet.git 179 | $ cd imagenet 180 | $ floyd init imagenet 181 | ``` 182 | 183 | #### Training 184 | 185 | To train a model, run `main.py` with the desired model architecture and the path to the ImageNet dataset: 186 | 187 | ```bash 188 | floyd run --gpu --data /datasets/imagenet/:input "python main.py -a resnet18 [other params]" 189 | ``` 190 | 191 | The default learning rate schedule starts at 0.1 and decays by a factor of 10 every 30 epochs. This is appropriate for ResNet and models with batch normalization, but too high for AlexNet and VGG. Use 0.01 as the initial learning rate for AlexNet or VGG: 192 | 193 | ```bash 194 | floyd run --gpu --data /datasets/imagenet/:input "python main.py -a --lr 0.01 [other params]" 195 | ``` 196 | 197 | **Note**: 198 | 199 | A full training on Imagenet *can takes weeks* according to the selected model. 200 | 201 | #### Evaluating 202 | 203 | It's time to evaluate our model with some images(put the images you want to classify in the `test/images` folder): 204 | ```bash 205 | floyd run --gpu --env pytorch-0.2 --data /datasets/imagenet/:input --data :model "python main.py -a --test --evalf test/ --resume /model/model_best.pth.tar" 206 | 207 | #### Try Pytorch Pretrained model 208 | 209 | Pytorch provided to you pretrained model for different models, if you want to evaluate your dataset with one of this model run: 210 | 211 | ```bash 212 | floyd run --gpu --data /datasets//:input "python main.py -a [arch] --pretrained --data /input/test [other params]" 213 | ``` 214 | 215 | 216 | #### Serve model through REST API 217 | 218 | FloydHub supports seving mode for demo and testing purpose. Before serving your model through REST API, you need to create a `floyd_requirements.txt` and declare the flask requirement in it. If you run a job with `--mode` serve flag, FloydHub will run the app.py file in your project and attach it to a dynamic service endpoint: 219 | 220 | ```bash 221 | floyd run --gpu --mode serve --env pytorch-0.2 --data /datasets//:input --data :model 222 | ``` 223 | 224 | Note: 225 | The script retrieve the number of classes from the dataset `--data /datasets//`. This behavior will be fixed in the next release. 226 | 227 | The above command will print out a service endpoint for this job in your terminal console. 228 | 229 | The service endpoint will take a couple minutes to become ready. Once it's up, you can interact with the model by sending an image file with a POST request that the model will classify(according to ImageNet labels): 230 | 231 | ```bash 232 | # Template 233 | # curl -X POST -F "file=@" 234 | 235 | # e.g. of a POST req 236 | curl -X POST -F "file=@./test/images/test01.png" https://www.floydlabs.com/expose/BhZCFAKom6Z8RptVKskHZW 237 | ``` 238 | 239 | Any job running in serving mode will stay up until it reaches maximum runtime. So once you are done testing, **remember to shutdown the job!** 240 | 241 | *Note that this feature is in preview mode and is not production ready yet* 242 | 243 | ## Transfer Learning 244 | 245 | ![Bees Vs Ants dataset](images/ant_bee.png) 246 | 247 | This project implements the a Transfer Learning classification task on the [Bees Vs Ants](https://download.pytorch.org/tutorial/hymenoptera_data.zip) toy dataset(train: 124 images of ants and 121 images of bees, val: 70 images of ants and 83 images of bees) with different Convolutional Neural Network(CNN or ConvNet) models. This is a porting of the [transfer learning tutorial from the official PyTorch Docs](http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html) making it usables on [FloydHub](https://www.floydhub.com). 248 | 249 | Credit: [Sasank Chilamkurthy](https://chsasank.github.io/) who has written the amazing tutorial on transfer learning in the PyTorch docs. 250 | 251 | ### Run on FloydHub 252 | 253 | Here's the commands to train and evaluate your [pretrained] model on FloydHub(these section will be improved with the next release): 254 | 255 | #### Project Setup 256 | 257 | Before you start, log in on FloydHub with the [floyd login](http://docs.floydhub.com/commands/login/) command, then fork and init 258 | the project: 259 | 260 | ```bash 261 | $ git clone https://github.com/floydhub/imagenet.git 262 | $ cd imagenet 263 | $ floyd init imagenet 264 | ``` 265 | 266 | #### Training 267 | 268 | I have already uploaded it as FloydHub dataset so that you can try and familiarize with `--data` parameter which mounts the specified volume(datasets/model) inside the container of your FloydHub instance. Now it's time to run our training on FloydHub. In this example we will train the model for 10 epochs with a gpu instance. 269 | **Note**: If you want to mount/create a dataset [look at the docs](https://docs.floydhub.com/guides/create_and_upload_dataset/). 270 | 271 | ```bash 272 | floyd run --gpu --env pytorch-0.2 --data redeipirati/datasets/pytorch-hymenoptera/1:input "python main.py -a resnet18 --train --fine-tuning --pretrained --epochs 10 -b 4" 273 | ``` 274 | 275 | Note: 276 | 277 | - `--gpu` run your job on a FloydHub GPU instance 278 | - `--env pytorch-0.2` prepares a pytorch environment for python 3. 279 | - `--data redeipirati/datasets/pytorch-hymenoptera/1` mounts the pytorch hymenoptera dataset(bees vs ants) in the /input folder inside the container for our job so that we do not need to dowload it at training time. 280 | 281 | #### Evaluating 282 | 283 | It's time to evaluate our model with some images: 284 | ```bash 285 | floyd run --gpu --env pytorch-0.2 --data redeipirati/datasets/pytorch-hymenoptera/1:input --data :model "python main.py -a resnet18 --test --fine-tuning --evalf test/ --resume /model/model_best.pth.tar" 286 | ``` 287 | 288 | Notes: 289 | 290 | - I've prepared for you some images in the `test` folder that you can use to evaluate your model. Feel free to add on it a bunch of bee/ant images downloaded from the web. 291 | - Remember to evaluate images which are taken from a similar distribution, otherwise you will have bad performance due to distribution mismatch. 292 | 293 | #### Try our pre-trained model 294 | We have provided to you a pre-trained model trained for 30 epochs with an accuracy of about 95%. 295 | 296 | ```bash 297 | floyd run --gpu --env pytorch-0.2 --data redeipirati/datasets/pytorch-hymenoptera/1:input --data redeipirati/datasets/pytorch-hymenoptera-30-epochs-resnet18-model/1:model "python main.py -a resnet18 --test --fine-tuning --evalf test/ --resume /model/model_best.pth.tar" 298 | ``` 299 | 300 | #### Serve model through REST API 301 | 302 | FloydHub supports seving mode for demo and testing purpose. Before serving your model through REST API, you need to create a `floyd_requirements.txt` and declare the flask requirement in it. If you run a job with `--mode` serve flag, FloydHub will run the app.py file in your project and attach it to a dynamic service endpoint: 303 | 304 | ```bash 305 | floyd run --gpu --mode serve --env pytorch-0.2 --data redeipirati/datasets/pytorch-hymenoptera/1:input --data :model 306 | ``` 307 | 308 | Note: 309 | The script retrieve the number of classes from the dataset `--data redeipirati/datasets/pytorch-hymenoptera/1`. This behavior will be fixed in the next release. 310 | 311 | The above command will print out a service endpoint for this job in your terminal console. 312 | 313 | The service endpoint will take a couple minutes to become ready. Once it's up, you can interact with the model by sending an images(of ant or bee) file with a POST request that the model will classify: 314 | 315 | ```bash 316 | # Template 317 | # curl -X POST -F "file=@" 318 | 319 | # e.g. of a POST req 320 | curl -X POST -F "file=@./test/images/test01.png" https://www.floydlabs.com/expose/BhZCFAKom6Z8RptVKskHZW 321 | ``` 322 | 323 | Any job running in serving mode will stay up until it reaches maximum runtime. So once you are done testing, **remember to shutdown the job!** 324 | 325 | *Note that this feature is in preview mode and is not production ready yet* 326 | 327 | ## More resources 328 | 329 | Some useful resources on ImageNet and the famous ConvNet models: 330 | 331 | - [ILSVRC(Imagenet Large Scale Visual Recognition Challenge)](http://www.image-net.org/challenges/LSVRC/) 332 | - [Karpathy CNN and ImageNet](http://karpathy.github.io/2014/09/02/what-i-learned-from-competing-against-a-convnet-on-imagenet/) 333 | - [CS231n CNN](http://cs231n.github.io/convolutional-networks/) 334 | - [CS231n understanding cnn](http://cs231n.github.io/understanding-cnn/) 335 | - [CS231n transfer learning](http://cs231n.github.io/transfer-learning/) 336 | - [An Intuitive Explanation of Convolutional Neural Networks](https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/) 337 | - [FloydHub Building your first CNN](https://blog.floydhub.com/building-your-first-convnet/) 338 | - [Inception v3](https://research.googleblog.com/2016/08/improving-inception-and-image.html) 339 | - [How does Deep Residual Net work?](https://www.quora.com/How-does-deep-residual-learning-work) 340 | - [How does Inception module work?](https://www.quora.com/How-does-the-Inception-module-work-in-GoogLeNet-deep-architecture) 341 | - [Squeezenet](http://www.kdnuggets.com/2016/09/deep-learning-reading-group-squeezenet.html) 342 | - [Famous CNN models KDnuggets explained](http://www.kdnuggets.com/2016/09/9-key-deep-learning-papers-explained.html/) 343 | 344 | ## Contributing 345 | 346 | For any questions, bug(even typos) and/or features requests do not hesitate to contact me or open an issue! 347 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy 3 | import os 4 | import shutil 5 | import time 6 | from PIL import Image 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.parallel 10 | import torch.backends.cudnn as cudnn 11 | import torch.distributed as dist 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | 19 | # Load all model arch available on Pytorch 20 | model_names = sorted(name for name in models.__dict__ 21 | if name.islower() and not name.startswith("__") 22 | and callable(models.__dict__[name])) 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 25 | parser.add_argument('--data', default='/input', metavar='DIR', 26 | help='path to dataset') 27 | parser.add_argument('--outf', default='/output', 28 | help='folder to output model checkpoints') 29 | parser.add_argument('--evalf', default="/eval" ,help='path to evaluate sample') 30 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', 31 | choices=model_names, 32 | help='model architecture: ' + 33 | ' | '.join(model_names) + 34 | ' (default: resnet18)') 35 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 36 | help='number of data loading workers (default: 4)') 37 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 40 | help='manual epoch number (useful on restarts)') 41 | parser.add_argument('-b', '--batch-size', default=256, type=int, 42 | metavar='N', help='mini-batch size (default: 256)') 43 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 44 | metavar='LR', help='initial learning rate') 45 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 46 | help='momentum') 47 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 48 | metavar='W', help='weight decay (default: 1e-4)') 49 | parser.add_argument('--print-freq', '-p', default=10, type=int, 50 | metavar='N', help='print frequency (default: 10)') 51 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 52 | help='path to latest checkpoint (default: none)') 53 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 54 | help='evaluate model on validation set') 55 | parser.add_argument('--train', action='store_true', 56 | help='train the model') 57 | parser.add_argument('--test', action='store_true', 58 | help='test a [pre]trained model on new images') 59 | parser.add_argument('-t', '--fine-tuning', action='store_true', 60 | help='transfer learning + fine tuning - train only the last FC layer.') 61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 62 | help='use pre-trained model') 63 | parser.add_argument('--world-size', default=1, type=int, 64 | help='number of distributed processes') 65 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 66 | help='url used to set up distributed training') 67 | parser.add_argument('--dist-backend', default='gloo', type=str, 68 | help='distributed backend') 69 | 70 | best_prec1 = torch.FloatTensor([0]) 71 | 72 | def get_images_name(folder): 73 | """Create a generator to list images name at evaluation time""" 74 | onlyfiles = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))] 75 | for f in onlyfiles: 76 | yield f 77 | 78 | def pil_loader(path): 79 | """Load images from /eval/ subfolder and resized it as squared""" 80 | with open(path, 'rb') as f: 81 | with Image.open(f) as img: 82 | sqrWidth = numpy.ceil(numpy.sqrt(img.size[0]*img.size[1])).astype(int) 83 | return img.resize((sqrWidth, sqrWidth)) 84 | 85 | def main(): 86 | global args, best_prec1, cuda, labels 87 | args = parser.parse_args() 88 | 89 | try: 90 | os.makedirs(args.outf) 91 | # os.makedirs(opt.outf+"/model") 92 | except OSError: 93 | pass 94 | 95 | # can we use CUDA? 96 | cuda = False #torch.cuda.is_available() 97 | print ("=> using cuda: {cuda}".format(cuda=cuda)) 98 | # Not working on FloydHub 99 | # if torch.cuda.device_count() is not None: 100 | # print ("=> available cuda devices: {dev}").format(dev=torch.cuda.device_count()) 101 | # Distributed Training? 102 | args.distributed = args.world_size > 1 103 | print ("=> distributed training: {dist}".format(dist=args.distributed)) 104 | 105 | ############ DATA PREPROCESSING ############ 106 | # Data loading code 107 | traindir = os.path.join(args.data, 'train') 108 | valdir = os.path.join(args.data, 'val') 109 | testdir = args.evalf 110 | # Normalize on RGB Value 111 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 112 | std=[0.229, 0.224, 0.225]) 113 | 114 | # Size on model 115 | if args.arch.startswith('inception'): 116 | size = (299, 299) 117 | else: 118 | size = (224, 256) 119 | 120 | # Train -> Preprocessing -> Tensor 121 | train_dataset = datasets.ImageFolder( 122 | traindir, 123 | transforms.Compose([ 124 | transforms.RandomSizedCrop(size[0]), #224 , 299 125 | transforms.RandomHorizontalFlip(), 126 | transforms.ToTensor(), 127 | normalize, 128 | ])) 129 | 130 | #print (train_dataset.classes) 131 | # Get number of labels 132 | labels = len(train_dataset.classes) 133 | 134 | if args.distributed: 135 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 136 | else: 137 | train_sampler = None 138 | 139 | # Pin memory 140 | if cuda: 141 | pin_memory = True 142 | else: 143 | pin_memory = False 144 | 145 | train_loader = torch.utils.data.DataLoader( 146 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 147 | num_workers=args.workers, pin_memory=pin_memory, sampler=train_sampler) 148 | 149 | # Validate -> Preprocessing -> Tensor 150 | val_loader = torch.utils.data.DataLoader( 151 | datasets.ImageFolder(valdir, transforms.Compose([ 152 | transforms.Scale(size[1]), # 256 153 | transforms.CenterCrop(size[0]), # 224 , 299 154 | transforms.ToTensor(), 155 | normalize, 156 | ])), 157 | batch_size=args.batch_size, shuffle=False, 158 | num_workers=args.workers, pin_memory=pin_memory) 159 | 160 | if args.test: 161 | # Testing -> Preprocessing -> Tensor 162 | test_loader = torch.utils.data.DataLoader( 163 | datasets.ImageFolder(testdir, transforms.Compose([ 164 | transforms.Scale(size[1]), # 256 165 | transforms.CenterCrop(size[0]), # 224 , 299 166 | transforms.ToTensor(), 167 | normalize, 168 | ]), loader=pil_loader), 169 | batch_size=1, shuffle=False, 170 | num_workers=args.workers, pin_memory=pin_memory) 171 | ############ BUILD MODEL ############ 172 | if args.distributed: 173 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 174 | world_size=args.world_size) 175 | 176 | # Create model from scratch or use a pretrained one 177 | if args.pretrained: 178 | print("=> using pre-trained model '{}'".format(args.arch)) 179 | model = models.__dict__[args.arch](pretrained=True) 180 | # print(model) 181 | # quit() 182 | else: 183 | print("=> creating model '{}'".format(args.arch)) 184 | model = models.__dict__[args.arch](num_classes=labels) 185 | # print(model) 186 | 187 | # Freeze model, train only the last FC layer for the transfered task 188 | if args.fine_tuning: 189 | print("=> transfer-learning mode + fine-tuning (train only the last FC layer)") 190 | # Freeze Previous Layers(now we are using them as features extractor) 191 | for param in model.parameters(): 192 | param.requires_grad = False 193 | # Fine Tuning the last Layer For the new task 194 | # RESNET 195 | if args.arch == 'resnet18': 196 | num_ftrs = model.fc.in_features 197 | model.fc = nn.Linear(num_ftrs, labels) 198 | parameters = model.fc.parameters() 199 | # print(model) 200 | # quit() 201 | # ALEXNET & VGG 202 | elif args.arch == 'alexnet' or args.arch == 'vgg19': 203 | model.classifier._modules['6'] = nn.Linear(4096, labels) 204 | parameters = model.classifier._modules['6'].parameters() 205 | # print(model) 206 | # quit() 207 | elif args.arch == 'densenet121': # DENSENET 208 | model.classifier = nn.Linear(1024, labels) 209 | parameters = model.classifier.parameters() 210 | # print(model) 211 | # quit() 212 | # INCEPTION 213 | elif args.arch == 'inception_v3': 214 | # Auxiliary Fc layer 215 | num_ftrs = model.AuxLogits.fc.in_features 216 | model.AuxLogits.fc = nn.Linear(num_ftrs, labels) 217 | # parameters = model.AuxLogits.fc.parameters() 218 | # print (parameters) 219 | # Last layer 220 | num_ftrs = model.fc.in_features 221 | model.fc = nn.Linear(num_ftrs, labels) 222 | parameters = model.fc.parameters() 223 | # print(model) 224 | # quit() 225 | else: 226 | print("Error: Fine-tuning is not supported on this architecture.") 227 | exit(-1) 228 | else: 229 | parameters = model.parameters() 230 | 231 | # Not working on FloydHub 232 | # if torch.cuda.device_count() is not None: 233 | # # Set [Distributed]DataParallel only on more than 1 cuda(GPUs) devices 234 | # if cuda and torch.cuda.device_count() > 1: 235 | # # Local or Distributed Enviroment 236 | # if not args.distributed: 237 | # print("=> load model on '{}' cuda devices".format(torch.cuda.device_count())) 238 | # if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 239 | # model.features = torch.nn.DataParallel(model.features) 240 | # else: 241 | # model = torch.nn.DataParallel(model) 242 | # else: 243 | # print("=> load model on on distributed enviroment".format(torch.cuda.device_count())) 244 | # model = torch.nn.parallel.DistributedDataParallel(model) 245 | 246 | # Define loss function (criterion) and optimizer 247 | criterion = nn.CrossEntropyLoss() 248 | if cuda: 249 | criterion.cuda() 250 | 251 | # Set SGD + Momentum 252 | optimizer = torch.optim.SGD(parameters, args.lr, 253 | momentum=args.momentum, 254 | weight_decay=args.weight_decay) 255 | 256 | # optionally resume from a checkpoint 257 | if args.resume: 258 | if os.path.isfile(args.resume): 259 | print("=> loading checkpoint '{}'".format(args.resume)) 260 | if cuda: 261 | checkpoint = torch.load(args.resume) 262 | else: 263 | # Load GPU model on CPU 264 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) 265 | args.start_epoch = checkpoint['epoch'] 266 | best_prec1 = checkpoint['best_prec1'] 267 | model.load_state_dict(checkpoint['state_dict']) 268 | optimizer.load_state_dict(checkpoint['optimizer']) 269 | print("=> loaded checkpoint '{}' (epoch {})" 270 | .format(args.resume, checkpoint['epoch'])) 271 | else: 272 | print("=> no checkpoint found at '{}'".format(args.resume)) 273 | 274 | # Load model on GPU or CPU 275 | if cuda: 276 | model.cuda() 277 | else: 278 | model.cpu() 279 | ############ TRAIN/EVAL/TEST ############ 280 | cudnn.benchmark = True 281 | 282 | # Evaluate? 283 | if args.evaluate: 284 | print("=> evaluating...") 285 | validate(val_loader, model, criterion) 286 | return 287 | 288 | # Testing? 289 | if args.test: 290 | print("=> testing...") 291 | # Name generator 292 | names = get_images_name(os.path.join(testdir, 'images')) 293 | test(test_loader, model, names, train_dataset.classes) 294 | return 295 | 296 | # Training 297 | if args.train: 298 | print("=> training...") 299 | for epoch in range(args.start_epoch, args.epochs): 300 | if args.distributed: 301 | train_sampler.set_epoch(epoch) 302 | adjust_learning_rate(optimizer, epoch) 303 | 304 | # Train for one epoch 305 | train(train_loader, model, criterion, optimizer, epoch) 306 | 307 | # Evaluate on validation set 308 | prec1 = validate(val_loader, model, criterion) 309 | # print (prec1) 310 | 311 | # Remember best prec@1 and save checkpoint 312 | if cuda: 313 | prec1 = prec1.cpu() # Load on CPU if CUDA 314 | # Get bool not ByteTensor 315 | is_best = bool(prec1.numpy() > best_prec1.numpy()) 316 | # Get greater Tensor 317 | best_prec1 = torch.FloatTensor(max(prec1.numpy(), best_prec1.numpy())) 318 | save_checkpoint({ 319 | 'epoch': epoch + 1, 320 | 'arch': args.arch, 321 | 'state_dict': model.state_dict(), 322 | 'best_prec1': best_prec1, 323 | 'optimizer' : optimizer.state_dict(), 324 | }, is_best) 325 | 326 | 327 | def train(train_loader, model, criterion, optimizer, epoch): 328 | """Train the model on Training Set""" 329 | batch_time = AverageMeter() 330 | data_time = AverageMeter() 331 | losses = AverageMeter() 332 | top1 = AverageMeter() 333 | top5 = AverageMeter() 334 | 335 | # switch to train mode 336 | model.train() 337 | 338 | end = time.time() 339 | for i, (input, target) in enumerate(train_loader): 340 | # measure data loading time 341 | data_time.update(time.time() - end) 342 | if cuda: 343 | input, target = input.cuda(async=True), target.cuda(async=True) 344 | 345 | input_var = torch.autograd.Variable(input) 346 | target_var = torch.autograd.Variable(target) 347 | 348 | # compute output 349 | output = model(input_var) 350 | #topk = (1,5) if labels >= 100 else (1,) # TO FIX 351 | # For nets that have multiple outputs such as Inception 352 | if isinstance(output, tuple): 353 | loss = sum((criterion(o,target_var) for o in output)) 354 | # print (output) 355 | for o in output: 356 | prec1 = accuracy(o.data, target, topk=(1,)) 357 | top1.update(prec1[0], input.size(0)) 358 | losses.update(loss.data[0], input.size(0)*len(output)) 359 | else: 360 | loss = criterion(output, target_var) 361 | prec1 = accuracy(output.data, target, topk=(1,)) 362 | top1.update(prec1[0], input.size(0)) 363 | losses.update(loss.data[0], input.size(0)) 364 | 365 | # compute gradient and do SGD step 366 | optimizer.zero_grad() 367 | loss.backward() 368 | optimizer.step() 369 | 370 | # measure elapsed time 371 | batch_time.update(time.time() - end) 372 | end = time.time() 373 | 374 | # Info log every args.print_freq 375 | if i % args.print_freq == 0: 376 | print('Epoch: [{0}][{1}/{2}]\t' 377 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 378 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 379 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 380 | 'Prec@1 {top1_val} ({top1_avg})'.format( 381 | epoch, i, len(train_loader), batch_time=batch_time, 382 | data_time=data_time, loss=losses, 383 | top1_val=numpy.asscalar(top1.val.cpu().numpy()), 384 | top1_avg=numpy.asscalar(top1.avg.cpu().numpy()))) 385 | 386 | 387 | def validate(val_loader, model, criterion): 388 | """Validate the model on Validation Set""" 389 | batch_time = AverageMeter() 390 | losses = AverageMeter() 391 | top1 = AverageMeter() 392 | top5 = AverageMeter() 393 | 394 | # switch to evaluate mode 395 | model.eval() 396 | 397 | end = time.time() 398 | # Evaluate all the validation set 399 | for i, (input, target) in enumerate(val_loader): 400 | if cuda: 401 | input, target = input.cuda(async=True), target.cuda(async=True) 402 | input_var = torch.autograd.Variable(input, volatile=True) 403 | target_var = torch.autograd.Variable(target, volatile=True) 404 | 405 | # compute output 406 | output = model(input_var) 407 | # print ("Output: ", output) 408 | #topk = (1,5) if labels >= 100 else (1,) # TODO: add more topk evaluation 409 | # For nets that have multiple outputs such as Inception 410 | if isinstance(output, tuple): 411 | loss = sum((criterion(o,target_var) for o in output)) 412 | # print (output) 413 | for o in output: 414 | prec1 = accuracy(o.data, target, topk=(1,)) 415 | top1.update(prec1[0], input.size(0)) 416 | losses.update(loss.data[0], input.size(0)*len(output)) 417 | else: 418 | loss = criterion(output, target_var) 419 | prec1 = accuracy(output.data, target, topk=(1,)) 420 | top1.update(prec1[0], input.size(0)) 421 | losses.update(loss.data[0], input.size(0)) 422 | 423 | # measure elapsed time 424 | batch_time.update(time.time() - end) 425 | end = time.time() 426 | 427 | # Info log every args.print_freq 428 | if i % args.print_freq == 0: 429 | print('Test: [{0}/{1}]\t' 430 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 431 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 432 | 'Prec@1 {top1_val} ({top1_avg})'.format( 433 | i, len(val_loader), batch_time=batch_time, 434 | loss=losses, 435 | top1_val=numpy.asscalar(top1.val.cpu().numpy()), 436 | top1_avg=numpy.asscalar(top1.avg.cpu().numpy()))) 437 | 438 | print(' * Prec@1 {top1}' 439 | .format(top1=numpy.asscalar(top1.avg.cpu().numpy()))) 440 | return top1.avg 441 | 442 | 443 | def test(test_loader, model, names, classes): 444 | """Test the model on the Evaluation Folder 445 | 446 | Args: 447 | - classes: is a list with the class name 448 | - names: is a generator to retrieve the filename that is classified 449 | """ 450 | # switch to evaluate mode 451 | model.eval() 452 | # Evaluate all the validation set 453 | for i, (input, _) in enumerate(test_loader): 454 | if cuda: 455 | input = input.cuda(async=True) 456 | input_var = torch.autograd.Variable(input, volatile=True) 457 | 458 | # compute output 459 | output = model(input_var) 460 | # Take last layer output 461 | if isinstance(output, tuple): 462 | output = output[len(output)-1] 463 | 464 | # print (output.data.max(1, keepdim=True)[1]) 465 | lab = classes[numpy.asscalar(output.data.max(1, keepdim=True)[1].cpu().numpy())] 466 | print ("Images: " + next(names) + ", Classified as: " + lab) 467 | 468 | 469 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 470 | torch.save(state, os.path.join(args.outf, filename)) 471 | if is_best: 472 | shutil.copyfile(os.path.join(args.outf, filename), os.path.join(args.outf,'model_best.pth.tar')) 473 | 474 | 475 | class AverageMeter(object): 476 | """Computes and stores the average and current value""" 477 | def __init__(self): 478 | self.reset() 479 | 480 | def reset(self): 481 | self.val = 0 482 | self.avg = 0 483 | self.sum = 0 484 | self.count = 0 485 | 486 | def update(self, val, n=1): 487 | self.val = val 488 | self.sum += val * n 489 | self.count += n 490 | self.avg = self.sum / self.count 491 | 492 | 493 | def adjust_learning_rate(optimizer, epoch): 494 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 495 | lr = args.lr * (0.1 ** (epoch // 30)) 496 | for param_group in optimizer.param_groups: 497 | param_group['lr'] = lr 498 | 499 | 500 | def accuracy(output, target, topk=(1,)): 501 | """Computes the precision@k for the specified values of k""" 502 | maxk = max(topk) 503 | batch_size = target.size(0) 504 | 505 | _, pred = output.topk(maxk, 1, True, True) 506 | pred = pred.t() 507 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 508 | 509 | res = [] 510 | for k in topk: 511 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 512 | res.append(correct_k.mul_(100.0 / batch_size)) 513 | return res 514 | 515 | 516 | if __name__ == '__main__': 517 | main() 518 | --------------------------------------------------------------------------------