├── test └── images │ ├── 1.png │ ├── 3.jpg │ └── 9.jpg ├── images ├── mnist_convet.png └── MNIST_samples.png ├── .floydignore ├── LICENSE ├── .gitignore ├── app.py ├── ConvNet.py ├── README.md └── main.py /test/images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/mnist/master/test/images/1.png -------------------------------------------------------------------------------- /test/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/mnist/master/test/images/3.jpg -------------------------------------------------------------------------------- /test/images/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/mnist/master/test/images/9.jpg -------------------------------------------------------------------------------- /images/mnist_convet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/mnist/master/images/mnist_convet.png -------------------------------------------------------------------------------- /images/MNIST_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/floydhub/mnist/master/images/MNIST_samples.png -------------------------------------------------------------------------------- /.floydignore: -------------------------------------------------------------------------------- 1 | 2 | # Directories and files to ignore when uploading code to floyd 3 | 4 | .git 5 | .eggs 6 | eggs 7 | lib 8 | lib64 9 | parts 10 | sdist 11 | var 12 | *.pyc 13 | *.swp 14 | .DS_Store 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # Floyd things 11 | FLOYD_README.md 12 | .floydexpt 13 | .DS_Store 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | 112 | # End of https://www.gitignore.io/api/python 113 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | """ 2 | Flask Serving 3 | This file is a sample flask app that can be used to test your model with an REST API. 4 | This app does the following: 5 | - Look for a Image and then process it to be MNIST compliant 6 | - Returns the evaluation 7 | Additional configuration: 8 | - You can also choose the checkpoint file name to use as a request parameter 9 | - Parameter name: ckp 10 | - It is loaded from /model 11 | 12 | POST req: 13 | parameter: 14 | - file, required, a handwritten digit in [0-9] range 15 | - ckp, optional, load a specific chekcpoint from /model 16 | 17 | """ 18 | import os 19 | import torch 20 | from flask import Flask, send_file, request 21 | from werkzeug.exceptions import BadRequest 22 | from werkzeug.utils import secure_filename 23 | from ConvNet import ConvNet 24 | 25 | ALLOWED_EXTENSIONS = set(['jpg', 'png', 'jpeg']) 26 | 27 | MODEL_PATH = '/input' 28 | print('Loading model from path: %s' % MODEL_PATH) 29 | 30 | EVAL_PATH = '/eval' 31 | # Is there the EVAL_PATH? 32 | try: 33 | os.makedirs(EVAL_PATH) 34 | except OSError: 35 | pass 36 | 37 | app = Flask('MNIST-Classifier') 38 | 39 | 40 | # Return an Image 41 | @app.route('/', methods=['POST']) 42 | def geneator_handler(): 43 | """Upload an handwrittend digit image in range [0-9], then 44 | preprocess and classify""" 45 | # check if the post request has the file part 46 | if 'file' not in request.files: 47 | return BadRequest("File not present in request") 48 | file = request.files['file'] 49 | if file.filename == '': 50 | return BadRequest("File name is not present in request") 51 | if not allowed_file(file.filename): 52 | return BadRequest("Invalid file type") 53 | filename = secure_filename(file.filename) 54 | image_folder = os.path.join(EVAL_PATH, "images") 55 | # Create dir /eval/images 56 | try: 57 | os.makedirs(image_folder) 58 | except OSError: 59 | pass 60 | # Save Image to process 61 | input_filepath = os.path.join(image_folder, filename) 62 | file.save(input_filepath) 63 | # Get ckp 64 | checkpoint = request.form.get("ckp") or "/input/mnist_convnet_model_epoch_10.pth" # FIX to 65 | 66 | # Preprocess, Build and Evaluate 67 | Model = ConvNet(ckp=checkpoint) 68 | Model.image_preprocessing() 69 | Model.build_model() 70 | pred = Model.classify() 71 | 72 | output = "Images: {file}, Classified as {pred}\n".format(file=file.filename, 73 | pred=int(pred)) 74 | os.remove(input_filepath) 75 | return output 76 | 77 | 78 | def allowed_file(filename): 79 | return '.' in filename and \ 80 | filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS 81 | 82 | if __name__ == "__main__": 83 | print("* Starting web server... please wait until server has fully started") 84 | app.run(host='0.0.0.0', threaded=False) 85 | -------------------------------------------------------------------------------- /ConvNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import numpy 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision.datasets import ImageFolder 9 | from PIL import Image 10 | import numpy as np 11 | from torchvision import datasets, transforms 12 | from torch.autograd import Variable 13 | 14 | class Net(nn.Module): 15 | """ConvNet -> Max_Pool -> RELU -> ConvNet -> Max_Pool -> RELU -> FC -> RELU -> FC -> SOFTMAX""" 16 | def __init__(self): 17 | super(Net, self).__init__() 18 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 19 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 20 | self.fc1 = nn.Linear(4*4*50, 500) 21 | self.fc2 = nn.Linear(500, 10) 22 | 23 | def forward(self, x): 24 | x = F.relu(self.conv1(x)) 25 | x = F.max_pool2d(x, 2, 2) 26 | x = F.relu(self.conv2(x)) 27 | x = F.max_pool2d(x, 2, 2) 28 | x = x.view(-1, 4*4*50) 29 | x = F.relu(self.fc1(x)) 30 | x = self.fc2(x) 31 | return F.log_softmax(x, dim=1) 32 | 33 | 34 | class ConvNet(object): 35 | """MNIST ConvNet Model class""" 36 | def __init__(self, 37 | ckp="/input/mnist_convnet_model_epoch_10.pth", 38 | evalf="/eval"): 39 | """MNIST ConvNet Builder 40 | 41 | Args: 42 | ckp: path to model checkpoint file (to continue training). 43 | evalf: path to evaluate sample. 44 | """ 45 | # Path to model weight 46 | self._ckp = ckp 47 | # Use CUDA? 48 | self._cuda = torch.cuda.is_available() 49 | self._device = torch.device("cuda" if self._cuda else "cpu") 50 | try: 51 | os.path.isfile(ckp) 52 | self.ckp = ckp 53 | except IOError as e: 54 | # Does not exist OR no read permissions 55 | print ("Unable to open ckp file") 56 | self._evalf = evalf 57 | 58 | 59 | # Build the model loading the weights 60 | def build_model(self): 61 | self._model = Net().to(self._device) 62 | 63 | # Load Weights 64 | if self._cuda: 65 | self._model.load_state_dict(torch.load(self._ckp)) 66 | else: 67 | # Load GPU model on CPU 68 | self._model.load_state_dict(torch.load(self._ckp, map_location=lambda storage, loc: storage)) 69 | 70 | 71 | # Preprocess Images to be MNIST-compliant 72 | def image_preprocessing(self): 73 | """Take images from args.evalf, process to be MNIST compliant 74 | and classify them with MNIST ConvNet model""" 75 | def pil_loader(path): 76 | """Load images from /eval/ subfolder, convert to greyscale and resized it as squared""" 77 | with open(path, 'rb') as f: 78 | with Image.open(f) as img: 79 | sqrWidth = np.ceil(np.sqrt(img.size[0]*img.size[1])).astype(int) 80 | return img.convert('L').resize((sqrWidth, sqrWidth)) 81 | 82 | kwargs = {'num_workers': 1, 'pin_memory': True} if self._cuda else {} 83 | self._eval_loader = torch.utils.data.DataLoader(ImageFolder(root=self._evalf, 84 | transform=transforms.Compose([ 85 | transforms.Resize(28), 86 | transforms.CenterCrop(28), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.1307,), (0.3081,)) 89 | ]), loader=pil_loader), batch_size=1, **kwargs) 90 | 91 | def classify(self): 92 | """Classify the current eval batch""" 93 | self._model.eval() 94 | with torch.no_grad(): 95 | for data, target in self._eval_loader: 96 | data = data.to(self._device) 97 | output = self._model(data) 98 | label = output.argmax(dim=1, keepdim=True).item() 99 | return label 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Basic MNIST Example 2 | 3 | ![MNIST samples](images/MNIST_samples.png) 4 | 5 | This project implements a beginner classification task on [MNIST](http://yann.lecun.com/exdb/mnist/) dataset with a [Convolutional Neural Network(CNN or ConvNet)](https://en.wikipedia.org/wiki/Convolutional_neural_network) model. This is a porting of [pytorch/examples/mnist](https://github.com/pytorch/examples/tree/master/mnist) making it usables on [FloydHub](https://www.floydhub.com/). 6 | 7 | ## Usage 8 | 9 | Training/Evaluating script: 10 | 11 | ```bash 12 | usage: main.py [-h] [--dataroot DATAROOT] [--evalf EVALF] [--outf OUTF] 13 | [--ckpf CKPF] [--batch-size N] [--test-batch-size N] 14 | [--epochs N] [--lr LR] [--momentum M] [--no-cuda] [--seed S] 15 | [--log-interval N] [--train] [--evaluate] 16 | 17 | PyTorch MNIST Example 18 | 19 | optional arguments: 20 | -h, --help show this help message and exit 21 | --dataroot DATAROOT path to dataset 22 | --evalf EVALF path to evaluate sample 23 | --outf OUTF folder to output images and model checkpoints 24 | --ckpf CKPF path to model checkpoint file (to continue training) 25 | --batch-size N input batch size for training (default: 64) 26 | --test-batch-size N input batch size for testing (default: 1000) 27 | --epochs N number of epochs to train (default: 10) 28 | --lr LR learning rate (default: 0.01) 29 | --momentum M SGD momentum (default: 0.5) 30 | --no-cuda disables CUDA training 31 | --seed S random seed (default: 1) 32 | --log-interval N how many batches to wait before logging training status 33 | --train training a ConvNet model on MNIST dataset 34 | --evaluate evaluate a [pre]trained model 35 | ``` 36 | 37 | If you want to use more GPUs set `CUDA_VISIBLE_DEVICES` as bash variable then run your script: 38 | 39 | ```bash 40 | # CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 41 | ``` 42 | 43 | ## MNIST CNN Architecture 44 | 45 | ![MNIST CNN](images/mnist_convet.png) 46 | 47 | ## Run on FloydHub 48 | 49 | Here's the commands to training, evaluating and serving your MNIST ConvNet model on [FloydHub](ttps://www.floydhub.com/). 50 | 51 | ### Project Setup 52 | 53 | Before you start, log in on FloydHub with the [floyd login](http://docs.floydhub.com/commands/login/) command, then fork and init the project (make sure you have already [created the project on FloydHub](https://docs.floydhub.com/guides/basics/create_new/)): 54 | 55 | ```bash 56 | $ git clone https://github.com/floydhub/mnist.git 57 | $ cd mnist 58 | $ floyd init mnist 59 | ``` 60 | 61 | ### Training 62 | 63 | This project will automatically dowload and process the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset for you, moreover I have already [uploaded it as FloydHub dataset](https://www.floydhub.com/redeipirati/datasets/pytorch-mnist) so that you can try and familiarize with `--data` parameter which mount the specified volume(datasets/model) inside the container of your FloydHub instance. 64 | 65 | Now it's time to run our training on FloydHub. In this example we will train the model for 10 epochs with a gpu instance and with cuda enabled. 66 | **Note**: If you want to mount/create a dataset look at the [docs](http://docs.floydhub.com/guides/basics/create_new/#create-a-new-dataset). 67 | 68 | ```bash 69 | $ floyd run --gpu --env pytorch-1.0 --data redeipirati/datasets/pytorch-mnist/1:input "python main.py --train" 70 | ``` 71 | 72 | Note: 73 | - `--gpu` run your job on a FloydHub GPU instance 74 | - `--env pytorch-1.0`, PyTorch 1.0 on Python3 75 | - `--data redeipirati/datasets/pytorch-mnist/1` mounts the pytorch mnist dataset in the `/input` folder inside the container for our job so that we do not need to dowload it at training time. 76 | 77 | 78 | You can follow along the progress by using the [logs](http://docs.floydhub.com/commands/logs/) command. 79 | The training should take about 2 minutes on a GPU instance and about 15 minutes on a CPU one. 80 | 81 | ### Evaluating 82 | 83 | It's time to evaluate our model with some images: 84 | 85 | ```bash 86 | floyd run --gpu --env pytorch-1.0 --data :resume "python main.py --evaluate --ckpf /resume/ --evalf ./test" 87 | ``` 88 | 89 | Notes: 90 | 91 | - I've prepared for you some images in the `test` folder that you can use to evaluate your model. Feel free to add on it a bunch of handwritten images download from the web or created by you. 92 | - Remember to evaluate images which are taken from a similar distribution, otherwise you will have bad performance due to distribution mismatch. 93 | 94 | ### Try our pre-trained model 95 | 96 | We have provided to you a pre-trained model trained for 10 epochs with an accuracy of 98%. 97 | 98 | ```bash 99 | floyd run --gpu --env pytorch-1.0 --data redeipirati/datasets/pytorch-mnist-10-epochs-model/2:/model "python main.py --evaluate --ckpf /model/mnist_convnet_model_epoch_10.pth --evalf ./test" 100 | ``` 101 | 102 | ### Serve model through REST API 103 | 104 | FloydHub supports seving mode for demo and testing purpose. If you run a job 105 | with `--mode serve` flag, FloydHub will run the `app.py` file in your project 106 | and attach it to a dynamic service endpoint: 107 | 108 | ```bash 109 | floyd run --gpu --mode serve --env pytorch-1.0 --data :input 110 | ``` 111 | 112 | The above command will print out a service endpoint for this job in your terminal console. Or you can use the more name-friendly (static) serving URL that you will find in the Model API tab of your project(`https://www.floydlabs.com/serve//projects/`) 113 | 114 | The service endpoint will take a couple minutes to become ready. Once it's up, you can interact with the model by sending an handwritten image file with a POST request that the model will classify: 115 | ```bash 116 | # Template 117 | # curl -X POST -F "file=@" -F "ckp=" 118 | 119 | # e.g. of a POST req 120 | curl -X POST -F "file=@./test/images/1.png" https://www.floydlabs.com/serve/BhZCFAKom6Z8RptVKskHZW 121 | ``` 122 | 123 | Any job running in serving mode will stay up until it reaches maximum runtime. So 124 | once you are done testing, **remember to shutdown the job!** 125 | 126 | ## More resources 127 | 128 | Some useful resources on MNIST and ConvNet: 129 | 130 | - [MNIST](http://yann.lecun.com/exdb/mnist/) 131 | - [Colah's blog](https://colah.github.io/posts/2014-10-Visualizing-MNIST/) 132 | - [FloydHub Building your first ConvNet](https://blog.floydhub.com/building-your-first-convnet/) 133 | - [How Convolutional Neural Networks work - Brandon Rohrer](https://youtu.be/FmpDIaiMIeA) 134 | - [An Intuitive Explanation of Convolutional Neural Networks](https://ujjwalkarn.me/2016/08/11/intuitive-explanation-convnets/) 135 | - [Stanford CS231n](https://cs231n.github.io/convolutional-networks/) 136 | - [Stanford CS231n Winter 2016 - Karpathy](https://youtu.be/NfnWJUyUJYU) 137 | 138 | ## Contributing 139 | 140 | For any questions, bug(even typos) and/or features requests do not hesitate to contact me or open an issue! 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import numpy 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision.datasets import ImageFolder 10 | from PIL import Image 11 | import numpy as np 12 | from torchvision import datasets, transforms 13 | from torch.autograd import Variable 14 | 15 | # Training settings 16 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 17 | parser.add_argument('--dataroot', default="/input/" ,help='path to dataset') 18 | parser.add_argument('--evalf', default="/eval/" ,help='path to evaluate sample') 19 | parser.add_argument('--outf', default='models', 20 | help='folder to output images and model checkpoints') 21 | parser.add_argument('--ckpf', default='', 22 | help="path to model checkpoint file (to continue training)") 23 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 24 | help='input batch size for training (default: 64)') 25 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 26 | help='input batch size for testing (default: 1000)') 27 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 28 | help='number of epochs to train (default: 10)') 29 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 30 | help='learning rate (default: 0.01)') 31 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 32 | help='SGD momentum (default: 0.5)') 33 | parser.add_argument('--no-cuda', action='store_true', default=False, 34 | help='disables CUDA training') 35 | parser.add_argument('--seed', type=int, default=1, metavar='S', 36 | help='random seed (default: 1)') 37 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 38 | help='how many batches to wait before logging training status') 39 | parser.add_argument('--train', action='store_true', 40 | help='training a ConvNet model on MNIST dataset') 41 | parser.add_argument('--evaluate', action='store_true', 42 | help='evaluate a [pre]trained model') 43 | 44 | 45 | args = parser.parse_args() 46 | # use CUDA? 47 | use_cuda = not args.no_cuda and torch.cuda.is_available() 48 | device = torch.device("cuda" if use_cuda else "cpu") 49 | 50 | # Is there the outf? 51 | try: 52 | os.makedirs(args.outf) 53 | except OSError: 54 | pass 55 | 56 | torch.manual_seed(args.seed) 57 | if use_cuda: 58 | torch.cuda.manual_seed(args.seed) 59 | 60 | # From MNIST to Tensor 61 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 62 | 63 | # Load MNIST only if training 64 | if args.train: 65 | train_loader = torch.utils.data.DataLoader( 66 | datasets.MNIST(root=args.dataroot, train=True, download=True, 67 | transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.1307,), (0.3081,)) 70 | ])), 71 | batch_size=args.batch_size, shuffle=True, **kwargs) 72 | test_loader = torch.utils.data.DataLoader( 73 | datasets.MNIST(root=args.dataroot, train=False, transform=transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.1307,), (0.3081,)) 76 | ])), 77 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 78 | 79 | 80 | class Net(nn.Module): 81 | """ConvNet -> Max_Pool -> RELU -> ConvNet -> Max_Pool -> RELU -> FC -> RELU -> FC -> SOFTMAX""" 82 | def __init__(self): 83 | super(Net, self).__init__() 84 | self.conv1 = nn.Conv2d(1, 20, 5, 1) 85 | self.conv2 = nn.Conv2d(20, 50, 5, 1) 86 | self.fc1 = nn.Linear(4*4*50, 500) 87 | self.fc2 = nn.Linear(500, 10) 88 | 89 | def forward(self, x): 90 | x = F.relu(self.conv1(x)) 91 | x = F.max_pool2d(x, 2, 2) 92 | x = F.relu(self.conv2(x)) 93 | x = F.max_pool2d(x, 2, 2) 94 | x = x.view(-1, 4*4*50) 95 | x = F.relu(self.fc1(x)) 96 | x = self.fc2(x) 97 | return F.log_softmax(x, dim=1) 98 | 99 | 100 | model = Net().to(device) 101 | 102 | # Load checkpoint 103 | if args.ckpf != '': 104 | if use_cuda: 105 | model.load_state_dict(torch.load(args.ckpf)) 106 | else: 107 | # Load GPU model on CPU 108 | model.load_state_dict(torch.load(args.ckpf, map_location=lambda storage, loc: storage)) 109 | 110 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 111 | 112 | 113 | def train(args, model, device, train_loader, optimizer, epoch): 114 | """Training""" 115 | model.train() 116 | for batch_idx, (data, target) in enumerate(train_loader): 117 | data, target = data.to(device), target.to(device) 118 | optimizer.zero_grad() 119 | output = model(data) 120 | loss = F.nll_loss(output, target) 121 | loss.backward() 122 | optimizer.step() 123 | if batch_idx % args.log_interval == 0: 124 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 125 | epoch, batch_idx * len(data), len(train_loader.dataset), 126 | 100. * batch_idx / len(train_loader), loss.item())) 127 | print('{{"metric": "Train - NLL Loss", "value": {}}}'.format( 128 | loss.item())) 129 | 130 | 131 | def test(args, model, device, test_loader, epoch): 132 | """Testing""" 133 | model.eval() 134 | test_loss = 0 135 | correct = 0 136 | with torch.no_grad(): 137 | for data, target in test_loader: 138 | data, target = data.to(device), target.to(device) 139 | output = model(data) 140 | test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss 141 | pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability 142 | correct += pred.eq(target.view_as(pred)).sum().item() 143 | 144 | test_loss /= len(test_loader.dataset) 145 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 146 | test_loss, correct, len(test_loader.dataset), 147 | 100. * correct / len(test_loader.dataset))) 148 | print('{{"metric": "Eval - NLL Loss", "value": {}, "epoch": {}}}'.format( 149 | test_loss, epoch)) 150 | print('{{"metric": "Eval - Accuracy", "value": {}, "epoch": {}}}'.format( 151 | 100. * correct / len(test_loader.dataset), epoch)) 152 | 153 | 154 | def test_image(): 155 | """Take images from args.evalf, process to be MNIST compliant 156 | and classify them with MNIST ConvNet model""" 157 | def get_images_name(folder): 158 | """Create a generator to list images name at evaluation time""" 159 | onlyfiles = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))] 160 | for f in onlyfiles: 161 | yield f 162 | 163 | def pil_loader(path): 164 | """Load images from /eval/ subfolder, convert to greyscale and resized it as squared""" 165 | with open(path, 'rb') as f: 166 | with Image.open(f) as img: 167 | sqrWidth = np.ceil(np.sqrt(img.size[0]*img.size[1])).astype(int) 168 | return img.convert('L').resize((sqrWidth, sqrWidth)) 169 | 170 | eval_loader = torch.utils.data.DataLoader(ImageFolder(root=args.evalf, transform=transforms.Compose([ 171 | transforms.Resize(28), 172 | transforms.CenterCrop(28), 173 | transforms.ToTensor(), 174 | transforms.Normalize((0.1307,), (0.3081,)) 175 | ]), loader=pil_loader), batch_size=1, **kwargs) 176 | 177 | # Name generator 178 | names = get_images_name(os.path.join(args.evalf, "images")) 179 | model.eval() 180 | with torch.no_grad(): 181 | for data, target in eval_loader: 182 | data, target = data.to(device), target.to(device) 183 | output = model(data) 184 | label = output.argmax(dim=1, keepdim=True).item() 185 | print ("Images: " + next(names) + ", Classified as: " + str(label)) 186 | 187 | # Train? 188 | if args.train: 189 | # Train + Test per epoch 190 | for epoch in range(1, args.epochs + 1): 191 | train(args, model, device, train_loader, optimizer, epoch) 192 | test(args, model, device, test_loader, epoch) 193 | 194 | # Do checkpointing - Is saved in outf 195 | torch.save(model.state_dict(), '%s/mnist_convnet_model_epoch_%d.pth' % (args.outf, args.epochs)) 196 | 197 | # Evaluate? 198 | if args.evaluate: 199 | test_image() 200 | --------------------------------------------------------------------------------