├── .idea ├── .gitignore ├── FireSmokeDetectionByEfficientNet.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── Conv_visual.py ├── LICENSE ├── README.md ├── cropdata └── train │ ├── fire │ └── 000016.jpg │ └── smoke │ └── 000002.png ├── efficientnet_pytorch ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── fire_smoke_model.cpython-37.pyc │ ├── model.cpython-37.pyc │ └── utils.cpython-37.pyc ├── fire_smoke_model.py ├── model.py └── utils.py ├── examples ├── imagenet │ ├── README.md │ ├── data │ │ └── README.md │ └── main.py └── simple │ ├── check.ipynb │ ├── example.ipynb │ ├── fire_smoke_map.txt │ ├── img.jpg │ ├── img2.jpg │ └── labels_map.txt ├── featmap ├── 5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg ├── 5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg └── 5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg ├── fire_smoke_demo.py ├── fire_smoke_detection.py ├── model_vusal.py ├── results ├── acc_loss.png ├── det_results000127.jpg ├── result_000127.jpg └── result_7e9ee24563cc31d34de2020e1acaecc5.jpeg ├── setup.py ├── tests ├── 000127.jpg └── 7e9ee24563cc31d34de2020e1acaecc5.jpeg └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml 3 | -------------------------------------------------------------------------------- /.idea/FireSmokeDetectionByEfficientNet.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Conv_visual.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: Conv_visual.py 5 | @Author:kong 6 | @Time: 2020年01月07日09时59分 7 | @Description:可视化fire&smoke 模型 8 | ''' 9 | 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from torch.autograd import Variable 15 | import json 16 | from PIL import Image, ImageDraw, ImageFont 17 | from torchvision import transforms 18 | from torchvision import models 19 | from efficientnet_pytorch import FireSmokeEfficientNet 20 | import collections 21 | 22 | def preprocess_image(cv2im, resize_im=True): 23 | """ 24 | Processes image for CNNs 25 | 26 | Args: 27 | PIL_img (PIL_img): Image to process 28 | resize_im (bool): Resize to 224 or not 29 | returns: 30 | im_as_var (Pytorch variable): Variable that contains processed float tensor 31 | """ 32 | # mean and std list for channels (Imagenet) 33 | # mean and std list for channels (Imagenet) 34 | mean = [0.485, 0.456, 0.406] 35 | std = [0.229, 0.224, 0.225] 36 | # Resize image 37 | if resize_im: 38 | cv2im = cv2.resize(cv2im, (224, 224)) 39 | im_as_arr = np.float32(cv2im) 40 | im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1]) 41 | im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H 42 | # Normalize the channels 43 | for channel, _ in enumerate(im_as_arr): 44 | im_as_arr[channel] /= 255 45 | im_as_arr[channel] -= mean[channel] 46 | im_as_arr[channel] /= std[channel] 47 | # Convert to float tensor 48 | im_as_ten = torch.from_numpy(im_as_arr).float() 49 | # Add one more channel to the beginning. Tensor shape = 1,3,224,224 50 | im_as_ten.unsqueeze_(0) 51 | # Convert to Pytorch variable 52 | im_as_var = Variable(im_as_ten, requires_grad=True) 53 | return im_as_var 54 | 55 | def getPretrainedModel(): 56 | model_para = collections.OrderedDict() 57 | model_tmp = FireSmokeEfficientNet.from_arch('efficientnet-b0') 58 | # out_channels = model._fc.in_features 59 | model_tmp._fc = torch.nn.Linear(1280, 3) 60 | modelpara = torch.load('./checkpoint.pth.tar') 61 | # print(modelpara['state_dict'].keys()) 62 | for key in modelpara['state_dict'].keys(): 63 | model_para[key[7:]] = modelpara['state_dict'][key] 64 | model_tmp.load_state_dict(model_para) 65 | return model_tmp 66 | 67 | 68 | class FeatureVisualization(): 69 | def __init__(self,img_path,selected_layer): 70 | self.img_path=img_path 71 | self.image = cv2.imread(img_path) 72 | self.selected_layer=selected_layer 73 | self.pretrained_model = getPretrainedModel() 74 | 75 | def process_image(self): 76 | img=cv2.imread(self.img_path) 77 | img=preprocess_image(img) 78 | return img 79 | 80 | def get_feature(self): 81 | # input = Variable(torch.randn(1, 3, 224, 224)) 82 | input=self.process_image() 83 | print(input.shape) 84 | # x=input 85 | 86 | x = self.pretrained_model._swish(self.pretrained_model._bn0(self.pretrained_model._conv_stem(input))) 87 | for index,layer in enumerate(self.pretrained_model._blocks): 88 | x=layer(x) 89 | if (index == self.selected_layer): 90 | return x 91 | # x = self.pretrained_model._conv_head(x) 92 | # return x 93 | 94 | 95 | def get_single_feature(self): 96 | features=self.get_feature() 97 | # print('特征是:',features.shape) 98 | 99 | feature=features[:,6,:,:] 100 | print(feature.shape) 101 | 102 | feature=feature.view(feature.shape[1],feature.shape[2]) 103 | print(feature.shape) 104 | 105 | return feature 106 | 107 | def get_all_feature(self): 108 | features=self.get_feature() 109 | # print(':',features.shape) 110 | 111 | feature=features[:,:,:,:] 112 | print(feature.shape) 113 | 114 | feature=feature.view(feature.shape[1],feature.shape[2],feature.shape[3]) 115 | print(feature.shape) 116 | 117 | return feature 118 | 119 | def save_feature_to_img(self): 120 | #to numpy 121 | feature=self.get_all_feature() 122 | feature=feature.data.numpy() 123 | 124 | #use sigmod to [0,1] 125 | feature= 1.0/(1+np.exp(-1*feature)) 126 | 127 | # to [0,255] 128 | feature=np.round(feature*255) 129 | print(feature[0]) 130 | print("image size:",self.image.shape) 131 | print("feature map size:",feature.shape) 132 | Nch = feature.shape[0] #channel num 133 | for i in range(Nch): 134 | print('--------------show the {}th channels-----------------'.format(i)) 135 | feature_resize = cv2.resize(feature[i], (self.image.shape[1], self.image.shape[0])) 136 | feature_resize = cv2.cvtColor(feature_resize, cv2.COLOR_GRAY2BGR) 137 | feature_resize = cv2.putText(feature_resize,'first_depthwise_conv {}th feature map'.format(i),(10,30),cv2.FONT_HERSHEY_COMPLEX,1,(0,255,0),3) 138 | image_cont = np.concatenate([self.image,feature_resize],axis=0) 139 | cv2.imwrite('featmap/5rd_depthwise_conv_featmap{}_{}'.format(i,self.img_path.split('/')[-1]),image_cont) 140 | 141 | if __name__=='__main__': 142 | # get class 143 | myClass=FeatureVisualization('./tests/000294.jpg',1) 144 | # print (myClass.pretrained_model) 145 | print("-----------------------------------------------------") 146 | print(myClass.pretrained_model) 147 | 148 | myClass.save_feature_to_img() 149 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FireSmokeDetectionByEfficientNet 2 | [EfficientNet](https://arxiv.org/abs/1905.11946) is a wonderful classification network. the efficientnet implementation refer to [[EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)](https://github.com/lukemelas/EfficientNet-PyTorch). I refer much of the code from the above implementation, and construct a network with the efficientnet feature extract layers and a fc layer of 3 output node for my fire & smoke detection task. 3 | 4 | ## 1. Introduction 5 | 6 | Fire and smoke classification and detection by training an efficientnet classifier,detections are implemented through classify crop patch of the input image. 7 | 8 | ## 2. Requirements 9 | 10 | Python 3.7、PyTorch1.3... 11 | 12 | ### 3. Running demos 13 | 14 | Download the pretrained model and simkai.ttf font from [BaiduNetDisk:](https://pan.baidu.com/s/14CM-U6bmVjXG6gNQ2CC8fw) code: awnf 15 | 16 | simple run: 17 | 18 | ```shell 19 | python fire_smoke_demo.py 20 | ``` 21 | 22 | will get the classification results as follows: 23 | 24 | ![avatar](./results/result_7e9ee24563cc31d34de2020e1acaecc5.jpeg) 25 | 26 | or try the detection demo: 27 | 28 | ```shell 29 | python fire_smoke_detection.py 30 | ``` 31 | 32 | will get results: 33 | 34 | ![avatar](./results/det_results000127.jpg) 35 | 36 | ## 4. Visual the CNN 37 | 38 | I visual the activation of some of the feature map as follows: 39 | 40 | ![avatar](./featmap/5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg) 41 | 42 | ![avatar](./featmap/5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg) 43 | 44 | As we can see, the cnn can automatically learn the edge or shape or ROI of the predict targets, 45 | 46 | some of the filters can recognize the fire, some can deal with other features. 47 | 48 | ## 5. Train Custom Dataset 49 | 50 | here offer a scrip to train ur own classification model using EfficientNet: 51 | 52 | ```shell 53 | python train.py --data [ur dataset path] --arch [efficientnet model:efficientnet-b0-7] --num_cls [ur task class num] 54 | ``` 55 | 56 | Refer to the cold for args details. 57 | 58 | For inference, u should change the label map txt for ur task. 59 | 60 | In my task: 61 | 62 | ```shell 63 | python train.py --data ./cropdata --arch efficientnet-b0 --num_cls 3 64 | ``` 65 | 66 | ![avatar](./results/acc_loss.png) 67 | 68 | ## Dataset 69 | https://pan.baidu.com/s/1eRXtYVrn6baJ6PRMOTjyzQ code: srav 70 | -------------------------------------------------------------------------------- /cropdata/train/fire/000016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/cropdata/train/fire/000016.jpg -------------------------------------------------------------------------------- /cropdata/train/smoke/000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/cropdata/train/smoke/000002.png -------------------------------------------------------------------------------- /efficientnet_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.1" 2 | from .fire_smoke_model import FireSmokeEfficientNet 3 | from .utils import ( 4 | GlobalParams, 5 | BlockArgs, 6 | BlockDecoder, 7 | efficientnet, 8 | get_model_params, 9 | ) 10 | 11 | -------------------------------------------------------------------------------- /efficientnet_pytorch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /efficientnet_pytorch/__pycache__/fire_smoke_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/fire_smoke_model.cpython-37.pyc -------------------------------------------------------------------------------- /efficientnet_pytorch/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /efficientnet_pytorch/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /efficientnet_pytorch/fire_smoke_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: fire_smoke_model.py 5 | @Author:konglingran 6 | @Time: 2020年01月02日15时52分 7 | @Description: 8 | ''' 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | 13 | from .utils import ( 14 | round_filters, 15 | round_repeats, 16 | drop_connect, 17 | get_same_padding_conv2d, 18 | get_model_params, 19 | efficientnet_params, 20 | load_pretrained_weights, 21 | Swish, 22 | MemoryEfficientSwish, 23 | ) 24 | 25 | 26 | class MBConvBlock(nn.Module): 27 | """ 28 | Mobile Inverted Residual Bottleneck Block 29 | 30 | Args: 31 | block_args (namedtuple): BlockArgs, see above 32 | global_params (namedtuple): GlobalParam, see above 33 | 34 | Attributes: 35 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 36 | """ 37 | 38 | def __init__(self, block_args, global_params): 39 | super().__init__() 40 | self._block_args = block_args 41 | self._bn_mom = 1 - global_params.batch_norm_momentum 42 | self._bn_eps = global_params.batch_norm_epsilon 43 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 44 | self.id_skip = block_args.id_skip # skip connection and drop connect 45 | 46 | # Get static or dynamic convolution depending on image size 47 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 48 | 49 | # Expansion phase 50 | inp = self._block_args.input_filters # number of input channels 51 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 52 | if self._block_args.expand_ratio != 1: 53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 55 | 56 | # Depthwise convolution phase 57 | k = self._block_args.kernel_size 58 | s = self._block_args.stride 59 | self._depthwise_conv = Conv2d( 60 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 61 | kernel_size=k, stride=s, bias=False) 62 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 63 | 64 | # Squeeze and Excitation layer, if desired 65 | if self.has_se: 66 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 67 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 68 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 69 | 70 | # Output phase 71 | final_oup = self._block_args.output_filters 72 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 73 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 74 | self._swish = MemoryEfficientSwish() 75 | 76 | def forward(self, inputs, drop_connect_rate=None): 77 | """ 78 | :param inputs: input tensor 79 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 80 | :return: output of block 81 | """ 82 | 83 | # Expansion and Depthwise Convolution 84 | x = inputs 85 | if self._block_args.expand_ratio != 1: 86 | x = self._swish(self._bn0(self._expand_conv(inputs))) 87 | x = self._swish(self._bn1(self._depthwise_conv(x))) 88 | 89 | # Squeeze and Excitation 90 | if self.has_se: 91 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 92 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 93 | x = torch.sigmoid(x_squeezed) * x 94 | 95 | x = self._bn2(self._project_conv(x)) 96 | 97 | # Skip connection and drop connect 98 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 99 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 100 | if drop_connect_rate: 101 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 102 | x = x + inputs # skip connection 103 | return x 104 | 105 | def set_swish(self, memory_efficient=True): 106 | """Sets swish function as memory efficient (for training) or standard (for export)""" 107 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 108 | 109 | 110 | class FireSmokeEfficientNet(nn.Module): 111 | """ 112 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 113 | 114 | Args: 115 | blocks_args (list): A list of BlockArgs to construct blocks 116 | global_params (namedtuple): A set of GlobalParams shared between blocks 117 | 118 | Example: 119 | model = EfficientNet.from_pretrained('efficientnet-b0') 120 | 121 | """ 122 | 123 | def __init__(self, blocks_args=None, global_params=None): 124 | super().__init__() 125 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 126 | assert len(blocks_args) > 0, 'block args must be greater than 0' 127 | self._global_params = global_params 128 | self._blocks_args = blocks_args 129 | 130 | # Get static or dynamic convolution depending on image size 131 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 132 | 133 | # Batch norm parameters 134 | bn_mom = 1 - self._global_params.batch_norm_momentum 135 | bn_eps = self._global_params.batch_norm_epsilon 136 | 137 | # Stem 138 | in_channels = 3 # rgb 139 | out_channels = round_filters(32, self._global_params) # number of output channels 140 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 141 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 142 | 143 | # Build blocks 144 | self._blocks = nn.ModuleList([]) 145 | for block_args in self._blocks_args: 146 | 147 | # Update block input and output filters based on depth multiplier. 148 | block_args = block_args._replace( 149 | input_filters=round_filters(block_args.input_filters, self._global_params), 150 | output_filters=round_filters(block_args.output_filters, self._global_params), 151 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 152 | ) 153 | 154 | # The first block needs to take care of stride and filter size increase. 155 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 156 | if block_args.num_repeat > 1: 157 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 158 | for _ in range(block_args.num_repeat - 1): 159 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 160 | 161 | # Head 162 | in_channels = block_args.output_filters # output of final block 163 | out_channels = round_filters(1280, self._global_params) 164 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 165 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 166 | 167 | # Final linear layer 168 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 169 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 170 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 171 | self._swish = MemoryEfficientSwish() 172 | 173 | def set_swish(self, memory_efficient=True): 174 | """Sets swish function as memory efficient (for training) or standard (for export)""" 175 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 176 | for block in self._blocks: 177 | block.set_swish(memory_efficient) 178 | 179 | def extract_features(self, inputs): 180 | """ Returns output of the final convolution layer """ 181 | 182 | # Stem 183 | x = self._swish(self._bn0(self._conv_stem(inputs))) 184 | 185 | # Blocks 186 | for idx, block in enumerate(self._blocks): 187 | drop_connect_rate = self._global_params.drop_connect_rate 188 | if drop_connect_rate: 189 | drop_connect_rate *= float(idx) / len(self._blocks) 190 | x = block(x, drop_connect_rate=drop_connect_rate) 191 | 192 | # Head 193 | x = self._swish(self._bn1(self._conv_head(x))) 194 | 195 | return x 196 | 197 | def forward(self, inputs): 198 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 199 | bs = inputs.size(0) 200 | # Convolution layers 201 | x = self.extract_features(inputs) 202 | 203 | # Pooling and final linear layer 204 | x = self._avg_pooling(x) 205 | x = x.view(bs, -1) 206 | x = self._dropout(x) 207 | x = self._fc(x) 208 | return x 209 | 210 | @classmethod 211 | def from_name(cls, model_name, override_params=None): 212 | cls._check_model_name_is_valid(model_name) 213 | blocks_args, global_params = get_model_params(model_name, override_params) 214 | return cls(blocks_args, global_params) 215 | 216 | @classmethod 217 | def from_arch(cls, model_name, override_params=None): 218 | cls._check_model_name_is_valid(model_name) 219 | blocks_args, global_params = get_model_params(model_name, override_params) 220 | return cls(blocks_args, global_params) 221 | 222 | @classmethod 223 | def from_pretrained(cls, args, num_classes=1000, in_channels=3): 224 | model = cls.from_name(args.arch, override_params={'num_classes': num_classes}) 225 | load_pretrained_weights(model, args.arch, load_fc=(num_classes == 1000)) 226 | if in_channels != 3: 227 | Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size) 228 | out_channels = round_filters(32, model._global_params) 229 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 230 | out_channels = round_filters(1280, model._global_params) 231 | model._fc = nn.Linear(out_channels, args.num_cls) 232 | return model 233 | 234 | @classmethod 235 | def from_pretrained(cls, args, num_classes=1000): 236 | model = cls.from_name(args.arch, override_params={'num_classes': num_classes}) 237 | load_pretrained_weights(model, args.arch, load_fc=(num_classes == 1000)) 238 | out_channels = round_filters(1280, model._global_params) 239 | model._fc = nn.Linear(out_channels, args.num_cls) 240 | return model 241 | 242 | @classmethod 243 | def get_image_size(cls, model_name): 244 | cls._check_model_name_is_valid(model_name) 245 | _, _, res, _ = efficientnet_params(model_name) 246 | return res 247 | 248 | @classmethod 249 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): 250 | """ Validates model name. None that pretrained weights are only available for 251 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ 252 | num_models = 4 if also_need_pretrained_weights else 8 253 | valid_models = ['efficientnet-b' + str(i) for i in range(num_models)] 254 | if model_name not in valid_models: 255 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 256 | -------------------------------------------------------------------------------- /efficientnet_pytorch/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .utils import ( 6 | round_filters, 7 | round_repeats, 8 | drop_connect, 9 | get_same_padding_conv2d, 10 | get_model_params, 11 | efficientnet_params, 12 | load_pretrained_weights, 13 | Swish, 14 | MemoryEfficientSwish, 15 | ) 16 | 17 | class MBConvBlock(nn.Module): 18 | """ 19 | Mobile Inverted Residual Bottleneck Block 20 | 21 | Args: 22 | block_args (namedtuple): BlockArgs, see above 23 | global_params (namedtuple): GlobalParam, see above 24 | 25 | Attributes: 26 | has_se (bool): Whether the block contains a Squeeze and Excitation layer. 27 | """ 28 | 29 | def __init__(self, block_args, global_params): 30 | super().__init__() 31 | self._block_args = block_args 32 | self._bn_mom = 1 - global_params.batch_norm_momentum 33 | self._bn_eps = global_params.batch_norm_epsilon 34 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) 35 | self.id_skip = block_args.id_skip # skip connection and drop connect 36 | 37 | # Get static or dynamic convolution depending on image size 38 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 39 | 40 | # Expansion phase 41 | inp = self._block_args.input_filters # number of input channels 42 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels 43 | if self._block_args.expand_ratio != 1: 44 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) 45 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 46 | 47 | # Depthwise convolution phase 48 | k = self._block_args.kernel_size 49 | s = self._block_args.stride 50 | self._depthwise_conv = Conv2d( 51 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise 52 | kernel_size=k, stride=s, bias=False) 53 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) 54 | 55 | # Squeeze and Excitation layer, if desired 56 | if self.has_se: 57 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) 58 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) 59 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) 60 | 61 | # Output phase 62 | final_oup = self._block_args.output_filters 63 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) 64 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) 65 | self._swish = MemoryEfficientSwish() 66 | 67 | def forward(self, inputs, drop_connect_rate=None): 68 | """ 69 | :param inputs: input tensor 70 | :param drop_connect_rate: drop connect rate (float, between 0 and 1) 71 | :return: output of block 72 | """ 73 | 74 | # Expansion and Depthwise Convolution 75 | x = inputs 76 | if self._block_args.expand_ratio != 1: 77 | x = self._swish(self._bn0(self._expand_conv(inputs))) 78 | x = self._swish(self._bn1(self._depthwise_conv(x))) 79 | 80 | # Squeeze and Excitation 81 | if self.has_se: 82 | x_squeezed = F.adaptive_avg_pool2d(x, 1) 83 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed))) 84 | x = torch.sigmoid(x_squeezed) * x 85 | 86 | x = self._bn2(self._project_conv(x)) 87 | 88 | # Skip connection and drop connect 89 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters 90 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: 91 | if drop_connect_rate: 92 | x = drop_connect(x, p=drop_connect_rate, training=self.training) 93 | x = x + inputs # skip connection 94 | return x 95 | 96 | def set_swish(self, memory_efficient=True): 97 | """Sets swish function as memory efficient (for training) or standard (for export)""" 98 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 99 | 100 | 101 | class EfficientNet(nn.Module): 102 | """ 103 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods 104 | 105 | Args: 106 | blocks_args (list): A list of BlockArgs to construct blocks 107 | global_params (namedtuple): A set of GlobalParams shared between blocks 108 | 109 | Example: 110 | model = EfficientNet.from_pretrained('efficientnet-b0') 111 | 112 | """ 113 | 114 | def __init__(self, blocks_args=None, global_params=None): 115 | super().__init__() 116 | assert isinstance(blocks_args, list), 'blocks_args should be a list' 117 | assert len(blocks_args) > 0, 'block args must be greater than 0' 118 | self._global_params = global_params 119 | self._blocks_args = blocks_args 120 | 121 | # Get static or dynamic convolution depending on image size 122 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) 123 | 124 | # Batch norm parameters 125 | bn_mom = 1 - self._global_params.batch_norm_momentum 126 | bn_eps = self._global_params.batch_norm_epsilon 127 | 128 | # Stem 129 | in_channels = 3 # rgb 130 | out_channels = round_filters(32, self._global_params) # number of output channels 131 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 132 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 133 | 134 | # Build blocks 135 | self._blocks = nn.ModuleList([]) 136 | for block_args in self._blocks_args: 137 | 138 | # Update block input and output filters based on depth multiplier. 139 | block_args = block_args._replace( 140 | input_filters=round_filters(block_args.input_filters, self._global_params), 141 | output_filters=round_filters(block_args.output_filters, self._global_params), 142 | num_repeat=round_repeats(block_args.num_repeat, self._global_params) 143 | ) 144 | 145 | # The first block needs to take care of stride and filter size increase. 146 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 147 | if block_args.num_repeat > 1: 148 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) 149 | for _ in range(block_args.num_repeat - 1): 150 | self._blocks.append(MBConvBlock(block_args, self._global_params)) 151 | 152 | # Head 153 | in_channels = block_args.output_filters # output of final block 154 | out_channels = round_filters(1280, self._global_params) 155 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 156 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) 157 | 158 | # Final linear layer 159 | self._avg_pooling = nn.AdaptiveAvgPool2d(1) 160 | self._dropout = nn.Dropout(self._global_params.dropout_rate) 161 | self._fc = nn.Linear(out_channels, self._global_params.num_classes) 162 | self._swish = MemoryEfficientSwish() 163 | 164 | def set_swish(self, memory_efficient=True): 165 | """Sets swish function as memory efficient (for training) or standard (for export)""" 166 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish() 167 | for block in self._blocks: 168 | block.set_swish(memory_efficient) 169 | 170 | 171 | def extract_features(self, inputs): 172 | """ Returns output of the final convolution layer """ 173 | 174 | # Stem 175 | x = self._swish(self._bn0(self._conv_stem(inputs))) 176 | 177 | # Blocks 178 | for idx, block in enumerate(self._blocks): 179 | drop_connect_rate = self._global_params.drop_connect_rate 180 | if drop_connect_rate: 181 | drop_connect_rate *= float(idx) / len(self._blocks) 182 | x = block(x, drop_connect_rate=drop_connect_rate) 183 | 184 | # Head 185 | x = self._swish(self._bn1(self._conv_head(x))) 186 | 187 | return x 188 | 189 | def forward(self, inputs): 190 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ 191 | bs = inputs.size(0) 192 | # Convolution layers 193 | x = self.extract_features(inputs) 194 | 195 | # Pooling and final linear layer 196 | x = self._avg_pooling(x) 197 | x = x.view(bs, -1) 198 | x = self._dropout(x) 199 | x = self._fc(x) 200 | return x 201 | 202 | @classmethod 203 | def from_name(cls, model_name, override_params=None): 204 | cls._check_model_name_is_valid(model_name) 205 | blocks_args, global_params = get_model_params(model_name, override_params) 206 | return cls(blocks_args, global_params) 207 | 208 | @classmethod 209 | def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3): 210 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 211 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 212 | if in_channels != 3: 213 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) 214 | out_channels = round_filters(32, model._global_params) 215 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) 216 | return model 217 | 218 | @classmethod 219 | def from_pretrained(cls, model_name, num_classes=1000): 220 | model = cls.from_name(model_name, override_params={'num_classes': num_classes}) 221 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000)) 222 | 223 | return model 224 | 225 | @classmethod 226 | def get_image_size(cls, model_name): 227 | cls._check_model_name_is_valid(model_name) 228 | _, _, res, _ = efficientnet_params(model_name) 229 | return res 230 | 231 | @classmethod 232 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False): 233 | """ Validates model name. None that pretrained weights are only available for 234 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """ 235 | num_models = 4 if also_need_pretrained_weights else 8 236 | valid_models = ['efficientnet-b'+str(i) for i in range(num_models)] 237 | if model_name not in valid_models: 238 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) 239 | -------------------------------------------------------------------------------- /efficientnet_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains helper functions for building the model and for loading model parameters. 3 | These helper functions are built to mirror those in the official TensorFlow implementation. 4 | """ 5 | 6 | import re 7 | import math 8 | import collections 9 | from functools import partial 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.utils import model_zoo 14 | 15 | ######################################################################## 16 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ############### 17 | ######################################################################## 18 | 19 | 20 | # Parameters for the entire model (stem, all blocks, and head) 21 | GlobalParams = collections.namedtuple('GlobalParams', [ 22 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 23 | 'num_classes', 'width_coefficient', 'depth_coefficient', 24 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) 25 | 26 | # Parameters for an individual model block 27 | BlockArgs = collections.namedtuple('BlockArgs', [ 28 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 29 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) 30 | 31 | # Change namedtuple defaults 32 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 33 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 34 | 35 | 36 | class SwishImplementation(torch.autograd.Function): 37 | @staticmethod 38 | def forward(ctx, i): 39 | result = i * torch.sigmoid(i) 40 | ctx.save_for_backward(i) 41 | return result 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | i = ctx.saved_variables[0] 46 | sigmoid_i = torch.sigmoid(i) 47 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 48 | 49 | 50 | class MemoryEfficientSwish(nn.Module): 51 | def forward(self, x): 52 | return SwishImplementation.apply(x) 53 | 54 | class Swish(nn.Module): 55 | def forward(self, x): 56 | return x * torch.sigmoid(x) 57 | 58 | 59 | def round_filters(filters, global_params): 60 | """ Calculate and round number of filters based on depth multiplier. """ 61 | multiplier = global_params.width_coefficient 62 | if not multiplier: 63 | return filters 64 | divisor = global_params.depth_divisor 65 | min_depth = global_params.min_depth 66 | filters *= multiplier 67 | min_depth = min_depth or divisor 68 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 69 | if new_filters < 0.9 * filters: # prevent rounding by more than 10% 70 | new_filters += divisor 71 | return int(new_filters) 72 | 73 | 74 | def round_repeats(repeats, global_params): 75 | """ Round number of filters based on depth multiplier. """ 76 | multiplier = global_params.depth_coefficient 77 | if not multiplier: 78 | return repeats 79 | return int(math.ceil(multiplier * repeats)) 80 | 81 | 82 | def drop_connect(inputs, p, training): 83 | """ Drop connect. """ 84 | if not training: return inputs 85 | batch_size = inputs.shape[0] 86 | keep_prob = 1 - p 87 | random_tensor = keep_prob 88 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) 89 | binary_tensor = torch.floor(random_tensor) 90 | output = inputs / keep_prob * binary_tensor 91 | return output 92 | 93 | 94 | def get_same_padding_conv2d(image_size=None): 95 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. 96 | Static padding is necessary for ONNX exporting of models. """ 97 | if image_size is None: 98 | return Conv2dDynamicSamePadding 99 | else: 100 | return partial(Conv2dStaticSamePadding, image_size=image_size) 101 | 102 | 103 | class Conv2dDynamicSamePadding(nn.Conv2d): 104 | """ 2D Convolutions like TensorFlow, for a dynamic image size """ 105 | 106 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): 107 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 108 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 109 | 110 | def forward(self, x): 111 | ih, iw = x.size()[-2:] 112 | kh, kw = self.weight.size()[-2:] 113 | sh, sw = self.stride 114 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 115 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 116 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 117 | if pad_h > 0 or pad_w > 0: 118 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 119 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 120 | 121 | 122 | class Conv2dStaticSamePadding(nn.Conv2d): 123 | """ 2D Convolutions like TensorFlow, for a fixed image size""" 124 | 125 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): 126 | super().__init__(in_channels, out_channels, kernel_size, **kwargs) 127 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 128 | 129 | # Calculate padding based on image size and save it 130 | assert image_size is not None 131 | ih, iw = image_size if type(image_size) == list else [image_size, image_size] 132 | kh, kw = self.weight.size()[-2:] 133 | sh, sw = self.stride 134 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) 135 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) 136 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) 137 | if pad_h > 0 or pad_w > 0: 138 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) 139 | else: 140 | self.static_padding = Identity() 141 | 142 | def forward(self, x): 143 | x = self.static_padding(x) 144 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 145 | return x 146 | 147 | 148 | class Identity(nn.Module): 149 | def __init__(self, ): 150 | super(Identity, self).__init__() 151 | 152 | def forward(self, input): 153 | return input 154 | 155 | 156 | ######################################################################## 157 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ############## 158 | ######################################################################## 159 | 160 | 161 | def efficientnet_params(model_name): 162 | """ Map EfficientNet model name to parameter coefficients. """ 163 | params_dict = { 164 | # Coefficients: width,depth,res,dropout 165 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 166 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 167 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 168 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 169 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 170 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 171 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 172 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 173 | } 174 | return params_dict[model_name] 175 | 176 | 177 | class BlockDecoder(object): 178 | """ Block Decoder for readability, straight from the official TensorFlow repository """ 179 | 180 | @staticmethod 181 | def _decode_block_string(block_string): 182 | """ Gets a block through a string notation of arguments. """ 183 | assert isinstance(block_string, str) 184 | 185 | ops = block_string.split('_') 186 | options = {} 187 | for op in ops: 188 | splits = re.split(r'(\d.*)', op) 189 | if len(splits) >= 2: 190 | key, value = splits[:2] 191 | options[key] = value 192 | 193 | # Check stride 194 | assert (('s' in options and len(options['s']) == 1) or 195 | (len(options['s']) == 2 and options['s'][0] == options['s'][1])) 196 | 197 | return BlockArgs( 198 | kernel_size=int(options['k']), 199 | num_repeat=int(options['r']), 200 | input_filters=int(options['i']), 201 | output_filters=int(options['o']), 202 | expand_ratio=int(options['e']), 203 | id_skip=('noskip' not in block_string), 204 | se_ratio=float(options['se']) if 'se' in options else None, 205 | stride=[int(options['s'][0])]) 206 | 207 | @staticmethod 208 | def _encode_block_string(block): 209 | """Encodes a block to a string.""" 210 | args = [ 211 | 'r%d' % block.num_repeat, 212 | 'k%d' % block.kernel_size, 213 | 's%d%d' % (block.strides[0], block.strides[1]), 214 | 'e%s' % block.expand_ratio, 215 | 'i%d' % block.input_filters, 216 | 'o%d' % block.output_filters 217 | ] 218 | if 0 < block.se_ratio <= 1: 219 | args.append('se%s' % block.se_ratio) 220 | if block.id_skip is False: 221 | args.append('noskip') 222 | return '_'.join(args) 223 | 224 | @staticmethod 225 | def decode(string_list): 226 | """ 227 | Decodes a list of string notations to specify blocks inside the network. 228 | 229 | :param string_list: a list of strings, each string is a notation of block 230 | :return: a list of BlockArgs namedtuples of block args 231 | """ 232 | assert isinstance(string_list, list) 233 | blocks_args = [] 234 | for block_string in string_list: 235 | blocks_args.append(BlockDecoder._decode_block_string(block_string)) 236 | return blocks_args 237 | 238 | @staticmethod 239 | def encode(blocks_args): 240 | """ 241 | Encodes a list of BlockArgs to a list of strings. 242 | 243 | :param blocks_args: a list of BlockArgs namedtuples of block args 244 | :return: a list of strings, each string is a notation of block 245 | """ 246 | block_strings = [] 247 | for block in blocks_args: 248 | block_strings.append(BlockDecoder._encode_block_string(block)) 249 | return block_strings 250 | 251 | 252 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, 253 | drop_connect_rate=0.2, image_size=None, num_classes=1000): 254 | """ Creates a efficientnet model. """ 255 | 256 | blocks_args = [ 257 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 258 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 259 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 260 | 'r1_k3_s11_e6_i192_o320_se0.25', 261 | ] 262 | blocks_args = BlockDecoder.decode(blocks_args) 263 | 264 | global_params = GlobalParams( 265 | batch_norm_momentum=0.99, 266 | batch_norm_epsilon=1e-3, 267 | dropout_rate=dropout_rate, 268 | drop_connect_rate=drop_connect_rate, 269 | # data_format='channels_last', # removed, this is always true in PyTorch 270 | num_classes=num_classes, 271 | width_coefficient=width_coefficient, 272 | depth_coefficient=depth_coefficient, 273 | depth_divisor=8, 274 | min_depth=None, 275 | image_size=image_size, 276 | ) 277 | 278 | return blocks_args, global_params 279 | 280 | 281 | def get_model_params(model_name, override_params): 282 | """ Get the block args and global params for a given model """ 283 | if model_name.startswith('efficientnet'): 284 | w, d, s, p = efficientnet_params(model_name) 285 | # note: all models have drop connect rate = 0.2 286 | blocks_args, global_params = efficientnet( 287 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) 288 | else: 289 | raise NotImplementedError('model name is not pre-defined: %s' % model_name) 290 | if override_params: 291 | # ValueError will be raised here if override_params has fields not included in global_params. 292 | global_params = global_params._replace(**override_params) 293 | return blocks_args, global_params 294 | 295 | 296 | url_map = { 297 | 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth', 298 | 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth', 299 | 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth', 300 | 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth', 301 | 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth', 302 | 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth', 303 | 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth', 304 | 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth', 305 | } 306 | 307 | 308 | def load_pretrained_weights(model, model_name, load_fc=True): 309 | """ Loads pretrained weights, and downloads if loading for the first time. """ 310 | state_dict = model_zoo.load_url(url_map[model_name]) 311 | if load_fc: 312 | model.load_state_dict(state_dict) 313 | else: 314 | state_dict.pop('_fc.weight') 315 | state_dict.pop('_fc.bias') 316 | res = model.load_state_dict(state_dict, strict=False) 317 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' 318 | print('Loaded pretrained weights for {}'.format(model_name)) 319 | -------------------------------------------------------------------------------- /examples/imagenet/README.md: -------------------------------------------------------------------------------- 1 | ### Imagenet 2 | 3 | This is a preliminary directory for evaluating the model on ImageNet. It is adapted from the standard PyTorch Imagenet script. 4 | 5 | For now, only evaluation is supported, but I am currently building scripts to assist with training new models on Imagenet. 6 | 7 | The evaluation results are slightly different from the original TensorFlow repository, due to differences in data preprocessing. For example, with the current preprocessing, `efficientnet-b3` gives a top-1 accuracy of `80.8`, rather than `81.1` in the paper. I am working on porting the TensorFlow preprocessing into PyTorch to address this issue. 8 | 9 | To run on Imagenet, place your `train` and `val` directories in `data`. 10 | 11 | Example commands: 12 | ```bash 13 | # Evaluate small EfficientNet on CPU 14 | python main.py data -e -a 'efficientnet-b0' --pretrained 15 | ``` 16 | ```bash 17 | # Evaluate medium EfficientNet on GPU 18 | python main.py data -e -a 'efficientnet-b3' --pretrained --gpu 0 --batch-size 128 19 | ``` 20 | ```bash 21 | # Evaluate ResNet-50 for comparison 22 | python main.py data -e -a 'resnet50' --pretrained --gpu 0 23 | ``` 24 | -------------------------------------------------------------------------------- /examples/imagenet/data/README.md: -------------------------------------------------------------------------------- 1 | ### ImageNet 2 | 3 | Download ImageNet and place it into `train` and `val` folders here. 4 | 5 | More details may be found with the official PyTorch ImageNet example [here](https://github.com/pytorch/examples/blob/master/imagenet). 6 | -------------------------------------------------------------------------------- /examples/imagenet/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate on ImageNet. Note that at the moment, training is not implemented (I am working on it). 3 | that being said, evaluation is working. 4 | """ 5 | 6 | import argparse 7 | import os 8 | import random 9 | import shutil 10 | import time 11 | import warnings 12 | import PIL 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | import torch.multiprocessing as mp 21 | import torch.utils.data 22 | import torch.utils.data.distributed 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | import torchvision.models as models 26 | 27 | from efficientnet_pytorch import EfficientNet 28 | 29 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 30 | parser.add_argument('data', metavar='DIR', 31 | help='path to dataset') 32 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 33 | help='model architecture (default: resnet18)') 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('-b', '--batch-size', default=256, type=int, 41 | metavar='N', 42 | help='mini-batch size (default: 256), this is the total ' 43 | 'batch size of all GPUs on the current node when ' 44 | 'using Data Parallel or Distributed Data Parallel') 45 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 46 | metavar='LR', help='initial learning rate', dest='lr') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 50 | metavar='W', help='weight decay (default: 1e-4)', 51 | dest='weight_decay') 52 | parser.add_argument('-p', '--print-freq', default=10, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--world-size', default=-1, type=int, 61 | help='number of nodes for distributed training') 62 | parser.add_argument('--rank', default=-1, type=int, 63 | help='node rank for distributed training') 64 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 65 | help='url used to set up distributed training') 66 | parser.add_argument('--dist-backend', default='nccl', type=str, 67 | help='distributed backend') 68 | parser.add_argument('--seed', default=None, type=int, 69 | help='seed for initializing training. ') 70 | parser.add_argument('--gpu', default=None, type=int, 71 | help='GPU id to use.') 72 | parser.add_argument('--image_size', default=224, type=int, 73 | help='image size') 74 | parser.add_argument('--multiprocessing-distributed', action='store_true', 75 | help='Use multi-processing distributed training to launch ' 76 | 'N processes per node, which has N GPUs. This is the ' 77 | 'fastest way to use PyTorch for either single node or ' 78 | 'multi node data parallel training') 79 | 80 | best_acc1 = 0 81 | 82 | 83 | def main(): 84 | args = parser.parse_args() 85 | 86 | if args.seed is not None: 87 | random.seed(args.seed) 88 | torch.manual_seed(args.seed) 89 | cudnn.deterministic = True 90 | warnings.warn('You have chosen to seed training. ' 91 | 'This will turn on the CUDNN deterministic setting, ' 92 | 'which can slow down your training considerably! ' 93 | 'You may see unexpected behavior when restarting ' 94 | 'from checkpoints.') 95 | 96 | if args.gpu is not None: 97 | warnings.warn('You have chosen a specific GPU. This will completely ' 98 | 'disable data parallelism.') 99 | 100 | if args.dist_url == "env://" and args.world_size == -1: 101 | args.world_size = int(os.environ["WORLD_SIZE"]) 102 | 103 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 104 | 105 | ngpus_per_node = torch.cuda.device_count() 106 | if args.multiprocessing_distributed: 107 | # Since we have ngpus_per_node processes per node, the total world_size 108 | # needs to be adjusted accordingly 109 | args.world_size = ngpus_per_node * args.world_size 110 | # Use torch.multiprocessing.spawn to launch distributed processes: the 111 | # main_worker process function 112 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 113 | else: 114 | # Simply call main_worker function 115 | main_worker(args.gpu, ngpus_per_node, args) 116 | 117 | 118 | def main_worker(gpu, ngpus_per_node, args): 119 | global best_acc1 120 | args.gpu = gpu 121 | 122 | if args.gpu is not None: 123 | print("Use GPU: {} for training".format(args.gpu)) 124 | 125 | if args.distributed: 126 | if args.dist_url == "env://" and args.rank == -1: 127 | args.rank = int(os.environ["RANK"]) 128 | if args.multiprocessing_distributed: 129 | # For multiprocessing distributed training, rank needs to be the 130 | # global rank among all the processes 131 | args.rank = args.rank * ngpus_per_node + gpu 132 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 133 | world_size=args.world_size, rank=args.rank) 134 | # create model 135 | if 'efficientnet' in args.arch: # NEW 136 | if args.pretrained: 137 | model = EfficientNet.from_pretrained(args.arch) 138 | print("=> using pre-trained model '{}'".format(args.arch)) 139 | else: 140 | print("=> creating model '{}'".format(args.arch)) 141 | model = EfficientNet.from_name(args.arch) 142 | 143 | else: 144 | if args.pretrained: 145 | print("=> using pre-trained model '{}'".format(args.arch)) 146 | model = models.__dict__[args.arch](pretrained=True) 147 | else: 148 | print("=> creating model '{}'".format(args.arch)) 149 | model = models.__dict__[args.arch]() 150 | 151 | if args.distributed: 152 | # For multiprocessing distributed, DistributedDataParallel constructor 153 | # should always set the single device scope, otherwise, 154 | # DistributedDataParallel will use all available devices. 155 | if args.gpu is not None: 156 | torch.cuda.set_device(args.gpu) 157 | model.cuda(args.gpu) 158 | # When using a single GPU per process and per 159 | # DistributedDataParallel, we need to divide the batch size 160 | # ourselves based on the total number of GPUs we have 161 | args.batch_size = int(args.batch_size / ngpus_per_node) 162 | args.workers = int(args.workers / ngpus_per_node) 163 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 164 | else: 165 | model.cuda() 166 | # DistributedDataParallel will divide and allocate batch_size to all 167 | # available GPUs if device_ids are not set 168 | model = torch.nn.parallel.DistributedDataParallel(model) 169 | elif args.gpu is not None: 170 | torch.cuda.set_device(args.gpu) 171 | model = model.cuda(args.gpu) 172 | else: 173 | # DataParallel will divide and allocate batch_size to all available GPUs 174 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 175 | model.features = torch.nn.DataParallel(model.features) 176 | model.cuda() 177 | else: 178 | model = torch.nn.DataParallel(model).cuda() 179 | 180 | # define loss function (criterion) and optimizer 181 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 182 | 183 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 184 | momentum=args.momentum, 185 | weight_decay=args.weight_decay) 186 | 187 | # optionally resume from a checkpoint 188 | if args.resume: 189 | if os.path.isfile(args.resume): 190 | print("=> loading checkpoint '{}'".format(args.resume)) 191 | checkpoint = torch.load(args.resume) 192 | args.start_epoch = checkpoint['epoch'] 193 | best_acc1 = checkpoint['best_acc1'] 194 | if args.gpu is not None: 195 | # best_acc1 may be from a checkpoint from a different GPU 196 | best_acc1 = best_acc1.to(args.gpu) 197 | model.load_state_dict(checkpoint['state_dict']) 198 | optimizer.load_state_dict(checkpoint['optimizer']) 199 | print("=> loaded checkpoint '{}' (epoch {})" 200 | .format(args.resume, checkpoint['epoch'])) 201 | else: 202 | print("=> no checkpoint found at '{}'".format(args.resume)) 203 | 204 | cudnn.benchmark = True 205 | 206 | # Data loading code 207 | traindir = os.path.join(args.data, 'train') 208 | valdir = os.path.join(args.data, 'val') 209 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 210 | std=[0.229, 0.224, 0.225]) 211 | 212 | train_dataset = datasets.ImageFolder( 213 | traindir, 214 | transforms.Compose([ 215 | transforms.RandomResizedCrop(224), 216 | transforms.RandomHorizontalFlip(), 217 | transforms.ToTensor(), 218 | normalize, 219 | ])) 220 | 221 | if args.distributed: 222 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 223 | else: 224 | train_sampler = None 225 | 226 | train_loader = torch.utils.data.DataLoader( 227 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 228 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 229 | 230 | if 'efficientnet' in args.arch: 231 | image_size = EfficientNet.get_image_size(args.arch) 232 | val_transforms = transforms.Compose([ 233 | transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC), 234 | transforms.CenterCrop(image_size), 235 | transforms.ToTensor(), 236 | normalize, 237 | ]) 238 | print('Using image size', image_size) 239 | else: 240 | val_transforms = transforms.Compose([ 241 | transforms.Resize(256), 242 | transforms.CenterCrop(224), 243 | transforms.ToTensor(), 244 | normalize, 245 | ]) 246 | print('Using image size', 224) 247 | 248 | val_loader = torch.utils.data.DataLoader( 249 | datasets.ImageFolder(valdir, val_transforms), 250 | batch_size=args.batch_size, shuffle=False, 251 | num_workers=args.workers, pin_memory=True) 252 | 253 | if args.evaluate: 254 | res = validate(val_loader, model, criterion, args) 255 | with open('res.txt', 'w') as f: 256 | print(res, file=f) 257 | return 258 | 259 | for epoch in range(args.start_epoch, args.epochs): 260 | if args.distributed: 261 | train_sampler.set_epoch(epoch) 262 | adjust_learning_rate(optimizer, epoch, args) 263 | 264 | # train for one epoch 265 | train(train_loader, model, criterion, optimizer, epoch, args) 266 | 267 | # evaluate on validation set 268 | acc1 = validate(val_loader, model, criterion, args) 269 | 270 | # remember best acc@1 and save checkpoint 271 | is_best = acc1 > best_acc1 272 | best_acc1 = max(acc1, best_acc1) 273 | 274 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 275 | and args.rank % ngpus_per_node == 0): 276 | save_checkpoint({ 277 | 'epoch': epoch + 1, 278 | 'arch': args.arch, 279 | 'state_dict': model.state_dict(), 280 | 'best_acc1': best_acc1, 281 | 'optimizer' : optimizer.state_dict(), 282 | }, is_best) 283 | 284 | 285 | def train(train_loader, model, criterion, optimizer, epoch, args): 286 | batch_time = AverageMeter('Time', ':6.3f') 287 | data_time = AverageMeter('Data', ':6.3f') 288 | losses = AverageMeter('Loss', ':.4e') 289 | top1 = AverageMeter('Acc@1', ':6.2f') 290 | top5 = AverageMeter('Acc@5', ':6.2f') 291 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 292 | top5, prefix="Epoch: [{}]".format(epoch)) 293 | 294 | # switch to train mode 295 | model.train() 296 | 297 | end = time.time() 298 | for i, (images, target) in enumerate(train_loader): 299 | # measure data loading time 300 | data_time.update(time.time() - end) 301 | 302 | if args.gpu is not None: 303 | images = images.cuda(args.gpu, non_blocking=True) 304 | target = target.cuda(args.gpu, non_blocking=True) 305 | 306 | # compute output 307 | output = model(images) 308 | loss = criterion(output, target) 309 | 310 | # measure accuracy and record loss 311 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 312 | losses.update(loss.item(), images.size(0)) 313 | top1.update(acc1[0], images.size(0)) 314 | top5.update(acc5[0], images.size(0)) 315 | 316 | # compute gradient and do SGD step 317 | optimizer.zero_grad() 318 | loss.backward() 319 | optimizer.step() 320 | 321 | # measure elapsed time 322 | batch_time.update(time.time() - end) 323 | end = time.time() 324 | 325 | if i % args.print_freq == 0: 326 | progress.print(i) 327 | 328 | 329 | def validate(val_loader, model, criterion, args): 330 | batch_time = AverageMeter('Time', ':6.3f') 331 | losses = AverageMeter('Loss', ':.4e') 332 | top1 = AverageMeter('Acc@1', ':6.2f') 333 | top5 = AverageMeter('Acc@5', ':6.2f') 334 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 335 | prefix='Test: ') 336 | 337 | # switch to evaluate mode 338 | model.eval() 339 | 340 | with torch.no_grad(): 341 | end = time.time() 342 | for i, (images, target) in enumerate(val_loader): 343 | if args.gpu is not None: 344 | images = images.cuda(args.gpu, non_blocking=True) 345 | target = target.cuda(args.gpu, non_blocking=True) 346 | 347 | # compute output 348 | output = model(images) 349 | loss = criterion(output, target) 350 | 351 | # measure accuracy and record loss 352 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 353 | losses.update(loss.item(), images.size(0)) 354 | top1.update(acc1[0], images.size(0)) 355 | top5.update(acc5[0], images.size(0)) 356 | 357 | # measure elapsed time 358 | batch_time.update(time.time() - end) 359 | end = time.time() 360 | 361 | if i % args.print_freq == 0: 362 | progress.print(i) 363 | 364 | # TODO: this should also be done with the ProgressMeter 365 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 366 | .format(top1=top1, top5=top5)) 367 | 368 | return top1.avg 369 | 370 | 371 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 372 | torch.save(state, filename) 373 | if is_best: 374 | shutil.copyfile(filename, 'model_best.pth.tar') 375 | 376 | 377 | class AverageMeter(object): 378 | """Computes and stores the average and current value""" 379 | def __init__(self, name, fmt=':f'): 380 | self.name = name 381 | self.fmt = fmt 382 | self.reset() 383 | 384 | def reset(self): 385 | self.val = 0 386 | self.avg = 0 387 | self.sum = 0 388 | self.count = 0 389 | 390 | def update(self, val, n=1): 391 | self.val = val 392 | self.sum += val * n 393 | self.count += n 394 | self.avg = self.sum / self.count 395 | 396 | def __str__(self): 397 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 398 | return fmtstr.format(**self.__dict__) 399 | 400 | 401 | class ProgressMeter(object): 402 | def __init__(self, num_batches, *meters, prefix=""): 403 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 404 | self.meters = meters 405 | self.prefix = prefix 406 | 407 | def print(self, batch): 408 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 409 | entries += [str(meter) for meter in self.meters] 410 | print('\t'.join(entries)) 411 | 412 | def _get_batch_fmtstr(self, num_batches): 413 | num_digits = len(str(num_batches // 1)) 414 | fmt = '{:' + str(num_digits) + 'd}' 415 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 416 | 417 | 418 | def adjust_learning_rate(optimizer, epoch, args): 419 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 420 | lr = args.lr * (0.1 ** (epoch // 30)) 421 | for param_group in optimizer.param_groups: 422 | param_group['lr'] = lr 423 | 424 | 425 | def accuracy(output, target, topk=(1,)): 426 | """Computes the accuracy over the k top predictions for the specified values of k""" 427 | with torch.no_grad(): 428 | maxk = max(topk) 429 | batch_size = target.size(0) 430 | 431 | _, pred = output.topk(maxk, 1, True, True) 432 | pred = pred.t() 433 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 434 | 435 | res = [] 436 | for k in topk: 437 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 438 | res.append(correct_k.mul_(100.0 / batch_size)) 439 | return res 440 | 441 | 442 | if __name__ == '__main__': 443 | main() 444 | -------------------------------------------------------------------------------- /examples/simple/fire_smoke_map.txt: -------------------------------------------------------------------------------- 1 | {"0": "fire", "1": "negtive", "2": "smoke"} 2 | -------------------------------------------------------------------------------- /examples/simple/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/examples/simple/img.jpg -------------------------------------------------------------------------------- /examples/simple/img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/examples/simple/img2.jpg -------------------------------------------------------------------------------- /examples/simple/labels_map.txt: -------------------------------------------------------------------------------- 1 | {"0": "tench, Tinca tinca", "1": "goldfish, Carassius auratus", "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "3": "tiger shark, Galeocerdo cuvieri", "4": "hammerhead, hammerhead shark", "5": "electric ray, crampfish, numbfish, torpedo", "6": "stingray", "7": "cock", "8": "hen", "9": "ostrich, Struthio camelus", "10": "brambling, Fringilla montifringilla", "11": "goldfinch, Carduelis carduelis", "12": "house finch, linnet, Carpodacus mexicanus", "13": "junco, snowbird", "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "15": "robin, American robin, Turdus migratorius", "16": "bulbul", "17": "jay", "18": "magpie", "19": "chickadee", "20": "water ouzel, dipper", "21": "kite", "22": "bald eagle, American eagle, Haliaeetus leucocephalus", "23": "vulture", "24": "great grey owl, great gray owl, Strix nebulosa", "25": "European fire salamander, Salamandra salamandra", "26": "common newt, Triturus vulgaris", "27": "eft", "28": "spotted salamander, Ambystoma maculatum", "29": "axolotl, mud puppy, Ambystoma mexicanum", "30": "bullfrog, Rana catesbeiana", "31": "tree frog, tree-frog", "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "33": "loggerhead, loggerhead turtle, Caretta caretta", "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "35": "mud turtle", "36": "terrapin", "37": "box turtle, box tortoise", "38": "banded gecko", "39": "common iguana, iguana, Iguana iguana", "40": "American chameleon, anole, Anolis carolinensis", "41": "whiptail, whiptail lizard", "42": "agama", "43": "frilled lizard, Chlamydosaurus kingi", "44": "alligator lizard", "45": "Gila monster, Heloderma suspectum", "46": "green lizard, Lacerta viridis", "47": "African chameleon, Chamaeleo chamaeleon", "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "49": "African crocodile, Nile crocodile, Crocodylus niloticus", "50": "American alligator, Alligator mississipiensis", "51": "triceratops", "52": "thunder snake, worm snake, Carphophis amoenus", "53": "ringneck snake, ring-necked snake, ring snake", "54": "hognose snake, puff adder, sand viper", "55": "green snake, grass snake", "56": "king snake, kingsnake", "57": "garter snake, grass snake", "58": "water snake", "59": "vine snake", "60": "night snake, Hypsiglena torquata", "61": "boa constrictor, Constrictor constrictor", "62": "rock python, rock snake, Python sebae", "63": "Indian cobra, Naja naja", "64": "green mamba", "65": "sea snake", "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", "68": "sidewinder, horned rattlesnake, Crotalus cerastes", "69": "trilobite", "70": "harvestman, daddy longlegs, Phalangium opilio", "71": "scorpion", "72": "black and gold garden spider, Argiope aurantia", "73": "barn spider, Araneus cavaticus", "74": "garden spider, Aranea diademata", "75": "black widow, Latrodectus mactans", "76": "tarantula", "77": "wolf spider, hunting spider", "78": "tick", "79": "centipede", "80": "black grouse", "81": "ptarmigan", "82": "ruffed grouse, partridge, Bonasa umbellus", "83": "prairie chicken, prairie grouse, prairie fowl", "84": "peacock", "85": "quail", "86": "partridge", "87": "African grey, African gray, Psittacus erithacus", "88": "macaw", "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "90": "lorikeet", "91": "coucal", "92": "bee eater", "93": "hornbill", "94": "hummingbird", "95": "jacamar", "96": "toucan", "97": "drake", "98": "red-breasted merganser, Mergus serrator", "99": "goose", "100": "black swan, Cygnus atratus", "101": "tusker", "102": "echidna, spiny anteater, anteater", "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "104": "wallaby, brush kangaroo", "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "106": "wombat", "107": "jellyfish", "108": "sea anemone, anemone", "109": "brain coral", "110": "flatworm, platyhelminth", "111": "nematode, nematode worm, roundworm", "112": "conch", "113": "snail", "114": "slug", "115": "sea slug, nudibranch", "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", "117": "chambered nautilus, pearly nautilus, nautilus", "118": "Dungeness crab, Cancer magister", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "124": "crayfish, crawfish, crawdad, crawdaddy", "125": "hermit crab", "126": "isopod", "127": "white stork, Ciconia ciconia", "128": "black stork, Ciconia nigra", "129": "spoonbill", "130": "flamingo", "131": "little blue heron, Egretta caerulea", "132": "American egret, great white heron, Egretta albus", "133": "bittern", "134": "crane", "135": "limpkin, Aramus pictus", "136": "European gallinule, Porphyrio porphyrio", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "138": "bustard", "139": "ruddy turnstone, Arenaria interpres", "140": "red-backed sandpiper, dunlin, Erolia alpina", "141": "redshank, Tringa totanus", "142": "dowitcher", "143": "oystercatcher, oyster catcher", "144": "pelican", "145": "king penguin, Aptenodytes patagonica", "146": "albatross, mollymawk", "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "149": "dugong, Dugong dugon", "150": "sea lion", "151": "Chihuahua", "152": "Japanese spaniel", "153": "Maltese dog, Maltese terrier, Maltese", "154": "Pekinese, Pekingese, Peke", "155": "Shih-Tzu", "156": "Blenheim spaniel", "157": "papillon", "158": "toy terrier", "159": "Rhodesian ridgeback", "160": "Afghan hound, Afghan", "161": "basset, basset hound", "162": "beagle", "163": "bloodhound, sleuthhound", "164": "bluetick", "165": "black-and-tan coonhound", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "168": "redbone", "169": "borzoi, Russian wolfhound", "170": "Irish wolfhound", "171": "Italian greyhound", "172": "whippet", "173": "Ibizan hound, Ibizan Podenco", "174": "Norwegian elkhound, elkhound", "175": "otterhound, otter hound", "176": "Saluki, gazelle hound", "177": "Scottish deerhound, deerhound", "178": "Weimaraner", "179": "Staffordshire bullterrier, Staffordshire bull terrier", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "181": "Bedlington terrier", "182": "Border terrier", "183": "Kerry blue terrier", "184": "Irish terrier", "185": "Norfolk terrier", "186": "Norwich terrier", "187": "Yorkshire terrier", "188": "wire-haired fox terrier", "189": "Lakeland terrier", "190": "Sealyham terrier, Sealyham", "191": "Airedale, Airedale terrier", "192": "cairn, cairn terrier", "193": "Australian terrier", "194": "Dandie Dinmont, Dandie Dinmont terrier", "195": "Boston bull, Boston terrier", "196": "miniature schnauzer", "197": "giant schnauzer", "198": "standard schnauzer", "199": "Scotch terrier, Scottish terrier, Scottie", "200": "Tibetan terrier, chrysanthemum dog", "201": "silky terrier, Sydney silky", "202": "soft-coated wheaten terrier", "203": "West Highland white terrier", "204": "Lhasa, Lhasa apso", "205": "flat-coated retriever", "206": "curly-coated retriever", "207": "golden retriever", "208": "Labrador retriever", "209": "Chesapeake Bay retriever", "210": "German short-haired pointer", "211": "vizsla, Hungarian pointer", "212": "English setter", "213": "Irish setter, red setter", "214": "Gordon setter", "215": "Brittany spaniel", "216": "clumber, clumber spaniel", "217": "English springer, English springer spaniel", "218": "Welsh springer spaniel", "219": "cocker spaniel, English cocker spaniel, cocker", "220": "Sussex spaniel", "221": "Irish water spaniel", "222": "kuvasz", "223": "schipperke", "224": "groenendael", "225": "malinois", "226": "briard", "227": "kelpie", "228": "komondor", "229": "Old English sheepdog, bobtail", "230": "Shetland sheepdog, Shetland sheep dog, Shetland", "231": "collie", "232": "Border collie", "233": "Bouvier des Flandres, Bouviers des Flandres", "234": "Rottweiler", "235": "German shepherd, German shepherd dog, German police dog, alsatian", "236": "Doberman, Doberman pinscher", "237": "miniature pinscher", "238": "Greater Swiss Mountain dog", "239": "Bernese mountain dog", "240": "Appenzeller", "241": "EntleBucher", "242": "boxer", "243": "bull mastiff", "244": "Tibetan mastiff", "245": "French bulldog", "246": "Great Dane", "247": "Saint Bernard, St Bernard", "248": "Eskimo dog, husky", "249": "malamute, malemute, Alaskan malamute", "250": "Siberian husky", "251": "dalmatian, coach dog, carriage dog", "252": "affenpinscher, monkey pinscher, monkey dog", "253": "basenji", "254": "pug, pug-dog", "255": "Leonberg", "256": "Newfoundland, Newfoundland dog", "257": "Great Pyrenees", "258": "Samoyed, Samoyede", "259": "Pomeranian", "260": "chow, chow chow", "261": "keeshond", "262": "Brabancon griffon", "263": "Pembroke, Pembroke Welsh corgi", "264": "Cardigan, Cardigan Welsh corgi", "265": "toy poodle", "266": "miniature poodle", "267": "standard poodle", "268": "Mexican hairless", "269": "timber wolf, grey wolf, gray wolf, Canis lupus", "270": "white wolf, Arctic wolf, Canis lupus tundrarum", "271": "red wolf, maned wolf, Canis rufus, Canis niger", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "273": "dingo, warrigal, warragal, Canis dingo", "274": "dhole, Cuon alpinus", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "276": "hyena, hyaena", "277": "red fox, Vulpes vulpes", "278": "kit fox, Vulpes macrotis", "279": "Arctic fox, white fox, Alopex lagopus", "280": "grey fox, gray fox, Urocyon cinereoargenteus", "281": "tabby, tabby cat", "282": "tiger cat", "283": "Persian cat", "284": "Siamese cat, Siamese", "285": "Egyptian cat", "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "287": "lynx, catamount", "288": "leopard, Panthera pardus", "289": "snow leopard, ounce, Panthera uncia", "290": "jaguar, panther, Panthera onca, Felis onca", "291": "lion, king of beasts, Panthera leo", "292": "tiger, Panthera tigris", "293": "cheetah, chetah, Acinonyx jubatus", "294": "brown bear, bruin, Ursus arctos", "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "297": "sloth bear, Melursus ursinus, Ursus ursinus", "298": "mongoose", "299": "meerkat, mierkat", "300": "tiger beetle", "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "302": "ground beetle, carabid beetle", "303": "long-horned beetle, longicorn, longicorn beetle", "304": "leaf beetle, chrysomelid", "305": "dung beetle", "306": "rhinoceros beetle", "307": "weevil", "308": "fly", "309": "bee", "310": "ant, emmet, pismire", "311": "grasshopper, hopper", "312": "cricket", "313": "walking stick, walkingstick, stick insect", "314": "cockroach, roach", "315": "mantis, mantid", "316": "cicada, cicala", "317": "leafhopper", "318": "lacewing, lacewing fly", "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "320": "damselfly", "321": "admiral", "322": "ringlet, ringlet butterfly", "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "324": "cabbage butterfly", "325": "sulphur butterfly, sulfur butterfly", "326": "lycaenid, lycaenid butterfly", "327": "starfish, sea star", "328": "sea urchin", "329": "sea cucumber, holothurian", "330": "wood rabbit, cottontail, cottontail rabbit", "331": "hare", "332": "Angora, Angora rabbit", "333": "hamster", "334": "porcupine, hedgehog", "335": "fox squirrel, eastern fox squirrel, Sciurus niger", "336": "marmot", "337": "beaver", "338": "guinea pig, Cavia cobaya", "339": "sorrel", "340": "zebra", "341": "hog, pig, grunter, squealer, Sus scrofa", "342": "wild boar, boar, Sus scrofa", "343": "warthog", "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", "345": "ox", "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "347": "bison", "348": "ram, tup", "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "350": "ibex, Capra ibex", "351": "hartebeest", "352": "impala, Aepyceros melampus", "353": "gazelle", "354": "Arabian camel, dromedary, Camelus dromedarius", "355": "llama", "356": "weasel", "357": "mink", "358": "polecat, fitch, foulmart, foumart, Mustela putorius", "359": "black-footed ferret, ferret, Mustela nigripes", "360": "otter", "361": "skunk, polecat, wood pussy", "362": "badger", "363": "armadillo", "364": "three-toed sloth, ai, Bradypus tridactylus", "365": "orangutan, orang, orangutang, Pongo pygmaeus", "366": "gorilla, Gorilla gorilla", "367": "chimpanzee, chimp, Pan troglodytes", "368": "gibbon, Hylobates lar", "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", "370": "guenon, guenon monkey", "371": "patas, hussar monkey, Erythrocebus patas", "372": "baboon", "373": "macaque", "374": "langur", "375": "colobus, colobus monkey", "376": "proboscis monkey, Nasalis larvatus", "377": "marmoset", "378": "capuchin, ringtail, Cebus capucinus", "379": "howler monkey, howler", "380": "titi, titi monkey", "381": "spider monkey, Ateles geoffroyi", "382": "squirrel monkey, Saimiri sciureus", "383": "Madagascar cat, ring-tailed lemur, Lemur catta", "384": "indri, indris, Indri indri, Indri brevicaudatus", "385": "Indian elephant, Elephas maximus", "386": "African elephant, Loxodonta africana", "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "389": "barracouta, snoek", "390": "eel", "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "392": "rock beauty, Holocanthus tricolor", "393": "anemone fish", "394": "sturgeon", "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", "396": "lionfish", "397": "puffer, pufferfish, blowfish, globefish", "398": "abacus", "399": "abaya", "400": "academic gown, academic robe, judge's robe", "401": "accordion, piano accordion, squeeze box", "402": "acoustic guitar", "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", "404": "airliner", "405": "airship, dirigible", "406": "altar", "407": "ambulance", "408": "amphibian, amphibious vehicle", "409": "analog clock", "410": "apiary, bee house", "411": "apron", "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "413": "assault rifle, assault gun", "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", "415": "bakery, bakeshop, bakehouse", "416": "balance beam, beam", "417": "balloon", "418": "ballpoint, ballpoint pen, ballpen, Biro", "419": "Band Aid", "420": "banjo", "421": "bannister, banister, balustrade, balusters, handrail", "422": "barbell", "423": "barber chair", "424": "barbershop", "425": "barn", "426": "barometer", "427": "barrel, cask", "428": "barrow, garden cart, lawn cart, wheelbarrow", "429": "baseball", "430": "basketball", "431": "bassinet", "432": "bassoon", "433": "bathing cap, swimming cap", "434": "bath towel", "435": "bathtub, bathing tub, bath, tub", "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "437": "beacon, lighthouse, beacon light, pharos", "438": "beaker", "439": "bearskin, busby, shako", "440": "beer bottle", "441": "beer glass", "442": "bell cote, bell cot", "443": "bib", "444": "bicycle-built-for-two, tandem bicycle, tandem", "445": "bikini, two-piece", "446": "binder, ring-binder", "447": "binoculars, field glasses, opera glasses", "448": "birdhouse", "449": "boathouse", "450": "bobsled, bobsleigh, bob", "451": "bolo tie, bolo, bola tie, bola", "452": "bonnet, poke bonnet", "453": "bookcase", "454": "bookshop, bookstore, bookstall", "455": "bottlecap", "456": "bow", "457": "bow tie, bow-tie, bowtie", "458": "brass, memorial tablet, plaque", "459": "brassiere, bra, bandeau", "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "461": "breastplate, aegis, egis", "462": "broom", "463": "bucket, pail", "464": "buckle", "465": "bulletproof vest", "466": "bullet train, bullet", "467": "butcher shop, meat market", "468": "cab, hack, taxi, taxicab", "469": "caldron, cauldron", "470": "candle, taper, wax light", "471": "cannon", "472": "canoe", "473": "can opener, tin opener", "474": "cardigan", "475": "car mirror", "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", "477": "carpenter's kit, tool kit", "478": "carton", "479": "car wheel", "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "481": "cassette", "482": "cassette player", "483": "castle", "484": "catamaran", "485": "CD player", "486": "cello, violoncello", "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", "488": "chain", "489": "chainlink fence", "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "491": "chain saw, chainsaw", "492": "chest", "493": "chiffonier, commode", "494": "chime, bell, gong", "495": "china cabinet, china closet", "496": "Christmas stocking", "497": "church, church building", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "499": "cleaver, meat cleaver, chopper", "500": "cliff dwelling", "501": "cloak", "502": "clog, geta, patten, sabot", "503": "cocktail shaker", "504": "coffee mug", "505": "coffeepot", "506": "coil, spiral, volute, whorl, helix", "507": "combination lock", "508": "computer keyboard, keypad", "509": "confectionery, confectionary, candy store", "510": "container ship, containership, container vessel", "511": "convertible", "512": "corkscrew, bottle screw", "513": "cornet, horn, trumpet, trump", "514": "cowboy boot", "515": "cowboy hat, ten-gallon hat", "516": "cradle", "517": "crane", "518": "crash helmet", "519": "crate", "520": "crib, cot", "521": "Crock Pot", "522": "croquet ball", "523": "crutch", "524": "cuirass", "525": "dam, dike, dyke", "526": "desk", "527": "desktop computer", "528": "dial telephone, dial phone", "529": "diaper, nappy, napkin", "530": "digital clock", "531": "digital watch", "532": "dining table, board", "533": "dishrag, dishcloth", "534": "dishwasher, dish washer, dishwashing machine", "535": "disk brake, disc brake", "536": "dock, dockage, docking facility", "537": "dogsled, dog sled, dog sleigh", "538": "dome", "539": "doormat, welcome mat", "540": "drilling platform, offshore rig", "541": "drum, membranophone, tympan", "542": "drumstick", "543": "dumbbell", "544": "Dutch oven", "545": "electric fan, blower", "546": "electric guitar", "547": "electric locomotive", "548": "entertainment center", "549": "envelope", "550": "espresso maker", "551": "face powder", "552": "feather boa, boa", "553": "file, file cabinet, filing cabinet", "554": "fireboat", "555": "fire engine, fire truck", "556": "fire screen, fireguard", "557": "flagpole, flagstaff", "558": "flute, transverse flute", "559": "folding chair", "560": "football helmet", "561": "forklift", "562": "fountain", "563": "fountain pen", "564": "four-poster", "565": "freight car", "566": "French horn, horn", "567": "frying pan, frypan, skillet", "568": "fur coat", "569": "garbage truck, dustcart", "570": "gasmask, respirator, gas helmet", "571": "gas pump, gasoline pump, petrol pump, island dispenser", "572": "goblet", "573": "go-kart", "574": "golf ball", "575": "golfcart, golf cart", "576": "gondola", "577": "gong, tam-tam", "578": "gown", "579": "grand piano, grand", "580": "greenhouse, nursery, glasshouse", "581": "grille, radiator grille", "582": "grocery store, grocery, food market, market", "583": "guillotine", "584": "hair slide", "585": "hair spray", "586": "half track", "587": "hammer", "588": "hamper", "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", "590": "hand-held computer, hand-held microcomputer", "591": "handkerchief, hankie, hanky, hankey", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "594": "harp", "595": "harvester, reaper", "596": "hatchet", "597": "holster", "598": "home theater, home theatre", "599": "honeycomb", "600": "hook, claw", "601": "hoopskirt, crinoline", "602": "horizontal bar, high bar", "603": "horse cart, horse-cart", "604": "hourglass", "605": "iPod", "606": "iron, smoothing iron", "607": "jack-o'-lantern", "608": "jean, blue jean, denim", "609": "jeep, landrover", "610": "jersey, T-shirt, tee shirt", "611": "jigsaw puzzle", "612": "jinrikisha, ricksha, rickshaw", "613": "joystick", "614": "kimono", "615": "knee pad", "616": "knot", "617": "lab coat, laboratory coat", "618": "ladle", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "621": "lawn mower, mower", "622": "lens cap, lens cover", "623": "letter opener, paper knife, paperknife", "624": "library", "625": "lifeboat", "626": "lighter, light, igniter, ignitor", "627": "limousine, limo", "628": "liner, ocean liner", "629": "lipstick, lip rouge", "630": "Loafer", "631": "lotion", "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "633": "loupe, jeweler's loupe", "634": "lumbermill, sawmill", "635": "magnetic compass", "636": "mailbag, postbag", "637": "mailbox, letter box", "638": "maillot", "639": "maillot, tank suit", "640": "manhole cover", "641": "maraca", "642": "marimba, xylophone", "643": "mask", "644": "matchstick", "645": "maypole", "646": "maze, labyrinth", "647": "measuring cup", "648": "medicine chest, medicine cabinet", "649": "megalith, megalithic structure", "650": "microphone, mike", "651": "microwave, microwave oven", "652": "military uniform", "653": "milk can", "654": "minibus", "655": "miniskirt, mini", "656": "minivan", "657": "missile", "658": "mitten", "659": "mixing bowl", "660": "mobile home, manufactured home", "661": "Model T", "662": "modem", "663": "monastery", "664": "monitor", "665": "moped", "666": "mortar", "667": "mortarboard", "668": "mosque", "669": "mosquito net", "670": "motor scooter, scooter", "671": "mountain bike, all-terrain bike, off-roader", "672": "mountain tent", "673": "mouse, computer mouse", "674": "mousetrap", "675": "moving van", "676": "muzzle", "677": "nail", "678": "neck brace", "679": "necklace", "680": "nipple", "681": "notebook, notebook computer", "682": "obelisk", "683": "oboe, hautboy, hautbois", "684": "ocarina, sweet potato", "685": "odometer, hodometer, mileometer, milometer", "686": "oil filter", "687": "organ, pipe organ", "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", "689": "overskirt", "690": "oxcart", "691": "oxygen mask", "692": "packet", "693": "paddle, boat paddle", "694": "paddlewheel, paddle wheel", "695": "padlock", "696": "paintbrush", "697": "pajama, pyjama, pj's, jammies", "698": "palace", "699": "panpipe, pandean pipe, syrinx", "700": "paper towel", "701": "parachute, chute", "702": "parallel bars, bars", "703": "park bench", "704": "parking meter", "705": "passenger car, coach, carriage", "706": "patio, terrace", "707": "pay-phone, pay-station", "708": "pedestal, plinth, footstall", "709": "pencil box, pencil case", "710": "pencil sharpener", "711": "perfume, essence", "712": "Petri dish", "713": "photocopier", "714": "pick, plectrum, plectron", "715": "pickelhaube", "716": "picket fence, paling", "717": "pickup, pickup truck", "718": "pier", "719": "piggy bank, penny bank", "720": "pill bottle", "721": "pillow", "722": "ping-pong ball", "723": "pinwheel", "724": "pirate, pirate ship", "725": "pitcher, ewer", "726": "plane, carpenter's plane, woodworking plane", "727": "planetarium", "728": "plastic bag", "729": "plate rack", "730": "plow, plough", "731": "plunger, plumber's helper", "732": "Polaroid camera, Polaroid Land camera", "733": "pole", "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "735": "poncho", "736": "pool table, billiard table, snooker table", "737": "pop bottle, soda bottle", "738": "pot, flowerpot", "739": "potter's wheel", "740": "power drill", "741": "prayer rug, prayer mat", "742": "printer", "743": "prison, prison house", "744": "projectile, missile", "745": "projector", "746": "puck, hockey puck", "747": "punching bag, punch bag, punching ball, punchball", "748": "purse", "749": "quill, quill pen", "750": "quilt, comforter, comfort, puff", "751": "racer, race car, racing car", "752": "racket, racquet", "753": "radiator", "754": "radio, wireless", "755": "radio telescope, radio reflector", "756": "rain barrel", "757": "recreational vehicle, RV, R.V.", "758": "reel", "759": "reflex camera", "760": "refrigerator, icebox", "761": "remote control, remote", "762": "restaurant, eating house, eating place, eatery", "763": "revolver, six-gun, six-shooter", "764": "rifle", "765": "rocking chair, rocker", "766": "rotisserie", "767": "rubber eraser, rubber, pencil eraser", "768": "rugby ball", "769": "rule, ruler", "770": "running shoe", "771": "safe", "772": "safety pin", "773": "saltshaker, salt shaker", "774": "sandal", "775": "sarong", "776": "sax, saxophone", "777": "scabbard", "778": "scale, weighing machine", "779": "school bus", "780": "schooner", "781": "scoreboard", "782": "screen, CRT screen", "783": "screw", "784": "screwdriver", "785": "seat belt, seatbelt", "786": "sewing machine", "787": "shield, buckler", "788": "shoe shop, shoe-shop, shoe store", "789": "shoji", "790": "shopping basket", "791": "shopping cart", "792": "shovel", "793": "shower cap", "794": "shower curtain", "795": "ski", "796": "ski mask", "797": "sleeping bag", "798": "slide rule, slipstick", "799": "sliding door", "800": "slot, one-armed bandit", "801": "snorkel", "802": "snowmobile", "803": "snowplow, snowplough", "804": "soap dispenser", "805": "soccer ball", "806": "sock", "807": "solar dish, solar collector, solar furnace", "808": "sombrero", "809": "soup bowl", "810": "space bar", "811": "space heater", "812": "space shuttle", "813": "spatula", "814": "speedboat", "815": "spider web, spider's web", "816": "spindle", "817": "sports car, sport car", "818": "spotlight, spot", "819": "stage", "820": "steam locomotive", "821": "steel arch bridge", "822": "steel drum", "823": "stethoscope", "824": "stole", "825": "stone wall", "826": "stopwatch, stop watch", "827": "stove", "828": "strainer", "829": "streetcar, tram, tramcar, trolley, trolley car", "830": "stretcher", "831": "studio couch, day bed", "832": "stupa, tope", "833": "submarine, pigboat, sub, U-boat", "834": "suit, suit of clothes", "835": "sundial", "836": "sunglass", "837": "sunglasses, dark glasses, shades", "838": "sunscreen, sunblock, sun blocker", "839": "suspension bridge", "840": "swab, swob, mop", "841": "sweatshirt", "842": "swimming trunks, bathing trunks", "843": "swing", "844": "switch, electric switch, electrical switch", "845": "syringe", "846": "table lamp", "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", "848": "tape player", "849": "teapot", "850": "teddy, teddy bear", "851": "television, television system", "852": "tennis ball", "853": "thatch, thatched roof", "854": "theater curtain, theatre curtain", "855": "thimble", "856": "thresher, thrasher, threshing machine", "857": "throne", "858": "tile roof", "859": "toaster", "860": "tobacco shop, tobacconist shop, tobacconist", "861": "toilet seat", "862": "torch", "863": "totem pole", "864": "tow truck, tow car, wrecker", "865": "toyshop", "866": "tractor", "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "868": "tray", "869": "trench coat", "870": "tricycle, trike, velocipede", "871": "trimaran", "872": "tripod", "873": "triumphal arch", "874": "trolleybus, trolley coach, trackless trolley", "875": "trombone", "876": "tub, vat", "877": "turnstile", "878": "typewriter keyboard", "879": "umbrella", "880": "unicycle, monocycle", "881": "upright, upright piano", "882": "vacuum, vacuum cleaner", "883": "vase", "884": "vault", "885": "velvet", "886": "vending machine", "887": "vestment", "888": "viaduct", "889": "violin, fiddle", "890": "volleyball", "891": "waffle iron", "892": "wall clock", "893": "wallet, billfold, notecase, pocketbook", "894": "wardrobe, closet, press", "895": "warplane, military plane", "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "897": "washer, automatic washer, washing machine", "898": "water bottle", "899": "water jug", "900": "water tower", "901": "whiskey jug", "902": "whistle", "903": "wig", "904": "window screen", "905": "window shade", "906": "Windsor tie", "907": "wine bottle", "908": "wing", "909": "wok", "910": "wooden spoon", "911": "wool, woolen, woollen", "912": "worm fence, snake fence, snake-rail fence, Virginia fence", "913": "wreck", "914": "yawl", "915": "yurt", "916": "web site, website, internet site, site", "917": "comic book", "918": "crossword puzzle, crossword", "919": "street sign", "920": "traffic light, traffic signal, stoplight", "921": "book jacket, dust cover, dust jacket, dust wrapper", "922": "menu", "923": "plate", "924": "guacamole", "925": "consomme", "926": "hot pot, hotpot", "927": "trifle", "928": "ice cream, icecream", "929": "ice lolly, lolly, lollipop, popsicle", "930": "French loaf", "931": "bagel, beigel", "932": "pretzel", "933": "cheeseburger", "934": "hotdog, hot dog, red hot", "935": "mashed potato", "936": "head cabbage", "937": "broccoli", "938": "cauliflower", "939": "zucchini, courgette", "940": "spaghetti squash", "941": "acorn squash", "942": "butternut squash", "943": "cucumber, cuke", "944": "artichoke, globe artichoke", "945": "bell pepper", "946": "cardoon", "947": "mushroom", "948": "Granny Smith", "949": "strawberry", "950": "orange", "951": "lemon", "952": "fig", "953": "pineapple, ananas", "954": "banana", "955": "jackfruit, jak, jack", "956": "custard apple", "957": "pomegranate", "958": "hay", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "961": "dough", "962": "meat loaf, meatloaf", "963": "pizza, pizza pie", "964": "potpie", "965": "burrito", "966": "red wine", "967": "espresso", "968": "cup", "969": "eggnog", "970": "alp", "971": "bubble", "972": "cliff, drop, drop-off", "973": "coral reef", "974": "geyser", "975": "lakeside, lakeshore", "976": "promontory, headland, head, foreland", "977": "sandbar, sand bar", "978": "seashore, coast, seacoast, sea-coast", "979": "valley, vale", "980": "volcano", "981": "ballplayer, baseball player", "982": "groom, bridegroom", "983": "scuba diver", "984": "rapeseed", "985": "daisy", "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "987": "corn", "988": "acorn", "989": "hip, rose hip, rosehip", "990": "buckeye, horse chestnut, conker", "991": "coral fungus", "992": "agaric", "993": "gyromitra", "994": "stinkhorn, carrion fungus", "995": "earthstar", "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "997": "bolete", "998": "ear, spike, capitulum", "999": "toilet tissue, toilet paper, bathroom tissue"} -------------------------------------------------------------------------------- /featmap/5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg -------------------------------------------------------------------------------- /featmap/5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg -------------------------------------------------------------------------------- /featmap/5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg -------------------------------------------------------------------------------- /fire_smoke_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: fire_smoke_demo.py 5 | @Author:kong 6 | @Time: 2020年01月02日15时45分 7 | @Description: 8 | ''' 9 | import json 10 | from PIL import Image, ImageDraw, ImageFont 11 | import torch 12 | from torch import nn 13 | from torchvision import transforms 14 | from efficientnet_pytorch import FireSmokeEfficientNet 15 | import collections 16 | image_dir = './tests/000294.jpg' 17 | model_para = collections.OrderedDict() 18 | model = FireSmokeEfficientNet.from_arch('efficientnet-b0') 19 | # out_channels = model._fc.in_features 20 | model._fc = nn.Linear(1280, 3) 21 | print(model) 22 | modelpara = torch.load('./checkpoint.pth.tar') 23 | # print(modelpara['state_dict'].keys()) 24 | for key in modelpara['state_dict'].keys(): 25 | # print(key[7:]) 26 | # newkey = model_para[key.split('.',2)[-1]] 27 | # print(newkey) 28 | model_para[key[7:]] =modelpara['state_dict'][key] 29 | 30 | # print(model_para.keys()) 31 | # 训练模型转换 32 | 33 | 34 | 35 | model.load_state_dict(model_para) 36 | 37 | # Preprocess image 38 | tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), 39 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),]) 40 | image = Image.open(image_dir) 41 | img = tfms(image).unsqueeze(0) 42 | print(img.shape) # torch.Size([1, 3, 224, 224]) 43 | 44 | # Load ImageNet class names 45 | labels_map = json.load(open('examples/simple/fire_smoke_map.txt')) 46 | labels_map = [labels_map[str(i)] for i in range(3)] 47 | 48 | # Classify 49 | model.eval() 50 | with torch.no_grad(): 51 | outputs = model(img) 52 | 53 | draw = ImageDraw.Draw(image) 54 | font = ImageFont.truetype('simkai.ttf', 30) 55 | # Print predictions 56 | print('-----') 57 | cout = 0 58 | for idx in torch.topk(outputs, k=2).indices.squeeze(0).tolist(): 59 | cout += 1 60 | prob = torch.softmax(outputs, dim=1)[0, idx].item() 61 | print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100)) 62 | position = (10, 30*cout - 20) 63 | text = '{label:<5} :{p:.2f}%'.format(label=labels_map[idx], p=prob*100) 64 | draw.text(position, text, font=font, fill="#ff0000", spacing=0, align='left') 65 | 66 | image.save('results/result_{}'.format(image_dir.split('/')[-1])) -------------------------------------------------------------------------------- /fire_smoke_detection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: fire_smoke_detection.py 5 | @Author:kong 6 | @Time: 2020年01月03日10时50分 7 | @Description: efficientnet烟火检测 8 | ''' 9 | 10 | import json 11 | from PIL import Image, ImageDraw, ImageFont 12 | import torch 13 | from torchvision import transforms 14 | from efficientnet_pytorch import FireSmokeEfficientNet 15 | import collections 16 | 17 | # from PIL import Image, ImageDraw, ImageFont 18 | image_path = './tests/000127.jpg' 19 | col = 5 20 | row = 4 21 | interCLS = ["smoke","fire"] 22 | model_para = collections.OrderedDict() 23 | model = FireSmokeEfficientNet.from_arch('efficientnet-b0') 24 | # out_channels = model._fc.in_features 25 | model._fc = torch.nn.Linear(1280, 3) 26 | modelpara = torch.load('./checkpoint.pth.tar') 27 | # print(modelpara['state_dict'].keys()) 28 | for key in modelpara['state_dict'].keys(): 29 | # print(key[7:]) 30 | # newkey = model_para[key.split('.',2)[-1]] 31 | # print(newkey) 32 | model_para[key[7:]] =modelpara['state_dict'][key] 33 | 34 | # print(model_para.keys()) 35 | # 训练模型转换 36 | model.load_state_dict(model_para) 37 | 38 | # Preprocess image 39 | tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),]) 41 | 42 | # Load ImageNet class names 43 | labels_map = json.load(open('examples/simple/fire_smoke_map.txt')) 44 | labels_map = [labels_map[str(i)] for i in range(3)] 45 | 46 | image = Image.open(image_path) 47 | width = image.width 48 | height = image.height 49 | w_len = int(width / col) ##每个block 长宽:h_len/w_len 50 | h_len = int(height / row) 51 | 52 | draw = ImageDraw.Draw(image) 53 | font = ImageFont.truetype("simkai.ttf", 40, encoding="utf-8")#格式,参数分别为 字体文件,文字大小,编码方式 54 | 55 | for r in range(row): 56 | for c in range(col): 57 | image_tmp = image.crop((c*w_len,r*h_len,(c+1)*w_len,(r+1)*h_len)) 58 | img_tmp = tfms(image_tmp).unsqueeze(0) 59 | model.eval() 60 | with torch.no_grad(): 61 | outputs = model(img_tmp) 62 | print('-----') 63 | for idx in torch.topk(outputs, k=1).indices.squeeze(0).tolist(): 64 | prob = torch.softmax(outputs, dim=1)[0, idx].item() 65 | print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob * 100)) 66 | # image_tmp.save('{}_{}_{}.jpg'.format(r, c, labels_map[idx])) 67 | if prob > 0.99 and labels_map[idx] in interCLS: 68 | draw.line([(c*w_len,r*h_len),((c+1)*w_len, r*h_len),((c+1)*w_len, (r+1)*h_len),(c*w_len,(r+1)*h_len),(c*w_len,r*h_len)],fill = (255,0,0), width = 2) 69 | draw.text(((c+1)*w_len, r*h_len), labels_map[idx], (255, 0, 0), font=font) # 写文字,参数为文字添加位置,添加的文字字符串,文字颜色,格式 70 | 71 | image.save("results/det_results{}".format(image_path.split('/')[-1])) 72 | 73 | -------------------------------------------------------------------------------- /model_vusal.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: model_vusal.py 5 | @Author:kong 6 | @Time: 2020年01月07日19时17分 7 | @Description: 8 | ''' 9 | 10 | import matplotlib.pyplot as plt 11 | import torch 12 | from torchvision import models, transforms 13 | from torch.utils.data import DataLoader, Dataset 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | import torch.nn.functional as F 17 | import os 18 | import numpy as np 19 | from torchvision.datasets import ImageFolder 20 | 21 | torch.cuda.set_device(0) # 设置GPU ID 22 | is_cuda = True 23 | simple_transform = transforms.Compose([transforms.Resize((224, 224)), 24 | transforms.ToTensor(), # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255 25 | transforms.Normalize([0.485, 0.456, 0.406], # std 26 | [0.229, 0.224, 0.225])]) 27 | 28 | # mean 先将输入归一化到(0,1),再使用公式”(x-mean)/std”,将每个元素分布到(-1,1) 29 | # 使用 ImageFolder 必须有对应的目录结构 30 | train = ImageFolder("/home/kong/Documents/EfficientNet-PyTorch/cropdata/train", simple_transform) 31 | valid = ImageFolder("/home/kong/Documents/EfficientNet-PyTorch/cropdata/val", simple_transform) 32 | train_loader = DataLoader(train, batch_size=1, shuffle=False, num_workers=5) 33 | val_loader = DataLoader(valid, batch_size=1, shuffle=False, num_workers=5) 34 | 35 | vgg = models.mobilenet_v2(pretrained=True).cuda() 36 | 37 | 38 | # 提取不同层输出的 主要代码 39 | class LayerActivations: 40 | features = None 41 | 42 | def __init__(self, model, layer_num): 43 | self.hook = model[layer_num].register_forward_hook(self.hook_fn) 44 | 45 | def hook_fn(self, module, input, output): 46 | self.features = output.cpu() 47 | 48 | def remove(self): 49 | self.hook.remove() 50 | 51 | 52 | conv_out = LayerActivations(vgg.features, 18) # 提出第 一个卷积层的输出 53 | img = next(iter(train_loader))[0] 54 | o = vgg(Variable(img.cuda())) 55 | conv_out.remove() # 56 | act = conv_out.features # act 即 第0层输出的特征 57 | 58 | # 可视化 输出 59 | fig = plt.figure(figsize=(20, 50)) 60 | fig.subplots_adjust(left=0, right=1, bottom=0, top=0.8, hspace=0, wspace=0.2) 61 | for i in range(30): 62 | ax = fig.add_subplot(12, 5, i + 1, xticks=[], yticks=[]) 63 | ax.imshow(act[0][i].detach().numpy(), cmap="gray") 64 | plt.show() -------------------------------------------------------------------------------- /results/acc_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/acc_loss.png -------------------------------------------------------------------------------- /results/det_results000127.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/det_results000127.jpg -------------------------------------------------------------------------------- /results/result_000127.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/result_000127.jpg -------------------------------------------------------------------------------- /results/result_7e9ee24563cc31d34de2020e1acaecc5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/result_7e9ee24563cc31d34de2020e1acaecc5.jpeg -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Note: To use the 'upload' functionality of this file, you must: 5 | # $ pipenv install twine --dev 6 | 7 | import io 8 | import os 9 | import sys 10 | from shutil import rmtree 11 | 12 | from setuptools import find_packages, setup, Command 13 | 14 | # Package meta-data. 15 | NAME = 'efficientnet_pytorch' 16 | DESCRIPTION = 'EfficientNet implemented in PyTorch.' 17 | URL = 'https://github.com/lukemelas/efficientnet_pytorch' 18 | EMAIL = 'lmelaskyriazi@college.harvard.edu' 19 | AUTHOR = 'Luke' 20 | REQUIRES_PYTHON = '>=3.5.0' 21 | VERSION = '0.5.1' 22 | 23 | # What packages are required for this module to be executed? 24 | REQUIRED = [ 25 | 'torch' 26 | ] 27 | 28 | # What packages are optional? 29 | EXTRAS = { 30 | # 'fancy feature': ['django'], 31 | } 32 | 33 | # The rest you shouldn't have to touch too much :) 34 | # ------------------------------------------------ 35 | # Except, perhaps the License and Trove Classifiers! 36 | # If you do change the License, remember to change the Trove Classifier for that! 37 | 38 | here = os.path.abspath(os.path.dirname(__file__)) 39 | 40 | # Import the README and use it as the long-description. 41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file! 42 | try: 43 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 44 | long_description = '\n' + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Load the package's __version__.py module as a dictionary. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, '__version__.py')) as f: 53 | exec(f.read(), about) 54 | else: 55 | about['__version__'] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | """Support setup.py upload.""" 60 | 61 | description = 'Build and publish the package.' 62 | user_options = [] 63 | 64 | @staticmethod 65 | def status(s): 66 | """Prints things in bold.""" 67 | print('\033[1m{0}\033[0m'.format(s)) 68 | 69 | def initialize_options(self): 70 | pass 71 | 72 | def finalize_options(self): 73 | pass 74 | 75 | def run(self): 76 | try: 77 | self.status('Removing previous builds…') 78 | rmtree(os.path.join(here, 'dist')) 79 | except OSError: 80 | pass 81 | 82 | self.status('Building Source and Wheel (universal) distribution…') 83 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) 84 | 85 | self.status('Uploading the package to PyPI via Twine…') 86 | os.system('twine upload dist/*') 87 | 88 | self.status('Pushing git tags…') 89 | os.system('git tag v{0}'.format(about['__version__'])) 90 | os.system('git push --tags') 91 | 92 | sys.exit() 93 | 94 | 95 | # Where the magic happens: 96 | setup( 97 | name=NAME, 98 | version=about['__version__'], 99 | description=DESCRIPTION, 100 | long_description=long_description, 101 | long_description_content_type='text/markdown', 102 | author=AUTHOR, 103 | author_email=EMAIL, 104 | python_requires=REQUIRES_PYTHON, 105 | url=URL, 106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 107 | # py_modules=['model'], # If your package is a single module, use this instead of 'packages' 108 | install_requires=REQUIRED, 109 | extras_require=EXTRAS, 110 | include_package_data=True, 111 | license='Apache', 112 | classifiers=[ 113 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 114 | 'License :: OSI Approved :: Apache Software License', 115 | 'Programming Language :: Python', 116 | 'Programming Language :: Python :: 3', 117 | 'Programming Language :: Python :: 3.6', 118 | ], 119 | # $ setup.py publish support. 120 | cmdclass={ 121 | 'upload': UploadCommand, 122 | }, 123 | ) 124 | -------------------------------------------------------------------------------- /tests/000127.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/tests/000127.jpg -------------------------------------------------------------------------------- /tests/7e9ee24563cc31d34de2020e1acaecc5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/tests/7e9ee24563cc31d34de2020e1acaecc5.jpeg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | @File: train.py 5 | @Author:kong 6 | @Time: 2020年01月02日13时47分 7 | @Description: 8 | ''' 9 | import argparse 10 | import os 11 | import random 12 | import shutil 13 | import time 14 | import warnings 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.parallel 19 | import torch.backends.cudnn as cudnn 20 | import torch.distributed as dist 21 | import torch.optim 22 | import torch.multiprocessing as mp 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | # import torchvision.models as models 28 | from torchvision import transforms 29 | import matplotlib.pyplot as plt 30 | from efficientnet_pytorch import FireSmokeEfficientNet 31 | 32 | parser = argparse.ArgumentParser(description='PyTorch Fire and Smoke Training with EfficientNet') 33 | parser.add_argument('-da','--data', metavar='DIR',default="./cropdata", 34 | help='path to dataset') 35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='efficientnet-b0', 36 | help='model architecture: ' + ' (default: resnet18)') 37 | parser.add_argument('--num_cls',default = 3, type = int, help = 'number of class in ur cls task') 38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--epochs', default=25, type=int, metavar='N', 41 | help='number of total epochs to run') 42 | parser.add_argument('--reduce',default = [15,20], type = list, help = 'lr decay list') 43 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 44 | help='manual epoch number (useful on restarts)') 45 | parser.add_argument('-b', '--batch-size', default=64, type=int, 46 | metavar='N', 47 | help='mini-batch size (default: 256), this is the total ' 48 | 'batch size of all GPUs on the current node when ' 49 | 'using Data Parallel or Distributed Data Parallel') 50 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, 51 | metavar='LR', help='initial learning rate', dest='lr') 52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 53 | help='momentum') 54 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 55 | metavar='W', help='weight decay (default: 1e-4)', 56 | dest='weight_decay') 57 | parser.add_argument('-p', '--print-freq', default=16, type=int, 58 | metavar='N', help='print frequency (default: 10)') 59 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 60 | help='path to latest checkpoint (default: none)') 61 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 62 | help='evaluate model on validation set') 63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 64 | help='use pre-trained model') 65 | parser.add_argument('--world-size', default=-1, type=int, 66 | help='number of nodes for distributed training') 67 | parser.add_argument('--rank', default=-1, type=int, 68 | help='node rank for distributed training') 69 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 70 | help='url used to set up distributed training') 71 | parser.add_argument('--dist-backend', default='nccl', type=str, 72 | help='distributed backend') 73 | parser.add_argument('--seed', default=None, type=int, 74 | help='seed for initializing training. ') 75 | parser.add_argument('--gpu', default=None, type=int, 76 | help='GPU id to use.') 77 | parser.add_argument('--multiprocessing-distributed', action='store_true', 78 | help='Use multi-processing distributed training to launch ' 79 | 'N processes per node, which has N GPUs. This is the ' 80 | 'fastest way to use PyTorch for either single node or ' 81 | 'multi node data parallel training') 82 | 83 | best_acc1 = 0 84 | 85 | def main(): 86 | args = parser.parse_args() 87 | 88 | if args.seed is not None: 89 | random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | cudnn.deterministic = True 92 | warnings.warn('You have chosen to seed training. ' 93 | 'This will turn on the CUDNN deterministic setting, ' 94 | 'which can slow down your training considerably! ' 95 | 'You may see unexpected behavior when restarting ' 96 | 'from checkpoints.') 97 | 98 | if args.gpu is not None: 99 | warnings.warn('You have chosen a specific GPU. This will completely ' 100 | 'disable data parallelism.') 101 | 102 | if args.dist_url == "env://" and args.world_size == -1: 103 | args.world_size = int(os.environ["WORLD_SIZE"]) 104 | 105 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 106 | 107 | ngpus_per_node = torch.cuda.device_count() 108 | if args.multiprocessing_distributed: 109 | # Since we have ngpus_per_node processes per node, the total world_size 110 | # needs to be adjusted accordingly 111 | args.world_size = ngpus_per_node * args.world_size 112 | # Use torch.multiprocessing.spawn to launch distributed processes: the 113 | # main_worker process function 114 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 115 | else: 116 | # Simply call main_worker function 117 | main_worker(args.gpu, ngpus_per_node, args) 118 | 119 | def get_current_lr(epoch, args): 120 | lr = args.lr 121 | for i, lr_decay_epoch in enumerate(args.reduce): 122 | if epoch >= lr_decay_epoch: 123 | lr *= 0.1 124 | return lr 125 | 126 | def adjust_learning_rate(optimizer, epoch, args): 127 | lr = get_current_lr(epoch, args) 128 | print("current lr is:", lr) 129 | for param_group in optimizer.param_groups: 130 | param_group['lr'] = lr 131 | 132 | 133 | def main_worker(gpu, ngpus_per_node, args): 134 | global best_acc1 135 | args.gpu = gpu 136 | 137 | if args.gpu is not None: 138 | print("Use GPU: {} for training".format(args.gpu)) 139 | 140 | if args.distributed: 141 | if args.dist_url == "env://" and args.rank == -1: 142 | args.rank = int(os.environ["RANK"]) 143 | if args.multiprocessing_distributed: 144 | # For multiprocessing distributed training, rank needs to be the 145 | # global rank among all the processes 146 | args.rank = args.rank * ngpus_per_node + gpu 147 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 148 | world_size=args.world_size, rank=args.rank) 149 | # create model 150 | if args.pretrained: 151 | print("=> using pre-trained model '{}'".format(args.arch)) 152 | model = FireSmokeEfficientNet.from_pretrained(args) 153 | print(model) 154 | else: 155 | print("=> creating model '{}'".format(args.arch)) 156 | model = FireSmokeEfficientNet.from_pretrained(args) 157 | print(model) 158 | 159 | if args.distributed: 160 | # For multiprocessing distributed, DistributedDataParallel constructor 161 | # should always set the single device scope, otherwise, 162 | # DistributedDataParallel will use all available devices. 163 | if args.gpu is not None: 164 | torch.cuda.set_device(args.gpu) 165 | model.cuda(args.gpu) 166 | # When using a single GPU per process and per 167 | # DistributedDataParallel, we need to divide the batch size 168 | # ourselves based on the total number of GPUs we have 169 | args.batch_size = int(args.batch_size / ngpus_per_node) 170 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 171 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 172 | else: 173 | model.cuda() 174 | # DistributedDataParallel will divide and allocate batch_size to all 175 | # available GPUs if device_ids are not set 176 | model = torch.nn.parallel.DistributedDataParallel(model) 177 | elif args.gpu is not None: 178 | torch.cuda.set_device(args.gpu) 179 | model = model.cuda(args.gpu) 180 | else: 181 | # DataParallel will divide and allocate batch_size to all available GPUs 182 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 183 | model.features = torch.nn.DataParallel(model.features) 184 | model.cuda() 185 | else: 186 | model = torch.nn.DataParallel(model).cuda() 187 | 188 | # define loss function (criterion) and optimizer 189 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 190 | # 191 | # optimizer = torch.optim.SGD(model.parameters(), args.lr, 192 | # momentum=args.momentum, 193 | # weight_decay=args.weight_decay) 194 | optimizer = torch.optim.Adam(model.parameters(), args.lr) 195 | # optionally resume from a checkpoint 196 | if args.resume: 197 | if os.path.isfile(args.resume): 198 | print("=> loading checkpoint '{}'".format(args.resume)) 199 | if args.gpu is None: 200 | checkpoint = torch.load(args.resume) 201 | else: 202 | # Map model to be loaded to specified single gpu. 203 | loc = 'cuda:{}'.format(args.gpu) 204 | checkpoint = torch.load(args.resume, map_location=loc) 205 | args.start_epoch = checkpoint['epoch'] 206 | best_acc1 = checkpoint['best_acc1'] 207 | if args.gpu is not None: 208 | # best_acc1 may be from a checkpoint from a different GPU 209 | best_acc1 = best_acc1.to(args.gpu) 210 | model.load_state_dict(checkpoint['state_dict']) 211 | optimizer.load_state_dict(checkpoint['optimizer']) 212 | print("=> loaded checkpoint '{}' (epoch {})" 213 | .format(args.resume, checkpoint['epoch'])) 214 | else: 215 | print("=> no checkpoint found at '{}'".format(args.resume)) 216 | 217 | cudnn.benchmark = True 218 | 219 | # Data loading code 220 | traindir = os.path.join(args.data, 'train') 221 | valdir = os.path.join(args.data, 'val') 222 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 223 | std=[0.229, 0.224, 0.225]) 224 | 225 | train_dataset = datasets.ImageFolder( 226 | traindir, 227 | transforms.Compose([ 228 | transforms.RandomResizedCrop(224), 229 | transforms.RandomHorizontalFlip(), 230 | transforms.ToTensor(), 231 | normalize, 232 | ])) 233 | 234 | if args.distributed: 235 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 236 | else: 237 | train_sampler = None 238 | 239 | train_loader = torch.utils.data.DataLoader( 240 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 241 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 242 | 243 | val_loader = torch.utils.data.DataLoader( 244 | datasets.ImageFolder(valdir, transforms.Compose([ 245 | transforms.Resize(256), 246 | transforms.CenterCrop(224), 247 | transforms.ToTensor(), 248 | normalize, 249 | ])), 250 | batch_size=args.batch_size, shuffle=False, 251 | num_workers=args.workers, pin_memory=True) 252 | 253 | if args.evaluate: 254 | validate(val_loader, model, criterion, args) 255 | return 256 | loss_all = [] 257 | acc_all = [] 258 | for epoch in range(args.start_epoch, args.epochs): 259 | if args.distributed: 260 | train_sampler.set_epoch(epoch) 261 | adjust_learning_rate(optimizer, epoch, args) 262 | 263 | # train for one epoch 264 | loss_tmp, acc_tmp = train(train_loader, model, criterion, optimizer, epoch, args) 265 | loss_all.extend(loss_tmp) 266 | acc_all.extend(acc_tmp) 267 | # evaluate on validation set 268 | acc1 = validate(val_loader, model, criterion, args) 269 | # remember best acc@1 and save checkpoint 270 | is_best = acc1 > best_acc1 271 | best_acc1 = max(acc1, best_acc1) 272 | 273 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 274 | and args.rank % ngpus_per_node == 0): 275 | save_checkpoint({ 276 | 'epoch': epoch + 1, 277 | 'arch': args.arch, 278 | 'state_dict': model.state_dict(), 279 | 'best_acc1': best_acc1, 280 | 'optimizer' : optimizer.state_dict(), 281 | }, is_best) 282 | x1 = list(range(len(acc_all))) 283 | x2 = list(range(len(loss_all))) 284 | y1 = acc_all 285 | y2 = loss_all 286 | plt.subplot(2, 1, 1) 287 | # plt.plot(x1, y1, 'o-',color='r') 288 | plt.plot(x1, y1, 'o-', label="Train_Accuracy") 289 | plt.title('train acc vs. iter') 290 | plt.ylabel('train accuracy') 291 | plt.legend(loc='best') 292 | plt.subplot(2, 1, 2) 293 | plt.plot(x2, y2, '.-', label="Train_Loss") 294 | plt.xlabel('train loss vs. iter') 295 | plt.ylabel('train loss') 296 | plt.legend(loc='best') 297 | plt.savefig("acc_loss.png") 298 | plt.show() 299 | 300 | 301 | def train(train_loader, model, criterion, optimizer, epoch, args): 302 | batch_time = AverageMeter('Time', ':6.3f') 303 | data_time = AverageMeter('Data', ':6.3f') 304 | losses = AverageMeter('Loss', ':.4e') 305 | top1 = AverageMeter('Acc@1', ':6.2f') 306 | top3 = AverageMeter('Acc@3', ':6.2f') 307 | loss_all = [] 308 | acc_all = [] 309 | progress = ProgressMeter( 310 | len(train_loader), 311 | [batch_time, data_time, losses, top1, top3], 312 | prefix="Epoch: [{}]".format(epoch)) 313 | 314 | # switch to train mode 315 | model.train() 316 | end = time.time() 317 | loss_tp = [] 318 | acc_tp = [] 319 | for i, (images, target) in enumerate(train_loader): 320 | # measure data loading time 321 | data_time.update(time.time() - end) 322 | if args.gpu is not None: 323 | images = images.cuda(args.gpu, non_blocking=True) 324 | target = target.cuda(args.gpu, non_blocking=True) 325 | 326 | # compute output 327 | output = model(images) 328 | loss = criterion(output, target) 329 | # measure accuracy and record loss 330 | acc1, acc3 = accuracy(output, target, topk=(1, 3)) 331 | losses.update(loss.item(), images.size(0)) 332 | top1.update(acc1[0], images.size(0)) 333 | top3.update(acc3[0], images.size(0)) 334 | loss_tp.append(loss.item()) 335 | acc_tp.append(acc1[0]) 336 | if i % 500 == 0: 337 | loss_all.append(sum(loss_tp) / (len(loss_tp))) 338 | loss_tp = [] 339 | acc_all.append(sum(acc_tp) / (len(acc_tp))) 340 | acc_tp = [] 341 | 342 | # compute gradient and do SGD step 343 | optimizer.zero_grad() 344 | loss.backward() 345 | optimizer.step() 346 | loss_all.append(loss.item()) 347 | acc_all.append(acc1[0]) 348 | 349 | # measure elapsed time 350 | batch_time.update(time.time() - end) 351 | end = time.time() 352 | 353 | if i % args.print_freq == 0: 354 | progress.display(i) 355 | return loss_all , acc_all 356 | 357 | 358 | def validate(val_loader, model, criterion, args): 359 | batch_time = AverageMeter('Time', ':6.3f') 360 | losses = AverageMeter('Loss', ':.4e') 361 | top1 = AverageMeter('Acc@1', ':6.2f') 362 | top3 = AverageMeter('Acc@3', ':6.2f') 363 | progress = ProgressMeter( 364 | len(val_loader), 365 | [batch_time, losses, top1, top3], 366 | prefix='Test: ') 367 | 368 | # switch to evaluate mode 369 | model.eval() 370 | 371 | with torch.no_grad(): 372 | end = time.time() 373 | for i, (images, target) in enumerate(val_loader): 374 | if args.gpu is not None: 375 | images = images.cuda(args.gpu, non_blocking=True) 376 | target = target.cuda(args.gpu, non_blocking=True) 377 | 378 | # compute output 379 | output = model(images) 380 | loss = criterion(output, target) 381 | 382 | # measure accuracy and record loss 383 | acc1, acc3 = accuracy(output, target, topk=(1, 3)) 384 | losses.update(loss.item(), images.size(0)) 385 | top1.update(acc1[0], images.size(0)) 386 | top3.update(acc3[0], images.size(0)) 387 | 388 | # measure elapsed time 389 | batch_time.update(time.time() - end) 390 | end = time.time() 391 | 392 | if i % args.print_freq == 0: 393 | progress.display(i) 394 | 395 | # TODO: this should also be done with the ProgressMeter 396 | print(' * Acc@1 {top1.avg:.3f} Acc@3 {top5.avg:.3f}' 397 | .format(top1=top1, top5=top3)) 398 | 399 | return top1.avg 400 | 401 | 402 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 403 | torch.save(state, filename) 404 | if is_best: 405 | shutil.copyfile(filename, 'model_best.pth.tar') 406 | 407 | 408 | class AverageMeter(object): 409 | """Computes and stores the average and current value""" 410 | def __init__(self, name, fmt=':f'): 411 | self.name = name 412 | self.fmt = fmt 413 | self.reset() 414 | 415 | def reset(self): 416 | self.val = 0 417 | self.avg = 0 418 | self.sum = 0 419 | self.count = 0 420 | 421 | def update(self, val, n=1): 422 | self.val = val 423 | self.sum += val * n 424 | self.count += n 425 | self.avg = self.sum / self.count 426 | 427 | def __str__(self): 428 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 429 | return fmtstr.format(**self.__dict__) 430 | 431 | 432 | class ProgressMeter(object): 433 | def __init__(self, num_batches, meters, prefix=""): 434 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 435 | self.meters = meters 436 | self.prefix = prefix 437 | 438 | def display(self, batch): 439 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 440 | entries += [str(meter) for meter in self.meters] 441 | print('\t'.join(entries)) 442 | 443 | def _get_batch_fmtstr(self, num_batches): 444 | num_digits = len(str(num_batches // 1)) 445 | fmt = '{:' + str(num_digits) + 'd}' 446 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 447 | 448 | 449 | # def adjust_learning_rate(optimizer, epoch, args): 450 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 451 | # lr = args.lr * (0.1 ** (epoch // 30)) 452 | # for param_group in optimizer.param_groups: 453 | # param_group['lr'] = lr 454 | 455 | 456 | def accuracy(output, target, topk=(1,)): 457 | """Computes the accuracy over the k top predictions for the specified values of k""" 458 | with torch.no_grad(): 459 | maxk = max(topk) 460 | batch_size = target.size(0) 461 | 462 | _, pred = output.topk(maxk, 1, True, True) 463 | pred = pred.t() 464 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 465 | 466 | res = [] 467 | for k in topk: 468 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 469 | res.append(correct_k.mul_(100.0 / batch_size)) 470 | return res 471 | 472 | 473 | if __name__ == '__main__': 474 | main() --------------------------------------------------------------------------------