├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
├── fsa1x1-08082020.chkpt
└── fsavar-09082020.chkpt
├── data
└── .gitignore
├── extras
└── headpose-demo.gif
├── pretrained
├── fsanet-1x1-iter-688590.onnx
├── fsanet-var-iter-688590.onnx
├── res10_300x300_ssd_iter_140000.caffemodel
└── resnet10_ssd.prototxt
├── requirements.txt
└── src
├── 1-Explore Dataset.ipynb
├── 2-Train Model.ipynb
├── 3-Test Model.ipynb
├── 4-Export to Onnx.ipynb
├── dataset.py
├── demo.py
├── face_detector.py
├── model.py
├── transforms.py
├── utils.py
└── zoom_transform.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # Custom Folders
132 | **/others/
133 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Omar Hassan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # headpose-fsanet-pytorch
2 | Pytorch implementation of FSA-Net: Learning Fine-Grained Structure Aggregation for Head Pose Estimation from a Single Image[2](#references).
3 |
4 | ## Demo
5 | 
6 |
7 | Video file or a camera index can be provided to demo script. If no argument is provided, default camera index is used.
8 |
9 | ### Video File Usage
10 |
11 | For any video format that OpenCV supported (`mp4`, `avi` etc.):
12 |
13 | ```bash
14 | python3 demo.py --video /path/to/video.mp4
15 | ```
16 |
17 | ### Camera Usage
18 |
19 | ```bash
20 | python3 demo.py --cam 0
21 | ```
22 |
23 | ## Results
24 |
25 | | Model | Dataset Type | Yaw (MAE) | Pitch (MAE) | Roll (MAE) |
26 | | --- | --- | --- | --- | --- |
27 | | FSA-Caps (1x1) | 1 | 4.85 | 6.27 | 4.96 |
28 | | FSA-Caps (Var) | 1 | 5.06 | 6.46 | 5.00 |
29 | | FSA-Caps (1x1 + Var) | 1 | **4.64** | **6.10** | **4.79** |
30 |
31 | **Note:** My results are slightly worse than original author's results. For best results, please refer to official repository[1](#acknowledgements).
32 |
33 |
34 | ## Dependencies
35 |
36 | ```
37 | Name Version
38 | python 3.7.6
39 | numpy 1.18.5
40 | opencv 4.2.0
41 | scipy 1.5.0
42 | matplotlib-base 3.2.2
43 | pytorch 1.5.1
44 | torchvision 0.6.1
45 | onnx 1.7.0
46 | onnxruntime 1.2.0
47 | ```
48 |
49 |
50 | Installation with pip
51 | ```bash
52 | pip3 install -r requirements.txt
53 | ```
54 |
55 |
56 | You may also need to install jupyter to access notebooks (.ipynb). It is recommended that you use Anaconda to install packages.
57 |
58 | Code has been tested on Ubuntu 18.04
59 |
60 | ## Important Files Overview
61 |
62 | - **src/dataset.py:** Our pytorch dataset class is defined here
63 | - **src/model.py:** Pytorch FSA-Net model is defined here
64 | - **src/transforms.py:** Augmentation Transforms are defined here
65 | - **src/1-Explore Dataset.ipynb:** To explore training data, refer to this notebook
66 | - **src/2-Train Model.ipynb:** For model training, refer to this notebook
67 | - **src/3-Test Model.ipynb:** For model testing, refer to this notebook
68 | - **src/4-Export to Onnx.ipynb:** For exporting model, refer to this notebook
69 | - **src/demo.py:** Demo script is defined here
70 |
71 | ## Download Dataset
72 | For model training and testing, download the preprocessed dataset from author's official git repository[1](#acknowledgements) and place them inside data/ directory. I am only using type1 data for training and testing. Your dataset hierarchy should look like:
73 |
74 | ```
75 | data/
76 | type1/
77 | test/
78 | AFLW2000.npz
79 | train/
80 | AFW.npz
81 | AFW_Flip.npz
82 | HELEN.npz
83 | HELEN_Flip.npz
84 | IBUG.npz
85 | IBUG_Flip.npz
86 | LFPW.npz
87 | LFPW_Flip.npz
88 | ```
89 |
90 | ## License
91 | Copyright (c) 2020, Omar Hassan. (MIT License)
92 |
93 | ## Acknowledgements
94 | Special thanks to Mr. Tsun-Yi Yang for providing an excellent code to his paper. Please refer to the official repository to see detailed information and best results regarding the model:
95 |
96 | \[1] T. Yang, FSA-Net, (2019), [GitHub repository](https://github.com/shamangary/FSA-Net)
97 |
98 | The models are trained and tested with various public datasets which have their own licenses. Please refer to them before using the code
99 |
100 | - 300W-LP: http://www.cbsr.ia.ac.cn/users/xiangyuzhu/projects/3DDFA/main.htm
101 | - LFPW: https://neerajkumar.org/databases/lfpw/
102 | - HELEN: http://www.ifp.illinois.edu/~vuongle2/helen/
103 | - AFW: https://www.ics.uci.edu/~xzhu/face/
104 | - IBUG: https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
105 | - AFW2000: http://cvlab.cse.msu.edu/lfw-and-aflw2000-datasets.html
106 |
107 | ## References
108 | \[2] T. Yang, Y. Chen, Y. Lin and Y. Chuang, "FSA-Net: Learning Fine-Grained Structure Aggregation for Head Pose Estimation From a Single Image," 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), Long Beach, CA, USA, 2019, pp. 1087-1096, doi: 10.1109/CVPR.2019.00118. [IEEE-Xplore link](https://ieeexplore.ieee.org/document/8954346)
109 |
110 | \[3] Tal Hassner, Shai Harel, Eran Paz, and Roee Enbar. Effective face frontalization in unconstrained images. In CVPR, 2015
111 |
112 | \[4] Xiangyu Zhu, Zhen Lei, Junjie Yan, Dong Yi, and Stan Z. Li. High-fidelity pose and expression normalization for face recognition in the wild. In CVPR, 2015.
113 |
--------------------------------------------------------------------------------
/checkpoints/fsa1x1-08082020.chkpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/checkpoints/fsa1x1-08082020.chkpt
--------------------------------------------------------------------------------
/checkpoints/fsavar-09082020.chkpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/checkpoints/fsavar-09082020.chkpt
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 |
--------------------------------------------------------------------------------
/extras/headpose-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/extras/headpose-demo.gif
--------------------------------------------------------------------------------
/pretrained/fsanet-1x1-iter-688590.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/pretrained/fsanet-1x1-iter-688590.onnx
--------------------------------------------------------------------------------
/pretrained/fsanet-var-iter-688590.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/pretrained/fsanet-var-iter-688590.onnx
--------------------------------------------------------------------------------
/pretrained/res10_300x300_ssd_iter_140000.caffemodel:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/omasaht/headpose-fsanet-pytorch/002549c981b607f722e0810d999afcde7e5aaa66/pretrained/res10_300x300_ssd_iter_140000.caffemodel
--------------------------------------------------------------------------------
/pretrained/resnet10_ssd.prototxt:
--------------------------------------------------------------------------------
1 | input: "data"
2 | input_shape {
3 | dim: 1
4 | dim: 3
5 | dim: 300
6 | dim: 300
7 | }
8 |
9 | layer {
10 | name: "data_bn"
11 | type: "BatchNorm"
12 | bottom: "data"
13 | top: "data_bn"
14 | param {
15 | lr_mult: 0.0
16 | }
17 | param {
18 | lr_mult: 0.0
19 | }
20 | param {
21 | lr_mult: 0.0
22 | }
23 | }
24 | layer {
25 | name: "data_scale"
26 | type: "Scale"
27 | bottom: "data_bn"
28 | top: "data_bn"
29 | param {
30 | lr_mult: 1.0
31 | decay_mult: 1.0
32 | }
33 | param {
34 | lr_mult: 2.0
35 | decay_mult: 1.0
36 | }
37 | scale_param {
38 | bias_term: true
39 | }
40 | }
41 | layer {
42 | name: "conv1_h"
43 | type: "Convolution"
44 | bottom: "data_bn"
45 | top: "conv1_h"
46 | param {
47 | lr_mult: 1.0
48 | decay_mult: 1.0
49 | }
50 | param {
51 | lr_mult: 2.0
52 | decay_mult: 1.0
53 | }
54 | convolution_param {
55 | num_output: 32
56 | pad: 3
57 | kernel_size: 7
58 | stride: 2
59 | weight_filler {
60 | type: "msra"
61 | variance_norm: FAN_OUT
62 | }
63 | bias_filler {
64 | type: "constant"
65 | value: 0.0
66 | }
67 | }
68 | }
69 | layer {
70 | name: "conv1_bn_h"
71 | type: "BatchNorm"
72 | bottom: "conv1_h"
73 | top: "conv1_h"
74 | param {
75 | lr_mult: 0.0
76 | }
77 | param {
78 | lr_mult: 0.0
79 | }
80 | param {
81 | lr_mult: 0.0
82 | }
83 | }
84 | layer {
85 | name: "conv1_scale_h"
86 | type: "Scale"
87 | bottom: "conv1_h"
88 | top: "conv1_h"
89 | param {
90 | lr_mult: 1.0
91 | decay_mult: 1.0
92 | }
93 | param {
94 | lr_mult: 2.0
95 | decay_mult: 1.0
96 | }
97 | scale_param {
98 | bias_term: true
99 | }
100 | }
101 | layer {
102 | name: "conv1_relu"
103 | type: "ReLU"
104 | bottom: "conv1_h"
105 | top: "conv1_h"
106 | }
107 | layer {
108 | name: "conv1_pool"
109 | type: "Pooling"
110 | bottom: "conv1_h"
111 | top: "conv1_pool"
112 | pooling_param {
113 | kernel_size: 3
114 | stride: 2
115 | }
116 | }
117 | layer {
118 | name: "layer_64_1_conv1_h"
119 | type: "Convolution"
120 | bottom: "conv1_pool"
121 | top: "layer_64_1_conv1_h"
122 | param {
123 | lr_mult: 1.0
124 | decay_mult: 1.0
125 | }
126 | convolution_param {
127 | num_output: 32
128 | bias_term: false
129 | pad: 1
130 | kernel_size: 3
131 | stride: 1
132 | weight_filler {
133 | type: "msra"
134 | }
135 | bias_filler {
136 | type: "constant"
137 | value: 0.0
138 | }
139 | }
140 | }
141 | layer {
142 | name: "layer_64_1_bn2_h"
143 | type: "BatchNorm"
144 | bottom: "layer_64_1_conv1_h"
145 | top: "layer_64_1_conv1_h"
146 | param {
147 | lr_mult: 0.0
148 | }
149 | param {
150 | lr_mult: 0.0
151 | }
152 | param {
153 | lr_mult: 0.0
154 | }
155 | }
156 | layer {
157 | name: "layer_64_1_scale2_h"
158 | type: "Scale"
159 | bottom: "layer_64_1_conv1_h"
160 | top: "layer_64_1_conv1_h"
161 | param {
162 | lr_mult: 1.0
163 | decay_mult: 1.0
164 | }
165 | param {
166 | lr_mult: 2.0
167 | decay_mult: 1.0
168 | }
169 | scale_param {
170 | bias_term: true
171 | }
172 | }
173 | layer {
174 | name: "layer_64_1_relu2"
175 | type: "ReLU"
176 | bottom: "layer_64_1_conv1_h"
177 | top: "layer_64_1_conv1_h"
178 | }
179 | layer {
180 | name: "layer_64_1_conv2_h"
181 | type: "Convolution"
182 | bottom: "layer_64_1_conv1_h"
183 | top: "layer_64_1_conv2_h"
184 | param {
185 | lr_mult: 1.0
186 | decay_mult: 1.0
187 | }
188 | convolution_param {
189 | num_output: 32
190 | bias_term: false
191 | pad: 1
192 | kernel_size: 3
193 | stride: 1
194 | weight_filler {
195 | type: "msra"
196 | }
197 | bias_filler {
198 | type: "constant"
199 | value: 0.0
200 | }
201 | }
202 | }
203 | layer {
204 | name: "layer_64_1_sum"
205 | type: "Eltwise"
206 | bottom: "layer_64_1_conv2_h"
207 | bottom: "conv1_pool"
208 | top: "layer_64_1_sum"
209 | }
210 | layer {
211 | name: "layer_128_1_bn1_h"
212 | type: "BatchNorm"
213 | bottom: "layer_64_1_sum"
214 | top: "layer_128_1_bn1_h"
215 | param {
216 | lr_mult: 0.0
217 | }
218 | param {
219 | lr_mult: 0.0
220 | }
221 | param {
222 | lr_mult: 0.0
223 | }
224 | }
225 | layer {
226 | name: "layer_128_1_scale1_h"
227 | type: "Scale"
228 | bottom: "layer_128_1_bn1_h"
229 | top: "layer_128_1_bn1_h"
230 | param {
231 | lr_mult: 1.0
232 | decay_mult: 1.0
233 | }
234 | param {
235 | lr_mult: 2.0
236 | decay_mult: 1.0
237 | }
238 | scale_param {
239 | bias_term: true
240 | }
241 | }
242 | layer {
243 | name: "layer_128_1_relu1"
244 | type: "ReLU"
245 | bottom: "layer_128_1_bn1_h"
246 | top: "layer_128_1_bn1_h"
247 | }
248 | layer {
249 | name: "layer_128_1_conv1_h"
250 | type: "Convolution"
251 | bottom: "layer_128_1_bn1_h"
252 | top: "layer_128_1_conv1_h"
253 | param {
254 | lr_mult: 1.0
255 | decay_mult: 1.0
256 | }
257 | convolution_param {
258 | num_output: 128
259 | bias_term: false
260 | pad: 1
261 | kernel_size: 3
262 | stride: 2
263 | weight_filler {
264 | type: "msra"
265 | }
266 | bias_filler {
267 | type: "constant"
268 | value: 0.0
269 | }
270 | }
271 | }
272 | layer {
273 | name: "layer_128_1_bn2"
274 | type: "BatchNorm"
275 | bottom: "layer_128_1_conv1_h"
276 | top: "layer_128_1_conv1_h"
277 | param {
278 | lr_mult: 0.0
279 | }
280 | param {
281 | lr_mult: 0.0
282 | }
283 | param {
284 | lr_mult: 0.0
285 | }
286 | }
287 | layer {
288 | name: "layer_128_1_scale2"
289 | type: "Scale"
290 | bottom: "layer_128_1_conv1_h"
291 | top: "layer_128_1_conv1_h"
292 | param {
293 | lr_mult: 1.0
294 | decay_mult: 1.0
295 | }
296 | param {
297 | lr_mult: 2.0
298 | decay_mult: 1.0
299 | }
300 | scale_param {
301 | bias_term: true
302 | }
303 | }
304 | layer {
305 | name: "layer_128_1_relu2"
306 | type: "ReLU"
307 | bottom: "layer_128_1_conv1_h"
308 | top: "layer_128_1_conv1_h"
309 | }
310 | layer {
311 | name: "layer_128_1_conv2"
312 | type: "Convolution"
313 | bottom: "layer_128_1_conv1_h"
314 | top: "layer_128_1_conv2"
315 | param {
316 | lr_mult: 1.0
317 | decay_mult: 1.0
318 | }
319 | convolution_param {
320 | num_output: 128
321 | bias_term: false
322 | pad: 1
323 | kernel_size: 3
324 | stride: 1
325 | weight_filler {
326 | type: "msra"
327 | }
328 | bias_filler {
329 | type: "constant"
330 | value: 0.0
331 | }
332 | }
333 | }
334 | layer {
335 | name: "layer_128_1_conv_expand_h"
336 | type: "Convolution"
337 | bottom: "layer_128_1_bn1_h"
338 | top: "layer_128_1_conv_expand_h"
339 | param {
340 | lr_mult: 1.0
341 | decay_mult: 1.0
342 | }
343 | convolution_param {
344 | num_output: 128
345 | bias_term: false
346 | pad: 0
347 | kernel_size: 1
348 | stride: 2
349 | weight_filler {
350 | type: "msra"
351 | }
352 | bias_filler {
353 | type: "constant"
354 | value: 0.0
355 | }
356 | }
357 | }
358 | layer {
359 | name: "layer_128_1_sum"
360 | type: "Eltwise"
361 | bottom: "layer_128_1_conv2"
362 | bottom: "layer_128_1_conv_expand_h"
363 | top: "layer_128_1_sum"
364 | }
365 | layer {
366 | name: "layer_256_1_bn1"
367 | type: "BatchNorm"
368 | bottom: "layer_128_1_sum"
369 | top: "layer_256_1_bn1"
370 | param {
371 | lr_mult: 0.0
372 | }
373 | param {
374 | lr_mult: 0.0
375 | }
376 | param {
377 | lr_mult: 0.0
378 | }
379 | }
380 | layer {
381 | name: "layer_256_1_scale1"
382 | type: "Scale"
383 | bottom: "layer_256_1_bn1"
384 | top: "layer_256_1_bn1"
385 | param {
386 | lr_mult: 1.0
387 | decay_mult: 1.0
388 | }
389 | param {
390 | lr_mult: 2.0
391 | decay_mult: 1.0
392 | }
393 | scale_param {
394 | bias_term: true
395 | }
396 | }
397 | layer {
398 | name: "layer_256_1_relu1"
399 | type: "ReLU"
400 | bottom: "layer_256_1_bn1"
401 | top: "layer_256_1_bn1"
402 | }
403 | layer {
404 | name: "layer_256_1_conv1"
405 | type: "Convolution"
406 | bottom: "layer_256_1_bn1"
407 | top: "layer_256_1_conv1"
408 | param {
409 | lr_mult: 1.0
410 | decay_mult: 1.0
411 | }
412 | convolution_param {
413 | num_output: 256
414 | bias_term: false
415 | pad: 1
416 | kernel_size: 3
417 | stride: 2
418 | weight_filler {
419 | type: "msra"
420 | }
421 | bias_filler {
422 | type: "constant"
423 | value: 0.0
424 | }
425 | }
426 | }
427 | layer {
428 | name: "layer_256_1_bn2"
429 | type: "BatchNorm"
430 | bottom: "layer_256_1_conv1"
431 | top: "layer_256_1_conv1"
432 | param {
433 | lr_mult: 0.0
434 | }
435 | param {
436 | lr_mult: 0.0
437 | }
438 | param {
439 | lr_mult: 0.0
440 | }
441 | }
442 | layer {
443 | name: "layer_256_1_scale2"
444 | type: "Scale"
445 | bottom: "layer_256_1_conv1"
446 | top: "layer_256_1_conv1"
447 | param {
448 | lr_mult: 1.0
449 | decay_mult: 1.0
450 | }
451 | param {
452 | lr_mult: 2.0
453 | decay_mult: 1.0
454 | }
455 | scale_param {
456 | bias_term: true
457 | }
458 | }
459 | layer {
460 | name: "layer_256_1_relu2"
461 | type: "ReLU"
462 | bottom: "layer_256_1_conv1"
463 | top: "layer_256_1_conv1"
464 | }
465 | layer {
466 | name: "layer_256_1_conv2"
467 | type: "Convolution"
468 | bottom: "layer_256_1_conv1"
469 | top: "layer_256_1_conv2"
470 | param {
471 | lr_mult: 1.0
472 | decay_mult: 1.0
473 | }
474 | convolution_param {
475 | num_output: 256
476 | bias_term: false
477 | pad: 1
478 | kernel_size: 3
479 | stride: 1
480 | weight_filler {
481 | type: "msra"
482 | }
483 | bias_filler {
484 | type: "constant"
485 | value: 0.0
486 | }
487 | }
488 | }
489 | layer {
490 | name: "layer_256_1_conv_expand"
491 | type: "Convolution"
492 | bottom: "layer_256_1_bn1"
493 | top: "layer_256_1_conv_expand"
494 | param {
495 | lr_mult: 1.0
496 | decay_mult: 1.0
497 | }
498 | convolution_param {
499 | num_output: 256
500 | bias_term: false
501 | pad: 0
502 | kernel_size: 1
503 | stride: 2
504 | weight_filler {
505 | type: "msra"
506 | }
507 | bias_filler {
508 | type: "constant"
509 | value: 0.0
510 | }
511 | }
512 | }
513 | layer {
514 | name: "layer_256_1_sum"
515 | type: "Eltwise"
516 | bottom: "layer_256_1_conv2"
517 | bottom: "layer_256_1_conv_expand"
518 | top: "layer_256_1_sum"
519 | }
520 | layer {
521 | name: "layer_512_1_bn1"
522 | type: "BatchNorm"
523 | bottom: "layer_256_1_sum"
524 | top: "layer_512_1_bn1"
525 | param {
526 | lr_mult: 0.0
527 | }
528 | param {
529 | lr_mult: 0.0
530 | }
531 | param {
532 | lr_mult: 0.0
533 | }
534 | }
535 | layer {
536 | name: "layer_512_1_scale1"
537 | type: "Scale"
538 | bottom: "layer_512_1_bn1"
539 | top: "layer_512_1_bn1"
540 | param {
541 | lr_mult: 1.0
542 | decay_mult: 1.0
543 | }
544 | param {
545 | lr_mult: 2.0
546 | decay_mult: 1.0
547 | }
548 | scale_param {
549 | bias_term: true
550 | }
551 | }
552 | layer {
553 | name: "layer_512_1_relu1"
554 | type: "ReLU"
555 | bottom: "layer_512_1_bn1"
556 | top: "layer_512_1_bn1"
557 | }
558 | layer {
559 | name: "layer_512_1_conv1_h"
560 | type: "Convolution"
561 | bottom: "layer_512_1_bn1"
562 | top: "layer_512_1_conv1_h"
563 | param {
564 | lr_mult: 1.0
565 | decay_mult: 1.0
566 | }
567 | convolution_param {
568 | num_output: 128
569 | bias_term: false
570 | pad: 1
571 | kernel_size: 3
572 | stride: 1 # 2
573 | weight_filler {
574 | type: "msra"
575 | }
576 | bias_filler {
577 | type: "constant"
578 | value: 0.0
579 | }
580 | }
581 | }
582 | layer {
583 | name: "layer_512_1_bn2_h"
584 | type: "BatchNorm"
585 | bottom: "layer_512_1_conv1_h"
586 | top: "layer_512_1_conv1_h"
587 | param {
588 | lr_mult: 0.0
589 | }
590 | param {
591 | lr_mult: 0.0
592 | }
593 | param {
594 | lr_mult: 0.0
595 | }
596 | }
597 | layer {
598 | name: "layer_512_1_scale2_h"
599 | type: "Scale"
600 | bottom: "layer_512_1_conv1_h"
601 | top: "layer_512_1_conv1_h"
602 | param {
603 | lr_mult: 1.0
604 | decay_mult: 1.0
605 | }
606 | param {
607 | lr_mult: 2.0
608 | decay_mult: 1.0
609 | }
610 | scale_param {
611 | bias_term: true
612 | }
613 | }
614 | layer {
615 | name: "layer_512_1_relu2"
616 | type: "ReLU"
617 | bottom: "layer_512_1_conv1_h"
618 | top: "layer_512_1_conv1_h"
619 | }
620 | layer {
621 | name: "layer_512_1_conv2_h"
622 | type: "Convolution"
623 | bottom: "layer_512_1_conv1_h"
624 | top: "layer_512_1_conv2_h"
625 | param {
626 | lr_mult: 1.0
627 | decay_mult: 1.0
628 | }
629 | convolution_param {
630 | num_output: 256
631 | bias_term: false
632 | pad: 2 # 1
633 | kernel_size: 3
634 | stride: 1
635 | dilation: 2
636 | weight_filler {
637 | type: "msra"
638 | }
639 | bias_filler {
640 | type: "constant"
641 | value: 0.0
642 | }
643 | }
644 | }
645 | layer {
646 | name: "layer_512_1_conv_expand_h"
647 | type: "Convolution"
648 | bottom: "layer_512_1_bn1"
649 | top: "layer_512_1_conv_expand_h"
650 | param {
651 | lr_mult: 1.0
652 | decay_mult: 1.0
653 | }
654 | convolution_param {
655 | num_output: 256
656 | bias_term: false
657 | pad: 0
658 | kernel_size: 1
659 | stride: 1 # 2
660 | weight_filler {
661 | type: "msra"
662 | }
663 | bias_filler {
664 | type: "constant"
665 | value: 0.0
666 | }
667 | }
668 | }
669 | layer {
670 | name: "layer_512_1_sum"
671 | type: "Eltwise"
672 | bottom: "layer_512_1_conv2_h"
673 | bottom: "layer_512_1_conv_expand_h"
674 | top: "layer_512_1_sum"
675 | }
676 | layer {
677 | name: "last_bn_h"
678 | type: "BatchNorm"
679 | bottom: "layer_512_1_sum"
680 | top: "layer_512_1_sum"
681 | param {
682 | lr_mult: 0.0
683 | }
684 | param {
685 | lr_mult: 0.0
686 | }
687 | param {
688 | lr_mult: 0.0
689 | }
690 | }
691 | layer {
692 | name: "last_scale_h"
693 | type: "Scale"
694 | bottom: "layer_512_1_sum"
695 | top: "layer_512_1_sum"
696 | param {
697 | lr_mult: 1.0
698 | decay_mult: 1.0
699 | }
700 | param {
701 | lr_mult: 2.0
702 | decay_mult: 1.0
703 | }
704 | scale_param {
705 | bias_term: true
706 | }
707 | }
708 | layer {
709 | name: "last_relu"
710 | type: "ReLU"
711 | bottom: "layer_512_1_sum"
712 | top: "fc7"
713 | }
714 |
715 | layer {
716 | name: "conv6_1_h"
717 | type: "Convolution"
718 | bottom: "fc7"
719 | top: "conv6_1_h"
720 | param {
721 | lr_mult: 1
722 | decay_mult: 1
723 | }
724 | param {
725 | lr_mult: 2
726 | decay_mult: 0
727 | }
728 | convolution_param {
729 | num_output: 128
730 | pad: 0
731 | kernel_size: 1
732 | stride: 1
733 | weight_filler {
734 | type: "xavier"
735 | }
736 | bias_filler {
737 | type: "constant"
738 | value: 0
739 | }
740 | }
741 | }
742 | layer {
743 | name: "conv6_1_relu"
744 | type: "ReLU"
745 | bottom: "conv6_1_h"
746 | top: "conv6_1_h"
747 | }
748 | layer {
749 | name: "conv6_2_h"
750 | type: "Convolution"
751 | bottom: "conv6_1_h"
752 | top: "conv6_2_h"
753 | param {
754 | lr_mult: 1
755 | decay_mult: 1
756 | }
757 | param {
758 | lr_mult: 2
759 | decay_mult: 0
760 | }
761 | convolution_param {
762 | num_output: 256
763 | pad: 1
764 | kernel_size: 3
765 | stride: 2
766 | weight_filler {
767 | type: "xavier"
768 | }
769 | bias_filler {
770 | type: "constant"
771 | value: 0
772 | }
773 | }
774 | }
775 | layer {
776 | name: "conv6_2_relu"
777 | type: "ReLU"
778 | bottom: "conv6_2_h"
779 | top: "conv6_2_h"
780 | }
781 | layer {
782 | name: "conv7_1_h"
783 | type: "Convolution"
784 | bottom: "conv6_2_h"
785 | top: "conv7_1_h"
786 | param {
787 | lr_mult: 1
788 | decay_mult: 1
789 | }
790 | param {
791 | lr_mult: 2
792 | decay_mult: 0
793 | }
794 | convolution_param {
795 | num_output: 64
796 | pad: 0
797 | kernel_size: 1
798 | stride: 1
799 | weight_filler {
800 | type: "xavier"
801 | }
802 | bias_filler {
803 | type: "constant"
804 | value: 0
805 | }
806 | }
807 | }
808 | layer {
809 | name: "conv7_1_relu"
810 | type: "ReLU"
811 | bottom: "conv7_1_h"
812 | top: "conv7_1_h"
813 | }
814 | layer {
815 | name: "conv7_2_h"
816 | type: "Convolution"
817 | bottom: "conv7_1_h"
818 | top: "conv7_2_h"
819 | param {
820 | lr_mult: 1
821 | decay_mult: 1
822 | }
823 | param {
824 | lr_mult: 2
825 | decay_mult: 0
826 | }
827 | convolution_param {
828 | num_output: 128
829 | pad: 1
830 | kernel_size: 3
831 | stride: 2
832 | weight_filler {
833 | type: "xavier"
834 | }
835 | bias_filler {
836 | type: "constant"
837 | value: 0
838 | }
839 | }
840 | }
841 | layer {
842 | name: "conv7_2_relu"
843 | type: "ReLU"
844 | bottom: "conv7_2_h"
845 | top: "conv7_2_h"
846 | }
847 | layer {
848 | name: "conv8_1_h"
849 | type: "Convolution"
850 | bottom: "conv7_2_h"
851 | top: "conv8_1_h"
852 | param {
853 | lr_mult: 1
854 | decay_mult: 1
855 | }
856 | param {
857 | lr_mult: 2
858 | decay_mult: 0
859 | }
860 | convolution_param {
861 | num_output: 64
862 | pad: 0
863 | kernel_size: 1
864 | stride: 1
865 | weight_filler {
866 | type: "xavier"
867 | }
868 | bias_filler {
869 | type: "constant"
870 | value: 0
871 | }
872 | }
873 | }
874 | layer {
875 | name: "conv8_1_relu"
876 | type: "ReLU"
877 | bottom: "conv8_1_h"
878 | top: "conv8_1_h"
879 | }
880 | layer {
881 | name: "conv8_2_h"
882 | type: "Convolution"
883 | bottom: "conv8_1_h"
884 | top: "conv8_2_h"
885 | param {
886 | lr_mult: 1
887 | decay_mult: 1
888 | }
889 | param {
890 | lr_mult: 2
891 | decay_mult: 0
892 | }
893 | convolution_param {
894 | num_output: 128
895 | pad: 1
896 | kernel_size: 3
897 | stride: 1
898 | weight_filler {
899 | type: "xavier"
900 | }
901 | bias_filler {
902 | type: "constant"
903 | value: 0
904 | }
905 | }
906 | }
907 | layer {
908 | name: "conv8_2_relu"
909 | type: "ReLU"
910 | bottom: "conv8_2_h"
911 | top: "conv8_2_h"
912 | }
913 | layer {
914 | name: "conv9_1_h"
915 | type: "Convolution"
916 | bottom: "conv8_2_h"
917 | top: "conv9_1_h"
918 | param {
919 | lr_mult: 1
920 | decay_mult: 1
921 | }
922 | param {
923 | lr_mult: 2
924 | decay_mult: 0
925 | }
926 | convolution_param {
927 | num_output: 64
928 | pad: 0
929 | kernel_size: 1
930 | stride: 1
931 | weight_filler {
932 | type: "xavier"
933 | }
934 | bias_filler {
935 | type: "constant"
936 | value: 0
937 | }
938 | }
939 | }
940 | layer {
941 | name: "conv9_1_relu"
942 | type: "ReLU"
943 | bottom: "conv9_1_h"
944 | top: "conv9_1_h"
945 | }
946 | layer {
947 | name: "conv9_2_h"
948 | type: "Convolution"
949 | bottom: "conv9_1_h"
950 | top: "conv9_2_h"
951 | param {
952 | lr_mult: 1
953 | decay_mult: 1
954 | }
955 | param {
956 | lr_mult: 2
957 | decay_mult: 0
958 | }
959 | convolution_param {
960 | num_output: 128
961 | pad: 1
962 | kernel_size: 3
963 | stride: 1
964 | weight_filler {
965 | type: "xavier"
966 | }
967 | bias_filler {
968 | type: "constant"
969 | value: 0
970 | }
971 | }
972 | }
973 | layer {
974 | name: "conv9_2_relu"
975 | type: "ReLU"
976 | bottom: "conv9_2_h"
977 | top: "conv9_2_h"
978 | }
979 | layer {
980 | name: "conv4_3_norm"
981 | type: "Normalize"
982 | bottom: "layer_256_1_bn1"
983 | top: "conv4_3_norm"
984 | norm_param {
985 | across_spatial: false
986 | scale_filler {
987 | type: "constant"
988 | value: 20
989 | }
990 | channel_shared: false
991 | }
992 | }
993 | layer {
994 | name: "conv4_3_norm_mbox_loc"
995 | type: "Convolution"
996 | bottom: "conv4_3_norm"
997 | top: "conv4_3_norm_mbox_loc"
998 | param {
999 | lr_mult: 1
1000 | decay_mult: 1
1001 | }
1002 | param {
1003 | lr_mult: 2
1004 | decay_mult: 0
1005 | }
1006 | convolution_param {
1007 | num_output: 16
1008 | pad: 1
1009 | kernel_size: 3
1010 | stride: 1
1011 | weight_filler {
1012 | type: "xavier"
1013 | }
1014 | bias_filler {
1015 | type: "constant"
1016 | value: 0
1017 | }
1018 | }
1019 | }
1020 | layer {
1021 | name: "conv4_3_norm_mbox_loc_perm"
1022 | type: "Permute"
1023 | bottom: "conv4_3_norm_mbox_loc"
1024 | top: "conv4_3_norm_mbox_loc_perm"
1025 | permute_param {
1026 | order: 0
1027 | order: 2
1028 | order: 3
1029 | order: 1
1030 | }
1031 | }
1032 | layer {
1033 | name: "conv4_3_norm_mbox_loc_flat"
1034 | type: "Flatten"
1035 | bottom: "conv4_3_norm_mbox_loc_perm"
1036 | top: "conv4_3_norm_mbox_loc_flat"
1037 | flatten_param {
1038 | axis: 1
1039 | }
1040 | }
1041 | layer {
1042 | name: "conv4_3_norm_mbox_conf"
1043 | type: "Convolution"
1044 | bottom: "conv4_3_norm"
1045 | top: "conv4_3_norm_mbox_conf"
1046 | param {
1047 | lr_mult: 1
1048 | decay_mult: 1
1049 | }
1050 | param {
1051 | lr_mult: 2
1052 | decay_mult: 0
1053 | }
1054 | convolution_param {
1055 | num_output: 8 # 84
1056 | pad: 1
1057 | kernel_size: 3
1058 | stride: 1
1059 | weight_filler {
1060 | type: "xavier"
1061 | }
1062 | bias_filler {
1063 | type: "constant"
1064 | value: 0
1065 | }
1066 | }
1067 | }
1068 | layer {
1069 | name: "conv4_3_norm_mbox_conf_perm"
1070 | type: "Permute"
1071 | bottom: "conv4_3_norm_mbox_conf"
1072 | top: "conv4_3_norm_mbox_conf_perm"
1073 | permute_param {
1074 | order: 0
1075 | order: 2
1076 | order: 3
1077 | order: 1
1078 | }
1079 | }
1080 | layer {
1081 | name: "conv4_3_norm_mbox_conf_flat"
1082 | type: "Flatten"
1083 | bottom: "conv4_3_norm_mbox_conf_perm"
1084 | top: "conv4_3_norm_mbox_conf_flat"
1085 | flatten_param {
1086 | axis: 1
1087 | }
1088 | }
1089 | layer {
1090 | name: "conv4_3_norm_mbox_priorbox"
1091 | type: "PriorBox"
1092 | bottom: "conv4_3_norm"
1093 | bottom: "data"
1094 | top: "conv4_3_norm_mbox_priorbox"
1095 | prior_box_param {
1096 | min_size: 30.0
1097 | max_size: 60.0
1098 | aspect_ratio: 2
1099 | flip: true
1100 | clip: false
1101 | variance: 0.1
1102 | variance: 0.1
1103 | variance: 0.2
1104 | variance: 0.2
1105 | step: 8
1106 | offset: 0.5
1107 | }
1108 | }
1109 | layer {
1110 | name: "fc7_mbox_loc"
1111 | type: "Convolution"
1112 | bottom: "fc7"
1113 | top: "fc7_mbox_loc"
1114 | param {
1115 | lr_mult: 1
1116 | decay_mult: 1
1117 | }
1118 | param {
1119 | lr_mult: 2
1120 | decay_mult: 0
1121 | }
1122 | convolution_param {
1123 | num_output: 24
1124 | pad: 1
1125 | kernel_size: 3
1126 | stride: 1
1127 | weight_filler {
1128 | type: "xavier"
1129 | }
1130 | bias_filler {
1131 | type: "constant"
1132 | value: 0
1133 | }
1134 | }
1135 | }
1136 | layer {
1137 | name: "fc7_mbox_loc_perm"
1138 | type: "Permute"
1139 | bottom: "fc7_mbox_loc"
1140 | top: "fc7_mbox_loc_perm"
1141 | permute_param {
1142 | order: 0
1143 | order: 2
1144 | order: 3
1145 | order: 1
1146 | }
1147 | }
1148 | layer {
1149 | name: "fc7_mbox_loc_flat"
1150 | type: "Flatten"
1151 | bottom: "fc7_mbox_loc_perm"
1152 | top: "fc7_mbox_loc_flat"
1153 | flatten_param {
1154 | axis: 1
1155 | }
1156 | }
1157 | layer {
1158 | name: "fc7_mbox_conf"
1159 | type: "Convolution"
1160 | bottom: "fc7"
1161 | top: "fc7_mbox_conf"
1162 | param {
1163 | lr_mult: 1
1164 | decay_mult: 1
1165 | }
1166 | param {
1167 | lr_mult: 2
1168 | decay_mult: 0
1169 | }
1170 | convolution_param {
1171 | num_output: 12 # 126
1172 | pad: 1
1173 | kernel_size: 3
1174 | stride: 1
1175 | weight_filler {
1176 | type: "xavier"
1177 | }
1178 | bias_filler {
1179 | type: "constant"
1180 | value: 0
1181 | }
1182 | }
1183 | }
1184 | layer {
1185 | name: "fc7_mbox_conf_perm"
1186 | type: "Permute"
1187 | bottom: "fc7_mbox_conf"
1188 | top: "fc7_mbox_conf_perm"
1189 | permute_param {
1190 | order: 0
1191 | order: 2
1192 | order: 3
1193 | order: 1
1194 | }
1195 | }
1196 | layer {
1197 | name: "fc7_mbox_conf_flat"
1198 | type: "Flatten"
1199 | bottom: "fc7_mbox_conf_perm"
1200 | top: "fc7_mbox_conf_flat"
1201 | flatten_param {
1202 | axis: 1
1203 | }
1204 | }
1205 | layer {
1206 | name: "fc7_mbox_priorbox"
1207 | type: "PriorBox"
1208 | bottom: "fc7"
1209 | bottom: "data"
1210 | top: "fc7_mbox_priorbox"
1211 | prior_box_param {
1212 | min_size: 60.0
1213 | max_size: 111.0
1214 | aspect_ratio: 2
1215 | aspect_ratio: 3
1216 | flip: true
1217 | clip: false
1218 | variance: 0.1
1219 | variance: 0.1
1220 | variance: 0.2
1221 | variance: 0.2
1222 | step: 16
1223 | offset: 0.5
1224 | }
1225 | }
1226 | layer {
1227 | name: "conv6_2_mbox_loc"
1228 | type: "Convolution"
1229 | bottom: "conv6_2_h"
1230 | top: "conv6_2_mbox_loc"
1231 | param {
1232 | lr_mult: 1
1233 | decay_mult: 1
1234 | }
1235 | param {
1236 | lr_mult: 2
1237 | decay_mult: 0
1238 | }
1239 | convolution_param {
1240 | num_output: 24
1241 | pad: 1
1242 | kernel_size: 3
1243 | stride: 1
1244 | weight_filler {
1245 | type: "xavier"
1246 | }
1247 | bias_filler {
1248 | type: "constant"
1249 | value: 0
1250 | }
1251 | }
1252 | }
1253 | layer {
1254 | name: "conv6_2_mbox_loc_perm"
1255 | type: "Permute"
1256 | bottom: "conv6_2_mbox_loc"
1257 | top: "conv6_2_mbox_loc_perm"
1258 | permute_param {
1259 | order: 0
1260 | order: 2
1261 | order: 3
1262 | order: 1
1263 | }
1264 | }
1265 | layer {
1266 | name: "conv6_2_mbox_loc_flat"
1267 | type: "Flatten"
1268 | bottom: "conv6_2_mbox_loc_perm"
1269 | top: "conv6_2_mbox_loc_flat"
1270 | flatten_param {
1271 | axis: 1
1272 | }
1273 | }
1274 | layer {
1275 | name: "conv6_2_mbox_conf"
1276 | type: "Convolution"
1277 | bottom: "conv6_2_h"
1278 | top: "conv6_2_mbox_conf"
1279 | param {
1280 | lr_mult: 1
1281 | decay_mult: 1
1282 | }
1283 | param {
1284 | lr_mult: 2
1285 | decay_mult: 0
1286 | }
1287 | convolution_param {
1288 | num_output: 12 # 126
1289 | pad: 1
1290 | kernel_size: 3
1291 | stride: 1
1292 | weight_filler {
1293 | type: "xavier"
1294 | }
1295 | bias_filler {
1296 | type: "constant"
1297 | value: 0
1298 | }
1299 | }
1300 | }
1301 | layer {
1302 | name: "conv6_2_mbox_conf_perm"
1303 | type: "Permute"
1304 | bottom: "conv6_2_mbox_conf"
1305 | top: "conv6_2_mbox_conf_perm"
1306 | permute_param {
1307 | order: 0
1308 | order: 2
1309 | order: 3
1310 | order: 1
1311 | }
1312 | }
1313 | layer {
1314 | name: "conv6_2_mbox_conf_flat"
1315 | type: "Flatten"
1316 | bottom: "conv6_2_mbox_conf_perm"
1317 | top: "conv6_2_mbox_conf_flat"
1318 | flatten_param {
1319 | axis: 1
1320 | }
1321 | }
1322 | layer {
1323 | name: "conv6_2_mbox_priorbox"
1324 | type: "PriorBox"
1325 | bottom: "conv6_2_h"
1326 | bottom: "data"
1327 | top: "conv6_2_mbox_priorbox"
1328 | prior_box_param {
1329 | min_size: 111.0
1330 | max_size: 162.0
1331 | aspect_ratio: 2
1332 | aspect_ratio: 3
1333 | flip: true
1334 | clip: false
1335 | variance: 0.1
1336 | variance: 0.1
1337 | variance: 0.2
1338 | variance: 0.2
1339 | step: 32
1340 | offset: 0.5
1341 | }
1342 | }
1343 | layer {
1344 | name: "conv7_2_mbox_loc"
1345 | type: "Convolution"
1346 | bottom: "conv7_2_h"
1347 | top: "conv7_2_mbox_loc"
1348 | param {
1349 | lr_mult: 1
1350 | decay_mult: 1
1351 | }
1352 | param {
1353 | lr_mult: 2
1354 | decay_mult: 0
1355 | }
1356 | convolution_param {
1357 | num_output: 24
1358 | pad: 1
1359 | kernel_size: 3
1360 | stride: 1
1361 | weight_filler {
1362 | type: "xavier"
1363 | }
1364 | bias_filler {
1365 | type: "constant"
1366 | value: 0
1367 | }
1368 | }
1369 | }
1370 | layer {
1371 | name: "conv7_2_mbox_loc_perm"
1372 | type: "Permute"
1373 | bottom: "conv7_2_mbox_loc"
1374 | top: "conv7_2_mbox_loc_perm"
1375 | permute_param {
1376 | order: 0
1377 | order: 2
1378 | order: 3
1379 | order: 1
1380 | }
1381 | }
1382 | layer {
1383 | name: "conv7_2_mbox_loc_flat"
1384 | type: "Flatten"
1385 | bottom: "conv7_2_mbox_loc_perm"
1386 | top: "conv7_2_mbox_loc_flat"
1387 | flatten_param {
1388 | axis: 1
1389 | }
1390 | }
1391 | layer {
1392 | name: "conv7_2_mbox_conf"
1393 | type: "Convolution"
1394 | bottom: "conv7_2_h"
1395 | top: "conv7_2_mbox_conf"
1396 | param {
1397 | lr_mult: 1
1398 | decay_mult: 1
1399 | }
1400 | param {
1401 | lr_mult: 2
1402 | decay_mult: 0
1403 | }
1404 | convolution_param {
1405 | num_output: 12 # 126
1406 | pad: 1
1407 | kernel_size: 3
1408 | stride: 1
1409 | weight_filler {
1410 | type: "xavier"
1411 | }
1412 | bias_filler {
1413 | type: "constant"
1414 | value: 0
1415 | }
1416 | }
1417 | }
1418 | layer {
1419 | name: "conv7_2_mbox_conf_perm"
1420 | type: "Permute"
1421 | bottom: "conv7_2_mbox_conf"
1422 | top: "conv7_2_mbox_conf_perm"
1423 | permute_param {
1424 | order: 0
1425 | order: 2
1426 | order: 3
1427 | order: 1
1428 | }
1429 | }
1430 | layer {
1431 | name: "conv7_2_mbox_conf_flat"
1432 | type: "Flatten"
1433 | bottom: "conv7_2_mbox_conf_perm"
1434 | top: "conv7_2_mbox_conf_flat"
1435 | flatten_param {
1436 | axis: 1
1437 | }
1438 | }
1439 | layer {
1440 | name: "conv7_2_mbox_priorbox"
1441 | type: "PriorBox"
1442 | bottom: "conv7_2_h"
1443 | bottom: "data"
1444 | top: "conv7_2_mbox_priorbox"
1445 | prior_box_param {
1446 | min_size: 162.0
1447 | max_size: 213.0
1448 | aspect_ratio: 2
1449 | aspect_ratio: 3
1450 | flip: true
1451 | clip: false
1452 | variance: 0.1
1453 | variance: 0.1
1454 | variance: 0.2
1455 | variance: 0.2
1456 | step: 64
1457 | offset: 0.5
1458 | }
1459 | }
1460 | layer {
1461 | name: "conv8_2_mbox_loc"
1462 | type: "Convolution"
1463 | bottom: "conv8_2_h"
1464 | top: "conv8_2_mbox_loc"
1465 | param {
1466 | lr_mult: 1
1467 | decay_mult: 1
1468 | }
1469 | param {
1470 | lr_mult: 2
1471 | decay_mult: 0
1472 | }
1473 | convolution_param {
1474 | num_output: 16
1475 | pad: 1
1476 | kernel_size: 3
1477 | stride: 1
1478 | weight_filler {
1479 | type: "xavier"
1480 | }
1481 | bias_filler {
1482 | type: "constant"
1483 | value: 0
1484 | }
1485 | }
1486 | }
1487 | layer {
1488 | name: "conv8_2_mbox_loc_perm"
1489 | type: "Permute"
1490 | bottom: "conv8_2_mbox_loc"
1491 | top: "conv8_2_mbox_loc_perm"
1492 | permute_param {
1493 | order: 0
1494 | order: 2
1495 | order: 3
1496 | order: 1
1497 | }
1498 | }
1499 | layer {
1500 | name: "conv8_2_mbox_loc_flat"
1501 | type: "Flatten"
1502 | bottom: "conv8_2_mbox_loc_perm"
1503 | top: "conv8_2_mbox_loc_flat"
1504 | flatten_param {
1505 | axis: 1
1506 | }
1507 | }
1508 | layer {
1509 | name: "conv8_2_mbox_conf"
1510 | type: "Convolution"
1511 | bottom: "conv8_2_h"
1512 | top: "conv8_2_mbox_conf"
1513 | param {
1514 | lr_mult: 1
1515 | decay_mult: 1
1516 | }
1517 | param {
1518 | lr_mult: 2
1519 | decay_mult: 0
1520 | }
1521 | convolution_param {
1522 | num_output: 8 # 84
1523 | pad: 1
1524 | kernel_size: 3
1525 | stride: 1
1526 | weight_filler {
1527 | type: "xavier"
1528 | }
1529 | bias_filler {
1530 | type: "constant"
1531 | value: 0
1532 | }
1533 | }
1534 | }
1535 | layer {
1536 | name: "conv8_2_mbox_conf_perm"
1537 | type: "Permute"
1538 | bottom: "conv8_2_mbox_conf"
1539 | top: "conv8_2_mbox_conf_perm"
1540 | permute_param {
1541 | order: 0
1542 | order: 2
1543 | order: 3
1544 | order: 1
1545 | }
1546 | }
1547 | layer {
1548 | name: "conv8_2_mbox_conf_flat"
1549 | type: "Flatten"
1550 | bottom: "conv8_2_mbox_conf_perm"
1551 | top: "conv8_2_mbox_conf_flat"
1552 | flatten_param {
1553 | axis: 1
1554 | }
1555 | }
1556 | layer {
1557 | name: "conv8_2_mbox_priorbox"
1558 | type: "PriorBox"
1559 | bottom: "conv8_2_h"
1560 | bottom: "data"
1561 | top: "conv8_2_mbox_priorbox"
1562 | prior_box_param {
1563 | min_size: 213.0
1564 | max_size: 264.0
1565 | aspect_ratio: 2
1566 | flip: true
1567 | clip: false
1568 | variance: 0.1
1569 | variance: 0.1
1570 | variance: 0.2
1571 | variance: 0.2
1572 | step: 100
1573 | offset: 0.5
1574 | }
1575 | }
1576 | layer {
1577 | name: "conv9_2_mbox_loc"
1578 | type: "Convolution"
1579 | bottom: "conv9_2_h"
1580 | top: "conv9_2_mbox_loc"
1581 | param {
1582 | lr_mult: 1
1583 | decay_mult: 1
1584 | }
1585 | param {
1586 | lr_mult: 2
1587 | decay_mult: 0
1588 | }
1589 | convolution_param {
1590 | num_output: 16
1591 | pad: 1
1592 | kernel_size: 3
1593 | stride: 1
1594 | weight_filler {
1595 | type: "xavier"
1596 | }
1597 | bias_filler {
1598 | type: "constant"
1599 | value: 0
1600 | }
1601 | }
1602 | }
1603 | layer {
1604 | name: "conv9_2_mbox_loc_perm"
1605 | type: "Permute"
1606 | bottom: "conv9_2_mbox_loc"
1607 | top: "conv9_2_mbox_loc_perm"
1608 | permute_param {
1609 | order: 0
1610 | order: 2
1611 | order: 3
1612 | order: 1
1613 | }
1614 | }
1615 | layer {
1616 | name: "conv9_2_mbox_loc_flat"
1617 | type: "Flatten"
1618 | bottom: "conv9_2_mbox_loc_perm"
1619 | top: "conv9_2_mbox_loc_flat"
1620 | flatten_param {
1621 | axis: 1
1622 | }
1623 | }
1624 | layer {
1625 | name: "conv9_2_mbox_conf"
1626 | type: "Convolution"
1627 | bottom: "conv9_2_h"
1628 | top: "conv9_2_mbox_conf"
1629 | param {
1630 | lr_mult: 1
1631 | decay_mult: 1
1632 | }
1633 | param {
1634 | lr_mult: 2
1635 | decay_mult: 0
1636 | }
1637 | convolution_param {
1638 | num_output: 8 # 84
1639 | pad: 1
1640 | kernel_size: 3
1641 | stride: 1
1642 | weight_filler {
1643 | type: "xavier"
1644 | }
1645 | bias_filler {
1646 | type: "constant"
1647 | value: 0
1648 | }
1649 | }
1650 | }
1651 | layer {
1652 | name: "conv9_2_mbox_conf_perm"
1653 | type: "Permute"
1654 | bottom: "conv9_2_mbox_conf"
1655 | top: "conv9_2_mbox_conf_perm"
1656 | permute_param {
1657 | order: 0
1658 | order: 2
1659 | order: 3
1660 | order: 1
1661 | }
1662 | }
1663 | layer {
1664 | name: "conv9_2_mbox_conf_flat"
1665 | type: "Flatten"
1666 | bottom: "conv9_2_mbox_conf_perm"
1667 | top: "conv9_2_mbox_conf_flat"
1668 | flatten_param {
1669 | axis: 1
1670 | }
1671 | }
1672 | layer {
1673 | name: "conv9_2_mbox_priorbox"
1674 | type: "PriorBox"
1675 | bottom: "conv9_2_h"
1676 | bottom: "data"
1677 | top: "conv9_2_mbox_priorbox"
1678 | prior_box_param {
1679 | min_size: 264.0
1680 | max_size: 315.0
1681 | aspect_ratio: 2
1682 | flip: true
1683 | clip: false
1684 | variance: 0.1
1685 | variance: 0.1
1686 | variance: 0.2
1687 | variance: 0.2
1688 | step: 300
1689 | offset: 0.5
1690 | }
1691 | }
1692 | layer {
1693 | name: "mbox_loc"
1694 | type: "Concat"
1695 | bottom: "conv4_3_norm_mbox_loc_flat"
1696 | bottom: "fc7_mbox_loc_flat"
1697 | bottom: "conv6_2_mbox_loc_flat"
1698 | bottom: "conv7_2_mbox_loc_flat"
1699 | bottom: "conv8_2_mbox_loc_flat"
1700 | bottom: "conv9_2_mbox_loc_flat"
1701 | top: "mbox_loc"
1702 | concat_param {
1703 | axis: 1
1704 | }
1705 | }
1706 | layer {
1707 | name: "mbox_conf"
1708 | type: "Concat"
1709 | bottom: "conv4_3_norm_mbox_conf_flat"
1710 | bottom: "fc7_mbox_conf_flat"
1711 | bottom: "conv6_2_mbox_conf_flat"
1712 | bottom: "conv7_2_mbox_conf_flat"
1713 | bottom: "conv8_2_mbox_conf_flat"
1714 | bottom: "conv9_2_mbox_conf_flat"
1715 | top: "mbox_conf"
1716 | concat_param {
1717 | axis: 1
1718 | }
1719 | }
1720 | layer {
1721 | name: "mbox_priorbox"
1722 | type: "Concat"
1723 | bottom: "conv4_3_norm_mbox_priorbox"
1724 | bottom: "fc7_mbox_priorbox"
1725 | bottom: "conv6_2_mbox_priorbox"
1726 | bottom: "conv7_2_mbox_priorbox"
1727 | bottom: "conv8_2_mbox_priorbox"
1728 | bottom: "conv9_2_mbox_priorbox"
1729 | top: "mbox_priorbox"
1730 | concat_param {
1731 | axis: 2
1732 | }
1733 | }
1734 |
1735 | layer {
1736 | name: "mbox_conf_reshape"
1737 | type: "Reshape"
1738 | bottom: "mbox_conf"
1739 | top: "mbox_conf_reshape"
1740 | reshape_param {
1741 | shape {
1742 | dim: 0
1743 | dim: -1
1744 | dim: 2
1745 | }
1746 | }
1747 | }
1748 | layer {
1749 | name: "mbox_conf_softmax"
1750 | type: "Softmax"
1751 | bottom: "mbox_conf_reshape"
1752 | top: "mbox_conf_softmax"
1753 | softmax_param {
1754 | axis: 2
1755 | }
1756 | }
1757 | layer {
1758 | name: "mbox_conf_flatten"
1759 | type: "Flatten"
1760 | bottom: "mbox_conf_softmax"
1761 | top: "mbox_conf_flatten"
1762 | flatten_param {
1763 | axis: 1
1764 | }
1765 | }
1766 |
1767 | layer {
1768 | name: "detection_out"
1769 | type: "DetectionOutput"
1770 | bottom: "mbox_loc"
1771 | bottom: "mbox_conf_flatten"
1772 | bottom: "mbox_priorbox"
1773 | top: "detection_out"
1774 | include {
1775 | phase: TEST
1776 | }
1777 | detection_output_param {
1778 | num_classes: 2
1779 | share_location: true
1780 | background_label_id: 0
1781 | nms_param {
1782 | nms_threshold: 0.45
1783 | top_k: 400
1784 | }
1785 | code_type: CENTER_SIZE
1786 | keep_top_k: 200
1787 | confidence_threshold: 0.01
1788 | }
1789 | }
1790 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.5
2 | opencv-python==4.2.0.32
3 | opencv-contrib-python==4.2.0.32
4 | scipy==1.5.0
5 | matplotlib==3.2.2
6 | torch==1.5.1
7 | torchvision==0.6.1
8 | onnx==1.7.0
9 | onnxruntime==1.2.0
10 |
--------------------------------------------------------------------------------
/src/1-Explore Dataset.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from torchvision import transforms\n",
10 | "import torch\n",
11 | "import torch.optim as optim\n",
12 | "from torch.utils.data import random_split, DataLoader\n",
13 | "import time\n",
14 | "import cv2\n",
15 | "import numpy as np\n",
16 | "import os\n",
17 | "import matplotlib.pyplot as plt\n",
18 | "#Local Imports\n",
19 | "from transforms import Normalize,SequenceRandomTransform,ToTensor\n",
20 | "from dataset import HeadposeDataset, DatasetFromSubset\n",
21 | "from model import FSANet\n",
22 | "from utils import draw_axis\n",
23 | "\n",
24 | "%matplotlib inline"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "metadata": {},
31 | "outputs": [
32 | {
33 | "name": "stdout",
34 | "output_type": "stream",
35 | "text": [
36 | "x (images) shape: (122415, 64, 64, 3)\n",
37 | "y (poses) shape: (122415, 3)\n"
38 | ]
39 | }
40 | ],
41 | "source": [
42 | "augmentation = SequenceRandomTransform()\n",
43 | "\n",
44 | "data_path = '../data/type1/train'\n",
45 | "\n",
46 | "hdb = HeadposeDataset(data_path,transform=None)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 4,
52 | "metadata": {},
53 | "outputs": [
54 | {
55 | "name": "stdout",
56 | "output_type": "stream",
57 | "text": [
58 | "Yaw: 21.16, Pitch: -6.89, Roll: -11.83\n"
59 | ]
60 | },
61 | {
62 | "data": {
63 | "text/plain": [
64 | ""
65 | ]
66 | },
67 | "execution_count": 4,
68 | "metadata": {},
69 | "output_type": "execute_result"
70 | },
71 | {
72 | "data": {
73 | "image/png": "\n",
74 | "text/plain": [
75 | ""
76 | ]
77 | },
78 | "metadata": {},
79 | "output_type": "display_data"
80 | }
81 | ],
82 | "source": [
83 | "#Choose Input Image Index from Dataset\n",
84 | "idx = 26\n",
85 | "x,y = hdb[idx]\n",
86 | "print(f'Yaw: {y[0]:.2f}, Pitch: {y[1]:.2f}, Roll: {y[2]:.2f}')\n",
87 | "x_real = x.copy()\n",
88 | "x_aug = augmentation(x_real).copy()\n",
89 | "\n",
90 | "draw_axis(x_real,y[0],y[1],y[2],size=20)\n",
91 | "draw_axis(x_aug,y[0],y[1],y[2],size=20)\n",
92 | "\n",
93 | "fig=plt.figure(figsize=(10, 10), dpi= 80, facecolor='w', edgecolor='k')\n",
94 | "#Draw Original Input x\n",
95 | "plt.subplot(121)\n",
96 | "plt.title('Original')\n",
97 | "plt.imshow(x_real[:,:,::-1]) #show image as rgb\n",
98 | "#Draw Augmented Input x\n",
99 | "plt.subplot(122)\n",
100 | "plt.title('Augmented')\n",
101 | "plt.imshow(x_aug[:,:,::-1])"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": null,
107 | "metadata": {},
108 | "outputs": [],
109 | "source": []
110 | }
111 | ],
112 | "metadata": {
113 | "kernelspec": {
114 | "display_name": "Python 3",
115 | "language": "python",
116 | "name": "python3"
117 | },
118 | "language_info": {
119 | "codemirror_mode": {
120 | "name": "ipython",
121 | "version": 3
122 | },
123 | "file_extension": ".py",
124 | "mimetype": "text/x-python",
125 | "name": "python",
126 | "nbconvert_exporter": "python",
127 | "pygments_lexer": "ipython3",
128 | "version": "3.7.6"
129 | }
130 | },
131 | "nbformat": 4,
132 | "nbformat_minor": 4
133 | }
134 |
--------------------------------------------------------------------------------
/src/3-Test Model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Test Model\n",
8 | "\n",
9 | "In this notebook, we are testing our model performance on test Dataset.\n",
10 | "\n",
11 | "**1. Import Required Libraries:-** "
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from torchvision import transforms\n",
21 | "import torch\n",
22 | "import torch.optim as optim\n",
23 | "from torch.utils.data import random_split, DataLoader\n",
24 | "import time\n",
25 | "import cv2\n",
26 | "import numpy as np\n",
27 | "import os\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "#Local Imports\n",
30 | "from transforms import Normalize,ToTensor\n",
31 | "from dataset import HeadposeDataset\n",
32 | "from model import FSANet\n",
33 | "\n",
34 | "%matplotlib inline"
35 | ]
36 | },
37 | {
38 | "cell_type": "markdown",
39 | "metadata": {},
40 | "source": [
41 | "**2. Compose Augmentation Transform and Create Dataset:-**"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": 2,
47 | "metadata": {},
48 | "outputs": [
49 | {
50 | "name": "stdout",
51 | "output_type": "stream",
52 | "text": [
53 | "x (images) shape: (1969, 64, 64, 3)\n",
54 | "y (poses) shape: (1969, 3)\n"
55 | ]
56 | }
57 | ],
58 | "source": [
59 | "transform = transforms.Compose([\n",
60 | " Normalize(mean=127.5,std=128),\n",
61 | " ToTensor()\n",
62 | " ])\n",
63 | "\n",
64 | "data_path = '../data/type1/test'\n",
65 | "\n",
66 | "hdb = HeadposeDataset(data_path,transform=transform)"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "**3. Create Dataloader to batch over test dataset:-**"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 3,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "#Setup dataloaders for train and validation\n",
83 | "batch_size = 64\n",
84 | "\n",
85 | "test_loader = DataLoader(hdb, \n",
86 | " batch_size=batch_size,\n",
87 | " shuffle=False)"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "metadata": {},
93 | "source": [
94 | "**4. Define Model Function:-**"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 4,
100 | "metadata": {},
101 | "outputs": [
102 | {
103 | "data": {
104 | "text/plain": [
105 | ""
106 | ]
107 | },
108 | "execution_count": 4,
109 | "metadata": {},
110 | "output_type": "execute_result"
111 | }
112 | ],
113 | "source": [
114 | "model = FSANet(var=False)\n",
115 | "#Load Model Checkpoint\n",
116 | "chkpt_dic = torch.load('../checkpoints/fsa1x1-08082020.chkpt')\n",
117 | "model.load_state_dict(chkpt_dic['best_states']['model'])"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "**5. Place Model in GPU:-**"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 5,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "data": {
134 | "text/plain": [
135 | "FSANet(\n",
136 | " (msms): MultiStreamMultiStage(\n",
137 | " (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
138 | " (s0_conv0): SepConvBlock(\n",
139 | " (conv): SepConv2d(\n",
140 | " (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)\n",
141 | " (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n",
142 | " )\n",
143 | " (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
144 | " (act): ReLU()\n",
145 | " )\n",
146 | " (s0_conv1_0): SepConvBlock(\n",
147 | " (conv): SepConv2d(\n",
148 | " (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16)\n",
149 | " (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
150 | " )\n",
151 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
152 | " (act): ReLU()\n",
153 | " )\n",
154 | " (s0_conv1_1): SepConvBlock(\n",
155 | " (conv): SepConv2d(\n",
156 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
157 | " (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
158 | " )\n",
159 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
160 | " (act): ReLU()\n",
161 | " )\n",
162 | " (s0_conv1_out): Conv2dAct(\n",
163 | " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
164 | " (act): ReLU()\n",
165 | " )\n",
166 | " (s0_conv2_0): SepConvBlock(\n",
167 | " (conv): SepConv2d(\n",
168 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
169 | " (pointwise): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
170 | " )\n",
171 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
172 | " (act): ReLU()\n",
173 | " )\n",
174 | " (s0_conv2_1): SepConvBlock(\n",
175 | " (conv): SepConv2d(\n",
176 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
177 | " (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
178 | " )\n",
179 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
180 | " (act): ReLU()\n",
181 | " )\n",
182 | " (s0_conv2_out): Conv2dAct(\n",
183 | " (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
184 | " (act): ReLU()\n",
185 | " )\n",
186 | " (s0_conv3_0): SepConvBlock(\n",
187 | " (conv): SepConv2d(\n",
188 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
189 | " (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
190 | " )\n",
191 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
192 | " (act): ReLU()\n",
193 | " )\n",
194 | " (s0_conv3_1): SepConvBlock(\n",
195 | " (conv): SepConv2d(\n",
196 | " (depthwise): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)\n",
197 | " (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
198 | " )\n",
199 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
200 | " (act): ReLU()\n",
201 | " )\n",
202 | " (s0_conv3_out): Conv2dAct(\n",
203 | " (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
204 | " (act): ReLU()\n",
205 | " )\n",
206 | " (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
207 | " (s1_conv0): SepConvBlock(\n",
208 | " (conv): SepConv2d(\n",
209 | " (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)\n",
210 | " (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n",
211 | " )\n",
212 | " (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
213 | " (act): ReLU()\n",
214 | " )\n",
215 | " (s1_conv1_0): SepConvBlock(\n",
216 | " (conv): SepConv2d(\n",
217 | " (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16)\n",
218 | " (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
219 | " )\n",
220 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
221 | " (act): Tanh()\n",
222 | " )\n",
223 | " (s1_conv1_1): SepConvBlock(\n",
224 | " (conv): SepConv2d(\n",
225 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
226 | " (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
227 | " )\n",
228 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
229 | " (act): Tanh()\n",
230 | " )\n",
231 | " (s1_conv1_out): Conv2dAct(\n",
232 | " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
233 | " (act): Tanh()\n",
234 | " )\n",
235 | " (s1_conv2_0): SepConvBlock(\n",
236 | " (conv): SepConv2d(\n",
237 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
238 | " (pointwise): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
239 | " )\n",
240 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
241 | " (act): Tanh()\n",
242 | " )\n",
243 | " (s1_conv2_1): SepConvBlock(\n",
244 | " (conv): SepConv2d(\n",
245 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
246 | " (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
247 | " )\n",
248 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
249 | " (act): Tanh()\n",
250 | " )\n",
251 | " (s1_conv2_out): Conv2dAct(\n",
252 | " (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
253 | " (act): Tanh()\n",
254 | " )\n",
255 | " (s1_conv3_0): SepConvBlock(\n",
256 | " (conv): SepConv2d(\n",
257 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
258 | " (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
259 | " )\n",
260 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
261 | " (act): Tanh()\n",
262 | " )\n",
263 | " (s1_conv3_1): SepConvBlock(\n",
264 | " (conv): SepConv2d(\n",
265 | " (depthwise): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)\n",
266 | " (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
267 | " )\n",
268 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
269 | " (act): Tanh()\n",
270 | " )\n",
271 | " (s1_conv3_out): Conv2dAct(\n",
272 | " (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
273 | " (act): Tanh()\n",
274 | " )\n",
275 | " )\n",
276 | " (fgsm): FineGrainedStructureMapping(\n",
277 | " (attention_maps): ScoringFunction(\n",
278 | " (reduce_channel): Conv2dAct(\n",
279 | " (conv): Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1))\n",
280 | " (act): Sigmoid()\n",
281 | " )\n",
282 | " )\n",
283 | " (fm): Linear(in_features=64, out_features=960, bias=True)\n",
284 | " (fc): Linear(in_features=192, out_features=35, bias=True)\n",
285 | " )\n",
286 | " (caps_layer): CapsuleLayer1d()\n",
287 | " (eaf): ExtractAggregatedFeatures()\n",
288 | " (esp_s1): ExtractSSRParams(\n",
289 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
290 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
291 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
292 | " )\n",
293 | " (esp_s2): ExtractSSRParams(\n",
294 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
295 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
296 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
297 | " )\n",
298 | " (esp_s3): ExtractSSRParams(\n",
299 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
300 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
301 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
302 | " )\n",
303 | " (ssr): SSRLayer()\n",
304 | ")"
305 | ]
306 | },
307 | "execution_count": 5,
308 | "metadata": {},
309 | "output_type": "execute_result"
310 | }
311 | ],
312 | "source": [
313 | "#Place all the necessary things in GPU \n",
314 | "device = torch.device(\"cuda\")\n",
315 | "model.to(device)"
316 | ]
317 | },
318 | {
319 | "cell_type": "markdown",
320 | "metadata": {},
321 | "source": [
322 | "**6. Define Testing Function:-**"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 6,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "def test_net():\n",
332 | " yaw_loss = []\n",
333 | " pitch_loss = []\n",
334 | " roll_loss = []\n",
335 | " model.eval()\n",
336 | " with torch.no_grad():\n",
337 | " for batch_i, data in enumerate(test_loader):\n",
338 | " # get the input images and their corresponding poses\n",
339 | " images,gt_poses = data\n",
340 | "\n",
341 | " # put data inside gpu\n",
342 | " images = images.float().to(device)\n",
343 | " gt_poses = gt_poses.float().to(device)\n",
344 | "\n",
345 | " # call model forward pass\n",
346 | " predicted_poses = model(images)\n",
347 | "\n",
348 | " abs_loss = torch.abs(gt_poses-predicted_poses)\n",
349 | " \n",
350 | " abs_loss = abs_loss.cpu().numpy().mean(axis=0)\n",
351 | " \n",
352 | " yaw_loss.append(abs_loss[0])\n",
353 | " pitch_loss.append(abs_loss[1])\n",
354 | " roll_loss.append(abs_loss[2])\n",
355 | " \n",
356 | " yaw_loss = np.mean(yaw_loss)\n",
357 | " pitch_loss = np.mean(pitch_loss) \n",
358 | " roll_loss = np.mean(roll_loss)\n",
359 | " print('Mean Absolute Error:-')\n",
360 | " print(f'Yaw: {yaw_loss:.2f}, Pitch: {pitch_loss:.2f}, Roll: {roll_loss:.2f}')"
361 | ]
362 | },
363 | {
364 | "cell_type": "markdown",
365 | "metadata": {},
366 | "source": [
367 | "**7. Test Model and print Mean Absolute Error:-**"
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": 7,
373 | "metadata": {},
374 | "outputs": [
375 | {
376 | "name": "stdout",
377 | "output_type": "stream",
378 | "text": [
379 | "Mean Absolute Error:-\n",
380 | "Yaw: 4.85, Pitch: 6.27, Roll: 4.96\n"
381 | ]
382 | }
383 | ],
384 | "source": [
385 | "try:\n",
386 | " # test your network\n",
387 | " test_net()\n",
388 | "except KeyboardInterrupt:\n",
389 | " print('Stopping Testing...')"
390 | ]
391 | },
392 | {
393 | "cell_type": "code",
394 | "execution_count": null,
395 | "metadata": {},
396 | "outputs": [],
397 | "source": []
398 | }
399 | ],
400 | "metadata": {
401 | "kernelspec": {
402 | "display_name": "Python 3",
403 | "language": "python",
404 | "name": "python3"
405 | },
406 | "language_info": {
407 | "codemirror_mode": {
408 | "name": "ipython",
409 | "version": 3
410 | },
411 | "file_extension": ".py",
412 | "mimetype": "text/x-python",
413 | "name": "python",
414 | "nbconvert_exporter": "python",
415 | "pygments_lexer": "ipython3",
416 | "version": "3.7.6"
417 | }
418 | },
419 | "nbformat": 4,
420 | "nbformat_minor": 2
421 | }
422 |
--------------------------------------------------------------------------------
/src/4-Export to Onnx.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Export to ONNX\n",
8 | "\n",
9 | "In this notebook, we export our pytorch model to ONNX so that it can later be used for inference.\n",
10 | "\n",
11 | "**1. Import Required Libraries:-** "
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 1,
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from torchvision import transforms\n",
21 | "import torch\n",
22 | "import torch.optim as optim\n",
23 | "from torch.utils.data import random_split, DataLoader\n",
24 | "import time\n",
25 | "import cv2\n",
26 | "import numpy as np\n",
27 | "import os\n",
28 | "import matplotlib.pyplot as plt\n",
29 | "#Local Imports\n",
30 | "from dataset import HeadposeDataset\n",
31 | "from model import FSANet\n",
32 | "import onnx\n",
33 | "import onnxruntime\n"
34 | ]
35 | },
36 | {
37 | "cell_type": "markdown",
38 | "metadata": {},
39 | "source": [
40 | "**2. Define Model and Load from Saved Checkpoint:-**"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 2,
46 | "metadata": {},
47 | "outputs": [
48 | {
49 | "data": {
50 | "text/plain": [
51 | "FSANet(\n",
52 | " (msms): MultiStreamMultiStage(\n",
53 | " (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n",
54 | " (s0_conv0): SepConvBlock(\n",
55 | " (conv): SepConv2d(\n",
56 | " (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)\n",
57 | " (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n",
58 | " )\n",
59 | " (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
60 | " (act): ReLU()\n",
61 | " )\n",
62 | " (s0_conv1_0): SepConvBlock(\n",
63 | " (conv): SepConv2d(\n",
64 | " (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16)\n",
65 | " (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
66 | " )\n",
67 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
68 | " (act): ReLU()\n",
69 | " )\n",
70 | " (s0_conv1_1): SepConvBlock(\n",
71 | " (conv): SepConv2d(\n",
72 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
73 | " (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
74 | " )\n",
75 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
76 | " (act): ReLU()\n",
77 | " )\n",
78 | " (s0_conv1_out): Conv2dAct(\n",
79 | " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
80 | " (act): ReLU()\n",
81 | " )\n",
82 | " (s0_conv2_0): SepConvBlock(\n",
83 | " (conv): SepConv2d(\n",
84 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
85 | " (pointwise): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
86 | " )\n",
87 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
88 | " (act): ReLU()\n",
89 | " )\n",
90 | " (s0_conv2_1): SepConvBlock(\n",
91 | " (conv): SepConv2d(\n",
92 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
93 | " (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
94 | " )\n",
95 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
96 | " (act): ReLU()\n",
97 | " )\n",
98 | " (s0_conv2_out): Conv2dAct(\n",
99 | " (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
100 | " (act): ReLU()\n",
101 | " )\n",
102 | " (s0_conv3_0): SepConvBlock(\n",
103 | " (conv): SepConv2d(\n",
104 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
105 | " (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
106 | " )\n",
107 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
108 | " (act): ReLU()\n",
109 | " )\n",
110 | " (s0_conv3_1): SepConvBlock(\n",
111 | " (conv): SepConv2d(\n",
112 | " (depthwise): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)\n",
113 | " (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
114 | " )\n",
115 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
116 | " (act): ReLU()\n",
117 | " )\n",
118 | " (s0_conv3_out): Conv2dAct(\n",
119 | " (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
120 | " (act): ReLU()\n",
121 | " )\n",
122 | " (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
123 | " (s1_conv0): SepConvBlock(\n",
124 | " (conv): SepConv2d(\n",
125 | " (depthwise): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3)\n",
126 | " (pointwise): Conv2d(3, 16, kernel_size=(1, 1), stride=(1, 1))\n",
127 | " )\n",
128 | " (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
129 | " (act): ReLU()\n",
130 | " )\n",
131 | " (s1_conv1_0): SepConvBlock(\n",
132 | " (conv): SepConv2d(\n",
133 | " (depthwise): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16)\n",
134 | " (pointwise): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))\n",
135 | " )\n",
136 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
137 | " (act): Tanh()\n",
138 | " )\n",
139 | " (s1_conv1_1): SepConvBlock(\n",
140 | " (conv): SepConv2d(\n",
141 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
142 | " (pointwise): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))\n",
143 | " )\n",
144 | " (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
145 | " (act): Tanh()\n",
146 | " )\n",
147 | " (s1_conv1_out): Conv2dAct(\n",
148 | " (conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
149 | " (act): Tanh()\n",
150 | " )\n",
151 | " (s1_conv2_0): SepConvBlock(\n",
152 | " (conv): SepConv2d(\n",
153 | " (depthwise): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)\n",
154 | " (pointwise): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))\n",
155 | " )\n",
156 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
157 | " (act): Tanh()\n",
158 | " )\n",
159 | " (s1_conv2_1): SepConvBlock(\n",
160 | " (conv): SepConv2d(\n",
161 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
162 | " (pointwise): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
163 | " )\n",
164 | " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
165 | " (act): Tanh()\n",
166 | " )\n",
167 | " (s1_conv2_out): Conv2dAct(\n",
168 | " (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))\n",
169 | " (act): Tanh()\n",
170 | " )\n",
171 | " (s1_conv3_0): SepConvBlock(\n",
172 | " (conv): SepConv2d(\n",
173 | " (depthwise): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)\n",
174 | " (pointwise): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))\n",
175 | " )\n",
176 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
177 | " (act): Tanh()\n",
178 | " )\n",
179 | " (s1_conv3_1): SepConvBlock(\n",
180 | " (conv): SepConv2d(\n",
181 | " (depthwise): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)\n",
182 | " (pointwise): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
183 | " )\n",
184 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
185 | " (act): Tanh()\n",
186 | " )\n",
187 | " (s1_conv3_out): Conv2dAct(\n",
188 | " (conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))\n",
189 | " (act): Tanh()\n",
190 | " )\n",
191 | " )\n",
192 | " (fgsm): FineGrainedStructureMapping(\n",
193 | " (attention_maps): ScoringFunction(\n",
194 | " (reduce_channel): VarianceC()\n",
195 | " )\n",
196 | " (fm): Linear(in_features=64, out_features=960, bias=True)\n",
197 | " (fc): Linear(in_features=192, out_features=35, bias=True)\n",
198 | " )\n",
199 | " (caps_layer): CapsuleLayer1d()\n",
200 | " (eaf): ExtractAggregatedFeatures()\n",
201 | " (esp_s1): ExtractSSRParams(\n",
202 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
203 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
204 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
205 | " )\n",
206 | " (esp_s2): ExtractSSRParams(\n",
207 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
208 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
209 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
210 | " )\n",
211 | " (esp_s3): ExtractSSRParams(\n",
212 | " (shift_fc): Linear(in_features=4, out_features=3, bias=True)\n",
213 | " (scale_fc): Linear(in_features=4, out_features=3, bias=True)\n",
214 | " (pred_fc): Linear(in_features=8, out_features=9, bias=True)\n",
215 | " )\n",
216 | " (ssr): SSRLayer()\n",
217 | ")"
218 | ]
219 | },
220 | "execution_count": 2,
221 | "metadata": {},
222 | "output_type": "execute_result"
223 | }
224 | ],
225 | "source": [
226 | "device = torch.device(\"cuda\")\n",
227 | "model = FSANet(var=True).to(device)\n",
228 | "#Load Model Checkpoint\n",
229 | "chkpt_dic = torch.load('checkpoints/fsavar-09082020.chkpt')\n",
230 | "\n",
231 | "model.load_state_dict(chkpt_dic['best_states']['model'])\n",
232 | "#set model to inference-ready\n",
233 | "model.eval()"
234 | ]
235 | },
236 | {
237 | "cell_type": "markdown",
238 | "metadata": {},
239 | "source": [
240 | "**3. Export model to ONNX:-**"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 4,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "#Export to ONNX\n",
250 | "x = torch.randn(1,3,64,64).to(device)\n",
251 | "model_out = model(x)\n",
252 | "save_path = \"pretrained/fsanet-var-iter-688590.onnx\"\n",
253 | "\n",
254 | "torch.onnx.export(model, # model being run\n",
255 | " x, # model input (or a tuple for multiple inputs)\n",
256 | " save_path, # where to save the model (can be a file or file-like object)\n",
257 | " export_params=True, # store the trained parameter weights inside the model file\n",
258 | " opset_version=9, # the ONNX version to export the model to\n",
259 | " do_constant_folding=True, # whether to execute constant folding for optimization\n",
260 | " input_names = ['input'], # the model's input names\n",
261 | " output_names = ['output']) # the model's output names"
262 | ]
263 | },
264 | {
265 | "cell_type": "markdown",
266 | "metadata": {},
267 | "source": [
268 | "**4. Reload model from ONNX:-**"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 5,
274 | "metadata": {},
275 | "outputs": [
276 | {
277 | "name": "stdout",
278 | "output_type": "stream",
279 | "text": [
280 | "graph torch-jit-export (\n",
281 | " %input[FLOAT, 1x3x64x64]\n",
282 | ") initializers (\n",
283 | " %595[INT64, 1]\n",
284 | " %596[INT64, 1]\n",
285 | " %597[INT64, 1]\n",
286 | " %598[INT64, 1]\n",
287 | " %599[INT64, 1]\n",
288 | " %600[INT64, 1]\n",
289 | " %601[INT64, 1]\n",
290 | " %602[INT64, 1]\n",
291 | " %603[INT64, 1]\n",
292 | " %604[INT64, 1]\n",
293 | " %605[INT64, 1]\n",
294 | " %606[INT64, 1]\n",
295 | " %607[INT64, 1]\n",
296 | " %608[INT64, 1]\n",
297 | " %609[FLOAT, 3x21x64x16]\n",
298 | " %610[FLOAT, scalar]\n",
299 | " %611[FLOAT, scalar]\n",
300 | " %612[INT64, 1]\n",
301 | " %613[INT64, 1]\n",
302 | " %614[INT64, 1]\n",
303 | " %615[INT64, 1]\n",
304 | " %616[INT64, 1]\n",
305 | " %617[INT64, 1]\n",
306 | " %618[INT64, 1]\n",
307 | " %619[INT64, 1]\n",
308 | " %620[INT64, 1]\n",
309 | " %caps_layer.affine_w[FLOAT, 3x21x16x64]\n",
310 | " %esp_s1.pred_fc.bias[FLOAT, 9]\n",
311 | " %esp_s1.pred_fc.weight[FLOAT, 9x8]\n",
312 | " %esp_s1.scale_fc.bias[FLOAT, 3]\n",
313 | " %esp_s1.scale_fc.weight[FLOAT, 3x4]\n",
314 | " %esp_s1.shift_fc.bias[FLOAT, 3]\n",
315 | " %esp_s1.shift_fc.weight[FLOAT, 3x4]\n",
316 | " %esp_s2.pred_fc.bias[FLOAT, 9]\n",
317 | " %esp_s2.pred_fc.weight[FLOAT, 9x8]\n",
318 | " %esp_s2.scale_fc.bias[FLOAT, 3]\n",
319 | " %esp_s2.scale_fc.weight[FLOAT, 3x4]\n",
320 | " %esp_s2.shift_fc.bias[FLOAT, 3]\n",
321 | " %esp_s2.shift_fc.weight[FLOAT, 3x4]\n",
322 | " %esp_s3.pred_fc.bias[FLOAT, 9]\n",
323 | " %esp_s3.pred_fc.weight[FLOAT, 9x8]\n",
324 | " %esp_s3.scale_fc.bias[FLOAT, 3]\n",
325 | " %esp_s3.scale_fc.weight[FLOAT, 3x4]\n",
326 | " %esp_s3.shift_fc.bias[FLOAT, 3]\n",
327 | " %esp_s3.shift_fc.weight[FLOAT, 3x4]\n",
328 | " %fgsm.fc.bias[FLOAT, 35]\n",
329 | " %fgsm.fc.weight[FLOAT, 35x192]\n",
330 | " %fgsm.fm.bias[FLOAT, 960]\n",
331 | " %fgsm.fm.weight[FLOAT, 960x64]\n",
332 | " %msms.s0_conv0.bn.bias[FLOAT, 16]\n",
333 | " %msms.s0_conv0.bn.running_mean[FLOAT, 16]\n",
334 | " %msms.s0_conv0.bn.running_var[FLOAT, 16]\n",
335 | " %msms.s0_conv0.bn.weight[FLOAT, 16]\n",
336 | " %msms.s0_conv0.conv.depthwise.bias[FLOAT, 3]\n",
337 | " %msms.s0_conv0.conv.depthwise.weight[FLOAT, 3x1x3x3]\n",
338 | " %msms.s0_conv0.conv.pointwise.bias[FLOAT, 16]\n",
339 | " %msms.s0_conv0.conv.pointwise.weight[FLOAT, 16x3x1x1]\n",
340 | " %msms.s0_conv1_0.bn.bias[FLOAT, 32]\n",
341 | " %msms.s0_conv1_0.bn.running_mean[FLOAT, 32]\n",
342 | " %msms.s0_conv1_0.bn.running_var[FLOAT, 32]\n",
343 | " %msms.s0_conv1_0.bn.weight[FLOAT, 32]\n",
344 | " %msms.s0_conv1_0.conv.depthwise.bias[FLOAT, 16]\n",
345 | " %msms.s0_conv1_0.conv.depthwise.weight[FLOAT, 16x1x3x3]\n",
346 | " %msms.s0_conv1_0.conv.pointwise.bias[FLOAT, 32]\n",
347 | " %msms.s0_conv1_0.conv.pointwise.weight[FLOAT, 32x16x1x1]\n",
348 | " %msms.s0_conv1_1.bn.bias[FLOAT, 32]\n",
349 | " %msms.s0_conv1_1.bn.running_mean[FLOAT, 32]\n",
350 | " %msms.s0_conv1_1.bn.running_var[FLOAT, 32]\n",
351 | " %msms.s0_conv1_1.bn.weight[FLOAT, 32]\n",
352 | " %msms.s0_conv1_1.conv.depthwise.bias[FLOAT, 32]\n",
353 | " %msms.s0_conv1_1.conv.depthwise.weight[FLOAT, 32x1x3x3]\n",
354 | " %msms.s0_conv1_1.conv.pointwise.bias[FLOAT, 32]\n",
355 | " %msms.s0_conv1_1.conv.pointwise.weight[FLOAT, 32x32x1x1]\n",
356 | " %msms.s0_conv1_out.conv.bias[FLOAT, 64]\n",
357 | " %msms.s0_conv1_out.conv.weight[FLOAT, 64x32x1x1]\n",
358 | " %msms.s0_conv2_0.bn.bias[FLOAT, 64]\n",
359 | " %msms.s0_conv2_0.bn.running_mean[FLOAT, 64]\n",
360 | " %msms.s0_conv2_0.bn.running_var[FLOAT, 64]\n",
361 | " %msms.s0_conv2_0.bn.weight[FLOAT, 64]\n",
362 | " %msms.s0_conv2_0.conv.depthwise.bias[FLOAT, 32]\n",
363 | " %msms.s0_conv2_0.conv.depthwise.weight[FLOAT, 32x1x3x3]\n",
364 | " %msms.s0_conv2_0.conv.pointwise.bias[FLOAT, 64]\n",
365 | " %msms.s0_conv2_0.conv.pointwise.weight[FLOAT, 64x32x1x1]\n",
366 | " %msms.s0_conv2_1.bn.bias[FLOAT, 64]\n",
367 | " %msms.s0_conv2_1.bn.running_mean[FLOAT, 64]\n",
368 | " %msms.s0_conv2_1.bn.running_var[FLOAT, 64]\n",
369 | " %msms.s0_conv2_1.bn.weight[FLOAT, 64]\n",
370 | " %msms.s0_conv2_1.conv.depthwise.bias[FLOAT, 64]\n",
371 | " %msms.s0_conv2_1.conv.depthwise.weight[FLOAT, 64x1x3x3]\n",
372 | " %msms.s0_conv2_1.conv.pointwise.bias[FLOAT, 64]\n",
373 | " %msms.s0_conv2_1.conv.pointwise.weight[FLOAT, 64x64x1x1]\n",
374 | " %msms.s0_conv2_out.conv.bias[FLOAT, 64]\n",
375 | " %msms.s0_conv2_out.conv.weight[FLOAT, 64x64x1x1]\n",
376 | " %msms.s0_conv3_0.bn.bias[FLOAT, 128]\n",
377 | " %msms.s0_conv3_0.bn.running_mean[FLOAT, 128]\n",
378 | " %msms.s0_conv3_0.bn.running_var[FLOAT, 128]\n",
379 | " %msms.s0_conv3_0.bn.weight[FLOAT, 128]\n",
380 | " %msms.s0_conv3_0.conv.depthwise.bias[FLOAT, 64]\n",
381 | " %msms.s0_conv3_0.conv.depthwise.weight[FLOAT, 64x1x3x3]\n",
382 | " %msms.s0_conv3_0.conv.pointwise.bias[FLOAT, 128]\n",
383 | " %msms.s0_conv3_0.conv.pointwise.weight[FLOAT, 128x64x1x1]\n",
384 | " %msms.s0_conv3_1.bn.bias[FLOAT, 128]\n",
385 | " %msms.s0_conv3_1.bn.running_mean[FLOAT, 128]\n",
386 | " %msms.s0_conv3_1.bn.running_var[FLOAT, 128]\n",
387 | " %msms.s0_conv3_1.bn.weight[FLOAT, 128]\n",
388 | " %msms.s0_conv3_1.conv.depthwise.bias[FLOAT, 128]\n",
389 | " %msms.s0_conv3_1.conv.depthwise.weight[FLOAT, 128x1x3x3]\n",
390 | " %msms.s0_conv3_1.conv.pointwise.bias[FLOAT, 128]\n",
391 | " %msms.s0_conv3_1.conv.pointwise.weight[FLOAT, 128x128x1x1]\n",
392 | " %msms.s0_conv3_out.conv.bias[FLOAT, 64]\n",
393 | " %msms.s0_conv3_out.conv.weight[FLOAT, 64x128x1x1]\n",
394 | " %msms.s1_conv0.bn.bias[FLOAT, 16]\n",
395 | " %msms.s1_conv0.bn.running_mean[FLOAT, 16]\n",
396 | " %msms.s1_conv0.bn.running_var[FLOAT, 16]\n",
397 | " %msms.s1_conv0.bn.weight[FLOAT, 16]\n",
398 | " %msms.s1_conv0.conv.depthwise.bias[FLOAT, 3]\n",
399 | " %msms.s1_conv0.conv.depthwise.weight[FLOAT, 3x1x3x3]\n",
400 | " %msms.s1_conv0.conv.pointwise.bias[FLOAT, 16]\n",
401 | " %msms.s1_conv0.conv.pointwise.weight[FLOAT, 16x3x1x1]\n",
402 | " %msms.s1_conv1_0.bn.bias[FLOAT, 32]\n",
403 | " %msms.s1_conv1_0.bn.running_mean[FLOAT, 32]\n",
404 | " %msms.s1_conv1_0.bn.running_var[FLOAT, 32]\n",
405 | " %msms.s1_conv1_0.bn.weight[FLOAT, 32]\n",
406 | " %msms.s1_conv1_0.conv.depthwise.bias[FLOAT, 16]\n",
407 | " %msms.s1_conv1_0.conv.depthwise.weight[FLOAT, 16x1x3x3]\n",
408 | " %msms.s1_conv1_0.conv.pointwise.bias[FLOAT, 32]\n",
409 | " %msms.s1_conv1_0.conv.pointwise.weight[FLOAT, 32x16x1x1]\n",
410 | " %msms.s1_conv1_1.bn.bias[FLOAT, 32]\n",
411 | " %msms.s1_conv1_1.bn.running_mean[FLOAT, 32]\n",
412 | " %msms.s1_conv1_1.bn.running_var[FLOAT, 32]\n",
413 | " %msms.s1_conv1_1.bn.weight[FLOAT, 32]\n",
414 | " %msms.s1_conv1_1.conv.depthwise.bias[FLOAT, 32]\n",
415 | " %msms.s1_conv1_1.conv.depthwise.weight[FLOAT, 32x1x3x3]\n",
416 | " %msms.s1_conv1_1.conv.pointwise.bias[FLOAT, 32]\n",
417 | " %msms.s1_conv1_1.conv.pointwise.weight[FLOAT, 32x32x1x1]\n",
418 | " %msms.s1_conv1_out.conv.bias[FLOAT, 64]\n",
419 | " %msms.s1_conv1_out.conv.weight[FLOAT, 64x32x1x1]\n",
420 | " %msms.s1_conv2_0.bn.bias[FLOAT, 64]\n",
421 | " %msms.s1_conv2_0.bn.running_mean[FLOAT, 64]\n",
422 | " %msms.s1_conv2_0.bn.running_var[FLOAT, 64]\n",
423 | " %msms.s1_conv2_0.bn.weight[FLOAT, 64]\n",
424 | " %msms.s1_conv2_0.conv.depthwise.bias[FLOAT, 32]\n",
425 | " %msms.s1_conv2_0.conv.depthwise.weight[FLOAT, 32x1x3x3]\n",
426 | " %msms.s1_conv2_0.conv.pointwise.bias[FLOAT, 64]\n",
427 | " %msms.s1_conv2_0.conv.pointwise.weight[FLOAT, 64x32x1x1]\n",
428 | " %msms.s1_conv2_1.bn.bias[FLOAT, 64]\n",
429 | " %msms.s1_conv2_1.bn.running_mean[FLOAT, 64]\n",
430 | " %msms.s1_conv2_1.bn.running_var[FLOAT, 64]\n",
431 | " %msms.s1_conv2_1.bn.weight[FLOAT, 64]\n",
432 | " %msms.s1_conv2_1.conv.depthwise.bias[FLOAT, 64]\n",
433 | " %msms.s1_conv2_1.conv.depthwise.weight[FLOAT, 64x1x3x3]\n",
434 | " %msms.s1_conv2_1.conv.pointwise.bias[FLOAT, 64]\n",
435 | " %msms.s1_conv2_1.conv.pointwise.weight[FLOAT, 64x64x1x1]\n",
436 | " %msms.s1_conv2_out.conv.bias[FLOAT, 64]\n",
437 | " %msms.s1_conv2_out.conv.weight[FLOAT, 64x64x1x1]\n",
438 | " %msms.s1_conv3_0.bn.bias[FLOAT, 128]\n",
439 | " %msms.s1_conv3_0.bn.running_mean[FLOAT, 128]\n",
440 | " %msms.s1_conv3_0.bn.running_var[FLOAT, 128]\n",
441 | " %msms.s1_conv3_0.bn.weight[FLOAT, 128]\n",
442 | " %msms.s1_conv3_0.conv.depthwise.bias[FLOAT, 64]\n",
443 | " %msms.s1_conv3_0.conv.depthwise.weight[FLOAT, 64x1x3x3]\n",
444 | " %msms.s1_conv3_0.conv.pointwise.bias[FLOAT, 128]\n",
445 | " %msms.s1_conv3_0.conv.pointwise.weight[FLOAT, 128x64x1x1]\n",
446 | " %msms.s1_conv3_1.bn.bias[FLOAT, 128]\n",
447 | " %msms.s1_conv3_1.bn.running_mean[FLOAT, 128]\n",
448 | " %msms.s1_conv3_1.bn.running_var[FLOAT, 128]\n",
449 | " %msms.s1_conv3_1.bn.weight[FLOAT, 128]\n",
450 | " %msms.s1_conv3_1.conv.depthwise.bias[FLOAT, 128]\n",
451 | " %msms.s1_conv3_1.conv.depthwise.weight[FLOAT, 128x1x3x3]\n",
452 | " %msms.s1_conv3_1.conv.pointwise.bias[FLOAT, 128]\n",
453 | " %msms.s1_conv3_1.conv.pointwise.weight[FLOAT, 128x128x1x1]\n",
454 | " %msms.s1_conv3_out.conv.bias[FLOAT, 64]\n",
455 | " %msms.s1_conv3_out.conv.weight[FLOAT, 64x128x1x1]\n",
456 | ") {\n",
457 | " %162 = Conv[dilations = [1, 1], group = 3, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%input, %msms.s0_conv0.conv.depthwise.weight, %msms.s0_conv0.conv.depthwise.bias)\n",
458 | " %163 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%162, %msms.s0_conv0.conv.pointwise.weight, %msms.s0_conv0.conv.pointwise.bias)\n",
459 | " %164 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%163, %msms.s0_conv0.bn.weight, %msms.s0_conv0.bn.bias, %msms.s0_conv0.bn.running_mean, %msms.s0_conv0.bn.running_var)\n",
460 | " %165 = Relu(%164)\n",
461 | " %166 = Pad[mode = 'constant', pads = [0, 0, 0, 0, 0, 0, 0, 0], value = 0](%165)\n",
462 | " %167 = AveragePool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%166)\n",
463 | " %168 = Conv[dilations = [1, 1], group = 3, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%input, %msms.s1_conv0.conv.depthwise.weight, %msms.s1_conv0.conv.depthwise.bias)\n",
464 | " %169 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%168, %msms.s1_conv0.conv.pointwise.weight, %msms.s1_conv0.conv.pointwise.bias)\n",
465 | " %170 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%169, %msms.s1_conv0.bn.weight, %msms.s1_conv0.bn.bias, %msms.s1_conv0.bn.running_mean, %msms.s1_conv0.bn.running_var)\n",
466 | " %171 = Relu(%170)\n",
467 | " %172 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%171)\n",
468 | " %173 = Conv[dilations = [1, 1], group = 16, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%167, %msms.s0_conv1_0.conv.depthwise.weight, %msms.s0_conv1_0.conv.depthwise.bias)\n",
469 | " %174 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%173, %msms.s0_conv1_0.conv.pointwise.weight, %msms.s0_conv1_0.conv.pointwise.bias)\n",
470 | " %175 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%174, %msms.s0_conv1_0.bn.weight, %msms.s0_conv1_0.bn.bias, %msms.s0_conv1_0.bn.running_mean, %msms.s0_conv1_0.bn.running_var)\n",
471 | " %176 = Relu(%175)\n",
472 | " %177 = Conv[dilations = [1, 1], group = 32, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%176, %msms.s0_conv1_1.conv.depthwise.weight, %msms.s0_conv1_1.conv.depthwise.bias)\n",
473 | " %178 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%177, %msms.s0_conv1_1.conv.pointwise.weight, %msms.s0_conv1_1.conv.pointwise.bias)\n",
474 | " %179 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%178, %msms.s0_conv1_1.bn.weight, %msms.s0_conv1_1.bn.bias, %msms.s0_conv1_1.bn.running_mean, %msms.s0_conv1_1.bn.running_var)\n",
475 | " %180 = Relu(%179)\n",
476 | " %181 = Pad[mode = 'constant', pads = [0, 0, 0, 0, 0, 0, 0, 0], value = 0](%180)\n",
477 | " %182 = AveragePool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%181)\n",
478 | " %183 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%182, %msms.s0_conv1_out.conv.weight, %msms.s0_conv1_out.conv.bias)\n",
479 | " %184 = Relu(%183)\n",
480 | " %185 = Conv[dilations = [1, 1], group = 16, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%172, %msms.s1_conv1_0.conv.depthwise.weight, %msms.s1_conv1_0.conv.depthwise.bias)\n",
481 | " %186 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%185, %msms.s1_conv1_0.conv.pointwise.weight, %msms.s1_conv1_0.conv.pointwise.bias)\n",
482 | " %187 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%186, %msms.s1_conv1_0.bn.weight, %msms.s1_conv1_0.bn.bias, %msms.s1_conv1_0.bn.running_mean, %msms.s1_conv1_0.bn.running_var)\n",
483 | " %188 = Tanh(%187)\n",
484 | " %189 = Conv[dilations = [1, 1], group = 32, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%188, %msms.s1_conv1_1.conv.depthwise.weight, %msms.s1_conv1_1.conv.depthwise.bias)\n",
485 | " %190 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%189, %msms.s1_conv1_1.conv.pointwise.weight, %msms.s1_conv1_1.conv.pointwise.bias)\n",
486 | " %191 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%190, %msms.s1_conv1_1.bn.weight, %msms.s1_conv1_1.bn.bias, %msms.s1_conv1_1.bn.running_mean, %msms.s1_conv1_1.bn.running_var)\n",
487 | " %192 = Tanh(%191)\n",
488 | " %193 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%192)\n",
489 | " %194 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%193, %msms.s1_conv1_out.conv.weight, %msms.s1_conv1_out.conv.bias)\n",
490 | " %195 = Tanh(%194)\n",
491 | " %196 = Mul(%184, %195)\n",
492 | " %197 = Pad[mode = 'constant', pads = [0, 0, 0, 0, 0, 0, 0, 0], value = 0](%196)\n",
493 | " %198 = AveragePool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%197)\n",
494 | " %199 = Conv[dilations = [1, 1], group = 32, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%182, %msms.s0_conv2_0.conv.depthwise.weight, %msms.s0_conv2_0.conv.depthwise.bias)\n",
495 | " %200 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%199, %msms.s0_conv2_0.conv.pointwise.weight, %msms.s0_conv2_0.conv.pointwise.bias)\n",
496 | " %201 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%200, %msms.s0_conv2_0.bn.weight, %msms.s0_conv2_0.bn.bias, %msms.s0_conv2_0.bn.running_mean, %msms.s0_conv2_0.bn.running_var)\n",
497 | " %202 = Relu(%201)\n",
498 | " %203 = Conv[dilations = [1, 1], group = 64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%202, %msms.s0_conv2_1.conv.depthwise.weight, %msms.s0_conv2_1.conv.depthwise.bias)\n",
499 | " %204 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%203, %msms.s0_conv2_1.conv.pointwise.weight, %msms.s0_conv2_1.conv.pointwise.bias)\n",
500 | " %205 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%204, %msms.s0_conv2_1.bn.weight, %msms.s0_conv2_1.bn.bias, %msms.s0_conv2_1.bn.running_mean, %msms.s0_conv2_1.bn.running_var)\n",
501 | " %206 = Relu(%205)\n",
502 | " %207 = Pad[mode = 'constant', pads = [0, 0, 0, 0, 0, 0, 0, 0], value = 0](%206)\n",
503 | " %208 = AveragePool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%207)\n",
504 | " %209 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%208, %msms.s0_conv2_out.conv.weight, %msms.s0_conv2_out.conv.bias)\n",
505 | " %210 = Relu(%209)\n",
506 | " %211 = Conv[dilations = [1, 1], group = 32, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%193, %msms.s1_conv2_0.conv.depthwise.weight, %msms.s1_conv2_0.conv.depthwise.bias)\n",
507 | " %212 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%211, %msms.s1_conv2_0.conv.pointwise.weight, %msms.s1_conv2_0.conv.pointwise.bias)\n",
508 | " %213 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%212, %msms.s1_conv2_0.bn.weight, %msms.s1_conv2_0.bn.bias, %msms.s1_conv2_0.bn.running_mean, %msms.s1_conv2_0.bn.running_var)\n",
509 | " %214 = Tanh(%213)\n",
510 | " %215 = Conv[dilations = [1, 1], group = 64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%214, %msms.s1_conv2_1.conv.depthwise.weight, %msms.s1_conv2_1.conv.depthwise.bias)\n",
511 | " %216 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%215, %msms.s1_conv2_1.conv.pointwise.weight, %msms.s1_conv2_1.conv.pointwise.bias)\n",
512 | " %217 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%216, %msms.s1_conv2_1.bn.weight, %msms.s1_conv2_1.bn.bias, %msms.s1_conv2_1.bn.running_mean, %msms.s1_conv2_1.bn.running_var)\n",
513 | " %218 = Tanh(%217)\n",
514 | " %219 = MaxPool[kernel_shape = [2, 2], pads = [0, 0, 0, 0], strides = [2, 2]](%218)\n",
515 | " %220 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%219, %msms.s1_conv2_out.conv.weight, %msms.s1_conv2_out.conv.bias)\n",
516 | " %221 = Tanh(%220)\n",
517 | " %222 = Mul(%210, %221)\n",
518 | " %223 = Conv[dilations = [1, 1], group = 64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%208, %msms.s0_conv3_0.conv.depthwise.weight, %msms.s0_conv3_0.conv.depthwise.bias)\n",
519 | " %224 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%223, %msms.s0_conv3_0.conv.pointwise.weight, %msms.s0_conv3_0.conv.pointwise.bias)\n",
520 | " %225 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%224, %msms.s0_conv3_0.bn.weight, %msms.s0_conv3_0.bn.bias, %msms.s0_conv3_0.bn.running_mean, %msms.s0_conv3_0.bn.running_var)\n",
521 | " %226 = Relu(%225)\n",
522 | " %227 = Conv[dilations = [1, 1], group = 128, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%226, %msms.s0_conv3_1.conv.depthwise.weight, %msms.s0_conv3_1.conv.depthwise.bias)\n",
523 | " %228 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%227, %msms.s0_conv3_1.conv.pointwise.weight, %msms.s0_conv3_1.conv.pointwise.bias)\n",
524 | " %229 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%228, %msms.s0_conv3_1.bn.weight, %msms.s0_conv3_1.bn.bias, %msms.s0_conv3_1.bn.running_mean, %msms.s0_conv3_1.bn.running_var)\n",
525 | " %230 = Relu(%229)\n",
526 | " %231 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%230, %msms.s0_conv3_out.conv.weight, %msms.s0_conv3_out.conv.bias)\n",
527 | " %232 = Relu(%231)\n",
528 | " %233 = Conv[dilations = [1, 1], group = 64, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%219, %msms.s1_conv3_0.conv.depthwise.weight, %msms.s1_conv3_0.conv.depthwise.bias)\n",
529 | " %234 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%233, %msms.s1_conv3_0.conv.pointwise.weight, %msms.s1_conv3_0.conv.pointwise.bias)\n",
530 | " %235 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%234, %msms.s1_conv3_0.bn.weight, %msms.s1_conv3_0.bn.bias, %msms.s1_conv3_0.bn.running_mean, %msms.s1_conv3_0.bn.running_var)\n",
531 | " %236 = Tanh(%235)\n",
532 | " %237 = Conv[dilations = [1, 1], group = 128, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%236, %msms.s1_conv3_1.conv.depthwise.weight, %msms.s1_conv3_1.conv.depthwise.bias)\n",
533 | " %238 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%237, %msms.s1_conv3_1.conv.pointwise.weight, %msms.s1_conv3_1.conv.pointwise.bias)\n",
534 | " %239 = BatchNormalization[epsilon = 9.99999974737875e-06, momentum = 0.899999976158142](%238, %msms.s1_conv3_1.bn.weight, %msms.s1_conv3_1.bn.bias, %msms.s1_conv3_1.bn.running_mean, %msms.s1_conv3_1.bn.running_var)\n",
535 | " %240 = Tanh(%239)\n",
536 | " %241 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [1, 1]](%240, %msms.s1_conv3_out.conv.weight, %msms.s1_conv3_out.conv.bias)\n",
537 | " %242 = Tanh(%241)\n",
538 | " %243 = Mul(%232, %242)\n",
539 | " %244 = ReduceMean[axes = [1], keepdims = 1](%243)\n",
540 | " %245 = Sub(%243, %244)\n",
541 | " %246 = Mul(%245, %245)\n",
542 | " %247 = ReduceMean[axes = [1], keepdims = 1](%246)\n",
543 | " %248 = Shape(%247)\n",
544 | " %249 = Constant[value = ]()\n",
545 | " %250 = Gather[axis = 0](%248, %249)\n",
546 | " %252 = Unsqueeze[axes = [0]](%250)\n",
547 | " %254 = Concat[axis = 0](%252, %595)\n",
548 | " %255 = Reshape(%247, %254)\n",
549 | " %256 = ReduceMean[axes = [1], keepdims = 1](%222)\n",
550 | " %257 = Sub(%222, %256)\n",
551 | " %258 = Mul(%257, %257)\n",
552 | " %259 = ReduceMean[axes = [1], keepdims = 1](%258)\n",
553 | " %260 = Shape(%259)\n",
554 | " %261 = Constant[value = ]()\n",
555 | " %262 = Gather[axis = 0](%260, %261)\n",
556 | " %264 = Unsqueeze[axes = [0]](%262)\n",
557 | " %266 = Concat[axis = 0](%264, %596)\n",
558 | " %267 = Reshape(%259, %266)\n",
559 | " %268 = ReduceMean[axes = [1], keepdims = 1](%198)\n",
560 | " %269 = Sub(%198, %268)\n",
561 | " %270 = Mul(%269, %269)\n",
562 | " %271 = ReduceMean[axes = [1], keepdims = 1](%270)\n",
563 | " %272 = Shape(%271)\n",
564 | " %273 = Constant[value = ]()\n",
565 | " %274 = Gather[axis = 0](%272, %273)\n",
566 | " %276 = Unsqueeze[axes = [0]](%274)\n",
567 | " %278 = Concat[axis = 0](%276, %597)\n",
568 | " %279 = Reshape(%271, %278)\n",
569 | " %280 = Concat[axis = 1](%255, %267, %279)\n",
570 | " %281 = Gemm[alpha = 1, beta = 1, transB = 1](%280, %fgsm.fc.weight, %fgsm.fc.bias)\n",
571 | " %282 = Sigmoid(%281)\n",
572 | " %283 = Shape(%282)\n",
573 | " %284 = Constant[value = ]()\n",
574 | " %285 = Gather[axis = 0](%283, %284)\n",
575 | " %288 = Unsqueeze[axes = [0]](%285)\n",
576 | " %291 = Concat[axis = 0](%288, %598, %599)\n",
577 | " %292 = Reshape(%282, %291)\n",
578 | " %293 = Gemm[alpha = 1, beta = 1, transB = 1](%255, %fgsm.fm.weight, %fgsm.fm.bias)\n",
579 | " %294 = Sigmoid(%293)\n",
580 | " %295 = Shape(%294)\n",
581 | " %296 = Constant[value = ]()\n",
582 | " %297 = Gather[axis = 0](%295, %296)\n",
583 | " %300 = Unsqueeze[axes = [0]](%297)\n",
584 | " %303 = Concat[axis = 0](%300, %600, %601)\n",
585 | " %304 = Reshape(%294, %303)\n",
586 | " %305 = Gemm[alpha = 1, beta = 1, transB = 1](%267, %fgsm.fm.weight, %fgsm.fm.bias)\n",
587 | " %306 = Sigmoid(%305)\n",
588 | " %307 = Shape(%306)\n",
589 | " %308 = Constant[value = ]()\n",
590 | " %309 = Gather[axis = 0](%307, %308)\n",
591 | " %312 = Unsqueeze[axes = [0]](%309)\n",
592 | " %315 = Concat[axis = 0](%312, %602, %603)\n",
593 | " %316 = Reshape(%306, %315)\n",
594 | " %317 = Gemm[alpha = 1, beta = 1, transB = 1](%279, %fgsm.fm.weight, %fgsm.fm.bias)\n",
595 | " %318 = Sigmoid(%317)\n",
596 | " %319 = Shape(%318)\n",
597 | " %320 = Constant[value = ]()\n",
598 | " %321 = Gather[axis = 0](%319, %320)\n",
599 | " %324 = Unsqueeze[axes = [0]](%321)\n",
600 | " %327 = Concat[axis = 0](%324, %604, %605)\n",
601 | " %328 = Reshape(%318, %327)\n",
602 | " %329 = MatMul(%292, %304)\n",
603 | " %330 = MatMul(%292, %316)\n",
604 | " %331 = MatMul(%292, %328)\n",
605 | " %332 = Shape(%243)\n",
606 | " %333 = Constant[value = ]()\n",
607 | " %334 = Gather[axis = 0](%332, %333)\n",
608 | " %335 = Shape(%243)\n",
609 | " %336 = Constant[value = ]()\n",
610 | " %337 = Gather[axis = 0](%335, %336)\n",
611 | " %338 = Shape(%243)\n",
612 | " %339 = Constant[value = ]()\n",
613 | " %340 = Gather[axis = 0](%338, %339)\n",
614 | " %341 = Mul(%337, %340)\n",
615 | " %344 = Unsqueeze[axes = [0]](%341)\n",
616 | " %345 = Unsqueeze[axes = [0]](%334)\n",
617 | " %346 = Concat[axis = 0](%606, %344, %345)\n",
618 | " %347 = Reshape(%243, %346)\n",
619 | " %348 = Mul(%337, %340)\n",
620 | " %351 = Unsqueeze[axes = [0]](%348)\n",
621 | " %352 = Unsqueeze[axes = [0]](%334)\n",
622 | " %353 = Concat[axis = 0](%607, %351, %352)\n",
623 | " %354 = Reshape(%222, %353)\n",
624 | " %355 = Mul(%337, %340)\n",
625 | " %358 = Unsqueeze[axes = [0]](%355)\n",
626 | " %359 = Unsqueeze[axes = [0]](%334)\n",
627 | " %360 = Concat[axis = 0](%608, %358, %359)\n",
628 | " %361 = Reshape(%198, %360)\n",
629 | " %362 = Concat[axis = 1](%347, %354, %361)\n",
630 | " %363 = MatMul(%329, %362)\n",
631 | " %364 = MatMul(%330, %362)\n",
632 | " %365 = MatMul(%331, %362)\n",
633 | " %366 = ReduceSum[axes = [-1], keepdims = 1](%329)\n",
634 | " %367 = Constant[value = ]()\n",
635 | " %368 = Add(%366, %367)\n",
636 | " %369 = ReduceSum[axes = [-1], keepdims = 1](%330)\n",
637 | " %370 = Constant[value = ]()\n",
638 | " %371 = Add(%369, %370)\n",
639 | " %372 = ReduceSum[axes = [-1], keepdims = 1](%331)\n",
640 | " %373 = Constant[value = ]()\n",
641 | " %374 = Add(%372, %373)\n",
642 | " %375 = Div(%363, %368)\n",
643 | " %376 = Div(%364, %371)\n",
644 | " %377 = Div(%365, %374)\n",
645 | " %378 = Concat[axis = 1](%375, %376, %377)\n",
646 | " %379 = Unsqueeze[axes = [1]](%378)\n",
647 | " %380 = Unsqueeze[axes = [3]](%379)\n",
648 | " %382 = MatMul(%380, %609)\n",
649 | " %383 = Squeeze[axes = [3]](%382)\n",
650 | " %384 = Shape(%caps_layer.affine_w)\n",
651 | " %385 = Constant[value = ]()\n",
652 | " %386 = Gather[axis = 0](%384, %385)\n",
653 | " %387 = Shape(%caps_layer.affine_w)\n",
654 | " %388 = Constant[value = ]()\n",
655 | " %389 = Gather[axis = 0](%387, %388)\n",
656 | " %390 = Shape(%383)\n",
657 | " %391 = Constant[value = ]()\n",
658 | " %392 = Gather[axis = 0](%390, %391)\n",
659 | " %393 = Unsqueeze[axes = [0]](%392)\n",
660 | " %394 = Unsqueeze[axes = [0]](%386)\n",
661 | " %395 = Unsqueeze[axes = [0]](%389)\n",
662 | " %396 = Concat[axis = 0](%393, %394, %395)\n",
663 | " %397 = ConstantOfShape[value = ](%396)\n",
664 | " %398 = Cast[to = 1](%397)\n",
665 | " %399 = Exp(%398)\n",
666 | " %400 = ReduceSum[axes = [1]](%399)\n",
667 | " %401 = Div(%399, %400)\n",
668 | " %402 = Unsqueeze[axes = [2]](%401)\n",
669 | " %403 = MatMul(%402, %383)\n",
670 | " %406 = Pow(%403, %610)\n",
671 | " %407 = ReduceSum[axes = [-1], keepdims = 1](%406)\n",
672 | " %408 = Constant[value = ]()\n",
673 | " %409 = Add(%407, %408)\n",
674 | " %410 = Div(%407, %409)\n",
675 | " %411 = Mul(%410, %403)\n",
676 | " %412 = Sqrt(%407)\n",
677 | " %413 = Constant[value = ]()\n",
678 | " %414 = Add(%412, %413)\n",
679 | " %415 = Div(%411, %414)\n",
680 | " %416 = Transpose[perm = [0, 1, 3, 2]](%415)\n",
681 | " %417 = MatMul(%383, %416)\n",
682 | " %418 = Squeeze[axes = [3]](%417)\n",
683 | " %419 = Add(%398, %418)\n",
684 | " %420 = Exp(%419)\n",
685 | " %421 = ReduceSum[axes = [1]](%420)\n",
686 | " %422 = Div(%420, %421)\n",
687 | " %423 = Unsqueeze[axes = [2]](%422)\n",
688 | " %424 = MatMul(%423, %383)\n",
689 | " %427 = Pow(%424, %611)\n",
690 | " %428 = ReduceSum[axes = [-1], keepdims = 1](%427)\n",
691 | " %429 = Constant[value = ]()\n",
692 | " %430 = Add(%428, %429)\n",
693 | " %431 = Div(%428, %430)\n",
694 | " %432 = Mul(%431, %424)\n",
695 | " %433 = Sqrt(%428)\n",
696 | " %434 = Constant[value = ]()\n",
697 | " %435 = Add(%433, %434)\n",
698 | " %436 = Div(%432, %435)\n",
699 | " %437 = Squeeze[axes = [2]](%436)\n",
700 | " %438 = Shape(%437)\n",
701 | " %439 = Constant[value = ]()\n",
702 | " %440 = Gather[axis = 0](%438, %439)\n",
703 | " %441 = Slice[axes = [1], ends = [1], starts = [0]](%437)\n",
704 | " %443 = Unsqueeze[axes = [0]](%440)\n",
705 | " %445 = Concat[axis = 0](%443, %612)\n",
706 | " %446 = Reshape(%441, %445)\n",
707 | " %447 = Slice[axes = [1], ends = [2], starts = [1]](%437)\n",
708 | " %449 = Unsqueeze[axes = [0]](%440)\n",
709 | " %451 = Concat[axis = 0](%449, %613)\n",
710 | " %452 = Reshape(%447, %451)\n",
711 | " %453 = Slice[axes = [1], ends = [3], starts = [2]](%437)\n",
712 | " %455 = Unsqueeze[axes = [0]](%440)\n",
713 | " %457 = Concat[axis = 0](%455, %614)\n",
714 | " %458 = Reshape(%453, %457)\n",
715 | " %459 = Slice[axes = [1], ends = [4], starts = [0]](%446)\n",
716 | " %460 = Gemm[alpha = 1, beta = 1, transB = 1](%459, %esp_s1.shift_fc.weight, %esp_s1.shift_fc.bias)\n",
717 | " %461 = Tanh(%460)\n",
718 | " %462 = Slice[axes = [1], ends = [8], starts = [4]](%446)\n",
719 | " %463 = Gemm[alpha = 1, beta = 1, transB = 1](%462, %esp_s1.scale_fc.weight, %esp_s1.scale_fc.bias)\n",
720 | " %464 = Tanh(%463)\n",
721 | " %465 = Slice[axes = [1], ends = [9223372036854775807], starts = [8]](%446)\n",
722 | " %466 = Gemm[alpha = 1, beta = 1, transB = 1](%465, %esp_s1.pred_fc.weight, %esp_s1.pred_fc.bias)\n",
723 | " %467 = Relu(%466)\n",
724 | " %468 = Shape(%467)\n",
725 | " %469 = Constant[value = ]()\n",
726 | " %470 = Gather[axis = 0](%468, %469)\n",
727 | " %473 = Unsqueeze[axes = [0]](%470)\n",
728 | " %476 = Concat[axis = 0](%473, %615, %616)\n",
729 | " %477 = Reshape(%467, %476)\n",
730 | " %478 = Slice[axes = [1], ends = [4], starts = [0]](%452)\n",
731 | " %479 = Gemm[alpha = 1, beta = 1, transB = 1](%478, %esp_s2.shift_fc.weight, %esp_s2.shift_fc.bias)\n",
732 | " %480 = Tanh(%479)\n",
733 | " %481 = Slice[axes = [1], ends = [8], starts = [4]](%452)\n",
734 | " %482 = Gemm[alpha = 1, beta = 1, transB = 1](%481, %esp_s2.scale_fc.weight, %esp_s2.scale_fc.bias)\n",
735 | " %483 = Tanh(%482)\n",
736 | " %484 = Slice[axes = [1], ends = [9223372036854775807], starts = [8]](%452)\n",
737 | " %485 = Gemm[alpha = 1, beta = 1, transB = 1](%484, %esp_s2.pred_fc.weight, %esp_s2.pred_fc.bias)\n",
738 | " %486 = Relu(%485)\n",
739 | " %487 = Shape(%486)\n",
740 | " %488 = Constant[value = ]()\n",
741 | " %489 = Gather[axis = 0](%487, %488)\n",
742 | " %492 = Unsqueeze[axes = [0]](%489)\n",
743 | " %495 = Concat[axis = 0](%492, %617, %618)\n",
744 | " %496 = Reshape(%486, %495)\n",
745 | " %497 = Slice[axes = [1], ends = [4], starts = [0]](%458)\n",
746 | " %498 = Gemm[alpha = 1, beta = 1, transB = 1](%497, %esp_s3.shift_fc.weight, %esp_s3.shift_fc.bias)\n",
747 | " %499 = Tanh(%498)\n",
748 | " %500 = Slice[axes = [1], ends = [8], starts = [4]](%458)\n",
749 | " %501 = Gemm[alpha = 1, beta = 1, transB = 1](%500, %esp_s3.scale_fc.weight, %esp_s3.scale_fc.bias)\n",
750 | " %502 = Tanh(%501)\n",
751 | " %503 = Slice[axes = [1], ends = [9223372036854775807], starts = [8]](%458)\n",
752 | " %504 = Gemm[alpha = 1, beta = 1, transB = 1](%503, %esp_s3.pred_fc.weight, %esp_s3.pred_fc.bias)\n",
753 | " %505 = Relu(%504)\n",
754 | " %506 = Shape(%505)\n",
755 | " %507 = Constant[value = ]()\n",
756 | " %508 = Gather[axis = 0](%506, %507)\n",
757 | " %511 = Unsqueeze[axes = [0]](%508)\n",
758 | " %514 = Concat[axis = 0](%511, %619, %620)\n",
759 | " %515 = Reshape(%505, %514)\n",
760 | " %516 = Constant[value = ]()\n",
761 | " %517 = Add(%461, %516)\n",
762 | " %518 = Constant[value = ]()\n",
763 | " %519 = Gather[axis = 2](%477, %518)\n",
764 | " %520 = Mul(%517, %519)\n",
765 | " %521 = Constant[value = ]()\n",
766 | " %522 = Gather[axis = 2](%477, %521)\n",
767 | " %523 = Mul(%461, %522)\n",
768 | " %524 = Add(%520, %523)\n",
769 | " %525 = Constant[value = ]()\n",
770 | " %526 = Add(%461, %525)\n",
771 | " %527 = Constant[value = ]()\n",
772 | " %528 = Gather[axis = 2](%477, %527)\n",
773 | " %529 = Mul(%526, %528)\n",
774 | " %530 = Add(%524, %529)\n",
775 | " %531 = Constant[value = ]()\n",
776 | " %532 = Add(%464, %531)\n",
777 | " %533 = Constant[value = ]()\n",
778 | " %534 = Mul(%532, %533)\n",
779 | " %535 = Div(%530, %534)\n",
780 | " %536 = Constant[value = ]()\n",
781 | " %537 = Add(%480, %536)\n",
782 | " %538 = Constant[value = ]()\n",
783 | " %539 = Gather[axis = 2](%496, %538)\n",
784 | " %540 = Mul(%537, %539)\n",
785 | " %541 = Constant[value = ]()\n",
786 | " %542 = Gather[axis = 2](%496, %541)\n",
787 | " %543 = Mul(%480, %542)\n",
788 | " %544 = Add(%540, %543)\n",
789 | " %545 = Constant[value = ]()\n",
790 | " %546 = Add(%480, %545)\n",
791 | " %547 = Constant[value = ]()\n",
792 | " %548 = Gather[axis = 2](%496, %547)\n",
793 | " %549 = Mul(%546, %548)\n",
794 | " %550 = Add(%544, %549)\n",
795 | " %551 = Constant[value = ]()\n",
796 | " %552 = Add(%464, %551)\n",
797 | " %553 = Constant[value = ]()\n",
798 | " %554 = Mul(%552, %553)\n",
799 | " %555 = Div(%550, %554)\n",
800 | " %556 = Constant[value = ]()\n",
801 | " %557 = Add(%483, %556)\n",
802 | " %558 = Constant[value = ]()\n",
803 | " %559 = Mul(%557, %558)\n",
804 | " %560 = Div(%555, %559)\n",
805 | " %561 = Constant[value = ]()\n",
806 | " %562 = Add(%499, %561)\n",
807 | " %563 = Constant[value = ]()\n",
808 | " %564 = Gather[axis = 2](%515, %563)\n",
809 | " %565 = Mul(%562, %564)\n",
810 | " %566 = Constant[value = ]()\n",
811 | " %567 = Gather[axis = 2](%515, %566)\n",
812 | " %568 = Mul(%499, %567)\n",
813 | " %569 = Add(%565, %568)\n",
814 | " %570 = Constant[value = ]()\n",
815 | " %571 = Add(%499, %570)\n",
816 | " %572 = Constant[value = ]()\n",
817 | " %573 = Gather[axis = 2](%515, %572)\n",
818 | " %574 = Mul(%571, %573)\n",
819 | " %575 = Add(%569, %574)\n",
820 | " %576 = Constant[value = ]()\n",
821 | " %577 = Add(%464, %576)\n",
822 | " %578 = Constant[value = ]()\n",
823 | " %579 = Mul(%577, %578)\n",
824 | " %580 = Div(%575, %579)\n",
825 | " %581 = Constant[value = ]()\n",
826 | " %582 = Add(%483, %581)\n",
827 | " %583 = Constant[value = ]()\n",
828 | " %584 = Mul(%582, %583)\n",
829 | " %585 = Div(%580, %584)\n",
830 | " %586 = Constant[value = ]()\n",
831 | " %587 = Add(%502, %586)\n",
832 | " %588 = Constant[value = ]()\n",
833 | " %589 = Mul(%587, %588)\n",
834 | " %590 = Div(%585, %589)\n",
835 | " %591 = Add(%535, %560)\n",
836 | " %592 = Add(%591, %590)\n",
837 | " %593 = Constant[value = ]()\n",
838 | " %output = Mul(%592, %593)\n",
839 | " return %output\n",
840 | "}\n"
841 | ]
842 | }
843 | ],
844 | "source": [
845 | "#Verify ONNX model\n",
846 | "model = onnx.load(save_path)\n",
847 | "\n",
848 | "# Check that the IR is well formed\n",
849 | "onnx.checker.check_model(model)\n",
850 | "\n",
851 | "# Print a human readable representation of the graph\n",
852 | "print(onnx.helper.printable_graph(model.graph))"
853 | ]
854 | },
855 | {
856 | "cell_type": "markdown",
857 | "metadata": {},
858 | "source": [
859 | "**5. Compare ONNXRuntime and Pytorch Exported Model Output:-**"
860 | ]
861 | },
862 | {
863 | "cell_type": "code",
864 | "execution_count": 6,
865 | "metadata": {},
866 | "outputs": [
867 | {
868 | "name": "stdout",
869 | "output_type": "stream",
870 | "text": [
871 | "Model Testing was Successful, ONNXRuntime Model Output matches with Pytorch Model Output!\n"
872 | ]
873 | }
874 | ],
875 | "source": [
876 | "ort_session = onnxruntime.InferenceSession(save_path)\n",
877 | "\n",
878 | "def to_numpy(tensor):\n",
879 | " return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()\n",
880 | "\n",
881 | "# compute ONNX Runtime output prediction\n",
882 | "ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}\n",
883 | "ort_outs = ort_session.run(None, ort_inputs)\n",
884 | "\n",
885 | "# compare ONNX Runtime and PyTorch results\n",
886 | "np.testing.assert_allclose(to_numpy(model_out), ort_outs[0], rtol=1e-03, atol=1e-05)\n",
887 | "\n",
888 | "print(\"Model Testing was Successful, ONNXRuntime Model Output matches with Pytorch Model Output!\")"
889 | ]
890 | },
891 | {
892 | "cell_type": "code",
893 | "execution_count": null,
894 | "metadata": {},
895 | "outputs": [],
896 | "source": []
897 | }
898 | ],
899 | "metadata": {
900 | "kernelspec": {
901 | "display_name": "Python 3",
902 | "language": "python",
903 | "name": "python3"
904 | },
905 | "language_info": {
906 | "codemirror_mode": {
907 | "name": "ipython",
908 | "version": 3
909 | },
910 | "file_extension": ".py",
911 | "mimetype": "text/x-python",
912 | "name": "python",
913 | "nbconvert_exporter": "python",
914 | "pygments_lexer": "ipython3",
915 | "version": "3.7.6"
916 | }
917 | },
918 | "nbformat": 4,
919 | "nbformat_minor": 2
920 | }
921 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset Class for FSANet Training
3 | Implemented by Omar Hassan
4 | August, 2020
5 | """
6 |
7 | from torch.utils.data import Dataset
8 | import torch
9 | import numpy as np
10 | import glob
11 |
12 | class HeadposeDataset(Dataset):
13 |
14 | def __init__(self,data_path,
15 | transform=None):
16 |
17 | self.transform = transform
18 |
19 | #since the data is not much, we can load it
20 | #entirely in RAM
21 | files_path = glob.glob(f'{data_path}/*.npz')
22 | image = []
23 | pose = []
24 | for path in files_path:
25 | data = np.load(path)
26 | image.append(data["image"])
27 | pose.append(data["pose"])
28 |
29 | image = np.concatenate(image,0)
30 | pose = np.concatenate(pose,0)
31 |
32 | #exclude examples with pose outside [-99,99]
33 | x_data = []
34 | y_data = []
35 | for i in range(pose.shape[0]):
36 | if np.max(pose[i,:])<=99.0 and np.min(pose[i,:])>=-99.0:
37 | x_data.append(image[i])
38 | y_data.append(pose[i])
39 |
40 | self.x_data = np.array(x_data)
41 | self.y_data = np.array(y_data)
42 |
43 |
44 | print('x (images) shape: ',self.x_data.shape)
45 | print('y (poses) shape: ',self.y_data.shape)
46 |
47 | def set_transform(self,transform):
48 | self.transform = transform
49 |
50 | def __len__(self):
51 | return self.y_data.shape[0]
52 |
53 | def __getitem__(self, idx):
54 | x = self.x_data[idx]
55 | y = self.y_data[idx]
56 |
57 | if(self.transform):
58 | x = self.transform(x)
59 |
60 | return x,y
61 |
62 | #used to apply different transforms to train,validation dataset
63 | class DatasetFromSubset(Dataset):
64 | def __init__(self, subset, transform=None):
65 | self.subset = subset
66 | self.transform = transform
67 |
68 | def __getitem__(self, index):
69 | x, y = self.subset[index]
70 | if self.transform:
71 | x = self.transform(x)
72 | return x, y
73 |
74 | def __len__(self):
75 | return len(self.subset)
76 |
--------------------------------------------------------------------------------
/src/demo.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import numpy as np
3 | import cv2
4 | import onnxruntime
5 | import sys
6 | from pathlib import Path
7 | #local imports
8 | from face_detector import FaceDetector
9 | from utils import draw_axis
10 |
11 | root_path = str(Path(__file__).absolute().parent.parent)
12 |
13 | def _main(cap_src):
14 |
15 | cap = cv2.VideoCapture(cap_src)
16 | cap.set(cv2.CAP_PROP_FRAME_WIDTH, 640)
17 | cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 480)
18 |
19 | face_d = FaceDetector()
20 |
21 | sess = onnxruntime.InferenceSession(f'{root_path}/pretrained/fsanet-1x1-iter-688590.onnx')
22 |
23 | sess2 = onnxruntime.InferenceSession(f'{root_path}/pretrained/fsanet-var-iter-688590.onnx')
24 |
25 | print('Processing frames, press q to exit application...')
26 | while True:
27 | ret,frame = cap.read()
28 | if(not ret):
29 | print('Could not capture a valid frame from video source, check your cam/video value...')
30 | break
31 | #get face bounding boxes from frame
32 | face_bb = face_d.get(frame)
33 | for (x1,y1,x2,y2) in face_bb:
34 | face_roi = frame[y1:y2+1,x1:x2+1]
35 |
36 | #preprocess headpose model input
37 | face_roi = cv2.resize(face_roi,(64,64))
38 | face_roi = face_roi.transpose((2,0,1))
39 | face_roi = np.expand_dims(face_roi,axis=0)
40 | face_roi = (face_roi-127.5)/128
41 | face_roi = face_roi.astype(np.float32)
42 |
43 | #get headpose
44 | res1 = sess.run(["output"], {"input": face_roi})[0]
45 | res2 = sess2.run(["output"], {"input": face_roi})[0]
46 |
47 | yaw,pitch,roll = np.mean(np.vstack((res1,res2)),axis=0)
48 |
49 | draw_axis(frame,yaw,pitch,roll,tdx=(x2-x1)//2+x1,tdy=(y2-y1)//2+y1,size=50)
50 |
51 | #draw face bb
52 | # cv2.rectangle(frame,(x1,y1),(x2,y2),(0,255,0),2)
53 |
54 | cv2.imshow('Frame',frame)
55 |
56 | key = cv2.waitKey(1)&0xFF
57 | if(key == ord('q')):
58 | break
59 |
60 |
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = ArgumentParser()
65 | parser.add_argument("--video", type=str, default=None,
66 | help="Path of video to process i.e. /path/to/vid.mp4")
67 | parser.add_argument("--cam", type=int, default=None,
68 | help="Specify camera index i.e. 0,1,2...")
69 | args = parser.parse_args()
70 | cap_src = args.cam if args.cam is not None else args.video
71 | if(cap_src is None):
72 | print('Camera or video not specified as argument, selecting default camera node (0) as input...')
73 | cap_src = 0
74 | _main(cap_src)
75 |
--------------------------------------------------------------------------------
/src/face_detector.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | from pathlib import Path
3 | import glob
4 | import time
5 | import numpy as np
6 | import cv2
7 |
8 | #top_level_dir path
9 | root_path = Path(__file__).parent.parent
10 |
11 | class FaceDetector:
12 | """
13 | This class is used for detecting face.
14 | """
15 |
16 | def __init__(self):
17 |
18 | """
19 | Constructor of class
20 | """
21 |
22 | config_path = root_path.joinpath("pretrained/",
23 | "resnet10_ssd.prototxt")
24 | face_model_path = root_path.joinpath("pretrained/",
25 | "res10_300x300_ssd_iter_140000.caffemodel")
26 |
27 | self.detector = cv2.dnn.readNetFromCaffe(str(config_path),
28 | str(face_model_path))
29 |
30 | #detector prediction threshold
31 | self.confidence = 0.7
32 |
33 |
34 | def get(self,img):
35 | """
36 | Given a image, detect faces and compute their bb
37 |
38 | """
39 | bb = self._detect_face_ResNet10_SSD(img)
40 |
41 | return bb
42 |
43 | def _detect_face_ResNet10_SSD(self,img):
44 | """
45 | Given a img, detect faces in it using resnet10_ssd detector
46 |
47 | """
48 |
49 | detector = self.detector
50 | (h, w) = img.shape[:2]
51 | # construct a blob from the image
52 | img_blob = cv2.dnn.blobFromImage(
53 | cv2.resize(img, (300, 300)), 1.0, (300, 300),
54 | (104.0, 177.0, 123.0), swapRB=False, crop=False)
55 |
56 | detector.setInput(img_blob)
57 | detections = detector.forward()
58 |
59 | (start_x, start_y, end_x, end_y) = (0,0,0,0)
60 | faces_bb = []
61 | if len(detections) > 0:
62 | # we're making the assumption that each image has only ONE
63 | # face, so find the bounding box with the largest probability
64 | for i in range(0, detections.shape[2]):
65 |
66 | score = detections[0, 0, i, 2]
67 |
68 | # ensure that the detection greater than our threshold is
69 | # selected
70 | if score > self.confidence:
71 | # compute the (x, y)-coordinates of the bounding box for
72 | # the face
73 | box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
74 | box = box.astype("int")
75 | (start_x, start_y, end_x, end_y) = box
76 |
77 | # extract the face ROI and grab the ROI dimensions
78 | face = img[start_y:end_y, start_x:end_x]
79 |
80 | (fh, fw) = face.shape[:2]
81 | # ensure the face width and height are sufficiently large
82 | if fw < 20 or fh < 20:
83 | pass
84 | else:
85 | faces_bb.append(box)
86 |
87 | if(len(faces_bb)>0):
88 | faces_bb = np.array(faces_bb)
89 |
90 | return faces_bb
91 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | """
2 | FSANet Model
3 | Implemented by Omar Hassan
4 | August, 2020
5 | """
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | #SeparableConv2d
13 | class SepConv2d(nn.Module):
14 | def __init__(self, nin, nout,ksize=3):
15 | super(SepConv2d, self).__init__()
16 | self.depthwise = nn.Conv2d(nin, nin, kernel_size=ksize, padding=1, groups=nin)
17 | self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
18 |
19 | def forward(self, x):
20 | out = self.depthwise(x)
21 | out = self.pointwise(out)
22 | return out
23 |
24 | #Conv2d with Activation Layer
25 | class Conv2dAct(nn.Module):
26 | def __init__(self,in_channels,out_channels,ksize=1,activation='relu'):
27 | super(Conv2dAct, self).__init__()
28 |
29 | self.conv = nn.Conv2d(in_channels,out_channels,ksize)
30 | if(activation == 'sigmoid'):
31 | self.act = nn.Sigmoid()
32 | elif(activation == 'relu'):
33 | self.act = nn.ReLU()
34 | elif(activation == 'tanh'):
35 | self.act = nn.Tanh()
36 |
37 | def forward(self, x):
38 | x = self.conv(x)
39 | x = self.act(x)
40 |
41 | return x
42 |
43 | class SepConvBlock(nn.Module):
44 | def __init__(self,in_channels, out_channels, activation='relu', ksize=3):
45 | super(SepConvBlock, self).__init__()
46 |
47 | self.conv = SepConv2d(in_channels,out_channels,ksize)
48 | self.bn = nn.BatchNorm2d(out_channels)
49 |
50 | if(activation == 'relu'):
51 | self.act = nn.ReLU()
52 | elif(activation == 'tanh'):
53 | self.act = nn.Tanh()
54 |
55 | def forward(self, x):
56 | x = self.conv(x)
57 | x = self.bn(x)
58 | x = self.act(x)
59 |
60 | return x
61 |
62 |
63 | class MultiStreamMultiStage(nn.Module):
64 | def __init__(self,in_channels):
65 | super(MultiStreamMultiStage, self).__init__()
66 |
67 |
68 | # Stream 0 Layers #
69 | self.avgpool = nn.AvgPool2d(2)
70 | self.s0_conv0 = SepConvBlock(in_channels,16,'relu')
71 |
72 | self.s0_conv1_0 = SepConvBlock(16,32,'relu')
73 | self.s0_conv1_1 = SepConvBlock(32,32,'relu')
74 | self.s0_conv1_out = Conv2dAct(32,64,1,'relu')
75 |
76 | self.s0_conv2_0 = SepConvBlock(32,64,'relu')
77 | self.s0_conv2_1 = SepConvBlock(64,64,'relu')
78 | self.s0_conv2_out = Conv2dAct(64,64,1,'relu')
79 |
80 | self.s0_conv3_0 = SepConvBlock(64,128,'relu')
81 | self.s0_conv3_1 = SepConvBlock(128,128,'relu')
82 | self.s0_conv3_out = Conv2dAct(128,64,1,'relu')
83 |
84 | # Stream 1 Layers #
85 | self.maxpool = nn.MaxPool2d(2)
86 | self.s1_conv0 = SepConvBlock(in_channels,16,'relu')
87 |
88 | self.s1_conv1_0 = SepConvBlock(16,32,'tanh')
89 | self.s1_conv1_1 = SepConvBlock(32,32,'tanh')
90 | self.s1_conv1_out = Conv2dAct(32,64,1,'tanh')
91 |
92 | self.s1_conv2_0 = SepConvBlock(32,64,'tanh')
93 | self.s1_conv2_1 = SepConvBlock(64,64,'tanh')
94 | self.s1_conv2_out = Conv2dAct(64,64,1,'tanh')
95 |
96 | self.s1_conv3_0 = SepConvBlock(64,128,'tanh')
97 | self.s1_conv3_1 = SepConvBlock(128,128,'tanh')
98 | self.s1_conv3_out = Conv2dAct(128,64,1,'tanh')
99 |
100 |
101 | def forward(self,x):
102 | # Stage 0 #
103 | # print(x.shape)
104 | s0_x = self.s0_conv0(x)
105 | s0_x = self.avgpool(s0_x)
106 |
107 | s1_x = self.s1_conv0(x)
108 | s1_x = self.maxpool(s1_x)
109 |
110 | # Stage 1 #
111 | s0_x = self.s0_conv1_0(s0_x)
112 | s0_x = self.s0_conv1_1(s0_x)
113 | s0_x = self.avgpool(s0_x)
114 | s0_stage1_out = self.s0_conv1_out(s0_x)
115 |
116 | s1_x = self.s1_conv1_0(s1_x)
117 | s1_x = self.s1_conv1_1(s1_x)
118 | s1_x = self.maxpool(s1_x)
119 | s1_stage1_out = self.s1_conv1_out(s1_x)
120 |
121 | stage1_out = torch.mul(s0_stage1_out,s1_stage1_out)
122 | #To make output size into (8x8x64), we will do avgpool here
123 | stage1_out = self.avgpool(stage1_out)
124 |
125 | # Stage 2 #
126 | s0_x = self.s0_conv2_0(s0_x)
127 | s0_x = self.s0_conv2_1(s0_x)
128 | s0_x = self.avgpool(s0_x)
129 | s0_stage2_out = self.s0_conv2_out(s0_x)
130 |
131 | s1_x = self.s1_conv2_0(s1_x)
132 | s1_x = self.s1_conv2_1(s1_x)
133 | s1_x = self.maxpool(s1_x)
134 | s1_stage2_out = self.s1_conv2_out(s1_x)
135 |
136 | stage2_out = torch.mul(s0_stage2_out,s1_stage2_out)
137 |
138 | # Stage 3 #
139 | s0_x = self.s0_conv3_0(s0_x)
140 | s0_x = self.s0_conv3_1(s0_x)
141 | s0_stage3_out = self.s0_conv3_out(s0_x)
142 |
143 | s1_x = self.s1_conv3_0(s1_x)
144 | s1_x = self.s1_conv3_1(s1_x)
145 | s1_stage3_out = self.s1_conv3_out(s1_x)
146 |
147 | stage3_out = torch.mul(s0_stage3_out,s1_stage3_out)
148 |
149 | return [stage3_out,stage2_out,stage1_out]
150 |
151 | #Channel-Wise Variance
152 | class VarianceC(nn.Module):
153 | def __init__(self):
154 | super(VarianceC, self).__init__()
155 |
156 | def forward(self,x):
157 | # we could just use torch.var here:
158 | # x = torch.var(x,dim=1,keepdim=True,unbiased=False)
159 | # but since ONNX does not support var operator,
160 | # we are computing variance manually
161 | mean_x = torch.mean(x,dim=1,keepdim=True)
162 | sub_x = x.sub(mean_x)
163 | x = torch.mean(torch.mul(sub_x,sub_x),dim=1,keepdim=True)
164 |
165 | return x
166 |
167 |
168 | class ScoringFunction(nn.Module):
169 | def __init__(self,in_channels,var=False):
170 | super(ScoringFunction, self).__init__()
171 | # self.mdim = mdim
172 | if(var):
173 | self.reduce_channel = VarianceC()
174 | else:
175 | self.reduce_channel = Conv2dAct(in_channels,1,1,'sigmoid')
176 |
177 | # self.fc = nn.Linear(8*8,mdim*(8*8*3))
178 |
179 | def forward(self,x):
180 | x = self.reduce_channel(x)
181 | #flatten x
182 | x = x.view(x.size(0), -1)
183 |
184 | return x
185 |
186 |
187 |
188 | class FineGrainedStructureMapping(nn.Module):
189 | def __init__(self,in_channels,num_primcaps,mdim,var=False):
190 | super(FineGrainedStructureMapping, self).__init__()
191 |
192 | self.n = 8*8*3
193 | self.n_new = int(num_primcaps/3) # this is n' in paper
194 | self.m = mdim
195 |
196 | self.attention_maps = ScoringFunction(in_channels,var)
197 |
198 | self.fm = nn.Linear(self.n//3,self.n*self.m) #this is used for calculating Mk in paper
199 |
200 | self.fc = nn.Linear(self.n,self.n_new*self.m) #this is used for calculating C in paper
201 |
202 | #input is list of stage outputs in batches
203 | def forward(self,x):
204 | U1,U2,U3 = x
205 | #Attention Maps (Ak)
206 | A1 = self.attention_maps(U1)
207 | A2 = self.attention_maps(U2)
208 | A3 = self.attention_maps(U3)
209 |
210 | #Attention Maps Concatenation
211 | A = torch.cat((A1,A2,A3),dim=1)
212 |
213 | #C Matrix
214 | C = torch.sigmoid(self.fc(A))
215 | C = C.view(C.size(0),self.n_new,self.m)
216 |
217 | #Mk Matrices
218 | M1 = torch.sigmoid(self.fm(A1))
219 | M1 = M1.view(M1.size(0),self.m,self.n)
220 |
221 | M2 = torch.sigmoid(self.fm(A2))
222 | M2 = M2.view(M2.size(0),self.m,self.n)
223 |
224 | M3 = torch.sigmoid(self.fm(A3))
225 | M3 = M3.view(M3.size(0),self.m,self.n)
226 |
227 | #Sk Matrices, Sk = matmul(C,Mk)
228 | S1 = torch.matmul(C,M1)
229 | S2 = torch.matmul(C,M2)
230 | S3 = torch.matmul(C,M3)
231 |
232 | #Concatenating Feature Maps, U = [U1,U2,U3]
233 | ##Reshape Uk matrices into 2d i.e. Uk_2d.shape = (batch,w*h,channels)
234 | _,ch,uh,uw = U1.size()
235 | U1 = U1.view(-1,uh*uw,ch)
236 | U2 = U2.view(-1,uh*uw,ch)
237 | U3 = U3.view(-1,uh*uw,ch)
238 |
239 | U = torch.cat((U1,U2,U3),dim=1)
240 |
241 | #Ubar_k Matrices, Ubar_k = Sk*U
242 | Ubar_1 = torch.matmul(S1,U)
243 | Ubar_2 = torch.matmul(S2,U)
244 | Ubar_3 = torch.matmul(S3,U)
245 |
246 | #Normalizing Ubar_k (L1_Norm)
247 | #As our input is in between 0-1 due to sigmoid, we dont need
248 | #to take absolute of values to cancel negative signs.
249 | #this helps us as absolute isn't differentiable
250 | norm_S1 = torch.sum(S1,dim=-1,keepdim=True) + 1e-8 #for numerical stability
251 | norm_S2 = torch.sum(S2,dim=-1,keepdim=True) + 1e-8
252 | norm_S3 = torch.sum(S3,dim=-1,keepdim=True) + 1e-8
253 |
254 | Ubar_1 = Ubar_1/norm_S1
255 | Ubar_2 = Ubar_2/norm_S2
256 | Ubar_3 = Ubar_3/norm_S3
257 |
258 | #Concatenate Ubar_k along dim=1 which is self.n_new
259 | Ubar = torch.cat((Ubar_1,Ubar_2,Ubar_3),dim=1)
260 |
261 | return Ubar
262 |
263 | #1d CapsuleLayer similar to nn.Linear (which outputs scalar neurons),
264 | #here, we output vectored neurons
265 | class CapsuleLayer1d(nn.Module):
266 | def __init__(self,num_in_capsule,in_capsule_dim,num_out_capsule,out_capsule_dim,routings=3):
267 | super(CapsuleLayer1d, self).__init__()
268 | self.routings = routings
269 | #Affine Transformation Weight Matrix which maps spatial relationship
270 | #between input capsules and output capsules
271 | ##initialize affine weight
272 | weight_tensor = torch.empty(
273 | num_out_capsule,
274 | num_in_capsule,
275 | out_capsule_dim,
276 | in_capsule_dim)
277 |
278 | init_weight = torch.nn.init.xavier_uniform_(weight_tensor)
279 | self.affine_w = nn.Parameter(init_weight)
280 |
281 | def squash(self, s, dim=-1):
282 | norm = torch.sum(s**2, dim=dim, keepdim=True)
283 | return norm / (1 + norm) * s / (torch.sqrt(norm) + 1e-8)
284 |
285 | def forward(self,x):
286 | #input shape: [batch,num_in_capsule,in_capsule_dim],
287 | #We will exapnd its dims so that we can do batch matmul properly
288 | #expanded input shape: [batch,1,num_in_capsule,1,in_capsule_dim]
289 | x = x.unsqueeze(1)
290 | x = x.unsqueeze(3)
291 | #input shape: [batch,1,num_in_capsule,1,in_capsule_dim],
292 | #weight shape: [num_out_capsule,num_in_capsule,out_capsule_dim,in_capsule_dim]
293 | #last two dims will be used for matrix multiply, rest is our batch.
294 | #result = input*w.T
295 | #result shape: [batch,num_out_capsule,num_in_capsule,1,out_capsule_dim]
296 | u_hat = torch.matmul(x,torch.transpose(self.affine_w,2,3))
297 | #reduced result shape: [batch,num_out_capsule,num_in_capsule,out_capsule_dim]
298 | u_hat = u_hat.squeeze(3)
299 |
300 | [num_out_capsule,num_in_capsule,out_capsule_dim,in_capsule_dim] = \
301 | self.affine_w.shape
302 |
303 | #initialize coupling coefficient as zeros
304 | b = torch.zeros(u_hat.shape[0],num_out_capsule,num_in_capsule).to(u_hat.device)
305 |
306 | for i in range(self.routings):
307 | #c is used to scale/weigh our input capsules based on their
308 | #similarity with our output capsules
309 | #summing up c for all output capsule equals to 1 due to softmax
310 | #this ensures probability distrubtion to our weights
311 | c = F.softmax(b,dim=1)
312 | #expand c
313 | c = c.unsqueeze(2)
314 |
315 | #u_hat shape: [batch,num_out_capsule,num_in_capsule,out_capsule_dim],
316 | #c shape: [batch,num_out_capsule,1,num_in_capsule]
317 | #result = c*u_hat
318 | #result shape: [batch,num_out_capsule,1,out_capsule_dim]
319 | outputs = torch.matmul(c,u_hat)
320 | #Apply non linear activation function
321 | outputs = self.squash(outputs)
322 |
323 | if i < self.routings - 1:
324 | #update coupling coefficient
325 | #u_hat shape: [batch,num_out_capsule,num_in_capsule,out_capsule_dim],
326 | #outputs shape: [batch,num_out_capsule,1,out_capsule_dim]
327 | #result = u_hat*outputs.T
328 | #result shape: [batch,num_out_capsule,num_in_capsule,1]
329 | b = b + torch.matmul(u_hat,torch.transpose(outputs,2,3)).squeeze(3)
330 | #reduced result shape: [batch,num_out_capsule,num_in_capsule]
331 | b = b
332 |
333 | #reduced result shape: [batch,num_out_capsule,out_capsule_dim]
334 | outputs = outputs.squeeze(2)
335 | return outputs
336 |
337 | class ExtractAggregatedFeatures(nn.Module):
338 | def __init__(self, num_capsule):
339 | super(ExtractAggregatedFeatures, self).__init__()
340 | self.num_capsule = num_capsule
341 |
342 | def forward(self,x):
343 | batch_size = x.shape[0]
344 | bin_size = self.num_capsule//3
345 |
346 | feat_s1 = x[:,:bin_size,:]
347 | feat_s1 = feat_s1.view(batch_size,-1) #reshape to 1d
348 |
349 | feat_s2 = x[:,bin_size:2*bin_size,:]
350 | feat_s2 = feat_s2.view(batch_size,-1)
351 |
352 | feat_s3 = x[:,2*bin_size:self.num_capsule,:]
353 | feat_s3 = feat_s3.view(batch_size,-1)
354 |
355 | return [feat_s1,feat_s2,feat_s3]
356 |
357 |
358 | class ExtractSSRParams(nn.Module):
359 | def __init__(self,bins,classes):
360 | #our classes are: pitch, roll, yaw
361 | #our bins per stage are: 3
362 | super(ExtractSSRParams, self).__init__()
363 | self.bins = bins
364 | self.classes = classes
365 |
366 | self.shift_fc = nn.Linear(4,classes) #used to shift bins
367 |
368 | self.scale_fc = nn.Linear(4,classes) #used to scale bins
369 |
370 | #every class will have its own probability distrubtion of bins
371 | #hence total predictions = bins*classes
372 | self.pred_fc = nn.Linear(8,bins*classes) #classes probability distrubtion of bins
373 |
374 | #x is batches of feature vector of shape: [batches,16]
375 | def forward(self,x):
376 | shift_param = torch.tanh(self.shift_fc(x[:,:4]))
377 | scale_param = torch.tanh(self.scale_fc(x[:,4:8]))
378 | pred_param = F.relu(self.pred_fc(x[:,8:]))
379 | pred_param = pred_param.view(pred_param.size(0),
380 | self.classes,
381 | self.bins)
382 |
383 | return [pred_param,shift_param,scale_param]
384 |
385 |
386 | class SSRLayer(nn.Module):
387 | def __init__(self, bins):
388 | #this ssr layer implements MD 3-stage SSR
389 | super(SSRLayer, self).__init__()
390 | self.bins_per_stage = bins
391 |
392 | #x is list of ssr params for each stage
393 | def forward(self,x):
394 | s1_params,s2_params,s3_params = x
395 |
396 | a = b = c = 0
397 |
398 | bins = self.bins_per_stage
399 |
400 | doffset = bins//2
401 |
402 | V = 99 #max bin width
403 |
404 | #Stage 1 loop over all bins
405 | for i in range(bins):
406 | a = a + (i - doffset + s1_params[1]) * s1_params[0][:,:,i]
407 | #this is unfolded multiplication loop of SSR equation in paper
408 | #here, k = 1
409 | a = a / (bins * (1 + s1_params[2]))
410 |
411 | #Stage 2 loop over all bins
412 | for i in range(bins):
413 | b = b + (i - doffset + s2_params[1]) * s2_params[0][:,:,i]
414 | #this is unfolded multiplication loop of SSR equation in paper
415 | #here, k = 2
416 | b = b / (bins * (1 + s1_params[2])) / (bins * (1 + s2_params[2]))
417 |
418 | #Stage 3 loop over all bins
419 | for i in range(bins):
420 | c = c + (i - doffset + s3_params[1]) * s3_params[0][:,:,i]
421 | #this is unfolded multiplication loop of SSR equation in paper
422 | #here, k = 3
423 | c = c / (bins * (1 + s1_params[2])) / (bins * (1 + s2_params[2])) / (bins * (1 + s3_params[2]))
424 |
425 | pred = (a + b + c) * V
426 |
427 | return pred
428 |
429 | class FSANet(nn.Module):
430 | def __init__(self,var=False):
431 | super(FSANet, self).__init__()
432 | num_primcaps = 7*3
433 | primcaps_dim = 64
434 | num_out_capsule = 3
435 | out_capsule_dim = 16
436 | routings = 2
437 | mdim = 5
438 |
439 | self.msms = MultiStreamMultiStage(3) #channels: rgb
440 | self.fgsm = FineGrainedStructureMapping(64,num_primcaps,mdim,var) #channels: feature maps
441 | self.caps_layer = CapsuleLayer1d(num_primcaps,primcaps_dim,num_out_capsule,out_capsule_dim,routings)
442 | self.eaf = ExtractAggregatedFeatures(num_out_capsule)
443 | self.esp_s1 = ExtractSSRParams(3,3)
444 | self.esp_s2 = ExtractSSRParams(3,3)
445 | self.esp_s3 = ExtractSSRParams(3,3)
446 | self.ssr = SSRLayer(3)
447 |
448 | #x is batch of input rgb images
449 | def forward(self,x):
450 | #Input: batch of RGB images tensors
451 | #Input Shape: [batch,3,64,64]
452 | #Output: list of Tensors containing feature maps Uk
453 | #Output Shape: [U1,U2,U3] where U1=U2=U3 has shape [batch,64,8,8]
454 | x = self.msms(x)
455 |
456 | #Input: Output of msms module
457 | #Output: Grouped feature maps Ubark
458 | #Output Shape: Ubark has shape [batch,21,64]
459 | x = self.fgsm(x)
460 |
461 | #Input: Output of fgsm module
462 | #Output: 3 capsules with shortened dims each representing a stage
463 | #Output Shape: capsules has shape [batch,3,16]
464 | x = self.caps_layer(x)
465 |
466 | #Input: Output of caps_layer module
467 | #Output: each stage capsule seprated as 1d vector
468 | #Output Shape: 3 capsules, each has shape [batch,16]
469 | x = self.eaf(x)
470 |
471 | #Input: Output of eaf module
472 | #Output: ssr params for each stage
473 | #Output Shape: ssr_params = [preds,shift,scale]
474 | #preds shape: [batch,3,3]
475 | #shift shape: [batch,3]
476 | #scale shape: [batch,3]
477 | ##Extract SSR params of each stage
478 | ssr_s1 = self.esp_s1(x[0])
479 | ssr_s2 = self.esp_s2(x[1])
480 | ssr_s3 = self.esp_s3(x[2])
481 |
482 | #Input: Output of esp modules
483 | #Output: ssr pose prediction
484 | #Output Shape: ssr_params = [batch,3]
485 | ##get prediction from SSR layer
486 | x = self.ssr([ssr_s1,ssr_s2,ssr_s3])
487 |
488 | return x
489 |
490 | if __name__ == '__main__':
491 | torch.random.manual_seed(10)
492 | model = FSANet(var=True).to('cuda')
493 | print('##############PyTorch################')
494 | x = torch.randn((1,3,64,64)).to('cuda')
495 | y = model(x)
496 | print(model)
497 |
--------------------------------------------------------------------------------
/src/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Augmentation Transforms for FSANet Training
3 | Implemented by Omar Hassan to match Original Author's Implementation
4 | https://github.com/shamangary/FSA-Net/blob/master/training_and_testing/TYY_generators.py
5 | August, 2020
6 | """
7 |
8 | import torch
9 | import cv2
10 | import numpy as np
11 | from zoom_transform import _apply_random_zoom
12 |
13 | class Normalize(object):
14 | """Applies following normalization: out = (img-mean)/std ."""
15 | def __init__(self, mean, std):
16 |
17 | self.mean = mean
18 | self.std = std
19 |
20 | def __call__(self, img):
21 |
22 | img = (img-self.mean)/self.std
23 | return img
24 |
25 | class RandomCrop(object):
26 | """Select random crop portion from input image."""
27 | def __init__(self):
28 | pass
29 |
30 | def __call__(self, img):
31 | dn = np.random.randint(15,size=1)[0]+1
32 |
33 | dx = np.random.randint(dn,size=1)[0]
34 | dy = np.random.randint(dn,size=1)[0]
35 | h = img.shape[0]
36 | w = img.shape[1]
37 | out = img[0+dy:h-(dn-dy),0+dx:w-(dn-dx),:]
38 |
39 | out = cv2.resize(out, (h,w), interpolation=cv2.INTER_CUBIC)
40 |
41 | return out
42 |
43 | class RandomCropBlack(object):
44 | """
45 | Select random crop portion from input image.
46 | Paste crop region on a black image having same shape as input image.
47 | """
48 | def __init__(self):
49 | pass
50 |
51 | def __call__(self, img):
52 | dn = np.random.randint(15,size=1)[0]+1
53 |
54 | dx = np.random.randint(dn,size=1)[0]
55 | dy = np.random.randint(dn,size=1)[0]
56 |
57 | h = img.shape[0]
58 | w = img.shape[1]
59 |
60 | dx_shift = np.random.randint(dn,size=1)[0]
61 | dy_shift = np.random.randint(dn,size=1)[0]
62 | out = np.zeros_like(img)
63 | out[0+dy_shift:h-(dn-dy_shift),0+dx_shift:w-(dn-dx_shift),:] = img[0+dy:h-(dn-dy),0+dx:w-(dn-dx),:]
64 |
65 | return out
66 |
67 | class RandomCropWhite(object):
68 | """
69 | Select random crop portion from input image.
70 | Paste crop region on a white image having same shape as input image.
71 | """
72 | def __init__(self):
73 | pass
74 |
75 | def __call__(self, img):
76 | dn = np.random.randint(15,size=1)[0]+1
77 |
78 | dx = np.random.randint(dn,size=1)[0]
79 | dy = np.random.randint(dn,size=1)[0]
80 |
81 | h = img.shape[0]
82 | w = img.shape[1]
83 |
84 | dx_shift = np.random.randint(dn,size=1)[0]
85 | dy_shift = np.random.randint(dn,size=1)[0]
86 | out = np.ones_like(img)*255
87 | out[0+dy_shift:h-(dn-dy_shift),0+dx_shift:w-(dn-dx_shift),:] = img[0+dy:h-(dn-dy),0+dx:w-(dn-dx),:]
88 |
89 | return out
90 |
91 | class RandomZoom(object):
92 | """Apply RandomZoom transformation."""
93 | def __init__(self,zoom_range=[0.8,1.2]):
94 | self.zoom_range = zoom_range
95 |
96 | def __call__(self,img):
97 | out = _apply_random_zoom(img,self.zoom_range)
98 |
99 | return out
100 |
101 | class SequenceRandomTransform(object):
102 | """
103 | Apply Transformation in a sequenced random order
104 | similar to original author's implementation
105 | """
106 | def __init__(self,zoom_range=[0.8,1.2]):
107 | self.rc = RandomCrop()
108 | self.rcb = RandomCropBlack()
109 | self.rcw = RandomCropWhite()
110 | self.rz = RandomZoom(zoom_range=zoom_range)
111 |
112 | def __call__(self,img):
113 | rand_r = np.random.random()
114 | if rand_r < 0.25:
115 | img = self.rc(img)
116 |
117 | elif rand_r >= 0.25 and rand_r < 0.5:
118 | img = self.rcb(img)
119 |
120 | elif rand_r >= 0.5 and rand_r < 0.75:
121 | img = self.rcw(img)
122 |
123 | if np.random.random() > 0.3:
124 | img = self.rz(img)
125 |
126 | return img
127 |
128 | class ToTensor(object):
129 | """Convert ndarrays to Tensors."""
130 | def __call__(self, img):
131 | # swap color axis because
132 | # numpy image: H x W x C
133 | # torch image: C X H X W
134 | img = img.transpose((2, 0, 1))
135 |
136 | return torch.from_numpy(img)
137 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size = 50,thickness=(2,2,2)):
5 | """
6 | Function used to draw y (headpose label) on Input Image x.
7 | Implemented by: shamangary
8 | https://github.com/shamangary/FSA-Net/blob/master/demo/demo_FSANET.py
9 | Modified by: Omar Hassan
10 | """
11 | pitch = pitch * np.pi / 180
12 | yaw = -(yaw * np.pi / 180)
13 | roll = roll * np.pi / 180
14 |
15 | if tdx != None and tdy != None:
16 | tdx = tdx
17 | tdy = tdy
18 | else:
19 | height, width = img.shape[:2]
20 | tdx = width / 2
21 | tdy = height / 2
22 |
23 | # X-Axis pointing to right. drawn in red
24 | x1 = size * (np.cos(yaw) * np.cos(roll)) + tdx
25 | y1 = size * (np.cos(pitch) * np.sin(roll) + np.cos(roll) * np.sin(pitch) * np.sin(yaw)) + tdy
26 |
27 | # Y-Axis | drawn in green
28 | # v
29 | x2 = size * (-np.cos(yaw) * np.sin(roll)) + tdx
30 | y2 = size * (np.cos(pitch) * np.cos(roll) - np.sin(pitch) * np.sin(yaw) * np.sin(roll)) + tdy
31 |
32 | # Z-Axis (out of the screen) drawn in blue
33 | x3 = size * (np.sin(yaw)) + tdx
34 | y3 = size * (-np.cos(yaw) * np.sin(pitch)) + tdy
35 |
36 | cv2.line(img, (int(tdx), int(tdy)), (int(x1),int(y1)),(0,0,255),thickness[0])
37 | cv2.line(img, (int(tdx), int(tdy)), (int(x2),int(y2)),(0,255,0),thickness[1])
38 | cv2.line(img, (int(tdx), int(tdy)), (int(x3),int(y3)),(255,0,0),thickness[2])
39 |
40 | return img
41 |
--------------------------------------------------------------------------------
/src/zoom_transform.py:
--------------------------------------------------------------------------------
1 | """
2 | Zoom Transformation
3 | Original Implementation from: keras-preprocessing
4 | https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/image/affine_transformations.py
5 | ------------------------------------------------
6 | Modified by: Omar Hassan
7 | August, 2020
8 | """
9 |
10 | import numpy as np
11 | import scipy
12 | from scipy import ndimage
13 |
14 |
15 | def _apply_random_zoom(x, zoom_range):
16 | """
17 | Apply zoom transformation given a set of range and input image.
18 | :param x: input image
19 | :param zoom_range: list of zoom range i.e. [0.8,1.2]
20 | :return: zoom augmented image
21 | """
22 | if len(zoom_range) != 2:
23 | raise ValueError('`zoom_range` should be a tuple or list of two'
24 | ' floats. Received: %s' % (zoom_range,))
25 |
26 | zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
27 | x = _apply_affine_transform(x, zx=zx, zy=zy)
28 | return x
29 |
30 |
31 | def _transform_matrix_offset_center(matrix, x, y):
32 | o_x = float(x) / 2 + 0.5
33 | o_y = float(y) / 2 + 0.5
34 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
35 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
36 | transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
37 | return transform_matrix
38 |
39 | def _apply_affine_transform(x,zx, zy):
40 | """
41 | Applies affine transformation with scale param of affine matrix
42 | set to zoom parameters.
43 | :param x: input image
44 | :param zx: horizontal zoom scale
45 | :param zy: vertical zoom scale
46 | :return: affine transformed input image
47 | """
48 | if scipy is None:
49 | raise ImportError('Image transformations require SciPy. '
50 | 'Install SciPy.')
51 | channel_axis = 2
52 | order = 1
53 | fill_mode = 'nearest'
54 | cval = 0
55 |
56 | zoom_matrix = np.array([[zx, 0, 0],
57 | [0, zy, 0],
58 | [0, 0, 1]])
59 |
60 | h, w = x.shape[:2]
61 | transform_matrix = _transform_matrix_offset_center(
62 | zoom_matrix, h, w)
63 | x = np.moveaxis(x, channel_axis, 0) #bring channel to first axis
64 | final_affine_matrix = transform_matrix[:2, :2]
65 | final_offset = transform_matrix[:2, 2]
66 |
67 | channel_images = [ndimage.interpolation.affine_transform(
68 | x_channel,
69 | final_affine_matrix,
70 | final_offset,
71 | order=order,
72 | mode=fill_mode,
73 | cval=cval) for x_channel in x]
74 | x = np.stack(channel_images, axis=0)
75 | x = np.moveaxis(x, 0, channel_axis) #bring channel to last axis
76 |
77 | return x
78 |
--------------------------------------------------------------------------------