├── .gitmodules ├── CMakeLists.txt ├── LICENSE.md ├── README.md ├── ZeroQ ├── LICENSE ├── README.md ├── distill_data.py ├── reconstruct_data.py ├── requirements.txt ├── run.sh ├── uniform_test.py └── utils │ ├── __init__.py │ ├── data_utils.py │ ├── quantization_utils │ ├── quant_modules.py │ └── quant_utils.py │ ├── quantize_model.py │ └── train_utils.py ├── _512_train.txt ├── convert_ncnn.py ├── dataset ├── __init__.py ├── detection │ ├── __init__.py │ ├── open_images.py │ └── voc_dataset.py └── segmentation │ ├── __init__.py │ ├── custom_transforms.py │ ├── pascal.py │ └── utils.py ├── dfq.py ├── images ├── LE_distill.png ├── graph_cls.png ├── graph_deeplab.png └── graph_ssd.png ├── improve_dfq.py ├── inference_cls.cpp ├── main_cls.py ├── main_seg.py ├── main_ssd.py ├── modeling ├── __init__.py ├── classification │ ├── MobileNetV2.py │ └── mobilenetv2_1.0-f2a8633.pth.tar ├── detection │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── mobilenetv1_ssd_config.py │ │ ├── squeezenet_ssd_config.py │ │ └── vgg_ssd_config.py │ ├── data_preprocessing.py │ ├── fpn_mobilenetv1_ssd.py │ ├── fpn_ssd.py │ ├── mb2-ssd-lite-mp-0_686.pth │ ├── mobilenet_v2_ssd_lite.py │ ├── mobilenetv1_ssd.py │ ├── mobilenetv1_ssd_lite.py │ ├── nn │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── mobilenet.py │ │ ├── mobilenet_v2.py │ │ ├── multibox_loss.py │ │ ├── scaled_l2_norm.py │ │ ├── squeezenet.py │ │ └── vgg.py │ ├── predictor.py │ ├── squeezenet_ssd_lite.py │ ├── ssd.py │ ├── transforms │ │ ├── __init__.py │ │ └── transforms.py │ ├── vgg_ssd.py │ └── voc-model-labels.txt ├── ncnn │ ├── model_quant_relu_equal.bin │ ├── model_quant_relu_equal.param │ └── model_quant_relu_equal.table └── segmentation │ ├── __init__.py │ ├── aspp.py │ ├── backbone │ ├── __init__.py │ ├── drn.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py │ ├── decoder.py │ ├── deeplab-mobilenet.pth.tar │ ├── deeplab.py │ └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── requirements.txt └── utils ├── __init__.py ├── detection ├── __init__.py ├── box_utils.py ├── measurements.py └── misc.py ├── layer_transform.py ├── metrics.py ├── quantize.py ├── relation.py └── segmentation ├── __init__.py └── utils.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "PyTransformer"] 2 | path = PyTransformer 3 | url = https://github.com/ricky40403/PyTransformer 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.12) 2 | project(dfq) 3 | include_directories(/home/jakc4103/Documents/ncnn/src) 4 | include_directories(/home/jakc4103/Documents/ncnn/build/src) 5 | 6 | #openmp 7 | FIND_PACKAGE( OpenMP REQUIRED) 8 | if(OPENMP_FOUND) 9 | message("OPENMP FOUND") 10 | set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") 11 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") 12 | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") 13 | endif() 14 | 15 | #ncnn 16 | set(NCNN_LIBS /home/jakc4103/Documents/ncnn/build/install/lib/libncnn.a) 17 | set(NCNN_INCLUDE_DIRS /home/jakc4103/Documents/ncnn/build/install/include) 18 | include_directories(${NCNN_INCLUDE_DIRS}) 19 | 20 | #opencv 21 | find_package( OpenCV REQUIRED ) 22 | include_directories( ${OpenCV_INCLUDE_DIRS} ) 23 | 24 | add_executable(inference_cls inference_cls.cpp) 25 | target_link_libraries(inference_cls ${NCNN_LIBS}) 26 | target_link_libraries( inference_cls ${OpenCV_LIBS} ) -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 jakc4103 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DFQ 2 | PyTorch implementation of [Data Free Quantization Through Weight Equalization and Bias Correction](https://arxiv.org/abs/1906.04721) with some ideas from [ZeroQ: A Novel Zero Shot Quantization Framework](https://arxiv.org/abs/2001.00281). 3 | 4 | ## Results 5 | Int8**: Fake quantization; 8 bits weight, 8 bits activation, 16 bits bias 6 | Int8*: Fake quantization; 8 bits weight, 8 bits activation, 8 bits bias 7 | Int8': Fake quantization; 8 bits weight(symmetric), 8 bits activation(symmetric), 32 bits bias 8 | Int8: Int8 Inference using [ncnn](https://github.com/Tencent/ncnn); 8 bits weight(symmetric), 8 bits activation(symmetric), 32 bits bias 9 | 10 | ### On classification task 11 | - Tested with [MobileNetV2](https://github.com/tonylins/pytorch-mobilenet-v2) and [ResNet-18](https://pytorch.org/docs/stable/torchvision/models.html) 12 | - ImageNet validation set (Acc.) 13 | 14 | 15 |
MobileNetV2 ResNet-18
16 | 17 | model/precision | FP32 | Int8** | Int8* | Int8' | Int8
(FP32-69.19) 18 | -----------|------|------| ------ | ------|------ 19 | Original | 71.81 | 0.102 | 0.1 | 0.062 | 0.082 20 | +ReLU | 71.78 | 0.102 | 0.096 | 0.094 | 0.082 21 | +ReLU+LE | 71.78 | 70.32 | 68.78 | 67.5 | 65.21 22 | +ReLU+LE +DR | -- | 70.47 | 68.87 | -- | -- 23 | +BC | -- | 57.07 | 0.12 | 26.25 | 5.57 24 | +BC +clip_15 | -- | 65.37 | 0.13 | 65.96 | 45.13 25 | +ReLU+LE+BC | -- | 70.79 | 68.17 | 68.65 | 62.19 26 | +ReLU+LE+BC +DR | -- | 70.9 | 68.41 | -- | -- 27 | 28 |
29 | 30 | model/precision | FP32 | Int8** | Int8* 31 | -----------|------|------|------ 32 | Original | 69.76 | 69.13 | 69.09 33 | +ReLU | 69.76 | 69.13 | 69.09 34 | +ReLU+LE | 69.76 | 69.2 | 69.2 35 | +ReLU+LE +DR | -- | 67.74 | 67.75 36 | +BC | -- | 69.04 | 68.56 37 | +BC +clip_15 | -- | 69.04 | 68.56 38 | +ReLU+LE+BC | -- | 69.04 | 68.56 39 | +ReLU+LE+BC +DR | -- | 67.65 | 67.62 40 | 41 |
42 | 43 | ### On segmentation task 44 | - Tested with [Deeplab-v3-plus_mobilenetv2](https://github.com/jfzhang95/pytorch-deeplab-xception) 45 | 46 | 47 |
Pascal VOC 2012 val set (mIOU) Pascal VOC 2007 test set (mIOU)
48 | 49 | model/precision | FP32 | Int8**| Int8* 50 | ----------------|-------|-------|------ 51 | Original | 70.81 | 60.03 | 59.31 52 | +ReLU | 70.72 | 60.0 | 58.98 53 | +ReLU+LE | 70.72 | 66.22 | 66.0 54 | +ReLU+LE +DR | -- | 67.04 | 67.23 55 | +ReLU+BC | -- | 69.04 | 68.42 56 | +ReLU+BC +clip_15 | -- | 66.99 | 66.39 57 | +ReLU+LE+BC | -- | 69.46 | 69.22 58 | +ReLU+LE+BC +DR | -- | 70.12 | 69.7 59 | 60 | 61 | 62 | model/precision | FP32 | Int8** | Int8* 63 | ----------------|-------|-------|------- 64 | Original | 74.54 | 62.36 | 61.21 65 | +ReLU | 74.35 | 61.66 | 61.04 66 | +ReLU+LE | 74.35 | 69.47 | 69.6 67 | +ReLU+LE +DR | -- | 70.28 | 69.93 68 | +BC | -- | 72.1 | 70.97 69 | +BC +clip_15 | -- | 70.16 | 70.76 70 | +ReLU+LE+BC | -- | 72.84 | 72.58 71 | +ReLU+LE+BC +DR | -- | 73.5 | 73.04 72 | 73 |
74 | 75 | ### On detection task 76 | - Tested with [MobileNetV2 SSD-Lite model](https://github.com/qfgaohao/pytorch-ssd) 77 | 78 | 79 | 80 |
Pascal VOC 2012 val set (mAP with 12 metric) Pascal VOC 2007 test set (mAP with 07 metric)
81 | 82 | model/precision | FP32 | Int8**|Int8* 83 | -----------|------|------|------ 84 | Original | 78.51 | 77.71 | 77.86 85 | +ReLU | 75.42 | 75.74 | 75.58 86 | +ReLU+LE | 75.42 | 75.32 | 75.37 87 | +ReLU+LE +DR | -- | 74.65 | 74.32 88 | +BC | -- | 77.73 | 77.78 89 | +BC +clip_15 | -- | 77.73 | 77.78 90 | +ReLU+LE+BC | -- | 75.66 | 75.66 91 | +ReLU+LE+BC +DR | -- | 74.92 | 74.65 92 | 93 | 94 | 95 | model/precision | FP32 | Int8** | Int8* 96 | ----------------|-------|-------|------- 97 | Original | 68.70 | 68.47 | 68.49 98 | +ReLU | 65.47 | 65.36 | 65.56 99 | +ReLU+LE | 65.47 | 65.36 | 65.27 100 | +ReLU+LE +DR | -- | 64.53 | 64.46 101 | +BC | -- | 68.32 | 65.33 102 | +BC +clip_15 | -- | 68.32 | 65.33 103 | +ReLU+LE+BC | -- | 65.63 | 65.58 104 | +ReLU+LE+BC +DR | -- | 64.92 | 64.42 105 | 106 |
107 | 108 | ## Usage 109 | There are 6 arguments, all default to False 110 | 1. quantize: whether to quantize parameters and activations. 111 | 2. relu: whether to replace relu6 to relu. 112 | 3. equalize: whether to perform cross layer equalization. 113 | 4. correction: whether to apply bias correction 114 | 5. clip_weight: whether to clip weights in range [-15, 15] (for convolution and linear layer) 115 | 6. distill_range: whether to use distill data for setting min/max range of activation quantization 116 | 117 | run the equalized model by: 118 | ``` 119 | python main_cls.py --quantize --relu --equalize 120 | ``` 121 | 122 | run the equalized and bias-corrected model by: 123 | ``` 124 | python main_cls.py --quantize --relu --equalize --correction 125 | ``` 126 | 127 | run the equalized and bias-corrected model with distilled data by: 128 | ``` 129 | python main_cls.py --quantize --relu --equalize --correction --distill_range 130 | ``` 131 | 132 | export equalized and bias-corrected model to onnx and generage calibration table file: 133 | ``` 134 | python convert_ncnn.py --equalize --correction --quantize --relu --ncnn_build path_to_ncnn_build_folder 135 | ``` 136 | 137 | ## Note 138 | ### Distilled Data (2020/02/03 updated) 139 | According to recent paper [ZeroQ](https://github.com/amirgholami/ZeroQ), we can distill some fake data to match the statistics from batch-normalization layers, then use it to set the min/max value range of activation quantization. 140 | It does not need each conv followed by batch norm layer, and should produce better and **more stable** results using distilled data (the method from DFQ sometimes failed to find a good enough value range). 141 | 142 | Here are some modifications that differs from original ZeroQ implementation: 143 | 1. Initialization of distilled data 144 | 2. Early stop criterion 145 | 146 | ~~Also, I think it can be applied to optimizing cross layer equalization and bias correction. The results will be updated as long as I make it to work.~~ 147 | Using distilled data to do LE or BC did not perform as good as using estimation from batch norm layers, probably because of overfitting. 148 | 149 | ### Fake Quantization 150 | The 'Int8' model in this repo is actually simulation of 8 bits, the actual calculation is done in floating points. 151 | This is done by quantizing-dequantizing parameters in each layer and activation between 2 consecutive layers; 152 | Which means each tensor will have dtype 'float32', but there would be at most 256 (2^8) unique values in it. 153 | ``` 154 | Weight_quant(Int8) = Quant(Weight) 155 | Weight_quant(FP32) = Weight_quant(Int8*) = Dequant(Quant(Weight)) 156 | ``` 157 | 158 | ### 16-bits Quantization for Bias 159 | Somehow I cannot make **Bias-Correction** work on 8-bits bias quantization for all scenarios (even with data dependent correction). 160 | I am not sure how the original paper managed to do it with 8 bits quantization, but I guess they either use some non-uniform quantization techniques or use more bits for bias parameters as I do. 161 | 162 | ### Int8 inference 163 | Refer to [ncnn](https://github.com/Tencent/ncnn), [pytorch2ncnn](https://github.com/Tencent/ncnn/wiki/use-ncnn-with-pytorch-or-onnx), [ncnn-quantize](https://github.com/Tencent/ncnn/tree/master/tools/quantize), [ncnn-int8-inference](https://github.com/Tencent/ncnn/wiki/quantized-int8-inference) for more details. 164 | You will need to install/build the followings: 165 | [ncnn](https://github.com/Tencent/ncnn) 166 | [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) 167 | 168 | Inference_cls.cpp only implements mobilenetv2. Basic steps are: 169 | 170 | 1. Run convert_ncnn.py to convert pytorch model (with layer equalization or bias correction) to ncnn int8 model and generate calibration table file. The name of out_layer will be printed to console. 171 | ``` 172 | python convert_ncnn.py --quantize --relu --equalize --correction 173 | ``` 174 | 175 | 2. compile inference_cls.cpp 176 | ``` 177 | mkdir build 178 | cd build 179 | cmake .. 180 | make 181 | ``` 182 | 3. Inference! [link](https://github.com/Tencent/ncnn/wiki/quantized-int8-inference) 183 | ``` 184 | ./inference_cls --images=path_to_imagenet_validation_set --param=../modeling/ncnn/model_int8.param --bin=../modeling/ncnn/model_int8.bin --out_layer=name_from_step1 185 | ``` 186 | 187 | ## TODO 188 | - [x] cross layer equalization 189 | - [ ] high bias absorption 190 | - [x] data-free bias correction 191 | - [x] test with detection model 192 | - [x] test with classification model 193 | - [x] use distilled data to set min/max activation range 194 | - [ ] ~~use distilled data to find optimal scale matrix~~ 195 | - [ ] ~~use distilled data to do bias correction~~ 196 | - [x] True Int8 inference 197 | 198 | ## Acknowledgment 199 | - https://github.com/jfzhang95/pytorch-deeplab-xception 200 | - https://github.com/ricky40403/PyTransformer 201 | - https://github.com/qfgaohao/pytorch-ssd 202 | - https://github.com/tonylins/pytorch-mobilenet-v2 203 | - https://github.com/xxradon/PytorchToCaffe 204 | - https://github.com/amirgholami/ZeroQ 205 | -------------------------------------------------------------------------------- /ZeroQ/README.md: -------------------------------------------------------------------------------- 1 | # ZeroQ: A Novel Zero Shot Quantization Framework 2 | 3 | 4 | 5 | ## Introduction 6 | 7 | This repository contains the PyTorch implementation for the paper [*ZeroQ: A Novel Zero-Shot Quantization Framework*](https://arxiv.org/abs/2001.00281). 8 | 9 | ## TLDR; 10 | 11 | ```bash 12 | # Code is based on PyTorch 1.2 (Cuda10). Other dependancies could be installed as follows: 13 | pip install -r requirements.txt --user 14 | # Set a symbolic link to ImageNet validation data (used only to evaluate model) 15 | mkdir data 16 | ln -s /path/to/imagenet/ data/ 17 | ``` 18 | 19 | The folder structures should be the same as following 20 | ``` 21 | zeroq 22 | ├── utils 23 | ├── data 24 | │ ├── imagenet 25 | │ │ ├── val 26 | ``` 27 | Afterwards you can test Zero Shot quantization with W8A8 by running: 28 | 29 | ```bash 30 | bash run.sh 31 | ``` 32 | 33 | Below are the results that you should get for 8-bit quantization (**W8A8** refers to the quantizing model to 8-bit weights and 8-bit activations). 34 | 35 | 36 | | Models | Single Precision Top-1 | W8A8 Top-1 | 37 | | ----------------------------------------------- | :--------------------: | :--------: | 38 | | [ResNet18](https://arxiv.org/abs/1512.03385) | 71.47 | 71.43 | 39 | | [ResNet50](https://arxiv.org/abs/1512.03385) | 77.72 | 77.67 | 40 | | [InceptionV3](https://arxiv.org/abs/1512.00567) | 78.88 | 78.72 | 41 | | [MobileNetV2](https://arxiv.org/abs/1801.04381) | 73.03 | 72.91 | 42 | | [ShuffleNet](https://arxiv.org/abs/1707.01083) | 65.07 | 64.94 | 43 | | [SqueezeNext](https://arxiv.org/abs/1803.10615) | 69.38 | 69.17 | 44 | 45 | ## Evaluate 46 | 47 | - You can test a single model using the following command: 48 | 49 | ```bash 50 | export CUDA_VISIBLE_DEVICES=0 51 | python uniform_test.py [--dataset] [--model] [--batch_size] [--test_batch_size] 52 | 53 | optional arguments: 54 | --dataset type of dataset (default: imagenet) 55 | --model model to be quantized (default: resnet18) 56 | --batch-size batch size of distilled data (default: 64) 57 | --test-batch-size batch size of test data (default: 512) 58 | ``` 59 | 60 | 61 | 62 | 63 | ## Citation 64 | ZeroQ has been developed as part of the following paper. We appreciate it if you would please cite the following paper if you found the implementation useful for your work: 65 | 66 | Y. Cai, Z. Yao, Z. Dong, A. Gholami, M. W. Mahoney, K. Keutzer. *ZeroQ: A Novel Zero Shot Quantization Framework*, under review [[PDF](https://arxiv.org/pdf/2001.00281.pdf)]. 67 | 68 | -------------------------------------------------------------------------------- /ZeroQ/reconstruct_data.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import os 22 | import json 23 | import torch 24 | import torch.nn as nn 25 | import copy 26 | import torch.optim as optim 27 | from utils import * 28 | 29 | 30 | def own_loss(A, B): 31 | """ 32 | L-2 loss between A and B normalized by length. 33 | A and B should have the same length 34 | """ 35 | return (A - B).norm()**2 / A.size(0) 36 | 37 | 38 | class output_hook(object): 39 | """ 40 | Forward_hook used to get the output of intermediate layer. 41 | """ 42 | def __init__(self): 43 | super(output_hook, self).__init__() 44 | self.outputs = None 45 | 46 | def hook(self, module, input, output): 47 | self.outputs = output 48 | 49 | def clear(self): 50 | self.outputs = None 51 | 52 | 53 | def getReconData(teacher_model, 54 | dataset, 55 | batch_size, 56 | num_batch=1, 57 | for_inception=False): 58 | """ 59 | Generate distilled data according to the BatchNorm statistics in pretrained single-precision model. 60 | Only support single GPU. 61 | 62 | teacher_model: pretrained single-precision model 63 | dataset: the name of dataset 64 | batch_size: the batch size of generated distilled data 65 | num_batch: the number of batch of generated distilled data 66 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 67 | """ 68 | 69 | # initialize distilled data with random noise according to the dataset 70 | dataloader = getRandomData(dataset=dataset, 71 | batch_size=batch_size, 72 | for_inception=for_inception) 73 | 74 | eps = 1e-6 75 | # initialize hooks and single-precision model 76 | hooks, hook_handles, bn_stats, refined_gaussian = [], [], [], [] 77 | teacher_model = teacher_model.cuda() 78 | teacher_model = teacher_model.eval() 79 | 80 | # get number of BatchNorm layers in the model 81 | layers = sum([ 82 | 1 if isinstance(layer, nn.BatchNorm2d) else 0 83 | for layer in teacher_model.modules() 84 | ]) 85 | 86 | for n, m in teacher_model.named_modules(): 87 | if isinstance(m, nn.Conv2d) and len(hook_handles) < layers: 88 | # register hooks on the convolutional layers to get the intermediate output after convolution and before BatchNorm. 89 | hook = output_hook() 90 | hooks.append(hook) 91 | hook_handles.append(m.register_forward_hook(hook.hook)) 92 | if isinstance(m, nn.BatchNorm2d): 93 | # get the statistics in the BatchNorm layers 94 | bn_stats.append( 95 | (m.running_mean.detach().clone().flatten().cuda(), 96 | torch.sqrt(m.running_var + 97 | eps).detach().clone().flatten().cuda())) 98 | assert len(hooks) == len(bn_stats) 99 | 100 | for i, gaussian_data in enumerate(dataloader): 101 | if i == num_batch: 102 | break 103 | # initialize the criterion, optimizer, and scheduler 104 | gaussian_data = gaussian_data.cuda() 105 | gaussian_data.requires_grad = True 106 | crit = nn.CrossEntropyLoss().cuda() 107 | optimizer = optim.Adam([gaussian_data], lr=0.1) 108 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 109 | min_lr=1e-4, 110 | verbose=False, 111 | patience=100) 112 | 113 | input_mean = torch.zeros(1, 3).cuda() 114 | input_std = torch.ones(1, 3).cuda() 115 | 116 | for it in range(500): 117 | teacher_model.zero_grad() 118 | optimizer.zero_grad() 119 | for hook in hooks: 120 | hook.clear() 121 | output = teacher_model(gaussian_data) 122 | mean_loss = 0 123 | std_loss = 0 124 | 125 | # compute the loss according to the BatchNorm statistics and the statistics of intermediate output 126 | for cnt, (bn_stat, hook) in enumerate(zip(bn_stats, hooks)): 127 | tmp_output = hook.outputs 128 | bn_mean, bn_std = bn_stat[0], bn_stat[1] 129 | tmp_mean = torch.mean(tmp_output.view(tmp_output.size(0), 130 | tmp_output.size(1), -1), 131 | dim=2) 132 | tmp_std = torch.sqrt( 133 | torch.var(tmp_output.view(tmp_output.size(0), 134 | tmp_output.size(1), -1), 135 | dim=2) + eps) 136 | mean_loss += own_loss(bn_mean, tmp_mean) 137 | std_loss += own_loss(bn_std, tmp_std) 138 | tmp_mean = torch.mean(gaussian_data.view(gaussian_data.size(0), 3, 139 | -1), 140 | dim=2) 141 | tmp_std = torch.sqrt( 142 | torch.var(gaussian_data.view(gaussian_data.size(0), 3, -1), 143 | dim=2) + eps) 144 | mean_loss += own_loss(tmp_mean, input_mean) 145 | std_loss += own_loss(tmp_std, input_std) 146 | total_loss = mean_loss + std_loss 147 | 148 | # update the distilled data 149 | total_loss.backward() 150 | optimizer.step() 151 | scheduler.step(total_loss.item()) 152 | 153 | # early stop to prevent overfit 154 | if total_loss <= (layers + 1) * 5: 155 | break 156 | 157 | refined_gaussian.append(gaussian_data.detach().clone()) 158 | 159 | for handle in hook_handles: 160 | handle.remove() 161 | return refined_gaussian 162 | -------------------------------------------------------------------------------- /ZeroQ/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorchcv==0.0.51 2 | progressbar>=1.5 3 | -------------------------------------------------------------------------------- /ZeroQ/run.sh: -------------------------------------------------------------------------------- 1 | for MODEL in resnet18 resnet50 inceptionv3 mobilenetv2_w1 shufflenet_g1_w1 sqnxt23_w2 2 | do 3 | echo Testing $MODEL ... 4 | python uniform_test.py \ 5 | --dataset=imagenet \ 6 | --model=$MODEL \ 7 | --batch_size=64 \ 8 | --test_batch_size=512 9 | done 10 | -------------------------------------------------------------------------------- /ZeroQ/uniform_test.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import argparse 22 | import torch 23 | import numpy as np 24 | import torch.nn as nn 25 | from pytorchcv.model_provider import get_model as ptcv_get_model 26 | from utils import * 27 | from distill_data import * 28 | 29 | 30 | # model settings 31 | def arg_parse(): 32 | parser = argparse.ArgumentParser( 33 | description='This repository contains the PyTorch implementation for the paper ZeroQ: A Novel Zero-Shot Quantization Framework.') 34 | parser.add_argument('--dataset', 35 | type=str, 36 | default='imagenet', 37 | choices=['imagenet', 'cifar10'], 38 | help='type of dataset') 39 | parser.add_argument('--model', 40 | type=str, 41 | default='resnet18', 42 | choices=[ 43 | 'resnet18', 'resnet50', 'inceptionv3', 44 | 'mobilenetv2_w1', 'shufflenet_g1_w1', 45 | 'resnet20_cifar10', 'sqnxt23_w2' 46 | ], 47 | help='model to be quantized') 48 | parser.add_argument('--batch_size', 49 | type=int, 50 | default=32, 51 | help='batch size of distilled data') 52 | parser.add_argument('--test_batch_size', 53 | type=int, 54 | default=128, 55 | help='batch size of test data') 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | if __name__ == '__main__': 61 | args = arg_parse() 62 | torch.backends.cudnn.deterministic = True 63 | torch.backends.cudnn.benchmark = False 64 | 65 | # Load pretrained model 66 | model = ptcv_get_model(args.model, pretrained=True) 67 | print('****** Full precision model loaded ******') 68 | 69 | # Load validation data 70 | test_loader = getTestData(args.dataset, 71 | batch_size=args.test_batch_size, 72 | path='./data/imagenet/', 73 | for_inception=args.model.startswith('inception')) 74 | # Generate distilled data 75 | dataloader = getDistilData( 76 | model.cuda(), 77 | args.dataset, 78 | batch_size=args.batch_size, 79 | for_inception=args.model.startswith('inception')) 80 | print('****** Data loaded ******') 81 | 82 | # Quantize single-precision model to 8-bit model 83 | quantized_model = quantize_model(model) 84 | # Freeze BatchNorm statistics 85 | quantized_model.eval() 86 | quantized_model = quantized_model.cuda() 87 | 88 | # Update activation range according to distilled data 89 | update(quantized_model, dataloader) 90 | 91 | # Freeze activation range during test 92 | freeze_model(quantized_model) 93 | quantized_model = nn.DataParallel(quantized_model).cuda() 94 | 95 | # Test the final quantized model 96 | test(quantized_model, test_loader) 97 | -------------------------------------------------------------------------------- /ZeroQ/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantize_model import * 2 | from .data_utils import * 3 | # from .train_utils import * 4 | from .quantization_utils.quant_utils import * -------------------------------------------------------------------------------- /ZeroQ/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | from torch.utils.data import Dataset, DataLoader 22 | from torchvision import datasets, transforms 23 | import torch 24 | 25 | 26 | class UniformDataset(Dataset): 27 | """ 28 | get random uniform samples with mean 0 and variance 1 29 | """ 30 | def __init__(self, length, size, transform, max_value): 31 | self.length = length 32 | self.transform = transform 33 | self.size = size 34 | self.max_value = max_value 35 | 36 | def __len__(self): 37 | return self.length 38 | 39 | def __getitem__(self, idx): 40 | # var[U(-128, 127)] = (127 - (-128))**2 / 12 = 5418.75 41 | # sample = (torch.randint(high=255, size=self.size).float() - 42 | # 127.5) / 5418.75 43 | sample = ((torch.randint(high=255, size=self.size).float() - 127.) / 128.) * self.max_value 44 | return sample 45 | 46 | 47 | def getRandomData(dataset='cifar10', batch_size=512, for_inception=False, max_value=3.0, size=[224, 224]): 48 | """ 49 | get random sample dataloader 50 | dataset: name of the dataset 51 | batch_size: the batch size of random data 52 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 53 | """ 54 | if dataset == 'cifar10': 55 | size = (3, 32, 32) 56 | num_data = 10000 57 | elif dataset == 'imagenet': 58 | num_data = 10000 59 | # if not for_inception: 60 | # size = (3, 224, 224) 61 | # else: 62 | # size = (3, 299, 299) 63 | size = (3, size[0], size[1]) 64 | else: 65 | raise NotImplementedError 66 | dataset = UniformDataset(length=10000, size=size, transform=None, max_value=max_value) 67 | data_loader = DataLoader(dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=0) 71 | return data_loader 72 | 73 | 74 | def getTestData(dataset='imagenet', 75 | batch_size=1024, 76 | path='data/imagenet', 77 | for_inception=False): 78 | """ 79 | Get dataloader of testset 80 | dataset: name of the dataset 81 | batch_size: the batch size of random data 82 | path: the path to the data 83 | for_inception: whether the data is for Inception because inception has input size 299 rather than 224 84 | """ 85 | if dataset == 'imagenet': 86 | input_size = 299 if for_inception else 224 87 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 88 | std=[0.229, 0.224, 0.225]) 89 | test_dataset = datasets.ImageFolder( 90 | path + 'val', 91 | transforms.Compose([ 92 | transforms.Resize(int(input_size / 0.875)), 93 | transforms.CenterCrop(input_size), 94 | transforms.ToTensor(), 95 | normalize, 96 | ])) 97 | test_loader = DataLoader(test_dataset, 98 | batch_size=batch_size, 99 | shuffle=False, 100 | num_workers=32) 101 | return test_loader 102 | elif dataset == 'cifar10': 103 | data_dir = '/rscratch/yaohuic/data/' 104 | normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 105 | std=(0.2023, 0.1994, 0.2010)) 106 | transform_test = transforms.Compose([transforms.ToTensor(), normalize]) 107 | 108 | test_dataset = datasets.CIFAR10(root=data_dir, 109 | train=False, 110 | transform=transform_test) 111 | test_loader = DataLoader(test_dataset, 112 | batch_size=batch_size, 113 | shuffle=False, 114 | num_workers=32) 115 | return test_loader 116 | -------------------------------------------------------------------------------- /ZeroQ/utils/quantization_utils/quant_modules.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import torch 22 | import time 23 | import math 24 | import numpy as np 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | from torch.nn import Module, Parameter 28 | from .quant_utils import * 29 | import sys 30 | 31 | 32 | class QuantAct(Module): 33 | """ 34 | Class to quantize given activations 35 | """ 36 | def __init__(self, 37 | activation_bit, 38 | full_precision_flag=False, 39 | running_stat=True): 40 | """ 41 | activation_bit: bit-setting for activation 42 | full_precision_flag: full precision or not 43 | running_stat: determines whether the activation range is updated or froze 44 | """ 45 | super(QuantAct, self).__init__() 46 | self.activation_bit = activation_bit 47 | self.momentum = 0.99 48 | self.full_precision_flag = full_precision_flag 49 | self.running_stat = running_stat 50 | self.register_buffer('x_min', torch.zeros(1)) 51 | self.register_buffer('x_max', torch.zeros(1)) 52 | self.act_function = AsymmetricQuantFunction.apply 53 | 54 | def __repr__(self): 55 | return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format( 56 | self.__class__.__name__, self.activation_bit, 57 | self.full_precision_flag, self.running_stat, self.x_min.item(), 58 | self.x_max.item()) 59 | 60 | def fix(self): 61 | """ 62 | fix the activation range by setting running stat 63 | """ 64 | self.running_stat = False 65 | 66 | def forward(self, x): 67 | """ 68 | quantize given activation x 69 | """ 70 | if self.running_stat: 71 | x_min = x.data.min() 72 | x_max = x.data.max() 73 | # in-place operation used on multi-gpus 74 | self.x_min += -self.x_min + min(self.x_min, x_min) 75 | self.x_max += -self.x_max + max(self.x_max, x_max) 76 | 77 | if not self.full_precision_flag: 78 | quant_act = self.act_function(x, self.activation_bit, self.x_min, 79 | self.x_max) 80 | return quant_act 81 | else: 82 | return x 83 | 84 | 85 | class Quant_Linear(Module): 86 | """ 87 | Class to quantize given linear layer weights 88 | """ 89 | def __init__(self, weight_bit, full_precision_flag=False): 90 | """ 91 | weight: bit-setting for weight 92 | full_precision_flag: full precision or not 93 | running_stat: determines whether the activation range is updated or froze 94 | """ 95 | super(Quant_Linear, self).__init__() 96 | self.full_precision_flag = full_precision_flag 97 | self.weight_bit = weight_bit 98 | self.weight_function = AsymmetricQuantFunction.apply 99 | 100 | def __repr__(self): 101 | s = super(Quant_Linear, self).__repr__() 102 | s = "(" + s + " weight_bit={}, full_precision_flag={})".format( 103 | self.weight_bit, self.full_precision_flag) 104 | return s 105 | 106 | def set_param(self, linear): 107 | self.in_features = linear.in_features 108 | self.out_features = linear.out_features 109 | self.weight = Parameter(linear.weight.data.clone()) 110 | try: 111 | self.bias = Parameter(linear.bias.data.clone()) 112 | except AttributeError: 113 | self.bias = None 114 | 115 | def forward(self, x): 116 | """ 117 | using quantized weights to forward activation x 118 | """ 119 | w = self.weight 120 | x_transform = w.data.detach() 121 | w_min = x_transform.min(dim=1).values 122 | w_max = x_transform.max(dim=1).values 123 | if not self.full_precision_flag: 124 | w = self.weight_function(self.weight, self.weight_bit, w_min, 125 | w_max) 126 | else: 127 | w = self.weight 128 | return F.linear(x, weight=w, bias=self.bias) 129 | 130 | 131 | class Quant_Conv2d(Module): 132 | """ 133 | Class to quantize given convolutional layer weights 134 | """ 135 | def __init__(self, weight_bit, full_precision_flag=False): 136 | super(Quant_Conv2d, self).__init__() 137 | self.full_precision_flag = full_precision_flag 138 | self.weight_bit = weight_bit 139 | self.weight_function = AsymmetricQuantFunction.apply 140 | 141 | def __repr__(self): 142 | s = super(Quant_Conv2d, self).__repr__() 143 | s = "(" + s + " weight_bit={}, full_precision_flag={})".format( 144 | self.weight_bit, self.full_precision_flag) 145 | return s 146 | 147 | def set_param(self, conv): 148 | self.in_channels = conv.in_channels 149 | self.out_channels = conv.out_channels 150 | self.kernel_size = conv.kernel_size 151 | self.stride = conv.stride 152 | self.padding = conv.padding 153 | self.dilation = conv.dilation 154 | self.groups = conv.groups 155 | self.weight = Parameter(conv.weight.data.clone()) 156 | try: 157 | self.bias = Parameter(conv.bias.data.clone()) 158 | except AttributeError: 159 | self.bias = None 160 | 161 | def forward(self, x): 162 | """ 163 | using quantized weights to forward activation x 164 | """ 165 | w = self.weight 166 | x_transform = w.data.contiguous().view(self.out_channels, -1) 167 | w_min = x_transform.min(dim=1).values 168 | w_max = x_transform.max(dim=1).values 169 | if not self.full_precision_flag: 170 | w = self.weight_function(self.weight, self.weight_bit, w_min, 171 | w_max) 172 | else: 173 | w = self.weight 174 | 175 | return F.conv2d(x, w, self.bias, self.stride, self.padding, 176 | self.dilation, self.groups) 177 | -------------------------------------------------------------------------------- /ZeroQ/utils/quantization_utils/quant_utils.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import math 22 | import numpy as np 23 | from torch.autograd import Function, Variable 24 | import torch 25 | 26 | 27 | def clamp(input, min, max, inplace=False): 28 | """ 29 | Clamp tensor input to (min, max). 30 | input: input tensor to be clamped 31 | """ 32 | 33 | if inplace: 34 | input.clamp_(min, max) 35 | return input 36 | return torch.clamp(input, min, max) 37 | 38 | 39 | def linear_quantize(input, scale, zero_point, inplace=False): 40 | """ 41 | Quantize single-precision input tensor to integers with the given scaling factor and zeropoint. 42 | input: single-precision input tensor to be quantized 43 | scale: scaling factor for quantization 44 | zero_pint: shift for quantization 45 | """ 46 | 47 | # reshape scale and zeropoint for convolutional weights and activation 48 | if len(input.shape) == 4: 49 | scale = scale.view(-1, 1, 1, 1) 50 | zero_point = zero_point.view(-1, 1, 1, 1) 51 | # reshape scale and zeropoint for linear weights 52 | elif len(input.shape) == 2: 53 | scale = scale.view(-1, 1) 54 | zero_point = zero_point.view(-1, 1) 55 | # mapping single-precision input to integer values with the given scale and zeropoint 56 | if inplace: 57 | input.mul_(scale).sub_(zero_point).round_() 58 | return input 59 | return torch.round(scale * input - zero_point) 60 | 61 | 62 | def linear_dequantize(input, scale, zero_point, inplace=False): 63 | """ 64 | Map integer input tensor to fixed point float point with given scaling factor and zeropoint. 65 | input: integer input tensor to be mapped 66 | scale: scaling factor for quantization 67 | zero_pint: shift for quantization 68 | """ 69 | 70 | # reshape scale and zeropoint for convolutional weights and activation 71 | if len(input.shape) == 4: 72 | scale = scale.view(-1, 1, 1, 1) 73 | zero_point = zero_point.view(-1, 1, 1, 1) 74 | # reshape scale and zeropoint for linear weights 75 | elif len(input.shape) == 2: 76 | scale = scale.view(-1, 1) 77 | zero_point = zero_point.view(-1, 1) 78 | # mapping integer input to fixed point float point value with given scaling factor and zeropoint 79 | if inplace: 80 | input.add_(zero_point).div_(scale) 81 | return input 82 | return (input + zero_point) / scale 83 | 84 | 85 | def asymmetric_linear_quantization_params(num_bits, 86 | saturation_min, 87 | saturation_max, 88 | integral_zero_point=True, 89 | signed=True): 90 | """ 91 | Compute the scaling factor and zeropoint with the given quantization range. 92 | saturation_min: lower bound for quantization range 93 | saturation_max: upper bound for quantization range 94 | """ 95 | n = 2**num_bits - 1 96 | scale = n / torch.clamp((saturation_max - saturation_min), min=1e-8) 97 | zero_point = scale * saturation_min 98 | 99 | if integral_zero_point: 100 | if isinstance(zero_point, torch.Tensor): 101 | zero_point = zero_point.round() 102 | else: 103 | zero_point = float(round(zero_point)) 104 | if signed: 105 | zero_point += 2**(num_bits - 1) 106 | return scale, zero_point 107 | 108 | 109 | class AsymmetricQuantFunction(Function): 110 | """ 111 | Class to quantize the given floating-point values with given range and bit-setting. 112 | Currently only support inference, but not support back-propagation. 113 | """ 114 | @staticmethod 115 | def forward(ctx, x, k, x_min=None, x_max=None): 116 | """ 117 | x: single-precision value to be quantized 118 | k: bit-setting for x 119 | x_min: lower bound for quantization range 120 | x_max=None 121 | """ 122 | 123 | if x_min is None or x_max is None or (sum(x_min == x_max) == 1 124 | and x_min.numel() == 1): 125 | x_min, x_max = x.min(), x.max() 126 | scale, zero_point = asymmetric_linear_quantization_params( 127 | k, x_min, x_max) 128 | new_quant_x = linear_quantize(x, scale, zero_point, inplace=False) 129 | n = 2**(k - 1) 130 | new_quant_x = torch.clamp(new_quant_x, -n, n - 1) 131 | quant_x = linear_dequantize(new_quant_x, 132 | scale, 133 | zero_point, 134 | inplace=False) 135 | return torch.autograd.Variable(quant_x) 136 | 137 | @staticmethod 138 | def backward(ctx, grad_output): 139 | raise NotImplementedError 140 | -------------------------------------------------------------------------------- /ZeroQ/utils/quantize_model.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import torch 22 | import torch.nn as nn 23 | import copy 24 | from .quantization_utils.quant_modules import * 25 | from pytorchcv.models.common import ConvBlock 26 | from pytorchcv.models.shufflenetv2 import ShuffleUnit, ShuffleInitBlock 27 | 28 | 29 | def quantize_model(model): 30 | """ 31 | Recursively quantize a pretrained single-precision model to int8 quantized model 32 | model: pretrained single-precision model 33 | """ 34 | 35 | # quantize convolutional and linear layers to 8-bit 36 | if type(model) == nn.Conv2d: 37 | quant_mod = Quant_Conv2d(weight_bit=8) 38 | quant_mod.set_param(model) 39 | return quant_mod 40 | elif type(model) == nn.Linear: 41 | quant_mod = Quant_Linear(weight_bit=8) 42 | quant_mod.set_param(model) 43 | return quant_mod 44 | 45 | # quantize all the activation to 8-bit 46 | elif type(model) == nn.ReLU or type(model) == nn.ReLU6: 47 | return nn.Sequential(*[model, QuantAct(activation_bit=8)]) 48 | 49 | # recursively use the quantized module to replace the single-precision module 50 | elif type(model) == nn.Sequential: 51 | mods = [] 52 | for n, m in model.named_children(): 53 | mods.append(quantize_model(m)) 54 | return nn.Sequential(*mods) 55 | else: 56 | q_model = copy.deepcopy(model) 57 | for attr in dir(model): 58 | mod = getattr(model, attr) 59 | if isinstance(mod, nn.Module) and 'norm' not in attr: 60 | setattr(q_model, attr, quantize_model(mod)) 61 | return q_model 62 | 63 | 64 | def freeze_model(model): 65 | """ 66 | freeze the activation range 67 | """ 68 | if type(model) == QuantAct: 69 | model.fix() 70 | elif type(model) == nn.Sequential: 71 | mods = [] 72 | for n, m in model.named_children(): 73 | freeze_model(m) 74 | else: 75 | for attr in dir(model): 76 | mod = getattr(model, attr) 77 | if isinstance(mod, nn.Module) and 'norm' not in attr: 78 | freeze_model(mod) 79 | return model 80 | 81 | 82 | def unfreeze_model(model): 83 | """ 84 | unfreeze the activation range 85 | """ 86 | if type(model) == QuantAct: 87 | model.unfix() 88 | elif type(model) == nn.Sequential: 89 | mods = [] 90 | for n, m in model.named_children(): 91 | unfreeze_model(m) 92 | else: 93 | for attr in dir(model): 94 | mod = getattr(model, attr) 95 | if isinstance(mod, nn.Module) and 'norm' not in attr: 96 | unfreeze_model(mod) 97 | return model 98 | -------------------------------------------------------------------------------- /ZeroQ/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | #* 2 | # @file Different utility functions 3 | # Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami 4 | # All rights reserved. 5 | # This file is part of ZeroQ repository. 6 | # 7 | # ZeroQ is free software: you can redistribute it and/or modify 8 | # it under the terms of the GNU General Public License as published by 9 | # the Free Software Foundation, either version 3 of the License, or 10 | # (at your option) any later version. 11 | # 12 | # ZeroQ is distributed in the hope that it will be useful, 13 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | # GNU General Public License for more details. 16 | # 17 | # You should have received a copy of the GNU General Public License 18 | # along with ZeroQ repository. If not, see . 19 | #* 20 | 21 | import torch 22 | import os 23 | import torch.nn as nn 24 | from progress.bar import Bar 25 | 26 | 27 | def test(model, test_loader): 28 | """ 29 | test a model on a given dataset 30 | """ 31 | total, correct = 0, 0 32 | bar = Bar('Testing', max=len(test_loader)) 33 | model.eval() 34 | with torch.no_grad(): 35 | for batch_idx, (inputs, targets) in enumerate(test_loader): 36 | inputs, targets = inputs.cuda(), targets.cuda() 37 | outputs = model(inputs) 38 | _, predicted = outputs.max(1) 39 | total += targets.size(0) 40 | correct += predicted.eq(targets).sum().item() 41 | acc = correct / total 42 | 43 | bar.suffix = f'({batch_idx + 1}/{len(test_loader)}) | ETA: {bar.eta_td} | top1: {acc}' 44 | bar.next() 45 | print('\nFinal acc: %.2f%% (%d/%d)' % (100. * acc, correct, total)) 46 | bar.finish() 47 | model.train() 48 | return acc 49 | 50 | 51 | def update(quantized_model, distilD): 52 | """ 53 | Update activation range according to distilled data 54 | quantized_model: a quantized model whose activation range to be updated 55 | distilD: distilled data 56 | """ 57 | print('******updateing BN stats...', end='') 58 | with torch.no_grad(): 59 | for batch_idx, inputs in enumerate(distilD): 60 | if isinstance(inputs, list): 61 | inputs = inputs[0] 62 | inputs = inputs.cuda() 63 | outputs = quantized_model(inputs) 64 | print(' Finished******') 65 | return quantized_model 66 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/detection/__init__.py -------------------------------------------------------------------------------- /dataset/detection/open_images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pathlib 3 | import cv2 4 | import pandas as pd 5 | import copy 6 | 7 | class OpenImagesDataset: 8 | 9 | def __init__(self, root, 10 | transform=None, target_transform=None, 11 | dataset_type="train", balance_data=False): 12 | self.root = pathlib.Path(root) 13 | self.transform = transform 14 | self.target_transform = target_transform 15 | self.dataset_type = dataset_type.lower() 16 | 17 | self.data, self.class_names, self.class_dict = self._read_data() 18 | self.balance_data = balance_data 19 | self.min_image_num = -1 20 | if self.balance_data: 21 | self.data = self._balance_data() 22 | self.ids = [info['image_id'] for info in self.data] 23 | 24 | self.class_stat = None 25 | 26 | def _getitem(self, index): 27 | image_info = self.data[index] 28 | image = self._read_image(image_info['image_id']) 29 | # duplicate boxes to prevent corruption of dataset 30 | boxes = copy.copy(image_info['boxes']) 31 | boxes[:, 0] *= image.shape[1] 32 | boxes[:, 1] *= image.shape[0] 33 | boxes[:, 2] *= image.shape[1] 34 | boxes[:, 3] *= image.shape[0] 35 | # duplicate labels to prevent corruption of dataset 36 | labels = copy.copy(image_info['labels']) 37 | if self.transform: 38 | image, boxes, labels = self.transform(image, boxes, labels) 39 | if self.target_transform: 40 | boxes, labels = self.target_transform(boxes, labels) 41 | return image_info['image_id'], image, boxes, labels 42 | 43 | def __getitem__(self, index): 44 | _, image, boxes, labels = self._getitem(index) 45 | return image, boxes, labels 46 | 47 | def get_annotation(self, index): 48 | """To conform the eval_ssd implementation that is based on the VOC dataset.""" 49 | image_id, image, boxes, labels = self._getitem(index) 50 | is_difficult = np.zeros(boxes.shape[0], dtype=np.uint8) 51 | return image_id, (boxes, labels, is_difficult) 52 | 53 | def get_image(self, index): 54 | image_info = self.data[index] 55 | image = self._read_image(image_info['image_id']) 56 | if self.transform: 57 | image, _ = self.transform(image) 58 | return image 59 | 60 | def _read_data(self): 61 | annotation_file = f"{self.root}/sub-{self.dataset_type}-annotations-bbox.csv" 62 | annotations = pd.read_csv(annotation_file) 63 | class_names = ['BACKGROUND'] + sorted(list(annotations['ClassName'].unique())) 64 | class_dict = {class_name: i for i, class_name in enumerate(class_names)} 65 | data = [] 66 | for image_id, group in annotations.groupby("ImageID"): 67 | boxes = group.loc[:, ["XMin", "YMin", "XMax", "YMax"]].values.astype(np.float32) 68 | # make labels 64 bits to satisfy the cross_entropy function 69 | labels = np.array([class_dict[name] for name in group["ClassName"]], dtype='int64') 70 | data.append({ 71 | 'image_id': image_id, 72 | 'boxes': boxes, 73 | 'labels': labels 74 | }) 75 | return data, class_names, class_dict 76 | 77 | def __len__(self): 78 | return len(self.data) 79 | 80 | def __repr__(self): 81 | if self.class_stat is None: 82 | self.class_stat = {name: 0 for name in self.class_names[1:]} 83 | for example in self.data: 84 | for class_index in example['labels']: 85 | class_name = self.class_names[class_index] 86 | self.class_stat[class_name] += 1 87 | content = ["Dataset Summary:" 88 | f"Number of Images: {len(self.data)}", 89 | f"Minimum Number of Images for a Class: {self.min_image_num}", 90 | "Label Distribution:"] 91 | for class_name, num in self.class_stat.items(): 92 | content.append(f"\t{class_name}: {num}") 93 | return "\n".join(content) 94 | 95 | def _read_image(self, image_id): 96 | image_file = self.root / self.dataset_type / f"{image_id}.jpg" 97 | image = cv2.imread(str(image_file)) 98 | if image.shape[2] == 1: 99 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 100 | else: 101 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 102 | return image 103 | 104 | def _balance_data(self): 105 | label_image_indexes = [set() for _ in range(len(self.class_names))] 106 | for i, image in enumerate(self.data): 107 | for label_id in image['labels']: 108 | label_image_indexes[label_id].add(i) 109 | label_stat = [len(s) for s in label_image_indexes] 110 | self.min_image_num = min(label_stat[1:]) 111 | sample_image_indexes = set() 112 | for image_indexes in label_image_indexes[1:]: 113 | image_indexes = np.array(list(image_indexes)) 114 | sub = np.random.permutation(image_indexes)[:self.min_image_num] 115 | sample_image_indexes.update(sub) 116 | sample_data = [self.data[i] for i in sample_image_indexes] 117 | return sample_data 118 | 119 | 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /dataset/detection/voc_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import pathlib 4 | import xml.etree.ElementTree as ET 5 | import cv2 6 | import os 7 | 8 | 9 | class VOCDataset: 10 | 11 | def __init__(self, root, transform=None, target_transform=None, is_test=False, keep_difficult=False, label_file=None): 12 | """Dataset for VOC data. 13 | Args: 14 | root: the root of the VOC2007 or VOC2012 dataset, the directory contains the following sub-directories: 15 | Annotations, ImageSets, JPEGImages, SegmentationClass, SegmentationObject. 16 | """ 17 | self.root = pathlib.Path(root) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | if is_test: 21 | image_sets_file = self.root / "ImageSets/Main/test.txt" 22 | else: 23 | image_sets_file = self.root / "ImageSets/Main/val.txt" 24 | # image_sets_file = self.root / "ImageSets/Main/trainval.txt" 25 | self.ids = VOCDataset._read_image_ids(image_sets_file) 26 | self.keep_difficult = keep_difficult 27 | 28 | # if the labels file exists, read in the class names 29 | label_file_name = self.root / "labels.txt" 30 | 31 | if os.path.isfile(label_file_name): 32 | class_string = "" 33 | with open(label_file_name, 'r') as infile: 34 | for line in infile: 35 | class_string += line.rstrip() 36 | 37 | # classes should be a comma separated list 38 | 39 | classes = class_string.split(',') 40 | # prepend BACKGROUND as first class 41 | classes.insert(0, 'BACKGROUND') 42 | classes = [ elem.replace(" ", "") for elem in classes] 43 | self.class_names = tuple(classes) 44 | logging.info("VOC Labels read from file: " + str(self.class_names)) 45 | 46 | else: 47 | logging.info("No labels file, using default VOC classes.") 48 | self.class_names = ('BACKGROUND', 49 | 'aeroplane', 'bicycle', 'bird', 'boat', 50 | 'bottle', 'bus', 'car', 'cat', 'chair', 51 | 'cow', 'diningtable', 'dog', 'horse', 52 | 'motorbike', 'person', 'pottedplant', 53 | 'sheep', 'sofa', 'train', 'tvmonitor') 54 | 55 | 56 | self.class_dict = {class_name: i for i, class_name in enumerate(self.class_names)} 57 | 58 | def __getitem__(self, index): 59 | image_id = self.ids[index] 60 | boxes, labels, is_difficult = self._get_annotation(image_id) 61 | if not self.keep_difficult: 62 | boxes = boxes[is_difficult == 0] 63 | labels = labels[is_difficult == 0] 64 | image = self._read_image(image_id) 65 | if self.transform: 66 | image, boxes, labels = self.transform(image, boxes, labels) 67 | if self.target_transform: 68 | boxes, labels = self.target_transform(boxes, labels) 69 | return image, boxes, labels 70 | 71 | def get_image(self, index): 72 | image_id = self.ids[index] 73 | image = self._read_image(image_id) 74 | if self.transform: 75 | image, _ = self.transform(image) 76 | return image 77 | 78 | def get_annotation(self, index): 79 | image_id = self.ids[index] 80 | return image_id, self._get_annotation(image_id) 81 | 82 | def __len__(self): 83 | return len(self.ids) 84 | 85 | @staticmethod 86 | def _read_image_ids(image_sets_file): 87 | ids = [] 88 | with open(image_sets_file) as f: 89 | for line in f: 90 | ids.append(line.rstrip()) 91 | return ids 92 | 93 | def _get_annotation(self, image_id): 94 | annotation_file = self.root / f"Annotations/{image_id}.xml" 95 | objects = ET.parse(annotation_file).findall("object") 96 | boxes = [] 97 | labels = [] 98 | is_difficult = [] 99 | for object in objects: 100 | class_name = object.find('name').text.lower().strip() 101 | # we're only concerned with clases in our list 102 | if class_name in self.class_dict: 103 | bbox = object.find('bndbox') 104 | 105 | # VOC dataset format follows Matlab, in which indexes start from 0 106 | x1 = float(bbox.find('xmin').text) - 1 107 | y1 = float(bbox.find('ymin').text) - 1 108 | x2 = float(bbox.find('xmax').text) - 1 109 | y2 = float(bbox.find('ymax').text) - 1 110 | boxes.append([x1, y1, x2, y2]) 111 | 112 | labels.append(self.class_dict[class_name]) 113 | is_difficult_str = object.find('difficult').text 114 | is_difficult.append(int(is_difficult_str) if is_difficult_str else 0) 115 | 116 | return (np.array(boxes, dtype=np.float32), 117 | np.array(labels, dtype=np.int64), 118 | np.array(is_difficult, dtype=np.uint8)) 119 | 120 | def _read_image(self, image_id): 121 | image_file = self.root / f"JPEGImages/{image_id}.jpg" 122 | image = cv2.imread(str(image_file)) 123 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 124 | return image 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /dataset/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/dataset/segmentation/__init__.py -------------------------------------------------------------------------------- /dataset/segmentation/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | from PIL import Image, ImageOps, ImageFilter 6 | 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | mask = np.array(mask).astype(np.float32) 22 | img /= 255.0 23 | img -= self.mean 24 | img /= self.std 25 | 26 | return {'image': img, 27 | 'label': mask} 28 | 29 | 30 | class ToTensor(object): 31 | """Convert ndarrays in sample to Tensors.""" 32 | 33 | def __call__(self, sample): 34 | # swap color axis because 35 | # numpy image: H x W x C 36 | # torch image: C X H X W 37 | img = sample['image'] 38 | mask = sample['label'] 39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 40 | mask = np.array(mask).astype(np.float32) 41 | 42 | img = torch.from_numpy(img).float() 43 | mask = torch.from_numpy(mask).float() 44 | 45 | return {'image': img, 46 | 'label': mask} 47 | 48 | 49 | class RandomHorizontalFlip(object): 50 | def __call__(self, sample): 51 | img = sample['image'] 52 | mask = sample['label'] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {'image': img, 58 | 'label': mask} 59 | 60 | 61 | class RandomRotate(object): 62 | def __init__(self, degree): 63 | self.degree = degree 64 | 65 | def __call__(self, sample): 66 | img = sample['image'] 67 | mask = sample['label'] 68 | rotate_degree = random.uniform(-1*self.degree, self.degree) 69 | img = img.rotate(rotate_degree, Image.BILINEAR) 70 | mask = mask.rotate(rotate_degree, Image.NEAREST) 71 | 72 | return {'image': img, 73 | 'label': mask} 74 | 75 | 76 | class RandomGaussianBlur(object): 77 | def __call__(self, sample): 78 | img = sample['image'] 79 | mask = sample['label'] 80 | if random.random() < 0.5: 81 | img = img.filter(ImageFilter.GaussianBlur( 82 | radius=random.random())) 83 | 84 | return {'image': img, 85 | 'label': mask} 86 | 87 | 88 | class RandomScaleCrop(object): 89 | def __init__(self, base_size, crop_size, fill=0): 90 | self.base_size = base_size 91 | self.crop_size = crop_size 92 | self.fill = fill 93 | 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | mask = sample['label'] 97 | # random scale (short edge) 98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 99 | w, h = img.size 100 | if h > w: 101 | ow = short_size 102 | oh = int(1.0 * h * ow / w) 103 | else: 104 | oh = short_size 105 | ow = int(1.0 * w * oh / h) 106 | img = img.resize((ow, oh), Image.BILINEAR) 107 | mask = mask.resize((ow, oh), Image.NEAREST) 108 | # pad crop 109 | if short_size < self.crop_size: 110 | padh = self.crop_size - oh if oh < self.crop_size else 0 111 | padw = self.crop_size - ow if ow < self.crop_size else 0 112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 114 | # random crop crop_size 115 | w, h = img.size 116 | x1 = random.randint(0, w - self.crop_size) 117 | y1 = random.randint(0, h - self.crop_size) 118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | 121 | return {'image': img, 122 | 'label': mask} 123 | 124 | 125 | class FixScaleCrop(object): 126 | def __init__(self, crop_size): 127 | self.crop_size = crop_size 128 | 129 | def __call__(self, sample): 130 | img = sample['image'] 131 | mask = sample['label'] 132 | w, h = img.size 133 | if w > h: 134 | oh = self.crop_size 135 | ow = int(1.0 * w * oh / h) 136 | else: 137 | ow = self.crop_size 138 | oh = int(1.0 * h * ow / w) 139 | img = img.resize((ow, oh), Image.BILINEAR) 140 | mask = mask.resize((ow, oh), Image.NEAREST) 141 | # center crop 142 | w, h = img.size 143 | x1 = int(round((w - self.crop_size) / 2.)) 144 | y1 = int(round((h - self.crop_size) / 2.)) 145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | 148 | return {'image': img, 149 | 'label': mask} 150 | 151 | class FixedResize(object): 152 | def __init__(self, size): 153 | self.size = (size, size) # size: (h, w) 154 | 155 | def __call__(self, sample): 156 | img = sample['image'] 157 | mask = sample['label'] 158 | 159 | assert img.size == mask.size 160 | 161 | img = img.resize(self.size, Image.BILINEAR) 162 | mask = mask.resize(self.size, Image.NEAREST) 163 | 164 | return {'image': img, 165 | 'label': mask} -------------------------------------------------------------------------------- /dataset/segmentation/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | # from mypath import Path 7 | from torchvision import transforms 8 | 9 | import dataset.segmentation.custom_transforms as tr 10 | import cv2 11 | 12 | class VOCSegmentation(Dataset): 13 | """ 14 | PascalVoc dataset 15 | """ 16 | NUM_CLASSES = 21 17 | 18 | def __init__(self, 19 | args, 20 | base_dir='/media/jakc4103/Toshiba/workspace/dataset/VOCdevkit/VOC2012/', 21 | split='val', 22 | label='SegmentationClass' 23 | ): 24 | """ 25 | :param base_dir: path to VOC dataset directory 26 | :param split: train/val 27 | :param transform: transform to apply 28 | :param label: SegmentationObject/SegmentationClass 29 | """ 30 | super().__init__() 31 | self._base_dir = base_dir 32 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages') 33 | self._cat_dir = os.path.join(self._base_dir, label) 34 | 35 | if isinstance(split, str): 36 | self.split = [split] 37 | else: 38 | split.sort() 39 | self.split = split 40 | 41 | self.args = args 42 | 43 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') 44 | 45 | self.im_ids = [] 46 | self.images = [] 47 | self.categories = [] 48 | 49 | for splt in self.split: 50 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 51 | lines = f.read().splitlines() 52 | 53 | for ii, line in enumerate(lines): 54 | _image = os.path.join(self._image_dir, line + ".jpg") 55 | _cat = os.path.join(self._cat_dir, line + ".png") 56 | assert os.path.isfile(_image) 57 | assert os.path.isfile(_cat) 58 | self.im_ids.append(line) 59 | self.images.append(_image) 60 | self.categories.append(_cat) 61 | 62 | assert (len(self.images) == len(self.categories)) 63 | 64 | # Display stats 65 | print('Number of images in {}: {:d}'.format(split, len(self.images))) 66 | 67 | def __len__(self): 68 | return len(self.images) 69 | 70 | 71 | def __getitem__(self, index): 72 | _img, _target = self._make_img_gt_point_pair(index) 73 | sample = {'image': _img, 'label': _target} 74 | 75 | for split in self.split: 76 | if split == "train": 77 | return self.transform_tr(sample) 78 | elif split == 'val' or split == 'test': 79 | return self.transform_val(sample) 80 | 81 | 82 | def _make_img_gt_point_pair(self, index): 83 | _img = Image.open(self.images[index]).convert('RGB') 84 | _target = Image.open(self.categories[index]) 85 | 86 | # test = np.array(_target) 87 | # print(test.shape) 88 | # print(np.unique(test)) 89 | # cv2.imshow('test', test) 90 | # cv2.waitKey(0) 91 | 92 | return _img, _target 93 | 94 | def transform_tr(self, sample): 95 | composed_transforms = transforms.Compose([ 96 | tr.RandomHorizontalFlip(), 97 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 98 | tr.RandomGaussianBlur(), 99 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 100 | tr.ToTensor()]) 101 | 102 | return composed_transforms(sample) 103 | 104 | def transform_val(self, sample): 105 | 106 | composed_transforms = transforms.Compose([ 107 | tr.FixScaleCrop(crop_size=self.args.crop_size), 108 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 109 | tr.ToTensor()]) 110 | 111 | return composed_transforms(sample) 112 | 113 | def __str__(self): 114 | return 'VOC2012(split=' + str(self.split) + ')' 115 | 116 | 117 | if __name__ == '__main__': 118 | from utils import decode_segmap 119 | from torch.utils.data import DataLoader 120 | import matplotlib.pyplot as plt 121 | import argparse 122 | 123 | parser = argparse.ArgumentParser() 124 | args = parser.parse_args() 125 | args.base_size = 513 126 | args.crop_size = 513 127 | 128 | voc_train = VOCSegmentation(args, split='val') 129 | 130 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 131 | 132 | for ii, sample in enumerate(dataloader): 133 | for jj in range(sample["image"].size()[0]): 134 | img = sample['image'].numpy() 135 | gt = sample['label'].numpy() 136 | print(np.unique(gt)) 137 | tmp = np.array(gt[jj]).astype(np.uint8) 138 | segmap = decode_segmap(tmp, dataset='pascal') 139 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 140 | img_tmp *= (0.229, 0.224, 0.225) 141 | img_tmp += (0.485, 0.456, 0.406) 142 | img_tmp *= 255.0 143 | img_tmp = img_tmp.astype(np.uint8) 144 | plt.figure() 145 | plt.title('display') 146 | plt.subplot(211) 147 | plt.imshow(img_tmp) 148 | plt.subplot(212) 149 | plt.imshow(segmap) 150 | 151 | if ii == 0: 152 | break 153 | 154 | plt.show(block=True) 155 | 156 | 157 | -------------------------------------------------------------------------------- /dataset/segmentation/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | else: 31 | raise NotImplementedError 32 | 33 | r = label_mask.copy() 34 | g = label_mask.copy() 35 | b = label_mask.copy() 36 | for ll in range(0, n_classes): 37 | r[label_mask == ll] = label_colours[ll, 0] 38 | g[label_mask == ll] = label_colours[ll, 1] 39 | b[label_mask == ll] = label_colours[ll, 2] 40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 41 | rgb[:, :, 0] = r / 255.0 42 | rgb[:, :, 1] = g / 255.0 43 | rgb[:, :, 2] = b / 255.0 44 | if plot: 45 | plt.imshow(rgb) 46 | plt.show() 47 | else: 48 | return rgb 49 | 50 | 51 | def encode_segmap(mask): 52 | """Encode segmentation label images as pascal classes 53 | Args: 54 | mask (np.ndarray): raw segmentation label image of dimension 55 | (M, N, 3), in which the Pascal classes are encoded as colours. 56 | Returns: 57 | (np.ndarray): class map with dimensions (M,N), where the value at 58 | a given location is the integer denoting the class index. 59 | """ 60 | mask = mask.astype(int) 61 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 62 | for ii, label in enumerate(get_pascal_labels()): 63 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 64 | label_mask = label_mask.astype(int) 65 | return label_mask 66 | 67 | 68 | def get_cityscapes_labels(): 69 | return np.array([ 70 | [128, 64, 128], 71 | [244, 35, 232], 72 | [70, 70, 70], 73 | [102, 102, 156], 74 | [190, 153, 153], 75 | [153, 153, 153], 76 | [250, 170, 30], 77 | [220, 220, 0], 78 | [107, 142, 35], 79 | [152, 251, 152], 80 | [0, 130, 180], 81 | [220, 20, 60], 82 | [255, 0, 0], 83 | [0, 0, 142], 84 | [0, 0, 70], 85 | [0, 60, 100], 86 | [0, 80, 100], 87 | [0, 0, 230], 88 | [119, 11, 32]]) 89 | 90 | 91 | def get_pascal_labels(): 92 | """Load the mapping that associates pascal classes with label colors 93 | Returns: 94 | np.ndarray with dimensions (21, 3) 95 | """ 96 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 97 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 98 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 99 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 100 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 101 | [0, 64, 128]]) -------------------------------------------------------------------------------- /images/LE_distill.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/LE_distill.png -------------------------------------------------------------------------------- /images/graph_cls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_cls.png -------------------------------------------------------------------------------- /images/graph_deeplab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_deeplab.png -------------------------------------------------------------------------------- /images/graph_ssd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/images/graph_ssd.png -------------------------------------------------------------------------------- /inference_cls.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "platform.h" 24 | #include "net.h" 25 | 26 | #if NCNN_VULKAN 27 | #include "gpu.h" 28 | #endif // NCNN_VULKAN 29 | 30 | int parse_images_dir(const std::string& base_path, std::vector& file_path) 31 | { 32 | file_path.clear(); 33 | 34 | const cv::String base_path_str(base_path); 35 | std::vector image_list; 36 | 37 | cv::glob(base_path_str, image_list, true); 38 | 39 | for (size_t i = 0; i < image_list.size(); i++) 40 | { 41 | const cv::String& image_path = image_list[i]; 42 | file_path.push_back(image_path); 43 | } 44 | 45 | return 0; 46 | } 47 | 48 | static int print_topk(const std::vector& cls_scores, int topk) 49 | { 50 | // partial sort topk with index 51 | int size = cls_scores.size(); 52 | std::vector< std::pair > vec; 53 | vec.resize(size); 54 | for (int i=0; i >()); 61 | int pred_idx; 62 | // print topk and score 63 | for (int i=0; i& image_list, std::vector& cls_scores, 78 | const std::string ncnn_param_file_path, const std::string ncnn_bin_file_path, const std::string out_layer) 79 | { 80 | ncnn::Net net; 81 | size_t size = image_list.size(); 82 | printf("Number of images: %lu\n", size); 83 | 84 | #if NCNN_VULKAN 85 | net.opt.use_vulkan_compute = true; 86 | #endif // NCNN_VULKAN 87 | 88 | net.load_param(&ncnn_param_file_path[0]); 89 | net.load_model(&ncnn_bin_file_path[0]); 90 | 91 | const float mean_vals[3] = {0.485f*255.f, 0.456f*255.f, 0.406f*255.f}; 92 | const float std_vals[3] = {1/0.229f/255.f, 1/0.224f/255.f, 1/0.225f/255.f}; 93 | int correct_count = 0; 94 | int label = -1; 95 | std::string folder_name = "dummy"; 96 | for (size_t i = 0; i < image_list.size(); i++) 97 | { 98 | 99 | std::string img_name = image_list[i]; 100 | 101 | std::istringstream f(img_name); 102 | std::string s; 103 | while(std::getline(f, s, '/')) 104 | { 105 | if((s.substr(0, 2) == "n0" || s.substr(0, 2) == "n1") && s.size() == 9 && folder_name != s) 106 | { 107 | label++; 108 | folder_name = s; 109 | } 110 | } 111 | 112 | if ((i + 1) % 1000 == 0) 113 | { 114 | fprintf(stderr, " %d/%d, acc:%f\n", static_cast(i + 1), static_cast(size), static_cast(correct_count)/static_cast(i)); 115 | } 116 | 117 | #if OpenCV_VERSION_MAJOR > 2 118 | cv::Mat bgr = cv::imread(img_name, cv::IMREAD_COLOR); 119 | #else 120 | cv::Mat bgr = cv::imread(img_name, CV_LOAD_IMAGE_COLOR); 121 | #endif 122 | if (bgr.empty()) 123 | { 124 | fprintf(stderr, "cv::imread %s failed\n", img_name.c_str()); 125 | return -1; 126 | } 127 | 128 | ncnn::Mat resized = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, bgr.cols, bgr.rows, 256, 256); 129 | ncnn::Mat in; 130 | ncnn::copy_cut_border(resized, in, 16, 16, 16, 16); 131 | in.substract_mean_normalize(mean_vals, std_vals); 132 | 133 | ncnn::Extractor ex = net.create_extractor(); 134 | ex.set_num_threads(2); 135 | 136 | ex.input("0", in); 137 | 138 | ncnn::Mat out; 139 | ex.extract(&out_layer[0], out); 140 | 141 | cls_scores.resize(out.w); 142 | for (int j=0; j(correct_count)/static_cast(size)); 155 | return 0; 156 | } 157 | 158 | int main(int argc, char** argv) 159 | { 160 | const char* key_map = 161 | "{help h usage ? | | print this message }" 162 | "{param p | | path to ncnn.param file }" 163 | "{bin b | | path to ncnn.bin file }" 164 | "{images i | | path to calibration images folder }" 165 | "{out_layer o | | name of the final layer (innerproduct or softmax) }" 166 | ; 167 | 168 | cv::CommandLineParser parser(argc, argv, key_map); 169 | const std::string image_folder_path = parser.get("images"); 170 | const std::string ncnn_param_file_path = parser.get("param"); 171 | const std::string ncnn_bin_file_path = parser.get("bin"); 172 | const std::string out_layer = parser.get("out_layer"); 173 | 174 | // check the input param 175 | if (image_folder_path.empty() || ncnn_param_file_path.empty() || ncnn_bin_file_path.empty()) 176 | { 177 | fprintf(stderr, "One or more path may be empty, please check and try again.\n"); 178 | return 0; 179 | } 180 | 181 | // parse the image file. 182 | std::vector image_file_path_list; 183 | parse_images_dir(image_folder_path, image_file_path_list); 184 | 185 | #if NCNN_VULKAN 186 | ncnn::create_gpu_instance(); 187 | #endif // NCNN_VULKAN 188 | 189 | std::vector cls_scores; 190 | detect_net(image_file_path_list, cls_scores, ncnn_param_file_path, ncnn_bin_file_path, out_layer); 191 | 192 | #if NCNN_VULKAN 193 | ncnn::destroy_gpu_instance(); 194 | #endif // NCNN_VULKAN 195 | 196 | return 0; 197 | } 198 | -------------------------------------------------------------------------------- /main_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | import argparse 9 | 10 | from modeling.segmentation.deeplab import DeepLab 11 | from torch.utils.data import DataLoader 12 | from dataset.segmentation.pascal import VOCSegmentation 13 | from utils.metrics import Evaluator 14 | 15 | from utils.relation import create_relation 16 | from dfq import cross_layer_equalization, bias_absorption, bias_correction, clip_weight 17 | from utils.layer_transform import switch_layers, replace_op, restore_op, set_quant_minmax, merge_batchnorm, quantize_targ_layer#, LayerTransform 18 | from PyTransformer.transformers.torchTransformer import TorchTransformer 19 | from utils.quantize import QuantConv2d, QuantNConv2d, QuantMeasure, QConv2d, set_layer_bits 20 | from ZeroQ.distill_data import getDistilData 21 | from improve_dfq import update_scale, transform_quant_layer, set_scale, update_quant_range, set_update_stat, bias_correction_distill 22 | 23 | def get_argument(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--quantize", action='store_true') 26 | parser.add_argument("--equalize", action='store_true') 27 | parser.add_argument("--correction", action='store_true') 28 | parser.add_argument("--absorption", action='store_true') 29 | parser.add_argument("--distill_range", action='store_true') 30 | parser.add_argument("--log", action='store_true') 31 | parser.add_argument("--relu", action='store_true') 32 | parser.add_argument("--clip_weight", action='store_true') 33 | parser.add_argument("--dataset", type=str, default="voc12") 34 | parser.add_argument("--trainable", action='store_true') 35 | parser.add_argument("--bits_weight", type=int, default=8) 36 | parser.add_argument("--bits_activation", type=int, default=8) 37 | parser.add_argument("--bits_bias", type=int, default=8) 38 | return parser.parse_args() 39 | 40 | def estimate_stats(model, state_dict, data, num_epoch=10, path_save='modeling/data_dependent_QuantConv2dAdd.pth'): 41 | import copy 42 | 43 | # model = DeepLab(sync_bn=False) 44 | model.eval() 45 | 46 | model = model.cuda() 47 | 48 | args = lambda: 0 49 | args.base_size = 513 50 | args.crop_size = 513 51 | voc_val = VOCSegmentation(args, split='train') 52 | dataloader = DataLoader(voc_val, batch_size=32, shuffle=True, num_workers=0) 53 | model.train() 54 | 55 | replace_op() 56 | ss = time.time() 57 | with torch.no_grad(): 58 | for epoch in range(num_epoch): 59 | start = time.time() 60 | for sample in dataloader: 61 | image, _ = sample['image'].cuda(), sample['label'].cuda() 62 | 63 | _ = model(image) 64 | 65 | end = time.time() 66 | print("epoch {}: {} sec.".format(epoch, end-start)) 67 | print('total time: {} sec'.format(time.time() - ss)) 68 | restore_op() 69 | 70 | # load 'running_mean' and 'running_var' of batchnorm back from pre-trained parameters 71 | bn_dict = {} 72 | for key in state_dict: 73 | if 'running' in key: 74 | bn_dict[key] = state_dict[key] 75 | 76 | state = model.state_dict() 77 | state.update(bn_dict) 78 | model.load_state_dict(state) 79 | 80 | torch.save(model.state_dict(), path_save) 81 | 82 | return model 83 | 84 | 85 | def inference_all(model, dataset='voc12', opt=None): 86 | print("Start inference") 87 | from utils.segmentation.utils import forward_all 88 | args = lambda: 0 89 | args.base_size = 513 90 | args.crop_size = 513 91 | if dataset == 'voc12': 92 | voc_val = VOCSegmentation(args, base_dir="/home/jakc4103/WDesktop/dataset/VOCdevkit/VOC2012/", split='val') 93 | elif dataset == 'voc07': 94 | voc_val = VOCSegmentation(args, base_dir="/home/jakc4103/WDesktop/dataset/VOCdevkit/VOC2007/", split='test') 95 | dataloader = DataLoader(voc_val, batch_size=32, shuffle=False, num_workers=2) 96 | 97 | forward_all(model, dataloader, visualize=False, opt=opt) 98 | 99 | 100 | def main(): 101 | args = get_argument() 102 | assert args.relu or args.relu == args.equalize, 'must replace relu6 to relu while equalization' 103 | assert args.equalize or args.absorption == args.equalize, 'must use absorption with equalize' 104 | data = torch.ones((4, 3, 513, 513))#.cuda() 105 | 106 | model = DeepLab(sync_bn=False) 107 | state_dict = torch.load('modeling/segmentation/deeplab-mobilenet.pth.tar')['state_dict'] 108 | model.load_state_dict(state_dict) 109 | model.eval() 110 | if args.distill_range: 111 | import copy 112 | # define FP32 model 113 | model_original = copy.deepcopy(model) 114 | model_original.eval() 115 | transformer = TorchTransformer() 116 | transformer._build_graph(model_original, data, [QuantMeasure]) 117 | graph = transformer.log.getGraph() 118 | bottoms = transformer.log.getBottoms() 119 | 120 | data_distill = getDistilData(model_original, 'imagenet', 32, bn_merged=False,\ 121 | num_batch=8, gpu=True, value_range=[-2.11790393, 2.64], size=[513, 513], early_break_factor=0.2) 122 | 123 | transformer = TorchTransformer() 124 | 125 | module_dict = {} 126 | if args.quantize: 127 | if args.distill_range: 128 | module_dict[1] = [(nn.Conv2d, QConv2d)] 129 | elif args.trainable: 130 | module_dict[1] = [(nn.Conv2d, QuantConv2d)] 131 | else: 132 | module_dict[1] = [(nn.Conv2d, QuantNConv2d)] 133 | 134 | if args.relu: 135 | module_dict[0] = [(torch.nn.ReLU6, torch.nn.ReLU)] 136 | 137 | # transformer.summary(model, data) 138 | # transformer.visualize(model, data, 'graph_deeplab', graph_size=120) 139 | 140 | model, transformer = switch_layers(model, transformer, data, module_dict, ignore_layer=[QuantMeasure], quant_op=args.quantize) 141 | graph = transformer.log.getGraph() 142 | bottoms = transformer.log.getBottoms() 143 | 144 | if args.quantize: 145 | if args.distill_range: 146 | targ_layer = [QConv2d] 147 | elif args.trainable: 148 | targ_layer = [QuantConv2d] 149 | else: 150 | targ_layer = [QuantNConv2d] 151 | else: 152 | targ_layer = [nn.Conv2d] 153 | if args.quantize: 154 | set_layer_bits(graph, args.bits_weight, args.bits_activation, args.bits_bias, targ_layer) 155 | model = merge_batchnorm(model, graph, bottoms, targ_layer) 156 | 157 | #create relations 158 | if args.equalize or args.distill_range: 159 | res = create_relation(graph, bottoms, targ_layer) 160 | if args.equalize: 161 | cross_layer_equalization(graph, res, targ_layer, visualize_state=False) 162 | 163 | # if args.distill: 164 | # set_scale(res, graph, bottoms, targ_layer) 165 | 166 | if args.absorption: 167 | bias_absorption(graph, res, bottoms, 3) 168 | 169 | if args.clip_weight: 170 | clip_weight(graph, range_clip=[-15, 15], targ_type=targ_layer) 171 | 172 | if args.correction: 173 | bias_correction(graph, bottoms, targ_layer) 174 | 175 | if args.quantize: 176 | if not args.trainable and not args.distill_range: 177 | graph = quantize_targ_layer(graph, args.bits_weight, args.bits_bias, targ_layer) 178 | 179 | if args.distill_range: 180 | set_update_stat(model, [QuantMeasure], True) 181 | model = update_quant_range(model.cuda(), data_distill, graph, bottoms) 182 | set_update_stat(model, [QuantMeasure], False) 183 | else: 184 | set_quant_minmax(graph, bottoms) 185 | 186 | torch.cuda.empty_cache() 187 | 188 | model = model.cuda() 189 | model.eval() 190 | 191 | if args.quantize: 192 | replace_op() 193 | inference_all(model, args.dataset, args if args.log else None) 194 | if args.quantize: 195 | restore_op() 196 | 197 | 198 | if __name__ == '__main__': 199 | main() -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/classification/MobileNetV2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | def conv_bn(inp, oup, stride): 7 | return nn.Sequential( 8 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 9 | nn.BatchNorm2d(oup), 10 | nn.ReLU6(inplace=True) 11 | ) 12 | 13 | 14 | def conv_1x1_bn(inp, oup): 15 | return nn.Sequential( 16 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 17 | nn.BatchNorm2d(oup), 18 | nn.ReLU6(inplace=True) 19 | ) 20 | 21 | 22 | def make_divisible(x, divisible_by=8): 23 | import numpy as np 24 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 25 | 26 | 27 | class InvertedResidual(nn.Module): 28 | def __init__(self, inp, oup, stride, expand_ratio): 29 | super(InvertedResidual, self).__init__() 30 | self.stride = stride 31 | assert stride in [1, 2] 32 | 33 | hidden_dim = int(inp * expand_ratio) 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 40 | nn.BatchNorm2d(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 50 | nn.BatchNorm2d(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 54 | nn.BatchNorm2d(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 58 | nn.BatchNorm2d(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | if self.use_res_connect: 63 | return x + self.conv(x) 64 | else: 65 | return self.conv(x) 66 | 67 | 68 | class MobileNetV2(nn.Module): 69 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 70 | super(MobileNetV2, self).__init__() 71 | block = InvertedResidual 72 | input_channel = 32 73 | last_channel = 1280 74 | interverted_residual_setting = [ 75 | # t, c, n, s 76 | [1, 16, 1, 1], 77 | [6, 24, 2, 2], 78 | [6, 32, 3, 2], 79 | [6, 64, 4, 2], 80 | [6, 96, 3, 1], 81 | [6, 160, 3, 2], 82 | [6, 320, 1, 1], 83 | ] 84 | 85 | # building first layer 86 | assert input_size % 32 == 0 87 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 88 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 89 | self.features = [conv_bn(3, input_channel, 2)] 90 | # building inverted residual blocks 91 | for t, c, n, s in interverted_residual_setting: 92 | output_channel = make_divisible(c * width_mult) if t > 1 else c 93 | for i in range(n): 94 | if i == 0: 95 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 96 | else: 97 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 98 | input_channel = output_channel 99 | # building last several layers 100 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 101 | # make it nn.Sequential 102 | self.features = nn.Sequential(*self.features) 103 | 104 | # building classifier 105 | self.classifier = nn.Linear(self.last_channel, n_class) 106 | 107 | self._initialize_weights() 108 | 109 | def forward(self, x): 110 | x = self.features(x) 111 | # x = x.mean(3).mean(2) 112 | x = torch.mean(x.view(x.size(0), x.size(1), -1), -1) 113 | x = self.classifier(x) 114 | return x 115 | 116 | def _initialize_weights(self): 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | if m.bias is not None: 122 | m.bias.data.zero_() 123 | elif isinstance(m, nn.BatchNorm2d): 124 | m.weight.data.fill_(1) 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.Linear): 127 | n = m.weight.size(1) 128 | m.weight.data.normal_(0, 0.01) 129 | m.bias.data.zero_() 130 | 131 | 132 | def mobilenet_v2(path_weight=None): 133 | model = MobileNetV2(width_mult=1) 134 | 135 | if path_weight is not None: 136 | print("load weight: {}".format(path_weight)) 137 | # state_dict = load_state_dict_from_url( 138 | # 'https://www.dropbox.com/s/47tyzpofuuyyv1b/mobilenetv2_1.0-f2a8633.pth.tar?dl=1', progress=True) 139 | state_dict = torch.load(path_weight) 140 | 141 | model.load_state_dict(state_dict) 142 | return model 143 | 144 | 145 | if __name__ == '__main__': 146 | # 'modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar' 147 | net = mobilenet_v2('./mobilenetv2_1.0-f2a8633.pth.tar') 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/classification/mobilenetv2_1.0-f2a8633.pth.tar -------------------------------------------------------------------------------- /modeling/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/__init__.py -------------------------------------------------------------------------------- /modeling/detection/config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/config/__init__.py -------------------------------------------------------------------------------- /modeling/detection/config/mobilenetv1_ssd_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors 4 | 5 | 6 | image_size = 300 7 | image_mean = np.array([127, 127, 127]) # RGB layout 8 | image_std = 128.0 9 | iou_threshold = 0.45 10 | center_variance = 0.1 11 | size_variance = 0.2 12 | 13 | specs = [ 14 | SSDSpec(19, 16, SSDBoxSizes(60, 105), [2, 3]), 15 | SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]), 16 | SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]), 17 | SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]), 18 | SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]), 19 | SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3]) 20 | ] 21 | 22 | 23 | priors = generate_ssd_priors(specs, image_size) -------------------------------------------------------------------------------- /modeling/detection/config/squeezenet_ssd_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors 4 | 5 | 6 | image_size = 300 7 | image_mean = np.array([127, 127, 127]) # RGB layout 8 | image_std = 128.0 9 | iou_threshold = 0.45 10 | center_variance = 0.1 11 | size_variance = 0.2 12 | 13 | specs = [ 14 | SSDSpec(17, 16, SSDBoxSizes(60, 105), [2, 3]), 15 | SSDSpec(10, 32, SSDBoxSizes(105, 150), [2, 3]), 16 | SSDSpec(5, 64, SSDBoxSizes(150, 195), [2, 3]), 17 | SSDSpec(3, 100, SSDBoxSizes(195, 240), [2, 3]), 18 | SSDSpec(2, 150, SSDBoxSizes(240, 285), [2, 3]), 19 | SSDSpec(1, 300, SSDBoxSizes(285, 330), [2, 3]) 20 | ] 21 | 22 | 23 | priors = generate_ssd_priors(specs, image_size) -------------------------------------------------------------------------------- /modeling/detection/config/vgg_ssd_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from utils.detection.box_utils import SSDSpec, SSDBoxSizes, generate_ssd_priors 4 | 5 | 6 | image_size = 300 7 | image_mean = np.array([123, 117, 104]) # RGB layout 8 | image_std = 1.0 9 | 10 | iou_threshold = 0.45 11 | center_variance = 0.1 12 | size_variance = 0.2 13 | 14 | specs = [ 15 | SSDSpec(38, 8, SSDBoxSizes(30, 60), [2]), 16 | SSDSpec(19, 16, SSDBoxSizes(60, 111), [2, 3]), 17 | SSDSpec(10, 32, SSDBoxSizes(111, 162), [2, 3]), 18 | SSDSpec(5, 64, SSDBoxSizes(162, 213), [2, 3]), 19 | SSDSpec(3, 100, SSDBoxSizes(213, 264), [2]), 20 | SSDSpec(1, 300, SSDBoxSizes(264, 315), [2]) 21 | ] 22 | 23 | 24 | priors = generate_ssd_priors(specs, image_size) -------------------------------------------------------------------------------- /modeling/detection/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | from .transforms.transforms import * 2 | 3 | 4 | class TrainAugmentation: 5 | def __init__(self, size, mean=0, std=1.0): 6 | """ 7 | Args: 8 | size: the size the of final image. 9 | mean: mean pixel value per channel. 10 | """ 11 | self.mean = mean 12 | self.size = size 13 | self.augment = Compose([ 14 | ConvertFromInts(), 15 | PhotometricDistort(), 16 | Expand(self.mean), 17 | RandomSampleCrop(), 18 | RandomMirror(), 19 | ToPercentCoords(), 20 | Resize(self.size), 21 | SubtractMeans(self.mean), 22 | lambda img, boxes=None, labels=None: (img / std, boxes, labels), 23 | ToTensor(), 24 | ]) 25 | 26 | def __call__(self, img, boxes, labels): 27 | """ 28 | 29 | Args: 30 | img: the output of cv.imread in RGB layout. 31 | boxes: boundding boxes in the form of (x1, y1, x2, y2). 32 | labels: labels of boxes. 33 | """ 34 | return self.augment(img, boxes, labels) 35 | 36 | 37 | class TestTransform: 38 | def __init__(self, size, mean=0.0, std=1.0): 39 | self.transform = Compose([ 40 | ToPercentCoords(), 41 | Resize(size), 42 | SubtractMeans(mean), 43 | lambda img, boxes=None, labels=None: (img / std, boxes, labels), 44 | ToTensor(), 45 | ]) 46 | 47 | def __call__(self, image, boxes, labels): 48 | return self.transform(image, boxes, labels) 49 | 50 | 51 | class PredictionTransform: 52 | def __init__(self, size, mean=0.0, std=1.0): 53 | self.transform = Compose([ 54 | Resize(size), 55 | SubtractMeans(mean), 56 | lambda img, boxes=None, labels=None: (img / std, boxes, labels), 57 | ToTensor() 58 | ]) 59 | 60 | def __call__(self, image): 61 | image, _, _ = self.transform(image) 62 | return image -------------------------------------------------------------------------------- /modeling/detection/fpn_mobilenetv1_ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU 3 | from .nn.mobilenet import MobileNetV1 4 | 5 | from .fpn_ssd import FPNSSD 6 | from .predictor import Predictor 7 | from .config import mobilenetv1_ssd_config as config 8 | 9 | 10 | def create_fpn_mobilenetv1_ssd(num_classes): 11 | base_net = MobileNetV1(1001).features # disable dropout layer 12 | 13 | source_layer_indexes = [ 14 | (69, Conv2d(in_channels=512, out_channels=256, kernel_size=1)), 15 | (len(base_net), Conv2d(in_channels=1024, out_channels=256, kernel_size=1)), 16 | ] 17 | extras = ModuleList([ 18 | Sequential( 19 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1), 20 | ReLU(), 21 | Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1), 22 | ReLU() 23 | ), 24 | Sequential( 25 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 26 | ReLU(), 27 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 28 | ReLU() 29 | ), 30 | Sequential( 31 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 32 | ReLU(), 33 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 34 | ReLU() 35 | ), 36 | Sequential( 37 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 38 | ReLU(), 39 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 40 | ReLU() 41 | ) 42 | ]) 43 | 44 | regression_headers = ModuleList([ 45 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 46 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 47 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 48 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 49 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 51 | ]) 52 | 53 | classification_headers = ModuleList([ 54 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 55 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 56 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 57 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 58 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 60 | ]) 61 | 62 | return FPNSSD(num_classes, base_net, source_layer_indexes, 63 | extras, classification_headers, regression_headers) 64 | 65 | 66 | def create_fpn_mobilenetv1_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')): 67 | predictor = Predictor(net, config.image_size, config.image_mean, config.priors, 68 | config.center_variance, config.size_variance, 69 | nms_method=nms_method, 70 | iou_threshold=config.iou_threshold, 71 | candidate_size=candidate_size, 72 | sigma=sigma, 73 | device=device) 74 | return predictor 75 | -------------------------------------------------------------------------------- /modeling/detection/fpn_ssd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from typing import List, Tuple 6 | 7 | from utils.detection import box_utils 8 | 9 | 10 | class FPNSSD(nn.Module): 11 | def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int], 12 | extras: nn.ModuleList, classification_headers: nn.ModuleList, 13 | regression_headers: nn.ModuleList, upsample_mode="nearest"): 14 | """Compose a SSD model using the given components. 15 | """ 16 | super(FPNSSD, self).__init__() 17 | 18 | self.num_classes = num_classes 19 | self.base_net = base_net 20 | self.source_layer_indexes = source_layer_indexes 21 | self.extras = extras 22 | self.classification_headers = classification_headers 23 | self.regression_headers = regression_headers 24 | self.upsample_mode = upsample_mode 25 | 26 | # register layers in source_layer_indexes by adding them to a module list 27 | self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes if isinstance(t, tuple)]) 28 | self.upsamplers = [ 29 | nn.Upsample(size=(19, 19), mode='bilinear'), 30 | nn.Upsample(size=(10, 10), mode='bilinear'), 31 | nn.Upsample(size=(5, 5), mode='bilinear'), 32 | nn.Upsample(size=(3, 3), mode='bilinear'), 33 | nn.Upsample(size=(2, 2), mode='bilinear'), 34 | ] 35 | 36 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 37 | confidences = [] 38 | locations = [] 39 | start_layer_index = 0 40 | header_index = 0 41 | features = [] 42 | for end_layer_index in self.source_layer_indexes: 43 | 44 | if isinstance(end_layer_index, tuple): 45 | added_layer = end_layer_index[1] 46 | end_layer_index = end_layer_index[0] 47 | else: 48 | added_layer = None 49 | for layer in self.base_net[start_layer_index: end_layer_index]: 50 | x = layer(x) 51 | start_layer_index = end_layer_index 52 | if added_layer: 53 | y = added_layer(x) 54 | else: 55 | y = x 56 | #confidence, location = self.compute_header(header_index, y) 57 | features.append(y) 58 | header_index += 1 59 | # confidences.append(confidence) 60 | # locations.append(location) 61 | 62 | for layer in self.base_net[end_layer_index:]: 63 | x = layer(x) 64 | 65 | for layer in self.extras: 66 | x = layer(x) 67 | #confidence, location = self.compute_header(header_index, x) 68 | features.append(x) 69 | header_index += 1 70 | # confidences.append(confidence) 71 | # locations.append(location) 72 | 73 | upstream_feature = None 74 | for i in range(len(features) - 1, -1, -1): 75 | feature = features[i] 76 | if upstream_feature is not None: 77 | upstream_feature = self.upsamplers[i](upstream_feature) 78 | upstream_feature += feature 79 | else: 80 | upstream_feature = feature 81 | confidence, location = self.compute_header(i, upstream_feature) 82 | confidences.append(confidence) 83 | locations.append(location) 84 | confidences = torch.cat(confidences, 1) 85 | locations = torch.cat(locations, 1) 86 | return confidences, locations 87 | 88 | def compute_header(self, i, x): 89 | confidence = self.classification_headers[i](x) 90 | confidence = confidence.permute(0, 2, 3, 1).contiguous() 91 | confidence = confidence.view(confidence.size(0), -1, self.num_classes) 92 | 93 | location = self.regression_headers[i](x) 94 | location = location.permute(0, 2, 3, 1).contiguous() 95 | location = location.view(location.size(0), -1, 4) 96 | 97 | return confidence, location 98 | 99 | def init_from_base_net(self, model): 100 | self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=False) 101 | self.source_layer_add_ons.apply(_xavier_init_) 102 | self.extras.apply(_xavier_init_) 103 | self.classification_headers.apply(_xavier_init_) 104 | self.regression_headers.apply(_xavier_init_) 105 | 106 | def init(self): 107 | self.base_net.apply(_xavier_init_) 108 | self.source_layer_add_ons.apply(_xavier_init_) 109 | self.extras.apply(_xavier_init_) 110 | self.classification_headers.apply(_xavier_init_) 111 | self.regression_headers.apply(_xavier_init_) 112 | 113 | def load(self, model): 114 | self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage)) 115 | 116 | def save(self, model_path): 117 | torch.save(self.state_dict(), model_path) 118 | 119 | 120 | class MatchPrior(object): 121 | def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold): 122 | self.center_form_priors = center_form_priors 123 | self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors) 124 | self.center_variance = center_variance 125 | self.size_variance = size_variance 126 | self.iou_threshold = iou_threshold 127 | 128 | def __call__(self, gt_boxes, gt_labels): 129 | if type(gt_boxes) is np.ndarray: 130 | gt_boxes = torch.from_numpy(gt_boxes) 131 | if type(gt_labels) is np.ndarray: 132 | gt_labels = torch.from_numpy(gt_labels) 133 | boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels, 134 | self.corner_form_priors, self.iou_threshold) 135 | boxes = box_utils.corner_form_to_center_form(boxes) 136 | locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance) 137 | return locations, labels 138 | 139 | 140 | def _xavier_init_(m: nn.Module): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.xavier_uniform_(m.weight) 143 | -------------------------------------------------------------------------------- /modeling/detection/mb2-ssd-lite-mp-0_686.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/mb2-ssd-lite-mp-0_686.pth -------------------------------------------------------------------------------- /modeling/detection/mobilenet_v2_ssd_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, BatchNorm2d 3 | from torch import nn 4 | from .nn.mobilenet_v2 import MobileNetV2, InvertedResidual 5 | 6 | from .ssd import SSD, GraphPath 7 | from .predictor import Predictor 8 | from .config import mobilenetv1_ssd_config as config 9 | 10 | 11 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, onnx_compatible=False): 12 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. 13 | """ 14 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6 15 | return Sequential( 16 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 17 | groups=in_channels, stride=stride, padding=padding), 18 | BatchNorm2d(in_channels), 19 | ReLU(), 20 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1), 21 | ) 22 | 23 | 24 | def create_mobilenetv2_ssd_lite(num_classes, width_mult=1.0, use_batch_norm=True, onnx_compatible=False, is_test=False, quantize=False): 25 | base_net = MobileNetV2(width_mult=width_mult, use_batch_norm=use_batch_norm, 26 | onnx_compatible=onnx_compatible).features 27 | 28 | source_layer_indexes = [ 29 | GraphPath(14, 'conv', 3), 30 | 19, 31 | ] 32 | extras = ModuleList([ 33 | InvertedResidual(1280, 512, stride=2, expand_ratio=0.2), 34 | InvertedResidual(512, 256, stride=2, expand_ratio=0.25), 35 | InvertedResidual(256, 256, stride=2, expand_ratio=0.5), 36 | InvertedResidual(256, 64, stride=2, expand_ratio=0.25) 37 | ]) 38 | 39 | regression_headers = ModuleList([ 40 | SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * 4, 41 | kernel_size=3, padding=1, onnx_compatible=False), 42 | SeperableConv2d(in_channels=1280, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), 43 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), 44 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), 45 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1, onnx_compatible=False), 46 | Conv2d(in_channels=64, out_channels=6 * 4, kernel_size=1), 47 | ]) 48 | 49 | classification_headers = ModuleList([ 50 | SeperableConv2d(in_channels=round(576 * width_mult), out_channels=6 * num_classes, kernel_size=3, padding=1), 51 | SeperableConv2d(in_channels=1280, out_channels=6 * num_classes, kernel_size=3, padding=1), 52 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 53 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 54 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 55 | Conv2d(in_channels=64, out_channels=6 * num_classes, kernel_size=1), 56 | ]) 57 | if quantize: 58 | from utils.quantize import quantize 59 | config.priors = quantize(config.priors, num_bits=8, min_value=float(config.priors.min()), max_value=float(config.priors.max())) 60 | return SSD(num_classes, base_net, source_layer_indexes, 61 | extras, classification_headers, regression_headers, is_test=is_test, config=config) 62 | 63 | 64 | def create_mobilenetv2_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')): 65 | predictor = Predictor(net, config.image_size, config.image_mean, 66 | config.image_std, 67 | nms_method=nms_method, 68 | iou_threshold=config.iou_threshold, 69 | candidate_size=candidate_size, 70 | sigma=sigma, 71 | device=device) 72 | return predictor 73 | -------------------------------------------------------------------------------- /modeling/detection/mobilenetv1_ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU 3 | from .nn.mobilenet import MobileNetV1 4 | 5 | from .ssd import SSD 6 | from .predictor import Predictor 7 | from .config import mobilenetv1_ssd_config as config 8 | 9 | 10 | def create_mobilenetv1_ssd(num_classes, is_test=False): 11 | base_net = MobileNetV1(1001).model # disable dropout layer 12 | 13 | source_layer_indexes = [ 14 | 12, 15 | 14, 16 | ] 17 | extras = ModuleList([ 18 | Sequential( 19 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1), 20 | ReLU(), 21 | Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), 22 | ReLU() 23 | ), 24 | Sequential( 25 | Conv2d(in_channels=512, out_channels=128, kernel_size=1), 26 | ReLU(), 27 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 28 | ReLU() 29 | ), 30 | Sequential( 31 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 32 | ReLU(), 33 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 34 | ReLU() 35 | ), 36 | Sequential( 37 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 38 | ReLU(), 39 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 40 | ReLU() 41 | ) 42 | ]) 43 | 44 | regression_headers = ModuleList([ 45 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 46 | Conv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1), 47 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 48 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 49 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 51 | ]) 52 | 53 | classification_headers = ModuleList([ 54 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 55 | Conv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1), 56 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 57 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 58 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 60 | ]) 61 | 62 | return SSD(num_classes, base_net, source_layer_indexes, 63 | extras, classification_headers, regression_headers, is_test=is_test, config=config) 64 | 65 | 66 | def create_mobilenetv1_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None): 67 | predictor = Predictor(net, config.image_size, config.image_mean, 68 | config.image_std, 69 | nms_method=nms_method, 70 | iou_threshold=config.iou_threshold, 71 | candidate_size=candidate_size, 72 | sigma=sigma, 73 | device=device) 74 | return predictor 75 | -------------------------------------------------------------------------------- /modeling/detection/mobilenetv1_ssd_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU, BatchNorm2d 3 | from .nn.mobilenet import MobileNetV1 4 | 5 | from .ssd import SSD 6 | from .predictor import Predictor 7 | from .config import mobilenetv1_ssd_config as config 8 | 9 | 10 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0): 11 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. 12 | """ 13 | return Sequential( 14 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 15 | groups=in_channels, stride=stride, padding=padding), 16 | ReLU(), 17 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1), 18 | ) 19 | 20 | 21 | def create_mobilenetv1_ssd_lite(num_classes, is_test=False): 22 | base_net = MobileNetV1(1001).model # disable dropout layer 23 | 24 | source_layer_indexes = [ 25 | 12, 26 | 14, 27 | ] 28 | extras = ModuleList([ 29 | Sequential( 30 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1), 31 | ReLU(), 32 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), 33 | ), 34 | Sequential( 35 | Conv2d(in_channels=512, out_channels=128, kernel_size=1), 36 | ReLU(), 37 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 38 | ), 39 | Sequential( 40 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 41 | ReLU(), 42 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 43 | ), 44 | Sequential( 45 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 46 | ReLU(), 47 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1) 48 | ) 49 | ]) 50 | 51 | regression_headers = ModuleList([ 52 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 53 | SeperableConv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1), 54 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 55 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 56 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 57 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=1), 58 | ]) 59 | 60 | classification_headers = ModuleList([ 61 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 62 | SeperableConv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1), 63 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 64 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 65 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 66 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=1), 67 | ]) 68 | 69 | return SSD(num_classes, base_net, source_layer_indexes, 70 | extras, classification_headers, regression_headers, is_test=is_test, config=config) 71 | 72 | 73 | def create_mobilenetv1_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None): 74 | predictor = Predictor(net, config.image_size, config.image_mean, 75 | config.image_std, 76 | nms_method=nms_method, 77 | iou_threshold=config.iou_threshold, 78 | candidate_size=candidate_size, 79 | sigma=sigma, 80 | device=device) 81 | return predictor 82 | -------------------------------------------------------------------------------- /modeling/detection/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/nn/__init__.py -------------------------------------------------------------------------------- /modeling/detection/nn/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | # copied from torchvision (https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py). 5 | # The forward function is modified for model pruning. 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | 10 | model_urls = { 11 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 12 | } 13 | 14 | 15 | class AlexNet(nn.Module): 16 | 17 | def __init__(self, num_classes=1000): 18 | super(AlexNet, self).__init__() 19 | self.features = nn.Sequential( 20 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(kernel_size=3, stride=2), 23 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 24 | nn.ReLU(inplace=True), 25 | nn.MaxPool2d(kernel_size=3, stride=2), 26 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 31 | nn.ReLU(inplace=True), 32 | nn.MaxPool2d(kernel_size=3, stride=2), 33 | ) 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(4096, num_classes), 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.features(x) 46 | x = x.view(x.size(0), -1) 47 | x = self.classifier(x) 48 | return x 49 | 50 | 51 | def alexnet(pretrained=False, **kwargs): 52 | r"""AlexNet model architecture from the 53 | `"One weird trick..." `_ paper. 54 | 55 | Args: 56 | pretrained (bool): If True, returns a model pre-trained on ImageNet 57 | """ 58 | model = AlexNet(**kwargs) 59 | if pretrained: 60 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 61 | return model -------------------------------------------------------------------------------- /modeling/detection/nn/mobilenet.py: -------------------------------------------------------------------------------- 1 | # borrowed from "https://github.com/marvis/pytorch-mobilenet" 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MobileNetV1(nn.Module): 8 | def __init__(self, num_classes=1024): 9 | super(MobileNetV1, self).__init__() 10 | 11 | def conv_bn(inp, oup, stride): 12 | return nn.Sequential( 13 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 14 | nn.BatchNorm2d(oup), 15 | nn.ReLU(inplace=True) 16 | ) 17 | 18 | def conv_dw(inp, oup, stride): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 21 | nn.BatchNorm2d(inp), 22 | nn.ReLU(inplace=True), 23 | 24 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 25 | nn.BatchNorm2d(oup), 26 | nn.ReLU(inplace=True), 27 | ) 28 | 29 | self.model = nn.Sequential( 30 | conv_bn(3, 32, 2), 31 | conv_dw(32, 64, 1), 32 | conv_dw(64, 128, 2), 33 | conv_dw(128, 128, 1), 34 | conv_dw(128, 256, 2), 35 | conv_dw(256, 256, 1), 36 | conv_dw(256, 512, 2), 37 | conv_dw(512, 512, 1), 38 | conv_dw(512, 512, 1), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 1024, 2), 43 | conv_dw(1024, 1024, 1), 44 | ) 45 | self.fc = nn.Linear(1024, num_classes) 46 | 47 | def forward(self, x): 48 | x = self.model(x) 49 | x = F.avg_pool2d(x, 7) 50 | x = x.view(-1, 1024) 51 | x = self.fc(x) 52 | return x -------------------------------------------------------------------------------- /modeling/detection/nn/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | # Modified from https://github.com/tonylins/pytorch-mobilenet-v2/blob/master/MobileNetV2.py. 5 | # In this version, Relu6 is replaced with Relu to make it ONNX compatible. 6 | # BatchNorm Layer is optional to make it easy do batch norm confusion. 7 | 8 | 9 | def conv_bn(inp, oup, stride, use_batch_norm=True, onnx_compatible=False): 10 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6 11 | 12 | if use_batch_norm: 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | ReLU(inplace=True) 17 | ) 18 | else: 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 21 | ReLU(inplace=True) 22 | ) 23 | 24 | 25 | def conv_1x1_bn(inp, oup, use_batch_norm=True, onnx_compatible=False): 26 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6 27 | if use_batch_norm: 28 | return nn.Sequential( 29 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 30 | nn.BatchNorm2d(oup), 31 | ReLU(inplace=True) 32 | ) 33 | else: 34 | return nn.Sequential( 35 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 36 | ReLU(inplace=True) 37 | ) 38 | 39 | 40 | class InvertedResidual(nn.Module): 41 | def __init__(self, inp, oup, stride, expand_ratio, use_batch_norm=True, onnx_compatible=False): 42 | super(InvertedResidual, self).__init__() 43 | ReLU = nn.ReLU if onnx_compatible else nn.ReLU6 44 | 45 | self.stride = stride 46 | assert stride in [1, 2] 47 | 48 | hidden_dim = round(inp * expand_ratio) 49 | self.use_res_connect = self.stride == 1 and inp == oup 50 | 51 | if expand_ratio == 1: 52 | if use_batch_norm: 53 | self.conv = nn.Sequential( 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 56 | nn.BatchNorm2d(hidden_dim), 57 | ReLU(inplace=True), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | nn.BatchNorm2d(oup), 61 | ) 62 | else: 63 | self.conv = nn.Sequential( 64 | # dw 65 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 66 | ReLU(inplace=True), 67 | # pw-linear 68 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 69 | ) 70 | else: 71 | if use_batch_norm: 72 | self.conv = nn.Sequential( 73 | # pw 74 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 75 | nn.BatchNorm2d(hidden_dim), 76 | ReLU(inplace=True), 77 | # dw 78 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 79 | nn.BatchNorm2d(hidden_dim), 80 | ReLU(inplace=True), 81 | # pw-linear 82 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 83 | nn.BatchNorm2d(oup), 84 | ) 85 | else: 86 | self.conv = nn.Sequential( 87 | # pw 88 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 89 | ReLU(inplace=True), 90 | # dw 91 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 92 | ReLU(inplace=True), 93 | # pw-linear 94 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 95 | ) 96 | 97 | def forward(self, x): 98 | if self.use_res_connect: 99 | return x + self.conv(x) 100 | else: 101 | return self.conv(x) 102 | 103 | 104 | class MobileNetV2(nn.Module): 105 | def __init__(self, n_class=1000, input_size=224, width_mult=1., dropout_ratio=0.2, 106 | use_batch_norm=True, onnx_compatible=False): 107 | super(MobileNetV2, self).__init__() 108 | block = InvertedResidual 109 | input_channel = 32 110 | last_channel = 1280 111 | interverted_residual_setting = [ 112 | # t, c, n, s 113 | [1, 16, 1, 1], 114 | [6, 24, 2, 2], 115 | [6, 32, 3, 2], 116 | [6, 64, 4, 2], 117 | [6, 96, 3, 1], 118 | [6, 160, 3, 2], 119 | [6, 320, 1, 1], 120 | ] 121 | 122 | # building first layer 123 | assert input_size % 32 == 0 124 | input_channel = int(input_channel * width_mult) 125 | self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel 126 | self.features = [conv_bn(3, input_channel, 2, onnx_compatible=onnx_compatible)] 127 | # building inverted residual blocks 128 | for t, c, n, s in interverted_residual_setting: 129 | output_channel = int(c * width_mult) 130 | for i in range(n): 131 | if i == 0: 132 | self.features.append(block(input_channel, output_channel, s, 133 | expand_ratio=t, use_batch_norm=use_batch_norm, 134 | onnx_compatible=onnx_compatible)) 135 | else: 136 | self.features.append(block(input_channel, output_channel, 1, 137 | expand_ratio=t, use_batch_norm=use_batch_norm, 138 | onnx_compatible=onnx_compatible)) 139 | input_channel = output_channel 140 | # building last several layers 141 | self.features.append(conv_1x1_bn(input_channel, self.last_channel, 142 | use_batch_norm=use_batch_norm, onnx_compatible=onnx_compatible)) 143 | # make it nn.Sequential 144 | self.features = nn.Sequential(*self.features) 145 | 146 | # building classifier 147 | self.classifier = nn.Sequential( 148 | nn.Dropout(dropout_ratio), 149 | nn.Linear(self.last_channel, n_class), 150 | ) 151 | 152 | self._initialize_weights() 153 | 154 | def forward(self, x): 155 | x = self.features(x) 156 | x = x.mean(3).mean(2) 157 | x = self.classifier(x) 158 | return x 159 | 160 | def _initialize_weights(self): 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | m.weight.data.normal_(0, math.sqrt(2. / n)) 165 | if m.bias is not None: 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | elif isinstance(m, nn.Linear): 171 | n = m.weight.size(1) 172 | m.weight.data.normal_(0, 0.01) 173 | m.bias.data.zero_() 174 | -------------------------------------------------------------------------------- /modeling/detection/nn/multibox_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | from ..utils import box_utils 7 | 8 | 9 | class MultiboxLoss(nn.Module): 10 | def __init__(self, priors, iou_threshold, neg_pos_ratio, 11 | center_variance, size_variance, device): 12 | """Implement SSD Multibox Loss. 13 | 14 | Basically, Multibox loss combines classification loss 15 | and Smooth L1 regression loss. 16 | """ 17 | super(MultiboxLoss, self).__init__() 18 | self.iou_threshold = iou_threshold 19 | self.neg_pos_ratio = neg_pos_ratio 20 | self.center_variance = center_variance 21 | self.size_variance = size_variance 22 | self.priors = priors 23 | self.priors.to(device) 24 | 25 | def forward(self, confidence, predicted_locations, labels, gt_locations): 26 | """Compute classification loss and smooth l1 loss. 27 | 28 | Args: 29 | confidence (batch_size, num_priors, num_classes): class predictions. 30 | locations (batch_size, num_priors, 4): predicted locations. 31 | labels (batch_size, num_priors): real labels of all the priors. 32 | boxes (batch_size, num_priors, 4): real boxes corresponding all the priors. 33 | """ 34 | num_classes = confidence.size(2) 35 | with torch.no_grad(): 36 | # derived from cross_entropy=sum(log(p)) 37 | loss = -F.log_softmax(confidence, dim=2)[:, :, 0] 38 | mask = box_utils.hard_negative_mining(loss, labels, self.neg_pos_ratio) 39 | 40 | confidence = confidence[mask, :] 41 | classification_loss = F.cross_entropy(confidence.reshape(-1, num_classes), labels[mask], size_average=False) 42 | pos_mask = labels > 0 43 | predicted_locations = predicted_locations[pos_mask, :].reshape(-1, 4) 44 | gt_locations = gt_locations[pos_mask, :].reshape(-1, 4) 45 | smooth_l1_loss = F.smooth_l1_loss(predicted_locations, gt_locations, size_average=False) 46 | num_pos = gt_locations.size(0) 47 | return smooth_l1_loss/num_pos, classification_loss/num_pos 48 | -------------------------------------------------------------------------------- /modeling/detection/nn/scaled_l2_norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class ScaledL2Norm(nn.Module): 7 | def __init__(self, in_channels, initial_scale): 8 | super(ScaledL2Norm, self).__init__() 9 | self.in_channels = in_channels 10 | self.scale = nn.Parameter(torch.Tensor(in_channels)) 11 | self.initial_scale = initial_scale 12 | self.reset_parameters() 13 | 14 | def forward(self, x): 15 | return (F.normalize(x, p=2, dim=1) 16 | * self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3)) 17 | 18 | def reset_parameters(self): 19 | self.scale.data.fill_(self.initial_scale) -------------------------------------------------------------------------------- /modeling/detection/nn/squeezenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | 8 | __all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] 9 | 10 | 11 | model_urls = { 12 | 'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth', 13 | 'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth', 14 | } 15 | 16 | 17 | class Fire(nn.Module): 18 | 19 | def __init__(self, inplanes, squeeze_planes, 20 | expand1x1_planes, expand3x3_planes): 21 | super(Fire, self).__init__() 22 | self.inplanes = inplanes 23 | self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1) 24 | self.squeeze_activation = nn.ReLU(inplace=True) 25 | self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, 26 | kernel_size=1) 27 | self.expand1x1_activation = nn.ReLU(inplace=True) 28 | self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, 29 | kernel_size=3, padding=1) 30 | self.expand3x3_activation = nn.ReLU(inplace=True) 31 | 32 | def forward(self, x): 33 | x = self.squeeze_activation(self.squeeze(x)) 34 | return torch.cat([ 35 | self.expand1x1_activation(self.expand1x1(x)), 36 | self.expand3x3_activation(self.expand3x3(x)) 37 | ], 1) 38 | 39 | 40 | class SqueezeNet(nn.Module): 41 | 42 | def __init__(self, version=1.0, num_classes=1000): 43 | super(SqueezeNet, self).__init__() 44 | if version not in [1.0, 1.1]: 45 | raise ValueError("Unsupported SqueezeNet version {version}:" 46 | "1.0 or 1.1 expected".format(version=version)) 47 | self.num_classes = num_classes 48 | if version == 1.0: 49 | self.features = nn.Sequential( 50 | nn.Conv2d(3, 96, kernel_size=7, stride=2), 51 | nn.ReLU(inplace=True), 52 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 53 | Fire(96, 16, 64, 64), 54 | Fire(128, 16, 64, 64), 55 | Fire(128, 32, 128, 128), 56 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 57 | Fire(256, 32, 128, 128), 58 | Fire(256, 48, 192, 192), 59 | Fire(384, 48, 192, 192), 60 | Fire(384, 64, 256, 256), 61 | nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), 62 | Fire(512, 64, 256, 256), 63 | ) 64 | else: 65 | self.features = nn.Sequential( 66 | nn.Conv2d(3, 64, kernel_size=3, stride=2), 67 | nn.ReLU(inplace=True), 68 | nn.MaxPool2d(kernel_size=3, stride=2), 69 | Fire(64, 16, 64, 64), 70 | Fire(128, 16, 64, 64), 71 | nn.MaxPool2d(kernel_size=3, stride=2), 72 | Fire(128, 32, 128, 128), 73 | Fire(256, 32, 128, 128), 74 | nn.MaxPool2d(kernel_size=3, stride=2), 75 | Fire(256, 48, 192, 192), 76 | Fire(384, 48, 192, 192), 77 | Fire(384, 64, 256, 256), 78 | Fire(512, 64, 256, 256), 79 | ) 80 | # Final convolution is initialized differently form the rest 81 | final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1) 82 | self.classifier = nn.Sequential( 83 | nn.Dropout(p=0.5), 84 | final_conv, 85 | nn.ReLU(inplace=True), 86 | nn.AvgPool2d(13, stride=1) 87 | ) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | if m is final_conv: 92 | init.normal_(m.weight, mean=0.0, std=0.01) 93 | else: 94 | init.kaiming_uniform_(m.weight) 95 | if m.bias is not None: 96 | init.constant_(m.bias, 0) 97 | 98 | def forward(self, x): 99 | x = self.features(x) 100 | x = self.classifier(x) 101 | return x.view(x.size(0), self.num_classes) 102 | 103 | 104 | def squeezenet1_0(pretrained=False, **kwargs): 105 | r"""SqueezeNet model architecture from the `"SqueezeNet: AlexNet-level 106 | accuracy with 50x fewer parameters and <0.5MB model size" 107 | `_ paper. 108 | 109 | Args: 110 | pretrained (bool): If True, returns a model pre-trained on ImageNet 111 | """ 112 | model = SqueezeNet(version=1.0, **kwargs) 113 | if pretrained: 114 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_0'])) 115 | return model 116 | 117 | 118 | def squeezenet1_1(pretrained=False, **kwargs): 119 | r"""SqueezeNet 1.1 model from the `official SqueezeNet repo 120 | `_. 121 | SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters 122 | than SqueezeNet 1.0, without sacrificing accuracy. 123 | 124 | Args: 125 | pretrained (bool): If True, returns a model pre-trained on ImageNet 126 | """ 127 | model = SqueezeNet(version=1.1, **kwargs) 128 | if pretrained: 129 | model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1'])) 130 | return model 131 | -------------------------------------------------------------------------------- /modeling/detection/nn/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # borrowed from https://github.com/amdegroot/ssd.pytorch/blob/master/ssd.py 5 | def vgg(cfg, batch_norm=False): 6 | layers = [] 7 | in_channels = 3 8 | for v in cfg: 9 | if v == 'M': 10 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 11 | elif v == 'C': 12 | layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] 13 | else: 14 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 15 | if batch_norm: 16 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 17 | else: 18 | layers += [conv2d, nn.ReLU(inplace=True)] 19 | in_channels = v 20 | pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 21 | conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) 22 | conv7 = nn.Conv2d(1024, 1024, kernel_size=1) 23 | layers += [pool5, conv6, 24 | nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] 25 | return layers -------------------------------------------------------------------------------- /modeling/detection/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from utils.detection import box_utils 4 | from .data_preprocessing import PredictionTransform 5 | from utils.detection.misc import Timer 6 | 7 | 8 | class Predictor: 9 | def __init__(self, net, size, mean=0.0, std=1.0, nms_method=None, 10 | iou_threshold=0.45, filter_threshold=0.01, candidate_size=200, sigma=0.5, device=None): 11 | self.net = net 12 | self.transform = PredictionTransform(size, mean, std) 13 | self.iou_threshold = iou_threshold 14 | self.filter_threshold = filter_threshold 15 | self.candidate_size = candidate_size 16 | self.nms_method = nms_method 17 | 18 | self.sigma = sigma 19 | if device: 20 | self.device = device 21 | else: 22 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 23 | 24 | self.net.to(self.device) 25 | self.net.eval() 26 | 27 | # self.timer = Timer() 28 | 29 | def predict(self, image, top_k=-1, prob_threshold=None): 30 | cpu_device = torch.device("cpu") 31 | height, width, _ = image.shape 32 | image = self.transform(image) 33 | images = image.unsqueeze(0) 34 | images = images.to(self.device) 35 | with torch.no_grad(): 36 | # self.timer.start() 37 | scores, boxes = self.net.forward(images) 38 | boxes = box_utils.convert_locations_to_boxes(*boxes) 39 | boxes = box_utils.center_form_to_corner_form(boxes) 40 | # scores, boxes = self.net.forward(images) 41 | # print("Inference time: ", self.timer.end()) 42 | boxes = boxes[0] 43 | scores = scores[0] 44 | if not prob_threshold: 45 | prob_threshold = self.filter_threshold 46 | # this version of nms is slower on GPU, so we move data to CPU. 47 | boxes = boxes.to(cpu_device) 48 | scores = scores.to(cpu_device) 49 | picked_box_probs = [] 50 | picked_labels = [] 51 | for class_index in range(1, scores.size(1)): 52 | probs = scores[:, class_index] 53 | mask = probs > prob_threshold 54 | probs = probs[mask] 55 | if probs.size(0) == 0: 56 | continue 57 | subset_boxes = boxes[mask, :] 58 | box_probs = torch.cat([subset_boxes, probs.reshape(-1, 1)], dim=1) 59 | box_probs = box_utils.nms(box_probs, self.nms_method, 60 | score_threshold=prob_threshold, 61 | iou_threshold=self.iou_threshold, 62 | sigma=self.sigma, 63 | top_k=top_k, 64 | candidate_size=self.candidate_size) 65 | picked_box_probs.append(box_probs) 66 | picked_labels.extend([class_index] * box_probs.size(0)) 67 | if not picked_box_probs: 68 | return torch.tensor([]), torch.tensor([]), torch.tensor([]) 69 | picked_box_probs = torch.cat(picked_box_probs) 70 | picked_box_probs[:, 0] *= width 71 | picked_box_probs[:, 1] *= height 72 | picked_box_probs[:, 2] *= width 73 | picked_box_probs[:, 3] *= height 74 | return picked_box_probs[:, :4], torch.tensor(picked_labels), picked_box_probs[:, 4] -------------------------------------------------------------------------------- /modeling/detection/squeezenet_ssd_lite.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU 3 | from .nn.squeezenet import squeezenet1_1 4 | 5 | from .ssd import SSD 6 | from .predictor import Predictor 7 | from .config import squeezenet_ssd_config as config 8 | 9 | 10 | def SeperableConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0): 11 | """Replace Conv2d with a depthwise Conv2d and Pointwise Conv2d. 12 | """ 13 | return Sequential( 14 | Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, 15 | groups=in_channels, stride=stride, padding=padding), 16 | ReLU(), 17 | Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1), 18 | ) 19 | 20 | 21 | def create_squeezenet_ssd_lite(num_classes, is_test=False): 22 | base_net = squeezenet1_1(False).features # disable dropout layer 23 | 24 | source_layer_indexes = [ 25 | 12 26 | ] 27 | extras = ModuleList([ 28 | Sequential( 29 | Conv2d(in_channels=512, out_channels=256, kernel_size=1), 30 | ReLU(), 31 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=2), 32 | ), 33 | Sequential( 34 | Conv2d(in_channels=512, out_channels=256, kernel_size=1), 35 | ReLU(), 36 | SeperableConv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), 37 | ), 38 | Sequential( 39 | Conv2d(in_channels=512, out_channels=128, kernel_size=1), 40 | ReLU(), 41 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 42 | ), 43 | Sequential( 44 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 45 | ReLU(), 46 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 47 | ), 48 | Sequential( 49 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 50 | ReLU(), 51 | SeperableConv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1) 52 | ) 53 | ]) 54 | 55 | regression_headers = ModuleList([ 56 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 57 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 58 | SeperableConv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 59 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 60 | SeperableConv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 61 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=1), 62 | ]) 63 | 64 | classification_headers = ModuleList([ 65 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 66 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 67 | SeperableConv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 68 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 69 | SeperableConv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 70 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=1), 71 | ]) 72 | 73 | return SSD(num_classes, base_net, source_layer_indexes, 74 | extras, classification_headers, regression_headers, is_test=is_test, config=config) 75 | 76 | 77 | def create_squeezenet_ssd_lite_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=torch.device('cpu')): 78 | predictor = Predictor(net, config.image_size, config.image_mean, 79 | config.image_std, 80 | nms_method=nms_method, 81 | iou_threshold=config.iou_threshold, 82 | candidate_size=candidate_size, 83 | sigma=sigma, 84 | device=device) 85 | return predictor -------------------------------------------------------------------------------- /modeling/detection/ssd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | from typing import List, Tuple 5 | import torch.nn.functional as F 6 | 7 | from utils.detection import box_utils 8 | from collections import namedtuple 9 | GraphPath = namedtuple("GraphPath", ['s0', 'name', 's1']) # 10 | 11 | 12 | class SSD(nn.Module): 13 | def __init__(self, num_classes: int, base_net: nn.ModuleList, source_layer_indexes: List[int], 14 | extras: nn.ModuleList, classification_headers: nn.ModuleList, 15 | regression_headers: nn.ModuleList, is_test=False, config=None, device=None): 16 | """Compose a SSD model using the given components. 17 | """ 18 | super(SSD, self).__init__() 19 | 20 | self.num_classes = num_classes 21 | self.base_net = base_net 22 | self.source_layer_indexes = source_layer_indexes 23 | self.extras = extras 24 | self.classification_headers = classification_headers 25 | self.regression_headers = regression_headers 26 | self.is_test = is_test 27 | self.config = config 28 | 29 | # register layers in source_layer_indexes by adding them to a module list 30 | self.source_layer_add_ons = nn.ModuleList([t[1] for t in source_layer_indexes 31 | if isinstance(t, tuple) and not isinstance(t, GraphPath)]) 32 | if device: 33 | self.device = device 34 | else: 35 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 36 | if is_test: 37 | self.config = config 38 | self.priors = config.priors.to(self.device) 39 | 40 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 41 | confidences = [] 42 | locations = [] 43 | start_layer_index = 0 44 | header_index = 0 45 | for end_layer_index in self.source_layer_indexes: 46 | if isinstance(end_layer_index, GraphPath): 47 | path = end_layer_index 48 | end_layer_index = end_layer_index.s0 49 | added_layer = None 50 | elif isinstance(end_layer_index, tuple): 51 | added_layer = end_layer_index[1] 52 | end_layer_index = end_layer_index[0] 53 | path = None 54 | else: 55 | added_layer = None 56 | path = None 57 | for layer in self.base_net[start_layer_index: end_layer_index]: 58 | x = layer(x) 59 | if added_layer: 60 | y = added_layer(x) 61 | else: 62 | y = x 63 | if path: 64 | sub = getattr(self.base_net[end_layer_index], path.name) 65 | for layer in sub[:path.s1]: 66 | x = layer(x) 67 | y = x 68 | for layer in sub[path.s1:]: 69 | x = layer(x) 70 | end_layer_index += 1 71 | start_layer_index = end_layer_index 72 | confidence, location = self.compute_header(header_index, y) 73 | header_index += 1 74 | confidences.append(confidence) 75 | locations.append(location) 76 | 77 | for layer in self.base_net[end_layer_index:]: 78 | x = layer(x) 79 | 80 | for layer in self.extras: 81 | x = layer(x) 82 | confidence, location = self.compute_header(header_index, x) 83 | header_index += 1 84 | confidences.append(confidence) 85 | locations.append(location) 86 | 87 | confidences = torch.cat(confidences, 1) 88 | locations = torch.cat(locations, 1) 89 | 90 | if self.is_test: 91 | confidences = F.softmax(confidences, dim=2) 92 | return confidences, (locations, self.priors, self.config.center_variance, self.config.size_variance) 93 | # boxes = box_utils.convert_locations_to_boxes( 94 | # locations, self.priors, self.config.center_variance, self.config.size_variance 95 | # ) 96 | # boxes = box_utils.center_form_to_corner_form(boxes) 97 | # return confidences, boxes 98 | else: 99 | return confidences, locations 100 | 101 | def compute_header(self, i, x): 102 | confidence = self.classification_headers[i](x) 103 | confidence = confidence.permute(0, 2, 3, 1).contiguous() 104 | confidence = confidence.view(confidence.size(0), -1, self.num_classes) 105 | 106 | location = self.regression_headers[i](x) 107 | location = location.permute(0, 2, 3, 1).contiguous() 108 | location = location.view(location.size(0), -1, 4) 109 | 110 | return confidence, location 111 | 112 | def init_from_base_net(self, model): 113 | self.base_net.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage), strict=True) 114 | self.source_layer_add_ons.apply(_xavier_init_) 115 | self.extras.apply(_xavier_init_) 116 | self.classification_headers.apply(_xavier_init_) 117 | self.regression_headers.apply(_xavier_init_) 118 | 119 | def init_from_pretrained_ssd(self, model): 120 | state_dict = torch.load(model, map_location=lambda storage, loc: storage) 121 | state_dict = {k: v for k, v in state_dict.items() if not (k.startswith("classification_headers") or k.startswith("regression_headers"))} 122 | model_dict = self.state_dict() 123 | model_dict.update(state_dict) 124 | self.load_state_dict(model_dict) 125 | self.classification_headers.apply(_xavier_init_) 126 | self.regression_headers.apply(_xavier_init_) 127 | 128 | def init(self): 129 | self.base_net.apply(_xavier_init_) 130 | self.source_layer_add_ons.apply(_xavier_init_) 131 | self.extras.apply(_xavier_init_) 132 | self.classification_headers.apply(_xavier_init_) 133 | self.regression_headers.apply(_xavier_init_) 134 | 135 | def load(self, model): 136 | self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage)) 137 | 138 | def save(self, model_path): 139 | torch.save(self.state_dict(), model_path) 140 | 141 | 142 | class MatchPrior(object): 143 | def __init__(self, center_form_priors, center_variance, size_variance, iou_threshold): 144 | self.center_form_priors = center_form_priors 145 | self.corner_form_priors = box_utils.center_form_to_corner_form(center_form_priors) 146 | self.center_variance = center_variance 147 | self.size_variance = size_variance 148 | self.iou_threshold = iou_threshold 149 | 150 | def __call__(self, gt_boxes, gt_labels): 151 | if type(gt_boxes) is np.ndarray: 152 | gt_boxes = torch.from_numpy(gt_boxes) 153 | if type(gt_labels) is np.ndarray: 154 | gt_labels = torch.from_numpy(gt_labels) 155 | boxes, labels = box_utils.assign_priors(gt_boxes, gt_labels, 156 | self.corner_form_priors, self.iou_threshold) 157 | boxes = box_utils.corner_form_to_center_form(boxes) 158 | locations = box_utils.convert_boxes_to_locations(boxes, self.center_form_priors, self.center_variance, self.size_variance) 159 | return locations, labels 160 | 161 | 162 | def _xavier_init_(m: nn.Module): 163 | if isinstance(m, nn.Conv2d): 164 | nn.init.xavier_uniform_(m.weight) 165 | -------------------------------------------------------------------------------- /modeling/detection/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/detection/transforms/__init__.py -------------------------------------------------------------------------------- /modeling/detection/vgg_ssd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Conv2d, Sequential, ModuleList, ReLU, BatchNorm2d 3 | from .nn.vgg import vgg 4 | 5 | from .ssd import SSD 6 | from .predictor import Predictor 7 | from .config import vgg_ssd_config as config 8 | 9 | 10 | def create_vgg_ssd(num_classes, is_test=False): 11 | vgg_config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 12 | 512, 512, 512] 13 | base_net = ModuleList(vgg(vgg_config)) 14 | 15 | source_layer_indexes = [ 16 | (23, BatchNorm2d(512)), 17 | len(base_net), 18 | ] 19 | extras = ModuleList([ 20 | Sequential( 21 | Conv2d(in_channels=1024, out_channels=256, kernel_size=1), 22 | ReLU(), 23 | Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1), 24 | ReLU() 25 | ), 26 | Sequential( 27 | Conv2d(in_channels=512, out_channels=128, kernel_size=1), 28 | ReLU(), 29 | Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), 30 | ReLU() 31 | ), 32 | Sequential( 33 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 34 | ReLU(), 35 | Conv2d(in_channels=128, out_channels=256, kernel_size=3), 36 | ReLU() 37 | ), 38 | Sequential( 39 | Conv2d(in_channels=256, out_channels=128, kernel_size=1), 40 | ReLU(), 41 | Conv2d(in_channels=128, out_channels=256, kernel_size=3), 42 | ReLU() 43 | ) 44 | ]) 45 | 46 | regression_headers = ModuleList([ 47 | Conv2d(in_channels=512, out_channels=4 * 4, kernel_size=3, padding=1), 48 | Conv2d(in_channels=1024, out_channels=6 * 4, kernel_size=3, padding=1), 49 | Conv2d(in_channels=512, out_channels=6 * 4, kernel_size=3, padding=1), 50 | Conv2d(in_channels=256, out_channels=6 * 4, kernel_size=3, padding=1), 51 | Conv2d(in_channels=256, out_channels=4 * 4, kernel_size=3, padding=1), 52 | Conv2d(in_channels=256, out_channels=4 * 4, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 53 | ]) 54 | 55 | classification_headers = ModuleList([ 56 | Conv2d(in_channels=512, out_channels=4 * num_classes, kernel_size=3, padding=1), 57 | Conv2d(in_channels=1024, out_channels=6 * num_classes, kernel_size=3, padding=1), 58 | Conv2d(in_channels=512, out_channels=6 * num_classes, kernel_size=3, padding=1), 59 | Conv2d(in_channels=256, out_channels=6 * num_classes, kernel_size=3, padding=1), 60 | Conv2d(in_channels=256, out_channels=4 * num_classes, kernel_size=3, padding=1), 61 | Conv2d(in_channels=256, out_channels=4 * num_classes, kernel_size=3, padding=1), # TODO: change to kernel_size=1, padding=0? 62 | ]) 63 | 64 | return SSD(num_classes, base_net, source_layer_indexes, 65 | extras, classification_headers, regression_headers, is_test=is_test, config=config) 66 | 67 | 68 | def create_vgg_ssd_predictor(net, candidate_size=200, nms_method=None, sigma=0.5, device=None): 69 | predictor = Predictor(net, config.image_size, config.image_mean, 70 | nms_method=nms_method, 71 | iou_threshold=config.iou_threshold, 72 | candidate_size=candidate_size, 73 | sigma=sigma, 74 | device=device) 75 | return predictor 76 | -------------------------------------------------------------------------------- /modeling/detection/voc-model-labels.txt: -------------------------------------------------------------------------------- 1 | BACKGROUND 2 | aeroplane 3 | bicycle 4 | bird 5 | boat 6 | bottle 7 | bus 8 | car 9 | cat 10 | chair 11 | cow 12 | diningtable 13 | dog 14 | horse 15 | motorbike 16 | person 17 | pottedplant 18 | sheep 19 | sofa 20 | train 21 | tvmonitor -------------------------------------------------------------------------------- /modeling/ncnn/model_quant_relu_equal.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/ncnn/model_quant_relu_equal.bin -------------------------------------------------------------------------------- /modeling/ncnn/model_quant_relu_equal.param: -------------------------------------------------------------------------------- 1 | 7767517 2 | 112 122 3 | Input 0 0 1 0 0=224 1=224 2=3 4 | Convolution 619 1 1 0 619 0=32 1=3 3=2 4=1 5=1 6=864 8=2 5 | ReLU 621 1 1 619 621 6 | ConvolutionDepthWise 622 1 1 621 622 0=32 1=3 4=1 5=1 6=288 7=32 8=1 7 | ReLU 624 1 1 622 624 8 | Convolution 625 1 1 624 625 0=16 1=1 5=1 6=512 8=2 9 | Convolution 627 1 1 625 627 0=96 1=1 5=1 6=1536 8=2 10 | ReLU 629 1 1 627 629 11 | ConvolutionDepthWise 630 1 1 629 630 0=96 1=3 3=2 4=1 5=1 6=864 7=96 8=1 12 | ReLU 632 1 1 630 632 13 | Convolution 633 1 1 632 633 0=24 1=1 5=1 6=2304 8=2 14 | Split splitncnn_0 1 2 633 633_splitncnn_0 633_splitncnn_1 15 | Convolution 635 1 1 633_splitncnn_1 635 0=144 1=1 5=1 6=3456 8=2 16 | ReLU 637 1 1 635 637 17 | ConvolutionDepthWise 638 1 1 637 638 0=144 1=3 4=1 5=1 6=1296 7=144 8=1 18 | ReLU 640 1 1 638 640 19 | Convolution 641 1 1 640 641 0=24 1=1 5=1 6=3456 8=2 20 | BinaryOp 643 2 1 633_splitncnn_0 641 643 21 | Convolution 644 1 1 643 644 0=144 1=1 5=1 6=3456 8=2 22 | ReLU 646 1 1 644 646 23 | ConvolutionDepthWise 647 1 1 646 647 0=144 1=3 3=2 4=1 5=1 6=1296 7=144 8=1 24 | ReLU 649 1 1 647 649 25 | Convolution 650 1 1 649 650 0=32 1=1 5=1 6=4608 8=2 26 | Split splitncnn_1 1 2 650 650_splitncnn_0 650_splitncnn_1 27 | Convolution 652 1 1 650_splitncnn_1 652 0=192 1=1 5=1 6=6144 8=2 28 | ReLU 654 1 1 652 654 29 | ConvolutionDepthWise 655 1 1 654 655 0=192 1=3 4=1 5=1 6=1728 7=192 8=1 30 | ReLU 657 1 1 655 657 31 | Convolution 658 1 1 657 658 0=32 1=1 5=1 6=6144 8=2 32 | BinaryOp 660 2 1 650_splitncnn_0 658 660 33 | Split splitncnn_2 1 2 660 660_splitncnn_0 660_splitncnn_1 34 | Convolution 661 1 1 660_splitncnn_1 661 0=192 1=1 5=1 6=6144 8=2 35 | ReLU 663 1 1 661 663 36 | ConvolutionDepthWise 664 1 1 663 664 0=192 1=3 4=1 5=1 6=1728 7=192 8=1 37 | ReLU 666 1 1 664 666 38 | Convolution 667 1 1 666 667 0=32 1=1 5=1 6=6144 8=2 39 | BinaryOp 669 2 1 660_splitncnn_0 667 669 40 | Convolution 670 1 1 669 670 0=192 1=1 5=1 6=6144 8=2 41 | ReLU 672 1 1 670 672 42 | ConvolutionDepthWise 673 1 1 672 673 0=192 1=3 3=2 4=1 5=1 6=1728 7=192 8=1 43 | ReLU 675 1 1 673 675 44 | Convolution 676 1 1 675 676 0=64 1=1 5=1 6=12288 8=2 45 | Split splitncnn_3 1 2 676 676_splitncnn_0 676_splitncnn_1 46 | Convolution 678 1 1 676_splitncnn_1 678 0=384 1=1 5=1 6=24576 8=2 47 | ReLU 680 1 1 678 680 48 | ConvolutionDepthWise 681 1 1 680 681 0=384 1=3 4=1 5=1 6=3456 7=384 8=1 49 | ReLU 683 1 1 681 683 50 | Convolution 684 1 1 683 684 0=64 1=1 5=1 6=24576 8=2 51 | BinaryOp 686 2 1 676_splitncnn_0 684 686 52 | Split splitncnn_4 1 2 686 686_splitncnn_0 686_splitncnn_1 53 | Convolution 687 1 1 686_splitncnn_1 687 0=384 1=1 5=1 6=24576 8=2 54 | ReLU 689 1 1 687 689 55 | ConvolutionDepthWise 690 1 1 689 690 0=384 1=3 4=1 5=1 6=3456 7=384 8=1 56 | ReLU 692 1 1 690 692 57 | Convolution 693 1 1 692 693 0=64 1=1 5=1 6=24576 8=2 58 | BinaryOp 695 2 1 686_splitncnn_0 693 695 59 | Split splitncnn_5 1 2 695 695_splitncnn_0 695_splitncnn_1 60 | Convolution 696 1 1 695_splitncnn_1 696 0=384 1=1 5=1 6=24576 8=2 61 | ReLU 698 1 1 696 698 62 | ConvolutionDepthWise 699 1 1 698 699 0=384 1=3 4=1 5=1 6=3456 7=384 8=1 63 | ReLU 701 1 1 699 701 64 | Convolution 702 1 1 701 702 0=64 1=1 5=1 6=24576 8=2 65 | BinaryOp 704 2 1 695_splitncnn_0 702 704 66 | Convolution 705 1 1 704 705 0=384 1=1 5=1 6=24576 8=2 67 | ReLU 707 1 1 705 707 68 | ConvolutionDepthWise 708 1 1 707 708 0=384 1=3 4=1 5=1 6=3456 7=384 8=1 69 | ReLU 710 1 1 708 710 70 | Convolution 711 1 1 710 711 0=96 1=1 5=1 6=36864 8=2 71 | Split splitncnn_6 1 2 711 711_splitncnn_0 711_splitncnn_1 72 | Convolution 713 1 1 711_splitncnn_1 713 0=576 1=1 5=1 6=55296 8=2 73 | ReLU 715 1 1 713 715 74 | ConvolutionDepthWise 716 1 1 715 716 0=576 1=3 4=1 5=1 6=5184 7=576 8=1 75 | ReLU 718 1 1 716 718 76 | Convolution 719 1 1 718 719 0=96 1=1 5=1 6=55296 8=2 77 | BinaryOp 721 2 1 711_splitncnn_0 719 721 78 | Split splitncnn_7 1 2 721 721_splitncnn_0 721_splitncnn_1 79 | Convolution 722 1 1 721_splitncnn_1 722 0=576 1=1 5=1 6=55296 8=2 80 | ReLU 724 1 1 722 724 81 | ConvolutionDepthWise 725 1 1 724 725 0=576 1=3 4=1 5=1 6=5184 7=576 8=1 82 | ReLU 727 1 1 725 727 83 | Convolution 728 1 1 727 728 0=96 1=1 5=1 6=55296 8=2 84 | BinaryOp 730 2 1 721_splitncnn_0 728 730 85 | Convolution 731 1 1 730 731 0=576 1=1 5=1 6=55296 8=2 86 | ReLU 733 1 1 731 733 87 | ConvolutionDepthWise 734 1 1 733 734 0=576 1=3 3=2 4=1 5=1 6=5184 7=576 8=1 88 | ReLU 736 1 1 734 736 89 | Convolution 737 1 1 736 737 0=160 1=1 5=1 6=92160 8=2 90 | Split splitncnn_8 1 2 737 737_splitncnn_0 737_splitncnn_1 91 | Convolution 739 1 1 737_splitncnn_1 739 0=960 1=1 5=1 6=153600 8=2 92 | ReLU 741 1 1 739 741 93 | ConvolutionDepthWise 742 1 1 741 742 0=960 1=3 4=1 5=1 6=8640 7=960 8=1 94 | ReLU 744 1 1 742 744 95 | Convolution 745 1 1 744 745 0=160 1=1 5=1 6=153600 8=2 96 | BinaryOp 747 2 1 737_splitncnn_0 745 747 97 | Split splitncnn_9 1 2 747 747_splitncnn_0 747_splitncnn_1 98 | Convolution 748 1 1 747_splitncnn_1 748 0=960 1=1 5=1 6=153600 8=2 99 | ReLU 750 1 1 748 750 100 | ConvolutionDepthWise 751 1 1 750 751 0=960 1=3 4=1 5=1 6=8640 7=960 8=1 101 | ReLU 753 1 1 751 753 102 | Convolution 754 1 1 753 754 0=160 1=1 5=1 6=153600 8=2 103 | BinaryOp 756 2 1 747_splitncnn_0 754 756 104 | Convolution 757 1 1 756 757 0=960 1=1 5=1 6=153600 8=2 105 | ReLU 759 1 1 757 759 106 | ConvolutionDepthWise 760 1 1 759 760 0=960 1=3 4=1 5=1 6=8640 7=960 8=1 107 | ReLU 762 1 1 760 762 108 | Convolution 763 1 1 762 763 0=320 1=1 5=1 6=307200 8=2 109 | Convolution 765 1 1 763 765 0=1280 1=1 5=1 6=409600 8=2 110 | ReLU 767 1 1 765 767 111 | Reshape 779 1 1 767 779 0=-1 1=1280 112 | Reduction 780 1 1 779 780 0=3 1=0 -23303=1,-1 113 | InnerProduct 781 1 1 780 781 0=1000 1=1 2=1280000 8=2 114 | Softmax 782 1 1 781 782 115 | -------------------------------------------------------------------------------- /modeling/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/segmentation/__init__.py -------------------------------------------------------------------------------- /modeling/segmentation/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /modeling/segmentation/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.segmentation.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /modeling/segmentation/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=False): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /modeling/segmentation/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /modeling/segmentation/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU(), 29 | nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv(x) 42 | 43 | return x 44 | 45 | def _init_weight(self): 46 | for m in self.modules(): 47 | if isinstance(m, nn.Conv2d): 48 | torch.nn.init.kaiming_normal_(m.weight) 49 | elif isinstance(m, SynchronizedBatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | elif isinstance(m, nn.BatchNorm2d): 53 | m.weight.data.fill_(1) 54 | m.bias.data.zero_() 55 | 56 | def build_decoder(num_classes, backbone, BatchNorm): 57 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /modeling/segmentation/deeplab-mobilenet.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/modeling/segmentation/deeplab-mobilenet.pth.tar -------------------------------------------------------------------------------- /modeling/segmentation/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modeling.segmentation.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from modeling.segmentation.aspp import build_aspp 6 | from modeling.segmentation.decoder import build_decoder 7 | from modeling.segmentation.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='mobilenet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | if freeze_bn: 26 | self.freeze_bn() 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 48 | or isinstance(m[1], nn.BatchNorm2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | 53 | def get_10x_lr_params(self): 54 | modules = [self.aspp, self.decoder] 55 | for i in range(len(modules)): 56 | for m in modules[i].named_modules(): 57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 58 | or isinstance(m[1], nn.BatchNorm2d): 59 | for p in m[1].parameters(): 60 | if p.requires_grad: 61 | yield p 62 | 63 | 64 | if __name__ == "__main__": 65 | model = DeepLab(backbone='mobilenet', output_stride=16) 66 | model.eval() 67 | input = torch.rand(1, 3, 513, 513) 68 | output = model(input) 69 | print(output.size()) 70 | 71 | 72 | -------------------------------------------------------------------------------- /modeling/segmentation/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /modeling/segmentation/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /modeling/segmentation/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /modeling/segmentation/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pydot==1.4.1 2 | torch==1.1.0 3 | matplotlib==3.1.0 4 | scipy==1.3.0 5 | numpy==1.16.4 6 | torchvision==0.3.0 7 | graphviz==0.10.1 8 | Pillow==8.1.1 9 | tqdm==4.47.0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | def visualize_per_layer(param, title='test'): 2 | import matplotlib.pyplot as plt 3 | channel = 0 4 | param_list = [] 5 | for idx in range(param.shape[channel]): 6 | # print(idx, param[idx].max(), param[idx].min()) 7 | param_list.append(param[idx].cpu().numpy().reshape(-1)) 8 | 9 | fig7, ax7 = plt.subplots() 10 | ax7.set_title(title) 11 | ax7.boxplot(param_list, showfliers=False) 12 | # plt.ylim(-70, 70) 13 | plt.show() -------------------------------------------------------------------------------- /utils/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | -------------------------------------------------------------------------------- /utils/detection/measurements.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_average_precision(precision, recall): 5 | """ 6 | It computes average precision based on the definition of Pascal Competition. It computes the under curve area 7 | of precision and recall. Recall follows the normal definition. Precision is a variant. 8 | pascal_precision[i] = typical_precision[i:].max() 9 | """ 10 | # identical but faster version of new_precision[i] = old_precision[i:].max() 11 | precision = np.concatenate([[0.0], precision, [0.0]]) 12 | for i in range(len(precision) - 1, 0, -1): 13 | precision[i - 1] = np.maximum(precision[i - 1], precision[i]) 14 | 15 | # find the index where the value changes 16 | recall = np.concatenate([[0.0], recall, [1.0]]) 17 | changing_points = np.where(recall[1:] != recall[:-1])[0] 18 | 19 | # compute under curve area 20 | areas = (recall[changing_points + 1] - recall[changing_points]) * precision[changing_points + 1] 21 | return areas.sum() 22 | 23 | 24 | def compute_voc2007_average_precision(precision, recall): 25 | ap = 0. 26 | for t in np.arange(0., 1.1, 0.1): 27 | if np.sum(recall >= t) == 0: 28 | p = 0 29 | else: 30 | p = np.max(precision[recall >= t]) 31 | ap = ap + p / 11. 32 | return ap 33 | -------------------------------------------------------------------------------- /utils/detection/misc.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | 5 | def str2bool(s): 6 | return s.lower() in ('true', '1') 7 | 8 | 9 | class Timer: 10 | def __init__(self): 11 | self.clock = {} 12 | 13 | def start(self, key="default"): 14 | self.clock[key] = time.time() 15 | 16 | def end(self, key="default"): 17 | if key not in self.clock: 18 | raise Exception(f"{key} is not in the clock.") 19 | interval = time.time() - self.clock[key] 20 | del self.clock[key] 21 | return interval 22 | 23 | 24 | def save_checkpoint(epoch, net_state_dict, optimizer_state_dict, best_score, checkpoint_path, model_path): 25 | torch.save({ 26 | 'epoch': epoch, 27 | 'model': net_state_dict, 28 | 'optimizer': optimizer_state_dict, 29 | 'best_score': best_score 30 | }, checkpoint_path) 31 | torch.save(net_state_dict, model_path) 32 | 33 | 34 | def load_checkpoint(checkpoint_path): 35 | return torch.load(checkpoint_path) 36 | 37 | 38 | def freeze_net_layers(net): 39 | for param in net.parameters(): 40 | param.requires_grad = False 41 | 42 | 43 | def store_labels(path, labels): 44 | with open(path, "w") as f: 45 | f.write("\n".join(labels)) 46 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Mean_Intersection_over_Union(self): 19 | MIoU = np.diag(self.confusion_matrix) / ( 20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 21 | np.diag(self.confusion_matrix)) 22 | MIoU = np.nanmean(MIoU) 23 | return MIoU 24 | 25 | def Frequency_Weighted_Intersection_over_Union(self): 26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 27 | iu = np.diag(self.confusion_matrix) / ( 28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 29 | np.diag(self.confusion_matrix)) 30 | 31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 32 | return FWIoU 33 | 34 | def _generate_matrix(self, gt_image, pre_image): 35 | mask = (gt_image >= 0) & (gt_image < self.num_class) 36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 37 | count = np.bincount(label, minlength=self.num_class**2) 38 | confusion_matrix = count.reshape(self.num_class, self.num_class) 39 | return confusion_matrix 40 | 41 | def add_batch(self, gt_image, pre_image): 42 | assert gt_image.shape == pre_image.shape 43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 44 | 45 | def reset(self): 46 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /utils/relation.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from torch.nn import BatchNorm2d, ReLU, Dropout, AvgPool2d 3 | from utils.quantize import QConv2d, QuantMeasure 4 | 5 | class Relation(): 6 | def __init__(self, layer_idx_1, layer_idx_2, bn_idx_1): 7 | self.layer_first = layer_idx_1 8 | self.layer_second = layer_idx_2 9 | self.bn_idx = bn_idx_1 10 | self.S = None 11 | 12 | 13 | def __repr__(self): 14 | return '({}, {})'.format(self.layer_first, self.layer_second) 15 | 16 | 17 | def get_idxs(self): 18 | return self.layer_first, self.layer_second, self.bn_idx 19 | 20 | def set_scale_vec(self, S): 21 | if self.S is None: 22 | self.S = S 23 | else: 24 | self.S *= S 25 | 26 | def get_scale_vec(self): 27 | return self.S 28 | 29 | 30 | def create_relation(graph, bottoms, targ_type=[QConv2d], delete_single=False): 31 | relation_dict = OrderedDict() 32 | 33 | def _find_prev(graph, bottoms, layer_idx, targ_type, top_counter): # find previous target layer to form relations 34 | bot = bottoms[layer_idx] 35 | last_bn = None 36 | while len(bot) == 1 and "Data" != bot[0] and top_counter[bot[0]] == 1: 37 | if type(graph[bot[0]]) == BatchNorm2d: 38 | last_bn = bot[0] 39 | if type(graph[bot[0]]) in targ_type: 40 | return bot[0], last_bn 41 | 42 | elif not(type(graph[bot[0]]) in [BatchNorm2d, ReLU, QuantMeasure, AvgPool2d] or 43 | (type(graph[bot[0]]) == str and ("F.pad" in bot[0] or "torch.mean" in bot[0]))): 44 | return None, None 45 | 46 | bot = bottoms[bot[0]] 47 | 48 | return None, None 49 | 50 | top_counter = {} #count the number of output branches of each layer 51 | for layer_idx in graph: 52 | if layer_idx == "Data": 53 | continue 54 | for bot in bottoms[layer_idx]: 55 | if bot in top_counter: 56 | top_counter[bot] += 1 57 | else: 58 | top_counter[bot] = 1 59 | 60 | # find relation pair for each layer 61 | for layer_idx in graph: 62 | if type(graph[layer_idx]) in targ_type: 63 | prev, bn = _find_prev(graph, bottoms, layer_idx, targ_type, top_counter) 64 | if prev in relation_dict: 65 | relation_dict.pop(prev) 66 | elif prev is not None: 67 | rel = Relation(prev, layer_idx, bn) 68 | relation_dict[prev] = rel 69 | 70 | if delete_single: 71 | # only take the relations with more than 3 targ_layers, ex: Conv2d->Conv2d->Conv2d,, ignore Conv2d->Conv2d (in detection task) 72 | tmp = list(relation_dict.values()) 73 | res_group = [] 74 | for rr in tmp: 75 | group_idx = -1 76 | for idx, group in enumerate(res_group): 77 | for rr_prev in group: 78 | if rr.get_idxs()[0] == rr_prev.get_idxs()[1]: 79 | group_idx = idx 80 | break 81 | if group_idx != -1: 82 | res_group[group_idx].append(rr) 83 | else: 84 | res_group.append([rr]) 85 | res = [] 86 | for group in res_group: 87 | if len(group) > 1: 88 | res.extend(group) 89 | 90 | # print(len(res), len(list(relation_dict.values()))) 91 | else: 92 | res = list(relation_dict.values()) 93 | 94 | return res #list(relation_dict.values()) 95 | -------------------------------------------------------------------------------- /utils/segmentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jakc4103/DFQ/6f15805cfdbf2769275defd54728df0a5d30dbc6/utils/segmentation/__init__.py -------------------------------------------------------------------------------- /utils/segmentation/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import cv2 5 | 6 | from utils.metrics import Evaluator 7 | 8 | 9 | def forward_all(net_inference, dataloader, visualize=False, opt=None): 10 | evaluator = Evaluator(21) 11 | evaluator.reset() 12 | with torch.no_grad(): 13 | for ii, sample in enumerate(dataloader): 14 | image, label = sample['image'].cuda(), sample['label'].cuda() 15 | 16 | activations = net_inference(image) 17 | 18 | image = image.cpu().numpy() 19 | label = label.cpu().numpy().astype(np.uint8) 20 | 21 | logits = activations[list(activations.keys())[-1]] if type(activations) != torch.Tensor else activations 22 | pred = torch.max(logits, 1)[1].cpu().numpy().astype(np.uint8) 23 | 24 | evaluator.add_batch(label, pred) 25 | 26 | # print(label.shape, pred.shape) 27 | if visualize: 28 | for jj in range(sample["image"].size()[0]): 29 | segmap_label = decode_segmap(label[jj], dataset='pascal') 30 | segmap_pred = decode_segmap(pred[jj], dataset='pascal') 31 | 32 | img_tmp = np.transpose(image[jj], axes=[1, 2, 0]) 33 | img_tmp *= (0.229, 0.224, 0.225) 34 | img_tmp += (0.485, 0.456, 0.406) 35 | img_tmp *= 255.0 36 | img_tmp = img_tmp.astype(np.uint8) 37 | 38 | cv2.imshow('image', img_tmp[:, :, [2,1,0]]) 39 | cv2.imshow('gt', segmap_label) 40 | cv2.imshow('pred', segmap_pred) 41 | cv2.waitKey(0) 42 | 43 | Acc = evaluator.Pixel_Accuracy() 44 | Acc_class = evaluator.Pixel_Accuracy_Class() 45 | mIoU = evaluator.Mean_Intersection_over_Union() 46 | FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() 47 | print("Acc: {}".format(Acc)) 48 | print("Acc_class: {}".format(Acc_class)) 49 | print("mIoU: {}".format(mIoU)) 50 | print("FWIoU: {}".format(FWIoU)) 51 | if opt is not None: 52 | with open("seg_result.txt", 'a+') as ww: 53 | ww.write("{}, quant: {}, relu: {}, equalize: {}, absorption: {}, correction: {}, clip: {}, distill_range: {}\n".format( 54 | opt.dataset, opt.quantize, opt.relu, opt.equalize, opt.absorption, opt.correction, opt.clip_weight, opt.distill_range 55 | )) 56 | ww.write("Acc: {}, Acc_class: {}, mIoU: {}, FWIoU: {}\n\n".format(Acc, Acc_class, mIoU, FWIoU)) 57 | 58 | 59 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 60 | rgb_masks = [] 61 | for label_mask in label_masks: 62 | rgb_mask = decode_segmap(label_mask, dataset) 63 | rgb_masks.append(rgb_mask) 64 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 65 | return rgb_masks 66 | 67 | 68 | def decode_segmap(label_mask, dataset, plot=False): 69 | """Decode segmentation class labels into a color image 70 | Args: 71 | label_mask (np.ndarray): an (M,N) array of integer values denoting 72 | the class label at each spatial location. 73 | plot (bool, optional): whether to show the resulting color image 74 | in a figure. 75 | Returns: 76 | (np.ndarray, optional): the resulting decoded color image. 77 | """ 78 | if dataset == 'pascal' or dataset == 'coco': 79 | n_classes = 21 80 | label_colours = get_pascal_labels() 81 | elif dataset == 'cityscapes': 82 | n_classes = 19 83 | label_colours = get_cityscapes_labels() 84 | else: 85 | raise NotImplementedError 86 | 87 | r = label_mask.copy() 88 | g = label_mask.copy() 89 | b = label_mask.copy() 90 | for ll in range(0, n_classes): 91 | r[label_mask == ll] = label_colours[ll, 0] 92 | g[label_mask == ll] = label_colours[ll, 1] 93 | b[label_mask == ll] = label_colours[ll, 2] 94 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 95 | rgb[:, :, 0] = r / 255.0 96 | rgb[:, :, 1] = g / 255.0 97 | rgb[:, :, 2] = b / 255.0 98 | if plot: 99 | plt.imshow(rgb) 100 | plt.show() 101 | else: 102 | return rgb 103 | 104 | 105 | def encode_segmap(mask): 106 | """Encode segmentation label images as pascal classes 107 | Args: 108 | mask (np.ndarray): raw segmentation label image of dimension 109 | (M, N, 3), in which the Pascal classes are encoded as colours. 110 | Returns: 111 | (np.ndarray): class map with dimensions (M,N), where the value at 112 | a given location is the integer denoting the class index. 113 | """ 114 | mask = mask.astype(int) 115 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 116 | for ii, label in enumerate(get_pascal_labels()): 117 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 118 | label_mask = label_mask.astype(int) 119 | return label_mask 120 | 121 | 122 | def get_cityscapes_labels(): 123 | return np.array([ 124 | [128, 64, 128], 125 | [244, 35, 232], 126 | [70, 70, 70], 127 | [102, 102, 156], 128 | [190, 153, 153], 129 | [153, 153, 153], 130 | [250, 170, 30], 131 | [220, 220, 0], 132 | [107, 142, 35], 133 | [152, 251, 152], 134 | [0, 130, 180], 135 | [220, 20, 60], 136 | [255, 0, 0], 137 | [0, 0, 142], 138 | [0, 0, 70], 139 | [0, 60, 100], 140 | [0, 80, 100], 141 | [0, 0, 230], 142 | [119, 11, 32]]) 143 | 144 | 145 | def get_pascal_labels(): 146 | """Load the mapping that associates pascal classes with label colors 147 | Returns: 148 | np.ndarray with dimensions (21, 3) 149 | """ 150 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 151 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 152 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 153 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 154 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 155 | [0, 64, 128]]) --------------------------------------------------------------------------------