├── .gitmodules
├── CMakeLists.txt
├── LICENSE.md
├── README.md
├── ZeroQ
├── LICENSE
├── README.md
├── distill_data.py
├── reconstruct_data.py
├── requirements.txt
├── run.sh
├── uniform_test.py
└── utils
│ ├── __init__.py
│ ├── data_utils.py
│ ├── quantization_utils
│ ├── quant_modules.py
│ └── quant_utils.py
│ ├── quantize_model.py
│ └── train_utils.py
├── _512_train.txt
├── convert_ncnn.py
├── dataset
├── __init__.py
├── detection
│ ├── __init__.py
│ ├── open_images.py
│ └── voc_dataset.py
└── segmentation
│ ├── __init__.py
│ ├── custom_transforms.py
│ ├── pascal.py
│ └── utils.py
├── dfq.py
├── images
├── LE_distill.png
├── graph_cls.png
├── graph_deeplab.png
└── graph_ssd.png
├── improve_dfq.py
├── inference_cls.cpp
├── main_cls.py
├── main_seg.py
├── main_ssd.py
├── modeling
├── __init__.py
├── classification
│ ├── MobileNetV2.py
│ └── mobilenetv2_1.0-f2a8633.pth.tar
├── detection
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── mobilenetv1_ssd_config.py
│ │ ├── squeezenet_ssd_config.py
│ │ └── vgg_ssd_config.py
│ ├── data_preprocessing.py
│ ├── fpn_mobilenetv1_ssd.py
│ ├── fpn_ssd.py
│ ├── mb2-ssd-lite-mp-0_686.pth
│ ├── mobilenet_v2_ssd_lite.py
│ ├── mobilenetv1_ssd.py
│ ├── mobilenetv1_ssd_lite.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ ├── mobilenet.py
│ │ ├── mobilenet_v2.py
│ │ ├── multibox_loss.py
│ │ ├── scaled_l2_norm.py
│ │ ├── squeezenet.py
│ │ └── vgg.py
│ ├── predictor.py
│ ├── squeezenet_ssd_lite.py
│ ├── ssd.py
│ ├── transforms
│ │ ├── __init__.py
│ │ └── transforms.py
│ ├── vgg_ssd.py
│ └── voc-model-labels.txt
├── ncnn
│ ├── model_quant_relu_equal.bin
│ ├── model_quant_relu_equal.param
│ └── model_quant_relu_equal.table
└── segmentation
│ ├── __init__.py
│ ├── aspp.py
│ ├── backbone
│ ├── __init__.py
│ ├── drn.py
│ ├── mobilenet.py
│ ├── resnet.py
│ └── xception.py
│ ├── decoder.py
│ ├── deeplab-mobilenet.pth.tar
│ ├── deeplab.py
│ └── sync_batchnorm
│ ├── __init__.py
│ ├── batchnorm.py
│ ├── comm.py
│ ├── replicate.py
│ └── unittest.py
├── requirements.txt
└── utils
├── __init__.py
├── detection
├── __init__.py
├── box_utils.py
├── measurements.py
└── misc.py
├── layer_transform.py
├── metrics.py
├── quantize.py
├── relation.py
└── segmentation
├── __init__.py
└── utils.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "PyTransformer"]
2 | path = PyTransformer
3 | url = https://github.com/ricky40403/PyTransformer
4 |
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.12)
2 | project(dfq)
3 | include_directories(/home/jakc4103/Documents/ncnn/src)
4 | include_directories(/home/jakc4103/Documents/ncnn/build/src)
5 |
6 | #openmp
7 | FIND_PACKAGE( OpenMP REQUIRED)
8 | if(OPENMP_FOUND)
9 | message("OPENMP FOUND")
10 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
11 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
12 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
13 | endif()
14 |
15 | #ncnn
16 | set(NCNN_LIBS /home/jakc4103/Documents/ncnn/build/install/lib/libncnn.a)
17 | set(NCNN_INCLUDE_DIRS /home/jakc4103/Documents/ncnn/build/install/include)
18 | include_directories(${NCNN_INCLUDE_DIRS})
19 |
20 | #opencv
21 | find_package( OpenCV REQUIRED )
22 | include_directories( ${OpenCV_INCLUDE_DIRS} )
23 |
24 | add_executable(inference_cls inference_cls.cpp)
25 | target_link_libraries(inference_cls ${NCNN_LIBS})
26 | target_link_libraries( inference_cls ${OpenCV_LIBS} )
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 jakc4103
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DFQ
2 | PyTorch implementation of [Data Free Quantization Through Weight Equalization and Bias Correction](https://arxiv.org/abs/1906.04721) with some ideas from [ZeroQ: A Novel Zero Shot Quantization Framework](https://arxiv.org/abs/2001.00281).
3 |
4 | ## Results
5 | Int8**: Fake quantization; 8 bits weight, 8 bits activation, 16 bits bias
6 | Int8*: Fake quantization; 8 bits weight, 8 bits activation, 8 bits bias
7 | Int8': Fake quantization; 8 bits weight(symmetric), 8 bits activation(symmetric), 32 bits bias
8 | Int8: Int8 Inference using [ncnn](https://github.com/Tencent/ncnn); 8 bits weight(symmetric), 8 bits activation(symmetric), 32 bits bias
9 |
10 | ### On classification task
11 | - Tested with [MobileNetV2](https://github.com/tonylins/pytorch-mobilenet-v2) and [ResNet-18](https://pytorch.org/docs/stable/torchvision/models.html)
12 | - ImageNet validation set (Acc.)
13 |
14 | MobileNetV2 | ResNet-18 |
15 |
16 |
17 | model/precision | FP32 | Int8** | Int8* | Int8' | Int8 (FP32-69.19)
18 | -----------|------|------| ------ | ------|------
19 | Original | 71.81 | 0.102 | 0.1 | 0.062 | 0.082
20 | +ReLU | 71.78 | 0.102 | 0.096 | 0.094 | 0.082
21 | +ReLU+LE | 71.78 | 70.32 | 68.78 | 67.5 | 65.21
22 | +ReLU+LE +DR | -- | 70.47 | 68.87 | -- | --
23 | +BC | -- | 57.07 | 0.12 | 26.25 | 5.57
24 | +BC +clip_15 | -- | 65.37 | 0.13 | 65.96 | 45.13
25 | +ReLU+LE+BC | -- | 70.79 | 68.17 | 68.65 | 62.19
26 | +ReLU+LE+BC +DR | -- | 70.9 | 68.41 | -- | --
27 |
28 | |
29 |
30 | model/precision | FP32 | Int8** | Int8*
31 | -----------|------|------|------
32 | Original | 69.76 | 69.13 | 69.09
33 | +ReLU | 69.76 | 69.13 | 69.09
34 | +ReLU+LE | 69.76 | 69.2 | 69.2
35 | +ReLU+LE +DR | -- | 67.74 | 67.75
36 | +BC | -- | 69.04 | 68.56
37 | +BC +clip_15 | -- | 69.04 | 68.56
38 | +ReLU+LE+BC | -- | 69.04 | 68.56
39 | +ReLU+LE+BC +DR | -- | 67.65 | 67.62
40 |
41 | |
42 |
43 | ### On segmentation task
44 | - Tested with [Deeplab-v3-plus_mobilenetv2](https://github.com/jfzhang95/pytorch-deeplab-xception)
45 |
46 | Pascal VOC 2012 val set (mIOU) | Pascal VOC 2007 test set (mIOU) |
47 |
48 |
49 | model/precision | FP32 | Int8**| Int8*
50 | ----------------|-------|-------|------
51 | Original | 70.81 | 60.03 | 59.31
52 | +ReLU | 70.72 | 60.0 | 58.98
53 | +ReLU+LE | 70.72 | 66.22 | 66.0
54 | +ReLU+LE +DR | -- | 67.04 | 67.23
55 | +ReLU+BC | -- | 69.04 | 68.42
56 | +ReLU+BC +clip_15 | -- | 66.99 | 66.39
57 | +ReLU+LE+BC | -- | 69.46 | 69.22
58 | +ReLU+LE+BC +DR | -- | 70.12 | 69.7
59 |
60 | |
61 |
62 | model/precision | FP32 | Int8** | Int8*
63 | ----------------|-------|-------|-------
64 | Original | 74.54 | 62.36 | 61.21
65 | +ReLU | 74.35 | 61.66 | 61.04
66 | +ReLU+LE | 74.35 | 69.47 | 69.6
67 | +ReLU+LE +DR | -- | 70.28 | 69.93
68 | +BC | -- | 72.1 | 70.97
69 | +BC +clip_15 | -- | 70.16 | 70.76
70 | +ReLU+LE+BC | -- | 72.84 | 72.58
71 | +ReLU+LE+BC +DR | -- | 73.5 | 73.04
72 |
73 | |
74 |
75 | ### On detection task
76 | - Tested with [MobileNetV2 SSD-Lite model](https://github.com/qfgaohao/pytorch-ssd)
77 |
78 |
79 | Pascal VOC 2012 val set (mAP with 12 metric) | Pascal VOC 2007 test set (mAP with 07 metric) |
80 |
81 |
82 | model/precision | FP32 | Int8**|Int8*
83 | -----------|------|------|------
84 | Original | 78.51 | 77.71 | 77.86
85 | +ReLU | 75.42 | 75.74 | 75.58
86 | +ReLU+LE | 75.42 | 75.32 | 75.37
87 | +ReLU+LE +DR | -- | 74.65 | 74.32
88 | +BC | -- | 77.73 | 77.78
89 | +BC +clip_15 | -- | 77.73 | 77.78
90 | +ReLU+LE+BC | -- | 75.66 | 75.66
91 | +ReLU+LE+BC +DR | -- | 74.92 | 74.65
92 |
93 | |
94 |
95 | model/precision | FP32 | Int8** | Int8*
96 | ----------------|-------|-------|-------
97 | Original | 68.70 | 68.47 | 68.49
98 | +ReLU | 65.47 | 65.36 | 65.56
99 | +ReLU+LE | 65.47 | 65.36 | 65.27
100 | +ReLU+LE +DR | -- | 64.53 | 64.46
101 | +BC | -- | 68.32 | 65.33
102 | +BC +clip_15 | -- | 68.32 | 65.33
103 | +ReLU+LE+BC | -- | 65.63 | 65.58
104 | +ReLU+LE+BC +DR | -- | 64.92 | 64.42
105 |
106 | |
107 |
108 | ## Usage
109 | There are 6 arguments, all default to False
110 | 1. quantize: whether to quantize parameters and activations.
111 | 2. relu: whether to replace relu6 to relu.
112 | 3. equalize: whether to perform cross layer equalization.
113 | 4. correction: whether to apply bias correction
114 | 5. clip_weight: whether to clip weights in range [-15, 15] (for convolution and linear layer)
115 | 6. distill_range: whether to use distill data for setting min/max range of activation quantization
116 |
117 | run the equalized model by:
118 | ```
119 | python main_cls.py --quantize --relu --equalize
120 | ```
121 |
122 | run the equalized and bias-corrected model by:
123 | ```
124 | python main_cls.py --quantize --relu --equalize --correction
125 | ```
126 |
127 | run the equalized and bias-corrected model with distilled data by:
128 | ```
129 | python main_cls.py --quantize --relu --equalize --correction --distill_range
130 | ```
131 |
132 | export equalized and bias-corrected model to onnx and generage calibration table file:
133 | ```
134 | python convert_ncnn.py --equalize --correction --quantize --relu --ncnn_build path_to_ncnn_build_folder
135 | ```
136 |
137 | ## Note
138 | ### Distilled Data (2020/02/03 updated)
139 | According to recent paper [ZeroQ](https://github.com/amirgholami/ZeroQ), we can distill some fake data to match the statistics from batch-normalization layers, then use it to set the min/max value range of activation quantization.
140 | It does not need each conv followed by batch norm layer, and should produce better and **more stable** results using distilled data (the method from DFQ sometimes failed to find a good enough value range).
141 |
142 | Here are some modifications that differs from original ZeroQ implementation:
143 | 1. Initialization of distilled data
144 | 2. Early stop criterion
145 |
146 | ~~Also, I think it can be applied to optimizing cross layer equalization and bias correction. The results will be updated as long as I make it to work.~~
147 | Using distilled data to do LE or BC did not perform as good as using estimation from batch norm layers, probably because of overfitting.
148 |
149 | ### Fake Quantization
150 | The 'Int8' model in this repo is actually simulation of 8 bits, the actual calculation is done in floating points.
151 | This is done by quantizing-dequantizing parameters in each layer and activation between 2 consecutive layers;
152 | Which means each tensor will have dtype 'float32', but there would be at most 256 (2^8) unique values in it.
153 | ```
154 | Weight_quant(Int8) = Quant(Weight)
155 | Weight_quant(FP32) = Weight_quant(Int8*) = Dequant(Quant(Weight))
156 | ```
157 |
158 | ### 16-bits Quantization for Bias
159 | Somehow I cannot make **Bias-Correction** work on 8-bits bias quantization for all scenarios (even with data dependent correction).
160 | I am not sure how the original paper managed to do it with 8 bits quantization, but I guess they either use some non-uniform quantization techniques or use more bits for bias parameters as I do.
161 |
162 | ### Int8 inference
163 | Refer to [ncnn](https://github.com/Tencent/ncnn), [pytorch2ncnn](https://github.com/Tencent/ncnn/wiki/use-ncnn-with-pytorch-or-onnx), [ncnn-quantize](https://github.com/Tencent/ncnn/tree/master/tools/quantize), [ncnn-int8-inference](https://github.com/Tencent/ncnn/wiki/quantized-int8-inference) for more details.
164 | You will need to install/build the followings:
165 | [ncnn](https://github.com/Tencent/ncnn)
166 | [onnx-simplifier](https://github.com/daquexian/onnx-simplifier)
167 |
168 | Inference_cls.cpp only implements mobilenetv2. Basic steps are:
169 |
170 | 1. Run convert_ncnn.py to convert pytorch model (with layer equalization or bias correction) to ncnn int8 model and generate calibration table file. The name of out_layer will be printed to console.
171 | ```
172 | python convert_ncnn.py --quantize --relu --equalize --correction
173 | ```
174 |
175 | 2. compile inference_cls.cpp
176 | ```
177 | mkdir build
178 | cd build
179 | cmake ..
180 | make
181 | ```
182 | 3. Inference! [link](https://github.com/Tencent/ncnn/wiki/quantized-int8-inference)
183 | ```
184 | ./inference_cls --images=path_to_imagenet_validation_set --param=../modeling/ncnn/model_int8.param --bin=../modeling/ncnn/model_int8.bin --out_layer=name_from_step1
185 | ```
186 |
187 | ## TODO
188 | - [x] cross layer equalization
189 | - [ ] high bias absorption
190 | - [x] data-free bias correction
191 | - [x] test with detection model
192 | - [x] test with classification model
193 | - [x] use distilled data to set min/max activation range
194 | - [ ] ~~use distilled data to find optimal scale matrix~~
195 | - [ ] ~~use distilled data to do bias correction~~
196 | - [x] True Int8 inference
197 |
198 | ## Acknowledgment
199 | - https://github.com/jfzhang95/pytorch-deeplab-xception
200 | - https://github.com/ricky40403/PyTransformer
201 | - https://github.com/qfgaohao/pytorch-ssd
202 | - https://github.com/tonylins/pytorch-mobilenet-v2
203 | - https://github.com/xxradon/PytorchToCaffe
204 | - https://github.com/amirgholami/ZeroQ
205 |
--------------------------------------------------------------------------------
/ZeroQ/README.md:
--------------------------------------------------------------------------------
1 | # ZeroQ: A Novel Zero Shot Quantization Framework
2 |
3 |
4 |
5 | ## Introduction
6 |
7 | This repository contains the PyTorch implementation for the paper [*ZeroQ: A Novel Zero-Shot Quantization Framework*](https://arxiv.org/abs/2001.00281).
8 |
9 | ## TLDR;
10 |
11 | ```bash
12 | # Code is based on PyTorch 1.2 (Cuda10). Other dependancies could be installed as follows:
13 | pip install -r requirements.txt --user
14 | # Set a symbolic link to ImageNet validation data (used only to evaluate model)
15 | mkdir data
16 | ln -s /path/to/imagenet/ data/
17 | ```
18 |
19 | The folder structures should be the same as following
20 | ```
21 | zeroq
22 | ├── utils
23 | ├── data
24 | │ ├── imagenet
25 | │ │ ├── val
26 | ```
27 | Afterwards you can test Zero Shot quantization with W8A8 by running:
28 |
29 | ```bash
30 | bash run.sh
31 | ```
32 |
33 | Below are the results that you should get for 8-bit quantization (**W8A8** refers to the quantizing model to 8-bit weights and 8-bit activations).
34 |
35 |
36 | | Models | Single Precision Top-1 | W8A8 Top-1 |
37 | | ----------------------------------------------- | :--------------------: | :--------: |
38 | | [ResNet18](https://arxiv.org/abs/1512.03385) | 71.47 | 71.43 |
39 | | [ResNet50](https://arxiv.org/abs/1512.03385) | 77.72 | 77.67 |
40 | | [InceptionV3](https://arxiv.org/abs/1512.00567) | 78.88 | 78.72 |
41 | | [MobileNetV2](https://arxiv.org/abs/1801.04381) | 73.03 | 72.91 |
42 | | [ShuffleNet](https://arxiv.org/abs/1707.01083) | 65.07 | 64.94 |
43 | | [SqueezeNext](https://arxiv.org/abs/1803.10615) | 69.38 | 69.17 |
44 |
45 | ## Evaluate
46 |
47 | - You can test a single model using the following command:
48 |
49 | ```bash
50 | export CUDA_VISIBLE_DEVICES=0
51 | python uniform_test.py [--dataset] [--model] [--batch_size] [--test_batch_size]
52 |
53 | optional arguments:
54 | --dataset type of dataset (default: imagenet)
55 | --model model to be quantized (default: resnet18)
56 | --batch-size batch size of distilled data (default: 64)
57 | --test-batch-size batch size of test data (default: 512)
58 | ```
59 |
60 |
61 |
62 |
63 | ## Citation
64 | ZeroQ has been developed as part of the following paper. We appreciate it if you would please cite the following paper if you found the implementation useful for your work:
65 |
66 | Y. Cai, Z. Yao, Z. Dong, A. Gholami, M. W. Mahoney, K. Keutzer. *ZeroQ: A Novel Zero Shot Quantization Framework*, under review [[PDF](https://arxiv.org/pdf/2001.00281.pdf)].
67 |
68 |
--------------------------------------------------------------------------------
/ZeroQ/reconstruct_data.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import os
22 | import json
23 | import torch
24 | import torch.nn as nn
25 | import copy
26 | import torch.optim as optim
27 | from utils import *
28 |
29 |
30 | def own_loss(A, B):
31 | """
32 | L-2 loss between A and B normalized by length.
33 | A and B should have the same length
34 | """
35 | return (A - B).norm()**2 / A.size(0)
36 |
37 |
38 | class output_hook(object):
39 | """
40 | Forward_hook used to get the output of intermediate layer.
41 | """
42 | def __init__(self):
43 | super(output_hook, self).__init__()
44 | self.outputs = None
45 |
46 | def hook(self, module, input, output):
47 | self.outputs = output
48 |
49 | def clear(self):
50 | self.outputs = None
51 |
52 |
53 | def getReconData(teacher_model,
54 | dataset,
55 | batch_size,
56 | num_batch=1,
57 | for_inception=False):
58 | """
59 | Generate distilled data according to the BatchNorm statistics in pretrained single-precision model.
60 | Only support single GPU.
61 |
62 | teacher_model: pretrained single-precision model
63 | dataset: the name of dataset
64 | batch_size: the batch size of generated distilled data
65 | num_batch: the number of batch of generated distilled data
66 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224
67 | """
68 |
69 | # initialize distilled data with random noise according to the dataset
70 | dataloader = getRandomData(dataset=dataset,
71 | batch_size=batch_size,
72 | for_inception=for_inception)
73 |
74 | eps = 1e-6
75 | # initialize hooks and single-precision model
76 | hooks, hook_handles, bn_stats, refined_gaussian = [], [], [], []
77 | teacher_model = teacher_model.cuda()
78 | teacher_model = teacher_model.eval()
79 |
80 | # get number of BatchNorm layers in the model
81 | layers = sum([
82 | 1 if isinstance(layer, nn.BatchNorm2d) else 0
83 | for layer in teacher_model.modules()
84 | ])
85 |
86 | for n, m in teacher_model.named_modules():
87 | if isinstance(m, nn.Conv2d) and len(hook_handles) < layers:
88 | # register hooks on the convolutional layers to get the intermediate output after convolution and before BatchNorm.
89 | hook = output_hook()
90 | hooks.append(hook)
91 | hook_handles.append(m.register_forward_hook(hook.hook))
92 | if isinstance(m, nn.BatchNorm2d):
93 | # get the statistics in the BatchNorm layers
94 | bn_stats.append(
95 | (m.running_mean.detach().clone().flatten().cuda(),
96 | torch.sqrt(m.running_var +
97 | eps).detach().clone().flatten().cuda()))
98 | assert len(hooks) == len(bn_stats)
99 |
100 | for i, gaussian_data in enumerate(dataloader):
101 | if i == num_batch:
102 | break
103 | # initialize the criterion, optimizer, and scheduler
104 | gaussian_data = gaussian_data.cuda()
105 | gaussian_data.requires_grad = True
106 | crit = nn.CrossEntropyLoss().cuda()
107 | optimizer = optim.Adam([gaussian_data], lr=0.1)
108 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
109 | min_lr=1e-4,
110 | verbose=False,
111 | patience=100)
112 |
113 | input_mean = torch.zeros(1, 3).cuda()
114 | input_std = torch.ones(1, 3).cuda()
115 |
116 | for it in range(500):
117 | teacher_model.zero_grad()
118 | optimizer.zero_grad()
119 | for hook in hooks:
120 | hook.clear()
121 | output = teacher_model(gaussian_data)
122 | mean_loss = 0
123 | std_loss = 0
124 |
125 | # compute the loss according to the BatchNorm statistics and the statistics of intermediate output
126 | for cnt, (bn_stat, hook) in enumerate(zip(bn_stats, hooks)):
127 | tmp_output = hook.outputs
128 | bn_mean, bn_std = bn_stat[0], bn_stat[1]
129 | tmp_mean = torch.mean(tmp_output.view(tmp_output.size(0),
130 | tmp_output.size(1), -1),
131 | dim=2)
132 | tmp_std = torch.sqrt(
133 | torch.var(tmp_output.view(tmp_output.size(0),
134 | tmp_output.size(1), -1),
135 | dim=2) + eps)
136 | mean_loss += own_loss(bn_mean, tmp_mean)
137 | std_loss += own_loss(bn_std, tmp_std)
138 | tmp_mean = torch.mean(gaussian_data.view(gaussian_data.size(0), 3,
139 | -1),
140 | dim=2)
141 | tmp_std = torch.sqrt(
142 | torch.var(gaussian_data.view(gaussian_data.size(0), 3, -1),
143 | dim=2) + eps)
144 | mean_loss += own_loss(tmp_mean, input_mean)
145 | std_loss += own_loss(tmp_std, input_std)
146 | total_loss = mean_loss + std_loss
147 |
148 | # update the distilled data
149 | total_loss.backward()
150 | optimizer.step()
151 | scheduler.step(total_loss.item())
152 |
153 | # early stop to prevent overfit
154 | if total_loss <= (layers + 1) * 5:
155 | break
156 |
157 | refined_gaussian.append(gaussian_data.detach().clone())
158 |
159 | for handle in hook_handles:
160 | handle.remove()
161 | return refined_gaussian
162 |
--------------------------------------------------------------------------------
/ZeroQ/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorchcv==0.0.51
2 | progressbar>=1.5
3 |
--------------------------------------------------------------------------------
/ZeroQ/run.sh:
--------------------------------------------------------------------------------
1 | for MODEL in resnet18 resnet50 inceptionv3 mobilenetv2_w1 shufflenet_g1_w1 sqnxt23_w2
2 | do
3 | echo Testing $MODEL ...
4 | python uniform_test.py \
5 | --dataset=imagenet \
6 | --model=$MODEL \
7 | --batch_size=64 \
8 | --test_batch_size=512
9 | done
10 |
--------------------------------------------------------------------------------
/ZeroQ/uniform_test.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import argparse
22 | import torch
23 | import numpy as np
24 | import torch.nn as nn
25 | from pytorchcv.model_provider import get_model as ptcv_get_model
26 | from utils import *
27 | from distill_data import *
28 |
29 |
30 | # model settings
31 | def arg_parse():
32 | parser = argparse.ArgumentParser(
33 | description='This repository contains the PyTorch implementation for the paper ZeroQ: A Novel Zero-Shot Quantization Framework.')
34 | parser.add_argument('--dataset',
35 | type=str,
36 | default='imagenet',
37 | choices=['imagenet', 'cifar10'],
38 | help='type of dataset')
39 | parser.add_argument('--model',
40 | type=str,
41 | default='resnet18',
42 | choices=[
43 | 'resnet18', 'resnet50', 'inceptionv3',
44 | 'mobilenetv2_w1', 'shufflenet_g1_w1',
45 | 'resnet20_cifar10', 'sqnxt23_w2'
46 | ],
47 | help='model to be quantized')
48 | parser.add_argument('--batch_size',
49 | type=int,
50 | default=32,
51 | help='batch size of distilled data')
52 | parser.add_argument('--test_batch_size',
53 | type=int,
54 | default=128,
55 | help='batch size of test data')
56 | args = parser.parse_args()
57 | return args
58 |
59 |
60 | if __name__ == '__main__':
61 | args = arg_parse()
62 | torch.backends.cudnn.deterministic = True
63 | torch.backends.cudnn.benchmark = False
64 |
65 | # Load pretrained model
66 | model = ptcv_get_model(args.model, pretrained=True)
67 | print('****** Full precision model loaded ******')
68 |
69 | # Load validation data
70 | test_loader = getTestData(args.dataset,
71 | batch_size=args.test_batch_size,
72 | path='./data/imagenet/',
73 | for_inception=args.model.startswith('inception'))
74 | # Generate distilled data
75 | dataloader = getDistilData(
76 | model.cuda(),
77 | args.dataset,
78 | batch_size=args.batch_size,
79 | for_inception=args.model.startswith('inception'))
80 | print('****** Data loaded ******')
81 |
82 | # Quantize single-precision model to 8-bit model
83 | quantized_model = quantize_model(model)
84 | # Freeze BatchNorm statistics
85 | quantized_model.eval()
86 | quantized_model = quantized_model.cuda()
87 |
88 | # Update activation range according to distilled data
89 | update(quantized_model, dataloader)
90 |
91 | # Freeze activation range during test
92 | freeze_model(quantized_model)
93 | quantized_model = nn.DataParallel(quantized_model).cuda()
94 |
95 | # Test the final quantized model
96 | test(quantized_model, test_loader)
97 |
--------------------------------------------------------------------------------
/ZeroQ/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantize_model import *
2 | from .data_utils import *
3 | # from .train_utils import *
4 | from .quantization_utils.quant_utils import *
--------------------------------------------------------------------------------
/ZeroQ/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | from torch.utils.data import Dataset, DataLoader
22 | from torchvision import datasets, transforms
23 | import torch
24 |
25 |
26 | class UniformDataset(Dataset):
27 | """
28 | get random uniform samples with mean 0 and variance 1
29 | """
30 | def __init__(self, length, size, transform, max_value):
31 | self.length = length
32 | self.transform = transform
33 | self.size = size
34 | self.max_value = max_value
35 |
36 | def __len__(self):
37 | return self.length
38 |
39 | def __getitem__(self, idx):
40 | # var[U(-128, 127)] = (127 - (-128))**2 / 12 = 5418.75
41 | # sample = (torch.randint(high=255, size=self.size).float() -
42 | # 127.5) / 5418.75
43 | sample = ((torch.randint(high=255, size=self.size).float() - 127.) / 128.) * self.max_value
44 | return sample
45 |
46 |
47 | def getRandomData(dataset='cifar10', batch_size=512, for_inception=False, max_value=3.0, size=[224, 224]):
48 | """
49 | get random sample dataloader
50 | dataset: name of the dataset
51 | batch_size: the batch size of random data
52 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224
53 | """
54 | if dataset == 'cifar10':
55 | size = (3, 32, 32)
56 | num_data = 10000
57 | elif dataset == 'imagenet':
58 | num_data = 10000
59 | # if not for_inception:
60 | # size = (3, 224, 224)
61 | # else:
62 | # size = (3, 299, 299)
63 | size = (3, size[0], size[1])
64 | else:
65 | raise NotImplementedError
66 | dataset = UniformDataset(length=10000, size=size, transform=None, max_value=max_value)
67 | data_loader = DataLoader(dataset,
68 | batch_size=batch_size,
69 | shuffle=False,
70 | num_workers=0)
71 | return data_loader
72 |
73 |
74 | def getTestData(dataset='imagenet',
75 | batch_size=1024,
76 | path='data/imagenet',
77 | for_inception=False):
78 | """
79 | Get dataloader of testset
80 | dataset: name of the dataset
81 | batch_size: the batch size of random data
82 | path: the path to the data
83 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224
84 | """
85 | if dataset == 'imagenet':
86 | input_size = 299 if for_inception else 224
87 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
88 | std=[0.229, 0.224, 0.225])
89 | test_dataset = datasets.ImageFolder(
90 | path + 'val',
91 | transforms.Compose([
92 | transforms.Resize(int(input_size / 0.875)),
93 | transforms.CenterCrop(input_size),
94 | transforms.ToTensor(),
95 | normalize,
96 | ]))
97 | test_loader = DataLoader(test_dataset,
98 | batch_size=batch_size,
99 | shuffle=False,
100 | num_workers=32)
101 | return test_loader
102 | elif dataset == 'cifar10':
103 | data_dir = '/rscratch/yaohuic/data/'
104 | normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
105 | std=(0.2023, 0.1994, 0.2010))
106 | transform_test = transforms.Compose([transforms.ToTensor(), normalize])
107 |
108 | test_dataset = datasets.CIFAR10(root=data_dir,
109 | train=False,
110 | transform=transform_test)
111 | test_loader = DataLoader(test_dataset,
112 | batch_size=batch_size,
113 | shuffle=False,
114 | num_workers=32)
115 | return test_loader
116 |
--------------------------------------------------------------------------------
/ZeroQ/utils/quantization_utils/quant_modules.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import torch
22 | import time
23 | import math
24 | import numpy as np
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | from torch.nn import Module, Parameter
28 | from .quant_utils import *
29 | import sys
30 |
31 |
32 | class QuantAct(Module):
33 | """
34 | Class to quantize given activations
35 | """
36 | def __init__(self,
37 | activation_bit,
38 | full_precision_flag=False,
39 | running_stat=True):
40 | """
41 | activation_bit: bit-setting for activation
42 | full_precision_flag: full precision or not
43 | running_stat: determines whether the activation range is updated or froze
44 | """
45 | super(QuantAct, self).__init__()
46 | self.activation_bit = activation_bit
47 | self.momentum = 0.99
48 | self.full_precision_flag = full_precision_flag
49 | self.running_stat = running_stat
50 | self.register_buffer('x_min', torch.zeros(1))
51 | self.register_buffer('x_max', torch.zeros(1))
52 | self.act_function = AsymmetricQuantFunction.apply
53 |
54 | def __repr__(self):
55 | return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format(
56 | self.__class__.__name__, self.activation_bit,
57 | self.full_precision_flag, self.running_stat, self.x_min.item(),
58 | self.x_max.item())
59 |
60 | def fix(self):
61 | """
62 | fix the activation range by setting running stat
63 | """
64 | self.running_stat = False
65 |
66 | def forward(self, x):
67 | """
68 | quantize given activation x
69 | """
70 | if self.running_stat:
71 | x_min = x.data.min()
72 | x_max = x.data.max()
73 | # in-place operation used on multi-gpus
74 | self.x_min += -self.x_min + min(self.x_min, x_min)
75 | self.x_max += -self.x_max + max(self.x_max, x_max)
76 |
77 | if not self.full_precision_flag:
78 | quant_act = self.act_function(x, self.activation_bit, self.x_min,
79 | self.x_max)
80 | return quant_act
81 | else:
82 | return x
83 |
84 |
85 | class Quant_Linear(Module):
86 | """
87 | Class to quantize given linear layer weights
88 | """
89 | def __init__(self, weight_bit, full_precision_flag=False):
90 | """
91 | weight: bit-setting for weight
92 | full_precision_flag: full precision or not
93 | running_stat: determines whether the activation range is updated or froze
94 | """
95 | super(Quant_Linear, self).__init__()
96 | self.full_precision_flag = full_precision_flag
97 | self.weight_bit = weight_bit
98 | self.weight_function = AsymmetricQuantFunction.apply
99 |
100 | def __repr__(self):
101 | s = super(Quant_Linear, self).__repr__()
102 | s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
103 | self.weight_bit, self.full_precision_flag)
104 | return s
105 |
106 | def set_param(self, linear):
107 | self.in_features = linear.in_features
108 | self.out_features = linear.out_features
109 | self.weight = Parameter(linear.weight.data.clone())
110 | try:
111 | self.bias = Parameter(linear.bias.data.clone())
112 | except AttributeError:
113 | self.bias = None
114 |
115 | def forward(self, x):
116 | """
117 | using quantized weights to forward activation x
118 | """
119 | w = self.weight
120 | x_transform = w.data.detach()
121 | w_min = x_transform.min(dim=1).values
122 | w_max = x_transform.max(dim=1).values
123 | if not self.full_precision_flag:
124 | w = self.weight_function(self.weight, self.weight_bit, w_min,
125 | w_max)
126 | else:
127 | w = self.weight
128 | return F.linear(x, weight=w, bias=self.bias)
129 |
130 |
131 | class Quant_Conv2d(Module):
132 | """
133 | Class to quantize given convolutional layer weights
134 | """
135 | def __init__(self, weight_bit, full_precision_flag=False):
136 | super(Quant_Conv2d, self).__init__()
137 | self.full_precision_flag = full_precision_flag
138 | self.weight_bit = weight_bit
139 | self.weight_function = AsymmetricQuantFunction.apply
140 |
141 | def __repr__(self):
142 | s = super(Quant_Conv2d, self).__repr__()
143 | s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
144 | self.weight_bit, self.full_precision_flag)
145 | return s
146 |
147 | def set_param(self, conv):
148 | self.in_channels = conv.in_channels
149 | self.out_channels = conv.out_channels
150 | self.kernel_size = conv.kernel_size
151 | self.stride = conv.stride
152 | self.padding = conv.padding
153 | self.dilation = conv.dilation
154 | self.groups = conv.groups
155 | self.weight = Parameter(conv.weight.data.clone())
156 | try:
157 | self.bias = Parameter(conv.bias.data.clone())
158 | except AttributeError:
159 | self.bias = None
160 |
161 | def forward(self, x):
162 | """
163 | using quantized weights to forward activation x
164 | """
165 | w = self.weight
166 | x_transform = w.data.contiguous().view(self.out_channels, -1)
167 | w_min = x_transform.min(dim=1).values
168 | w_max = x_transform.max(dim=1).values
169 | if not self.full_precision_flag:
170 | w = self.weight_function(self.weight, self.weight_bit, w_min,
171 | w_max)
172 | else:
173 | w = self.weight
174 |
175 | return F.conv2d(x, w, self.bias, self.stride, self.padding,
176 | self.dilation, self.groups)
177 |
--------------------------------------------------------------------------------
/ZeroQ/utils/quantization_utils/quant_utils.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import math
22 | import numpy as np
23 | from torch.autograd import Function, Variable
24 | import torch
25 |
26 |
27 | def clamp(input, min, max, inplace=False):
28 | """
29 | Clamp tensor input to (min, max).
30 | input: input tensor to be clamped
31 | """
32 |
33 | if inplace:
34 | input.clamp_(min, max)
35 | return input
36 | return torch.clamp(input, min, max)
37 |
38 |
39 | def linear_quantize(input, scale, zero_point, inplace=False):
40 | """
41 | Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
42 | input: single-precision input tensor to be quantized
43 | scale: scaling factor for quantization
44 | zero_pint: shift for quantization
45 | """
46 |
47 | # reshape scale and zeropoint for convolutional weights and activation
48 | if len(input.shape) == 4:
49 | scale = scale.view(-1, 1, 1, 1)
50 | zero_point = zero_point.view(-1, 1, 1, 1)
51 | # reshape scale and zeropoint for linear weights
52 | elif len(input.shape) == 2:
53 | scale = scale.view(-1, 1)
54 | zero_point = zero_point.view(-1, 1)
55 | # mapping single-precision input to integer values with the given scale and zeropoint
56 | if inplace:
57 | input.mul_(scale).sub_(zero_point).round_()
58 | return input
59 | return torch.round(scale * input - zero_point)
60 |
61 |
62 | def linear_dequantize(input, scale, zero_point, inplace=False):
63 | """
64 | Map integer input tensor to fixed point float point with given scaling factor and zeropoint.
65 | input: integer input tensor to be mapped
66 | scale: scaling factor for quantization
67 | zero_pint: shift for quantization
68 | """
69 |
70 | # reshape scale and zeropoint for convolutional weights and activation
71 | if len(input.shape) == 4:
72 | scale = scale.view(-1, 1, 1, 1)
73 | zero_point = zero_point.view(-1, 1, 1, 1)
74 | # reshape scale and zeropoint for linear weights
75 | elif len(input.shape) == 2:
76 | scale = scale.view(-1, 1)
77 | zero_point = zero_point.view(-1, 1)
78 | # mapping integer input to fixed point float point value with given scaling factor and zeropoint
79 | if inplace:
80 | input.add_(zero_point).div_(scale)
81 | return input
82 | return (input + zero_point) / scale
83 |
84 |
85 | def asymmetric_linear_quantization_params(num_bits,
86 | saturation_min,
87 | saturation_max,
88 | integral_zero_point=True,
89 | signed=True):
90 | """
91 | Compute the scaling factor and zeropoint with the given quantization range.
92 | saturation_min: lower bound for quantization range
93 | saturation_max: upper bound for quantization range
94 | """
95 | n = 2**num_bits - 1
96 | scale = n / torch.clamp((saturation_max - saturation_min), min=1e-8)
97 | zero_point = scale * saturation_min
98 |
99 | if integral_zero_point:
100 | if isinstance(zero_point, torch.Tensor):
101 | zero_point = zero_point.round()
102 | else:
103 | zero_point = float(round(zero_point))
104 | if signed:
105 | zero_point += 2**(num_bits - 1)
106 | return scale, zero_point
107 |
108 |
109 | class AsymmetricQuantFunction(Function):
110 | """
111 | Class to quantize the given floating-point values with given range and bit-setting.
112 | Currently only support inference, but not support back-propagation.
113 | """
114 | @staticmethod
115 | def forward(ctx, x, k, x_min=None, x_max=None):
116 | """
117 | x: single-precision value to be quantized
118 | k: bit-setting for x
119 | x_min: lower bound for quantization range
120 | x_max=None
121 | """
122 |
123 | if x_min is None or x_max is None or (sum(x_min == x_max) == 1
124 | and x_min.numel() == 1):
125 | x_min, x_max = x.min(), x.max()
126 | scale, zero_point = asymmetric_linear_quantization_params(
127 | k, x_min, x_max)
128 | new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
129 | n = 2**(k - 1)
130 | new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
131 | quant_x = linear_dequantize(new_quant_x,
132 | scale,
133 | zero_point,
134 | inplace=False)
135 | return torch.autograd.Variable(quant_x)
136 |
137 | @staticmethod
138 | def backward(ctx, grad_output):
139 | raise NotImplementedError
140 |
--------------------------------------------------------------------------------
/ZeroQ/utils/quantize_model.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import torch
22 | import torch.nn as nn
23 | import copy
24 | from .quantization_utils.quant_modules import *
25 | from pytorchcv.models.common import ConvBlock
26 | from pytorchcv.models.shufflenetv2 import ShuffleUnit, ShuffleInitBlock
27 |
28 |
29 | def quantize_model(model):
30 | """
31 | Recursively quantize a pretrained single-precision model to int8 quantized model
32 | model: pretrained single-precision model
33 | """
34 |
35 | # quantize convolutional and linear layers to 8-bit
36 | if type(model) == nn.Conv2d:
37 | quant_mod = Quant_Conv2d(weight_bit=8)
38 | quant_mod.set_param(model)
39 | return quant_mod
40 | elif type(model) == nn.Linear:
41 | quant_mod = Quant_Linear(weight_bit=8)
42 | quant_mod.set_param(model)
43 | return quant_mod
44 |
45 | # quantize all the activation to 8-bit
46 | elif type(model) == nn.ReLU or type(model) == nn.ReLU6:
47 | return nn.Sequential(*[model, QuantAct(activation_bit=8)])
48 |
49 | # recursively use the quantized module to replace the single-precision module
50 | elif type(model) == nn.Sequential:
51 | mods = []
52 | for n, m in model.named_children():
53 | mods.append(quantize_model(m))
54 | return nn.Sequential(*mods)
55 | else:
56 | q_model = copy.deepcopy(model)
57 | for attr in dir(model):
58 | mod = getattr(model, attr)
59 | if isinstance(mod, nn.Module) and 'norm' not in attr:
60 | setattr(q_model, attr, quantize_model(mod))
61 | return q_model
62 |
63 |
64 | def freeze_model(model):
65 | """
66 | freeze the activation range
67 | """
68 | if type(model) == QuantAct:
69 | model.fix()
70 | elif type(model) == nn.Sequential:
71 | mods = []
72 | for n, m in model.named_children():
73 | freeze_model(m)
74 | else:
75 | for attr in dir(model):
76 | mod = getattr(model, attr)
77 | if isinstance(mod, nn.Module) and 'norm' not in attr:
78 | freeze_model(mod)
79 | return model
80 |
81 |
82 | def unfreeze_model(model):
83 | """
84 | unfreeze the activation range
85 | """
86 | if type(model) == QuantAct:
87 | model.unfix()
88 | elif type(model) == nn.Sequential:
89 | mods = []
90 | for n, m in model.named_children():
91 | unfreeze_model(m)
92 | else:
93 | for attr in dir(model):
94 | mod = getattr(model, attr)
95 | if isinstance(mod, nn.Module) and 'norm' not in attr:
96 | unfreeze_model(mod)
97 | return model
98 |
--------------------------------------------------------------------------------
/ZeroQ/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | #*
2 | # @file Different utility functions
3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
4 | # All rights reserved.
5 | # This file is part of ZeroQ repository.
6 | #
7 | # ZeroQ is free software: you can redistribute it and/or modify
8 | # it under the terms of the GNU General Public License as published by
9 | # the Free Software Foundation, either version 3 of the License, or
10 | # (at your option) any later version.
11 | #
12 | # ZeroQ is distributed in the hope that it will be useful,
13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 | # GNU General Public License for more details.
16 | #
17 | # You should have received a copy of the GNU General Public License
18 | # along with ZeroQ repository. If not, see .
19 | #*
20 |
21 | import torch
22 | import os
23 | import torch.nn as nn
24 | from progress.bar import Bar
25 |
26 |
27 | def test(model, test_loader):
28 | """
29 | test a model on a given dataset
30 | """
31 | total, correct = 0, 0
32 | bar = Bar('Testing', max=len(test_loader))
33 | model.eval()
34 | with torch.no_grad():
35 | for batch_idx, (inputs, targets) in enumerate(test_loader):
36 | inputs, targets = inputs.cuda(), targets.cuda()
37 | outputs = model(inputs)
38 | _, predicted = outputs.max(1)
39 | total += targets.size(0)
40 | correct += predicted.eq(targets).sum().item()
41 | acc = correct / total
42 |
43 | bar.suffix = f'({batch_idx + 1}/{len(test_loader)}) | ETA: {bar.eta_td} | top1: {acc}'
44 | bar.next()
45 | print('\nFinal acc: %.2f%% (%d/%d)' % (100. * acc, correct, total))
46 | bar.finish()
47 | model.train()
48 | return acc
49 |
50 |
51 | def update(quantized_model, distilD):
52 | """
53 | Update activation range according to distilled data
54 | quantized_model: a quantized model whose activation range to be updated
55 | distilD: distilled data
56 | """
57 | print('******updateing BN stats...', end='')
58 | with torch.no_grad():
59 | for batch_idx, inputs in enumerate(distilD):
60 | if isinstance(inputs, list):
61 | inputs = inputs[0]
62 | inputs = inputs.cuda()
63 | outputs = quantized_model(inputs)
64 | print(' Finished******')
65 | return quantized_model
66 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/detection/__init__.py
--------------------------------------------------------------------------------
/dataset/detection/open_images.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pathlib
3 | import cv2
4 | import pandas as pd
5 | import copy
6 |
7 | class OpenImagesDataset:
8 |
9 | def __init__(self, root,
10 | transform=None, target_transform=None,
11 | dataset_type="train", balance_data=False):
12 | self.root = pathlib.Path(root)
13 | self.transform = transform
14 | self.target_transform = target_transform
15 | self.dataset_type = dataset_type.lower()
16 |
17 | self.data, self.class_names, self.class_dict = self._read_data()
18 | self.balance_data = balance_data
19 | self.min_image_num = -1
20 | if self.balance_data:
21 | self.data = self._balance_data()
22 | self.ids = [info['image_id'] for info in self.data]
23 |
24 | self.class_stat = None
25 |
26 | def _getitem(self, index):
27 | image_info = self.data[index]
28 | image = self._read_image(image_info['image_id'])
29 | # duplicate boxes to prevent corruption of dataset
30 | boxes = copy.copy(image_info['boxes'])
31 | boxes[:, 0] *= image.shape[1]
32 | boxes[:, 1] *= image.shape[0]
33 | boxes[:, 2] *= image.shape[1]
34 | boxes[:, 3] *= image.shape[0]
35 | # duplicate labels to prevent corruption of dataset
36 | labels = copy.copy(image_info['labels'])
37 | if self.transform:
38 | image, boxes, labels = self.transform(image, boxes, labels)
39 | if self.target_transform:
40 | boxes, labels = self.target_transform(boxes, labels)
41 | return image_info['image_id'], image, boxes, labels
42 |
43 | def __getitem__(self, index):
44 | _, image, boxes, labels = self._getitem(index)
45 | return image, boxes, labels
46 |
47 | def get_annotation(self, index):
48 | """To conform the eval_ssd implementation that is based on the VOC dataset."""
49 | image_id, image, boxes, labels = self._getitem(index)
50 | is_difficult = np.zeros(boxes.shape[0], dtype=np.uint8)
51 | return image_id, (boxes, labels, is_difficult)
52 |
53 | def get_image(self, index):
54 | image_info = self.data[index]
55 | image = self._read_image(image_info['image_id'])
56 | if self.transform:
57 | image, _ = self.transform(image)
58 | return image
59 |
60 | def _read_data(self):
61 | annotation_file = f"{self.root}/sub-{self.dataset_type}-annotations-bbox.csv"
62 | annotations = pd.read_csv(annotation_file)
63 | class_names = ['BACKGROUND'] + sorted(list(annotations['ClassName'].unique()))
64 | class_dict = {class_name: i for i, class_name in enumerate(class_names)}
65 | data = []
66 | for image_id, group in annotations.groupby("ImageID"):
67 | boxes = group.loc[:, ["XMin", "YMin", "XMax", "YMax"]].values.astype(np.float32)
68 | # make labels 64 bits to satisfy the cross_entropy function
69 | labels = np.array([class_dict[name] for name in group["ClassName"]], dtype='int64')
70 | data.append({
71 | 'image_id': image_id,
72 | 'boxes': boxes,
73 | 'labels': labels
74 | })
75 | return data, class_names, class_dict
76 |
77 | def __len__(self):
78 | return len(self.data)
79 |
80 | def __repr__(self):
81 | if self.class_stat is None:
82 | self.class_stat = {name: 0 for name in self.class_names[1:]}
83 | for example in self.data:
84 | for class_index in example['labels']:
85 | class_name = self.class_names[class_index]
86 | self.class_stat[class_name] += 1
87 | content = ["Dataset Summary:"
88 | f"Number of Images: {len(self.data)}",
89 | f"Minimum Number of Images for a Class: {self.min_image_num}",
90 | "Label Distribution:"]
91 | for class_name, num in self.class_stat.items():
92 | content.append(f"\t{class_name}: {num}")
93 | return "\n".join(content)
94 |
95 | def _read_image(self, image_id):
96 | image_file = self.root / self.dataset_type / f"{image_id}.jpg"
97 | image = cv2.imread(str(image_file))
98 | if image.shape[2] == 1:
99 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
100 | else:
101 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
102 | return image
103 |
104 | def _balance_data(self):
105 | label_image_indexes = [set() for _ in range(len(self.class_names))]
106 | for i, image in enumerate(self.data):
107 | for label_id in image['labels']:
108 | label_image_indexes[label_id].add(i)
109 | label_stat = [len(s) for s in label_image_indexes]
110 | self.min_image_num = min(label_stat[1:])
111 | sample_image_indexes = set()
112 | for image_indexes in label_image_indexes[1:]:
113 | image_indexes = np.array(list(image_indexes))
114 | sub = np.random.permutation(image_indexes)[:self.min_image_num]
115 | sample_image_indexes.update(sub)
116 | sample_data = [self.data[i] for i in sample_image_indexes]
117 | return sample_data
118 |
119 |
120 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/dataset/detection/voc_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 | import pathlib
4 | import xml.etree.ElementTree as ET
5 | import cv2
6 | import os
7 |
8 |
9 | class VOCDataset:
10 |
11 | def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None):
12 | """Dataset for VOC data.
13 | Args:
14 | root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories:
15 | Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject.
16 | """
17 | self.root = pathlib.Path(root)
18 | self.transform = transform
19 | self.target_transform = target_transform
20 | if is_test:
21 | image_sets_file = self.root / "ImageSets/Main/test.txt"
22 | else:
23 | image_sets_file = self.root / "ImageSets/Main/val.txt"
24 | # image_sets_file = self.root / "ImageSets/Main/trainval.txt"
25 | self.ids = VOCDataset._read_image_ids(image_sets_file)
26 | self.keep_difficult = keep_difficult
27 |
28 | # if the labels file exists, read in the class names
29 | label_file_name = self.root / "labels.txt"
30 |
31 | if os.path.isfile(label_file_name):
32 | class_string = ""
33 | with open(label_file_name, 'r') as infile:
34 | for line in infile:
35 | class_string += line.rstrip()
36 |
37 | # classes should be a comma separated list
38 |
39 | classes = class_string.split(',')
40 | # prepend BACKGROUND as first class
41 | classes.insert(0, 'BACKGROUND')
42 | classes = [ elem.replace(" ", "") for elem in classes]
43 | self.class_names = tuple(classes)
44 | logging.info("VOC Labels read from file: " + str(self.class_names))
45 |
46 | else:
47 | logging.info("No labels file, using default VOC classes.")
48 | self.class_names = ('BACKGROUND',
49 | 'aeroplane', 'bicycle', 'bird', 'boat',
50 | 'bottle', 'bus', 'car', 'cat', 'chair',
51 | 'cow', 'diningtable', 'dog', 'horse',
52 | 'motorbike', 'person', 'pottedplant',
53 | 'sheep', 'sofa', 'train', 'tvmonitor')
54 |
55 |
56 | self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)}
57 |
58 | def __getitem__(self, index):
59 | image_id = self.ids[index]
60 | boxes, labels, is_difficult = self._get_annotation(image_id)
61 | if not self.keep_difficult:
62 | boxes = boxes[is_difficult == 0]
63 | labels = labels[is_difficult == 0]
64 | image = self._read_image(image_id)
65 | if self.transform:
66 | image, boxes, labels = self.transform(image, boxes, labels)
67 | if self.target_transform:
68 | boxes, labels = self.target_transform(boxes, labels)
69 | return image, boxes, labels
70 |
71 | def get_image(self, index):
72 | image_id = self.ids[index]
73 | image = self._read_image(image_id)
74 | if self.transform:
75 | image, _ = self.transform(image)
76 | return image
77 |
78 | def get_annotation(self, index):
79 | image_id = self.ids[index]
80 | return image_id, self._get_annotation(image_id)
81 |
82 | def __len__(self):
83 | return len(self.ids)
84 |
85 | @staticmethod
86 | def _read_image_ids(image_sets_file):
87 | ids = []
88 | with open(image_sets_file) as f:
89 | for line in f:
90 | ids.append(line.rstrip())
91 | return ids
92 |
93 | def _get_annotation(self, image_id):
94 | annotation_file = self.root / f"Annotations/{image_id}.xml"
95 | objects = ET.parse(annotation_file).findall("object")
96 | boxes = []
97 | labels = []
98 | is_difficult = []
99 | for object in objects:
100 | class_name = object.find('name').text.lower().strip()
101 | # we're only concerned with clases in our list
102 | if class_name in self.class_dict:
103 | bbox = object.find('bndbox')
104 |
105 | # VOC dataset format follows Matlab, in which indexes start from 0
106 | x1 = float(bbox.find('xmin').text) - 1
107 | y1 = float(bbox.find('ymin').text) - 1
108 | x2 = float(bbox.find('xmax').text) - 1
109 | y2 = float(bbox.find('ymax').text) - 1
110 | boxes.append([x1, y1, x2, y2])
111 |
112 | labels.append(self.class_dict[class_name])
113 | is_difficult_str = object.find('difficult').text
114 | is_difficult.append(int(is_difficult_str) if is_difficult_str else 0)
115 |
116 | return (np.array(boxes, dtype=np.float32),
117 | np.array(labels, dtype=np.int64),
118 | np.array(is_difficult, dtype=np.uint8))
119 |
120 | def _read_image(self, image_id):
121 | image_file = self.root / f"JPEGImages/{image_id}.jpg"
122 | image = cv2.imread(str(image_file))
123 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
124 | return image
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/dataset/segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/segmentation/__init__.py
--------------------------------------------------------------------------------
/dataset/segmentation/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import numpy as np
4 |
5 | from PIL import Image, ImageOps, ImageFilter
6 |
7 | class Normalize(object):
8 | """Normalize a tensor image with mean and standard deviation.
9 | Args:
10 | mean (tuple): means for each channel.
11 | std (tuple): standard deviations for each channel.
12 | """
13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
14 | self.mean = mean
15 | self.std = std
16 |
17 | def __call__(self, sample):
18 | img = sample['image']
19 | mask = sample['label']
20 | img = np.array(img).astype(np.float32)
21 | mask = np.array(mask).astype(np.float32)
22 | img /= 255.0
23 | img -= self.mean
24 | img /= self.std
25 |
26 | return {'image': img,
27 | 'label': mask}
28 |
29 |
30 | class ToTensor(object):
31 | """Convert ndarrays in sample to Tensors."""
32 |
33 | def __call__(self, sample):
34 | # swap color axis because
35 | # numpy image: H x W x C
36 | # torch image: C X H X W
37 | img = sample['image']
38 | mask = sample['label']
39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1))
40 | mask = np.array(mask).astype(np.float32)
41 |
42 | img = torch.from_numpy(img).float()
43 | mask = torch.from_numpy(mask).float()
44 |
45 | return {'image': img,
46 | 'label': mask}
47 |
48 |
49 | class RandomHorizontalFlip(object):
50 | def __call__(self, sample):
51 | img = sample['image']
52 | mask = sample['label']
53 | if random.random() < 0.5:
54 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
56 |
57 | return {'image': img,
58 | 'label': mask}
59 |
60 |
61 | class RandomRotate(object):
62 | def __init__(self, degree):
63 | self.degree = degree
64 |
65 | def __call__(self, sample):
66 | img = sample['image']
67 | mask = sample['label']
68 | rotate_degree = random.uniform(-1*self.degree, self.degree)
69 | img = img.rotate(rotate_degree, Image.BILINEAR)
70 | mask = mask.rotate(rotate_degree, Image.NEAREST)
71 |
72 | return {'image': img,
73 | 'label': mask}
74 |
75 |
76 | class RandomGaussianBlur(object):
77 | def __call__(self, sample):
78 | img = sample['image']
79 | mask = sample['label']
80 | if random.random() < 0.5:
81 | img = img.filter(ImageFilter.GaussianBlur(
82 | radius=random.random()))
83 |
84 | return {'image': img,
85 | 'label': mask}
86 |
87 |
88 | class RandomScaleCrop(object):
89 | def __init__(self, base_size, crop_size, fill=0):
90 | self.base_size = base_size
91 | self.crop_size = crop_size
92 | self.fill = fill
93 |
94 | def __call__(self, sample):
95 | img = sample['image']
96 | mask = sample['label']
97 | # random scale (short edge)
98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
99 | w, h = img.size
100 | if h > w:
101 | ow = short_size
102 | oh = int(1.0 * h * ow / w)
103 | else:
104 | oh = short_size
105 | ow = int(1.0 * w * oh / h)
106 | img = img.resize((ow, oh), Image.BILINEAR)
107 | mask = mask.resize((ow, oh), Image.NEAREST)
108 | # pad crop
109 | if short_size < self.crop_size:
110 | padh = self.crop_size - oh if oh < self.crop_size else 0
111 | padw = self.crop_size - ow if ow < self.crop_size else 0
112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill)
114 | # random crop crop_size
115 | w, h = img.size
116 | x1 = random.randint(0, w - self.crop_size)
117 | y1 = random.randint(0, h - self.crop_size)
118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
120 |
121 | return {'image': img,
122 | 'label': mask}
123 |
124 |
125 | class FixScaleCrop(object):
126 | def __init__(self, crop_size):
127 | self.crop_size = crop_size
128 |
129 | def __call__(self, sample):
130 | img = sample['image']
131 | mask = sample['label']
132 | w, h = img.size
133 | if w > h:
134 | oh = self.crop_size
135 | ow = int(1.0 * w * oh / h)
136 | else:
137 | ow = self.crop_size
138 | oh = int(1.0 * h * ow / w)
139 | img = img.resize((ow, oh), Image.BILINEAR)
140 | mask = mask.resize((ow, oh), Image.NEAREST)
141 | # center crop
142 | w, h = img.size
143 | x1 = int(round((w - self.crop_size) / 2.))
144 | y1 = int(round((h - self.crop_size) / 2.))
145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
147 |
148 | return {'image': img,
149 | 'label': mask}
150 |
151 | class FixedResize(object):
152 | def __init__(self, size):
153 | self.size = (size, size) # size: (h, w)
154 |
155 | def __call__(self, sample):
156 | img = sample['image']
157 | mask = sample['label']
158 |
159 | assert img.size == mask.size
160 |
161 | img = img.resize(self.size, Image.BILINEAR)
162 | mask = mask.resize(self.size, Image.NEAREST)
163 |
164 | return {'image': img,
165 | 'label': mask}
--------------------------------------------------------------------------------
/dataset/segmentation/pascal.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | # from mypath import Path
7 | from torchvision import transforms
8 |
9 | import dataset.segmentation.custom_transforms as tr
10 | import cv2
11 |
12 | class VOCSegmentation(Dataset):
13 | """
14 | PascalVoc dataset
15 | """
16 | NUM_CLASSES = 21
17 |
18 | def __init__(self,
19 | args,
20 | base_dir='/media/jakc4103/Toshiba/workspace/dataset/VOCdevkit/VOC2012/',
21 | split='val',
22 | label='SegmentationClass'
23 | ):
24 | """
25 | :param base_dir: path to VOC dataset directory
26 | :param split: train/val
27 | :param transform: transform to apply
28 | :param label: SegmentationObject/SegmentationClass
29 | """
30 | super().__init__()
31 | self._base_dir = base_dir
32 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages')
33 | self._cat_dir = os.path.join(self._base_dir, label)
34 |
35 | if isinstance(split, str):
36 | self.split = [split]
37 | else:
38 | split.sort()
39 | self.split = split
40 |
41 | self.args = args
42 |
43 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation')
44 |
45 | self.im_ids = []
46 | self.images = []
47 | self.categories = []
48 |
49 | for splt in self.split:
50 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f:
51 | lines = f.read().splitlines()
52 |
53 | for ii, line in enumerate(lines):
54 | _image = os.path.join(self._image_dir, line + ".jpg")
55 | _cat = os.path.join(self._cat_dir, line + ".png")
56 | assert os.path.isfile(_image)
57 | assert os.path.isfile(_cat)
58 | self.im_ids.append(line)
59 | self.images.append(_image)
60 | self.categories.append(_cat)
61 |
62 | assert (len(self.images) == len(self.categories))
63 |
64 | # Display stats
65 | print('Number of images in {}: {:d}'.format(split, len(self.images)))
66 |
67 | def __len__(self):
68 | return len(self.images)
69 |
70 |
71 | def __getitem__(self, index):
72 | _img, _target = self._make_img_gt_point_pair(index)
73 | sample = {'image': _img, 'label': _target}
74 |
75 | for split in self.split:
76 | if split == "train":
77 | return self.transform_tr(sample)
78 | elif split == 'val' or split == 'test':
79 | return self.transform_val(sample)
80 |
81 |
82 | def _make_img_gt_point_pair(self, index):
83 | _img = Image.open(self.images[index]).convert('RGB')
84 | _target = Image.open(self.categories[index])
85 |
86 | # test = np.array(_target)
87 | # print(test.shape)
88 | # print(np.unique(test))
89 | # cv2.imshow('test', test)
90 | # cv2.waitKey(0)
91 |
92 | return _img, _target
93 |
94 | def transform_tr(self, sample):
95 | composed_transforms = transforms.Compose([
96 | tr.RandomHorizontalFlip(),
97 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size),
98 | tr.RandomGaussianBlur(),
99 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
100 | tr.ToTensor()])
101 |
102 | return composed_transforms(sample)
103 |
104 | def transform_val(self, sample):
105 |
106 | composed_transforms = transforms.Compose([
107 | tr.FixScaleCrop(crop_size=self.args.crop_size),
108 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
109 | tr.ToTensor()])
110 |
111 | return composed_transforms(sample)
112 |
113 | def __str__(self):
114 | return 'VOC2012(split=' + str(self.split) + ')'
115 |
116 |
117 | if __name__ == '__main__':
118 | from utils import decode_segmap
119 | from torch.utils.data import DataLoader
120 | import matplotlib.pyplot as plt
121 | import argparse
122 |
123 | parser = argparse.ArgumentParser()
124 | args = parser.parse_args()
125 | args.base_size = 513
126 | args.crop_size = 513
127 |
128 | voc_train = VOCSegmentation(args, split='val')
129 |
130 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
131 |
132 | for ii, sample in enumerate(dataloader):
133 | for jj in range(sample["image"].size()[0]):
134 | img = sample['image'].numpy()
135 | gt = sample['label'].numpy()
136 | print(np.unique(gt))
137 | tmp = np.array(gt[jj]).astype(np.uint8)
138 | segmap = decode_segmap(tmp, dataset='pascal')
139 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
140 | img_tmp *= (0.229, 0.224, 0.225)
141 | img_tmp += (0.485, 0.456, 0.406)
142 | img_tmp *= 255.0
143 | img_tmp = img_tmp.astype(np.uint8)
144 | plt.figure()
145 | plt.title('display')
146 | plt.subplot(211)
147 | plt.imshow(img_tmp)
148 | plt.subplot(212)
149 | plt.imshow(segmap)
150 |
151 | if ii == 0:
152 | break
153 |
154 | plt.show(block=True)
155 |
156 |
157 |
--------------------------------------------------------------------------------
/dataset/segmentation/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 |
5 | def decode_seg_map_sequence(label_masks, dataset='pascal'):
6 | rgb_masks = []
7 | for label_mask in label_masks:
8 | rgb_mask = decode_segmap(label_mask, dataset)
9 | rgb_masks.append(rgb_mask)
10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
11 | return rgb_masks
12 |
13 |
14 | def decode_segmap(label_mask, dataset, plot=False):
15 | """Decode segmentation class labels into a color image
16 | Args:
17 | label_mask (np.ndarray): an (M,N) array of integer values denoting
18 | the class label at each spatial location.
19 | plot (bool, optional): whether to show the resulting color image
20 | in a figure.
21 | Returns:
22 | (np.ndarray, optional): the resulting decoded color image.
23 | """
24 | if dataset == 'pascal' or dataset == 'coco':
25 | n_classes = 21
26 | label_colours = get_pascal_labels()
27 | elif dataset == 'cityscapes':
28 | n_classes = 19
29 | label_colours = get_cityscapes_labels()
30 | else:
31 | raise NotImplementedError
32 |
33 | r = label_mask.copy()
34 | g = label_mask.copy()
35 | b = label_mask.copy()
36 | for ll in range(0, n_classes):
37 | r[label_mask == ll] = label_colours[ll, 0]
38 | g[label_mask == ll] = label_colours[ll, 1]
39 | b[label_mask == ll] = label_colours[ll, 2]
40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
41 | rgb[:, :, 0] = r / 255.0
42 | rgb[:, :, 1] = g / 255.0
43 | rgb[:, :, 2] = b / 255.0
44 | if plot:
45 | plt.imshow(rgb)
46 | plt.show()
47 | else:
48 | return rgb
49 |
50 |
51 | def encode_segmap(mask):
52 | """Encode segmentation label images as pascal classes
53 | Args:
54 | mask (np.ndarray): raw segmentation label image of dimension
55 | (M, N, 3), in which the Pascal classes are encoded as colours.
56 | Returns:
57 | (np.ndarray): class map with dimensions (M,N), where the value at
58 | a given location is the integer denoting the class index.
59 | """
60 | mask = mask.astype(int)
61 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
62 | for ii, label in enumerate(get_pascal_labels()):
63 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
64 | label_mask = label_mask.astype(int)
65 | return label_mask
66 |
67 |
68 | def get_cityscapes_labels():
69 | return np.array([
70 | [128, 64, 128],
71 | [244, 35, 232],
72 | [70, 70, 70],
73 | [102, 102, 156],
74 | [190, 153, 153],
75 | [153, 153, 153],
76 | [250, 170, 30],
77 | [220, 220, 0],
78 | [107, 142, 35],
79 | [152, 251, 152],
80 | [0, 130, 180],
81 | [220, 20, 60],
82 | [255, 0, 0],
83 | [0, 0, 142],
84 | [0, 0, 70],
85 | [0, 60, 100],
86 | [0, 80, 100],
87 | [0, 0, 230],
88 | [119, 11, 32]])
89 |
90 |
91 | def get_pascal_labels():
92 | """Load the mapping that associates pascal classes with label colors
93 | Returns:
94 | np.ndarray with dimensions (21, 3)
95 | """
96 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
97 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
98 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
99 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
100 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
101 | [0, 64, 128]])
--------------------------------------------------------------------------------
/images/LE_distill.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/LE_distill.png
--------------------------------------------------------------------------------
/images/graph_cls.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_cls.png
--------------------------------------------------------------------------------
/images/graph_deeplab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_deeplab.png
--------------------------------------------------------------------------------
/images/graph_ssd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_ssd.png
--------------------------------------------------------------------------------
/inference_cls.cpp:
--------------------------------------------------------------------------------
1 | // Tencent is pleased to support the open source community by making ncnn available.
2 | //
3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 | //
5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 | // in compliance with the License. You may obtain a copy of the License at
7 | //
8 | // https://opensource.org/licenses/BSD-3-Clause
9 | //
10 | // Unless required by applicable law or agreed to in writing, software distributed
11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 | // specific language governing permissions and limitations under the License.
14 |
15 | #include
16 | #include
17 | #include
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #include "platform.h"
24 | #include "net.h"
25 |
26 | #if NCNN_VULKAN
27 | #include "gpu.h"
28 | #endif // NCNN_VULKAN
29 |
30 | int parse_images_dir(const std::string& base_path, std::vector& file_path)
31 | {
32 | file_path.clear();
33 |
34 | const cv::String base_path_str(base_path);
35 | std::vector image_list;
36 |
37 | cv::glob(base_path_str, image_list, true);
38 |
39 | for (size_t i = 0; i < image_list.size(); i++)
40 | {
41 | const cv::String& image_path = image_list[i];
42 | file_path.push_back(image_path);
43 | }
44 |
45 | return 0;
46 | }
47 |
48 | static int print_topk(const std::vector& cls_scores, int topk)
49 | {
50 | // partial sort topk with index
51 | int size = cls_scores.size();
52 | std::vector< std::pair > vec;
53 | vec.resize(size);
54 | for (int i=0; i >());
61 | int pred_idx;
62 | // print topk and score
63 | for (int i=0; i& image_list, std::vector& cls_scores,
78 | const std::string ncnn_param_file_path, const std::string ncnn_bin_file_path, const std::string out_layer)
79 | {
80 | ncnn::Net net;
81 | size_t size = image_list.size();
82 | printf("Number of images: %lu\n", size);
83 |
84 | #if NCNN_VULKAN
85 | net.opt.use_vulkan_compute = true;
86 | #endif // NCNN_VULKAN
87 |
88 | net.load_param(&ncnn_param_file_path[0]);
89 | net.load_model(&ncnn_bin_file_path[0]);
90 |
91 | const float mean_vals[3] = {0.485f*255.f, 0.456f*255.f, 0.406f*255.f};
92 | const float std_vals[3] = {1/0.229f/255.f, 1/0.224f/255.f, 1/0.225f/255.f};
93 | int correct_count = 0;
94 | int label = -1;
95 | std::string folder_name = "dummy";
96 | for (size_t i = 0; i < image_list.size(); i++)
97 | {
98 |
99 | std::string img_name = image_list[i];
100 |
101 | std::istringstream f(img_name);
102 | std::string s;
103 | while(std::getline(f, s, '/'))
104 | {
105 | if((s.substr(0, 2) == "n0" || s.substr(0, 2) == "n1") && s.size() == 9 && folder_name != s)
106 | {
107 | label++;
108 | folder_name = s;
109 | }
110 | }
111 |
112 | if ((i + 1) % 1000 == 0)
113 | {
114 | fprintf(stderr, " %d/%d, acc:%f\n", static_cast(i + 1), static_cast(size), static_cast(correct_count)/static_cast(i));
115 | }
116 |
117 | #if OpenCV_VERSION_MAJOR > 2
118 | cv::Mat bgr = cv::imread(img_name, cv::IMREAD_COLOR);
119 | #else
120 | cv::Mat bgr = cv::imread(img_name, CV_LOAD_IMAGE_COLOR);
121 | #endif
122 | if (bgr.empty())
123 | {
124 | fprintf(stderr, "cv::imread %s failed\n", img_name.c_str());
125 | return -1;
126 | }
127 |
128 | ncnn::Mat resized = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, bgr.cols, bgr.rows, 256, 256);
129 | ncnn::Mat in;
130 | ncnn::copy_cut_border(resized, in, 16, 16, 16, 16);
131 | in.substract_mean_normalize(mean_vals, std_vals);
132 |
133 | ncnn::Extractor ex = net.create_extractor();
134 | ex.set_num_threads(2);
135 |
136 | ex.input("0", in);
137 |
138 | ncnn::Mat out;
139 | ex.extract(&out_layer[0], out);
140 |
141 | cls_scores.resize(out.w);
142 | for (int j=0; j(correct_count)/static_cast(size));
155 | return 0;
156 | }
157 |
158 | int main(int argc, char** argv)
159 | {
160 | const char* key_map =
161 | "{help h usage ? | | print this message }"
162 | "{param p | | path to ncnn.param file }"
163 | "{bin b | | path to ncnn.bin file }"
164 | "{images i | | path to calibration images folder }"
165 | "{out_layer o | | name of the final layer (innerproduct or softmax) }"
166 | ;
167 |
168 | cv::CommandLineParser parser(argc, argv, key_map);
169 | const std::string image_folder_path = parser.get("images");
170 | const std::string ncnn_param_file_path = parser.get("param");
171 | const std::string ncnn_bin_file_path = parser.get("bin");
172 | const std::string out_layer = parser.get("out_layer");
173 |
174 | // check the input param
175 | if (image_folder_path.empty() || ncnn_param_file_path.empty() || ncnn_bin_file_path.empty())
176 | {
177 | fprintf(stderr, "One or more path may be empty, please check and try again.\n");
178 | return 0;
179 | }
180 |
181 | // parse the image file.
182 | std::vector image_file_path_list;
183 | parse_images_dir(image_folder_path, image_file_path_list);
184 |
185 | #if NCNN_VULKAN
186 | ncnn::create_gpu_instance();
187 | #endif // NCNN_VULKAN
188 |
189 | std::vector cls_scores;
190 | detect_net(image_file_path_list, cls_scores, ncnn_param_file_path, ncnn_bin_file_path, out_layer);
191 |
192 | #if NCNN_VULKAN
193 | ncnn::destroy_gpu_instance();
194 | #endif // NCNN_VULKAN
195 |
196 | return 0;
197 | }
198 |
--------------------------------------------------------------------------------
/main_seg.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | import argparse
9 |
10 | from modeling.segmentation.deeplab import DeepLab
11 | from torch.utils.data import DataLoader
12 | from dataset.segmentation.pascal import VOCSegmentation
13 | from utils.metrics import Evaluator
14 |
15 | from utils.relation import create_relation
16 | from dfq import cross_layer_equalization, bias_absorption, bias_correction, clip_weight
17 | from utils.layer_transform import switch_layers, replace_op, restore_op, set_quant_minmax, merge_batchnorm, quantize_targ_layer#, LayerTransform
18 | from PyTransformer.transformers.torchTransformer import TorchTransformer
19 | from utils.quantize import QuantConv2d, QuantNConv2d, QuantMeasure, QConv2d, set_layer_bits
20 | from ZeroQ.distill_data import getDistilData
21 | from improve_dfq import update_scale, transform_quant_layer, set_scale, update_quant_range, set_update_stat, bias_correction_distill
22 |
23 | def get_argument():
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--quantize", action='store_true')
26 | parser.add_argument("--equalize", action='store_true')
27 | parser.add_argument("--correction", action='store_true')
28 | parser.add_argument("--absorption", action='store_true')
29 | parser.add_argument("--distill_range", action='store_true')
30 | parser.add_argument("--log", action='store_true')
31 | parser.add_argument("--relu", action='store_true')
32 | parser.add_argument("--clip_weight", action='store_true')
33 | parser.add_argument("--dataset", type=str, default="voc12")
34 | parser.add_argument("--trainable", action='store_true')
35 | parser.add_argument("--bits_weight", type=int, default=8)
36 | parser.add_argument("--bits_activation", type=int, default=8)
37 | parser.add_argument("--bits_bias", type=int, default=8)
38 | return parser.parse_args()
39 |
40 | def estimate_stats(model, state_dict, data, num_epoch=10, path_save='modeling/data_dependent_QuantConv2dAdd.pth'):
41 | import copy
42 |
43 | # model = DeepLab(sync_bn=False)
44 | model.eval()
45 |
46 | model = model.cuda()
47 |
48 | args = lambda: 0
49 | args.base_size = 513
50 | args.crop_size = 513
51 | voc_val = VOCSegmentation(args, split='train')
52 | dataloader = DataLoader(voc_val, batch_size=32, shuffle=True, num_workers=0)
53 | model.train()
54 |
55 | replace_op()
56 | ss = time.time()
57 | with torch.no_grad():
58 | for epoch in range(num_epoch):
59 | start = time.time()
60 | for sample in dataloader:
61 | image, _ = sample['image'].cuda(), sample['label'].cuda()
62 |
63 | _ = model(image)
64 |
65 | end = time.time()
66 | print("epoch {}: {} sec.".format(epoch, end-start))
67 | print('total time: {} sec'.format(time.time() - ss))
68 | restore_op()
69 |
70 | # load 'running_mean' and 'running_var' of batchnorm back from pre-trained parameters
71 | bn_dict = {}
72 | for key in state_dict:
73 | if 'running' in key:
74 | bn_dict[key] = state_dict[key]
75 |
76 | state = model.state_dict()
77 | state.update(bn_dict)
78 | model.load_state_dict(state)
79 |
80 | torch.save(model.state_dict(), path_save)
81 |
82 | return model
83 |
84 |
85 | def inference_all(model, dataset='voc12', opt=None):
86 | print("Start inference")
87 | from utils.segmentation.utils import forward_all
88 | args = lambda: 0
89 | args.base_size = 513
90 | args.crop_size = 513
91 | if dataset == 'voc12':
92 | voc_val = VOCSegmentation(args, base_dir="/home/jakc4103/WDesktop/dataset/VOCdevkit/VOC2012/", split='val')
93 | elif dataset == 'voc07':
94 | voc_val = VOCSegmentation(args, base_dir="/home/jakc4103/WDesktop/dataset/VOCdevkit/VOC2007/", split='test')
95 | dataloader = DataLoader(voc_val, batch_size=32, shuffle=False, num_workers=2)
96 |
97 | forward_all(model, dataloader, visualize=False, opt=opt)
98 |
99 |
100 | def main():
101 | args = get_argument()
102 | assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization'
103 | assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize'
104 | data = torch.ones((4, 3, 513, 513))#.cuda()
105 |
106 | model = DeepLab(sync_bn=False)
107 | state_dict = torch.load('modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict']
108 | model.load_state_dict(state_dict)
109 | model.eval()
110 | if args.distill_range:
111 | import copy
112 | # define FP32 model
113 | model_original = copy.deepcopy(model)
114 | model_original.eval()
115 | transformer = TorchTransformer()
116 | transformer._build_graph(model_original, data, [QuantMeasure])
117 | graph = transformer.log.getGraph()
118 | bottoms = transformer.log.getBottoms()
119 |
120 | data_distill = getDistilData(model_original, 'imagenet', 32, bn_merged=False,\
121 | num_batch=8, gpu=True, value_range=[-2.11790393, 2.64], size=[513, 513], early_break_factor=0.2)
122 |
123 | transformer = TorchTransformer()
124 |
125 | module_dict = {}
126 | if args.quantize:
127 | if args.distill_range:
128 | module_dict[1] = [(nn.Conv2d, QConv2d)]
129 | elif args.trainable:
130 | module_dict[1] = [(nn.Conv2d, QuantConv2d)]
131 | else:
132 | module_dict[1] = [(nn.Conv2d, QuantNConv2d)]
133 |
134 | if args.relu:
135 | module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)]
136 |
137 | # transformer.summary(model, data)
138 | # transformer.visualize(model, data, 'graph_deeplab', graph_size=120)
139 |
140 | model, transformer = switch_layers(model, transformer, data, module_dict, ignore_layer=[QuantMeasure], quant_op=args.quantize)
141 | graph = transformer.log.getGraph()
142 | bottoms = transformer.log.getBottoms()
143 |
144 | if args.quantize:
145 | if args.distill_range:
146 | targ_layer = [QConv2d]
147 | elif args.trainable:
148 | targ_layer = [QuantConv2d]
149 | else:
150 | targ_layer = [QuantNConv2d]
151 | else:
152 | targ_layer = [nn.Conv2d]
153 | if args.quantize:
154 | set_layer_bits(graph, args.bits_weight, args.bits_activation, args.bits_bias, targ_layer)
155 | model = merge_batchnorm(model, graph, bottoms, targ_layer)
156 |
157 | #create relations
158 | if args.equalize or args.distill_range:
159 | res = create_relation(graph, bottoms, targ_layer)
160 | if args.equalize:
161 | cross_layer_equalization(graph, res, targ_layer, visualize_state=False)
162 |
163 | # if args.distill:
164 | # set_scale(res, graph, bottoms, targ_layer)
165 |
166 | if args.absorption:
167 | bias_absorption(graph, res, bottoms, 3)
168 |
169 | if args.clip_weight:
170 | clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer)
171 |
172 | if args.correction:
173 | bias_correction(graph, bottoms, targ_layer)
174 |
175 | if args.quantize:
176 | if not args.trainable and not args.distill_range:
177 | graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer)
178 |
179 | if args.distill_range:
180 | set_update_stat(model, [QuantMeasure], True)
181 | model = update_quant_range(model.cuda(), data_distill, graph, bottoms)
182 | set_update_stat(model, [QuantMeasure], False)
183 | else:
184 | set_quant_minmax(graph, bottoms)
185 |
186 | torch.cuda.empty_cache()
187 |
188 | model = model.cuda()
189 | model.eval()
190 |
191 | if args.quantize:
192 | replace_op()
193 | inference_all(model, args.dataset, args if args.log else None)
194 | if args.quantize:
195 | restore_op()
196 |
197 |
198 | if __name__ == '__main__':
199 | main()
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/__init__.py
--------------------------------------------------------------------------------
/modeling/classification/MobileNetV2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | def conv_bn(inp, oup, stride):
7 | return nn.Sequential(
8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
9 | nn.BatchNorm2d(oup),
10 | nn.ReLU6(inplace=True)
11 | )
12 |
13 |
14 | def conv_1x1_bn(inp, oup):
15 | return nn.Sequential(
16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
17 | nn.BatchNorm2d(oup),
18 | nn.ReLU6(inplace=True)
19 | )
20 |
21 |
22 | def make_divisible(x, divisible_by=8):
23 | import numpy as np
24 | return int(np.ceil(x * 1. / divisible_by) * divisible_by)
25 |
26 |
27 | class InvertedResidual(nn.Module):
28 | def __init__(self, inp, oup, stride, expand_ratio):
29 | super(InvertedResidual, self).__init__()
30 | self.stride = stride
31 | assert stride in [1, 2]
32 |
33 | hidden_dim = int(inp * expand_ratio)
34 | self.use_res_connect = self.stride == 1 and inp == oup
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
40 | nn.BatchNorm2d(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
50 | nn.BatchNorm2d(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
54 | nn.BatchNorm2d(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
58 | nn.BatchNorm2d(oup),
59 | )
60 |
61 | def forward(self, x):
62 | if self.use_res_connect:
63 | return x + self.conv(x)
64 | else:
65 | return self.conv(x)
66 |
67 |
68 | class MobileNetV2(nn.Module):
69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
70 | super(MobileNetV2, self).__init__()
71 | block = InvertedResidual
72 | input_channel = 32
73 | last_channel = 1280
74 | interverted_residual_setting = [
75 | # t, c, n, s
76 | [1, 16, 1, 1],
77 | [6, 24, 2, 2],
78 | [6, 32, 3, 2],
79 | [6, 64, 4, 2],
80 | [6, 96, 3, 1],
81 | [6, 160, 3, 2],
82 | [6, 320, 1, 1],
83 | ]
84 |
85 | # building first layer
86 | assert input_size % 32 == 0
87 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32!
88 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
89 | self.features = [conv_bn(3, input_channel, 2)]
90 | # building inverted residual blocks
91 | for t, c, n, s in interverted_residual_setting:
92 | output_channel = make_divisible(c * width_mult) if t > 1 else c
93 | for i in range(n):
94 | if i == 0:
95 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
96 | else:
97 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
98 | input_channel = output_channel
99 | # building last several layers
100 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
101 | # make it nn.Sequential
102 | self.features = nn.Sequential(*self.features)
103 |
104 | # building classifier
105 | self.classifier = nn.Linear(self.last_channel, n_class)
106 |
107 | self._initialize_weights()
108 |
109 | def forward(self, x):
110 | x = self.features(x)
111 | # x = x.mean(3).mean(2)
112 | x = torch.mean(x.view(x.size(0), x.size(1), -1), -1)
113 | x = self.classifier(x)
114 | return x
115 |
116 | def _initialize_weights(self):
117 | for m in self.modules():
118 | if isinstance(m, nn.Conv2d):
119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
120 | m.weight.data.normal_(0, math.sqrt(2. / n))
121 | if m.bias is not None:
122 | m.bias.data.zero_()
123 | elif isinstance(m, nn.BatchNorm2d):
124 | m.weight.data.fill_(1)
125 | m.bias.data.zero_()
126 | elif isinstance(m, nn.Linear):
127 | n = m.weight.size(1)
128 | m.weight.data.normal_(0, 0.01)
129 | m.bias.data.zero_()
130 |
131 |
132 | def mobilenet_v2(path_weight=None):
133 | model = MobileNetV2(width_mult=1)
134 |
135 | if path_weight is not None:
136 | print("load weight: {}".format(path_weight))
137 | # state_dict = load_state_dict_from_url(
138 | # 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True)
139 | state_dict = torch.load(path_weight)
140 |
141 | model.load_state_dict(state_dict)
142 | return model
143 |
144 |
145 | if __name__ == '__main__':
146 | # 'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar'
147 | net = mobilenet_v2('./mobilenetv2_1.0-f2a8633.pth.tar')
148 |
149 |
150 |
151 |
152 |
153 |
--------------------------------------------------------------------------------
/modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar
--------------------------------------------------------------------------------
/modeling/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/__init__.py
--------------------------------------------------------------------------------
/modeling/detection/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/config/__init__.py
--------------------------------------------------------------------------------
/modeling/detection/config/mobilenetv1_ssd_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors
4 |
5 |
6 | image_size = 300
7 | image_mean = np.array([127, 127, 127]) # RGB layout
8 | image_std = 128.0
9 | iou_threshold = 0.45
10 | center_variance = 0.1
11 | size_variance = 0.2
12 |
13 | specs = [
14 | SSDSpec(19, 16, SSDBoxSizes(60, 105), [2, 3]),
15 | SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]),
16 | SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]),
17 | SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]),
18 | SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]),
19 | SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3])
20 | ]
21 |
22 |
23 | priors = generate_ssd_priors(specs, image_size)
--------------------------------------------------------------------------------
/modeling/detection/config/squeezenet_ssd_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors
4 |
5 |
6 | image_size = 300
7 | image_mean = np.array([127, 127, 127]) # RGB layout
8 | image_std = 128.0
9 | iou_threshold = 0.45
10 | center_variance = 0.1
11 | size_variance = 0.2
12 |
13 | specs = [
14 | SSDSpec(17, 16, SSDBoxSizes(60, 105), [2, 3]),
15 | SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]),
16 | SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]),
17 | SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]),
18 | SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]),
19 | SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3])
20 | ]
21 |
22 |
23 | priors = generate_ssd_priors(specs, image_size)
--------------------------------------------------------------------------------
/modeling/detection/config/vgg_ssd_config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors
4 |
5 |
6 | image_size = 300
7 | image_mean = np.array([123, 117, 104]) # RGB layout
8 | image_std = 1.0
9 |
10 | iou_threshold = 0.45
11 | center_variance = 0.1
12 | size_variance = 0.2
13 |
14 | specs = [
15 | SSDSpec(38, 8, SSDBoxSizes(30, 60), [2]),
16 | SSDSpec(19, 16, SSDBoxSizes(60, 111), [2, 3]),
17 | SSDSpec(10, 32, SSDBoxSizes(111, 162), [2, 3]),
18 | SSDSpec(5, 64, SSDBoxSizes(162, 213), [2, 3]),
19 | SSDSpec(3, 100, SSDBoxSizes(213, 264), [2]),
20 | SSDSpec(1, 300, SSDBoxSizes(264, 315), [2])
21 | ]
22 |
23 |
24 | priors = generate_ssd_priors(specs, image_size)
--------------------------------------------------------------------------------
/modeling/detection/data_preprocessing.py:
--------------------------------------------------------------------------------
1 | from .transforms.transforms import *
2 |
3 |
4 | class TrainAugmentation:
5 | def __init__(self, size, mean=0, std=1.0):
6 | """
7 | Args:
8 | size: the size the of final image.
9 | mean: mean pixel value per channel.
10 | """
11 | self.mean = mean
12 | self.size = size
13 | self.augment = Compose([
14 | ConvertFromInts(),
15 | PhotometricDistort(),
16 | Expand(self.mean),
17 | RandomSampleCrop(),
18 | RandomMirror(),
19 | ToPercentCoords(),
20 | Resize(self.size),
21 | SubtractMeans(self.mean),
22 | lambda img, boxes=None, labels=None: (img / std, boxes, labels),
23 | ToTensor(),
24 | ])
25 |
26 | def __call__(self, img, boxes, labels):
27 | """
28 |
29 | Args:
30 | img: the output of cv.imread in RGB layout.
31 | boxes: boundding boxes in the form of (x1, y1, x2, y2).
32 | labels: labels of boxes.
33 | """
34 | return self.augment(img, boxes, labels)
35 |
36 |
37 | class TestTransform:
38 | def __init__(self, size, mean=0.0, std=1.0):
39 | self.transform = Compose([
40 | ToPercentCoords(),
41 | Resize(size),
42 | SubtractMeans(mean),
43 | lambda img, boxes=None, labels=None: (img / std, boxes, labels),
44 | ToTensor(),
45 | ])
46 |
47 | def __call__(self, image, boxes, labels):
48 | return self.transform(image, boxes, labels)
49 |
50 |
51 | class PredictionTransform:
52 | def __init__(self, size, mean=0.0, std=1.0):
53 | self.transform = Compose([
54 | Resize(size),
55 | SubtractMeans(mean),
56 | lambda img, boxes=None, labels=None: (img / std, boxes, labels),
57 | ToTensor()
58 | ])
59 |
60 | def __call__(self, image):
61 | image, _, _ = self.transform(image)
62 | return image
--------------------------------------------------------------------------------
/modeling/detection/fpn_mobilenetv1_ssd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU
3 | from .nn.mobilenet import MobileNetV1
4 |
5 | from .fpn_ssd import FPNSSD
6 | from .predictor import Predictor
7 | from .config import mobilenetv1_ssd_config as config
8 |
9 |
10 | def create_fpn_mobilenetv1_ssd(num_classes):
11 | base_net = MobileNetV1(1001).features # disable dropout layer
12 |
13 | source_layer_indexes = [
14 | (69, Conv2d(in_channels=512, out_channels=256, kernel_size=1)),
15 | (len(base_net), Conv2d(in_channels=1024, out_channels=256, kernel_size=1)),
16 | ]
17 | extras = ModuleList([
18 | Sequential(
19 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
20 | ReLU(),
21 | Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
22 | ReLU()
23 | ),
24 | Sequential(
25 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
26 | ReLU(),
27 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
28 | ReLU()
29 | ),
30 | Sequential(
31 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
32 | ReLU(),
33 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
34 | ReLU()
35 | ),
36 | Sequential(
37 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
38 | ReLU(),
39 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
40 | ReLU()
41 | )
42 | ])
43 |
44 | regression_headers = ModuleList([
45 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
46 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
47 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
48 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
49 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
51 | ])
52 |
53 | classification_headers = ModuleList([
54 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
55 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
56 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
57 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
58 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
60 | ])
61 |
62 | return FPNSSD(num_classes, base_net, source_layer_indexes,
63 | extras, classification_headers, regression_headers)
64 |
65 |
66 | def create_fpn_mobilenetv1_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')):
67 | predictor = Predictor(net, config.image_size, config.image_mean, config.priors,
68 | config.center_variance, config.size_variance,
69 | nms_method=nms_method,
70 | iou_threshold=config.iou_threshold,
71 | candidate_size=candidate_size,
72 | sigma=sigma,
73 | device=device)
74 | return predictor
75 |
--------------------------------------------------------------------------------
/modeling/detection/fpn_ssd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from typing import List, Tuple
6 |
7 | from utils.detection import box_utils
8 |
9 |
10 | class FPNSSD(nn.Module):
11 | def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int],
12 | extras: nn.ModuleList, classification_headers: nn.ModuleList,
13 | regression_headers: nn.ModuleList, upsample_mode="nearest"):
14 | """Compose a SSD model using the given components.
15 | """
16 | super(FPNSSD, self).__init__()
17 |
18 | self.num_classes = num_classes
19 | self.base_net = base_net
20 | self.source_layer_indexes = source_layer_indexes
21 | self.extras = extras
22 | self.classification_headers = classification_headers
23 | self.regression_headers = regression_headers
24 | self.upsample_mode = upsample_mode
25 |
26 | # register layers in source_layer_indexes by adding them to a module list
27 | self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes if isinstance(t, tuple)])
28 | self.upsamplers = [
29 | nn.Upsample(size=(19, 19), mode='bilinear'),
30 | nn.Upsample(size=(10, 10), mode='bilinear'),
31 | nn.Upsample(size=(5, 5), mode='bilinear'),
32 | nn.Upsample(size=(3, 3), mode='bilinear'),
33 | nn.Upsample(size=(2, 2), mode='bilinear'),
34 | ]
35 |
36 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
37 | confidences = []
38 | locations = []
39 | start_layer_index = 0
40 | header_index = 0
41 | features = []
42 | for end_layer_index in self.source_layer_indexes:
43 |
44 | if isinstance(end_layer_index, tuple):
45 | added_layer = end_layer_index[1]
46 | end_layer_index = end_layer_index[0]
47 | else:
48 | added_layer = None
49 | for layer in self.base_net[start_layer_index: end_layer_index]:
50 | x = layer(x)
51 | start_layer_index = end_layer_index
52 | if added_layer:
53 | y = added_layer(x)
54 | else:
55 | y = x
56 | #confidence, location = self.compute_header(header_index, y)
57 | features.append(y)
58 | header_index += 1
59 | # confidences.append(confidence)
60 | # locations.append(location)
61 |
62 | for layer in self.base_net[end_layer_index:]:
63 | x = layer(x)
64 |
65 | for layer in self.extras:
66 | x = layer(x)
67 | #confidence, location = self.compute_header(header_index, x)
68 | features.append(x)
69 | header_index += 1
70 | # confidences.append(confidence)
71 | # locations.append(location)
72 |
73 | upstream_feature = None
74 | for i in range(len(features) - 1, -1, -1):
75 | feature = features[i]
76 | if upstream_feature is not None:
77 | upstream_feature = self.upsamplers[i](upstream_feature)
78 | upstream_feature += feature
79 | else:
80 | upstream_feature = feature
81 | confidence, location = self.compute_header(i, upstream_feature)
82 | confidences.append(confidence)
83 | locations.append(location)
84 | confidences = torch.cat(confidences, 1)
85 | locations = torch.cat(locations, 1)
86 | return confidences, locations
87 |
88 | def compute_header(self, i, x):
89 | confidence = self.classification_headers[i](x)
90 | confidence = confidence.permute(0, 2, 3, 1).contiguous()
91 | confidence = confidence.view(confidence.size(0), -1, self.num_classes)
92 |
93 | location = self.regression_headers[i](x)
94 | location = location.permute(0, 2, 3, 1).contiguous()
95 | location = location.view(location.size(0), -1, 4)
96 |
97 | return confidence, location
98 |
99 | def init_from_base_net(self, model):
100 | self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=False)
101 | self.source_layer_add_ons.apply(_xavier_init_)
102 | self.extras.apply(_xavier_init_)
103 | self.classification_headers.apply(_xavier_init_)
104 | self.regression_headers.apply(_xavier_init_)
105 |
106 | def init(self):
107 | self.base_net.apply(_xavier_init_)
108 | self.source_layer_add_ons.apply(_xavier_init_)
109 | self.extras.apply(_xavier_init_)
110 | self.classification_headers.apply(_xavier_init_)
111 | self.regression_headers.apply(_xavier_init_)
112 |
113 | def load(self, model):
114 | self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
115 |
116 | def save(self, model_path):
117 | torch.save(self.state_dict(), model_path)
118 |
119 |
120 | class MatchPrior(object):
121 | def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold):
122 | self.center_form_priors = center_form_priors
123 | self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors)
124 | self.center_variance = center_variance
125 | self.size_variance = size_variance
126 | self.iou_threshold = iou_threshold
127 |
128 | def __call__(self, gt_boxes, gt_labels):
129 | if type(gt_boxes) is np.ndarray:
130 | gt_boxes = torch.from_numpy(gt_boxes)
131 | if type(gt_labels) is np.ndarray:
132 | gt_labels = torch.from_numpy(gt_labels)
133 | boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels,
134 | self.corner_form_priors, self.iou_threshold)
135 | boxes = box_utils.corner_form_to_center_form(boxes)
136 | locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance)
137 | return locations, labels
138 |
139 |
140 | def _xavier_init_(m: nn.Module):
141 | if isinstance(m, nn.Conv2d):
142 | nn.init.xavier_uniform_(m.weight)
143 |
--------------------------------------------------------------------------------
/modeling/detection/mb2-ssd-lite-mp-0_686.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/mb2-ssd-lite-mp-0_686.pth
--------------------------------------------------------------------------------
/modeling/detection/mobilenet_v2_ssd_lite.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, BatchNorm2d
3 | from torch import nn
4 | from .nn.mobilenet_v2 import MobileNetV2, InvertedResidual
5 |
6 | from .ssd import SSD, GraphPath
7 | from .predictor import Predictor
8 | from .config import mobilenetv1_ssd_config as config
9 |
10 |
11 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, onnx_compatible=False):
12 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
13 | """
14 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
15 | return Sequential(
16 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
17 | groups=in_channels, stride=stride, padding=padding),
18 | BatchNorm2d(in_channels),
19 | ReLU(),
20 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
21 | )
22 |
23 |
24 | def create_mobilenetv2_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False, quantize=False):
25 | base_net = MobileNetV2(width_mult=width_mult, use_batch_norm=use_batch_norm,
26 | onnx_compatible=onnx_compatible).features
27 |
28 | source_layer_indexes = [
29 | GraphPath(14, 'conv', 3),
30 | 19,
31 | ]
32 | extras = ModuleList([
33 | InvertedResidual(1280, 512, stride=2, expand_ratio=0.2),
34 | InvertedResidual(512, 256, stride=2, expand_ratio=0.25),
35 | InvertedResidual(256, 256, stride=2, expand_ratio=0.5),
36 | InvertedResidual(256, 64, stride=2, expand_ratio=0.25)
37 | ])
38 |
39 | regression_headers = ModuleList([
40 | SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * 4,
41 | kernel_size=3, padding=1, onnx_compatible=False),
42 | SeperableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
43 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
44 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
45 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False),
46 | Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1),
47 | ])
48 |
49 | classification_headers = ModuleList([
50 | SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1),
51 | SeperableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1),
52 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
53 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
54 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
55 | Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1),
56 | ])
57 | if quantize:
58 | from utils.quantize import quantize
59 | config.priors = quantize(config.priors, num_bits=8, min_value=float(config.priors.min()), max_value=float(config.priors.max()))
60 | return SSD(num_classes, base_net, source_layer_indexes,
61 | extras, classification_headers, regression_headers, is_test=is_test, config=config)
62 |
63 |
64 | def create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')):
65 | predictor = Predictor(net, config.image_size, config.image_mean,
66 | config.image_std,
67 | nms_method=nms_method,
68 | iou_threshold=config.iou_threshold,
69 | candidate_size=candidate_size,
70 | sigma=sigma,
71 | device=device)
72 | return predictor
73 |
--------------------------------------------------------------------------------
/modeling/detection/mobilenetv1_ssd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU
3 | from .nn.mobilenet import MobileNetV1
4 |
5 | from .ssd import SSD
6 | from .predictor import Predictor
7 | from .config import mobilenetv1_ssd_config as config
8 |
9 |
10 | def create_mobilenetv1_ssd(num_classes, is_test=False):
11 | base_net = MobileNetV1(1001).model # disable dropout layer
12 |
13 | source_layer_indexes = [
14 | 12,
15 | 14,
16 | ]
17 | extras = ModuleList([
18 | Sequential(
19 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
20 | ReLU(),
21 | Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
22 | ReLU()
23 | ),
24 | Sequential(
25 | Conv2d(in_channels=512, out_channels=128, kernel_size=1),
26 | ReLU(),
27 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
28 | ReLU()
29 | ),
30 | Sequential(
31 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
32 | ReLU(),
33 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
34 | ReLU()
35 | ),
36 | Sequential(
37 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
38 | ReLU(),
39 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
40 | ReLU()
41 | )
42 | ])
43 |
44 | regression_headers = ModuleList([
45 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
46 | Conv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1),
47 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
48 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
49 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
51 | ])
52 |
53 | classification_headers = ModuleList([
54 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
55 | Conv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1),
56 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
57 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
58 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
60 | ])
61 |
62 | return SSD(num_classes, base_net, source_layer_indexes,
63 | extras, classification_headers, regression_headers, is_test=is_test, config=config)
64 |
65 |
66 | def create_mobilenetv1_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None):
67 | predictor = Predictor(net, config.image_size, config.image_mean,
68 | config.image_std,
69 | nms_method=nms_method,
70 | iou_threshold=config.iou_threshold,
71 | candidate_size=candidate_size,
72 | sigma=sigma,
73 | device=device)
74 | return predictor
75 |
--------------------------------------------------------------------------------
/modeling/detection/mobilenetv1_ssd_lite.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU, BatchNorm2d
3 | from .nn.mobilenet import MobileNetV1
4 |
5 | from .ssd import SSD
6 | from .predictor import Predictor
7 | from .config import mobilenetv1_ssd_config as config
8 |
9 |
10 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0):
11 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
12 | """
13 | return Sequential(
14 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
15 | groups=in_channels, stride=stride, padding=padding),
16 | ReLU(),
17 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
18 | )
19 |
20 |
21 | def create_mobilenetv1_ssd_lite(num_classes, is_test=False):
22 | base_net = MobileNetV1(1001).model # disable dropout layer
23 |
24 | source_layer_indexes = [
25 | 12,
26 | 14,
27 | ]
28 | extras = ModuleList([
29 | Sequential(
30 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
31 | ReLU(),
32 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
33 | ),
34 | Sequential(
35 | Conv2d(in_channels=512, out_channels=128, kernel_size=1),
36 | ReLU(),
37 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
38 | ),
39 | Sequential(
40 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
41 | ReLU(),
42 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
43 | ),
44 | Sequential(
45 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
46 | ReLU(),
47 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
48 | )
49 | ])
50 |
51 | regression_headers = ModuleList([
52 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
53 | SeperableConv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1),
54 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
55 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
56 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
57 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=1),
58 | ])
59 |
60 | classification_headers = ModuleList([
61 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
62 | SeperableConv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1),
63 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
64 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
65 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
66 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=1),
67 | ])
68 |
69 | return SSD(num_classes, base_net, source_layer_indexes,
70 | extras, classification_headers, regression_headers, is_test=is_test, config=config)
71 |
72 |
73 | def create_mobilenetv1_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None):
74 | predictor = Predictor(net, config.image_size, config.image_mean,
75 | config.image_std,
76 | nms_method=nms_method,
77 | iou_threshold=config.iou_threshold,
78 | candidate_size=candidate_size,
79 | sigma=sigma,
80 | device=device)
81 | return predictor
82 |
--------------------------------------------------------------------------------
/modeling/detection/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/nn/__init__.py
--------------------------------------------------------------------------------
/modeling/detection/nn/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 |
4 | # copied from torchvision (https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py).
5 | # The forward function is modified for model pruning.
6 |
7 | __all__ = ['AlexNet', 'alexnet']
8 |
9 |
10 | model_urls = {
11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
12 | }
13 |
14 |
15 | class AlexNet(nn.Module):
16 |
17 | def __init__(self, num_classes=1000):
18 | super(AlexNet, self).__init__()
19 | self.features = nn.Sequential(
20 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
21 | nn.ReLU(inplace=True),
22 | nn.MaxPool2d(kernel_size=3, stride=2),
23 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
24 | nn.ReLU(inplace=True),
25 | nn.MaxPool2d(kernel_size=3, stride=2),
26 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
31 | nn.ReLU(inplace=True),
32 | nn.MaxPool2d(kernel_size=3, stride=2),
33 | )
34 | self.classifier = nn.Sequential(
35 | nn.Dropout(),
36 | nn.Linear(256 * 6 * 6, 4096),
37 | nn.ReLU(inplace=True),
38 | nn.Dropout(),
39 | nn.Linear(4096, 4096),
40 | nn.ReLU(inplace=True),
41 | nn.Linear(4096, num_classes),
42 | )
43 |
44 | def forward(self, x):
45 | x = self.features(x)
46 | x = x.view(x.size(0), -1)
47 | x = self.classifier(x)
48 | return x
49 |
50 |
51 | def alexnet(pretrained=False, **kwargs):
52 | r"""AlexNet model architecture from the
53 | `"One weird trick..." `_ paper.
54 |
55 | Args:
56 | pretrained (bool): If True, returns a model pre-trained on ImageNet
57 | """
58 | model = AlexNet(**kwargs)
59 | if pretrained:
60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
61 | return model
--------------------------------------------------------------------------------
/modeling/detection/nn/mobilenet.py:
--------------------------------------------------------------------------------
1 | # borrowed from "https://github.com/marvis/pytorch-mobilenet"
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class MobileNetV1(nn.Module):
8 | def __init__(self, num_classes=1024):
9 | super(MobileNetV1, self).__init__()
10 |
11 | def conv_bn(inp, oup, stride):
12 | return nn.Sequential(
13 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
14 | nn.BatchNorm2d(oup),
15 | nn.ReLU(inplace=True)
16 | )
17 |
18 | def conv_dw(inp, oup, stride):
19 | return nn.Sequential(
20 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
21 | nn.BatchNorm2d(inp),
22 | nn.ReLU(inplace=True),
23 |
24 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
25 | nn.BatchNorm2d(oup),
26 | nn.ReLU(inplace=True),
27 | )
28 |
29 | self.model = nn.Sequential(
30 | conv_bn(3, 32, 2),
31 | conv_dw(32, 64, 1),
32 | conv_dw(64, 128, 2),
33 | conv_dw(128, 128, 1),
34 | conv_dw(128, 256, 2),
35 | conv_dw(256, 256, 1),
36 | conv_dw(256, 512, 2),
37 | conv_dw(512, 512, 1),
38 | conv_dw(512, 512, 1),
39 | conv_dw(512, 512, 1),
40 | conv_dw(512, 512, 1),
41 | conv_dw(512, 512, 1),
42 | conv_dw(512, 1024, 2),
43 | conv_dw(1024, 1024, 1),
44 | )
45 | self.fc = nn.Linear(1024, num_classes)
46 |
47 | def forward(self, x):
48 | x = self.model(x)
49 | x = F.avg_pool2d(x, 7)
50 | x = x.view(-1, 1024)
51 | x = self.fc(x)
52 | return x
--------------------------------------------------------------------------------
/modeling/detection/nn/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 | # Modified from https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py.
5 | # In this version, Relu6 is replaced with Relu to make it ONNX compatible.
6 | # BatchNorm Layer is optional to make it easy do batch norm confusion.
7 |
8 |
9 | def conv_bn(inp, oup, stride, use_batch_norm=True, onnx_compatible=False):
10 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
11 |
12 | if use_batch_norm:
13 | return nn.Sequential(
14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
15 | nn.BatchNorm2d(oup),
16 | ReLU(inplace=True)
17 | )
18 | else:
19 | return nn.Sequential(
20 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
21 | ReLU(inplace=True)
22 | )
23 |
24 |
25 | def conv_1x1_bn(inp, oup, use_batch_norm=True, onnx_compatible=False):
26 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
27 | if use_batch_norm:
28 | return nn.Sequential(
29 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
30 | nn.BatchNorm2d(oup),
31 | ReLU(inplace=True)
32 | )
33 | else:
34 | return nn.Sequential(
35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
36 | ReLU(inplace=True)
37 | )
38 |
39 |
40 | class InvertedResidual(nn.Module):
41 | def __init__(self, inp, oup, stride, expand_ratio, use_batch_norm=True, onnx_compatible=False):
42 | super(InvertedResidual, self).__init__()
43 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6
44 |
45 | self.stride = stride
46 | assert stride in [1, 2]
47 |
48 | hidden_dim = round(inp * expand_ratio)
49 | self.use_res_connect = self.stride == 1 and inp == oup
50 |
51 | if expand_ratio == 1:
52 | if use_batch_norm:
53 | self.conv = nn.Sequential(
54 | # dw
55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
56 | nn.BatchNorm2d(hidden_dim),
57 | ReLU(inplace=True),
58 | # pw-linear
59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
60 | nn.BatchNorm2d(oup),
61 | )
62 | else:
63 | self.conv = nn.Sequential(
64 | # dw
65 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
66 | ReLU(inplace=True),
67 | # pw-linear
68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
69 | )
70 | else:
71 | if use_batch_norm:
72 | self.conv = nn.Sequential(
73 | # pw
74 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
75 | nn.BatchNorm2d(hidden_dim),
76 | ReLU(inplace=True),
77 | # dw
78 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
79 | nn.BatchNorm2d(hidden_dim),
80 | ReLU(inplace=True),
81 | # pw-linear
82 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
83 | nn.BatchNorm2d(oup),
84 | )
85 | else:
86 | self.conv = nn.Sequential(
87 | # pw
88 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
89 | ReLU(inplace=True),
90 | # dw
91 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
92 | ReLU(inplace=True),
93 | # pw-linear
94 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
95 | )
96 |
97 | def forward(self, x):
98 | if self.use_res_connect:
99 | return x + self.conv(x)
100 | else:
101 | return self.conv(x)
102 |
103 |
104 | class MobileNetV2(nn.Module):
105 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout_ratio=0.2,
106 | use_batch_norm=True, onnx_compatible=False):
107 | super(MobileNetV2, self).__init__()
108 | block = InvertedResidual
109 | input_channel = 32
110 | last_channel = 1280
111 | interverted_residual_setting = [
112 | # t, c, n, s
113 | [1, 16, 1, 1],
114 | [6, 24, 2, 2],
115 | [6, 32, 3, 2],
116 | [6, 64, 4, 2],
117 | [6, 96, 3, 1],
118 | [6, 160, 3, 2],
119 | [6, 320, 1, 1],
120 | ]
121 |
122 | # building first layer
123 | assert input_size % 32 == 0
124 | input_channel = int(input_channel * width_mult)
125 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
126 | self.features = [conv_bn(3, input_channel, 2, onnx_compatible=onnx_compatible)]
127 | # building inverted residual blocks
128 | for t, c, n, s in interverted_residual_setting:
129 | output_channel = int(c * width_mult)
130 | for i in range(n):
131 | if i == 0:
132 | self.features.append(block(input_channel, output_channel, s,
133 | expand_ratio=t, use_batch_norm=use_batch_norm,
134 | onnx_compatible=onnx_compatible))
135 | else:
136 | self.features.append(block(input_channel, output_channel, 1,
137 | expand_ratio=t, use_batch_norm=use_batch_norm,
138 | onnx_compatible=onnx_compatible))
139 | input_channel = output_channel
140 | # building last several layers
141 | self.features.append(conv_1x1_bn(input_channel, self.last_channel,
142 | use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible))
143 | # make it nn.Sequential
144 | self.features = nn.Sequential(*self.features)
145 |
146 | # building classifier
147 | self.classifier = nn.Sequential(
148 | nn.Dropout(dropout_ratio),
149 | nn.Linear(self.last_channel, n_class),
150 | )
151 |
152 | self._initialize_weights()
153 |
154 | def forward(self, x):
155 | x = self.features(x)
156 | x = x.mean(3).mean(2)
157 | x = self.classifier(x)
158 | return x
159 |
160 | def _initialize_weights(self):
161 | for m in self.modules():
162 | if isinstance(m, nn.Conv2d):
163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
164 | m.weight.data.normal_(0, math.sqrt(2. / n))
165 | if m.bias is not None:
166 | m.bias.data.zero_()
167 | elif isinstance(m, nn.BatchNorm2d):
168 | m.weight.data.fill_(1)
169 | m.bias.data.zero_()
170 | elif isinstance(m, nn.Linear):
171 | n = m.weight.size(1)
172 | m.weight.data.normal_(0, 0.01)
173 | m.bias.data.zero_()
174 |
--------------------------------------------------------------------------------
/modeling/detection/nn/multibox_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 |
6 | from ..utils import box_utils
7 |
8 |
9 | class MultiboxLoss(nn.Module):
10 | def __init__(self, priors, iou_threshold, neg_pos_ratio,
11 | center_variance, size_variance, device):
12 | """Implement SSD Multibox Loss.
13 |
14 | Basically, Multibox loss combines classification loss
15 | and Smooth L1 regression loss.
16 | """
17 | super(MultiboxLoss, self).__init__()
18 | self.iou_threshold = iou_threshold
19 | self.neg_pos_ratio = neg_pos_ratio
20 | self.center_variance = center_variance
21 | self.size_variance = size_variance
22 | self.priors = priors
23 | self.priors.to(device)
24 |
25 | def forward(self, confidence, predicted_locations, labels, gt_locations):
26 | """Compute classification loss and smooth l1 loss.
27 |
28 | Args:
29 | confidence (batch_size, num_priors, num_classes): class predictions.
30 | locations (batch_size, num_priors, 4): predicted locations.
31 | labels (batch_size, num_priors): real labels of all the priors.
32 | boxes (batch_size, num_priors, 4): real boxes corresponding all the priors.
33 | """
34 | num_classes = confidence.size(2)
35 | with torch.no_grad():
36 | # derived from cross_entropy=sum(log(p))
37 | loss = -F.log_softmax(confidence, dim=2)[:, :, 0]
38 | mask = box_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio)
39 |
40 | confidence = confidence[mask, :]
41 | classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], size_average=False)
42 | pos_mask = labels > 0
43 | predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4)
44 | gt_locations = gt_locations[pos_mask, :].reshape(-1, 4)
45 | smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False)
46 | num_pos = gt_locations.size(0)
47 | return smooth_l1_loss/num_pos, classification_loss/num_pos
48 |
--------------------------------------------------------------------------------
/modeling/detection/nn/scaled_l2_norm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | class ScaledL2Norm(nn.Module):
7 | def __init__(self, in_channels, initial_scale):
8 | super(ScaledL2Norm, self).__init__()
9 | self.in_channels = in_channels
10 | self.scale = nn.Parameter(torch.Tensor(in_channels))
11 | self.initial_scale = initial_scale
12 | self.reset_parameters()
13 |
14 | def forward(self, x):
15 | return (F.normalize(x, p=2, dim=1)
16 | * self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3))
17 |
18 | def reset_parameters(self):
19 | self.scale.data.fill_(self.initial_scale)
--------------------------------------------------------------------------------
/modeling/detection/nn/squeezenet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.init as init
5 | import torch.utils.model_zoo as model_zoo
6 |
7 |
8 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
9 |
10 |
11 | model_urls = {
12 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
13 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
14 | }
15 |
16 |
17 | class Fire(nn.Module):
18 |
19 | def __init__(self, inplanes, squeeze_planes,
20 | expand1x1_planes, expand3x3_planes):
21 | super(Fire, self).__init__()
22 | self.inplanes = inplanes
23 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
24 | self.squeeze_activation = nn.ReLU(inplace=True)
25 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
26 | kernel_size=1)
27 | self.expand1x1_activation = nn.ReLU(inplace=True)
28 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
29 | kernel_size=3, padding=1)
30 | self.expand3x3_activation = nn.ReLU(inplace=True)
31 |
32 | def forward(self, x):
33 | x = self.squeeze_activation(self.squeeze(x))
34 | return torch.cat([
35 | self.expand1x1_activation(self.expand1x1(x)),
36 | self.expand3x3_activation(self.expand3x3(x))
37 | ], 1)
38 |
39 |
40 | class SqueezeNet(nn.Module):
41 |
42 | def __init__(self, version=1.0, num_classes=1000):
43 | super(SqueezeNet, self).__init__()
44 | if version not in [1.0, 1.1]:
45 | raise ValueError("Unsupported SqueezeNet version {version}:"
46 | "1.0 or 1.1 expected".format(version=version))
47 | self.num_classes = num_classes
48 | if version == 1.0:
49 | self.features = nn.Sequential(
50 | nn.Conv2d(3, 96, kernel_size=7, stride=2),
51 | nn.ReLU(inplace=True),
52 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
53 | Fire(96, 16, 64, 64),
54 | Fire(128, 16, 64, 64),
55 | Fire(128, 32, 128, 128),
56 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
57 | Fire(256, 32, 128, 128),
58 | Fire(256, 48, 192, 192),
59 | Fire(384, 48, 192, 192),
60 | Fire(384, 64, 256, 256),
61 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
62 | Fire(512, 64, 256, 256),
63 | )
64 | else:
65 | self.features = nn.Sequential(
66 | nn.Conv2d(3, 64, kernel_size=3, stride=2),
67 | nn.ReLU(inplace=True),
68 | nn.MaxPool2d(kernel_size=3, stride=2),
69 | Fire(64, 16, 64, 64),
70 | Fire(128, 16, 64, 64),
71 | nn.MaxPool2d(kernel_size=3, stride=2),
72 | Fire(128, 32, 128, 128),
73 | Fire(256, 32, 128, 128),
74 | nn.MaxPool2d(kernel_size=3, stride=2),
75 | Fire(256, 48, 192, 192),
76 | Fire(384, 48, 192, 192),
77 | Fire(384, 64, 256, 256),
78 | Fire(512, 64, 256, 256),
79 | )
80 | # Final convolution is initialized differently form the rest
81 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
82 | self.classifier = nn.Sequential(
83 | nn.Dropout(p=0.5),
84 | final_conv,
85 | nn.ReLU(inplace=True),
86 | nn.AvgPool2d(13, stride=1)
87 | )
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | if m is final_conv:
92 | init.normal_(m.weight, mean=0.0, std=0.01)
93 | else:
94 | init.kaiming_uniform_(m.weight)
95 | if m.bias is not None:
96 | init.constant_(m.bias, 0)
97 |
98 | def forward(self, x):
99 | x = self.features(x)
100 | x = self.classifier(x)
101 | return x.view(x.size(0), self.num_classes)
102 |
103 |
104 | def squeezenet1_0(pretrained=False, **kwargs):
105 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level
106 | accuracy with 50x fewer parameters and <0.5MB model size"
107 | `_ paper.
108 |
109 | Args:
110 | pretrained (bool): If True, returns a model pre-trained on ImageNet
111 | """
112 | model = SqueezeNet(version=1.0, **kwargs)
113 | if pretrained:
114 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0']))
115 | return model
116 |
117 |
118 | def squeezenet1_1(pretrained=False, **kwargs):
119 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo
120 | `_.
121 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters
122 | than SqueezeNet 1.0, without sacrificing accuracy.
123 |
124 | Args:
125 | pretrained (bool): If True, returns a model pre-trained on ImageNet
126 | """
127 | model = SqueezeNet(version=1.1, **kwargs)
128 | if pretrained:
129 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))
130 | return model
131 |
--------------------------------------------------------------------------------
/modeling/detection/nn/vgg.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | # borrowed from https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py
5 | def vgg(cfg, batch_norm=False):
6 | layers = []
7 | in_channels = 3
8 | for v in cfg:
9 | if v == 'M':
10 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
11 | elif v == 'C':
12 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
13 | else:
14 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
15 | if batch_norm:
16 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
17 | else:
18 | layers += [conv2d, nn.ReLU(inplace=True)]
19 | in_channels = v
20 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
21 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)
22 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
23 | layers += [pool5, conv6,
24 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
25 | return layers
--------------------------------------------------------------------------------
/modeling/detection/predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from utils.detection import box_utils
4 | from .data_preprocessing import PredictionTransform
5 | from utils.detection.misc import Timer
6 |
7 |
8 | class Predictor:
9 | def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None,
10 | iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None):
11 | self.net = net
12 | self.transform = PredictionTransform(size, mean, std)
13 | self.iou_threshold = iou_threshold
14 | self.filter_threshold = filter_threshold
15 | self.candidate_size = candidate_size
16 | self.nms_method = nms_method
17 |
18 | self.sigma = sigma
19 | if device:
20 | self.device = device
21 | else:
22 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23 |
24 | self.net.to(self.device)
25 | self.net.eval()
26 |
27 | # self.timer = Timer()
28 |
29 | def predict(self, image, top_k=-1, prob_threshold=None):
30 | cpu_device = torch.device("cpu")
31 | height, width, _ = image.shape
32 | image = self.transform(image)
33 | images = image.unsqueeze(0)
34 | images = images.to(self.device)
35 | with torch.no_grad():
36 | # self.timer.start()
37 | scores, boxes = self.net.forward(images)
38 | boxes = box_utils.convert_locations_to_boxes(*boxes)
39 | boxes = box_utils.center_form_to_corner_form(boxes)
40 | # scores, boxes = self.net.forward(images)
41 | # print("Inference time: ", self.timer.end())
42 | boxes = boxes[0]
43 | scores = scores[0]
44 | if not prob_threshold:
45 | prob_threshold = self.filter_threshold
46 | # this version of nms is slower on GPU, so we move data to CPU.
47 | boxes = boxes.to(cpu_device)
48 | scores = scores.to(cpu_device)
49 | picked_box_probs = []
50 | picked_labels = []
51 | for class_index in range(1, scores.size(1)):
52 | probs = scores[:, class_index]
53 | mask = probs > prob_threshold
54 | probs = probs[mask]
55 | if probs.size(0) == 0:
56 | continue
57 | subset_boxes = boxes[mask, :]
58 | box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1)
59 | box_probs = box_utils.nms(box_probs, self.nms_method,
60 | score_threshold=prob_threshold,
61 | iou_threshold=self.iou_threshold,
62 | sigma=self.sigma,
63 | top_k=top_k,
64 | candidate_size=self.candidate_size)
65 | picked_box_probs.append(box_probs)
66 | picked_labels.extend([class_index] * box_probs.size(0))
67 | if not picked_box_probs:
68 | return torch.tensor([]), torch.tensor([]), torch.tensor([])
69 | picked_box_probs = torch.cat(picked_box_probs)
70 | picked_box_probs[:, 0] *= width
71 | picked_box_probs[:, 1] *= height
72 | picked_box_probs[:, 2] *= width
73 | picked_box_probs[:, 3] *= height
74 | return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4]
--------------------------------------------------------------------------------
/modeling/detection/squeezenet_ssd_lite.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU
3 | from .nn.squeezenet import squeezenet1_1
4 |
5 | from .ssd import SSD
6 | from .predictor import Predictor
7 | from .config import squeezenet_ssd_config as config
8 |
9 |
10 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0):
11 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d.
12 | """
13 | return Sequential(
14 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size,
15 | groups=in_channels, stride=stride, padding=padding),
16 | ReLU(),
17 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1),
18 | )
19 |
20 |
21 | def create_squeezenet_ssd_lite(num_classes, is_test=False):
22 | base_net = squeezenet1_1(False).features # disable dropout layer
23 |
24 | source_layer_indexes = [
25 | 12
26 | ]
27 | extras = ModuleList([
28 | Sequential(
29 | Conv2d(in_channels=512, out_channels=256, kernel_size=1),
30 | ReLU(),
31 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=2),
32 | ),
33 | Sequential(
34 | Conv2d(in_channels=512, out_channels=256, kernel_size=1),
35 | ReLU(),
36 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
37 | ),
38 | Sequential(
39 | Conv2d(in_channels=512, out_channels=128, kernel_size=1),
40 | ReLU(),
41 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
42 | ),
43 | Sequential(
44 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
45 | ReLU(),
46 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
47 | ),
48 | Sequential(
49 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
50 | ReLU(),
51 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
52 | )
53 | ])
54 |
55 | regression_headers = ModuleList([
56 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
57 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
58 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
59 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
60 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
61 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=1),
62 | ])
63 |
64 | classification_headers = ModuleList([
65 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
66 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
67 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
68 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
69 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
70 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=1),
71 | ])
72 |
73 | return SSD(num_classes, base_net, source_layer_indexes,
74 | extras, classification_headers, regression_headers, is_test=is_test, config=config)
75 |
76 |
77 | def create_squeezenet_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')):
78 | predictor = Predictor(net, config.image_size, config.image_mean,
79 | config.image_std,
80 | nms_method=nms_method,
81 | iou_threshold=config.iou_threshold,
82 | candidate_size=candidate_size,
83 | sigma=sigma,
84 | device=device)
85 | return predictor
--------------------------------------------------------------------------------
/modeling/detection/ssd.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 | from typing import List, Tuple
5 | import torch.nn.functional as F
6 |
7 | from utils.detection import box_utils
8 | from collections import namedtuple
9 | GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1']) #
10 |
11 |
12 | class SSD(nn.Module):
13 | def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int],
14 | extras: nn.ModuleList, classification_headers: nn.ModuleList,
15 | regression_headers: nn.ModuleList, is_test=False, config=None, device=None):
16 | """Compose a SSD model using the given components.
17 | """
18 | super(SSD, self).__init__()
19 |
20 | self.num_classes = num_classes
21 | self.base_net = base_net
22 | self.source_layer_indexes = source_layer_indexes
23 | self.extras = extras
24 | self.classification_headers = classification_headers
25 | self.regression_headers = regression_headers
26 | self.is_test = is_test
27 | self.config = config
28 |
29 | # register layers in source_layer_indexes by adding them to a module list
30 | self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes
31 | if isinstance(t, tuple) and not isinstance(t, GraphPath)])
32 | if device:
33 | self.device = device
34 | else:
35 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
36 | if is_test:
37 | self.config = config
38 | self.priors = config.priors.to(self.device)
39 |
40 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
41 | confidences = []
42 | locations = []
43 | start_layer_index = 0
44 | header_index = 0
45 | for end_layer_index in self.source_layer_indexes:
46 | if isinstance(end_layer_index, GraphPath):
47 | path = end_layer_index
48 | end_layer_index = end_layer_index.s0
49 | added_layer = None
50 | elif isinstance(end_layer_index, tuple):
51 | added_layer = end_layer_index[1]
52 | end_layer_index = end_layer_index[0]
53 | path = None
54 | else:
55 | added_layer = None
56 | path = None
57 | for layer in self.base_net[start_layer_index: end_layer_index]:
58 | x = layer(x)
59 | if added_layer:
60 | y = added_layer(x)
61 | else:
62 | y = x
63 | if path:
64 | sub = getattr(self.base_net[end_layer_index], path.name)
65 | for layer in sub[:path.s1]:
66 | x = layer(x)
67 | y = x
68 | for layer in sub[path.s1:]:
69 | x = layer(x)
70 | end_layer_index += 1
71 | start_layer_index = end_layer_index
72 | confidence, location = self.compute_header(header_index, y)
73 | header_index += 1
74 | confidences.append(confidence)
75 | locations.append(location)
76 |
77 | for layer in self.base_net[end_layer_index:]:
78 | x = layer(x)
79 |
80 | for layer in self.extras:
81 | x = layer(x)
82 | confidence, location = self.compute_header(header_index, x)
83 | header_index += 1
84 | confidences.append(confidence)
85 | locations.append(location)
86 |
87 | confidences = torch.cat(confidences, 1)
88 | locations = torch.cat(locations, 1)
89 |
90 | if self.is_test:
91 | confidences = F.softmax(confidences, dim=2)
92 | return confidences, (locations, self.priors, self.config.center_variance, self.config.size_variance)
93 | # boxes = box_utils.convert_locations_to_boxes(
94 | # locations, self.priors, self.config.center_variance, self.config.size_variance
95 | # )
96 | # boxes = box_utils.center_form_to_corner_form(boxes)
97 | # return confidences, boxes
98 | else:
99 | return confidences, locations
100 |
101 | def compute_header(self, i, x):
102 | confidence = self.classification_headers[i](x)
103 | confidence = confidence.permute(0, 2, 3, 1).contiguous()
104 | confidence = confidence.view(confidence.size(0), -1, self.num_classes)
105 |
106 | location = self.regression_headers[i](x)
107 | location = location.permute(0, 2, 3, 1).contiguous()
108 | location = location.view(location.size(0), -1, 4)
109 |
110 | return confidence, location
111 |
112 | def init_from_base_net(self, model):
113 | self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=True)
114 | self.source_layer_add_ons.apply(_xavier_init_)
115 | self.extras.apply(_xavier_init_)
116 | self.classification_headers.apply(_xavier_init_)
117 | self.regression_headers.apply(_xavier_init_)
118 |
119 | def init_from_pretrained_ssd(self, model):
120 | state_dict = torch.load(model, map_location=lambda storage, loc: storage)
121 | state_dict = {k: v for k, v in state_dict.items() if not (k.startswith("classification_headers") or k.startswith("regression_headers"))}
122 | model_dict = self.state_dict()
123 | model_dict.update(state_dict)
124 | self.load_state_dict(model_dict)
125 | self.classification_headers.apply(_xavier_init_)
126 | self.regression_headers.apply(_xavier_init_)
127 |
128 | def init(self):
129 | self.base_net.apply(_xavier_init_)
130 | self.source_layer_add_ons.apply(_xavier_init_)
131 | self.extras.apply(_xavier_init_)
132 | self.classification_headers.apply(_xavier_init_)
133 | self.regression_headers.apply(_xavier_init_)
134 |
135 | def load(self, model):
136 | self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
137 |
138 | def save(self, model_path):
139 | torch.save(self.state_dict(), model_path)
140 |
141 |
142 | class MatchPrior(object):
143 | def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold):
144 | self.center_form_priors = center_form_priors
145 | self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors)
146 | self.center_variance = center_variance
147 | self.size_variance = size_variance
148 | self.iou_threshold = iou_threshold
149 |
150 | def __call__(self, gt_boxes, gt_labels):
151 | if type(gt_boxes) is np.ndarray:
152 | gt_boxes = torch.from_numpy(gt_boxes)
153 | if type(gt_labels) is np.ndarray:
154 | gt_labels = torch.from_numpy(gt_labels)
155 | boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels,
156 | self.corner_form_priors, self.iou_threshold)
157 | boxes = box_utils.corner_form_to_center_form(boxes)
158 | locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance)
159 | return locations, labels
160 |
161 |
162 | def _xavier_init_(m: nn.Module):
163 | if isinstance(m, nn.Conv2d):
164 | nn.init.xavier_uniform_(m.weight)
165 |
--------------------------------------------------------------------------------
/modeling/detection/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/transforms/__init__.py
--------------------------------------------------------------------------------
/modeling/detection/vgg_ssd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU, BatchNorm2d
3 | from .nn.vgg import vgg
4 |
5 | from .ssd import SSD
6 | from .predictor import Predictor
7 | from .config import vgg_ssd_config as config
8 |
9 |
10 | def create_vgg_ssd(num_classes, is_test=False):
11 | vgg_config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',
12 | 512, 512, 512]
13 | base_net = ModuleList(vgg(vgg_config))
14 |
15 | source_layer_indexes = [
16 | (23, BatchNorm2d(512)),
17 | len(base_net),
18 | ]
19 | extras = ModuleList([
20 | Sequential(
21 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1),
22 | ReLU(),
23 | Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
24 | ReLU()
25 | ),
26 | Sequential(
27 | Conv2d(in_channels=512, out_channels=128, kernel_size=1),
28 | ReLU(),
29 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
30 | ReLU()
31 | ),
32 | Sequential(
33 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
34 | ReLU(),
35 | Conv2d(in_channels=128, out_channels=256, kernel_size=3),
36 | ReLU()
37 | ),
38 | Sequential(
39 | Conv2d(in_channels=256, out_channels=128, kernel_size=1),
40 | ReLU(),
41 | Conv2d(in_channels=128, out_channels=256, kernel_size=3),
42 | ReLU()
43 | )
44 | ])
45 |
46 | regression_headers = ModuleList([
47 | Conv2d(in_channels=512, out_channels=4 * 4, kernel_size=3, padding=1),
48 | Conv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1),
49 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1),
50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1),
51 | Conv2d(in_channels=256, out_channels=4 * 4, kernel_size=3, padding=1),
52 | Conv2d(in_channels=256, out_channels=4 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
53 | ])
54 |
55 | classification_headers = ModuleList([
56 | Conv2d(in_channels=512, out_channels=4 * num_classes, kernel_size=3, padding=1),
57 | Conv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1),
58 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1),
59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1),
60 | Conv2d(in_channels=256, out_channels=4 * num_classes, kernel_size=3, padding=1),
61 | Conv2d(in_channels=256, out_channels=4 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0?
62 | ])
63 |
64 | return SSD(num_classes, base_net, source_layer_indexes,
65 | extras, classification_headers, regression_headers, is_test=is_test, config=config)
66 |
67 |
68 | def create_vgg_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None):
69 | predictor = Predictor(net, config.image_size, config.image_mean,
70 | nms_method=nms_method,
71 | iou_threshold=config.iou_threshold,
72 | candidate_size=candidate_size,
73 | sigma=sigma,
74 | device=device)
75 | return predictor
76 |
--------------------------------------------------------------------------------
/modeling/detection/voc-model-labels.txt:
--------------------------------------------------------------------------------
1 | BACKGROUND
2 | aeroplane
3 | bicycle
4 | bird
5 | boat
6 | bottle
7 | bus
8 | car
9 | cat
10 | chair
11 | cow
12 | diningtable
13 | dog
14 | horse
15 | motorbike
16 | person
17 | pottedplant
18 | sheep
19 | sofa
20 | train
21 | tvmonitor
--------------------------------------------------------------------------------
/modeling/ncnn/model_quant_relu_equal.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/ncnn/model_quant_relu_equal.bin
--------------------------------------------------------------------------------
/modeling/ncnn/model_quant_relu_equal.param:
--------------------------------------------------------------------------------
1 | 7767517
2 | 112 122
3 | Input 0 0 1 0 0=224 1=224 2=3
4 | Convolution 619 1 1 0 619 0=32 1=3 3=2 4=1 5=1 6=864 8=2
5 | ReLU 621 1 1 619 621
6 | ConvolutionDepthWise 622 1 1 621 622 0=32 1=3 4=1 5=1 6=288 7=32 8=1
7 | ReLU 624 1 1 622 624
8 | Convolution 625 1 1 624 625 0=16 1=1 5=1 6=512 8=2
9 | Convolution 627 1 1 625 627 0=96 1=1 5=1 6=1536 8=2
10 | ReLU 629 1 1 627 629
11 | ConvolutionDepthWise 630 1 1 629 630 0=96 1=3 3=2 4=1 5=1 6=864 7=96 8=1
12 | ReLU 632 1 1 630 632
13 | Convolution 633 1 1 632 633 0=24 1=1 5=1 6=2304 8=2
14 | Split splitncnn_0 1 2 633 633_splitncnn_0 633_splitncnn_1
15 | Convolution 635 1 1 633_splitncnn_1 635 0=144 1=1 5=1 6=3456 8=2
16 | ReLU 637 1 1 635 637
17 | ConvolutionDepthWise 638 1 1 637 638 0=144 1=3 4=1 5=1 6=1296 7=144 8=1
18 | ReLU 640 1 1 638 640
19 | Convolution 641 1 1 640 641 0=24 1=1 5=1 6=3456 8=2
20 | BinaryOp 643 2 1 633_splitncnn_0 641 643
21 | Convolution 644 1 1 643 644 0=144 1=1 5=1 6=3456 8=2
22 | ReLU 646 1 1 644 646
23 | ConvolutionDepthWise 647 1 1 646 647 0=144 1=3 3=2 4=1 5=1 6=1296 7=144 8=1
24 | ReLU 649 1 1 647 649
25 | Convolution 650 1 1 649 650 0=32 1=1 5=1 6=4608 8=2
26 | Split splitncnn_1 1 2 650 650_splitncnn_0 650_splitncnn_1
27 | Convolution 652 1 1 650_splitncnn_1 652 0=192 1=1 5=1 6=6144 8=2
28 | ReLU 654 1 1 652 654
29 | ConvolutionDepthWise 655 1 1 654 655 0=192 1=3 4=1 5=1 6=1728 7=192 8=1
30 | ReLU 657 1 1 655 657
31 | Convolution 658 1 1 657 658 0=32 1=1 5=1 6=6144 8=2
32 | BinaryOp 660 2 1 650_splitncnn_0 658 660
33 | Split splitncnn_2 1 2 660 660_splitncnn_0 660_splitncnn_1
34 | Convolution 661 1 1 660_splitncnn_1 661 0=192 1=1 5=1 6=6144 8=2
35 | ReLU 663 1 1 661 663
36 | ConvolutionDepthWise 664 1 1 663 664 0=192 1=3 4=1 5=1 6=1728 7=192 8=1
37 | ReLU 666 1 1 664 666
38 | Convolution 667 1 1 666 667 0=32 1=1 5=1 6=6144 8=2
39 | BinaryOp 669 2 1 660_splitncnn_0 667 669
40 | Convolution 670 1 1 669 670 0=192 1=1 5=1 6=6144 8=2
41 | ReLU 672 1 1 670 672
42 | ConvolutionDepthWise 673 1 1 672 673 0=192 1=3 3=2 4=1 5=1 6=1728 7=192 8=1
43 | ReLU 675 1 1 673 675
44 | Convolution 676 1 1 675 676 0=64 1=1 5=1 6=12288 8=2
45 | Split splitncnn_3 1 2 676 676_splitncnn_0 676_splitncnn_1
46 | Convolution 678 1 1 676_splitncnn_1 678 0=384 1=1 5=1 6=24576 8=2
47 | ReLU 680 1 1 678 680
48 | ConvolutionDepthWise 681 1 1 680 681 0=384 1=3 4=1 5=1 6=3456 7=384 8=1
49 | ReLU 683 1 1 681 683
50 | Convolution 684 1 1 683 684 0=64 1=1 5=1 6=24576 8=2
51 | BinaryOp 686 2 1 676_splitncnn_0 684 686
52 | Split splitncnn_4 1 2 686 686_splitncnn_0 686_splitncnn_1
53 | Convolution 687 1 1 686_splitncnn_1 687 0=384 1=1 5=1 6=24576 8=2
54 | ReLU 689 1 1 687 689
55 | ConvolutionDepthWise 690 1 1 689 690 0=384 1=3 4=1 5=1 6=3456 7=384 8=1
56 | ReLU 692 1 1 690 692
57 | Convolution 693 1 1 692 693 0=64 1=1 5=1 6=24576 8=2
58 | BinaryOp 695 2 1 686_splitncnn_0 693 695
59 | Split splitncnn_5 1 2 695 695_splitncnn_0 695_splitncnn_1
60 | Convolution 696 1 1 695_splitncnn_1 696 0=384 1=1 5=1 6=24576 8=2
61 | ReLU 698 1 1 696 698
62 | ConvolutionDepthWise 699 1 1 698 699 0=384 1=3 4=1 5=1 6=3456 7=384 8=1
63 | ReLU 701 1 1 699 701
64 | Convolution 702 1 1 701 702 0=64 1=1 5=1 6=24576 8=2
65 | BinaryOp 704 2 1 695_splitncnn_0 702 704
66 | Convolution 705 1 1 704 705 0=384 1=1 5=1 6=24576 8=2
67 | ReLU 707 1 1 705 707
68 | ConvolutionDepthWise 708 1 1 707 708 0=384 1=3 4=1 5=1 6=3456 7=384 8=1
69 | ReLU 710 1 1 708 710
70 | Convolution 711 1 1 710 711 0=96 1=1 5=1 6=36864 8=2
71 | Split splitncnn_6 1 2 711 711_splitncnn_0 711_splitncnn_1
72 | Convolution 713 1 1 711_splitncnn_1 713 0=576 1=1 5=1 6=55296 8=2
73 | ReLU 715 1 1 713 715
74 | ConvolutionDepthWise 716 1 1 715 716 0=576 1=3 4=1 5=1 6=5184 7=576 8=1
75 | ReLU 718 1 1 716 718
76 | Convolution 719 1 1 718 719 0=96 1=1 5=1 6=55296 8=2
77 | BinaryOp 721 2 1 711_splitncnn_0 719 721
78 | Split splitncnn_7 1 2 721 721_splitncnn_0 721_splitncnn_1
79 | Convolution 722 1 1 721_splitncnn_1 722 0=576 1=1 5=1 6=55296 8=2
80 | ReLU 724 1 1 722 724
81 | ConvolutionDepthWise 725 1 1 724 725 0=576 1=3 4=1 5=1 6=5184 7=576 8=1
82 | ReLU 727 1 1 725 727
83 | Convolution 728 1 1 727 728 0=96 1=1 5=1 6=55296 8=2
84 | BinaryOp 730 2 1 721_splitncnn_0 728 730
85 | Convolution 731 1 1 730 731 0=576 1=1 5=1 6=55296 8=2
86 | ReLU 733 1 1 731 733
87 | ConvolutionDepthWise 734 1 1 733 734 0=576 1=3 3=2 4=1 5=1 6=5184 7=576 8=1
88 | ReLU 736 1 1 734 736
89 | Convolution 737 1 1 736 737 0=160 1=1 5=1 6=92160 8=2
90 | Split splitncnn_8 1 2 737 737_splitncnn_0 737_splitncnn_1
91 | Convolution 739 1 1 737_splitncnn_1 739 0=960 1=1 5=1 6=153600 8=2
92 | ReLU 741 1 1 739 741
93 | ConvolutionDepthWise 742 1 1 741 742 0=960 1=3 4=1 5=1 6=8640 7=960 8=1
94 | ReLU 744 1 1 742 744
95 | Convolution 745 1 1 744 745 0=160 1=1 5=1 6=153600 8=2
96 | BinaryOp 747 2 1 737_splitncnn_0 745 747
97 | Split splitncnn_9 1 2 747 747_splitncnn_0 747_splitncnn_1
98 | Convolution 748 1 1 747_splitncnn_1 748 0=960 1=1 5=1 6=153600 8=2
99 | ReLU 750 1 1 748 750
100 | ConvolutionDepthWise 751 1 1 750 751 0=960 1=3 4=1 5=1 6=8640 7=960 8=1
101 | ReLU 753 1 1 751 753
102 | Convolution 754 1 1 753 754 0=160 1=1 5=1 6=153600 8=2
103 | BinaryOp 756 2 1 747_splitncnn_0 754 756
104 | Convolution 757 1 1 756 757 0=960 1=1 5=1 6=153600 8=2
105 | ReLU 759 1 1 757 759
106 | ConvolutionDepthWise 760 1 1 759 760 0=960 1=3 4=1 5=1 6=8640 7=960 8=1
107 | ReLU 762 1 1 760 762
108 | Convolution 763 1 1 762 763 0=320 1=1 5=1 6=307200 8=2
109 | Convolution 765 1 1 763 765 0=1280 1=1 5=1 6=409600 8=2
110 | ReLU 767 1 1 765 767
111 | Reshape 779 1 1 767 779 0=-1 1=1280
112 | Reduction 780 1 1 779 780 0=3 1=0 -23303=1,-1
113 | InnerProduct 781 1 1 780 781 0=1000 1=1 2=1280000 8=2
114 | Softmax 782 1 1 781 782
115 |
--------------------------------------------------------------------------------
/modeling/segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/segmentation/__init__.py
--------------------------------------------------------------------------------
/modeling/segmentation/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | BatchNorm(256),
58 | nn.ReLU())
59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60 | self.bn1 = BatchNorm(256)
61 | self.relu = nn.ReLU()
62 | self.dropout = nn.Dropout(0.5)
63 | self._init_weight()
64 |
65 | def forward(self, x):
66 | x1 = self.aspp1(x)
67 | x2 = self.aspp2(x)
68 | x3 = self.aspp3(x)
69 | x4 = self.aspp4(x)
70 | x5 = self.global_avg_pool(x)
71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73 |
74 | x = self.conv1(x)
75 | x = self.bn1(x)
76 | x = self.relu(x)
77 |
78 | return self.dropout(x)
79 |
80 | def _init_weight(self):
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84 | # m.weight.data.normal_(0, math.sqrt(2. / n))
85 | torch.nn.init.kaiming_normal_(m.weight)
86 | elif isinstance(m, SynchronizedBatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 |
94 | def build_aspp(backbone, output_stride, BatchNorm):
95 | return ASPP(backbone, output_stride, BatchNorm)
--------------------------------------------------------------------------------
/modeling/segmentation/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from modeling.segmentation.backbone import resnet, xception, drn, mobilenet
2 |
3 | def build_backbone(backbone, output_stride, BatchNorm):
4 | if backbone == 'resnet':
5 | return resnet.ResNet101(output_stride, BatchNorm)
6 | elif backbone == 'xception':
7 | return xception.AlignedXception(output_stride, BatchNorm)
8 | elif backbone == 'drn':
9 | return drn.drn_d_54(BatchNorm)
10 | elif backbone == 'mobilenet':
11 | return mobilenet.MobileNetV2(output_stride, BatchNorm)
12 | else:
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/modeling/segmentation/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | def conv_bn(inp, oup, stride, BatchNorm):
9 | return nn.Sequential(
10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11 | BatchNorm(oup),
12 | nn.ReLU6(inplace=True)
13 | )
14 |
15 |
16 | def fixed_padding(inputs, kernel_size, dilation):
17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18 | pad_total = kernel_size_effective - 1
19 | pad_beg = pad_total // 2
20 | pad_end = pad_total - pad_beg
21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22 | return padded_inputs
23 |
24 |
25 | class InvertedResidual(nn.Module):
26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27 | super(InvertedResidual, self).__init__()
28 | self.stride = stride
29 | assert stride in [1, 2]
30 |
31 | hidden_dim = round(inp * expand_ratio)
32 | self.use_res_connect = self.stride == 1 and inp == oup
33 | self.kernel_size = 3
34 | self.dilation = dilation
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40 | BatchNorm(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44 | BatchNorm(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50 | BatchNorm(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54 | BatchNorm(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58 | BatchNorm(oup),
59 | )
60 |
61 | def forward(self, x):
62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63 | if self.use_res_connect:
64 | x = x + self.conv(x_pad)
65 | else:
66 | x = self.conv(x_pad)
67 | return x
68 |
69 |
70 | class MobileNetV2(nn.Module):
71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=False):
72 | super(MobileNetV2, self).__init__()
73 | block = InvertedResidual
74 | input_channel = 32
75 | current_stride = 1
76 | rate = 1
77 | interverted_residual_setting = [
78 | # t, c, n, s
79 | [1, 16, 1, 1],
80 | [6, 24, 2, 2],
81 | [6, 32, 3, 2],
82 | [6, 64, 4, 2],
83 | [6, 96, 3, 1],
84 | [6, 160, 3, 2],
85 | [6, 320, 1, 1],
86 | ]
87 |
88 | # building first layer
89 | input_channel = int(input_channel * width_mult)
90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91 | current_stride *= 2
92 | # building inverted residual blocks
93 | for t, c, n, s in interverted_residual_setting:
94 | if current_stride == output_stride:
95 | stride = 1
96 | dilation = rate
97 | rate *= s
98 | else:
99 | stride = s
100 | dilation = 1
101 | current_stride *= s
102 | output_channel = int(c * width_mult)
103 | for i in range(n):
104 | if i == 0:
105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106 | else:
107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108 | input_channel = output_channel
109 | self.features = nn.Sequential(*self.features)
110 | self._initialize_weights()
111 |
112 | if pretrained:
113 | self._load_pretrained_model()
114 |
115 | self.low_level_features = self.features[0:4]
116 | self.high_level_features = self.features[4:]
117 |
118 | def forward(self, x):
119 | low_level_feat = self.low_level_features(x)
120 | x = self.high_level_features(low_level_feat)
121 | return x, low_level_feat
122 |
123 | def _load_pretrained_model(self):
124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125 | model_dict = {}
126 | state_dict = self.state_dict()
127 | for k, v in pretrain_dict.items():
128 | if k in state_dict:
129 | model_dict[k] = v
130 | state_dict.update(model_dict)
131 | self.load_state_dict(state_dict)
132 |
133 | def _initialize_weights(self):
134 | for m in self.modules():
135 | if isinstance(m, nn.Conv2d):
136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | # m.weight.data.normal_(0, math.sqrt(2. / n))
138 | torch.nn.init.kaiming_normal_(m.weight)
139 | elif isinstance(m, SynchronizedBatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
142 | elif isinstance(m, nn.BatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 |
146 | if __name__ == "__main__":
147 | input = torch.rand(1, 3, 512, 512)
148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149 | output, low_level_feat = model(input)
150 | print(output.size())
151 | print(low_level_feat.size())
152 |
--------------------------------------------------------------------------------
/modeling/segmentation/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 |
9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10 | super(Bottleneck, self).__init__()
11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12 | self.bn1 = BatchNorm(planes)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.bn2 = BatchNorm(planes)
16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17 | self.bn3 = BatchNorm(planes * 4)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = downsample
20 | self.stride = stride
21 | self.dilation = dilation
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv3(out)
35 | out = self.bn3(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 | class ResNet(nn.Module):
46 |
47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48 | self.inplanes = 64
49 | super(ResNet, self).__init__()
50 | blocks = [1, 2, 4]
51 | if output_stride == 16:
52 | strides = [1, 2, 2, 1]
53 | dilations = [1, 1, 1, 2]
54 | elif output_stride == 8:
55 | strides = [1, 2, 1, 1]
56 | dilations = [1, 1, 2, 4]
57 | else:
58 | raise NotImplementedError
59 |
60 | # Modules
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = BatchNorm(64)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66 |
67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72 | self._init_weight()
73 |
74 | if pretrained:
75 | self._load_pretrained_model()
76 |
77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78 | downsample = None
79 | if stride != 1 or self.inplanes != planes * block.expansion:
80 | downsample = nn.Sequential(
81 | nn.Conv2d(self.inplanes, planes * block.expansion,
82 | kernel_size=1, stride=stride, bias=False),
83 | BatchNorm(planes * block.expansion),
84 | )
85 |
86 | layers = []
87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88 | self.inplanes = planes * block.expansion
89 | for i in range(1, blocks):
90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91 |
92 | return nn.Sequential(*layers)
93 |
94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95 | downsample = None
96 | if stride != 1 or self.inplanes != planes * block.expansion:
97 | downsample = nn.Sequential(
98 | nn.Conv2d(self.inplanes, planes * block.expansion,
99 | kernel_size=1, stride=stride, bias=False),
100 | BatchNorm(planes * block.expansion),
101 | )
102 |
103 | layers = []
104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105 | downsample=downsample, BatchNorm=BatchNorm))
106 | self.inplanes = planes * block.expansion
107 | for i in range(1, len(blocks)):
108 | layers.append(block(self.inplanes, planes, stride=1,
109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110 |
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, input):
114 | x = self.conv1(input)
115 | x = self.bn1(x)
116 | x = self.relu(x)
117 | x = self.maxpool(x)
118 |
119 | x = self.layer1(x)
120 | low_level_feat = x
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 | return x, low_level_feat
125 |
126 | def _init_weight(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, SynchronizedBatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.BatchNorm2d):
135 | m.weight.data.fill_(1)
136 | m.bias.data.zero_()
137 |
138 | def _load_pretrained_model(self):
139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140 | model_dict = {}
141 | state_dict = self.state_dict()
142 | for k, v in pretrain_dict.items():
143 | if k in state_dict:
144 | model_dict[k] = v
145 | state_dict.update(model_dict)
146 | self.load_state_dict(state_dict)
147 |
148 | def ResNet101(output_stride, BatchNorm, pretrained=True):
149 | """Constructs a ResNet-101 model.
150 | Args:
151 | pretrained (bool): If True, returns a model pre-trained on ImageNet
152 | """
153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
154 | return model
155 |
156 | if __name__ == "__main__":
157 | import torch
158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
159 | input = torch.rand(1, 3, 512, 512)
160 | output, low_level_feat = model(input)
161 | print(output.size())
162 | print(low_level_feat.size())
--------------------------------------------------------------------------------
/modeling/segmentation/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm):
9 | super(Decoder, self).__init__()
10 | if backbone == 'resnet' or backbone == 'drn':
11 | low_level_inplanes = 256
12 | elif backbone == 'xception':
13 | low_level_inplanes = 128
14 | elif backbone == 'mobilenet':
15 | low_level_inplanes = 24
16 | else:
17 | raise NotImplementedError
18 |
19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20 | self.bn1 = BatchNorm(48)
21 | self.relu = nn.ReLU()
22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23 | BatchNorm(256),
24 | nn.ReLU(),
25 | nn.Dropout(0.5),
26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27 | BatchNorm(256),
28 | nn.ReLU(),
29 | nn.Dropout(0.1),
30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
31 | self._init_weight()
32 |
33 |
34 | def forward(self, x, low_level_feat):
35 | low_level_feat = self.conv1(low_level_feat)
36 | low_level_feat = self.bn1(low_level_feat)
37 | low_level_feat = self.relu(low_level_feat)
38 |
39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
40 | x = torch.cat((x, low_level_feat), dim=1)
41 | x = self.last_conv(x)
42 |
43 | return x
44 |
45 | def _init_weight(self):
46 | for m in self.modules():
47 | if isinstance(m, nn.Conv2d):
48 | torch.nn.init.kaiming_normal_(m.weight)
49 | elif isinstance(m, SynchronizedBatchNorm2d):
50 | m.weight.data.fill_(1)
51 | m.bias.data.zero_()
52 | elif isinstance(m, nn.BatchNorm2d):
53 | m.weight.data.fill_(1)
54 | m.bias.data.zero_()
55 |
56 | def build_decoder(num_classes, backbone, BatchNorm):
57 | return Decoder(num_classes, backbone, BatchNorm)
--------------------------------------------------------------------------------
/modeling/segmentation/deeplab-mobilenet.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/segmentation/deeplab-mobilenet.pth.tar
--------------------------------------------------------------------------------
/modeling/segmentation/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 | from modeling.segmentation.aspp import build_aspp
6 | from modeling.segmentation.decoder import build_decoder
7 | from modeling.segmentation.backbone import build_backbone
8 |
9 | class DeepLab(nn.Module):
10 | def __init__(self, backbone='mobilenet', output_stride=16, num_classes=21,
11 | sync_bn=True, freeze_bn=False):
12 | super(DeepLab, self).__init__()
13 | if backbone == 'drn':
14 | output_stride = 8
15 |
16 | if sync_bn == True:
17 | BatchNorm = SynchronizedBatchNorm2d
18 | else:
19 | BatchNorm = nn.BatchNorm2d
20 |
21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm)
22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm)
24 |
25 | if freeze_bn:
26 | self.freeze_bn()
27 |
28 | def forward(self, input):
29 | x, low_level_feat = self.backbone(input)
30 | x = self.aspp(x)
31 | x = self.decoder(x, low_level_feat)
32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
33 |
34 | return x
35 |
36 | def freeze_bn(self):
37 | for m in self.modules():
38 | if isinstance(m, SynchronizedBatchNorm2d):
39 | m.eval()
40 | elif isinstance(m, nn.BatchNorm2d):
41 | m.eval()
42 |
43 | def get_1x_lr_params(self):
44 | modules = [self.backbone]
45 | for i in range(len(modules)):
46 | for m in modules[i].named_modules():
47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
48 | or isinstance(m[1], nn.BatchNorm2d):
49 | for p in m[1].parameters():
50 | if p.requires_grad:
51 | yield p
52 |
53 | def get_10x_lr_params(self):
54 | modules = [self.aspp, self.decoder]
55 | for i in range(len(modules)):
56 | for m in modules[i].named_modules():
57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
58 | or isinstance(m[1], nn.BatchNorm2d):
59 | for p in m[1].parameters():
60 | if p.requires_grad:
61 | yield p
62 |
63 |
64 | if __name__ == "__main__":
65 | model = DeepLab(backbone='mobilenet', output_stride=16)
66 | model.eval()
67 | input = torch.rand(1, 3, 513, 513)
68 | output = model(input)
69 | print(output.size())
70 |
71 |
72 |
--------------------------------------------------------------------------------
/modeling/segmentation/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
--------------------------------------------------------------------------------
/modeling/segmentation/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/modeling/segmentation/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
31 | Note that, as all modules are isomorphism, we assign each sub-module with a context
32 | (shared among multiple copies of this module on different devices).
33 | Through this context, different copies can share some information.
34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
35 | of any slave copies.
36 | """
37 | master_copy = modules[0]
38 | nr_modules = len(list(master_copy.modules()))
39 | ctxs = [CallbackContext() for _ in range(nr_modules)]
40 |
41 | for i, module in enumerate(modules):
42 | for j, m in enumerate(module.modules()):
43 | if hasattr(m, '__data_parallel_replicate__'):
44 | m.__data_parallel_replicate__(ctxs[j], i)
45 |
46 |
47 | class DataParallelWithCallback(DataParallel):
48 | """
49 | Data Parallel with a replication callback.
50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
51 | original `replicate` function.
52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
53 | Examples:
54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
56 | # sync_bn.__data_parallel_replicate__ will be invoked.
57 | """
58 |
59 | def replicate(self, module, device_ids):
60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
61 | execute_replication_callbacks(modules)
62 | return modules
63 |
64 |
65 | def patch_replication_callback(data_parallel):
66 | """
67 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
68 | Useful when you have customized `DataParallel` implementation.
69 | Examples:
70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
72 | > patch_replication_callback(sync_bn)
73 | # this is equivalent to
74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
76 | """
77 |
78 | assert isinstance(data_parallel, DataParallel)
79 |
80 | old_replicate = data_parallel.replicate
81 |
82 | @functools.wraps(old_replicate)
83 | def new_replicate(module, device_ids):
84 | modules = old_replicate(module, device_ids)
85 | execute_replication_callbacks(modules)
86 | return modules
87 |
88 | data_parallel.replicate = new_replicate
--------------------------------------------------------------------------------
/modeling/segmentation/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pydot==1.4.1
2 | torch==1.1.0
3 | matplotlib==3.1.0
4 | scipy==1.3.0
5 | numpy==1.16.4
6 | torchvision==0.3.0
7 | graphviz==0.10.1
8 | Pillow==8.1.1
9 | tqdm==4.47.0
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | def visualize_per_layer(param, title='test'):
2 | import matplotlib.pyplot as plt
3 | channel = 0
4 | param_list = []
5 | for idx in range(param.shape[channel]):
6 | # print(idx, param[idx].max(), param[idx].min())
7 | param_list.append(param[idx].cpu().numpy().reshape(-1))
8 |
9 | fig7, ax7 = plt.subplots()
10 | ax7.set_title(title)
11 | ax7.boxplot(param_list, showfliers=False)
12 | # plt.ylim(-70, 70)
13 | plt.show()
--------------------------------------------------------------------------------
/utils/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 |
--------------------------------------------------------------------------------
/utils/detection/measurements.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_average_precision(precision, recall):
5 | """
6 | It computes average precision based on the definition of Pascal Competition. It computes the under curve area
7 | of precision and recall. Recall follows the normal definition. Precision is a variant.
8 | pascal_precision[i] = typical_precision[i:].max()
9 | """
10 | # identical but faster version of new_precision[i] = old_precision[i:].max()
11 | precision = np.concatenate([[0.0], precision, [0.0]])
12 | for i in range(len(precision) - 1, 0, -1):
13 | precision[i - 1] = np.maximum(precision[i - 1], precision[i])
14 |
15 | # find the index where the value changes
16 | recall = np.concatenate([[0.0], recall, [1.0]])
17 | changing_points = np.where(recall[1:] != recall[:-1])[0]
18 |
19 | # compute under curve area
20 | areas = (recall[changing_points + 1] - recall[changing_points]) * precision[changing_points + 1]
21 | return areas.sum()
22 |
23 |
24 | def compute_voc2007_average_precision(precision, recall):
25 | ap = 0.
26 | for t in np.arange(0., 1.1, 0.1):
27 | if np.sum(recall >= t) == 0:
28 | p = 0
29 | else:
30 | p = np.max(precision[recall >= t])
31 | ap = ap + p / 11.
32 | return ap
33 |
--------------------------------------------------------------------------------
/utils/detection/misc.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 |
4 |
5 | def str2bool(s):
6 | return s.lower() in ('true', '1')
7 |
8 |
9 | class Timer:
10 | def __init__(self):
11 | self.clock = {}
12 |
13 | def start(self, key="default"):
14 | self.clock[key] = time.time()
15 |
16 | def end(self, key="default"):
17 | if key not in self.clock:
18 | raise Exception(f"{key} is not in the clock.")
19 | interval = time.time() - self.clock[key]
20 | del self.clock[key]
21 | return interval
22 |
23 |
24 | def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path):
25 | torch.save({
26 | 'epoch': epoch,
27 | 'model': net_state_dict,
28 | 'optimizer': optimizer_state_dict,
29 | 'best_score': best_score
30 | }, checkpoint_path)
31 | torch.save(net_state_dict, model_path)
32 |
33 |
34 | def load_checkpoint(checkpoint_path):
35 | return torch.load(checkpoint_path)
36 |
37 |
38 | def freeze_net_layers(net):
39 | for param in net.parameters():
40 | param.requires_grad = False
41 |
42 |
43 | def store_labels(path, labels):
44 | with open(path, "w") as f:
45 | f.write("\n".join(labels))
46 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class Evaluator(object):
5 | def __init__(self, num_class):
6 | self.num_class = num_class
7 | self.confusion_matrix = np.zeros((self.num_class,)*2)
8 |
9 | def Pixel_Accuracy(self):
10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
11 | return Acc
12 |
13 | def Pixel_Accuracy_Class(self):
14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
15 | Acc = np.nanmean(Acc)
16 | return Acc
17 |
18 | def Mean_Intersection_over_Union(self):
19 | MIoU = np.diag(self.confusion_matrix) / (
20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
21 | np.diag(self.confusion_matrix))
22 | MIoU = np.nanmean(MIoU)
23 | return MIoU
24 |
25 | def Frequency_Weighted_Intersection_over_Union(self):
26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
27 | iu = np.diag(self.confusion_matrix) / (
28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
29 | np.diag(self.confusion_matrix))
30 |
31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
32 | return FWIoU
33 |
34 | def _generate_matrix(self, gt_image, pre_image):
35 | mask = (gt_image >= 0) & (gt_image < self.num_class)
36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
37 | count = np.bincount(label, minlength=self.num_class**2)
38 | confusion_matrix = count.reshape(self.num_class, self.num_class)
39 | return confusion_matrix
40 |
41 | def add_batch(self, gt_image, pre_image):
42 | assert gt_image.shape == pre_image.shape
43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image)
44 |
45 | def reset(self):
46 | self.confusion_matrix = np.zeros((self.num_class,) * 2)
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/utils/relation.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from torch.nn import BatchNorm2d, ReLU, Dropout, AvgPool2d
3 | from utils.quantize import QConv2d, QuantMeasure
4 |
5 | class Relation():
6 | def __init__(self, layer_idx_1, layer_idx_2, bn_idx_1):
7 | self.layer_first = layer_idx_1
8 | self.layer_second = layer_idx_2
9 | self.bn_idx = bn_idx_1
10 | self.S = None
11 |
12 |
13 | def __repr__(self):
14 | return '({}, {})'.format(self.layer_first, self.layer_second)
15 |
16 |
17 | def get_idxs(self):
18 | return self.layer_first, self.layer_second, self.bn_idx
19 |
20 | def set_scale_vec(self, S):
21 | if self.S is None:
22 | self.S = S
23 | else:
24 | self.S *= S
25 |
26 | def get_scale_vec(self):
27 | return self.S
28 |
29 |
30 | def create_relation(graph, bottoms, targ_type=[QConv2d], delete_single=False):
31 | relation_dict = OrderedDict()
32 |
33 | def _find_prev(graph, bottoms, layer_idx, targ_type, top_counter): # find previous target layer to form relations
34 | bot = bottoms[layer_idx]
35 | last_bn = None
36 | while len(bot) == 1 and "Data" != bot[0] and top_counter[bot[0]] == 1:
37 | if type(graph[bot[0]]) == BatchNorm2d:
38 | last_bn = bot[0]
39 | if type(graph[bot[0]]) in targ_type:
40 | return bot[0], last_bn
41 |
42 | elif not(type(graph[bot[0]]) in [BatchNorm2d, ReLU, QuantMeasure, AvgPool2d] or
43 | (type(graph[bot[0]]) == str and ("F.pad" in bot[0] or "torch.mean" in bot[0]))):
44 | return None, None
45 |
46 | bot = bottoms[bot[0]]
47 |
48 | return None, None
49 |
50 | top_counter = {} #count the number of output branches of each layer
51 | for layer_idx in graph:
52 | if layer_idx == "Data":
53 | continue
54 | for bot in bottoms[layer_idx]:
55 | if bot in top_counter:
56 | top_counter[bot] += 1
57 | else:
58 | top_counter[bot] = 1
59 |
60 | # find relation pair for each layer
61 | for layer_idx in graph:
62 | if type(graph[layer_idx]) in targ_type:
63 | prev, bn = _find_prev(graph, bottoms, layer_idx, targ_type, top_counter)
64 | if prev in relation_dict:
65 | relation_dict.pop(prev)
66 | elif prev is not None:
67 | rel = Relation(prev, layer_idx, bn)
68 | relation_dict[prev] = rel
69 |
70 | if delete_single:
71 | # only take the relations with more than 3 targ_layers, ex: Conv2d->Conv2d->Conv2d,, ignore Conv2d->Conv2d (in detection task)
72 | tmp = list(relation_dict.values())
73 | res_group = []
74 | for rr in tmp:
75 | group_idx = -1
76 | for idx, group in enumerate(res_group):
77 | for rr_prev in group:
78 | if rr.get_idxs()[0] == rr_prev.get_idxs()[1]:
79 | group_idx = idx
80 | break
81 | if group_idx != -1:
82 | res_group[group_idx].append(rr)
83 | else:
84 | res_group.append([rr])
85 | res = []
86 | for group in res_group:
87 | if len(group) > 1:
88 | res.extend(group)
89 |
90 | # print(len(res), len(list(relation_dict.values())))
91 | else:
92 | res = list(relation_dict.values())
93 |
94 | return res #list(relation_dict.values())
95 |
--------------------------------------------------------------------------------
/utils/segmentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/utils/segmentation/__init__.py
--------------------------------------------------------------------------------
/utils/segmentation/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import torch
4 | import cv2
5 |
6 | from utils.metrics import Evaluator
7 |
8 |
9 | def forward_all(net_inference, dataloader, visualize=False, opt=None):
10 | evaluator = Evaluator(21)
11 | evaluator.reset()
12 | with torch.no_grad():
13 | for ii, sample in enumerate(dataloader):
14 | image, label = sample['image'].cuda(), sample['label'].cuda()
15 |
16 | activations = net_inference(image)
17 |
18 | image = image.cpu().numpy()
19 | label = label.cpu().numpy().astype(np.uint8)
20 |
21 | logits = activations[list(activations.keys())[-1]] if type(activations) != torch.Tensor else activations
22 | pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8)
23 |
24 | evaluator.add_batch(label, pred)
25 |
26 | # print(label.shape, pred.shape)
27 | if visualize:
28 | for jj in range(sample["image"].size()[0]):
29 | segmap_label = decode_segmap(label[jj], dataset='pascal')
30 | segmap_pred = decode_segmap(pred[jj], dataset='pascal')
31 |
32 | img_tmp = np.transpose(image[jj], axes=[1, 2, 0])
33 | img_tmp *= (0.229, 0.224, 0.225)
34 | img_tmp += (0.485, 0.456, 0.406)
35 | img_tmp *= 255.0
36 | img_tmp = img_tmp.astype(np.uint8)
37 |
38 | cv2.imshow('image', img_tmp[:, :, [2,1,0]])
39 | cv2.imshow('gt', segmap_label)
40 | cv2.imshow('pred', segmap_pred)
41 | cv2.waitKey(0)
42 |
43 | Acc = evaluator.Pixel_Accuracy()
44 | Acc_class = evaluator.Pixel_Accuracy_Class()
45 | mIoU = evaluator.Mean_Intersection_over_Union()
46 | FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
47 | print("Acc: {}".format(Acc))
48 | print("Acc_class: {}".format(Acc_class))
49 | print("mIoU: {}".format(mIoU))
50 | print("FWIoU: {}".format(FWIoU))
51 | if opt is not None:
52 | with open("seg_result.txt", 'a+') as ww:
53 | ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format(
54 | opt.dataset, opt.quantize, opt.relu, opt.equalize, opt.absorption, opt.correction, opt.clip_weight, opt.distill_range
55 | ))
56 | ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(Acc, Acc_class, mIoU, FWIoU))
57 |
58 |
59 | def decode_seg_map_sequence(label_masks, dataset='pascal'):
60 | rgb_masks = []
61 | for label_mask in label_masks:
62 | rgb_mask = decode_segmap(label_mask, dataset)
63 | rgb_masks.append(rgb_mask)
64 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2]))
65 | return rgb_masks
66 |
67 |
68 | def decode_segmap(label_mask, dataset, plot=False):
69 | """Decode segmentation class labels into a color image
70 | Args:
71 | label_mask (np.ndarray): an (M,N) array of integer values denoting
72 | the class label at each spatial location.
73 | plot (bool, optional): whether to show the resulting color image
74 | in a figure.
75 | Returns:
76 | (np.ndarray, optional): the resulting decoded color image.
77 | """
78 | if dataset == 'pascal' or dataset == 'coco':
79 | n_classes = 21
80 | label_colours = get_pascal_labels()
81 | elif dataset == 'cityscapes':
82 | n_classes = 19
83 | label_colours = get_cityscapes_labels()
84 | else:
85 | raise NotImplementedError
86 |
87 | r = label_mask.copy()
88 | g = label_mask.copy()
89 | b = label_mask.copy()
90 | for ll in range(0, n_classes):
91 | r[label_mask == ll] = label_colours[ll, 0]
92 | g[label_mask == ll] = label_colours[ll, 1]
93 | b[label_mask == ll] = label_colours[ll, 2]
94 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
95 | rgb[:, :, 0] = r / 255.0
96 | rgb[:, :, 1] = g / 255.0
97 | rgb[:, :, 2] = b / 255.0
98 | if plot:
99 | plt.imshow(rgb)
100 | plt.show()
101 | else:
102 | return rgb
103 |
104 |
105 | def encode_segmap(mask):
106 | """Encode segmentation label images as pascal classes
107 | Args:
108 | mask (np.ndarray): raw segmentation label image of dimension
109 | (M, N, 3), in which the Pascal classes are encoded as colours.
110 | Returns:
111 | (np.ndarray): class map with dimensions (M,N), where the value at
112 | a given location is the integer denoting the class index.
113 | """
114 | mask = mask.astype(int)
115 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
116 | for ii, label in enumerate(get_pascal_labels()):
117 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
118 | label_mask = label_mask.astype(int)
119 | return label_mask
120 |
121 |
122 | def get_cityscapes_labels():
123 | return np.array([
124 | [128, 64, 128],
125 | [244, 35, 232],
126 | [70, 70, 70],
127 | [102, 102, 156],
128 | [190, 153, 153],
129 | [153, 153, 153],
130 | [250, 170, 30],
131 | [220, 220, 0],
132 | [107, 142, 35],
133 | [152, 251, 152],
134 | [0, 130, 180],
135 | [220, 20, 60],
136 | [255, 0, 0],
137 | [0, 0, 142],
138 | [0, 0, 70],
139 | [0, 60, 100],
140 | [0, 80, 100],
141 | [0, 0, 230],
142 | [119, 11, 32]])
143 |
144 |
145 | def get_pascal_labels():
146 | """Load the mapping that associates pascal classes with label colors
147 | Returns:
148 | np.ndarray with dimensions (21, 3)
149 | """
150 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
151 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
152 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
153 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
154 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
155 | [0, 64, 128]])
--------------------------------------------------------------------------------