├── .gitignore ├── CNN.py ├── DataPreprocessing.py ├── Detector_MTCNN.py ├── DirectoryStructure.ipynb ├── FaceExtraction.ipynb ├── Filter ├── filter15.png └── filter9.png ├── GIFS ├── mask.gif └── mask2.gif ├── Graph ├── ResNet9trainingraph.png ├── lr.png ├── res9vsres15.png └── resnet15.png ├── LICENSE ├── Readme.md ├── testingCNNmodels.ipynb └── training.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """CNN.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1bsDqsK60vP02Uk8Baei6mIIlzIGaKhE9 8 | 9 | #CNN 10 | """ 11 | 12 | import torch.nn as nn 13 | 14 | def conv_block(in_channels, out_channels, pooling=False): 15 | ''' 16 | params: in_channels: (int) number of input channels 17 | params: out_channels: (int) number of output channels 18 | params: pooling: (bool) use pooling or not 19 | return: convolutional layers 20 | ''' 21 | conv_layers = nn.Sequential( 22 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=1), 23 | nn.BatchNorm2d(out_channels), 24 | nn.ReLU() 25 | ) 26 | if pooling: 27 | conv_layers.add_module('max_pooling',nn.MaxPool2d(2)) 28 | return conv_layers 29 | 30 | class ResNet9(nn.Module): 31 | def __init__(self, in_channels, num_classes): 32 | super().__init__() 33 | 34 | #1st Block 35 | self.conv1 = conv_block(in_channels, 64)#input size 1*128*128 36 | self.conv2 = conv_block(64, 128, True) #After pooling 64*64*64 37 | #Residual layer 38 | self.res1 = nn.Sequential(conv_block(128,128), conv_block(128,128)) 39 | 40 | #2nd Block 41 | self.conv3 = conv_block(128, 256, True) #After pooling 256*32*32 42 | self.conv4 = conv_block(256, 512, True) #After pooling 512*16*16 43 | #Residual layer 44 | self.res2 = nn.Sequential(conv_block(512,512), conv_block(512,512)) 45 | 46 | #Linear Network 47 | self.linear = nn.Sequential( 48 | nn.MaxPool2d(16), #After pooling 512*1*1 49 | nn.Flatten(), # 512 50 | nn.Linear(512, num_classes), 51 | nn.LogSoftmax() 52 | ) 53 | 54 | def forward(self,x): 55 | #Block-1 56 | out = self.conv1(x) 57 | out = self.conv2(out) 58 | res1 = self.res1(out) + out 59 | 60 | #Block-2 61 | out = self.conv3(res1) 62 | out = self.conv4(out) 63 | res2 = self.res2(out) + out 64 | 65 | #Linear network 66 | out = self.linear(res2) 67 | return out 68 | 69 | class ResNet15(nn.Module): 70 | def __init__(self, in_channels, num_classes): 71 | super().__init__() 72 | 73 | #1st Block 74 | self.conv1 = conv_block(in_channels, 64) #inputs size 1*128*128 75 | self.conv2 = conv_block(64, 128, True) #After pooling 64*64*64 76 | #Residual layer 77 | self.res1 = nn.Sequential(conv_block(128,128), conv_block(128,128)) 78 | 79 | #2nd Block 80 | self.conv3 = conv_block(128, 256, True) #After pooling 256*32*32 81 | self.conv4 = conv_block(256, 512, True) #After pooling 512*16*16 82 | #Residual layer 83 | self.res2 = nn.Sequential(conv_block(512,512), conv_block(512,512)) 84 | 85 | #3rd Block 86 | self.conv5 = conv_block(512, 512, True) #After pooling 512*8*8 87 | self.conv6 = conv_block(512, 1024, True) #After pooling 1024*4*4 88 | #Residual layer 89 | self.res3 = nn.Sequential(conv_block(1024,1024), conv_block(1024,1024)) 90 | 91 | 92 | #Linear Network 93 | self.linear = nn.Sequential( 94 | nn.MaxPool2d(4), #After pooling 1024*1*1 95 | nn.Flatten(), # 1024 96 | nn.Linear(1024, 512), 97 | nn.ReLU(), 98 | nn.Linear(512,128), 99 | nn.ReLU(), 100 | nn.Linear(128,num_classes), 101 | nn.LogSoftmax() 102 | ) 103 | 104 | def forward(self,x): 105 | #Block-1 106 | out = self.conv1(x) 107 | out = self.conv2(out) 108 | res1 = self.res1(out) + out 109 | 110 | #Block-2 111 | out = self.conv3(res1) 112 | out = self.conv4(out) 113 | res2 = self.res2(out) + out 114 | 115 | #Block-3 116 | out = self.conv5(res2) 117 | out = self.conv6(out) 118 | res3 = self.res3(out) + out 119 | 120 | #Linear network 121 | out = self.linear(res3) 122 | return out -------------------------------------------------------------------------------- /DataPreprocessing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """DataPreprocessing.ipynb 3 | 4 | Automatically generated by Colaboratory. 5 | 6 | Original file is located at 7 | https://colab.research.google.com/drive/1iZLQ3IqfMJ-DBVLYl4J9KPZInf_xoGxW 8 | 9 | #Data Preprocessing 10 | """ 11 | 12 | #Importing required libraries 13 | 14 | from torchvision.datasets import ImageFolder 15 | from torchvision import transforms 16 | import numpy as np 17 | import torch 18 | 19 | class Preprocessing(): 20 | # This class will transforms every images into the applied transforms. 21 | # And it returns the dataset as tensor dataset and the categories 22 | 23 | def __init__(self, path=None, img=None): 24 | ''' 25 | params: path: (str) directory to the image folder 26 | params: array: (array) image array 27 | ''' 28 | self.directory = path 29 | self.img = img 30 | 31 | def __image_transformation(self): 32 | ''' 33 | params: None 34 | return: transformations 35 | ''' 36 | transform = transforms.Compose([ 37 | transforms.Resize((130,130)), 38 | transforms.CenterCrop(128), 39 | transforms.Grayscale(1), 40 | transforms.ToTensor(), 41 | transforms.Normalize(0.5,0.5) 42 | ]) 43 | return transform 44 | 45 | 46 | def preprocessed_arrays(self): 47 | #For predicting 48 | ''' 49 | params: None 50 | return: (array) single Tensor data 51 | ''' 52 | img = self.img 53 | transforms = self.__image_transformation() 54 | return torch.tensor(np.expand_dims(transforms(img),0)) 55 | 56 | def preprocessed_dataset(self): 57 | ''' 58 | params: None 59 | return: Tensor dataset 60 | ''' 61 | #Using torch's ImageFolder to get data from directory and applying 62 | #transforms 63 | transformations = transforms.Compose([ 64 | transforms.RandomHorizontalFlip(), 65 | transforms.RandomPerspective(0.2,p=0.5), 66 | self.__image_transformation()]) 67 | 68 | dataset_train = ImageFolder(self.directory['train'], 69 | transform= transformations) 70 | dataset_test = ImageFolder(self.directory['test'], 71 | transform= self.__image_transformation()) 72 | 73 | 74 | return dataset_train, dataset_test 75 | 76 | -------------------------------------------------------------------------------- /Detector_MTCNN.py: -------------------------------------------------------------------------------- 1 | #Importing libraries 2 | from facenet_pytorch import MTCNN 3 | import cv2 as cv 4 | from PIL import Image 5 | import torch 6 | from math import ceil as r 7 | 8 | #Importing modules 9 | from CNN import ResNet15, ResNet9 10 | from DataPreprocessing import Preprocessing 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | #model 15 | model = ResNet15(1,2).to(device) 16 | #model = ResNet15(1,2) 17 | 18 | state_dict = torch.load('state_dict_resnet15.pth', map_location=device) 19 | #loading the state_dict to the model 20 | model.load_state_dict(state_dict) 21 | 22 | #the classifier 23 | mtcnn = MTCNN(select_largest=False, device=device) 24 | 25 | # Load a single image and display 26 | cap = cv.VideoCapture(0) 27 | 28 | labels = { 29 | 0:'with mask', 30 | 1:'without mask' 31 | } 32 | color_dict={ 33 | 0:(255,0,255), 34 | 1:(255,0,0) 35 | } 36 | 37 | def resize(frame, height, width): 38 | if height and width >= 1000: 39 | return cv.resize(frame, (r(height*0.5), r(width*0.5))) 40 | elif height and width >= 2500: 41 | return cv.resize(frame, (r(height*0.7), r(width*0.7))) 42 | else: 43 | return cv.resize(frame, (height, width)) 44 | 45 | while True: 46 | success, frame = cap.read() 47 | 48 | if success: 49 | width, height, _ = frame.shape 50 | # if the video is too big uncomment the below code 51 | #frame = resize(frame, height, width) 52 | 53 | #padding the image to avoid the bounding box going out of the image 54 | #and crashes the program 55 | padding = cv.copyMakeBorder(frame, 50,50,50,50, cv.BORDER_CONSTANT) 56 | #converting numpy array into image 57 | image = Image.fromarray(padding) 58 | 59 | #gives the face co-ordinates 60 | face_coord,_ = mtcnn.detect(image) 61 | 62 | 63 | if face_coord is not None: 64 | for coord in face_coord: 65 | for x1,y1,x2,y2 in [coord]: 66 | x1,y1,x2,y2 = r(x1),r(y1),r(x2),r(y2) 67 | 68 | #face array 69 | face = padding[y1:y2 ,x1:x2] 70 | 71 | #Preprocessing 72 | preprocess = Preprocessing(img=Image.fromarray(face)) 73 | #tensor array 74 | tensor_img_array = preprocess.preprocessed_arrays() 75 | 76 | #Predicting 77 | prob, label = torch.max(torch.exp(model( 78 | tensor_img_array.to(device))),dim=1) 79 | 80 | scale = round((y2-y1)*35/100) 81 | #mini box 82 | cv.rectangle(frame, (x1-50,y1-50), (x1-40,y1-40), 83 | color_dict[label.item()],-1) 84 | 85 | #Bounding box 86 | cv.rectangle(frame, (x1-50,y1-50), (x2-50,y2-50), 87 | color_dict[label.item()],1) 88 | 89 | cv.putText(frame,labels[label.item()], 90 | (x1-50,y1-53),cv.FONT_HERSHEY_SIMPLEX, 91 | scale*0.01,(255,255,0),1) 92 | 93 | cv.imshow("Frame", frame) 94 | # im = Image.fromarray(frame) 95 | # im.save('mask/a%s.png'%(a)) 96 | if cv.waitKey(1) & 0xFF == ord('q'): 97 | break 98 | 99 | 100 | else: 101 | print('End') 102 | break 103 | 104 | cap.release() 105 | cv.destroyAllWindows() 106 | 107 | -------------------------------------------------------------------------------- /DirectoryStructure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "DirectoryStructure.ipynb", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | } 13 | }, 14 | "cells": [ 15 | { 16 | "cell_type": "markdown", 17 | "metadata": { 18 | "id": "dkSk2uFilchj", 19 | "colab_type": "text" 20 | }, 21 | "source": [ 22 | "directory of the dataset before.\n", 23 | "```\n", 24 | ".\n", 25 | "│─── dataset\n", 26 | " └── with mask\n", 27 | " └── without mask\n", 28 | "\n", 29 | "```\n", 30 | "directory of the dataset after(goal).\n", 31 | "```\n", 32 | ".\n", 33 | "│─── dataset\n", 34 | " |──test\n", 35 | " | ├── with mask\n", 36 | " | └── without mask\n", 37 | " |──test\n", 38 | " ├── with mask\n", 39 | " └── without mask\n", 40 | "\n", 41 | "```\n", 42 | "\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "metadata": { 48 | "id": "5fL95hhf53YX", 49 | "colab_type": "code", 50 | "colab": {} 51 | }, 52 | "source": [ 53 | "#Importing libraries\n", 54 | "import os, math\n", 55 | "from google.colab import drive" 56 | ], 57 | "execution_count": null, 58 | "outputs": [] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "bvro868C59Vq", 64 | "colab_type": "code", 65 | "colab": { 66 | "base_uri": "https://localhost:8080/", 67 | "height": 34 68 | }, 69 | "outputId": "992f1842-3c70-45ff-b95e-dda110f4363c" 70 | }, 71 | "source": [ 72 | "drive.mount('/content/gdrive')\n", 73 | "dir_ = '/content/gdrive/My Drive/Colab Notebooks/Face mask/dataset'" 74 | ], 75 | "execution_count": null, 76 | "outputs": [ 77 | { 78 | "output_type": "stream", 79 | "text": [ 80 | "Mounted at /content/gdrive\n" 81 | ], 82 | "name": "stdout" 83 | } 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "metadata": { 89 | "id": "_R0y19SDaszr", 90 | "colab_type": "code", 91 | "colab": {} 92 | }, 93 | "source": [ 94 | "#changing the directory to the dataset\n", 95 | "os.chdir(dir_)" 96 | ], 97 | "execution_count": null, 98 | "outputs": [] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "metadata": { 103 | "id": "33WWDgPs6kOY", 104 | "colab_type": "code", 105 | "colab": {} 106 | }, 107 | "source": [ 108 | "def folder_maker(folder_names):\n", 109 | " '''\n", 110 | " This functions creates a new folder.\n", 111 | "\n", 112 | " params: folders: folder name\n", 113 | " return: None\n", 114 | " '''\n", 115 | " for x in folder_names:\n", 116 | " #checking if the folder exist or not\n", 117 | " if not os.path.isdir(x):\n", 118 | " os.mkdir(x)\n", 119 | " else:\n", 120 | " pass\n", 121 | "\n", 122 | "folder_names = ['test','train','train/with mask','train/without mask',\n", 123 | " 'test/with mask', 'test/without mask']\n", 124 | "\n", 125 | "\n", 126 | "folder_maker(folder_names)" 127 | ], 128 | "execution_count": null, 129 | "outputs": [] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "metadata": { 134 | "id": "aS26T41jKAe4", 135 | "colab_type": "code", 136 | "colab": {} 137 | }, 138 | "source": [ 139 | "#getting the images from the folders\n", 140 | "with_mask_dir = os.listdir('with mask')\n", 141 | "without_mask_dir = os.listdir('without mask')" 142 | ], 143 | "execution_count": null, 144 | "outputs": [] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "metadata": { 149 | "id": "WTwkEPv0IVS-", 150 | "colab_type": "code", 151 | "colab": {} 152 | }, 153 | "source": [ 154 | "#12% of the data will be for the test dataset\n", 155 | "test_size_with = math.ceil(len(with_mask_dir)*12/100)\n", 156 | "test_size_without = math.ceil(len(without_mask_dir)*12/100)" 157 | ], 158 | "execution_count": null, 159 | "outputs": [] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "metadata": { 164 | "id": "0ssxVk2yIzsJ", 165 | "colab_type": "code", 166 | "colab": {} 167 | }, 168 | "source": [ 169 | "def move_withRange(items,range,from_, to):\n", 170 | " '''\n", 171 | " This function will move the data from one\n", 172 | " directory to other.\n", 173 | "\n", 174 | " params: items: the images\n", 175 | " params: range: number of data to move\n", 176 | " params: from_: directory from where the data to be moved\n", 177 | " params: to: desired directory to move in the data\n", 178 | " '''\n", 179 | " for img_name in items[:range]:\n", 180 | " os.replace('%s/%s'%(from_, img_name), '%s/%s'%(to,img_name))\n", 181 | "\n", 182 | "move_withRange(without_mask_dir, test_size_without, 'without mask', 'test/without mask')\n", 183 | "move_withRange(with_mask_dir, test_size_with, 'with mask', 'test/with mask')" 184 | ], 185 | "execution_count": null, 186 | "outputs": [] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "metadata": { 191 | "id": "MaVFMkPogwvZ", 192 | "colab_type": "code", 193 | "colab": {} 194 | }, 195 | "source": [ 196 | "def move(items,from_, to):\n", 197 | " '''\n", 198 | " This is similar to the other fn but this\n", 199 | " function will move the entier data.\n", 200 | " '''\n", 201 | " for img_name in items:\n", 202 | " os.replace('%s/%s'%(from_, img_name), '%s/%s'%(to,img_name))\n", 203 | "\n", 204 | "move(with_mask_dir, 'with mask', 'train/with mask')\n", 205 | "move(without_mask_dir, 'without mask', 'train/without mask')" 206 | ], 207 | "execution_count": null, 208 | "outputs": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "metadata": { 213 | "id": "GGFR-kR8j9YG", 214 | "colab_type": "code", 215 | "colab": {} 216 | }, 217 | "source": [ 218 | "" 219 | ], 220 | "execution_count": null, 221 | "outputs": [] 222 | } 223 | ] 224 | } -------------------------------------------------------------------------------- /FaceExtraction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "FaceExtraction.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | } 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "g7HmeiKGDjsJ", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "The dataset contains irrelevant things like the background behind the face and \n", 24 | "upper part of the body. Training with those irrelevant stuffs makes the network\n", 25 | "poor during testing." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "_1JLukWpDpHM", 32 | "colab_type": "code", 33 | "colab": { 34 | "base_uri": "https://localhost:8080/", 35 | "height": 225 36 | }, 37 | "outputId": "25b84d8e-cad1-4dcd-8168-453a7ba1966b" 38 | }, 39 | "source": [ 40 | "#Importing libraries\n", 41 | "import os\n", 42 | "from PIL import Image\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "import torchvision.transforms as t\n", 45 | "from math import ceil as r\n", 46 | "!pip install facenet_pytorch\n", 47 | "from facenet_pytorch import MTCNN" 48 | ], 49 | "execution_count": null, 50 | "outputs": [ 51 | { 52 | "output_type": "stream", 53 | "text": [ 54 | "Collecting facenet_pytorch\n", 55 | "\u001b[?25l Downloading https://files.pythonhosted.org/packages/58/26/9dbb553500bff164cdcd491785cfe55dcbb34b431d44f655640476db8d82/facenet_pytorch-2.2.9-py3-none-any.whl (1.9MB)\n", 56 | "\r\u001b[K |▏ | 10kB 18.2MB/s eta 0:00:01\r\u001b[K |▍ | 20kB 6.2MB/s eta 0:00:01\r\u001b[K |▌ | 30kB 6.0MB/s eta 0:00:01\r\u001b[K |▊ | 40kB 6.8MB/s eta 0:00:01\r\u001b[K |▉ | 51kB 6.3MB/s eta 0:00:01\r\u001b[K |█ | 61kB 7.0MB/s eta 0:00:01\r\u001b[K |█▏ | 71kB 7.3MB/s eta 0:00:01\r\u001b[K |█▍ | 81kB 7.9MB/s eta 0:00:01\r\u001b[K |█▋ | 92kB 7.4MB/s eta 0:00:01\r\u001b[K |█▊ | 102kB 7.8MB/s eta 0:00:01\r\u001b[K |██ | 112kB 7.8MB/s eta 0:00:01\r\u001b[K |██ | 122kB 7.8MB/s eta 0:00:01\r\u001b[K |██▎ | 133kB 7.8MB/s eta 0:00:01\r\u001b[K |██▍ | 143kB 7.8MB/s eta 0:00:01\r\u001b[K |██▋ | 153kB 7.8MB/s eta 0:00:01\r\u001b[K |██▉ | 163kB 7.8MB/s eta 0:00:01\r\u001b[K |███ | 174kB 7.8MB/s eta 0:00:01\r\u001b[K |███▏ | 184kB 7.8MB/s eta 0:00:01\r\u001b[K |███▎ | 194kB 7.8MB/s eta 0:00:01\r\u001b[K |███▌ | 204kB 7.8MB/s eta 0:00:01\r\u001b[K |███▋ | 215kB 7.8MB/s eta 0:00:01\r\u001b[K |███▉ | 225kB 7.8MB/s eta 0:00:01\r\u001b[K |████ | 235kB 7.8MB/s eta 0:00:01\r\u001b[K |████▏ | 245kB 7.8MB/s eta 0:00:01\r\u001b[K |████▍ | 256kB 7.8MB/s eta 0:00:01\r\u001b[K |████▌ | 266kB 7.8MB/s eta 0:00:01\r\u001b[K |████▊ | 276kB 7.8MB/s eta 0:00:01\r\u001b[K |████▉ | 286kB 7.8MB/s eta 0:00:01\r\u001b[K |█████ | 296kB 7.8MB/s eta 0:00:01\r\u001b[K |█████▎ | 307kB 7.8MB/s eta 0:00:01\r\u001b[K |█████▍ | 317kB 7.8MB/s eta 0:00:01\r\u001b[K |█████▋ | 327kB 7.8MB/s eta 0:00:01\r\u001b[K |█████▊ | 337kB 7.8MB/s eta 0:00:01\r\u001b[K |██████ | 348kB 7.8MB/s eta 0:00:01\r\u001b[K |██████ | 358kB 7.8MB/s eta 0:00:01\r\u001b[K |██████▎ | 368kB 7.8MB/s eta 0:00:01\r\u001b[K |██████▌ | 378kB 7.8MB/s eta 0:00:01\r\u001b[K |██████▋ | 389kB 7.8MB/s eta 0:00:01\r\u001b[K |██████▉ | 399kB 7.8MB/s eta 0:00:01\r\u001b[K |███████ | 409kB 7.8MB/s eta 0:00:01\r\u001b[K |███████▏ | 419kB 7.8MB/s eta 0:00:01\r\u001b[K |███████▎ | 430kB 7.8MB/s eta 0:00:01\r\u001b[K |███████▌ | 440kB 7.8MB/s eta 0:00:01\r\u001b[K |███████▊ | 450kB 7.8MB/s eta 0:00:01\r\u001b[K |███████▉ | 460kB 7.8MB/s eta 0:00:01\r\u001b[K |████████ | 471kB 7.8MB/s eta 0:00:01\r\u001b[K |████████▏ | 481kB 7.8MB/s eta 0:00:01\r\u001b[K |████████▍ | 491kB 7.8MB/s eta 0:00:01\r\u001b[K |████████▌ | 501kB 7.8MB/s eta 0:00:01\r\u001b[K |████████▊ | 512kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████ | 522kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████ | 532kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████▎ | 542kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████▍ | 552kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████▋ | 563kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████▊ | 573kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████ | 583kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████▏ | 593kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████▎ | 604kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████▌ | 614kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████▋ | 624kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████▉ | 634kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████ | 645kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▏ | 655kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▍ | 665kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▌ | 675kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▊ | 686kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████▉ | 696kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████ | 706kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████▏ | 716kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████▍ | 727kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████▋ | 737kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████▊ | 747kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████ | 757kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████ | 768kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████▎ | 778kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████▍ | 788kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████▋ | 798kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████▉ | 808kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████ | 819kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████▏ | 829kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████▎ | 839kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████▌ | 849kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 860kB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████▉ | 870kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████ | 880kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████▏ | 890kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████▍ | 901kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████▌ | 911kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████▊ | 921kB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████▉ | 931kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████ | 942kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████▏ | 952kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████▍ | 962kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████▋ | 972kB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 983kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████ | 993kB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████ | 1.0MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████▎ | 1.0MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████▍ | 1.0MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████▋ | 1.0MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████▉ | 1.0MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████▏ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████▎ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████▌ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████▋ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████▉ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████▏ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████▍ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████▌ | 1.1MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████▊ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████▉ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████▎ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████▍ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████▊ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▎ | 1.2MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▌ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▋ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████▉ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▏ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▎ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▌ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▊ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████▉ | 1.3MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▏ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▍ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▌ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████▊ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▎ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▍ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▋ | 1.4MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████▊ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▏ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▎ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▌ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▋ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████▉ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▏ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▍ | 1.5MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▌ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▊ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████▉ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▏ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▍ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▋ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████▊ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████ | 1.6MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▎ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▍ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▋ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████▉ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▎ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▌ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▋ | 1.7MB 7.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▉ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▏ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▍ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▌ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▊ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |██████████████████████████████▉ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████ | 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▍| 1.8MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▋| 1.9MB 7.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▊| 1.9MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 1.9MB 7.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 1.9MB 7.8MB/s \n", 57 | "\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from facenet_pytorch) (2.23.0)\n", 58 | "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from facenet_pytorch) (1.18.5)\n", 59 | "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->facenet_pytorch) (2.9)\n", 60 | "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->facenet_pytorch) (1.24.3)\n", 61 | "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->facenet_pytorch) (2020.6.20)\n", 62 | "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->facenet_pytorch) (3.0.4)\n", 63 | "Installing collected packages: facenet-pytorch\n", 64 | "Successfully installed facenet-pytorch-2.2.9\n" 65 | ], 66 | "name": "stdout" 67 | } 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "ZG5nqIShDaup", 74 | "colab_type": "code", 75 | "colab": { 76 | "base_uri": "https://localhost:8080/", 77 | "height": 34 78 | }, 79 | "outputId": "ec2e3d9b-2c0d-4df5-f670-e2c186a4ff27" 80 | }, 81 | "source": [ 82 | "#mounting\n", 83 | "from google.colab import drive\n", 84 | "drive.mount('/content/drive')" 85 | ], 86 | "execution_count": null, 87 | "outputs": [ 88 | { 89 | "output_type": "stream", 90 | "text": [ 91 | "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" 92 | ], 93 | "name": "stdout" 94 | } 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "metadata": { 100 | "id": "MvzE_G1MSUzM", 101 | "colab_type": "code", 102 | "colab": {} 103 | }, 104 | "source": [ 105 | "dir_ = '/content/drive/My Drive/Colab Notebooks/Face mask/dataset/new'\n", 106 | "os.chdir(dir_)" 107 | ], 108 | "execution_count": null, 109 | "outputs": [] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "metadata": { 114 | "id": "NLq3j1sZSat8", 115 | "colab_type": "code", 116 | "colab": {} 117 | }, 118 | "source": [ 119 | "#images with mask\n", 120 | "imgs_with = os.listdir('with mask')\n", 121 | "#images without mask\n", 122 | "imgs_without = os.listdir('without mask')" 123 | ], 124 | "execution_count": null, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "ckNOTevdDCEk", 131 | "colab_type": "code", 132 | "colab": {} 133 | }, 134 | "source": [ 135 | "if not os.path.isdir('train'):\n", 136 | " os.mkdir('train')" 137 | ], 138 | "execution_count": null, 139 | "outputs": [] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "metadata": { 144 | "id": "cTK8ywqNTCIp", 145 | "colab_type": "code", 146 | "colab": {} 147 | }, 148 | "source": [ 149 | "mtcnn = MTCNN()\n", 150 | "def face_extractor(items, from_, to):\n", 151 | " '''\n", 152 | " Extracts the face and saves it in another file\n", 153 | "\n", 154 | " params: items: the images\n", 155 | " params: from_: directory from where the data to be moved\n", 156 | " params: to: desired directory to move in the data\n", 157 | " '''\n", 158 | " \n", 159 | " for i in range(len(items)):\n", 160 | " \n", 161 | " pic = Image.open('%s/%s'%(from_,items[i]))\n", 162 | " tensor = t.ToTensor()\n", 163 | " try: \n", 164 | " face = mtcnn.detect(pic)\n", 165 | " \n", 166 | " tensor_img = tensor(pic)\n", 167 | " \n", 168 | " if face[0] is not None:\n", 169 | " for x,y,w,h in face[0]:\n", 170 | " pass\n", 171 | " x,y,w,h = r(x), r(y), r(w), r(h) \n", 172 | " \n", 173 | " pil = t.ToPILImage()(tensor_img[:,y:h, x:w])\n", 174 | " #This way the dataset contains images above the size (130,130) 175 | " if pil.size >= (130,130):\n", 176 | " pil.save('%s/%s'%(to, items[i]))\n", 177 | " \n", 178 | " except: \n", 179 | " pass\n", 180 | "\n", 181 | "\n", 182 | "face_extractor(imgs_with, 'with mask', 'train/with mask')\n", 183 | "face_extractor(imgs_without, 'without mask', 'train/without mask')" 184 | ], 185 | "execution_count": null, 186 | "outputs": [] 187 | } 188 | ] 189 | } 190 | -------------------------------------------------------------------------------- /Filter/filter15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Filter/filter15.png -------------------------------------------------------------------------------- /Filter/filter9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Filter/filter9.png -------------------------------------------------------------------------------- /GIFS/mask.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/GIFS/mask.gif -------------------------------------------------------------------------------- /GIFS/mask2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/GIFS/mask2.gif -------------------------------------------------------------------------------- /Graph/ResNet9trainingraph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Graph/ResNet9trainingraph.png -------------------------------------------------------------------------------- /Graph/lr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Graph/lr.png -------------------------------------------------------------------------------- /Graph/res9vsres15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Graph/res9vsres15.png -------------------------------------------------------------------------------- /Graph/resnet15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ashborn-SM/Face-Mask-Detection-Pytorch/4eebcabc8ef5889a778097eebeda4dc7fded0c7e/Graph/resnet15.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Rahul0128 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ## About 2 | This is a course project i have created using pytorch ,using all the skills i learned 3 | from [freecodecamp](https://www.freecodecamp.org/) and [Jovian](https://www.jovian.ml/). 4 | 5 | ## How to use: 6 | Run the Detector_MTCNN.py file. At present video is taken from the webcam(live) if you want 7 | to feed in a pre-recorded video give the path of the file instead of 0 in line 28 *cv.VideoCapture(0)*. 8 | If the video is too big and potentially freeze the computer uncomment line 57 *#frame = resize(frame, height, width)* 9 | this will resize it. 10 | 11 | **Make sure to download the state dict to get the predictions right** 12 | 13 | **state dict- https://drive.google.com/drive/folders/1oRBDw_HmqCaQ2jnT4aSZHyYVBi4ELhSt?usp=sharing, 14 | dataset - https://drive.google.com/drive/folders/1LEKdePxk854r0kT542g42loM1z1UkL4g?usp=sharing** 15 | 16 | I tried both the models with different video's, ResNet9 and ResNet15 performed well. 17 | I noticed that there are some video ResNet9 performed well but ResNet15 did not and vice-versa. 18 | 19 | Try both the models and see whats best. 20 | 21 | ### Note: 22 | The model is trained on certain type of mask so it may not perform well on other kinds of mask. 23 | 24 | ## Third-Party Libraries used: 25 | 1. Facenet PyTorch 26 | 2. Open CV 27 | 3. PyTorch 28 | 4. Numpy 29 | 5. Matplotlib 30 | 31 | ## Guide 32 | 1. Guide to MTCNN in facenet-pytorch - https://www.kaggle.com/timesler/guide-to-mtcnn-in-facenet-pytorch 33 | 2. Facenet implementation in a video - https://github.com/timesler/facenet-pytorch/blob/master/examples/face_tracking.ipynb 34 | 35 | ### Predicted video 36 | ![Mask](https://github.com/Rahul0128/Face-Mask-Detection-Pytorch/blob/master/GIFS/mask.gif), 37 | ![](https://github.com/Rahul0128/Face-Mask-Detection-Pytorch/blob/master/GIFS/mask2.gif) 38 | -------------------------------------------------------------------------------- /testingCNNmodels.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "testingCNNmodels.ipynb", 7 | "provenance": [] 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "accelerator": "GPU" 14 | }, 15 | "cells": [ 16 | { 17 | "cell_type": "markdown", 18 | "metadata": { 19 | "id": "aE_n-CoyFgj_", 20 | "colab_type": "text" 21 | }, 22 | "source": [ 23 | "Testing both the custom built ResNet9 and ResNet15" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "metadata": { 29 | "id": "ekZHJ2hz02An", 30 | "colab_type": "code", 31 | "colab": { 32 | "base_uri": "https://localhost:8080/", 33 | "height": 122 34 | }, 35 | "outputId": "a44728cb-57cd-4191-a4f3-4d6c3f359d43" 36 | }, 37 | "source": [ 38 | "#Importing libraries\n", 39 | "import matplotlib.pyplot as plt\n", 40 | "from google.colab import drive\n", 41 | "import torch\n", 42 | "from torch.utils.data import DataLoader\n", 43 | "drive.mount('/content/gdrive')" 44 | ], 45 | "execution_count": 1, 46 | "outputs": [ 47 | { 48 | "output_type": "stream", 49 | "text": [ 50 | "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n", 51 | "\n", 52 | "Enter your authorization code:\n", 53 | "··········\n", 54 | "Mounted at /content/gdrive\n" 55 | ], 56 | "name": "stdout" 57 | } 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "metadata": { 63 | "id": "rT7LPviX1Vjf", 64 | "colab_type": "code", 65 | "colab": { 66 | "base_uri": "https://localhost:8080/", 67 | "height": 170 68 | }, 69 | "outputId": "13cc47d0-d20b-4d4b-b93f-ed812840ce1b" 70 | }, 71 | "source": [ 72 | "!pip install import_ipynb\n", 73 | "import import_ipynb" 74 | ], 75 | "execution_count": 2, 76 | "outputs": [ 77 | { 78 | "output_type": "stream", 79 | "text": [ 80 | "Collecting import_ipynb\n", 81 | " Downloading https://files.pythonhosted.org/packages/63/35/495e0021bfdcc924c7cdec4e9fbb87c88dd03b9b9b22419444dc370c8a45/import-ipynb-0.1.3.tar.gz\n", 82 | "Building wheels for collected packages: import-ipynb\n", 83 | " Building wheel for import-ipynb (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 84 | " Created wheel for import-ipynb: filename=import_ipynb-0.1.3-cp36-none-any.whl size=2976 sha256=fd14f250ef23f4ca18b073df946cb5f6d0772d59424f2d0fb5f74412673e241b\n", 85 | " Stored in directory: /root/.cache/pip/wheels/b4/7b/e9/a3a6e496115dffdb4e3085d0ae39ffe8a814eacc44bbf494b5\n", 86 | "Successfully built import-ipynb\n", 87 | "Installing collected packages: import-ipynb\n", 88 | "Successfully installed import-ipynb-0.1.3\n" 89 | ], 90 | "name": "stdout" 91 | } 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "metadata": { 97 | "id": "eN4tayU_0_Vm", 98 | "colab_type": "code", 99 | "colab": { 100 | "base_uri": "https://localhost:8080/", 101 | "height": 34 102 | }, 103 | "outputId": "b3ea6cd8-02d8-4deb-b576-a8e1836cf030" 104 | }, 105 | "source": [ 106 | "%cd '/content/gdrive/My Drive/Colab Notebooks/Face mask'" 107 | ], 108 | "execution_count": 3, 109 | "outputs": [ 110 | { 111 | "output_type": "stream", 112 | "text": [ 113 | "/content/gdrive/My Drive/Colab Notebooks/Face mask\n" 114 | ], 115 | "name": "stdout" 116 | } 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "metadata": { 122 | "id": "EG9fFAf51MJu", 123 | "colab_type": "code", 124 | "colab": { 125 | "base_uri": "https://localhost:8080/", 126 | "height": 51 127 | }, 128 | "outputId": "ad05f218-3e06-49bf-fcac-4862dd7c7039" 129 | }, 130 | "source": [ 131 | "#importing modules\n", 132 | "from CNN import ResNet9, ResNet15\n", 133 | "from DataPreprocessing import Preprocessing" 134 | ], 135 | "execution_count": 4, 136 | "outputs": [ 137 | { 138 | "output_type": "stream", 139 | "text": [ 140 | "importing Jupyter notebook from CNN.ipynb\n", 141 | "importing Jupyter notebook from DataPreprocessing.ipynb\n" 142 | ], 143 | "name": "stdout" 144 | } 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "metadata": { 150 | "id": "Mpo7k2Zv1jLc", 151 | "colab_type": "code", 152 | "colab": {} 153 | }, 154 | "source": [ 155 | "device = torch.device('cuda')" 156 | ], 157 | "execution_count": 5, 158 | "outputs": [] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "metadata": { 163 | "id": "PlSWlDR20yBB", 164 | "colab_type": "code", 165 | "colab": { 166 | "base_uri": "https://localhost:8080/", 167 | "height": 34 168 | }, 169 | "outputId": "7cc4b190-ddbb-4dd3-bc0e-ecffb1a328f7" 170 | }, 171 | "source": [ 172 | "#model and loading the state_dict\n", 173 | "model15 = ResNet15(1,2).to(device)\n", 174 | "model9 = ResNet9(1,2).to(device)\n", 175 | "model15.load_state_dict(torch.load('state_dict_resnet15.pth'))\n", 176 | "model9.load_state_dict(torch.load('state_dict_resnet9.pth'))" 177 | ], 178 | "execution_count": 6, 179 | "outputs": [ 180 | { 181 | "output_type": "execute_result", 182 | "data": { 183 | "text/plain": [ 184 | "" 185 | ] 186 | }, 187 | "metadata": { 188 | "tags": [] 189 | }, 190 | "execution_count": 6 191 | } 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "g7UoC0s12azM", 198 | "colab_type": "code", 199 | "colab": {} 200 | }, 201 | "source": [ 202 | "data = Preprocessing(path={'test':'dataset/test',\n", 203 | " 'train':'dataset/train'})\n", 204 | "_, test_ds = data.preprocessed_dataset()\n", 205 | "test_dl = DataLoader(test_ds, 10, num_workers=3,\n", 206 | " pin_memory=True)" 207 | ], 208 | "execution_count": 8, 209 | "outputs": [] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "metadata": { 214 | "id": "fGmtIgV62MmX", 215 | "colab_type": "code", 216 | "colab": { 217 | "base_uri": "https://localhost:8080/", 218 | "height": 241 219 | }, 220 | "outputId": "e714dac9-9f07-4843-a96f-46744fe6b907" 221 | }, 222 | "source": [ 223 | "#testing\n", 224 | "accuracy9 = []\n", 225 | "accuracy15 = []\n", 226 | "for x, y in test_dl:\n", 227 | " x, y = x.to(device), y.to(device)\n", 228 | " _, label15 = torch.max(torch.exp(model15(x)), dim=1)\n", 229 | " _, label9 = torch.max(torch.exp(model9(x)), dim=1)\n", 230 | " acc9 = torch.sum(label9 == y).item()/len(y)\n", 231 | " acc15 = torch.sum(label15 == y).item()/len(y)\n", 232 | " accuracy9.append(acc9)\n", 233 | " accuracy15.append(acc15)\n", 234 | " print('ResNet9: {}, ResNet15: {}'.format(acc9, acc15))\n", 235 | " " 236 | ], 237 | "execution_count": 9, 238 | "outputs": [ 239 | { 240 | "output_type": "stream", 241 | "text": [ 242 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py:100: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 243 | " input = module(input)\n" 244 | ], 245 | "name": "stderr" 246 | }, 247 | { 248 | "output_type": "stream", 249 | "text": [ 250 | "ResNet9: 0.8, ResNet15: 1.0\n", 251 | "ResNet9: 0.8, ResNet15: 1.0\n", 252 | "ResNet9: 1.0, ResNet15: 1.0\n", 253 | "ResNet9: 0.9, ResNet15: 1.0\n", 254 | "ResNet9: 0.9, ResNet15: 0.9\n", 255 | "ResNet9: 1.0, ResNet15: 0.7\n", 256 | "ResNet9: 0.8, ResNet15: 0.9\n", 257 | "ResNet9: 1.0, ResNet15: 0.9\n", 258 | "ResNet9: 0.9, ResNet15: 0.8\n", 259 | "ResNet9: 0.8888888888888888, ResNet15: 0.7777777777777778\n" 260 | ], 261 | "name": "stdout" 262 | } 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "metadata": { 268 | "id": "eMnJcrGZ3sPC", 269 | "colab_type": "code", 270 | "colab": { 271 | "base_uri": "https://localhost:8080/", 272 | "height": 312 273 | }, 274 | "outputId": "3e2774d7-291e-4e95-9feb-0b4cddcdaa38" 275 | }, 276 | "source": [ 277 | "plt.plot(accuracy9, 'r--x', markersize=8, markeredgewidth=3\n", 278 | " , linewidth=3,label='ResNet9')\n", 279 | "plt.plot(accuracy15, 'c-1', markersize=12, markeredgewidth=3\n", 280 | " , linewidth=2,label='ResNet15')\n", 281 | "plt.grid()\n", 282 | "plt.title('Accuracy: ResNet9 vs ResNet15')\n", 283 | "plt.xlabel('batches')\n", 284 | "plt.ylabel('Accuracy')\n", 285 | "plt.legend(loc='lower right')\n" 286 | ], 287 | "execution_count": 10, 288 | "outputs": [ 289 | { 290 | "output_type": "execute_result", 291 | "data": { 292 | "text/plain": [ 293 | "" 294 | ] 295 | }, 296 | "metadata": { 297 | "tags": [] 298 | }, 299 | "execution_count": 10 300 | }, 301 | { 302 | "output_type": "display_data", 303 | "data": { 304 | "image/png": "\n", 305 | "text/plain": [ 306 | "
" 307 | ] 308 | }, 309 | "metadata": { 310 | "tags": [], 311 | "needs_background": "light" 312 | } 313 | } 314 | ] 315 | } 316 | ] 317 | } -------------------------------------------------------------------------------- /training.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "training.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "accelerator": "GPU" 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "sVbYgY3Rb8UL", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "#Training\n", 26 | "\n", 27 | "\n" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "metadata": { 33 | "id": "Hv2tuwUsb66A", 34 | "colab_type": "code", 35 | "colab": {} 36 | }, 37 | "source": [ 38 | "#Importing libraries\n", 39 | "import torch\n", 40 | "import torch.nn as nn\n", 41 | "from torchvision.utils import make_grid\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "from torch.utils.data import DataLoader, random_split\n", 44 | "import numpy as np\n", 45 | "import math" 46 | ], 47 | "execution_count": null, 48 | "outputs": [] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "metadata": { 53 | "id": "6fcJ9D9d-LnQ", 54 | "colab_type": "code", 55 | "colab": { 56 | "base_uri": "https://localhost:8080/", 57 | "height": 34 58 | }, 59 | "outputId": "aa881b03-e93f-4c54-84eb-cd383a59069b" 60 | }, 61 | "source": [ 62 | "#Libraries for importing ipynb file\n", 63 | "from google.colab import drive, files\n", 64 | "!pip install import_ipynb\n", 65 | "import import_ipynb" 66 | ], 67 | "execution_count": null, 68 | "outputs": [ 69 | { 70 | "output_type": "stream", 71 | "text": [ 72 | "Requirement already satisfied: import_ipynb in /usr/local/lib/python3.6/dist-packages (0.1.3)\n" 73 | ], 74 | "name": "stdout" 75 | } 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "metadata": { 81 | "id": "RF4q-TuDdMn9", 82 | "colab_type": "code", 83 | "colab": { 84 | "base_uri": "https://localhost:8080/", 85 | "height": 34 86 | }, 87 | "outputId": "ae8441c3-9d4a-4d27-f469-1c50de3fc52e" 88 | }, 89 | "source": [ 90 | "#Mounting \n", 91 | "drive.mount('/content/gdrive')" 92 | ], 93 | "execution_count": null, 94 | "outputs": [ 95 | { 96 | "output_type": "stream", 97 | "text": [ 98 | "Mounted at /content/gdrive\n" 99 | ], 100 | "name": "stdout" 101 | } 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "metadata": { 107 | "id": "WAEds0OKelYd", 108 | "colab_type": "code", 109 | "colab": { 110 | "base_uri": "https://localhost:8080/", 111 | "height": 68 112 | }, 113 | "outputId": "38f771a3-aaa1-447a-d4d7-00df7400f28c" 114 | }, 115 | "source": [ 116 | "\n", 117 | "#Location of Datapreprocessing file is in Colab Notebooks folder\n", 118 | "#Changing the directory to Colab Notebooks\n", 119 | "%cd '/content/gdrive/My Drive/Colab Notebooks/Face mask/new'\n", 120 | "from DataPreprocessing import Preprocessing\n", 121 | "from CNN import ResNet15, ResNet9" 122 | ], 123 | "execution_count": null, 124 | "outputs": [ 125 | { 126 | "output_type": "stream", 127 | "text": [ 128 | "/content/gdrive/My Drive/Colab Notebooks/Face mask\n", 129 | "importing Jupyter notebook from DataPreprocessing.ipynb\n", 130 | "importing Jupyter notebook from CNN.ipynb\n" 131 | ], 132 | "name": "stdout" 133 | } 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "metadata": { 139 | "id": "k2mkHfYYlZUK", 140 | "colab_type": "code", 141 | "colab": {} 142 | }, 143 | "source": [ 144 | "#Directory of the dataset\n", 145 | "directory = 'dataset'" 146 | ], 147 | "execution_count": null, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "metadata": { 153 | "id": "2FXUNgRgyXRy", 154 | "colab_type": "code", 155 | "colab": {} 156 | }, 157 | "source": [ 158 | "data = Preprocessing(path = {'test':directory+'/test',\n", 159 | " 'train':directory+'/train'})\n", 160 | "train_ds, test_ds = data.preprocessed_dataset()" 161 | ], 162 | "execution_count": null, 163 | "outputs": [] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "metadata": { 168 | "id": "i_jtQSbPrVPy", 169 | "colab_type": "code", 170 | "colab": {} 171 | }, 172 | "source": [ 173 | "#Splitting and creating batches of data\n", 174 | "batch_size = 20\n", 175 | "val_size = math.ceil(len(test_ds)* 20/100)\n", 176 | "test_size = len(test_ds) - val_size \n", 177 | "\n", 178 | "#splitting the data\n", 179 | "val_ds, test_ds = random_split(test_ds, [val_size, test_size])\n", 180 | "\n", 181 | "#creating data loader\n", 182 | "train_dl = DataLoader(train_ds, batch_size, num_workers=3, shuffle = True,\n", 183 | " pin_memory=True)\n", 184 | "val_dl = DataLoader(val_ds, batch_size, num_workers=3,\n", 185 | " pin_memory=True)\n" 186 | ], 187 | "execution_count": null, 188 | "outputs": [] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "metadata": { 193 | "id": "6i-i4jQa0t3W", 194 | "colab_type": "code", 195 | "colab": {} 196 | }, 197 | "source": [ 198 | "x,y = next(iter(train_dl))" 199 | ], 200 | "execution_count": null, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "KcDR8q8Bci4-", 207 | "colab_type": "code", 208 | "colab": { 209 | "base_uri": "https://localhost:8080/", 210 | "height": 154 211 | }, 212 | "outputId": "3a641101-630d-44b9-addb-126e83eb97b5" 213 | }, 214 | "source": [ 215 | "\n", 216 | "#Visualizing \n", 217 | "fig, ax = plt.subplots(figsize=(12,12))\n", 218 | "ax.imshow(make_grid(x[:10], nrow=10).permute(1,2,0))\n", 219 | "x.shape" 220 | ], 221 | "execution_count": null, 222 | "outputs": [ 223 | { 224 | "output_type": "stream", 225 | "text": [ 226 | "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n" 227 | ], 228 | "name": "stderr" 229 | }, 230 | { 231 | "output_type": "execute_result", 232 | "data": { 233 | "text/plain": [ 234 | "torch.Size([20, 1, 128, 128])" 235 | ] 236 | }, 237 | "metadata": { 238 | "tags": [] 239 | }, 240 | "execution_count": 9 241 | }, 242 | { 243 | "output_type": "display_data", 244 | "data": { 245 | "image/png": "\n", 246 | "text/plain": [ 247 | "
" 248 | ] 249 | }, 250 | "metadata": { 251 | "tags": [], 252 | "needs_background": "light" 253 | } 254 | } 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "metadata": { 260 | "id": "8rLp7Toowj6l", 261 | "colab_type": "code", 262 | "colab": {} 263 | }, 264 | "source": [ 265 | "#Activating cuda\n", 266 | " #Uses cuda if available else cpu\n", 267 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 268 | ], 269 | "execution_count": null, 270 | "outputs": [] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "metadata": { 275 | "id": "iSBT4xHs-dl9", 276 | "colab_type": "code", 277 | "colab": {} 278 | }, 279 | "source": [ 280 | "model = ResNet9(1,2).to(device)" 281 | ], 282 | "execution_count": null, 283 | "outputs": [] 284 | }, 285 | { 286 | "cell_type": "code", 287 | "metadata": { 288 | "id": "s9AI7OZ5G3wJ", 289 | "colab_type": "code", 290 | "colab": {} 291 | }, 292 | "source": [ 293 | "def get_lr(optimizer):\n", 294 | " '''\n", 295 | " params: optimizer: optimizer of the model\n", 296 | " return: learning rate\n", 297 | " '''\n", 298 | " \n", 299 | " for x in optimizer.param_groups:\n", 300 | " return x['lr']" 301 | ], 302 | "execution_count": null, 303 | "outputs": [] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "id": "gNczvHaq_iWy", 309 | "colab_type": "code", 310 | "colab": {} 311 | }, 312 | "source": [ 313 | "def fit(model, epochs, max_lr, train_dl, val_dl, weight_decay=0, \n", 314 | " optim = torch.optim.Adam, grad_clip = None):\n", 315 | " '''\n", 316 | " Arguments:\n", 317 | "\n", 318 | " model --> CNN model\n", 319 | " epochs --> number of epoch\n", 320 | " max_lr --> max learning rate for One-cycle lr scheduler\n", 321 | " train_dl, val_dl --> training and validation dataloader\n", 322 | " weight_decay --> regularizing parametric(reduces the weights to avoid \n", 323 | " overfitting)\n", 324 | " optim --> optimizer\n", 325 | " grad_clip --> regularizing parametric(limit the values of gradients to a \n", 326 | " small range)\n", 327 | " \n", 328 | " Return:\n", 329 | "\n", 330 | " history --> {\n", 331 | " training_loss --> training loss for every epochs\n", 332 | " validation_loss --> validation loss for every epochs\n", 333 | " validation_acc --> validation accuracy for every epochs\n", 334 | " }\n", 335 | " '''\n", 336 | " model.train() #training\n", 337 | " torch.cuda.empty_cache() #releases cache memory improves performance\n", 338 | "\n", 339 | " #Defining optimizer and loss\n", 340 | " optimizer = optim(model.parameters(), max_lr, weight_decay = weight_decay)\n", 341 | " criterion = nn.NLLLoss()\n", 342 | "\n", 343 | " #One-cycle learning rate scheduler\n", 344 | " sched = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr,\n", 345 | " steps_per_epoch = len(train_dl) ,epochs = epochs)\n", 346 | "\n", 347 | " val_loss_min = np.inf #initial value at infinity\n", 348 | "\n", 349 | " history = {\n", 350 | " 'training_loss': [], #training loss for every epochs\n", 351 | " 'validation_loss': [], #validation loss for every epochs\n", 352 | " 'validation_acc': [], #validation accuracy for every epochs\n", 353 | " 'learning_rates': [] #lr for every epochs \n", 354 | " }\n", 355 | " \n", 356 | " for epoch in range(epochs):\n", 357 | "\n", 358 | " batch_train_loss = 0 #training loss for every batch\n", 359 | " batch_val_loss = 0 #validation loss for every batch\n", 360 | " batch_val_acc = 0 #valdation acc for every batch\n", 361 | "\n", 362 | " #looping through training dataloader\n", 363 | " for imgs, labels in train_dl:\n", 364 | "\n", 365 | " #Sending images and labels to GPU\n", 366 | " imgs, labels = imgs.to(device), labels.to(device)\n", 367 | "\n", 368 | " predictions = model(imgs)\n", 369 | " loss = criterion(predictions, labels)\n", 370 | " batch_train_loss += loss\n", 371 | " loss.backward()\n", 372 | "\n", 373 | " #gradient clipping\n", 374 | " if grad_clip:\n", 375 | " nn.utils.clip_grad_value_(model.parameters(), grad_clip)\n", 376 | " \n", 377 | " optimizer.step()\n", 378 | " optimizer.zero_grad()\n", 379 | " \n", 380 | " #record and update lr\n", 381 | " history['learning_rates'].append(get_lr(optimizer))\n", 382 | " sched.step()\n", 383 | "\n", 384 | " #shutting down the gradients\n", 385 | " with torch.no_grad():\n", 386 | " model.eval() #for evaluating\n", 387 | "\n", 388 | " #looping through validation dataloader\n", 389 | " for imgs, labels in val_dl:\n", 390 | " #Sending images and labels to GPU\n", 391 | " imgs, labels = imgs.to(device), labels.to(device)\n", 392 | "\n", 393 | " predictions = model(imgs)\n", 394 | " loss = criterion(predictions, labels) \n", 395 | "\n", 396 | " #Converting predictions to probabilities\n", 397 | " prob = torch.exp(predictions)\n", 398 | " #Finding the index of maximum probability\n", 399 | " _, index = torch.max(prob, dim=1)\n", 400 | " \n", 401 | " batch_val_loss += loss.item()\n", 402 | " batch_val_acc += torch.sum(index == labels).item() / len(index)\n", 403 | " \n", 404 | " #loss and accuracy per epoch\n", 405 | " training_loss = batch_train_loss / len(train_dl)\n", 406 | " val_loss = batch_val_loss / len(val_dl)\n", 407 | " val_acc = batch_val_acc / len(val_dl)\n", 408 | "\n", 409 | " history['training_loss'].append(training_loss)\n", 410 | " history['validation_loss'].append(val_loss)\n", 411 | " history['validation_acc'].append(val_acc)\n", 412 | "\n", 413 | " #Verbose\n", 414 | " print('Epochs: {}, training_loss: {:.4f}, val_loss: {:.4f},\\\n", 415 | " val_acc: {:.4f}'.format(epoch, training_loss, val_loss, val_acc))\n", 416 | " \n", 417 | " if val_loss <= val_loss_min:\n", 418 | " '''\n", 419 | " Initailly the val_loss_min is set at infinity and gets \n", 420 | " updated each time val_loss is lesser than val_loss_min.\n", 421 | " Basically it saves the model with least validation loss.\n", 422 | " ''' \n", 423 | " torch.save(model.state_dict(), 'state_dict_resnet9.pth')\n", 424 | " print('{} --> {}. Saving ...'.format(val_loss_min, val_loss))\n", 425 | " val_loss_min = val_loss\n", 426 | "\n", 427 | " return history\n", 428 | " " 429 | ], 430 | "execution_count": null, 431 | "outputs": [] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "metadata": { 436 | "id": "ZlGBZSj9CfaI", 437 | "colab_type": "code", 438 | "colab": { 439 | "base_uri": "https://localhost:8080/", 440 | "height": 326 441 | }, 442 | "outputId": "243435e2-dbbd-48f8-e5fe-acbd46f35656" 443 | }, 444 | "source": [ 445 | "history = fit(model, 10, 0.0001, train_dl, val_dl, weight_decay=0.01, \n", 446 | " optim = torch.optim.Adam, grad_clip = 0.1)" 447 | ], 448 | "execution_count": null, 449 | "outputs": [ 450 | { 451 | "output_type": "stream", 452 | "text": [ 453 | "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py:100: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 454 | " input = module(input)\n" 455 | ], 456 | "name": "stderr" 457 | }, 458 | { 459 | "output_type": "stream", 460 | "text": [ 461 | "Epochs: 0, training_loss: 1.1460, val_loss: 0.2883, val_acc: 0.9000\n", 462 | "inf --> 0.2882944643497467. Saving ...\n", 463 | "Epochs: 1, training_loss: 0.3747, val_loss: 0.0648, val_acc: 1.0000\n", 464 | "0.2882944643497467 --> 0.06478280574083328. Saving ...\n", 465 | "Epochs: 2, training_loss: 0.1320, val_loss: 0.1003, val_acc: 1.0000\n", 466 | "Epochs: 3, training_loss: 0.1161, val_loss: 0.0954, val_acc: 0.9500\n", 467 | "Epochs: 4, training_loss: 0.0595, val_loss: 0.0122, val_acc: 1.0000\n", 468 | "0.06478280574083328 --> 0.012150990776717663. Saving ...\n", 469 | "Epochs: 5, training_loss: 0.0521, val_loss: 0.0022, val_acc: 1.0000\n", 470 | "0.012150990776717663 --> 0.002183937933295965. Saving ...\n", 471 | "Epochs: 6, training_loss: 0.0207, val_loss: 0.0057, val_acc: 1.0000\n", 472 | "Epochs: 7, training_loss: 0.0162, val_loss: 0.0021, val_acc: 1.0000\n", 473 | "0.002183937933295965 --> 0.0020512579940259457. Saving ...\n", 474 | "Epochs: 8, training_loss: 0.0102, val_loss: 0.0037, val_acc: 1.0000\n", 475 | "Epochs: 9, training_loss: 0.0083, val_loss: 0.0028, val_acc: 1.0000\n" 476 | ], 477 | "name": "stdout" 478 | } 479 | ] 480 | }, 481 | { 482 | "cell_type": "code", 483 | "metadata": { 484 | "id": "tA52b5f3kX5C", 485 | "colab_type": "code", 486 | "colab": {} 487 | }, 488 | "source": [ 489 | "training_loss = history['training_loss']\n", 490 | "val_loss = history['validation_loss']\n", 491 | "val_acc = history['validation_acc']\n", 492 | "lr = history['learning_rates']" 493 | ], 494 | "execution_count": null, 495 | "outputs": [] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "metadata": { 500 | "id": "T4B79tghpt6z", 501 | "colab_type": "code", 502 | "colab": { 503 | "base_uri": "https://localhost:8080/", 504 | "height": 448 505 | }, 506 | "outputId": "a74fe9d8-0b58-4847-ef4f-a25acc9f70de" 507 | }, 508 | "source": [ 509 | "plt.figure(figsize=(10,7))\n", 510 | "plt.plot(val_loss, 'b-1', markersize=12, markeredgewidth=3\n", 511 | " , linewidth=2,label='val_loss')\n", 512 | "plt.plot(val_acc, 'r--x', markersize=8, markeredgewidth=3\n", 513 | " , linewidth=3,label='val_acc')\n", 514 | "plt.plot(training_loss, 'c-1', markersize=12, markeredgewidth=3\n", 515 | " , linewidth=2,label='training_loss')\n", 516 | "plt.grid()\n", 517 | "plt.legend(loc='center right')" 518 | ], 519 | "execution_count": null, 520 | "outputs": [ 521 | { 522 | "output_type": "execute_result", 523 | "data": { 524 | "text/plain": [ 525 | "" 526 | ] 527 | }, 528 | "metadata": { 529 | "tags": [] 530 | }, 531 | "execution_count": 16 532 | }, 533 | { 534 | "output_type": "display_data", 535 | "data": { 536 | "image/png": "\n", 537 | "text/plain": [ 538 | "
" 539 | ] 540 | }, 541 | "metadata": { 542 | "tags": [], 543 | "needs_background": "light" 544 | } 545 | } 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "nmnYlGVxp_RM", 552 | "colab_type": "code", 553 | "colab": { 554 | "base_uri": "https://localhost:8080/", 555 | "height": 282 556 | }, 557 | "outputId": "b73f806f-de46-4618-b7a5-f072c7516437" 558 | }, 559 | "source": [ 560 | "plt.plot(lr,'r', label='lr')\n", 561 | "plt.legend()" 562 | ], 563 | "execution_count": null, 564 | "outputs": [ 565 | { 566 | "output_type": "execute_result", 567 | "data": { 568 | "text/plain": [ 569 | "" 570 | ] 571 | }, 572 | "metadata": { 573 | "tags": [] 574 | }, 575 | "execution_count": 17 576 | }, 577 | { 578 | "output_type": "display_data", 579 | "data": { 580 | "image/png": "\n", 581 | "text/plain": [ 582 | "
" 583 | ] 584 | }, 585 | "metadata": { 586 | "tags": [], 587 | "needs_background": "light" 588 | } 589 | } 590 | ] 591 | } 592 | ] 593 | } --------------------------------------------------------------------------------