├── .idea
├── .gitignore
├── FireSmokeDetectionByEfficientNet.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── Conv_visual.py
├── LICENSE
├── README.md
├── cropdata
└── train
│ ├── fire
│ └── 000016.jpg
│ └── smoke
│ └── 000002.png
├── efficientnet_pytorch
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── fire_smoke_model.cpython-37.pyc
│ ├── model.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── fire_smoke_model.py
├── model.py
└── utils.py
├── examples
├── imagenet
│ ├── README.md
│ ├── data
│ │ └── README.md
│ └── main.py
└── simple
│ ├── check.ipynb
│ ├── example.ipynb
│ ├── fire_smoke_map.txt
│ ├── img.jpg
│ ├── img2.jpg
│ └── labels_map.txt
├── featmap
├── 5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg
├── 5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg
└── 5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg
├── fire_smoke_demo.py
├── fire_smoke_detection.py
├── model_vusal.py
├── results
├── acc_loss.png
├── det_results000127.jpg
├── result_000127.jpg
└── result_7e9ee24563cc31d34de2020e1acaecc5.jpeg
├── setup.py
├── tests
├── 000127.jpg
└── 7e9ee24563cc31d34de2020e1acaecc5.jpeg
└── train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
3 |
--------------------------------------------------------------------------------
/.idea/FireSmokeDetectionByEfficientNet.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Conv_visual.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: Conv_visual.py
5 | @Author:kong
6 | @Time: 2020年01月07日09时59分
7 | @Description:可视化fire&smoke 模型
8 | '''
9 |
10 |
11 | import cv2
12 | import numpy as np
13 | import torch
14 | from torch.autograd import Variable
15 | import json
16 | from PIL import Image, ImageDraw, ImageFont
17 | from torchvision import transforms
18 | from torchvision import models
19 | from efficientnet_pytorch import FireSmokeEfficientNet
20 | import collections
21 |
22 | def preprocess_image(cv2im, resize_im=True):
23 | """
24 | Processes image for CNNs
25 |
26 | Args:
27 | PIL_img (PIL_img): Image to process
28 | resize_im (bool): Resize to 224 or not
29 | returns:
30 | im_as_var (Pytorch variable): Variable that contains processed float tensor
31 | """
32 | # mean and std list for channels (Imagenet)
33 | # mean and std list for channels (Imagenet)
34 | mean = [0.485, 0.456, 0.406]
35 | std = [0.229, 0.224, 0.225]
36 | # Resize image
37 | if resize_im:
38 | cv2im = cv2.resize(cv2im, (224, 224))
39 | im_as_arr = np.float32(cv2im)
40 | im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
41 | im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
42 | # Normalize the channels
43 | for channel, _ in enumerate(im_as_arr):
44 | im_as_arr[channel] /= 255
45 | im_as_arr[channel] -= mean[channel]
46 | im_as_arr[channel] /= std[channel]
47 | # Convert to float tensor
48 | im_as_ten = torch.from_numpy(im_as_arr).float()
49 | # Add one more channel to the beginning. Tensor shape = 1,3,224,224
50 | im_as_ten.unsqueeze_(0)
51 | # Convert to Pytorch variable
52 | im_as_var = Variable(im_as_ten, requires_grad=True)
53 | return im_as_var
54 |
55 | def getPretrainedModel():
56 | model_para = collections.OrderedDict()
57 | model_tmp = FireSmokeEfficientNet.from_arch('efficientnet-b0')
58 | # out_channels = model._fc.in_features
59 | model_tmp._fc = torch.nn.Linear(1280, 3)
60 | modelpara = torch.load('./checkpoint.pth.tar')
61 | # print(modelpara['state_dict'].keys())
62 | for key in modelpara['state_dict'].keys():
63 | model_para[key[7:]] = modelpara['state_dict'][key]
64 | model_tmp.load_state_dict(model_para)
65 | return model_tmp
66 |
67 |
68 | class FeatureVisualization():
69 | def __init__(self,img_path,selected_layer):
70 | self.img_path=img_path
71 | self.image = cv2.imread(img_path)
72 | self.selected_layer=selected_layer
73 | self.pretrained_model = getPretrainedModel()
74 |
75 | def process_image(self):
76 | img=cv2.imread(self.img_path)
77 | img=preprocess_image(img)
78 | return img
79 |
80 | def get_feature(self):
81 | # input = Variable(torch.randn(1, 3, 224, 224))
82 | input=self.process_image()
83 | print(input.shape)
84 | # x=input
85 |
86 | x = self.pretrained_model._swish(self.pretrained_model._bn0(self.pretrained_model._conv_stem(input)))
87 | for index,layer in enumerate(self.pretrained_model._blocks):
88 | x=layer(x)
89 | if (index == self.selected_layer):
90 | return x
91 | # x = self.pretrained_model._conv_head(x)
92 | # return x
93 |
94 |
95 | def get_single_feature(self):
96 | features=self.get_feature()
97 | # print('特征是:',features.shape)
98 |
99 | feature=features[:,6,:,:]
100 | print(feature.shape)
101 |
102 | feature=feature.view(feature.shape[1],feature.shape[2])
103 | print(feature.shape)
104 |
105 | return feature
106 |
107 | def get_all_feature(self):
108 | features=self.get_feature()
109 | # print(':',features.shape)
110 |
111 | feature=features[:,:,:,:]
112 | print(feature.shape)
113 |
114 | feature=feature.view(feature.shape[1],feature.shape[2],feature.shape[3])
115 | print(feature.shape)
116 |
117 | return feature
118 |
119 | def save_feature_to_img(self):
120 | #to numpy
121 | feature=self.get_all_feature()
122 | feature=feature.data.numpy()
123 |
124 | #use sigmod to [0,1]
125 | feature= 1.0/(1+np.exp(-1*feature))
126 |
127 | # to [0,255]
128 | feature=np.round(feature*255)
129 | print(feature[0])
130 | print("image size:",self.image.shape)
131 | print("feature map size:",feature.shape)
132 | Nch = feature.shape[0] #channel num
133 | for i in range(Nch):
134 | print('--------------show the {}th channels-----------------'.format(i))
135 | feature_resize = cv2.resize(feature[i], (self.image.shape[1], self.image.shape[0]))
136 | feature_resize = cv2.cvtColor(feature_resize, cv2.COLOR_GRAY2BGR)
137 | feature_resize = cv2.putText(feature_resize,'first_depthwise_conv {}th feature map'.format(i),(10,30),cv2.FONT_HERSHEY_COMPLEX,1,(0,255,0),3)
138 | image_cont = np.concatenate([self.image,feature_resize],axis=0)
139 | cv2.imwrite('featmap/5rd_depthwise_conv_featmap{}_{}'.format(i,self.img_path.split('/')[-1]),image_cont)
140 |
141 | if __name__=='__main__':
142 | # get class
143 | myClass=FeatureVisualization('./tests/000294.jpg',1)
144 | # print (myClass.pretrained_model)
145 | print("-----------------------------------------------------")
146 | print(myClass.pretrained_model)
147 |
148 | myClass.save_feature_to_img()
149 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FireSmokeDetectionByEfficientNet
2 | [EfficientNet](https://arxiv.org/abs/1905.11946) is a wonderful classification network. the efficientnet implementation refer to [[EfficientNet-PyTorch](https://github.com/lukemelas/EfficientNet-PyTorch)](https://github.com/lukemelas/EfficientNet-PyTorch). I refer much of the code from the above implementation, and construct a network with the efficientnet feature extract layers and a fc layer of 3 output node for my fire & smoke detection task.
3 |
4 | ## 1. Introduction
5 |
6 | Fire and smoke classification and detection by training an efficientnet classifier,detections are implemented through classify crop patch of the input image.
7 |
8 | ## 2. Requirements
9 |
10 | Python 3.7、PyTorch1.3...
11 |
12 | ### 3. Running demos
13 |
14 | Download the pretrained model and simkai.ttf font from [BaiduNetDisk:](https://pan.baidu.com/s/14CM-U6bmVjXG6gNQ2CC8fw) code: awnf
15 |
16 | simple run:
17 |
18 | ```shell
19 | python fire_smoke_demo.py
20 | ```
21 |
22 | will get the classification results as follows:
23 |
24 | 
25 |
26 | or try the detection demo:
27 |
28 | ```shell
29 | python fire_smoke_detection.py
30 | ```
31 |
32 | will get results:
33 |
34 | 
35 |
36 | ## 4. Visual the CNN
37 |
38 | I visual the activation of some of the feature map as follows:
39 |
40 | 
41 |
42 | 
43 |
44 | As we can see, the cnn can automatically learn the edge or shape or ROI of the predict targets,
45 |
46 | some of the filters can recognize the fire, some can deal with other features.
47 |
48 | ## 5. Train Custom Dataset
49 |
50 | here offer a scrip to train ur own classification model using EfficientNet:
51 |
52 | ```shell
53 | python train.py --data [ur dataset path] --arch [efficientnet model:efficientnet-b0-7] --num_cls [ur task class num]
54 | ```
55 |
56 | Refer to the cold for args details.
57 |
58 | For inference, u should change the label map txt for ur task.
59 |
60 | In my task:
61 |
62 | ```shell
63 | python train.py --data ./cropdata --arch efficientnet-b0 --num_cls 3
64 | ```
65 |
66 | 
67 |
68 | ## Dataset
69 | https://pan.baidu.com/s/1eRXtYVrn6baJ6PRMOTjyzQ code: srav
70 |
--------------------------------------------------------------------------------
/cropdata/train/fire/000016.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/cropdata/train/fire/000016.jpg
--------------------------------------------------------------------------------
/cropdata/train/smoke/000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/cropdata/train/smoke/000002.png
--------------------------------------------------------------------------------
/efficientnet_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.5.1"
2 | from .fire_smoke_model import FireSmokeEfficientNet
3 | from .utils import (
4 | GlobalParams,
5 | BlockArgs,
6 | BlockDecoder,
7 | efficientnet,
8 | get_model_params,
9 | )
10 |
11 |
--------------------------------------------------------------------------------
/efficientnet_pytorch/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/efficientnet_pytorch/__pycache__/fire_smoke_model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/fire_smoke_model.cpython-37.pyc
--------------------------------------------------------------------------------
/efficientnet_pytorch/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/efficientnet_pytorch/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/efficientnet_pytorch/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/efficientnet_pytorch/fire_smoke_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: fire_smoke_model.py
5 | @Author:konglingran
6 | @Time: 2020年01月02日15时52分
7 | @Description:
8 | '''
9 | import torch
10 | from torch import nn
11 | from torch.nn import functional as F
12 |
13 | from .utils import (
14 | round_filters,
15 | round_repeats,
16 | drop_connect,
17 | get_same_padding_conv2d,
18 | get_model_params,
19 | efficientnet_params,
20 | load_pretrained_weights,
21 | Swish,
22 | MemoryEfficientSwish,
23 | )
24 |
25 |
26 | class MBConvBlock(nn.Module):
27 | """
28 | Mobile Inverted Residual Bottleneck Block
29 |
30 | Args:
31 | block_args (namedtuple): BlockArgs, see above
32 | global_params (namedtuple): GlobalParam, see above
33 |
34 | Attributes:
35 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
36 | """
37 |
38 | def __init__(self, block_args, global_params):
39 | super().__init__()
40 | self._block_args = block_args
41 | self._bn_mom = 1 - global_params.batch_norm_momentum
42 | self._bn_eps = global_params.batch_norm_epsilon
43 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
44 | self.id_skip = block_args.id_skip # skip connection and drop connect
45 |
46 | # Get static or dynamic convolution depending on image size
47 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
48 |
49 | # Expansion phase
50 | inp = self._block_args.input_filters # number of input channels
51 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
52 | if self._block_args.expand_ratio != 1:
53 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
54 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
55 |
56 | # Depthwise convolution phase
57 | k = self._block_args.kernel_size
58 | s = self._block_args.stride
59 | self._depthwise_conv = Conv2d(
60 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
61 | kernel_size=k, stride=s, bias=False)
62 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
63 |
64 | # Squeeze and Excitation layer, if desired
65 | if self.has_se:
66 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
67 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
68 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
69 |
70 | # Output phase
71 | final_oup = self._block_args.output_filters
72 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
73 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
74 | self._swish = MemoryEfficientSwish()
75 |
76 | def forward(self, inputs, drop_connect_rate=None):
77 | """
78 | :param inputs: input tensor
79 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
80 | :return: output of block
81 | """
82 |
83 | # Expansion and Depthwise Convolution
84 | x = inputs
85 | if self._block_args.expand_ratio != 1:
86 | x = self._swish(self._bn0(self._expand_conv(inputs)))
87 | x = self._swish(self._bn1(self._depthwise_conv(x)))
88 |
89 | # Squeeze and Excitation
90 | if self.has_se:
91 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
92 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
93 | x = torch.sigmoid(x_squeezed) * x
94 |
95 | x = self._bn2(self._project_conv(x))
96 |
97 | # Skip connection and drop connect
98 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
99 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
100 | if drop_connect_rate:
101 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
102 | x = x + inputs # skip connection
103 | return x
104 |
105 | def set_swish(self, memory_efficient=True):
106 | """Sets swish function as memory efficient (for training) or standard (for export)"""
107 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
108 |
109 |
110 | class FireSmokeEfficientNet(nn.Module):
111 | """
112 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
113 |
114 | Args:
115 | blocks_args (list): A list of BlockArgs to construct blocks
116 | global_params (namedtuple): A set of GlobalParams shared between blocks
117 |
118 | Example:
119 | model = EfficientNet.from_pretrained('efficientnet-b0')
120 |
121 | """
122 |
123 | def __init__(self, blocks_args=None, global_params=None):
124 | super().__init__()
125 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
126 | assert len(blocks_args) > 0, 'block args must be greater than 0'
127 | self._global_params = global_params
128 | self._blocks_args = blocks_args
129 |
130 | # Get static or dynamic convolution depending on image size
131 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
132 |
133 | # Batch norm parameters
134 | bn_mom = 1 - self._global_params.batch_norm_momentum
135 | bn_eps = self._global_params.batch_norm_epsilon
136 |
137 | # Stem
138 | in_channels = 3 # rgb
139 | out_channels = round_filters(32, self._global_params) # number of output channels
140 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
141 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
142 |
143 | # Build blocks
144 | self._blocks = nn.ModuleList([])
145 | for block_args in self._blocks_args:
146 |
147 | # Update block input and output filters based on depth multiplier.
148 | block_args = block_args._replace(
149 | input_filters=round_filters(block_args.input_filters, self._global_params),
150 | output_filters=round_filters(block_args.output_filters, self._global_params),
151 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
152 | )
153 |
154 | # The first block needs to take care of stride and filter size increase.
155 | self._blocks.append(MBConvBlock(block_args, self._global_params))
156 | if block_args.num_repeat > 1:
157 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
158 | for _ in range(block_args.num_repeat - 1):
159 | self._blocks.append(MBConvBlock(block_args, self._global_params))
160 |
161 | # Head
162 | in_channels = block_args.output_filters # output of final block
163 | out_channels = round_filters(1280, self._global_params)
164 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
165 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
166 |
167 | # Final linear layer
168 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
169 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
170 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
171 | self._swish = MemoryEfficientSwish()
172 |
173 | def set_swish(self, memory_efficient=True):
174 | """Sets swish function as memory efficient (for training) or standard (for export)"""
175 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
176 | for block in self._blocks:
177 | block.set_swish(memory_efficient)
178 |
179 | def extract_features(self, inputs):
180 | """ Returns output of the final convolution layer """
181 |
182 | # Stem
183 | x = self._swish(self._bn0(self._conv_stem(inputs)))
184 |
185 | # Blocks
186 | for idx, block in enumerate(self._blocks):
187 | drop_connect_rate = self._global_params.drop_connect_rate
188 | if drop_connect_rate:
189 | drop_connect_rate *= float(idx) / len(self._blocks)
190 | x = block(x, drop_connect_rate=drop_connect_rate)
191 |
192 | # Head
193 | x = self._swish(self._bn1(self._conv_head(x)))
194 |
195 | return x
196 |
197 | def forward(self, inputs):
198 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
199 | bs = inputs.size(0)
200 | # Convolution layers
201 | x = self.extract_features(inputs)
202 |
203 | # Pooling and final linear layer
204 | x = self._avg_pooling(x)
205 | x = x.view(bs, -1)
206 | x = self._dropout(x)
207 | x = self._fc(x)
208 | return x
209 |
210 | @classmethod
211 | def from_name(cls, model_name, override_params=None):
212 | cls._check_model_name_is_valid(model_name)
213 | blocks_args, global_params = get_model_params(model_name, override_params)
214 | return cls(blocks_args, global_params)
215 |
216 | @classmethod
217 | def from_arch(cls, model_name, override_params=None):
218 | cls._check_model_name_is_valid(model_name)
219 | blocks_args, global_params = get_model_params(model_name, override_params)
220 | return cls(blocks_args, global_params)
221 |
222 | @classmethod
223 | def from_pretrained(cls, args, num_classes=1000, in_channels=3):
224 | model = cls.from_name(args.arch, override_params={'num_classes': num_classes})
225 | load_pretrained_weights(model, args.arch, load_fc=(num_classes == 1000))
226 | if in_channels != 3:
227 | Conv2d = get_same_padding_conv2d(image_size=model._global_params.image_size)
228 | out_channels = round_filters(32, model._global_params)
229 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
230 | out_channels = round_filters(1280, model._global_params)
231 | model._fc = nn.Linear(out_channels, args.num_cls)
232 | return model
233 |
234 | @classmethod
235 | def from_pretrained(cls, args, num_classes=1000):
236 | model = cls.from_name(args.arch, override_params={'num_classes': num_classes})
237 | load_pretrained_weights(model, args.arch, load_fc=(num_classes == 1000))
238 | out_channels = round_filters(1280, model._global_params)
239 | model._fc = nn.Linear(out_channels, args.num_cls)
240 | return model
241 |
242 | @classmethod
243 | def get_image_size(cls, model_name):
244 | cls._check_model_name_is_valid(model_name)
245 | _, _, res, _ = efficientnet_params(model_name)
246 | return res
247 |
248 | @classmethod
249 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
250 | """ Validates model name. None that pretrained weights are only available for
251 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
252 | num_models = 4 if also_need_pretrained_weights else 8
253 | valid_models = ['efficientnet-b' + str(i) for i in range(num_models)]
254 | if model_name not in valid_models:
255 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
256 |
--------------------------------------------------------------------------------
/efficientnet_pytorch/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 | from .utils import (
6 | round_filters,
7 | round_repeats,
8 | drop_connect,
9 | get_same_padding_conv2d,
10 | get_model_params,
11 | efficientnet_params,
12 | load_pretrained_weights,
13 | Swish,
14 | MemoryEfficientSwish,
15 | )
16 |
17 | class MBConvBlock(nn.Module):
18 | """
19 | Mobile Inverted Residual Bottleneck Block
20 |
21 | Args:
22 | block_args (namedtuple): BlockArgs, see above
23 | global_params (namedtuple): GlobalParam, see above
24 |
25 | Attributes:
26 | has_se (bool): Whether the block contains a Squeeze and Excitation layer.
27 | """
28 |
29 | def __init__(self, block_args, global_params):
30 | super().__init__()
31 | self._block_args = block_args
32 | self._bn_mom = 1 - global_params.batch_norm_momentum
33 | self._bn_eps = global_params.batch_norm_epsilon
34 | self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
35 | self.id_skip = block_args.id_skip # skip connection and drop connect
36 |
37 | # Get static or dynamic convolution depending on image size
38 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
39 |
40 | # Expansion phase
41 | inp = self._block_args.input_filters # number of input channels
42 | oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
43 | if self._block_args.expand_ratio != 1:
44 | self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
45 | self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
46 |
47 | # Depthwise convolution phase
48 | k = self._block_args.kernel_size
49 | s = self._block_args.stride
50 | self._depthwise_conv = Conv2d(
51 | in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
52 | kernel_size=k, stride=s, bias=False)
53 | self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
54 |
55 | # Squeeze and Excitation layer, if desired
56 | if self.has_se:
57 | num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
58 | self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
59 | self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
60 |
61 | # Output phase
62 | final_oup = self._block_args.output_filters
63 | self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
64 | self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
65 | self._swish = MemoryEfficientSwish()
66 |
67 | def forward(self, inputs, drop_connect_rate=None):
68 | """
69 | :param inputs: input tensor
70 | :param drop_connect_rate: drop connect rate (float, between 0 and 1)
71 | :return: output of block
72 | """
73 |
74 | # Expansion and Depthwise Convolution
75 | x = inputs
76 | if self._block_args.expand_ratio != 1:
77 | x = self._swish(self._bn0(self._expand_conv(inputs)))
78 | x = self._swish(self._bn1(self._depthwise_conv(x)))
79 |
80 | # Squeeze and Excitation
81 | if self.has_se:
82 | x_squeezed = F.adaptive_avg_pool2d(x, 1)
83 | x_squeezed = self._se_expand(self._swish(self._se_reduce(x_squeezed)))
84 | x = torch.sigmoid(x_squeezed) * x
85 |
86 | x = self._bn2(self._project_conv(x))
87 |
88 | # Skip connection and drop connect
89 | input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
90 | if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
91 | if drop_connect_rate:
92 | x = drop_connect(x, p=drop_connect_rate, training=self.training)
93 | x = x + inputs # skip connection
94 | return x
95 |
96 | def set_swish(self, memory_efficient=True):
97 | """Sets swish function as memory efficient (for training) or standard (for export)"""
98 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
99 |
100 |
101 | class EfficientNet(nn.Module):
102 | """
103 | An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
104 |
105 | Args:
106 | blocks_args (list): A list of BlockArgs to construct blocks
107 | global_params (namedtuple): A set of GlobalParams shared between blocks
108 |
109 | Example:
110 | model = EfficientNet.from_pretrained('efficientnet-b0')
111 |
112 | """
113 |
114 | def __init__(self, blocks_args=None, global_params=None):
115 | super().__init__()
116 | assert isinstance(blocks_args, list), 'blocks_args should be a list'
117 | assert len(blocks_args) > 0, 'block args must be greater than 0'
118 | self._global_params = global_params
119 | self._blocks_args = blocks_args
120 |
121 | # Get static or dynamic convolution depending on image size
122 | Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
123 |
124 | # Batch norm parameters
125 | bn_mom = 1 - self._global_params.batch_norm_momentum
126 | bn_eps = self._global_params.batch_norm_epsilon
127 |
128 | # Stem
129 | in_channels = 3 # rgb
130 | out_channels = round_filters(32, self._global_params) # number of output channels
131 | self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
132 | self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
133 |
134 | # Build blocks
135 | self._blocks = nn.ModuleList([])
136 | for block_args in self._blocks_args:
137 |
138 | # Update block input and output filters based on depth multiplier.
139 | block_args = block_args._replace(
140 | input_filters=round_filters(block_args.input_filters, self._global_params),
141 | output_filters=round_filters(block_args.output_filters, self._global_params),
142 | num_repeat=round_repeats(block_args.num_repeat, self._global_params)
143 | )
144 |
145 | # The first block needs to take care of stride and filter size increase.
146 | self._blocks.append(MBConvBlock(block_args, self._global_params))
147 | if block_args.num_repeat > 1:
148 | block_args = block_args._replace(input_filters=block_args.output_filters, stride=1)
149 | for _ in range(block_args.num_repeat - 1):
150 | self._blocks.append(MBConvBlock(block_args, self._global_params))
151 |
152 | # Head
153 | in_channels = block_args.output_filters # output of final block
154 | out_channels = round_filters(1280, self._global_params)
155 | self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
156 | self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps)
157 |
158 | # Final linear layer
159 | self._avg_pooling = nn.AdaptiveAvgPool2d(1)
160 | self._dropout = nn.Dropout(self._global_params.dropout_rate)
161 | self._fc = nn.Linear(out_channels, self._global_params.num_classes)
162 | self._swish = MemoryEfficientSwish()
163 |
164 | def set_swish(self, memory_efficient=True):
165 | """Sets swish function as memory efficient (for training) or standard (for export)"""
166 | self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
167 | for block in self._blocks:
168 | block.set_swish(memory_efficient)
169 |
170 |
171 | def extract_features(self, inputs):
172 | """ Returns output of the final convolution layer """
173 |
174 | # Stem
175 | x = self._swish(self._bn0(self._conv_stem(inputs)))
176 |
177 | # Blocks
178 | for idx, block in enumerate(self._blocks):
179 | drop_connect_rate = self._global_params.drop_connect_rate
180 | if drop_connect_rate:
181 | drop_connect_rate *= float(idx) / len(self._blocks)
182 | x = block(x, drop_connect_rate=drop_connect_rate)
183 |
184 | # Head
185 | x = self._swish(self._bn1(self._conv_head(x)))
186 |
187 | return x
188 |
189 | def forward(self, inputs):
190 | """ Calls extract_features to extract features, applies final linear layer, and returns logits. """
191 | bs = inputs.size(0)
192 | # Convolution layers
193 | x = self.extract_features(inputs)
194 |
195 | # Pooling and final linear layer
196 | x = self._avg_pooling(x)
197 | x = x.view(bs, -1)
198 | x = self._dropout(x)
199 | x = self._fc(x)
200 | return x
201 |
202 | @classmethod
203 | def from_name(cls, model_name, override_params=None):
204 | cls._check_model_name_is_valid(model_name)
205 | blocks_args, global_params = get_model_params(model_name, override_params)
206 | return cls(blocks_args, global_params)
207 |
208 | @classmethod
209 | def from_pretrained(cls, model_name, num_classes=1000, in_channels = 3):
210 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
211 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
212 | if in_channels != 3:
213 | Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size)
214 | out_channels = round_filters(32, model._global_params)
215 | model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False)
216 | return model
217 |
218 | @classmethod
219 | def from_pretrained(cls, model_name, num_classes=1000):
220 | model = cls.from_name(model_name, override_params={'num_classes': num_classes})
221 | load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000))
222 |
223 | return model
224 |
225 | @classmethod
226 | def get_image_size(cls, model_name):
227 | cls._check_model_name_is_valid(model_name)
228 | _, _, res, _ = efficientnet_params(model_name)
229 | return res
230 |
231 | @classmethod
232 | def _check_model_name_is_valid(cls, model_name, also_need_pretrained_weights=False):
233 | """ Validates model name. None that pretrained weights are only available for
234 | the first four models (efficientnet-b{i} for i in 0,1,2,3) at the moment. """
235 | num_models = 4 if also_need_pretrained_weights else 8
236 | valid_models = ['efficientnet-b'+str(i) for i in range(num_models)]
237 | if model_name not in valid_models:
238 | raise ValueError('model_name should be one of: ' + ', '.join(valid_models))
239 |
--------------------------------------------------------------------------------
/efficientnet_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains helper functions for building the model and for loading model parameters.
3 | These helper functions are built to mirror those in the official TensorFlow implementation.
4 | """
5 |
6 | import re
7 | import math
8 | import collections
9 | from functools import partial
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 | from torch.utils import model_zoo
14 |
15 | ########################################################################
16 | ############### HELPERS FUNCTIONS FOR MODEL ARCHITECTURE ###############
17 | ########################################################################
18 |
19 |
20 | # Parameters for the entire model (stem, all blocks, and head)
21 | GlobalParams = collections.namedtuple('GlobalParams', [
22 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate',
23 | 'num_classes', 'width_coefficient', 'depth_coefficient',
24 | 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size'])
25 |
26 | # Parameters for an individual model block
27 | BlockArgs = collections.namedtuple('BlockArgs', [
28 | 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
29 | 'expand_ratio', 'id_skip', 'stride', 'se_ratio'])
30 |
31 | # Change namedtuple defaults
32 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
33 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)
34 |
35 |
36 | class SwishImplementation(torch.autograd.Function):
37 | @staticmethod
38 | def forward(ctx, i):
39 | result = i * torch.sigmoid(i)
40 | ctx.save_for_backward(i)
41 | return result
42 |
43 | @staticmethod
44 | def backward(ctx, grad_output):
45 | i = ctx.saved_variables[0]
46 | sigmoid_i = torch.sigmoid(i)
47 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
48 |
49 |
50 | class MemoryEfficientSwish(nn.Module):
51 | def forward(self, x):
52 | return SwishImplementation.apply(x)
53 |
54 | class Swish(nn.Module):
55 | def forward(self, x):
56 | return x * torch.sigmoid(x)
57 |
58 |
59 | def round_filters(filters, global_params):
60 | """ Calculate and round number of filters based on depth multiplier. """
61 | multiplier = global_params.width_coefficient
62 | if not multiplier:
63 | return filters
64 | divisor = global_params.depth_divisor
65 | min_depth = global_params.min_depth
66 | filters *= multiplier
67 | min_depth = min_depth or divisor
68 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
69 | if new_filters < 0.9 * filters: # prevent rounding by more than 10%
70 | new_filters += divisor
71 | return int(new_filters)
72 |
73 |
74 | def round_repeats(repeats, global_params):
75 | """ Round number of filters based on depth multiplier. """
76 | multiplier = global_params.depth_coefficient
77 | if not multiplier:
78 | return repeats
79 | return int(math.ceil(multiplier * repeats))
80 |
81 |
82 | def drop_connect(inputs, p, training):
83 | """ Drop connect. """
84 | if not training: return inputs
85 | batch_size = inputs.shape[0]
86 | keep_prob = 1 - p
87 | random_tensor = keep_prob
88 | random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
89 | binary_tensor = torch.floor(random_tensor)
90 | output = inputs / keep_prob * binary_tensor
91 | return output
92 |
93 |
94 | def get_same_padding_conv2d(image_size=None):
95 | """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
96 | Static padding is necessary for ONNX exporting of models. """
97 | if image_size is None:
98 | return Conv2dDynamicSamePadding
99 | else:
100 | return partial(Conv2dStaticSamePadding, image_size=image_size)
101 |
102 |
103 | class Conv2dDynamicSamePadding(nn.Conv2d):
104 | """ 2D Convolutions like TensorFlow, for a dynamic image size """
105 |
106 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
107 | super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
108 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
109 |
110 | def forward(self, x):
111 | ih, iw = x.size()[-2:]
112 | kh, kw = self.weight.size()[-2:]
113 | sh, sw = self.stride
114 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
115 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
116 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
117 | if pad_h > 0 or pad_w > 0:
118 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
119 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
120 |
121 |
122 | class Conv2dStaticSamePadding(nn.Conv2d):
123 | """ 2D Convolutions like TensorFlow, for a fixed image size"""
124 |
125 | def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
126 | super().__init__(in_channels, out_channels, kernel_size, **kwargs)
127 | self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
128 |
129 | # Calculate padding based on image size and save it
130 | assert image_size is not None
131 | ih, iw = image_size if type(image_size) == list else [image_size, image_size]
132 | kh, kw = self.weight.size()[-2:]
133 | sh, sw = self.stride
134 | oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
135 | pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
136 | pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
137 | if pad_h > 0 or pad_w > 0:
138 | self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
139 | else:
140 | self.static_padding = Identity()
141 |
142 | def forward(self, x):
143 | x = self.static_padding(x)
144 | x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
145 | return x
146 |
147 |
148 | class Identity(nn.Module):
149 | def __init__(self, ):
150 | super(Identity, self).__init__()
151 |
152 | def forward(self, input):
153 | return input
154 |
155 |
156 | ########################################################################
157 | ############## HELPERS FUNCTIONS FOR LOADING MODEL PARAMS ##############
158 | ########################################################################
159 |
160 |
161 | def efficientnet_params(model_name):
162 | """ Map EfficientNet model name to parameter coefficients. """
163 | params_dict = {
164 | # Coefficients: width,depth,res,dropout
165 | 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
166 | 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
167 | 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
168 | 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
169 | 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
170 | 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
171 | 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
172 | 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
173 | }
174 | return params_dict[model_name]
175 |
176 |
177 | class BlockDecoder(object):
178 | """ Block Decoder for readability, straight from the official TensorFlow repository """
179 |
180 | @staticmethod
181 | def _decode_block_string(block_string):
182 | """ Gets a block through a string notation of arguments. """
183 | assert isinstance(block_string, str)
184 |
185 | ops = block_string.split('_')
186 | options = {}
187 | for op in ops:
188 | splits = re.split(r'(\d.*)', op)
189 | if len(splits) >= 2:
190 | key, value = splits[:2]
191 | options[key] = value
192 |
193 | # Check stride
194 | assert (('s' in options and len(options['s']) == 1) or
195 | (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
196 |
197 | return BlockArgs(
198 | kernel_size=int(options['k']),
199 | num_repeat=int(options['r']),
200 | input_filters=int(options['i']),
201 | output_filters=int(options['o']),
202 | expand_ratio=int(options['e']),
203 | id_skip=('noskip' not in block_string),
204 | se_ratio=float(options['se']) if 'se' in options else None,
205 | stride=[int(options['s'][0])])
206 |
207 | @staticmethod
208 | def _encode_block_string(block):
209 | """Encodes a block to a string."""
210 | args = [
211 | 'r%d' % block.num_repeat,
212 | 'k%d' % block.kernel_size,
213 | 's%d%d' % (block.strides[0], block.strides[1]),
214 | 'e%s' % block.expand_ratio,
215 | 'i%d' % block.input_filters,
216 | 'o%d' % block.output_filters
217 | ]
218 | if 0 < block.se_ratio <= 1:
219 | args.append('se%s' % block.se_ratio)
220 | if block.id_skip is False:
221 | args.append('noskip')
222 | return '_'.join(args)
223 |
224 | @staticmethod
225 | def decode(string_list):
226 | """
227 | Decodes a list of string notations to specify blocks inside the network.
228 |
229 | :param string_list: a list of strings, each string is a notation of block
230 | :return: a list of BlockArgs namedtuples of block args
231 | """
232 | assert isinstance(string_list, list)
233 | blocks_args = []
234 | for block_string in string_list:
235 | blocks_args.append(BlockDecoder._decode_block_string(block_string))
236 | return blocks_args
237 |
238 | @staticmethod
239 | def encode(blocks_args):
240 | """
241 | Encodes a list of BlockArgs to a list of strings.
242 |
243 | :param blocks_args: a list of BlockArgs namedtuples of block args
244 | :return: a list of strings, each string is a notation of block
245 | """
246 | block_strings = []
247 | for block in blocks_args:
248 | block_strings.append(BlockDecoder._encode_block_string(block))
249 | return block_strings
250 |
251 |
252 | def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2,
253 | drop_connect_rate=0.2, image_size=None, num_classes=1000):
254 | """ Creates a efficientnet model. """
255 |
256 | blocks_args = [
257 | 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
258 | 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
259 | 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
260 | 'r1_k3_s11_e6_i192_o320_se0.25',
261 | ]
262 | blocks_args = BlockDecoder.decode(blocks_args)
263 |
264 | global_params = GlobalParams(
265 | batch_norm_momentum=0.99,
266 | batch_norm_epsilon=1e-3,
267 | dropout_rate=dropout_rate,
268 | drop_connect_rate=drop_connect_rate,
269 | # data_format='channels_last', # removed, this is always true in PyTorch
270 | num_classes=num_classes,
271 | width_coefficient=width_coefficient,
272 | depth_coefficient=depth_coefficient,
273 | depth_divisor=8,
274 | min_depth=None,
275 | image_size=image_size,
276 | )
277 |
278 | return blocks_args, global_params
279 |
280 |
281 | def get_model_params(model_name, override_params):
282 | """ Get the block args and global params for a given model """
283 | if model_name.startswith('efficientnet'):
284 | w, d, s, p = efficientnet_params(model_name)
285 | # note: all models have drop connect rate = 0.2
286 | blocks_args, global_params = efficientnet(
287 | width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s)
288 | else:
289 | raise NotImplementedError('model name is not pre-defined: %s' % model_name)
290 | if override_params:
291 | # ValueError will be raised here if override_params has fields not included in global_params.
292 | global_params = global_params._replace(**override_params)
293 | return blocks_args, global_params
294 |
295 |
296 | url_map = {
297 | 'efficientnet-b0': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b0-355c32eb.pth',
298 | 'efficientnet-b1': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b1-f1951068.pth',
299 | 'efficientnet-b2': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b2-8bb594d6.pth',
300 | 'efficientnet-b3': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b3-5fb5a3c3.pth',
301 | 'efficientnet-b4': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b4-6ed6700e.pth',
302 | 'efficientnet-b5': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b5-b6417697.pth',
303 | 'efficientnet-b6': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b6-c76e70fd.pth',
304 | 'efficientnet-b7': 'http://storage.googleapis.com/public-models/efficientnet/efficientnet-b7-dcc49843.pth',
305 | }
306 |
307 |
308 | def load_pretrained_weights(model, model_name, load_fc=True):
309 | """ Loads pretrained weights, and downloads if loading for the first time. """
310 | state_dict = model_zoo.load_url(url_map[model_name])
311 | if load_fc:
312 | model.load_state_dict(state_dict)
313 | else:
314 | state_dict.pop('_fc.weight')
315 | state_dict.pop('_fc.bias')
316 | res = model.load_state_dict(state_dict, strict=False)
317 | assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights'
318 | print('Loaded pretrained weights for {}'.format(model_name))
319 |
--------------------------------------------------------------------------------
/examples/imagenet/README.md:
--------------------------------------------------------------------------------
1 | ### Imagenet
2 |
3 | This is a preliminary directory for evaluating the model on ImageNet. It is adapted from the standard PyTorch Imagenet script.
4 |
5 | For now, only evaluation is supported, but I am currently building scripts to assist with training new models on Imagenet.
6 |
7 | The evaluation results are slightly different from the original TensorFlow repository, due to differences in data preprocessing. For example, with the current preprocessing, `efficientnet-b3` gives a top-1 accuracy of `80.8`, rather than `81.1` in the paper. I am working on porting the TensorFlow preprocessing into PyTorch to address this issue.
8 |
9 | To run on Imagenet, place your `train` and `val` directories in `data`.
10 |
11 | Example commands:
12 | ```bash
13 | # Evaluate small EfficientNet on CPU
14 | python main.py data -e -a 'efficientnet-b0' --pretrained
15 | ```
16 | ```bash
17 | # Evaluate medium EfficientNet on GPU
18 | python main.py data -e -a 'efficientnet-b3' --pretrained --gpu 0 --batch-size 128
19 | ```
20 | ```bash
21 | # Evaluate ResNet-50 for comparison
22 | python main.py data -e -a 'resnet50' --pretrained --gpu 0
23 | ```
24 |
--------------------------------------------------------------------------------
/examples/imagenet/data/README.md:
--------------------------------------------------------------------------------
1 | ### ImageNet
2 |
3 | Download ImageNet and place it into `train` and `val` folders here.
4 |
5 | More details may be found with the official PyTorch ImageNet example [here](https://github.com/pytorch/examples/blob/master/imagenet).
6 |
--------------------------------------------------------------------------------
/examples/imagenet/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate on ImageNet. Note that at the moment, training is not implemented (I am working on it).
3 | that being said, evaluation is working.
4 | """
5 |
6 | import argparse
7 | import os
8 | import random
9 | import shutil
10 | import time
11 | import warnings
12 | import PIL
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.backends.cudnn as cudnn
18 | import torch.distributed as dist
19 | import torch.optim
20 | import torch.multiprocessing as mp
21 | import torch.utils.data
22 | import torch.utils.data.distributed
23 | import torchvision.transforms as transforms
24 | import torchvision.datasets as datasets
25 | import torchvision.models as models
26 |
27 | from efficientnet_pytorch import EfficientNet
28 |
29 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
30 | parser.add_argument('data', metavar='DIR',
31 | help='path to dataset')
32 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
33 | help='model architecture (default: resnet18)')
34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
37 | help='number of total epochs to run')
38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
39 | help='manual epoch number (useful on restarts)')
40 | parser.add_argument('-b', '--batch-size', default=256, type=int,
41 | metavar='N',
42 | help='mini-batch size (default: 256), this is the total '
43 | 'batch size of all GPUs on the current node when '
44 | 'using Data Parallel or Distributed Data Parallel')
45 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
46 | metavar='LR', help='initial learning rate', dest='lr')
47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
48 | help='momentum')
49 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
50 | metavar='W', help='weight decay (default: 1e-4)',
51 | dest='weight_decay')
52 | parser.add_argument('-p', '--print-freq', default=10, type=int,
53 | metavar='N', help='print frequency (default: 10)')
54 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
55 | help='path to latest checkpoint (default: none)')
56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
57 | help='evaluate model on validation set')
58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
59 | help='use pre-trained model')
60 | parser.add_argument('--world-size', default=-1, type=int,
61 | help='number of nodes for distributed training')
62 | parser.add_argument('--rank', default=-1, type=int,
63 | help='node rank for distributed training')
64 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
65 | help='url used to set up distributed training')
66 | parser.add_argument('--dist-backend', default='nccl', type=str,
67 | help='distributed backend')
68 | parser.add_argument('--seed', default=None, type=int,
69 | help='seed for initializing training. ')
70 | parser.add_argument('--gpu', default=None, type=int,
71 | help='GPU id to use.')
72 | parser.add_argument('--image_size', default=224, type=int,
73 | help='image size')
74 | parser.add_argument('--multiprocessing-distributed', action='store_true',
75 | help='Use multi-processing distributed training to launch '
76 | 'N processes per node, which has N GPUs. This is the '
77 | 'fastest way to use PyTorch for either single node or '
78 | 'multi node data parallel training')
79 |
80 | best_acc1 = 0
81 |
82 |
83 | def main():
84 | args = parser.parse_args()
85 |
86 | if args.seed is not None:
87 | random.seed(args.seed)
88 | torch.manual_seed(args.seed)
89 | cudnn.deterministic = True
90 | warnings.warn('You have chosen to seed training. '
91 | 'This will turn on the CUDNN deterministic setting, '
92 | 'which can slow down your training considerably! '
93 | 'You may see unexpected behavior when restarting '
94 | 'from checkpoints.')
95 |
96 | if args.gpu is not None:
97 | warnings.warn('You have chosen a specific GPU. This will completely '
98 | 'disable data parallelism.')
99 |
100 | if args.dist_url == "env://" and args.world_size == -1:
101 | args.world_size = int(os.environ["WORLD_SIZE"])
102 |
103 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
104 |
105 | ngpus_per_node = torch.cuda.device_count()
106 | if args.multiprocessing_distributed:
107 | # Since we have ngpus_per_node processes per node, the total world_size
108 | # needs to be adjusted accordingly
109 | args.world_size = ngpus_per_node * args.world_size
110 | # Use torch.multiprocessing.spawn to launch distributed processes: the
111 | # main_worker process function
112 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
113 | else:
114 | # Simply call main_worker function
115 | main_worker(args.gpu, ngpus_per_node, args)
116 |
117 |
118 | def main_worker(gpu, ngpus_per_node, args):
119 | global best_acc1
120 | args.gpu = gpu
121 |
122 | if args.gpu is not None:
123 | print("Use GPU: {} for training".format(args.gpu))
124 |
125 | if args.distributed:
126 | if args.dist_url == "env://" and args.rank == -1:
127 | args.rank = int(os.environ["RANK"])
128 | if args.multiprocessing_distributed:
129 | # For multiprocessing distributed training, rank needs to be the
130 | # global rank among all the processes
131 | args.rank = args.rank * ngpus_per_node + gpu
132 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
133 | world_size=args.world_size, rank=args.rank)
134 | # create model
135 | if 'efficientnet' in args.arch: # NEW
136 | if args.pretrained:
137 | model = EfficientNet.from_pretrained(args.arch)
138 | print("=> using pre-trained model '{}'".format(args.arch))
139 | else:
140 | print("=> creating model '{}'".format(args.arch))
141 | model = EfficientNet.from_name(args.arch)
142 |
143 | else:
144 | if args.pretrained:
145 | print("=> using pre-trained model '{}'".format(args.arch))
146 | model = models.__dict__[args.arch](pretrained=True)
147 | else:
148 | print("=> creating model '{}'".format(args.arch))
149 | model = models.__dict__[args.arch]()
150 |
151 | if args.distributed:
152 | # For multiprocessing distributed, DistributedDataParallel constructor
153 | # should always set the single device scope, otherwise,
154 | # DistributedDataParallel will use all available devices.
155 | if args.gpu is not None:
156 | torch.cuda.set_device(args.gpu)
157 | model.cuda(args.gpu)
158 | # When using a single GPU per process and per
159 | # DistributedDataParallel, we need to divide the batch size
160 | # ourselves based on the total number of GPUs we have
161 | args.batch_size = int(args.batch_size / ngpus_per_node)
162 | args.workers = int(args.workers / ngpus_per_node)
163 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
164 | else:
165 | model.cuda()
166 | # DistributedDataParallel will divide and allocate batch_size to all
167 | # available GPUs if device_ids are not set
168 | model = torch.nn.parallel.DistributedDataParallel(model)
169 | elif args.gpu is not None:
170 | torch.cuda.set_device(args.gpu)
171 | model = model.cuda(args.gpu)
172 | else:
173 | # DataParallel will divide and allocate batch_size to all available GPUs
174 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
175 | model.features = torch.nn.DataParallel(model.features)
176 | model.cuda()
177 | else:
178 | model = torch.nn.DataParallel(model).cuda()
179 |
180 | # define loss function (criterion) and optimizer
181 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
182 |
183 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
184 | momentum=args.momentum,
185 | weight_decay=args.weight_decay)
186 |
187 | # optionally resume from a checkpoint
188 | if args.resume:
189 | if os.path.isfile(args.resume):
190 | print("=> loading checkpoint '{}'".format(args.resume))
191 | checkpoint = torch.load(args.resume)
192 | args.start_epoch = checkpoint['epoch']
193 | best_acc1 = checkpoint['best_acc1']
194 | if args.gpu is not None:
195 | # best_acc1 may be from a checkpoint from a different GPU
196 | best_acc1 = best_acc1.to(args.gpu)
197 | model.load_state_dict(checkpoint['state_dict'])
198 | optimizer.load_state_dict(checkpoint['optimizer'])
199 | print("=> loaded checkpoint '{}' (epoch {})"
200 | .format(args.resume, checkpoint['epoch']))
201 | else:
202 | print("=> no checkpoint found at '{}'".format(args.resume))
203 |
204 | cudnn.benchmark = True
205 |
206 | # Data loading code
207 | traindir = os.path.join(args.data, 'train')
208 | valdir = os.path.join(args.data, 'val')
209 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
210 | std=[0.229, 0.224, 0.225])
211 |
212 | train_dataset = datasets.ImageFolder(
213 | traindir,
214 | transforms.Compose([
215 | transforms.RandomResizedCrop(224),
216 | transforms.RandomHorizontalFlip(),
217 | transforms.ToTensor(),
218 | normalize,
219 | ]))
220 |
221 | if args.distributed:
222 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
223 | else:
224 | train_sampler = None
225 |
226 | train_loader = torch.utils.data.DataLoader(
227 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
228 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
229 |
230 | if 'efficientnet' in args.arch:
231 | image_size = EfficientNet.get_image_size(args.arch)
232 | val_transforms = transforms.Compose([
233 | transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
234 | transforms.CenterCrop(image_size),
235 | transforms.ToTensor(),
236 | normalize,
237 | ])
238 | print('Using image size', image_size)
239 | else:
240 | val_transforms = transforms.Compose([
241 | transforms.Resize(256),
242 | transforms.CenterCrop(224),
243 | transforms.ToTensor(),
244 | normalize,
245 | ])
246 | print('Using image size', 224)
247 |
248 | val_loader = torch.utils.data.DataLoader(
249 | datasets.ImageFolder(valdir, val_transforms),
250 | batch_size=args.batch_size, shuffle=False,
251 | num_workers=args.workers, pin_memory=True)
252 |
253 | if args.evaluate:
254 | res = validate(val_loader, model, criterion, args)
255 | with open('res.txt', 'w') as f:
256 | print(res, file=f)
257 | return
258 |
259 | for epoch in range(args.start_epoch, args.epochs):
260 | if args.distributed:
261 | train_sampler.set_epoch(epoch)
262 | adjust_learning_rate(optimizer, epoch, args)
263 |
264 | # train for one epoch
265 | train(train_loader, model, criterion, optimizer, epoch, args)
266 |
267 | # evaluate on validation set
268 | acc1 = validate(val_loader, model, criterion, args)
269 |
270 | # remember best acc@1 and save checkpoint
271 | is_best = acc1 > best_acc1
272 | best_acc1 = max(acc1, best_acc1)
273 |
274 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
275 | and args.rank % ngpus_per_node == 0):
276 | save_checkpoint({
277 | 'epoch': epoch + 1,
278 | 'arch': args.arch,
279 | 'state_dict': model.state_dict(),
280 | 'best_acc1': best_acc1,
281 | 'optimizer' : optimizer.state_dict(),
282 | }, is_best)
283 |
284 |
285 | def train(train_loader, model, criterion, optimizer, epoch, args):
286 | batch_time = AverageMeter('Time', ':6.3f')
287 | data_time = AverageMeter('Data', ':6.3f')
288 | losses = AverageMeter('Loss', ':.4e')
289 | top1 = AverageMeter('Acc@1', ':6.2f')
290 | top5 = AverageMeter('Acc@5', ':6.2f')
291 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1,
292 | top5, prefix="Epoch: [{}]".format(epoch))
293 |
294 | # switch to train mode
295 | model.train()
296 |
297 | end = time.time()
298 | for i, (images, target) in enumerate(train_loader):
299 | # measure data loading time
300 | data_time.update(time.time() - end)
301 |
302 | if args.gpu is not None:
303 | images = images.cuda(args.gpu, non_blocking=True)
304 | target = target.cuda(args.gpu, non_blocking=True)
305 |
306 | # compute output
307 | output = model(images)
308 | loss = criterion(output, target)
309 |
310 | # measure accuracy and record loss
311 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
312 | losses.update(loss.item(), images.size(0))
313 | top1.update(acc1[0], images.size(0))
314 | top5.update(acc5[0], images.size(0))
315 |
316 | # compute gradient and do SGD step
317 | optimizer.zero_grad()
318 | loss.backward()
319 | optimizer.step()
320 |
321 | # measure elapsed time
322 | batch_time.update(time.time() - end)
323 | end = time.time()
324 |
325 | if i % args.print_freq == 0:
326 | progress.print(i)
327 |
328 |
329 | def validate(val_loader, model, criterion, args):
330 | batch_time = AverageMeter('Time', ':6.3f')
331 | losses = AverageMeter('Loss', ':.4e')
332 | top1 = AverageMeter('Acc@1', ':6.2f')
333 | top5 = AverageMeter('Acc@5', ':6.2f')
334 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5,
335 | prefix='Test: ')
336 |
337 | # switch to evaluate mode
338 | model.eval()
339 |
340 | with torch.no_grad():
341 | end = time.time()
342 | for i, (images, target) in enumerate(val_loader):
343 | if args.gpu is not None:
344 | images = images.cuda(args.gpu, non_blocking=True)
345 | target = target.cuda(args.gpu, non_blocking=True)
346 |
347 | # compute output
348 | output = model(images)
349 | loss = criterion(output, target)
350 |
351 | # measure accuracy and record loss
352 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
353 | losses.update(loss.item(), images.size(0))
354 | top1.update(acc1[0], images.size(0))
355 | top5.update(acc5[0], images.size(0))
356 |
357 | # measure elapsed time
358 | batch_time.update(time.time() - end)
359 | end = time.time()
360 |
361 | if i % args.print_freq == 0:
362 | progress.print(i)
363 |
364 | # TODO: this should also be done with the ProgressMeter
365 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
366 | .format(top1=top1, top5=top5))
367 |
368 | return top1.avg
369 |
370 |
371 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
372 | torch.save(state, filename)
373 | if is_best:
374 | shutil.copyfile(filename, 'model_best.pth.tar')
375 |
376 |
377 | class AverageMeter(object):
378 | """Computes and stores the average and current value"""
379 | def __init__(self, name, fmt=':f'):
380 | self.name = name
381 | self.fmt = fmt
382 | self.reset()
383 |
384 | def reset(self):
385 | self.val = 0
386 | self.avg = 0
387 | self.sum = 0
388 | self.count = 0
389 |
390 | def update(self, val, n=1):
391 | self.val = val
392 | self.sum += val * n
393 | self.count += n
394 | self.avg = self.sum / self.count
395 |
396 | def __str__(self):
397 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
398 | return fmtstr.format(**self.__dict__)
399 |
400 |
401 | class ProgressMeter(object):
402 | def __init__(self, num_batches, *meters, prefix=""):
403 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
404 | self.meters = meters
405 | self.prefix = prefix
406 |
407 | def print(self, batch):
408 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
409 | entries += [str(meter) for meter in self.meters]
410 | print('\t'.join(entries))
411 |
412 | def _get_batch_fmtstr(self, num_batches):
413 | num_digits = len(str(num_batches // 1))
414 | fmt = '{:' + str(num_digits) + 'd}'
415 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
416 |
417 |
418 | def adjust_learning_rate(optimizer, epoch, args):
419 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
420 | lr = args.lr * (0.1 ** (epoch // 30))
421 | for param_group in optimizer.param_groups:
422 | param_group['lr'] = lr
423 |
424 |
425 | def accuracy(output, target, topk=(1,)):
426 | """Computes the accuracy over the k top predictions for the specified values of k"""
427 | with torch.no_grad():
428 | maxk = max(topk)
429 | batch_size = target.size(0)
430 |
431 | _, pred = output.topk(maxk, 1, True, True)
432 | pred = pred.t()
433 | correct = pred.eq(target.view(1, -1).expand_as(pred))
434 |
435 | res = []
436 | for k in topk:
437 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
438 | res.append(correct_k.mul_(100.0 / batch_size))
439 | return res
440 |
441 |
442 | if __name__ == '__main__':
443 | main()
444 |
--------------------------------------------------------------------------------
/examples/simple/fire_smoke_map.txt:
--------------------------------------------------------------------------------
1 | {"0": "fire", "1": "negtive", "2": "smoke"}
2 |
--------------------------------------------------------------------------------
/examples/simple/img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/examples/simple/img.jpg
--------------------------------------------------------------------------------
/examples/simple/img2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/examples/simple/img2.jpg
--------------------------------------------------------------------------------
/examples/simple/labels_map.txt:
--------------------------------------------------------------------------------
1 | {"0": "tench, Tinca tinca", "1": "goldfish, Carassius auratus", "2": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", "3": "tiger shark, Galeocerdo cuvieri", "4": "hammerhead, hammerhead shark", "5": "electric ray, crampfish, numbfish, torpedo", "6": "stingray", "7": "cock", "8": "hen", "9": "ostrich, Struthio camelus", "10": "brambling, Fringilla montifringilla", "11": "goldfinch, Carduelis carduelis", "12": "house finch, linnet, Carpodacus mexicanus", "13": "junco, snowbird", "14": "indigo bunting, indigo finch, indigo bird, Passerina cyanea", "15": "robin, American robin, Turdus migratorius", "16": "bulbul", "17": "jay", "18": "magpie", "19": "chickadee", "20": "water ouzel, dipper", "21": "kite", "22": "bald eagle, American eagle, Haliaeetus leucocephalus", "23": "vulture", "24": "great grey owl, great gray owl, Strix nebulosa", "25": "European fire salamander, Salamandra salamandra", "26": "common newt, Triturus vulgaris", "27": "eft", "28": "spotted salamander, Ambystoma maculatum", "29": "axolotl, mud puppy, Ambystoma mexicanum", "30": "bullfrog, Rana catesbeiana", "31": "tree frog, tree-frog", "32": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui", "33": "loggerhead, loggerhead turtle, Caretta caretta", "34": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", "35": "mud turtle", "36": "terrapin", "37": "box turtle, box tortoise", "38": "banded gecko", "39": "common iguana, iguana, Iguana iguana", "40": "American chameleon, anole, Anolis carolinensis", "41": "whiptail, whiptail lizard", "42": "agama", "43": "frilled lizard, Chlamydosaurus kingi", "44": "alligator lizard", "45": "Gila monster, Heloderma suspectum", "46": "green lizard, Lacerta viridis", "47": "African chameleon, Chamaeleo chamaeleon", "48": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", "49": "African crocodile, Nile crocodile, Crocodylus niloticus", "50": "American alligator, Alligator mississipiensis", "51": "triceratops", "52": "thunder snake, worm snake, Carphophis amoenus", "53": "ringneck snake, ring-necked snake, ring snake", "54": "hognose snake, puff adder, sand viper", "55": "green snake, grass snake", "56": "king snake, kingsnake", "57": "garter snake, grass snake", "58": "water snake", "59": "vine snake", "60": "night snake, Hypsiglena torquata", "61": "boa constrictor, Constrictor constrictor", "62": "rock python, rock snake, Python sebae", "63": "Indian cobra, Naja naja", "64": "green mamba", "65": "sea snake", "66": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus", "67": "diamondback, diamondback rattlesnake, Crotalus adamanteus", "68": "sidewinder, horned rattlesnake, Crotalus cerastes", "69": "trilobite", "70": "harvestman, daddy longlegs, Phalangium opilio", "71": "scorpion", "72": "black and gold garden spider, Argiope aurantia", "73": "barn spider, Araneus cavaticus", "74": "garden spider, Aranea diademata", "75": "black widow, Latrodectus mactans", "76": "tarantula", "77": "wolf spider, hunting spider", "78": "tick", "79": "centipede", "80": "black grouse", "81": "ptarmigan", "82": "ruffed grouse, partridge, Bonasa umbellus", "83": "prairie chicken, prairie grouse, prairie fowl", "84": "peacock", "85": "quail", "86": "partridge", "87": "African grey, African gray, Psittacus erithacus", "88": "macaw", "89": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", "90": "lorikeet", "91": "coucal", "92": "bee eater", "93": "hornbill", "94": "hummingbird", "95": "jacamar", "96": "toucan", "97": "drake", "98": "red-breasted merganser, Mergus serrator", "99": "goose", "100": "black swan, Cygnus atratus", "101": "tusker", "102": "echidna, spiny anteater, anteater", "103": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", "104": "wallaby, brush kangaroo", "105": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", "106": "wombat", "107": "jellyfish", "108": "sea anemone, anemone", "109": "brain coral", "110": "flatworm, platyhelminth", "111": "nematode, nematode worm, roundworm", "112": "conch", "113": "snail", "114": "slug", "115": "sea slug, nudibranch", "116": "chiton, coat-of-mail shell, sea cradle, polyplacophore", "117": "chambered nautilus, pearly nautilus, nautilus", "118": "Dungeness crab, Cancer magister", "119": "rock crab, Cancer irroratus", "120": "fiddler crab", "121": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", "122": "American lobster, Northern lobster, Maine lobster, Homarus americanus", "123": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", "124": "crayfish, crawfish, crawdad, crawdaddy", "125": "hermit crab", "126": "isopod", "127": "white stork, Ciconia ciconia", "128": "black stork, Ciconia nigra", "129": "spoonbill", "130": "flamingo", "131": "little blue heron, Egretta caerulea", "132": "American egret, great white heron, Egretta albus", "133": "bittern", "134": "crane", "135": "limpkin, Aramus pictus", "136": "European gallinule, Porphyrio porphyrio", "137": "American coot, marsh hen, mud hen, water hen, Fulica americana", "138": "bustard", "139": "ruddy turnstone, Arenaria interpres", "140": "red-backed sandpiper, dunlin, Erolia alpina", "141": "redshank, Tringa totanus", "142": "dowitcher", "143": "oystercatcher, oyster catcher", "144": "pelican", "145": "king penguin, Aptenodytes patagonica", "146": "albatross, mollymawk", "147": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", "148": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca", "149": "dugong, Dugong dugon", "150": "sea lion", "151": "Chihuahua", "152": "Japanese spaniel", "153": "Maltese dog, Maltese terrier, Maltese", "154": "Pekinese, Pekingese, Peke", "155": "Shih-Tzu", "156": "Blenheim spaniel", "157": "papillon", "158": "toy terrier", "159": "Rhodesian ridgeback", "160": "Afghan hound, Afghan", "161": "basset, basset hound", "162": "beagle", "163": "bloodhound, sleuthhound", "164": "bluetick", "165": "black-and-tan coonhound", "166": "Walker hound, Walker foxhound", "167": "English foxhound", "168": "redbone", "169": "borzoi, Russian wolfhound", "170": "Irish wolfhound", "171": "Italian greyhound", "172": "whippet", "173": "Ibizan hound, Ibizan Podenco", "174": "Norwegian elkhound, elkhound", "175": "otterhound, otter hound", "176": "Saluki, gazelle hound", "177": "Scottish deerhound, deerhound", "178": "Weimaraner", "179": "Staffordshire bullterrier, Staffordshire bull terrier", "180": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", "181": "Bedlington terrier", "182": "Border terrier", "183": "Kerry blue terrier", "184": "Irish terrier", "185": "Norfolk terrier", "186": "Norwich terrier", "187": "Yorkshire terrier", "188": "wire-haired fox terrier", "189": "Lakeland terrier", "190": "Sealyham terrier, Sealyham", "191": "Airedale, Airedale terrier", "192": "cairn, cairn terrier", "193": "Australian terrier", "194": "Dandie Dinmont, Dandie Dinmont terrier", "195": "Boston bull, Boston terrier", "196": "miniature schnauzer", "197": "giant schnauzer", "198": "standard schnauzer", "199": "Scotch terrier, Scottish terrier, Scottie", "200": "Tibetan terrier, chrysanthemum dog", "201": "silky terrier, Sydney silky", "202": "soft-coated wheaten terrier", "203": "West Highland white terrier", "204": "Lhasa, Lhasa apso", "205": "flat-coated retriever", "206": "curly-coated retriever", "207": "golden retriever", "208": "Labrador retriever", "209": "Chesapeake Bay retriever", "210": "German short-haired pointer", "211": "vizsla, Hungarian pointer", "212": "English setter", "213": "Irish setter, red setter", "214": "Gordon setter", "215": "Brittany spaniel", "216": "clumber, clumber spaniel", "217": "English springer, English springer spaniel", "218": "Welsh springer spaniel", "219": "cocker spaniel, English cocker spaniel, cocker", "220": "Sussex spaniel", "221": "Irish water spaniel", "222": "kuvasz", "223": "schipperke", "224": "groenendael", "225": "malinois", "226": "briard", "227": "kelpie", "228": "komondor", "229": "Old English sheepdog, bobtail", "230": "Shetland sheepdog, Shetland sheep dog, Shetland", "231": "collie", "232": "Border collie", "233": "Bouvier des Flandres, Bouviers des Flandres", "234": "Rottweiler", "235": "German shepherd, German shepherd dog, German police dog, alsatian", "236": "Doberman, Doberman pinscher", "237": "miniature pinscher", "238": "Greater Swiss Mountain dog", "239": "Bernese mountain dog", "240": "Appenzeller", "241": "EntleBucher", "242": "boxer", "243": "bull mastiff", "244": "Tibetan mastiff", "245": "French bulldog", "246": "Great Dane", "247": "Saint Bernard, St Bernard", "248": "Eskimo dog, husky", "249": "malamute, malemute, Alaskan malamute", "250": "Siberian husky", "251": "dalmatian, coach dog, carriage dog", "252": "affenpinscher, monkey pinscher, monkey dog", "253": "basenji", "254": "pug, pug-dog", "255": "Leonberg", "256": "Newfoundland, Newfoundland dog", "257": "Great Pyrenees", "258": "Samoyed, Samoyede", "259": "Pomeranian", "260": "chow, chow chow", "261": "keeshond", "262": "Brabancon griffon", "263": "Pembroke, Pembroke Welsh corgi", "264": "Cardigan, Cardigan Welsh corgi", "265": "toy poodle", "266": "miniature poodle", "267": "standard poodle", "268": "Mexican hairless", "269": "timber wolf, grey wolf, gray wolf, Canis lupus", "270": "white wolf, Arctic wolf, Canis lupus tundrarum", "271": "red wolf, maned wolf, Canis rufus, Canis niger", "272": "coyote, prairie wolf, brush wolf, Canis latrans", "273": "dingo, warrigal, warragal, Canis dingo", "274": "dhole, Cuon alpinus", "275": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus", "276": "hyena, hyaena", "277": "red fox, Vulpes vulpes", "278": "kit fox, Vulpes macrotis", "279": "Arctic fox, white fox, Alopex lagopus", "280": "grey fox, gray fox, Urocyon cinereoargenteus", "281": "tabby, tabby cat", "282": "tiger cat", "283": "Persian cat", "284": "Siamese cat, Siamese", "285": "Egyptian cat", "286": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", "287": "lynx, catamount", "288": "leopard, Panthera pardus", "289": "snow leopard, ounce, Panthera uncia", "290": "jaguar, panther, Panthera onca, Felis onca", "291": "lion, king of beasts, Panthera leo", "292": "tiger, Panthera tigris", "293": "cheetah, chetah, Acinonyx jubatus", "294": "brown bear, bruin, Ursus arctos", "295": "American black bear, black bear, Ursus americanus, Euarctos americanus", "296": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus", "297": "sloth bear, Melursus ursinus, Ursus ursinus", "298": "mongoose", "299": "meerkat, mierkat", "300": "tiger beetle", "301": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle", "302": "ground beetle, carabid beetle", "303": "long-horned beetle, longicorn, longicorn beetle", "304": "leaf beetle, chrysomelid", "305": "dung beetle", "306": "rhinoceros beetle", "307": "weevil", "308": "fly", "309": "bee", "310": "ant, emmet, pismire", "311": "grasshopper, hopper", "312": "cricket", "313": "walking stick, walkingstick, stick insect", "314": "cockroach, roach", "315": "mantis, mantid", "316": "cicada, cicala", "317": "leafhopper", "318": "lacewing, lacewing fly", "319": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", "320": "damselfly", "321": "admiral", "322": "ringlet, ringlet butterfly", "323": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", "324": "cabbage butterfly", "325": "sulphur butterfly, sulfur butterfly", "326": "lycaenid, lycaenid butterfly", "327": "starfish, sea star", "328": "sea urchin", "329": "sea cucumber, holothurian", "330": "wood rabbit, cottontail, cottontail rabbit", "331": "hare", "332": "Angora, Angora rabbit", "333": "hamster", "334": "porcupine, hedgehog", "335": "fox squirrel, eastern fox squirrel, Sciurus niger", "336": "marmot", "337": "beaver", "338": "guinea pig, Cavia cobaya", "339": "sorrel", "340": "zebra", "341": "hog, pig, grunter, squealer, Sus scrofa", "342": "wild boar, boar, Sus scrofa", "343": "warthog", "344": "hippopotamus, hippo, river horse, Hippopotamus amphibius", "345": "ox", "346": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis", "347": "bison", "348": "ram, tup", "349": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", "350": "ibex, Capra ibex", "351": "hartebeest", "352": "impala, Aepyceros melampus", "353": "gazelle", "354": "Arabian camel, dromedary, Camelus dromedarius", "355": "llama", "356": "weasel", "357": "mink", "358": "polecat, fitch, foulmart, foumart, Mustela putorius", "359": "black-footed ferret, ferret, Mustela nigripes", "360": "otter", "361": "skunk, polecat, wood pussy", "362": "badger", "363": "armadillo", "364": "three-toed sloth, ai, Bradypus tridactylus", "365": "orangutan, orang, orangutang, Pongo pygmaeus", "366": "gorilla, Gorilla gorilla", "367": "chimpanzee, chimp, Pan troglodytes", "368": "gibbon, Hylobates lar", "369": "siamang, Hylobates syndactylus, Symphalangus syndactylus", "370": "guenon, guenon monkey", "371": "patas, hussar monkey, Erythrocebus patas", "372": "baboon", "373": "macaque", "374": "langur", "375": "colobus, colobus monkey", "376": "proboscis monkey, Nasalis larvatus", "377": "marmoset", "378": "capuchin, ringtail, Cebus capucinus", "379": "howler monkey, howler", "380": "titi, titi monkey", "381": "spider monkey, Ateles geoffroyi", "382": "squirrel monkey, Saimiri sciureus", "383": "Madagascar cat, ring-tailed lemur, Lemur catta", "384": "indri, indris, Indri indri, Indri brevicaudatus", "385": "Indian elephant, Elephas maximus", "386": "African elephant, Loxodonta africana", "387": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens", "388": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca", "389": "barracouta, snoek", "390": "eel", "391": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", "392": "rock beauty, Holocanthus tricolor", "393": "anemone fish", "394": "sturgeon", "395": "gar, garfish, garpike, billfish, Lepisosteus osseus", "396": "lionfish", "397": "puffer, pufferfish, blowfish, globefish", "398": "abacus", "399": "abaya", "400": "academic gown, academic robe, judge's robe", "401": "accordion, piano accordion, squeeze box", "402": "acoustic guitar", "403": "aircraft carrier, carrier, flattop, attack aircraft carrier", "404": "airliner", "405": "airship, dirigible", "406": "altar", "407": "ambulance", "408": "amphibian, amphibious vehicle", "409": "analog clock", "410": "apiary, bee house", "411": "apron", "412": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", "413": "assault rifle, assault gun", "414": "backpack, back pack, knapsack, packsack, rucksack, haversack", "415": "bakery, bakeshop, bakehouse", "416": "balance beam, beam", "417": "balloon", "418": "ballpoint, ballpoint pen, ballpen, Biro", "419": "Band Aid", "420": "banjo", "421": "bannister, banister, balustrade, balusters, handrail", "422": "barbell", "423": "barber chair", "424": "barbershop", "425": "barn", "426": "barometer", "427": "barrel, cask", "428": "barrow, garden cart, lawn cart, wheelbarrow", "429": "baseball", "430": "basketball", "431": "bassinet", "432": "bassoon", "433": "bathing cap, swimming cap", "434": "bath towel", "435": "bathtub, bathing tub, bath, tub", "436": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", "437": "beacon, lighthouse, beacon light, pharos", "438": "beaker", "439": "bearskin, busby, shako", "440": "beer bottle", "441": "beer glass", "442": "bell cote, bell cot", "443": "bib", "444": "bicycle-built-for-two, tandem bicycle, tandem", "445": "bikini, two-piece", "446": "binder, ring-binder", "447": "binoculars, field glasses, opera glasses", "448": "birdhouse", "449": "boathouse", "450": "bobsled, bobsleigh, bob", "451": "bolo tie, bolo, bola tie, bola", "452": "bonnet, poke bonnet", "453": "bookcase", "454": "bookshop, bookstore, bookstall", "455": "bottlecap", "456": "bow", "457": "bow tie, bow-tie, bowtie", "458": "brass, memorial tablet, plaque", "459": "brassiere, bra, bandeau", "460": "breakwater, groin, groyne, mole, bulwark, seawall, jetty", "461": "breastplate, aegis, egis", "462": "broom", "463": "bucket, pail", "464": "buckle", "465": "bulletproof vest", "466": "bullet train, bullet", "467": "butcher shop, meat market", "468": "cab, hack, taxi, taxicab", "469": "caldron, cauldron", "470": "candle, taper, wax light", "471": "cannon", "472": "canoe", "473": "can opener, tin opener", "474": "cardigan", "475": "car mirror", "476": "carousel, carrousel, merry-go-round, roundabout, whirligig", "477": "carpenter's kit, tool kit", "478": "carton", "479": "car wheel", "480": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", "481": "cassette", "482": "cassette player", "483": "castle", "484": "catamaran", "485": "CD player", "486": "cello, violoncello", "487": "cellular telephone, cellular phone, cellphone, cell, mobile phone", "488": "chain", "489": "chainlink fence", "490": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", "491": "chain saw, chainsaw", "492": "chest", "493": "chiffonier, commode", "494": "chime, bell, gong", "495": "china cabinet, china closet", "496": "Christmas stocking", "497": "church, church building", "498": "cinema, movie theater, movie theatre, movie house, picture palace", "499": "cleaver, meat cleaver, chopper", "500": "cliff dwelling", "501": "cloak", "502": "clog, geta, patten, sabot", "503": "cocktail shaker", "504": "coffee mug", "505": "coffeepot", "506": "coil, spiral, volute, whorl, helix", "507": "combination lock", "508": "computer keyboard, keypad", "509": "confectionery, confectionary, candy store", "510": "container ship, containership, container vessel", "511": "convertible", "512": "corkscrew, bottle screw", "513": "cornet, horn, trumpet, trump", "514": "cowboy boot", "515": "cowboy hat, ten-gallon hat", "516": "cradle", "517": "crane", "518": "crash helmet", "519": "crate", "520": "crib, cot", "521": "Crock Pot", "522": "croquet ball", "523": "crutch", "524": "cuirass", "525": "dam, dike, dyke", "526": "desk", "527": "desktop computer", "528": "dial telephone, dial phone", "529": "diaper, nappy, napkin", "530": "digital clock", "531": "digital watch", "532": "dining table, board", "533": "dishrag, dishcloth", "534": "dishwasher, dish washer, dishwashing machine", "535": "disk brake, disc brake", "536": "dock, dockage, docking facility", "537": "dogsled, dog sled, dog sleigh", "538": "dome", "539": "doormat, welcome mat", "540": "drilling platform, offshore rig", "541": "drum, membranophone, tympan", "542": "drumstick", "543": "dumbbell", "544": "Dutch oven", "545": "electric fan, blower", "546": "electric guitar", "547": "electric locomotive", "548": "entertainment center", "549": "envelope", "550": "espresso maker", "551": "face powder", "552": "feather boa, boa", "553": "file, file cabinet, filing cabinet", "554": "fireboat", "555": "fire engine, fire truck", "556": "fire screen, fireguard", "557": "flagpole, flagstaff", "558": "flute, transverse flute", "559": "folding chair", "560": "football helmet", "561": "forklift", "562": "fountain", "563": "fountain pen", "564": "four-poster", "565": "freight car", "566": "French horn, horn", "567": "frying pan, frypan, skillet", "568": "fur coat", "569": "garbage truck, dustcart", "570": "gasmask, respirator, gas helmet", "571": "gas pump, gasoline pump, petrol pump, island dispenser", "572": "goblet", "573": "go-kart", "574": "golf ball", "575": "golfcart, golf cart", "576": "gondola", "577": "gong, tam-tam", "578": "gown", "579": "grand piano, grand", "580": "greenhouse, nursery, glasshouse", "581": "grille, radiator grille", "582": "grocery store, grocery, food market, market", "583": "guillotine", "584": "hair slide", "585": "hair spray", "586": "half track", "587": "hammer", "588": "hamper", "589": "hand blower, blow dryer, blow drier, hair dryer, hair drier", "590": "hand-held computer, hand-held microcomputer", "591": "handkerchief, hankie, hanky, hankey", "592": "hard disc, hard disk, fixed disk", "593": "harmonica, mouth organ, harp, mouth harp", "594": "harp", "595": "harvester, reaper", "596": "hatchet", "597": "holster", "598": "home theater, home theatre", "599": "honeycomb", "600": "hook, claw", "601": "hoopskirt, crinoline", "602": "horizontal bar, high bar", "603": "horse cart, horse-cart", "604": "hourglass", "605": "iPod", "606": "iron, smoothing iron", "607": "jack-o'-lantern", "608": "jean, blue jean, denim", "609": "jeep, landrover", "610": "jersey, T-shirt, tee shirt", "611": "jigsaw puzzle", "612": "jinrikisha, ricksha, rickshaw", "613": "joystick", "614": "kimono", "615": "knee pad", "616": "knot", "617": "lab coat, laboratory coat", "618": "ladle", "619": "lampshade, lamp shade", "620": "laptop, laptop computer", "621": "lawn mower, mower", "622": "lens cap, lens cover", "623": "letter opener, paper knife, paperknife", "624": "library", "625": "lifeboat", "626": "lighter, light, igniter, ignitor", "627": "limousine, limo", "628": "liner, ocean liner", "629": "lipstick, lip rouge", "630": "Loafer", "631": "lotion", "632": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", "633": "loupe, jeweler's loupe", "634": "lumbermill, sawmill", "635": "magnetic compass", "636": "mailbag, postbag", "637": "mailbox, letter box", "638": "maillot", "639": "maillot, tank suit", "640": "manhole cover", "641": "maraca", "642": "marimba, xylophone", "643": "mask", "644": "matchstick", "645": "maypole", "646": "maze, labyrinth", "647": "measuring cup", "648": "medicine chest, medicine cabinet", "649": "megalith, megalithic structure", "650": "microphone, mike", "651": "microwave, microwave oven", "652": "military uniform", "653": "milk can", "654": "minibus", "655": "miniskirt, mini", "656": "minivan", "657": "missile", "658": "mitten", "659": "mixing bowl", "660": "mobile home, manufactured home", "661": "Model T", "662": "modem", "663": "monastery", "664": "monitor", "665": "moped", "666": "mortar", "667": "mortarboard", "668": "mosque", "669": "mosquito net", "670": "motor scooter, scooter", "671": "mountain bike, all-terrain bike, off-roader", "672": "mountain tent", "673": "mouse, computer mouse", "674": "mousetrap", "675": "moving van", "676": "muzzle", "677": "nail", "678": "neck brace", "679": "necklace", "680": "nipple", "681": "notebook, notebook computer", "682": "obelisk", "683": "oboe, hautboy, hautbois", "684": "ocarina, sweet potato", "685": "odometer, hodometer, mileometer, milometer", "686": "oil filter", "687": "organ, pipe organ", "688": "oscilloscope, scope, cathode-ray oscilloscope, CRO", "689": "overskirt", "690": "oxcart", "691": "oxygen mask", "692": "packet", "693": "paddle, boat paddle", "694": "paddlewheel, paddle wheel", "695": "padlock", "696": "paintbrush", "697": "pajama, pyjama, pj's, jammies", "698": "palace", "699": "panpipe, pandean pipe, syrinx", "700": "paper towel", "701": "parachute, chute", "702": "parallel bars, bars", "703": "park bench", "704": "parking meter", "705": "passenger car, coach, carriage", "706": "patio, terrace", "707": "pay-phone, pay-station", "708": "pedestal, plinth, footstall", "709": "pencil box, pencil case", "710": "pencil sharpener", "711": "perfume, essence", "712": "Petri dish", "713": "photocopier", "714": "pick, plectrum, plectron", "715": "pickelhaube", "716": "picket fence, paling", "717": "pickup, pickup truck", "718": "pier", "719": "piggy bank, penny bank", "720": "pill bottle", "721": "pillow", "722": "ping-pong ball", "723": "pinwheel", "724": "pirate, pirate ship", "725": "pitcher, ewer", "726": "plane, carpenter's plane, woodworking plane", "727": "planetarium", "728": "plastic bag", "729": "plate rack", "730": "plow, plough", "731": "plunger, plumber's helper", "732": "Polaroid camera, Polaroid Land camera", "733": "pole", "734": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", "735": "poncho", "736": "pool table, billiard table, snooker table", "737": "pop bottle, soda bottle", "738": "pot, flowerpot", "739": "potter's wheel", "740": "power drill", "741": "prayer rug, prayer mat", "742": "printer", "743": "prison, prison house", "744": "projectile, missile", "745": "projector", "746": "puck, hockey puck", "747": "punching bag, punch bag, punching ball, punchball", "748": "purse", "749": "quill, quill pen", "750": "quilt, comforter, comfort, puff", "751": "racer, race car, racing car", "752": "racket, racquet", "753": "radiator", "754": "radio, wireless", "755": "radio telescope, radio reflector", "756": "rain barrel", "757": "recreational vehicle, RV, R.V.", "758": "reel", "759": "reflex camera", "760": "refrigerator, icebox", "761": "remote control, remote", "762": "restaurant, eating house, eating place, eatery", "763": "revolver, six-gun, six-shooter", "764": "rifle", "765": "rocking chair, rocker", "766": "rotisserie", "767": "rubber eraser, rubber, pencil eraser", "768": "rugby ball", "769": "rule, ruler", "770": "running shoe", "771": "safe", "772": "safety pin", "773": "saltshaker, salt shaker", "774": "sandal", "775": "sarong", "776": "sax, saxophone", "777": "scabbard", "778": "scale, weighing machine", "779": "school bus", "780": "schooner", "781": "scoreboard", "782": "screen, CRT screen", "783": "screw", "784": "screwdriver", "785": "seat belt, seatbelt", "786": "sewing machine", "787": "shield, buckler", "788": "shoe shop, shoe-shop, shoe store", "789": "shoji", "790": "shopping basket", "791": "shopping cart", "792": "shovel", "793": "shower cap", "794": "shower curtain", "795": "ski", "796": "ski mask", "797": "sleeping bag", "798": "slide rule, slipstick", "799": "sliding door", "800": "slot, one-armed bandit", "801": "snorkel", "802": "snowmobile", "803": "snowplow, snowplough", "804": "soap dispenser", "805": "soccer ball", "806": "sock", "807": "solar dish, solar collector, solar furnace", "808": "sombrero", "809": "soup bowl", "810": "space bar", "811": "space heater", "812": "space shuttle", "813": "spatula", "814": "speedboat", "815": "spider web, spider's web", "816": "spindle", "817": "sports car, sport car", "818": "spotlight, spot", "819": "stage", "820": "steam locomotive", "821": "steel arch bridge", "822": "steel drum", "823": "stethoscope", "824": "stole", "825": "stone wall", "826": "stopwatch, stop watch", "827": "stove", "828": "strainer", "829": "streetcar, tram, tramcar, trolley, trolley car", "830": "stretcher", "831": "studio couch, day bed", "832": "stupa, tope", "833": "submarine, pigboat, sub, U-boat", "834": "suit, suit of clothes", "835": "sundial", "836": "sunglass", "837": "sunglasses, dark glasses, shades", "838": "sunscreen, sunblock, sun blocker", "839": "suspension bridge", "840": "swab, swob, mop", "841": "sweatshirt", "842": "swimming trunks, bathing trunks", "843": "swing", "844": "switch, electric switch, electrical switch", "845": "syringe", "846": "table lamp", "847": "tank, army tank, armored combat vehicle, armoured combat vehicle", "848": "tape player", "849": "teapot", "850": "teddy, teddy bear", "851": "television, television system", "852": "tennis ball", "853": "thatch, thatched roof", "854": "theater curtain, theatre curtain", "855": "thimble", "856": "thresher, thrasher, threshing machine", "857": "throne", "858": "tile roof", "859": "toaster", "860": "tobacco shop, tobacconist shop, tobacconist", "861": "toilet seat", "862": "torch", "863": "totem pole", "864": "tow truck, tow car, wrecker", "865": "toyshop", "866": "tractor", "867": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", "868": "tray", "869": "trench coat", "870": "tricycle, trike, velocipede", "871": "trimaran", "872": "tripod", "873": "triumphal arch", "874": "trolleybus, trolley coach, trackless trolley", "875": "trombone", "876": "tub, vat", "877": "turnstile", "878": "typewriter keyboard", "879": "umbrella", "880": "unicycle, monocycle", "881": "upright, upright piano", "882": "vacuum, vacuum cleaner", "883": "vase", "884": "vault", "885": "velvet", "886": "vending machine", "887": "vestment", "888": "viaduct", "889": "violin, fiddle", "890": "volleyball", "891": "waffle iron", "892": "wall clock", "893": "wallet, billfold, notecase, pocketbook", "894": "wardrobe, closet, press", "895": "warplane, military plane", "896": "washbasin, handbasin, washbowl, lavabo, wash-hand basin", "897": "washer, automatic washer, washing machine", "898": "water bottle", "899": "water jug", "900": "water tower", "901": "whiskey jug", "902": "whistle", "903": "wig", "904": "window screen", "905": "window shade", "906": "Windsor tie", "907": "wine bottle", "908": "wing", "909": "wok", "910": "wooden spoon", "911": "wool, woolen, woollen", "912": "worm fence, snake fence, snake-rail fence, Virginia fence", "913": "wreck", "914": "yawl", "915": "yurt", "916": "web site, website, internet site, site", "917": "comic book", "918": "crossword puzzle, crossword", "919": "street sign", "920": "traffic light, traffic signal, stoplight", "921": "book jacket, dust cover, dust jacket, dust wrapper", "922": "menu", "923": "plate", "924": "guacamole", "925": "consomme", "926": "hot pot, hotpot", "927": "trifle", "928": "ice cream, icecream", "929": "ice lolly, lolly, lollipop, popsicle", "930": "French loaf", "931": "bagel, beigel", "932": "pretzel", "933": "cheeseburger", "934": "hotdog, hot dog, red hot", "935": "mashed potato", "936": "head cabbage", "937": "broccoli", "938": "cauliflower", "939": "zucchini, courgette", "940": "spaghetti squash", "941": "acorn squash", "942": "butternut squash", "943": "cucumber, cuke", "944": "artichoke, globe artichoke", "945": "bell pepper", "946": "cardoon", "947": "mushroom", "948": "Granny Smith", "949": "strawberry", "950": "orange", "951": "lemon", "952": "fig", "953": "pineapple, ananas", "954": "banana", "955": "jackfruit, jak, jack", "956": "custard apple", "957": "pomegranate", "958": "hay", "959": "carbonara", "960": "chocolate sauce, chocolate syrup", "961": "dough", "962": "meat loaf, meatloaf", "963": "pizza, pizza pie", "964": "potpie", "965": "burrito", "966": "red wine", "967": "espresso", "968": "cup", "969": "eggnog", "970": "alp", "971": "bubble", "972": "cliff, drop, drop-off", "973": "coral reef", "974": "geyser", "975": "lakeside, lakeshore", "976": "promontory, headland, head, foreland", "977": "sandbar, sand bar", "978": "seashore, coast, seacoast, sea-coast", "979": "valley, vale", "980": "volcano", "981": "ballplayer, baseball player", "982": "groom, bridegroom", "983": "scuba diver", "984": "rapeseed", "985": "daisy", "986": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", "987": "corn", "988": "acorn", "989": "hip, rose hip, rosehip", "990": "buckeye, horse chestnut, conker", "991": "coral fungus", "992": "agaric", "993": "gyromitra", "994": "stinkhorn, carrion fungus", "995": "earthstar", "996": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", "997": "bolete", "998": "ear, spike, capitulum", "999": "toilet tissue, toilet paper, bathroom tissue"}
--------------------------------------------------------------------------------
/featmap/5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap3_7e9ee24563cc31d34de2020e1acaecc5.jpeg
--------------------------------------------------------------------------------
/featmap/5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap4_7e9ee24563cc31d34de2020e1acaecc5.jpeg
--------------------------------------------------------------------------------
/featmap/5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/featmap/5rd_depthwise_conv_featmap5_7e9ee24563cc31d34de2020e1acaecc5.jpeg
--------------------------------------------------------------------------------
/fire_smoke_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: fire_smoke_demo.py
5 | @Author:kong
6 | @Time: 2020年01月02日15时45分
7 | @Description:
8 | '''
9 | import json
10 | from PIL import Image, ImageDraw, ImageFont
11 | import torch
12 | from torch import nn
13 | from torchvision import transforms
14 | from efficientnet_pytorch import FireSmokeEfficientNet
15 | import collections
16 | image_dir = './tests/000294.jpg'
17 | model_para = collections.OrderedDict()
18 | model = FireSmokeEfficientNet.from_arch('efficientnet-b0')
19 | # out_channels = model._fc.in_features
20 | model._fc = nn.Linear(1280, 3)
21 | print(model)
22 | modelpara = torch.load('./checkpoint.pth.tar')
23 | # print(modelpara['state_dict'].keys())
24 | for key in modelpara['state_dict'].keys():
25 | # print(key[7:])
26 | # newkey = model_para[key.split('.',2)[-1]]
27 | # print(newkey)
28 | model_para[key[7:]] =modelpara['state_dict'][key]
29 |
30 | # print(model_para.keys())
31 | # 训练模型转换
32 |
33 |
34 |
35 | model.load_state_dict(model_para)
36 |
37 | # Preprocess image
38 | tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
39 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
40 | image = Image.open(image_dir)
41 | img = tfms(image).unsqueeze(0)
42 | print(img.shape) # torch.Size([1, 3, 224, 224])
43 |
44 | # Load ImageNet class names
45 | labels_map = json.load(open('examples/simple/fire_smoke_map.txt'))
46 | labels_map = [labels_map[str(i)] for i in range(3)]
47 |
48 | # Classify
49 | model.eval()
50 | with torch.no_grad():
51 | outputs = model(img)
52 |
53 | draw = ImageDraw.Draw(image)
54 | font = ImageFont.truetype('simkai.ttf', 30)
55 | # Print predictions
56 | print('-----')
57 | cout = 0
58 | for idx in torch.topk(outputs, k=2).indices.squeeze(0).tolist():
59 | cout += 1
60 | prob = torch.softmax(outputs, dim=1)[0, idx].item()
61 | print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob*100))
62 | position = (10, 30*cout - 20)
63 | text = '{label:<5} :{p:.2f}%'.format(label=labels_map[idx], p=prob*100)
64 | draw.text(position, text, font=font, fill="#ff0000", spacing=0, align='left')
65 |
66 | image.save('results/result_{}'.format(image_dir.split('/')[-1]))
--------------------------------------------------------------------------------
/fire_smoke_detection.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: fire_smoke_detection.py
5 | @Author:kong
6 | @Time: 2020年01月03日10时50分
7 | @Description: efficientnet烟火检测
8 | '''
9 |
10 | import json
11 | from PIL import Image, ImageDraw, ImageFont
12 | import torch
13 | from torchvision import transforms
14 | from efficientnet_pytorch import FireSmokeEfficientNet
15 | import collections
16 |
17 | # from PIL import Image, ImageDraw, ImageFont
18 | image_path = './tests/000127.jpg'
19 | col = 5
20 | row = 4
21 | interCLS = ["smoke","fire"]
22 | model_para = collections.OrderedDict()
23 | model = FireSmokeEfficientNet.from_arch('efficientnet-b0')
24 | # out_channels = model._fc.in_features
25 | model._fc = torch.nn.Linear(1280, 3)
26 | modelpara = torch.load('./checkpoint.pth.tar')
27 | # print(modelpara['state_dict'].keys())
28 | for key in modelpara['state_dict'].keys():
29 | # print(key[7:])
30 | # newkey = model_para[key.split('.',2)[-1]]
31 | # print(newkey)
32 | model_para[key[7:]] =modelpara['state_dict'][key]
33 |
34 | # print(model_para.keys())
35 | # 训练模型转换
36 | model.load_state_dict(model_para)
37 |
38 | # Preprocess image
39 | tfms = transforms.Compose([transforms.Resize(224), transforms.ToTensor(),
40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
41 |
42 | # Load ImageNet class names
43 | labels_map = json.load(open('examples/simple/fire_smoke_map.txt'))
44 | labels_map = [labels_map[str(i)] for i in range(3)]
45 |
46 | image = Image.open(image_path)
47 | width = image.width
48 | height = image.height
49 | w_len = int(width / col) ##每个block 长宽:h_len/w_len
50 | h_len = int(height / row)
51 |
52 | draw = ImageDraw.Draw(image)
53 | font = ImageFont.truetype("simkai.ttf", 40, encoding="utf-8")#格式,参数分别为 字体文件,文字大小,编码方式
54 |
55 | for r in range(row):
56 | for c in range(col):
57 | image_tmp = image.crop((c*w_len,r*h_len,(c+1)*w_len,(r+1)*h_len))
58 | img_tmp = tfms(image_tmp).unsqueeze(0)
59 | model.eval()
60 | with torch.no_grad():
61 | outputs = model(img_tmp)
62 | print('-----')
63 | for idx in torch.topk(outputs, k=1).indices.squeeze(0).tolist():
64 | prob = torch.softmax(outputs, dim=1)[0, idx].item()
65 | print('{label:<75} ({p:.2f}%)'.format(label=labels_map[idx], p=prob * 100))
66 | # image_tmp.save('{}_{}_{}.jpg'.format(r, c, labels_map[idx]))
67 | if prob > 0.99 and labels_map[idx] in interCLS:
68 | draw.line([(c*w_len,r*h_len),((c+1)*w_len, r*h_len),((c+1)*w_len, (r+1)*h_len),(c*w_len,(r+1)*h_len),(c*w_len,r*h_len)],fill = (255,0,0), width = 2)
69 | draw.text(((c+1)*w_len, r*h_len), labels_map[idx], (255, 0, 0), font=font) # 写文字,参数为文字添加位置,添加的文字字符串,文字颜色,格式
70 |
71 | image.save("results/det_results{}".format(image_path.split('/')[-1]))
72 |
73 |
--------------------------------------------------------------------------------
/model_vusal.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: model_vusal.py
5 | @Author:kong
6 | @Time: 2020年01月07日19时17分
7 | @Description:
8 | '''
9 |
10 | import matplotlib.pyplot as plt
11 | import torch
12 | from torchvision import models, transforms
13 | from torch.utils.data import DataLoader, Dataset
14 | import torch.nn as nn
15 | from torch.autograd import Variable
16 | import torch.nn.functional as F
17 | import os
18 | import numpy as np
19 | from torchvision.datasets import ImageFolder
20 |
21 | torch.cuda.set_device(0) # 设置GPU ID
22 | is_cuda = True
23 | simple_transform = transforms.Compose([transforms.Resize((224, 224)),
24 | transforms.ToTensor(), # H, W, C -> C, W, H 归一化到(0,1),简单直接除以255
25 | transforms.Normalize([0.485, 0.456, 0.406], # std
26 | [0.229, 0.224, 0.225])])
27 |
28 | # mean 先将输入归一化到(0,1),再使用公式”(x-mean)/std”,将每个元素分布到(-1,1)
29 | # 使用 ImageFolder 必须有对应的目录结构
30 | train = ImageFolder("/home/kong/Documents/EfficientNet-PyTorch/cropdata/train", simple_transform)
31 | valid = ImageFolder("/home/kong/Documents/EfficientNet-PyTorch/cropdata/val", simple_transform)
32 | train_loader = DataLoader(train, batch_size=1, shuffle=False, num_workers=5)
33 | val_loader = DataLoader(valid, batch_size=1, shuffle=False, num_workers=5)
34 |
35 | vgg = models.mobilenet_v2(pretrained=True).cuda()
36 |
37 |
38 | # 提取不同层输出的 主要代码
39 | class LayerActivations:
40 | features = None
41 |
42 | def __init__(self, model, layer_num):
43 | self.hook = model[layer_num].register_forward_hook(self.hook_fn)
44 |
45 | def hook_fn(self, module, input, output):
46 | self.features = output.cpu()
47 |
48 | def remove(self):
49 | self.hook.remove()
50 |
51 |
52 | conv_out = LayerActivations(vgg.features, 18) # 提出第 一个卷积层的输出
53 | img = next(iter(train_loader))[0]
54 | o = vgg(Variable(img.cuda()))
55 | conv_out.remove() #
56 | act = conv_out.features # act 即 第0层输出的特征
57 |
58 | # 可视化 输出
59 | fig = plt.figure(figsize=(20, 50))
60 | fig.subplots_adjust(left=0, right=1, bottom=0, top=0.8, hspace=0, wspace=0.2)
61 | for i in range(30):
62 | ax = fig.add_subplot(12, 5, i + 1, xticks=[], yticks=[])
63 | ax.imshow(act[0][i].detach().numpy(), cmap="gray")
64 | plt.show()
--------------------------------------------------------------------------------
/results/acc_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/acc_loss.png
--------------------------------------------------------------------------------
/results/det_results000127.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/det_results000127.jpg
--------------------------------------------------------------------------------
/results/result_000127.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/result_000127.jpg
--------------------------------------------------------------------------------
/results/result_7e9ee24563cc31d34de2020e1acaecc5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/results/result_7e9ee24563cc31d34de2020e1acaecc5.jpeg
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # Note: To use the 'upload' functionality of this file, you must:
5 | # $ pipenv install twine --dev
6 |
7 | import io
8 | import os
9 | import sys
10 | from shutil import rmtree
11 |
12 | from setuptools import find_packages, setup, Command
13 |
14 | # Package meta-data.
15 | NAME = 'efficientnet_pytorch'
16 | DESCRIPTION = 'EfficientNet implemented in PyTorch.'
17 | URL = 'https://github.com/lukemelas/efficientnet_pytorch'
18 | EMAIL = 'lmelaskyriazi@college.harvard.edu'
19 | AUTHOR = 'Luke'
20 | REQUIRES_PYTHON = '>=3.5.0'
21 | VERSION = '0.5.1'
22 |
23 | # What packages are required for this module to be executed?
24 | REQUIRED = [
25 | 'torch'
26 | ]
27 |
28 | # What packages are optional?
29 | EXTRAS = {
30 | # 'fancy feature': ['django'],
31 | }
32 |
33 | # The rest you shouldn't have to touch too much :)
34 | # ------------------------------------------------
35 | # Except, perhaps the License and Trove Classifiers!
36 | # If you do change the License, remember to change the Trove Classifier for that!
37 |
38 | here = os.path.abspath(os.path.dirname(__file__))
39 |
40 | # Import the README and use it as the long-description.
41 | # Note: this will only work if 'README.md' is present in your MANIFEST.in file!
42 | try:
43 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f:
44 | long_description = '\n' + f.read()
45 | except FileNotFoundError:
46 | long_description = DESCRIPTION
47 |
48 | # Load the package's __version__.py module as a dictionary.
49 | about = {}
50 | if not VERSION:
51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
52 | with open(os.path.join(here, project_slug, '__version__.py')) as f:
53 | exec(f.read(), about)
54 | else:
55 | about['__version__'] = VERSION
56 |
57 |
58 | class UploadCommand(Command):
59 | """Support setup.py upload."""
60 |
61 | description = 'Build and publish the package.'
62 | user_options = []
63 |
64 | @staticmethod
65 | def status(s):
66 | """Prints things in bold."""
67 | print('\033[1m{0}\033[0m'.format(s))
68 |
69 | def initialize_options(self):
70 | pass
71 |
72 | def finalize_options(self):
73 | pass
74 |
75 | def run(self):
76 | try:
77 | self.status('Removing previous builds…')
78 | rmtree(os.path.join(here, 'dist'))
79 | except OSError:
80 | pass
81 |
82 | self.status('Building Source and Wheel (universal) distribution…')
83 | os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable))
84 |
85 | self.status('Uploading the package to PyPI via Twine…')
86 | os.system('twine upload dist/*')
87 |
88 | self.status('Pushing git tags…')
89 | os.system('git tag v{0}'.format(about['__version__']))
90 | os.system('git push --tags')
91 |
92 | sys.exit()
93 |
94 |
95 | # Where the magic happens:
96 | setup(
97 | name=NAME,
98 | version=about['__version__'],
99 | description=DESCRIPTION,
100 | long_description=long_description,
101 | long_description_content_type='text/markdown',
102 | author=AUTHOR,
103 | author_email=EMAIL,
104 | python_requires=REQUIRES_PYTHON,
105 | url=URL,
106 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
107 | # py_modules=['model'], # If your package is a single module, use this instead of 'packages'
108 | install_requires=REQUIRED,
109 | extras_require=EXTRAS,
110 | include_package_data=True,
111 | license='Apache',
112 | classifiers=[
113 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
114 | 'License :: OSI Approved :: Apache Software License',
115 | 'Programming Language :: Python',
116 | 'Programming Language :: Python :: 3',
117 | 'Programming Language :: Python :: 3.6',
118 | ],
119 | # $ setup.py publish support.
120 | cmdclass={
121 | 'upload': UploadCommand,
122 | },
123 | )
124 |
--------------------------------------------------------------------------------
/tests/000127.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/tests/000127.jpg
--------------------------------------------------------------------------------
/tests/7e9ee24563cc31d34de2020e1acaecc5.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/midasklr/FireSmokeDetectionByEfficientNet/c8fc3a1acfca121e4e3ffc3ee9a8ac9060264c42/tests/7e9ee24563cc31d34de2020e1acaecc5.jpeg
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # -*- coding: utf-8 -*-
3 | '''
4 | @File: train.py
5 | @Author:kong
6 | @Time: 2020年01月02日13时47分
7 | @Description:
8 | '''
9 | import argparse
10 | import os
11 | import random
12 | import shutil
13 | import time
14 | import warnings
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.parallel
19 | import torch.backends.cudnn as cudnn
20 | import torch.distributed as dist
21 | import torch.optim
22 | import torch.multiprocessing as mp
23 | import torch.utils.data
24 | import torch.utils.data.distributed
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 | # import torchvision.models as models
28 | from torchvision import transforms
29 | import matplotlib.pyplot as plt
30 | from efficientnet_pytorch import FireSmokeEfficientNet
31 |
32 | parser = argparse.ArgumentParser(description='PyTorch Fire and Smoke Training with EfficientNet')
33 | parser.add_argument('-da','--data', metavar='DIR',default="./cropdata",
34 | help='path to dataset')
35 | parser.add_argument('-a', '--arch', metavar='ARCH', default='efficientnet-b0',
36 | help='model architecture: ' + ' (default: resnet18)')
37 | parser.add_argument('--num_cls',default = 3, type = int, help = 'number of class in ur cls task')
38 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
39 | help='number of data loading workers (default: 4)')
40 | parser.add_argument('--epochs', default=25, type=int, metavar='N',
41 | help='number of total epochs to run')
42 | parser.add_argument('--reduce',default = [15,20], type = list, help = 'lr decay list')
43 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
44 | help='manual epoch number (useful on restarts)')
45 | parser.add_argument('-b', '--batch-size', default=64, type=int,
46 | metavar='N',
47 | help='mini-batch size (default: 256), this is the total '
48 | 'batch size of all GPUs on the current node when '
49 | 'using Data Parallel or Distributed Data Parallel')
50 | parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
51 | metavar='LR', help='initial learning rate', dest='lr')
52 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
53 | help='momentum')
54 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
55 | metavar='W', help='weight decay (default: 1e-4)',
56 | dest='weight_decay')
57 | parser.add_argument('-p', '--print-freq', default=16, type=int,
58 | metavar='N', help='print frequency (default: 10)')
59 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
60 | help='path to latest checkpoint (default: none)')
61 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
62 | help='evaluate model on validation set')
63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
64 | help='use pre-trained model')
65 | parser.add_argument('--world-size', default=-1, type=int,
66 | help='number of nodes for distributed training')
67 | parser.add_argument('--rank', default=-1, type=int,
68 | help='node rank for distributed training')
69 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
70 | help='url used to set up distributed training')
71 | parser.add_argument('--dist-backend', default='nccl', type=str,
72 | help='distributed backend')
73 | parser.add_argument('--seed', default=None, type=int,
74 | help='seed for initializing training. ')
75 | parser.add_argument('--gpu', default=None, type=int,
76 | help='GPU id to use.')
77 | parser.add_argument('--multiprocessing-distributed', action='store_true',
78 | help='Use multi-processing distributed training to launch '
79 | 'N processes per node, which has N GPUs. This is the '
80 | 'fastest way to use PyTorch for either single node or '
81 | 'multi node data parallel training')
82 |
83 | best_acc1 = 0
84 |
85 | def main():
86 | args = parser.parse_args()
87 |
88 | if args.seed is not None:
89 | random.seed(args.seed)
90 | torch.manual_seed(args.seed)
91 | cudnn.deterministic = True
92 | warnings.warn('You have chosen to seed training. '
93 | 'This will turn on the CUDNN deterministic setting, '
94 | 'which can slow down your training considerably! '
95 | 'You may see unexpected behavior when restarting '
96 | 'from checkpoints.')
97 |
98 | if args.gpu is not None:
99 | warnings.warn('You have chosen a specific GPU. This will completely '
100 | 'disable data parallelism.')
101 |
102 | if args.dist_url == "env://" and args.world_size == -1:
103 | args.world_size = int(os.environ["WORLD_SIZE"])
104 |
105 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
106 |
107 | ngpus_per_node = torch.cuda.device_count()
108 | if args.multiprocessing_distributed:
109 | # Since we have ngpus_per_node processes per node, the total world_size
110 | # needs to be adjusted accordingly
111 | args.world_size = ngpus_per_node * args.world_size
112 | # Use torch.multiprocessing.spawn to launch distributed processes: the
113 | # main_worker process function
114 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
115 | else:
116 | # Simply call main_worker function
117 | main_worker(args.gpu, ngpus_per_node, args)
118 |
119 | def get_current_lr(epoch, args):
120 | lr = args.lr
121 | for i, lr_decay_epoch in enumerate(args.reduce):
122 | if epoch >= lr_decay_epoch:
123 | lr *= 0.1
124 | return lr
125 |
126 | def adjust_learning_rate(optimizer, epoch, args):
127 | lr = get_current_lr(epoch, args)
128 | print("current lr is:", lr)
129 | for param_group in optimizer.param_groups:
130 | param_group['lr'] = lr
131 |
132 |
133 | def main_worker(gpu, ngpus_per_node, args):
134 | global best_acc1
135 | args.gpu = gpu
136 |
137 | if args.gpu is not None:
138 | print("Use GPU: {} for training".format(args.gpu))
139 |
140 | if args.distributed:
141 | if args.dist_url == "env://" and args.rank == -1:
142 | args.rank = int(os.environ["RANK"])
143 | if args.multiprocessing_distributed:
144 | # For multiprocessing distributed training, rank needs to be the
145 | # global rank among all the processes
146 | args.rank = args.rank * ngpus_per_node + gpu
147 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
148 | world_size=args.world_size, rank=args.rank)
149 | # create model
150 | if args.pretrained:
151 | print("=> using pre-trained model '{}'".format(args.arch))
152 | model = FireSmokeEfficientNet.from_pretrained(args)
153 | print(model)
154 | else:
155 | print("=> creating model '{}'".format(args.arch))
156 | model = FireSmokeEfficientNet.from_pretrained(args)
157 | print(model)
158 |
159 | if args.distributed:
160 | # For multiprocessing distributed, DistributedDataParallel constructor
161 | # should always set the single device scope, otherwise,
162 | # DistributedDataParallel will use all available devices.
163 | if args.gpu is not None:
164 | torch.cuda.set_device(args.gpu)
165 | model.cuda(args.gpu)
166 | # When using a single GPU per process and per
167 | # DistributedDataParallel, we need to divide the batch size
168 | # ourselves based on the total number of GPUs we have
169 | args.batch_size = int(args.batch_size / ngpus_per_node)
170 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
171 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
172 | else:
173 | model.cuda()
174 | # DistributedDataParallel will divide and allocate batch_size to all
175 | # available GPUs if device_ids are not set
176 | model = torch.nn.parallel.DistributedDataParallel(model)
177 | elif args.gpu is not None:
178 | torch.cuda.set_device(args.gpu)
179 | model = model.cuda(args.gpu)
180 | else:
181 | # DataParallel will divide and allocate batch_size to all available GPUs
182 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
183 | model.features = torch.nn.DataParallel(model.features)
184 | model.cuda()
185 | else:
186 | model = torch.nn.DataParallel(model).cuda()
187 |
188 | # define loss function (criterion) and optimizer
189 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
190 | #
191 | # optimizer = torch.optim.SGD(model.parameters(), args.lr,
192 | # momentum=args.momentum,
193 | # weight_decay=args.weight_decay)
194 | optimizer = torch.optim.Adam(model.parameters(), args.lr)
195 | # optionally resume from a checkpoint
196 | if args.resume:
197 | if os.path.isfile(args.resume):
198 | print("=> loading checkpoint '{}'".format(args.resume))
199 | if args.gpu is None:
200 | checkpoint = torch.load(args.resume)
201 | else:
202 | # Map model to be loaded to specified single gpu.
203 | loc = 'cuda:{}'.format(args.gpu)
204 | checkpoint = torch.load(args.resume, map_location=loc)
205 | args.start_epoch = checkpoint['epoch']
206 | best_acc1 = checkpoint['best_acc1']
207 | if args.gpu is not None:
208 | # best_acc1 may be from a checkpoint from a different GPU
209 | best_acc1 = best_acc1.to(args.gpu)
210 | model.load_state_dict(checkpoint['state_dict'])
211 | optimizer.load_state_dict(checkpoint['optimizer'])
212 | print("=> loaded checkpoint '{}' (epoch {})"
213 | .format(args.resume, checkpoint['epoch']))
214 | else:
215 | print("=> no checkpoint found at '{}'".format(args.resume))
216 |
217 | cudnn.benchmark = True
218 |
219 | # Data loading code
220 | traindir = os.path.join(args.data, 'train')
221 | valdir = os.path.join(args.data, 'val')
222 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
223 | std=[0.229, 0.224, 0.225])
224 |
225 | train_dataset = datasets.ImageFolder(
226 | traindir,
227 | transforms.Compose([
228 | transforms.RandomResizedCrop(224),
229 | transforms.RandomHorizontalFlip(),
230 | transforms.ToTensor(),
231 | normalize,
232 | ]))
233 |
234 | if args.distributed:
235 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
236 | else:
237 | train_sampler = None
238 |
239 | train_loader = torch.utils.data.DataLoader(
240 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
241 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
242 |
243 | val_loader = torch.utils.data.DataLoader(
244 | datasets.ImageFolder(valdir, transforms.Compose([
245 | transforms.Resize(256),
246 | transforms.CenterCrop(224),
247 | transforms.ToTensor(),
248 | normalize,
249 | ])),
250 | batch_size=args.batch_size, shuffle=False,
251 | num_workers=args.workers, pin_memory=True)
252 |
253 | if args.evaluate:
254 | validate(val_loader, model, criterion, args)
255 | return
256 | loss_all = []
257 | acc_all = []
258 | for epoch in range(args.start_epoch, args.epochs):
259 | if args.distributed:
260 | train_sampler.set_epoch(epoch)
261 | adjust_learning_rate(optimizer, epoch, args)
262 |
263 | # train for one epoch
264 | loss_tmp, acc_tmp = train(train_loader, model, criterion, optimizer, epoch, args)
265 | loss_all.extend(loss_tmp)
266 | acc_all.extend(acc_tmp)
267 | # evaluate on validation set
268 | acc1 = validate(val_loader, model, criterion, args)
269 | # remember best acc@1 and save checkpoint
270 | is_best = acc1 > best_acc1
271 | best_acc1 = max(acc1, best_acc1)
272 |
273 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
274 | and args.rank % ngpus_per_node == 0):
275 | save_checkpoint({
276 | 'epoch': epoch + 1,
277 | 'arch': args.arch,
278 | 'state_dict': model.state_dict(),
279 | 'best_acc1': best_acc1,
280 | 'optimizer' : optimizer.state_dict(),
281 | }, is_best)
282 | x1 = list(range(len(acc_all)))
283 | x2 = list(range(len(loss_all)))
284 | y1 = acc_all
285 | y2 = loss_all
286 | plt.subplot(2, 1, 1)
287 | # plt.plot(x1, y1, 'o-',color='r')
288 | plt.plot(x1, y1, 'o-', label="Train_Accuracy")
289 | plt.title('train acc vs. iter')
290 | plt.ylabel('train accuracy')
291 | plt.legend(loc='best')
292 | plt.subplot(2, 1, 2)
293 | plt.plot(x2, y2, '.-', label="Train_Loss")
294 | plt.xlabel('train loss vs. iter')
295 | plt.ylabel('train loss')
296 | plt.legend(loc='best')
297 | plt.savefig("acc_loss.png")
298 | plt.show()
299 |
300 |
301 | def train(train_loader, model, criterion, optimizer, epoch, args):
302 | batch_time = AverageMeter('Time', ':6.3f')
303 | data_time = AverageMeter('Data', ':6.3f')
304 | losses = AverageMeter('Loss', ':.4e')
305 | top1 = AverageMeter('Acc@1', ':6.2f')
306 | top3 = AverageMeter('Acc@3', ':6.2f')
307 | loss_all = []
308 | acc_all = []
309 | progress = ProgressMeter(
310 | len(train_loader),
311 | [batch_time, data_time, losses, top1, top3],
312 | prefix="Epoch: [{}]".format(epoch))
313 |
314 | # switch to train mode
315 | model.train()
316 | end = time.time()
317 | loss_tp = []
318 | acc_tp = []
319 | for i, (images, target) in enumerate(train_loader):
320 | # measure data loading time
321 | data_time.update(time.time() - end)
322 | if args.gpu is not None:
323 | images = images.cuda(args.gpu, non_blocking=True)
324 | target = target.cuda(args.gpu, non_blocking=True)
325 |
326 | # compute output
327 | output = model(images)
328 | loss = criterion(output, target)
329 | # measure accuracy and record loss
330 | acc1, acc3 = accuracy(output, target, topk=(1, 3))
331 | losses.update(loss.item(), images.size(0))
332 | top1.update(acc1[0], images.size(0))
333 | top3.update(acc3[0], images.size(0))
334 | loss_tp.append(loss.item())
335 | acc_tp.append(acc1[0])
336 | if i % 500 == 0:
337 | loss_all.append(sum(loss_tp) / (len(loss_tp)))
338 | loss_tp = []
339 | acc_all.append(sum(acc_tp) / (len(acc_tp)))
340 | acc_tp = []
341 |
342 | # compute gradient and do SGD step
343 | optimizer.zero_grad()
344 | loss.backward()
345 | optimizer.step()
346 | loss_all.append(loss.item())
347 | acc_all.append(acc1[0])
348 |
349 | # measure elapsed time
350 | batch_time.update(time.time() - end)
351 | end = time.time()
352 |
353 | if i % args.print_freq == 0:
354 | progress.display(i)
355 | return loss_all , acc_all
356 |
357 |
358 | def validate(val_loader, model, criterion, args):
359 | batch_time = AverageMeter('Time', ':6.3f')
360 | losses = AverageMeter('Loss', ':.4e')
361 | top1 = AverageMeter('Acc@1', ':6.2f')
362 | top3 = AverageMeter('Acc@3', ':6.2f')
363 | progress = ProgressMeter(
364 | len(val_loader),
365 | [batch_time, losses, top1, top3],
366 | prefix='Test: ')
367 |
368 | # switch to evaluate mode
369 | model.eval()
370 |
371 | with torch.no_grad():
372 | end = time.time()
373 | for i, (images, target) in enumerate(val_loader):
374 | if args.gpu is not None:
375 | images = images.cuda(args.gpu, non_blocking=True)
376 | target = target.cuda(args.gpu, non_blocking=True)
377 |
378 | # compute output
379 | output = model(images)
380 | loss = criterion(output, target)
381 |
382 | # measure accuracy and record loss
383 | acc1, acc3 = accuracy(output, target, topk=(1, 3))
384 | losses.update(loss.item(), images.size(0))
385 | top1.update(acc1[0], images.size(0))
386 | top3.update(acc3[0], images.size(0))
387 |
388 | # measure elapsed time
389 | batch_time.update(time.time() - end)
390 | end = time.time()
391 |
392 | if i % args.print_freq == 0:
393 | progress.display(i)
394 |
395 | # TODO: this should also be done with the ProgressMeter
396 | print(' * Acc@1 {top1.avg:.3f} Acc@3 {top5.avg:.3f}'
397 | .format(top1=top1, top5=top3))
398 |
399 | return top1.avg
400 |
401 |
402 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
403 | torch.save(state, filename)
404 | if is_best:
405 | shutil.copyfile(filename, 'model_best.pth.tar')
406 |
407 |
408 | class AverageMeter(object):
409 | """Computes and stores the average and current value"""
410 | def __init__(self, name, fmt=':f'):
411 | self.name = name
412 | self.fmt = fmt
413 | self.reset()
414 |
415 | def reset(self):
416 | self.val = 0
417 | self.avg = 0
418 | self.sum = 0
419 | self.count = 0
420 |
421 | def update(self, val, n=1):
422 | self.val = val
423 | self.sum += val * n
424 | self.count += n
425 | self.avg = self.sum / self.count
426 |
427 | def __str__(self):
428 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
429 | return fmtstr.format(**self.__dict__)
430 |
431 |
432 | class ProgressMeter(object):
433 | def __init__(self, num_batches, meters, prefix=""):
434 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
435 | self.meters = meters
436 | self.prefix = prefix
437 |
438 | def display(self, batch):
439 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
440 | entries += [str(meter) for meter in self.meters]
441 | print('\t'.join(entries))
442 |
443 | def _get_batch_fmtstr(self, num_batches):
444 | num_digits = len(str(num_batches // 1))
445 | fmt = '{:' + str(num_digits) + 'd}'
446 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
447 |
448 |
449 | # def adjust_learning_rate(optimizer, epoch, args):
450 | # """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
451 | # lr = args.lr * (0.1 ** (epoch // 30))
452 | # for param_group in optimizer.param_groups:
453 | # param_group['lr'] = lr
454 |
455 |
456 | def accuracy(output, target, topk=(1,)):
457 | """Computes the accuracy over the k top predictions for the specified values of k"""
458 | with torch.no_grad():
459 | maxk = max(topk)
460 | batch_size = target.size(0)
461 |
462 | _, pred = output.topk(maxk, 1, True, True)
463 | pred = pred.t()
464 | correct = pred.eq(target.view(1, -1).expand_as(pred))
465 |
466 | res = []
467 | for k in topk:
468 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
469 | res.append(correct_k.mul_(100.0 / batch_size))
470 | return res
471 |
472 |
473 | if __name__ == '__main__':
474 | main()
--------------------------------------------------------------------------------