├── LICENSE.txt ├── MANIFEST.in ├── Makefile ├── README.md ├── build_helper.py ├── docs ├── Makefile ├── make.bat └── source │ ├── build_helper.rst │ ├── conf.py │ ├── index.rst │ ├── ncnnqat.rst │ └── setup.rst ├── ncnnqat ├── __init__.py └── quantize.py ├── pyproject.toml ├── setup.cfg ├── setup.py ├── src ├── fake_quantize.cpp ├── fake_quantize.cu └── fake_quantize.h └── tests ├── ssd300 ├── main.py └── src │ ├── __init__.py │ ├── coco.py │ ├── coco_pipeline.py │ ├── data.py │ ├── distributed.py │ ├── evaluate.py │ ├── model.py │ ├── train.py │ └── utils.py └── test_cifar10.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Shisen Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 9 | of the Software, and to permit persons to whom the Software is furnished to do 10 | so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenShisen/ncnnqat/253a413264507cf90089d1aa0e30c0ef30087cfe/MANIFEST.in -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Uncomment for debugging 2 | # DEBUG := 1 3 | # Pretty build 4 | # Q ?= @ 5 | 6 | CXX := g++ 7 | python := python3 8 | PYTHON_HEADER_DIR := $(shell python -c 'from distutils.sysconfig import get_python_inc; print(get_python_inc())') 9 | PYTORCH_INCLUDES := $(shell python -c 'from torch.utils.cpp_extension import include_paths; [print(p) for p in include_paths()]') 10 | PYTORCH_LIBRARIES := $(shell python -c 'from torch.utils.cpp_extension import library_paths; [print(p) for p in library_paths()]') 11 | 12 | CUDA_DIR := $(shell python -c 'from torch.utils.cpp_extension import _find_cuda_home; print(_find_cuda_home())') 13 | WITH_ABI := $(shell python -c 'import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))') 14 | INCLUDE_DIRS := ./ $(CUDA_DIR)/include 15 | INCLUDE_DIRS += $(PYTHON_HEADER_DIR) 16 | INCLUDE_DIRS += $(PYTORCH_INCLUDES) 17 | 18 | # Custom (MKL/ATLAS/OpenBLAS) include and lib directories. 19 | # BLAS_INCLUDE := /path/to/your/blas 20 | # BLAS_LIB := /path/to/your/blas 21 | 22 | SRC_DIR := ./src 23 | OBJ_DIR := ./obj 24 | CPP_SRCS := $(wildcard $(SRC_DIR)/*.cpp) 25 | CU_SRCS := $(wildcard $(SRC_DIR)/*.cu) 26 | OBJS := $(patsubst $(SRC_DIR)/%.cpp,$(OBJ_DIR)/%.o,$(CPP_SRCS)) 27 | CU_OBJS := $(patsubst $(SRC_DIR)/%.cu,$(OBJ_DIR)/cuda/%.o,$(CU_SRCS)) 28 | #STATIC_LIB := $(OBJ_DIR)/libquant_impl.a 29 | STATIC_LIB := $(OBJ_DIR)/libquant_cuda.a 30 | 31 | 32 | CUDA_ARCH := -gencode arch=compute_50,code=sm_50 \ 33 | -gencode arch=compute_52,code=sm_52 \ 34 | -gencode arch=compute_60,code=sm_60 \ 35 | -gencode arch=compute_61,code=sm_61 \ 36 | -gencode arch=compute_70,code=sm_70 \ 37 | -gencode arch=compute_75,code=sm_75 \ 38 | -gencode arch=compute_75,code=compute_75 39 | 40 | 41 | LIBRARIES += stdc++ cudart c10 caffe2 torch torch_python caffe2_gpu 42 | 43 | 44 | ifeq ($(DEBUG), 1) 45 | COMMON_FLAGS += -DDEBUG -g -O0 46 | NVCCFLAGS += -g -G # -rdc true 47 | else 48 | COMMON_FLAGS += -DNDEBUG -O3 49 | endif 50 | 51 | WARNINGS := -Wall -Wno-sign-compare -Wcomment 52 | INCLUDE_DIRS += $(BLAS_INCLUDE) 53 | CXXFLAGS += -MMD -MP 54 | COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) \ 55 | -DTORCH_API_INCLUDE_EXTENSION_H -D_GLIBCXX_USE_CXX11_ABI=$(WITH_ABI) 56 | CXXFLAGS += -pthread -fPIC -fwrapv -std=c++14 $(COMMON_FLAGS) $(WARNINGS) 57 | NVCCFLAGS += -std=c++14 -ccbin=$(CXX) -Xcompiler -fPIC -use_fast_math $(COMMON_FLAGS) 58 | 59 | default: $(STATIC_LIB) 60 | 61 | $(OBJ_DIR): 62 | @ mkdir -p $@ 63 | @ mkdir -p $@/cuda 64 | 65 | $(OBJ_DIR)/%.o: $(SRC_DIR)/%.cpp | $(OBJ_DIR) 66 | @ echo CXX $< 67 | $(Q)$(CXX) $< $(CXXFLAGS) -c -o $@ 68 | 69 | $(OBJ_DIR)/cuda/%.o: $(SRC_DIR)/%.cu | $(OBJ_DIR) 70 | @ echo NVCC $< 71 | $(Q)nvcc $(NVCCFLAGS) $(CUDA_ARCH) -M $< -o ${@:.o=.d} \ 72 | -odir $(@D) 73 | $(Q)nvcc $(NVCCFLAGS) $(CUDA_ARCH) -c $< -o $@ 74 | 75 | $(STATIC_LIB): $(OBJS) $(CU_OBJS) | $(OBJ_DIR) 76 | $(RM) -f $(STATIC_LIB) 77 | $(RM) -rf build dist 78 | @ echo LD -o $@ 79 | ar rc $(STATIC_LIB) $(OBJS) $(CU_OBJS) 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | build: 88 | $(python) setup.py build 89 | 90 | upload: 91 | $(python) setup.py sdist bdist_wheel 92 | 93 | clean: 94 | $(RM) -rf build dist ncnnqat.egg-info 95 | 96 | test: 97 | nosetests -s tests/test_merge_freeze_bn.py --nologcapture 98 | 99 | lint: 100 | pylint ncnnqat --reports=n 101 | 102 | lintfull: 103 | pylint ncnnqat 104 | 105 | install: 106 | $(python) setup.py install 107 | 108 | uninstall: 109 | $(python) setup.py install --record install.log 110 | cat install.log | xargs rm -rf 111 | $(RM) install.log 112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # ncnnqat 4 | 5 | ncnnqat is a quantize aware training package for NCNN on pytorch. 6 | 7 |
8 | 9 | ## Table of Contents 10 | 11 | - [ncnnqat](#ncnnqat) 12 | - [Table of Contents](#table-of-contents) 13 | - [Installation](#installation) 14 | - [Usage](#usage) 15 | - [Code Examples](#code-examples) 16 | - [Results](#results) 17 | - [Todo](#todo) 18 | 19 | 20 |
21 | 22 | ## Installation 23 | 24 | * Supported Platforms: Linux 25 | * Accelerators and GPUs: NVIDIA GPUs via CUDA driver ***10.1***. 26 | * Dependencies: 27 | * python >= 3.5, < 4 28 | * pytorch >= 1.6 29 | * numpy >= 1.18.1 30 | * onnx >= 1.7.0 31 | * onnx-simplifier >= 0.3.6 32 | 33 | * Install ncnnqat via pypi: 34 | ```shell 35 | $ pip install ncnnqat (to do....) 36 | ``` 37 | It is recommended to install from the source code 38 | * or Install ncnnqat via repo: 39 | ```shell 40 | $ git clone https://github.com/ChenShisen/ncnnqat 41 | $ cd ncnnqat 42 | $ make install 43 | ``` 44 | 45 |
46 | 47 | ## Usage 48 | 49 | 50 | * register_quantization_hook and merge_freeze_bn 51 | 52 | (suggest finetuning from a well-trained model, do it after a few epochs of training otherwise.) 53 | 54 | ```python 55 | from ncnnqat import unquant_weight, merge_freeze_bn, register_quantization_hook 56 | ... 57 | ... 58 | for epoch in range(epoch_train): 59 | model.train() 60 | if epoch==well_epoch: 61 | register_quantization_hook(model) 62 | if epoch>=well_epoch: 63 | model = merge_freeze_bn(model) #it will change bn to eval() mode during training 64 | ... 65 | ``` 66 | 67 | * Unquantize weight before update it 68 | 69 | ```python 70 | ... 71 | ... 72 | if epoch>=well_epoch: 73 | model.apply(unquant_weight) # using original weight while updating 74 | optimizer.step() 75 | ... 76 | ``` 77 | 78 | * Save weight and save ncnn quantize table after train 79 | 80 | 81 | ```python 82 | ... 83 | ... 84 | onnx_path = "./xxx/model.onnx" 85 | table_path="./xxx/model.table" 86 | dummy_input = torch.randn(1, 3, img_size, img_size, device='cuda') 87 | input_names = [ "input" ] 88 | output_names = [ "fc" ] 89 | torch.onnx.export(model, dummy_input, onnx_path, verbose=False, input_names=input_names, output_names=output_names) 90 | save_table(model,onnx_path=onnx_path,table=table_path) 91 | 92 | ... 93 | ``` 94 | if use "model = nn.DataParallel(model)",pytorch unsupport torch.onnx.export,you should save state_dict first and prepare a new model with one gpu,then you will export onnx model. 95 | 96 | ```python 97 | ... 98 | ... 99 | model_s = new_net() # 100 | model_s.cuda() 101 | register_quantization_hook(model_s) 102 | #model_s = merge_freeze_bn(model_s) 103 | onnx_path = "./xxx/model.onnx" 104 | table_path="./xxx/model.table" 105 | dummy_input = torch.randn(1, 3, img_size, img_size, device='cuda') 106 | input_names = [ "input" ] 107 | output_names = [ "fc" ] 108 | model_s.load_state_dict({k.replace('module.',''):v for k,v in model.state_dict().items()}) #model_s = model model = nn.DataParallel(model) 109 | 110 | torch.onnx.export(model_s, dummy_input, onnx_path, verbose=False, input_names=input_names, output_names=output_names) 111 | save_table(model_s,onnx_path=onnx_path,table=table_path) 112 | 113 | 114 | ... 115 | ``` 116 | 117 | 118 |
119 | 120 | ## Code Examples 121 | 122 | Cifar10 quantization aware training example. 123 | 124 | ```python test/test_cifar10.py``` 125 | 126 | SSD300 quantization aware training example. 127 | 128 | ``` 129 | ln -s /your_coco_path/coco ./tests/ssd300/data 130 | ``` 131 | ``` 132 | python -m torch.distributed.launch \ 133 | --nproc_per_node=4 \ 134 | --nnodes=1 \ 135 | --node_rank=0 \ 136 | ./tests/ssd300/main.py \ 137 | -d ./tests/ssd300/data/coco 138 | ``` 139 | ``` 140 | python ./tests/ssd300/main.py --onnx_save #load model dict, export onnx and ncnn table 141 | ``` 142 | 143 |
144 | 145 | ## Results 146 | 147 | * Cifar10 148 | 149 | 150 | result: 151 | 152 | | net | fp32(onnx) | ncnnqat | ncnn aciq | ncnn kl | 153 | | -------- | -------- | -------- | -------- | -------- | 154 | | mobilenet_v2 | 0.91 | 0.9066 | 0.9033 | 0.9066 | 155 | | resnet18 | 0.94 | 0.93333 | 0.9367 | 0.937| 156 | 157 | 158 | * SSD300(resnet18|coco) 159 | 160 | 161 | ``` 162 | fp32: 163 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.193 164 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.344 165 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.191 166 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.042 167 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.195 168 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.328 169 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.199 170 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.293 171 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.309 172 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.084 173 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326 174 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.501 175 | Current AP: 0.19269 176 | 177 | ncnnqat: 178 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.192 179 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.342 180 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.194 181 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.041 182 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.194 183 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.327 184 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.197 185 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.291 186 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.307 187 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.082 188 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.325 189 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.497 190 | Current AP: 0.19202 191 | ``` 192 | 193 | 194 |
195 | 196 | ## Todo 197 | 198 | .... 199 | -------------------------------------------------------------------------------- /build_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | import sys 5 | import tempfile 6 | from distutils import ccompiler 7 | 8 | 9 | def print_warning(*lines): 10 | print('**************************************************') 11 | for line in lines: 12 | print('*** WARNING: %s' % line) 13 | print('**************************************************') 14 | 15 | 16 | def get_path(key): 17 | return os.environ.get(key, '').split(os.pathsep) 18 | 19 | 20 | def search_on_path(filenames): 21 | for p in get_path('PATH'): 22 | for filename in filenames: 23 | full = os.path.join(p, filename) 24 | if os.path.exists(full): 25 | return os.path.abspath(full) 26 | return None 27 | 28 | 29 | minimum_cuda_version = 10010 30 | maxinum_cuda_version = 10030 31 | minimum_cudnn_version = 7000 32 | 33 | 34 | def get_compiler_setting(): 35 | nvcc_path = search_on_path(('nvcc', 'nvcc.exe')) 36 | cuda_path_default = None 37 | if nvcc_path is None: 38 | print_warning('nvcc not in path.', 'Please set path to nvcc.') 39 | else: 40 | cuda_path_default = os.path.normpath( 41 | os.path.join(os.path.dirname(nvcc_path), '..')) 42 | 43 | cuda_path = os.environ.get('CUDA_PATH', '') # Nvidia default on Windows 44 | if len(cuda_path) > 0 and cuda_path != cuda_path_default: 45 | print_warning('nvcc path != CUDA_PATH', 46 | 'nvcc path: %s' % cuda_path_default, 47 | 'CUDA_PATH: %s' % cuda_path) 48 | 49 | if not os.path.exists(cuda_path): 50 | cuda_path = cuda_path_default 51 | 52 | if not cuda_path and os.path.exists('/usr/local/cuda'): 53 | cuda_path = '/usr/local/cuda' 54 | 55 | include_dirs = [] 56 | library_dirs = [] 57 | define_macros = [] 58 | 59 | if cuda_path: 60 | include_dirs.append(os.path.join(cuda_path, 'include')) 61 | if sys.platform == 'win32': 62 | library_dirs.append(os.path.join(cuda_path, 'bin')) 63 | library_dirs.append(os.path.join(cuda_path, 'lib', 'x64')) 64 | else: 65 | library_dirs.append(os.path.join(cuda_path, 'lib64')) 66 | library_dirs.append(os.path.join(cuda_path, 'lib')) 67 | if sys.platform == 'darwin': 68 | library_dirs.append('/usr/local/cuda/lib') 69 | 70 | return { 71 | 'include_dirs': include_dirs, 72 | 'library_dirs': library_dirs, 73 | 'define_macros': define_macros, 74 | 'language': 'c++', 75 | } 76 | 77 | 78 | def check_cuda_version(): 79 | compiler = ccompiler.new_compiler() 80 | settings = get_compiler_setting() 81 | try: 82 | out = build_and_run(compiler, 83 | ''' 84 | #include 85 | #include 86 | int main(int argc, char* argv[]) { 87 | printf("%d", CUDA_VERSION); 88 | return 0; 89 | } 90 | ''', 91 | include_dirs=settings['include_dirs']) 92 | 93 | except Exception as e: 94 | print_warning('Cannot check CUDA version', str(e)) 95 | return False 96 | 97 | cuda_version = int(out) 98 | if cuda_version < minimum_cuda_version: 99 | print_warning('CUDA version is too old: %d' % cuda_version, 100 | 'CUDA v10.1 or CUDA v10.2 is required') 101 | return False 102 | if cuda_version > maxinum_cuda_version: 103 | print_warning('CUDA version is too new: %d' % cuda_version, 104 | 'CUDA v10.1 or CUDA v10.2 is required') 105 | 106 | return True 107 | 108 | 109 | def check_cudnn_version(): 110 | compiler = ccompiler.new_compiler() 111 | settings = get_compiler_setting() 112 | try: 113 | out = build_and_run(compiler, 114 | ''' 115 | #include 116 | #include 117 | int main(int argc, char* argv[]) { 118 | printf("%d", CUDNN_VERSION); 119 | return 0; 120 | } 121 | ''', 122 | include_dirs=settings['include_dirs']) 123 | 124 | except Exception as e: 125 | print_warning('Cannot check cuDNN version\n{0}'.format(e)) 126 | return False 127 | 128 | cudnn_version = int(out) 129 | if cudnn_version < minimum_cudnn_version: 130 | print_warning('cuDNN version is too old: %d' % cudnn_version, 131 | 'cuDNN v7 or newer is required') 132 | return False 133 | 134 | return True 135 | 136 | 137 | def build_and_run(compiler, 138 | source, 139 | libraries=(), 140 | include_dirs=(), 141 | library_dirs=()): 142 | temp_dir = tempfile.mkdtemp() 143 | 144 | try: 145 | fname = os.path.join(temp_dir, 'a.cpp') 146 | with open(fname, 'w') as f: 147 | f.write(source) 148 | 149 | objects = compiler.compile([fname], 150 | output_dir=temp_dir, 151 | include_dirs=include_dirs) 152 | 153 | try: 154 | postargs = ['/MANIFEST'] if sys.platform == 'win32' else [] 155 | compiler.link_executable(objects, 156 | os.path.join(temp_dir, 'a'), 157 | libraries=libraries, 158 | library_dirs=library_dirs, 159 | extra_postargs=postargs, 160 | target_lang='c++') 161 | except Exception as e: 162 | msg = 'Cannot build a stub file.\nOriginal error: {0}'.format(e) 163 | raise Exception(msg) 164 | 165 | try: 166 | out = subprocess.check_output(os.path.join(temp_dir, 'a')) 167 | return out 168 | 169 | except Exception as e: 170 | msg = 'Cannot execute a stub file.\nOriginal error: {0}'.format(e) 171 | raise Exception(msg) 172 | 173 | finally: 174 | shutil.rmtree(temp_dir, ignore_errors=True) 175 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/build_helper.rst: -------------------------------------------------------------------------------- 1 | build\_helper module 2 | ==================== 3 | 4 | .. automodule:: build_helper 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | import os 4 | import sys 5 | sys.path.insert(0, os.path.abspath('./../../')) 6 | 7 | 8 | # -- Project information ----------------------------------------------------- 9 | 10 | project = 'ncnnqat' 11 | copyright = '2021, Shisen Chen' 12 | author = 'Shisen Chen' 13 | 14 | # The short X.Y version 15 | version = '' 16 | # The full version, including alpha/beta/rc tags 17 | release = '0.1.0' 18 | 19 | 20 | # -- General configuration --------------------------------------------------- 21 | 22 | # If your documentation needs a minimal Sphinx version, state it here. 23 | # 24 | # needs_sphinx = '1.0' 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be 27 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 28 | # ones. 29 | extensions = [ 30 | 'sphinx.ext.todo', 31 | 'sphinx.ext.githubpages', 32 | 'sphinx.ext.autodoc', 33 | ] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # The suffix(es) of source filenames. 39 | # You can specify multiple suffix as a list of string: 40 | # 41 | # source_suffix = ['.rst', '.md'] 42 | source_suffix = '.rst' 43 | 44 | # The master toctree document. 45 | master_doc = 'index' 46 | 47 | # The language for content autogenerated by Sphinx. Refer to documentation 48 | # for a list of supported languages. 49 | # 50 | # This is also used if you do content translation via gettext catalogs. 51 | # Usually you set "language" from the command line for these cases. 52 | language = None 53 | 54 | # List of patterns, relative to source directory, that match files and 55 | # directories to ignore when looking for source files. 56 | # This pattern also affects html_static_path and html_extra_path . 57 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 58 | 59 | # The name of the Pygments (syntax highlighting) style to use. 60 | pygments_style = 'sphinx' 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | 69 | # Theme options are theme-specific and customize the look and feel of a theme 70 | # further. For a list of options available for each theme, see the 71 | # documentation. 72 | # 73 | # html_theme_options = {} 74 | 75 | # Add any paths that contain custom static files (such as style sheets) here, 76 | # relative to this directory. They are copied after the builtin static files, 77 | # so a file named "default.css" will overwrite the builtin "default.css". 78 | html_static_path = ['_static'] 79 | 80 | # Custom sidebar templates, must be a dictionary that maps document names 81 | # to template names. 82 | # 83 | # The default sidebars (for documents that don't match any pattern) are 84 | # defined by theme itself. Builtin themes are using these templates by 85 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 86 | # 'searchbox.html']``. 87 | # 88 | # html_sidebars = {} 89 | html_theme = 'sphinx_rtd_theme' 90 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. ncnnqat documentation master file, created by 2 | sphinx-quickstart on Fri Aug 21 03:52:34 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to ncnnqat's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/source/ncnnqat.rst: -------------------------------------------------------------------------------- 1 | ncnnqat package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | ncnnqat 10 | 11 | Module contents 12 | --------------- 13 | 14 | .. automodule:: ncnnqat 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | -------------------------------------------------------------------------------- /docs/source/setup.rst: -------------------------------------------------------------------------------- 1 | setup module 2 | ============ 3 | 4 | .. automodule:: setup 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /ncnnqat/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | try: 3 | from .quantize import unquant_weight, freeze_bn, \ 4 | merge_freeze_bn, register_quantization_hook,save_table 5 | except: 6 | raise 7 | __all__ = [ 8 | "unquant_weight", "freeze_bn", "merge_freeze_bn", \ 9 | "register_quantization_hook","save_table"] 10 | 11 | -------------------------------------------------------------------------------- /ncnnqat/quantize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import logging 4 | import torch 5 | import numpy as np 6 | import onnx 7 | 8 | from quant_cuda import fake_quantize 9 | 10 | class FakeQuantCuda(): 11 | r""" 12 | """ 13 | def __init__(self, 14 | bit_width=8, 15 | type=1, 16 | c=1 17 | ): 18 | 19 | self._bit_width = bit_width 20 | self._type = type 21 | self._c = c 22 | 23 | 24 | def __call__(self, tensor,tensor_scale,tensor_movMax=None, aciq=0): #type=0,1,2=pre_conv_activate,w,after_conv_activate 25 | r""" Converts float weights to quantized weights. 26 | 27 | Args: 28 | - tensor: input data 29 | - tensor_scale data scale data 30 | - tensor_movMax tensor max value 31 | - aciq qat methed ,default turn of, use kl 32 | """ 33 | 34 | #print(self._type,self._bit_width) 35 | #tensor.data = fake_quantize_c(tensor.data.detach().clone(),tensor_s.data.detach().clone(),self._bit_width,self._type) 36 | 37 | out = fake_quantize(tensor.data.detach().clone(),self._bit_width,self._type,self._c,aciq) 38 | tensor.data = out[0] 39 | tensor_scale.data = out[1] 40 | if self._type==0: 41 | tensor_movMax.data = out[2] 42 | #print("tensor_scale",tensor_scale) 43 | 44 | return tensor,tensor_scale,tensor_movMax 45 | 46 | 47 | 48 | 49 | def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b): 50 | """ fuse convolution and batch norm's weight. 51 | 52 | Args: 53 | conv_w (torch.nn.Parameter): convolution weight. 54 | conv_b (torch.nn.Parameter): convolution bias. 55 | bn_rm (torch.nn.Parameter): batch norm running mean. 56 | bn_rv (torch.nn.Parameter): batch norm running variance. 57 | bn_eps (torch.nn.Parameter): batch norm epsilon. 58 | bn_w (torch.nn.Parameter): batch norm weight. 59 | bn_b (torch.nn.Parameter): batch norm weight. 60 | 61 | Returns: 62 | conv_w(torch.nn.Parameter): fused convolution weight. 63 | conv_b(torch.nn.Parameter): fused convllution bias. 64 | """ 65 | 66 | if conv_b is None: 67 | conv_b = bn_rm.new_zeros(bn_rm.shape) 68 | bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) 69 | 70 | conv_w = conv_w * \ 71 | (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1)) 72 | conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b 73 | 74 | return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b) 75 | 76 | 77 | def _fuse_conv_bn(conv, bn): 78 | conv.weight, conv.bias = \ 79 | _fuse_conv_bn_weights(conv.weight, conv.bias, 80 | bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias) 81 | return conv 82 | 83 | 84 | def _fuse_modules(model): 85 | r"""Fuses a list of modules into a single module 86 | 87 | Fuses only the following sequence of modules: 88 | conv, bn 89 | All other sequences are left unchanged. 90 | For these sequences, fuse modules on weight level, keep model structure unchanged. 91 | 92 | Arguments: 93 | model: Model containing the modules to be fused 94 | 95 | Returns: 96 | model with fused modules. 97 | 98 | """ 99 | children = list(model.named_children()) 100 | conv_module = None 101 | conv_name = None 102 | 103 | for name, child in children: 104 | if isinstance(child, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, 105 | torch.nn.BatchNorm3d)): 106 | if isinstance(conv_module, (torch.nn.Conv2d, torch.nn.Conv3d)): 107 | conv_module = _fuse_conv_bn(conv_module, child) 108 | model._modules[conv_name] = conv_module 109 | child.eval() 110 | child.running_mean = child.running_mean.new_full( 111 | child.running_mean.shape, 0) 112 | child.running_var = child.running_var.new_full( 113 | child.running_var.shape, 1) 114 | 115 | if child.weight is not None: 116 | child.weight.data = child.weight.data.new_full( 117 | child.weight.shape, 1) 118 | if child.bias is not None: 119 | child.bias.data = child.bias.data.new_full( 120 | child.bias.shape, 0) 121 | #print(child,child.bias) 122 | child.track_running_stats = False 123 | child.momentum = 0 124 | child.eps = 0 125 | #child.affine = False 126 | conv_module = None 127 | elif isinstance(child, (torch.nn.Conv2d, torch.nn.Conv3d)): 128 | conv_module = child 129 | conv_name = name 130 | else: 131 | _fuse_modules(child) 132 | return model 133 | 134 | 135 | def freeze_bn(m, freeze_bn_affine=True): 136 | """Freeze batch normalization. 137 | reference: https://arxiv.org/abs/1806.08342 138 | 139 | 140 | Args: 141 | - m (nn.module): torch module 142 | - freeze_bn_affine (bool, optional): Freeze affine scale and 143 | translation factor or not. Defaults: True. 144 | """ 145 | 146 | if isinstance( 147 | m, 148 | (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)): 149 | 150 | m.eval() 151 | if freeze_bn_affine: 152 | m.weight.requires_grad = False 153 | m.bias.requires_grad = False 154 | 155 | 156 | def merge_freeze_bn(model): 157 | """merge batch norm's weight into convolution, then freeze it. 158 | 159 | Args: 160 | model (nn.module): model. 161 | 162 | Returns: 163 | [nn.module]: model. 164 | """ 165 | model = _fuse_modules(model) #merge conv bn ; mean 0 std 1 gama 1 beta 0 166 | model.apply(freeze_bn) # bn backward = false,bn not train 167 | return model 168 | 169 | 170 | def unquant_weight(m): 171 | """ unquantize weight before update weight, avoid training turbulence. 172 | 173 | Args: 174 | - m (nn.module): torch module. 175 | """ 176 | try: 177 | if hasattr(m, "weight_origin") and m.weight is not None: 178 | m.weight.data.copy_(m.weight_origin.data) 179 | except AttributeError: 180 | pass 181 | except TypeError: 182 | pass 183 | 184 | ''' 185 | def quant_dequant_weight(m): 186 | """ quant weight manually. 187 | 188 | Args: 189 | - m (nn.module): torch module. 190 | """ 191 | quant_handle = FakeQuantCuda() 192 | try: 193 | if hasattr(m, "weight_origin") and m.weight is not None: 194 | m.weight_origin.data.copy_(m.weight.data) 195 | m.weight.data = quant_handle(m.weight.data.detach().clone()) 196 | except AttributeError: 197 | pass 198 | except TypeError: 199 | pass 200 | ''' 201 | 202 | def _quantizing_activation_ncnn(module, input): 203 | """ quantize per-layer activation(input of layer) before layer calculate. 204 | 205 | Args: 206 | - module (nn.module): torch module. 207 | - input : layer input(tuple) ,torch tensor (nchw or n**). 208 | """ 209 | #GOOGLE QAT movMax = movMax*momenta + max(abs(tensor))*(1-momenta) momenta = 0.95 210 | #print("input.shape",input[0].shape) 211 | aciq = 0 212 | quant_handle = FakeQuantCuda(type=0,bit_width=8,c=1) 213 | list_modified = [] 214 | if isinstance(input, tuple): 215 | for item in input: 216 | aciq = 0 217 | item_type = item.dtype 218 | if item.numel()/item.shape[0]>8000: 219 | aciq = 1 220 | #quant_tuple = quant_handle(item.float(),module.activation_scale.data.detach().clone()) 221 | quant_tuple = quant_handle(item.float(),module.activation_scale.data.detach().clone(),tensor_movMax=module.activation_movMax.data.detach().clone(),aciq=aciq) 222 | item = quant_tuple[0] 223 | if item.dtype!=item_type: 224 | #print(item.dtype,item_type) 225 | item.to(item_type) 226 | module.activation_scale.data = quant_tuple[1] 227 | module.activation_movMax.data = quant_tuple[2] 228 | #print(quant_tuple[2]) 229 | list_modified.append(item) 230 | 231 | else: 232 | input_type = input.dtype 233 | if input.numel()/input.shape[0]>8000: 234 | aciq = 1 235 | #quant_tuple = quant_handle(input.float(),module.activation_scale.data.detach().clone()) 236 | quant_tuple = quant_handle(input.float(),module.activation_scale.data.detach().clone(),tensor_movMax=module.activation_movMax.data.detach().clone(),aciq=aciq) 237 | input = quant_tuple[0] 238 | module.activation_scale.data = quant_tuple[1] 239 | module.activation_movMax.data = quant_tuple[2] 240 | if input.dtype!=input_type: 241 | input.to(input_type) 242 | list_modified.append(input) 243 | tuple_input = tuple(list_modified) 244 | return tuple_input 245 | def _quantizing_weight_ncnn(module, input): 246 | """ quantize per-channel weight before layer calculate. 247 | 248 | Args: 249 | - module (nn.module): torch module. 250 | - input : layer input(tuple) ,torch tensor (nchw or n**). 251 | """ 252 | module_shape = module.weight.shape 253 | #print("module_shape",module_shape) 254 | channel = module_shape[0] #oikk 255 | if isinstance(module,(torch.nn.Conv2d)) and module.groups!=1: #depthwise 256 | channel = module.groups 257 | bit_width = 8 258 | if isinstance(module,(torch.nn.Conv2d)) and module.stride==(1,1) and module.dilation==(1,1) and module.kernel_size==(3,3) and module.groups==1: #winnograd f(4,3) 259 | bit_width=6 260 | 261 | aciq = 0 262 | weight_numel = module.weight.numel() 263 | if weight_numel/channel>8000: #when > 8000 , max_var > threshold 264 | aciq = 1 265 | #print("aciq",aciq,module) 266 | 267 | 268 | quant_handle = FakeQuantCuda(type=1,bit_width=bit_width,c=channel) 269 | # print("quantizing weight.") 270 | # print(module.weight[0][0][0]) 271 | module.weight_origin.data.copy_(module.weight.data) #copy float data to a new place 272 | 273 | quant_tuple = quant_handle(module.weight.data.detach().clone(),module.weight_scale.data.detach().clone(),aciq=aciq)#把原始数据 quant——dequant 此时数据是有损的,计算损失后,把备份数据考回原处做梯度计算 274 | module.weight.data = quant_tuple[0] 275 | module.weight_scale.data = quant_tuple[1] 276 | # print(module.weight[0][0][0]) 277 | #print(module.weight_scale) 278 | 279 | 280 | def register_quantization_hook(model, 281 | quant_weight=True, 282 | quant_activation=True, 283 | ): 284 | """register quantization hook for model. 285 | 286 | Args: 287 | model (:class:`Module`): Module. 288 | 289 | Returns: 290 | Module: self 291 | """ 292 | 293 | # weight quantizing. 294 | logger = logging.getLogger(__name__) 295 | logger.setLevel(logging.INFO) 296 | 297 | for _, module in model._modules.items(): 298 | #print("module",module) 299 | if len(list(module.children())) > 0: 300 | register_quantization_hook(module, quant_weight, quant_activation) 301 | else: 302 | if quant_weight and hasattr(module,"weight") and module.weight is not None and isinstance( 303 | module, (torch.nn.Conv2d,torch.nn.Linear)): 304 | module.register_buffer('weight_origin', module.weight.detach().clone()) #数据备份空间 305 | #module.register_buffer("weight_scale", torch.ones([1,model._modules["conv1"].weight.shape[0]], dtype=torch.float).cuda()) #weight scale 306 | #module.register_buffer("weight_scale", torch.ones([1,module.weight.shape[0]], dtype=torch.float).cuda()) #weight scale module.weight.shape =[o,i,k,k] 307 | module.register_buffer("weight_scale", torch.ones([module.weight.shape[0]], dtype=torch.float).cuda()) #weight scale module.weight.shape =[o,i,k,k] 308 | 309 | 310 | module.register_forward_pre_hook(_quantizing_weight_ncnn) 311 | logger.info("Quantizing weight of %s", str(module)) 312 | 313 | 314 | module.register_buffer("activation_scale", torch.tensor([1], dtype=torch.float).cuda()) 315 | module.register_buffer("activation_movMax", torch.tensor([1], dtype=torch.float).cuda()) 316 | #module.register_buffer("activation_momenta", torch.tensor([1], dtype=torch.float).cuda()) 317 | module.register_forward_pre_hook(_quantizing_activation_ncnn) 318 | logger.info("Quantizing activation of %s", str(module)) 319 | 320 | return model 321 | 322 | def save_table(torch_model,onnx_path="model.onnx",table="model.table"): 323 | f = open(table,"w",encoding='utf8') 324 | static_dict_org = torch_model.state_dict() 325 | static_dict = {k.replace('module.',''):v for k,v in static_dict_org.items()} 326 | 327 | 328 | model = onnx.load(onnx_path) 329 | node = model.graph.node 330 | node_num = len(node) 331 | 332 | tail_layer = "_param_0" 333 | split_char = " " 334 | tab_char = "\n" 335 | tail_len = 6 336 | for each in range(node_num): 337 | if node[each].op_type not in ["Conv","Gemm"]: 338 | continue 339 | #print(node[each].op_type) 340 | pre_name = node[each].input[1] 341 | #print(pre_name) 342 | #print(pre_name.replace(pre_name.split(".")[-1],"weight_scale")) 343 | scale_data = static_dict[pre_name.replace(pre_name.split(".")[-1],"weight_scale")] 344 | list_scale = scale_data.cpu().numpy().flatten().tolist() 345 | #print(node[each].name,node[each].op_type,node[each].input) 346 | f.write(node[each].name + tail_layer) 347 | for d in list_scale: 348 | d = float(d) 349 | f.write(split_char + "{:.6f}".format(d)) 350 | f.write(tab_char) 351 | for each in range(node_num): 352 | if node[each].op_type not in ["Conv","Gemm"]: 353 | continue 354 | pre_name = node[each].input[1] 355 | scale_data = static_dict[pre_name.replace(pre_name.split(".")[-1],"activation_scale")] 356 | list_scale = scale_data.cpu().numpy().flatten().tolist() 357 | #print(node[each].name,node[each].op_type,node[each].input) 358 | f.write(node[each].name) 359 | for d in list_scale: 360 | d = float(d) 361 | f.write(split_char + "{:.6f}".format(d)) 362 | 363 | f.write(tab_char) 364 | f.close() 365 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=40.8.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | license_files = LICENSE.txt 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import pathlib 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | from build_helper import check_cuda_version 6 | assert(check_cuda_version()) 7 | 8 | import os 9 | os.system('make -j%d' % os.cpu_count()) 10 | 11 | here = pathlib.Path(__file__).parent.resolve() 12 | long_description = (here / 'README.md').read_text(encoding='utf-8') 13 | 14 | setup( 15 | name='ncnnqat', 16 | version='0.1.0', 17 | description='A ncnn quantization aware training tool on pytorch.', 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | url='https://github.com/ChenShisen/ncnnqat', 21 | author='Shisen Chen', 22 | author_email='napoleo54css@gmail.com', 23 | license='MIT', 24 | classifiers=[ 25 | 'Development Status :: 5 - Production/Stable', 26 | "Intended Audience :: Science/Research", 27 | 'Intended Audience :: Developers', 28 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 29 | "Topic :: Software Development :: Libraries :: Python Modules", 30 | 'License :: OSI Approved :: MIT License', 31 | 'Programming Language :: Python :: 3', 32 | 'Programming Language :: Python :: 3.5', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Programming Language :: Python :: 3.7', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Programming Language :: Python :: 3 :: Only', 37 | ], 38 | keywords=[ 39 | "ncnn" 40 | "quantization aware training", 41 | "deep learning", 42 | "neural network", 43 | "CNN", 44 | "machine learning", 45 | ], 46 | packages=find_packages(), 47 | 48 | python_requires='>=3.5, <4', 49 | install_requires=[ 50 | "torch >= 1.5", 51 | "numpy >= 1.18.1", 52 | "onnx >= 1.7.0", 53 | "onnx-simplifier >= 0.3.6" 54 | ], 55 | extras_require={ 56 | 'test': ["torchvision>=0.4", 57 | "nose", 58 | "ddt" 59 | ], 60 | 'docs': [ 61 | 'sphinx==2.4.4', 62 | 'sphinx_rtd_theme' 63 | ] 64 | }, 65 | ext_modules=[ 66 | CUDAExtension( 67 | #name="quant_impl", 68 | name="quant_cuda", 69 | sources=[ 70 | "./src/fake_quantize.cpp", 71 | ], 72 | libraries=['quant_cuda'], 73 | library_dirs=['obj'], 74 | ) 75 | ], 76 | cmdclass={'build_ext': BuildExtension}, 77 | #test_suite="ncnnqat.test.test_cifar10", 78 | ) 79 | -------------------------------------------------------------------------------- /src/fake_quantize.cpp: -------------------------------------------------------------------------------- 1 | #include "fake_quantize.h" 2 | 3 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") 4 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 5 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 6 | 7 | std::vector fake_quantize(Tensor a, int bit_width,int type,int c,int aciq){ 8 | CHECK_INPUT(a); 9 | return fake_quantize_cuda(a, bit_width,type,c,aciq); 10 | } 11 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ 12 | m.def("fake_quantize", &fake_quantize, "NCNN Fake Quantization (CUDA)"); 13 | } 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/fake_quantize.cu: -------------------------------------------------------------------------------- 1 | #include "fake_quantize.h" 2 | 3 | 4 | __global__ void max_reduce(float* __restrict__ data,float* out_ptr,int width,int lg_n) //preset data[i] >=0 5 | { 6 | __shared__ float* middleware[blockSize]; 7 | const float min_positive_float = 1e-6; 8 | int row = blockIdx.x * width + threadIdx.x; 9 | int bid = blockIdx.x; 10 | int tid = threadIdx.x; 11 | int tid_tmp = threadIdx.x; 12 | 13 | //if(tid*(middleware[tid])) middleware[tid] = &(data[Row]); 21 | if(fabs(data[Row])>fabs(*(middleware[tid]))) middleware[tid] = data+row; 22 | row+=blockSize; 23 | tid_tmp +=blockSize; 24 | } 25 | __syncthreads(); 26 | 27 | //for(int i=blockSize/2; i>0; i/=2) 28 | for(int i=lg_n/2; i>0; i/=2) 29 | { 30 | if(tidfabs(*(middleware[tid]))) middleware[tid]=middleware[tid+i]; 33 | } 34 | __syncthreads(); 35 | } 36 | 37 | if(tid==0) out_ptr[bid] = fabs(*(middleware[0])); 38 | } 39 | __global__ void fake_quantize_layer_google(float* __restrict__ a, 40 | float* o, 41 | float* o1, 42 | float* mov_max, 43 | int size, 44 | int bit_width, 45 | float* max_entry) 46 | { 47 | int index = blockIdx.x * blockDim.x + threadIdx.x; 48 | if (index < size) 49 | { 50 | const float momenta = 0.95; 51 | float mov_max_tmp = mov_max[0]; 52 | if(mov_max_tmp<1e-6) mov_max_tmp=fabs(*max_entry); //movMax dafault 0 ,now first step set it a non zero data 53 | else mov_max_tmp= mov_max_tmp * momenta + fabs(*max_entry) * (1.-momenta); // #GOOGLE QAT : movMax = movMax*momenta + max(abs(tensor))*(1-momenta) momenta = 0.95 54 | float data_scale = __powf(2.,bit_width-1.)-1; 55 | 56 | float scale; 57 | if(mov_max_tmp < 1e-6) scale = __fdividef(data_scale,1e-6); 58 | else scale = __fdividef(data_scale,mov_max_tmp); 59 | 60 | int o_int = round(a[index]*scale); 61 | //o[index] = __fdividef(round(a[index]*scale),scale); 62 | if(o_int>data_scale) o_int=(int)data_scale; 63 | else if(o_int<-data_scale) o_int=(int)(-data_scale); 64 | else {}; 65 | o[index] = __fdividef(o_int*1.,scale); 66 | 67 | if(index==0) 68 | { 69 | o1[0] = scale; 70 | mov_max[0] = mov_max_tmp; 71 | } 72 | } 73 | } 74 | 75 | 76 | __global__ void fake_quantize_layer_aciq(float* __restrict__ a, 77 | float* o, 78 | float* o1, 79 | float* mov_max, 80 | int feature_pixl_num, 81 | int size, 82 | int bit_width, 83 | float* max_entry) 84 | { 85 | int index = blockIdx.x * blockDim.x + threadIdx.x; 86 | if (index < size) 87 | { 88 | const float momenta = 0.95; 89 | float mov_max_tmp = mov_max[0]; 90 | if(mov_max_tmp<1e-6) mov_max_tmp=fabs(*max_entry); //movMax dafault 0 ,now first step set it a non zero data 91 | else mov_max_tmp= fabs(*max_entry);//mov_max_tmp * momenta + fabs(*max_entry) * (1.-momenta); // #GOOGLE QAT : movMax = movMax*momenta + max(abs(tensor))*(1-momenta) momenta = 0.95 92 | float data_scale = __powf(2.,bit_width-1.)-1; 93 | 94 | const float alpha_gaussian[8] = {0, 1.71063519, 2.15159277, 2.55913646, 2.93620062, 3.28691474, 3.6151146, 3.92403714}; 95 | const double gaussian_const = (0.5 * 0.35) * (1 + sqrt(3.14159265358979323846 * __logf(4.))); 96 | double std = (mov_max_tmp * 2 * gaussian_const) / sqrt(2 * __logf(feature_pixl_num)); 97 | float threshold = (float)(alpha_gaussian[bit_width - 1] * std); 98 | 99 | float scale; 100 | if(threshold < 1e-6) scale = __fdividef(data_scale,1e-6); 101 | else scale = __fdividef(data_scale,threshold); 102 | //float o_index = __fdividef(round(a[index]*scale),scale); 103 | int o_int = round(a[index]*scale); 104 | //o[index] = __fdividef(round(a[index]*scale),scale); 105 | if(o_int>data_scale) o_int=(int)data_scale; 106 | else if(o_int<-data_scale) o_int=(int)(-data_scale); 107 | else {}; 108 | o[index] = __fdividef(o_int*1.,scale); 109 | 110 | if(index==0) 111 | { 112 | o1[0] = scale; 113 | mov_max[0] = mov_max_tmp; 114 | } 115 | } 116 | } 117 | 118 | __global__ void fake_quantize_channel_aciq(float* __restrict__ a, 119 | float* o, 120 | float* o1, 121 | int size, 122 | int bit_width, 123 | float* max_entry_arr, //max_entry_arr already>0 124 | int channel_num) 125 | { 126 | int index = blockIdx.x * blockDim.x + threadIdx.x; 127 | if (index < size) 128 | { 129 | int channel = index/channel_num; 130 | float* max_entry = max_entry_arr+channel; 131 | float data_scale = __powf(2.,bit_width-1.)-1; 132 | if((*max_entry) < 1e-6) 133 | { 134 | //if(index%channel_num==0) o1[channel] = scale; 135 | *max_entry = 1e-6; 136 | //return; 137 | } 138 | const float alpha_gaussian[8] = {0, 1.71063519, 2.15159277, 2.55913646, 2.93620062, 3.28691474, 3.6151146, 3.92403714}; 139 | const double gaussian_const = (0.5 * 0.35) * (1 + sqrt(3.14159265358979323846 * __logf(4.))); 140 | double std = ((*max_entry) * 2 * gaussian_const) / sqrt(2 * __logf(channel_num)); 141 | float threshold = (float)(alpha_gaussian[bit_width - 1] * std); 142 | 143 | float scale = __fdividef(data_scale,threshold); 144 | int o_int = round(a[index]*scale); 145 | if(o_int>data_scale) o_int=(int)data_scale; 146 | else if(o_int<-data_scale) o_int=(int)(-data_scale); 147 | else {}; 148 | o[index] = __fdividef(o_int*1.,scale); 149 | if(index%channel_num==0) o1[channel] = scale; 150 | } 151 | } 152 | __global__ void fake_quantize_channel_cuda(float* __restrict__ a, 153 | float* o, 154 | float* o1, 155 | int size, 156 | int bit_width, 157 | float* max_entry_arr, //max_entry_arr already>0 158 | int channel_num) 159 | { 160 | int index = blockIdx.x * blockDim.x + threadIdx.x; 161 | if (index < size) 162 | { 163 | int channel = index/channel_num; 164 | float* max_entry = max_entry_arr+channel; 165 | float data_scale = __powf(2.,bit_width-1.)-1; 166 | if((*max_entry) < 1e-6) 167 | { 168 | //if(index%channel_num==0) o1[channel] = scale; 169 | *max_entry = 1e-6; 170 | //return; 171 | } 172 | float scale = __fdividef(data_scale,*max_entry); 173 | o[index] = __fdividef(round(a[index]*scale),scale); 174 | if(index%channel_num==0) o1[channel] = scale; 175 | } 176 | } 177 | std::vector fake_quantize_activate_cuda(Tensor a, int bit_width ,int aciq) 178 | { 179 | auto o = at::zeros_like(a); //q out 180 | auto o1 = at::zeros({1}, a.options()); //scale 181 | auto mov_max = at::zeros({1}, a.options()); //max of tensor #GOOGLE QAT movMax = movMax*momenta + max(abs(tensor))*(1-momenta) momenta = 0.95 182 | int64_t size = a.numel(); 183 | 184 | int batch_size = a.size(0);//batchsize 185 | int feature_pixl_num = size/batch_size; 186 | 187 | Tensor max_entry = at::max(at::abs(a)); 188 | int blockNums = (size + blockSize - 1) / blockSize; 189 | 190 | if(aciq==0) //movmax 191 | { 192 | //printf("layer_max...."); 193 | fake_quantize_layer_google<<>>(a.data_ptr(), 194 | o.data_ptr(), 195 | o1.data_ptr(), 196 | mov_max.data_ptr(), 197 | size, 198 | bit_width, 199 | max_entry.data_ptr()); 200 | } 201 | else // aciq 202 | { 203 | //printf("layer_aciq...."); 204 | fake_quantize_layer_aciq<<>>(a.data_ptr(), 205 | o.data_ptr(), 206 | o1.data_ptr(), 207 | mov_max.data_ptr(), 208 | feature_pixl_num, 209 | size, 210 | bit_width, 211 | max_entry.data_ptr()); 212 | } 213 | return {o,o1,mov_max}; 214 | } 215 | 216 | 217 | std::vector fake_quantize_weight_cuda(Tensor a, int bit_width,int c ,int aciq) 218 | { 219 | auto o = at::zeros_like(a); //q out 220 | auto o1 = at::zeros({c}, a.options()); //scale 221 | int64_t size = a.numel(); 222 | 223 | int blockNums = (size + blockSize - 1) / blockSize; 224 | int channel_num = size/c; 225 | auto max_entry_arr = at::zeros({c}, a.options()); 226 | 227 | int lg_n = ceil(log2(channel_num*1.)); //2^x - channel_num >0 228 | lg_n = pow(2,lg_n); //2^x 229 | if(lg_n>blockSize) lg_n=blockSize; // 230 | 231 | max_reduce <<>> (a.data_ptr(), 232 | max_entry_arr.data_ptr(), 233 | channel_num, 234 | lg_n); //c block , each block get a max value 235 | 236 | if(aciq==0) 237 | { 238 | //printf("weight_max...."); 239 | fake_quantize_channel_cuda<<>>(a.data_ptr(), 240 | o.data_ptr(), 241 | o1.data_ptr(), 242 | size, 243 | bit_width, 244 | max_entry_arr.data_ptr(), //max_entry_arr already>0 245 | channel_num); 246 | } 247 | else 248 | { 249 | //printf("weight_aciq...."); 250 | fake_quantize_channel_aciq<<>>(a.data_ptr(), 251 | o.data_ptr(), 252 | o1.data_ptr(), 253 | size, 254 | bit_width, 255 | max_entry_arr.data_ptr(), //max_entry_arr already>0 256 | channel_num); 257 | 258 | } 259 | return {o,o1}; 260 | } 261 | 262 | 263 | std::vector fake_quantize_cuda(Tensor a, int bit_width,int type,int c,int aciq) 264 | { 265 | /* 266 | https://arxiv.org/pdf/1806.08342.pdf 2.5 267 | For weights,we use the actual minimum and maximum values to determine the quantizer parameters. 268 | For activations, we use the moving average of the minimum and maximum values across batches to determine the quantizer parameters. 269 | float 6 7 ,double 15 16 270 | */ 271 | if(type==0) return fake_quantize_activate_cuda(a,bit_width,aciq); //type==0 per layer 272 | else return fake_quantize_weight_cuda(a,bit_width,c,aciq); //type==1 perchannel 273 | } 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /src/fake_quantize.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | const int blockSize = 1024; 13 | //#define blockSize 1024 14 | 15 | using namespace at; 16 | 17 | 18 | 19 | std::vector fake_quantize_cuda(Tensor a, int bit_width=8,int type=1,int c=1,int aciq=0); 20 | 21 | std::vector fake_quantize_activate_cuda(Tensor a, int bit_width ,int aciq); 22 | std::vector fake_quantize_weight_cuda(Tensor a, int bit_width,int c,int aciq); 23 | 24 | 25 | __global__ void max_reduce(float* __restrict__ data,float* out_ptr,int width,int lg_n); 26 | 27 | 28 | 29 | 30 | __global__ void fake_quantize_layer_google(float* __restrict__ a, 31 | float* o, 32 | float* o1, 33 | float* mov_max, 34 | int size, 35 | int bit_width, 36 | float* max_entry); 37 | __global__ void fake_quantize_layer_aciq(float* __restrict__ a, 38 | float* o, 39 | float* o1, 40 | float* mov_max, 41 | int feature_pixl_num, 42 | int size, 43 | int bit_width, 44 | float* max_entry); 45 | 46 | __global__ void fake_quantize_channel_cuda(float* __restrict__ a, 47 | float* o, 48 | float* o1, 49 | int size, 50 | int bit_width, 51 | float* max_entry_arr, 52 | int channel_num); 53 | __global__ void fake_quantize_channel_aciq(float* __restrict__ a, 54 | float* o, 55 | float* o1, 56 | int size, 57 | int bit_width, 58 | float* max_entry_arr, 59 | int channel_num); 60 | 61 | -------------------------------------------------------------------------------- /tests/ssd300/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from argparse import ArgumentParser 5 | import math 6 | import numpy as np 7 | import time 8 | import torch 9 | from torch.optim.lr_scheduler import MultiStepLR 10 | import torch.utils.data.distributed 11 | from torchsummary import summary 12 | 13 | from ncnnqat import merge_freeze_bn, register_quantization_hook,save_table 14 | 15 | from src.model import model, Loss 16 | from src.utils import dboxes300_coco, Encoder 17 | 18 | from src.evaluate import evaluate 19 | from src.train import train_loop, tencent_trick 20 | from src.data import * 21 | 22 | 23 | #os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5' 24 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 25 | #os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 26 | 27 | 28 | 29 | class Logger: 30 | def __init__(self, batch_size, local_rank, n_gpu, print_freq=20): 31 | self.batch_size = batch_size 32 | self.local_rank = local_rank 33 | self.n_gpu = n_gpu 34 | self.print_freq = print_freq 35 | 36 | self.processed_samples = 0 37 | self.epochs_times = [] 38 | self.epochs_speeds = [] 39 | 40 | 41 | def update_iter(self, epoch, iteration, loss): 42 | if self.local_rank != 0: 43 | return 44 | 45 | if iteration % self.print_freq == 0: 46 | print('Epoch: {:2d}, Iteration: {}, Loss: {}'.format(epoch, iteration, loss)) 47 | 48 | self.processed_samples = self.processed_samples + self.batch_size 49 | 50 | def start_epoch(self): 51 | self.epoch_start = time.time() 52 | 53 | def end_epoch(self): 54 | epoch_time = time.time() - self.epoch_start 55 | epoch_speed = self.processed_samples / epoch_time 56 | 57 | self.epochs_times.append(epoch_time) 58 | self.epochs_speeds.append(epoch_speed) 59 | self.processed_samples = 0 60 | 61 | if self.local_rank == 0: 62 | print('Epoch {:2d} finished. Time: {:4f} s, Speed: {:4f} img/sec, Average speed: {:4f}' 63 | .format(len(self.epochs_times)-1, epoch_time, epoch_speed * self.n_gpu, self.average_speed() * self.n_gpu)) 64 | 65 | def average_speed(self): 66 | return sum(self.epochs_speeds) / len(self.epochs_speeds) 67 | 68 | 69 | def make_parser(): 70 | epoch_all = 65 71 | epoch_qat = epoch_all-5 if epoch_all-5>0 else epoch_all 72 | 73 | eval_list = [0,epoch_all-1] if epoch_all-1>0 else [0] 74 | 75 | parser = ArgumentParser( 76 | description="Train Single Shot MultiBox Detector on COCO") 77 | parser.add_argument( 78 | '--data', '-d', type=str, default='./data/coco', required=False, 79 | help='path to test and training data files') 80 | parser.add_argument( 81 | '--epochs', '-e', type=int, default=epoch_all, #65 82 | help='number of epochs for training') 83 | parser.add_argument( 84 | '--qat-epoch', '-q', type=int, default=epoch_qat, 85 | help='epoch of qat begaining') 86 | parser.add_argument( 87 | '--batch-size', '--bs', type=int, default=32, 88 | help='number of examples for each iteration') 89 | parser.add_argument( 90 | '--eval-batch-size', '--ebs', type=int, default=32, 91 | help='number of examples for each evaluation iteration') 92 | parser.add_argument( 93 | '--seed', '-s', type=int, default=0, 94 | help='manually set random seed for torch') 95 | parser.add_argument( 96 | '--evaluation', nargs='*', type=int, 97 | default=eval_list,#[0, 48, 53, 59,63, 64,65], 98 | help='epochs at which to evaluate') 99 | parser.add_argument( 100 | '--multistep', nargs='*', type=int, default=[43, 54], 101 | help='epochs at which to decay learning rate') 102 | parser.add_argument( 103 | '--target', type=float, default=None, 104 | help='target mAP to assert against at the end') 105 | 106 | #save model 107 | parser.add_argument('--check-save', '--s', type=bool, default=True) 108 | parser.add_argument( 109 | '--check-point', '-c', type=str, default='./models', required=False, 110 | help='path to model save files') 111 | parser.add_argument('--onnx_save', action='store_true') 112 | 113 | # Hyperparameters 114 | parser.add_argument( 115 | '--learning-rate', '--lr', type=float, default=2.6e-3, help='learning rate') 116 | parser.add_argument( 117 | '--momentum', '-m', type=float, default=0.9, 118 | help='momentum argument for SGD optimizer') 119 | parser.add_argument( 120 | '--weight-decay', '--wd', type=float, default=0.0005, 121 | help='momentum argument for SGD optimizer') 122 | parser.add_argument('--warmup', type=int, default=None) 123 | parser.add_argument( 124 | '--backbone', type=str, default='resnet18', 125 | choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']) 126 | parser.add_argument('--num-workers', type=int, default=4) 127 | parser.add_argument('--fp16-mode', type=str, default='static', choices=['off', 'static', 'amp'], 128 | help='Half precission mode to use') 129 | 130 | # Distributed 131 | parser.add_argument('--local_rank', default=0, type=int, 132 | help='Used for multi-process training. Can either be manually set ' + 133 | 'or automatically set by using \'python -m multiproc\'.') 134 | 135 | # Pipeline control 136 | parser.add_argument( 137 | '--data_pipeline', type=str, default='dali', choices=['dali', 'no_dali'], 138 | help='data preprocessing pipline to use') 139 | 140 | return parser 141 | 142 | 143 | def train(args): 144 | 145 | args.distributed = False 146 | if 'WORLD_SIZE' in os.environ: 147 | print('WORLD_SIZE in os.environ',os.environ['WORLD_SIZE'],args.local_rank) 148 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 149 | print(args.distributed) 150 | if args.distributed: 151 | torch.cuda.set_device(args.local_rank) 152 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 153 | args.N_gpu = torch.distributed.get_world_size() 154 | else: 155 | args.N_gpu = 1 156 | 157 | dboxes = dboxes300_coco() 158 | encoder = Encoder(dboxes) 159 | cocoGt = get_coco_ground_truth(args) 160 | 161 | ssd300 = model(args) 162 | 163 | loss_func = Loss(dboxes) 164 | loss_func.cuda() 165 | 166 | 167 | args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size / 32) 168 | iteration = 0 169 | 170 | optimizer = torch.optim.SGD( 171 | tencent_trick(ssd300), 172 | lr=args.learning_rate, 173 | momentum=args.momentum, 174 | weight_decay=args.weight_decay) 175 | 176 | scheduler = MultiStepLR( 177 | optimizer=optimizer, 178 | milestones=args.multistep, 179 | gamma=0.1) 180 | 181 | 182 | 183 | val_dataloader, inv_map = get_val_dataloader(args) 184 | train_loader = get_train_loader(args, dboxes) 185 | 186 | #print(inv_map) 187 | #print(val_dataset.label_info) 188 | 189 | acc = 0 190 | acc_best = 0 191 | epoch_check = 0 192 | logger = Logger(args.batch_size, args.local_rank, args.N_gpu) 193 | 194 | for epoch in range(epoch_check, args.epochs): 195 | logger.start_epoch() 196 | #scheduler.step() 197 | #print(ssd300) 198 | '''qat''' 199 | if epoch==args.qat_epoch: 200 | register_quantization_hook(ssd300) 201 | ssd300 = merge_freeze_bn(ssd300) 202 | print("qat hook...") 203 | if epoch>args.qat_epoch: 204 | ssd300 = merge_freeze_bn(ssd300) 205 | print("merge bn ...") 206 | '''qat''' 207 | 208 | iteration = train_loop( 209 | ssd300, loss_func, epoch, optimizer, 210 | train_loader, iteration, logger, args) 211 | scheduler.step() 212 | logger.end_epoch() 213 | 214 | if epoch in args.evaluation: 215 | acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args) 216 | if args.local_rank == 0: 217 | print('Epoch {:2d}, Accuracy: {:4f} mAP'.format(epoch, acc)) 218 | 219 | 220 | if acc>=acc_best and args.local_rank == 0: 221 | acc_best = acc 222 | 223 | if args.distributed: 224 | model_dict = ssd300.module.state_dict() 225 | else: 226 | model_dict = ssd300.state_dict() 227 | torch.save({ 228 | 'epoch': epoch+1, 229 | 'model_state_dict': model_dict, 230 | 'optimizer_state_dict': optimizer.state_dict(), 231 | 'val_acc':acc_best, 232 | "inv_map":inv_map, 233 | "scheduler":scheduler.state_dict(), 234 | }, args.checkpoint) 235 | 236 | if args.data_pipeline == 'dali': 237 | train_loader.reset() 238 | 239 | return acc, logger.average_speed() 240 | 241 | 242 | if __name__ == "__main__": 243 | parser = make_parser() 244 | args = parser.parse_args() 245 | 246 | 247 | if args.onnx_save: #after train ,load model , save onnx model and ncnn table 248 | #python main.py --onnx_save 249 | onnx_path = os.path.join(args.check_point,"model.onnx") 250 | table_path = os.path.join(args.check_point,"model.table") 251 | #print(onnx_path) 252 | checkpoint = os.path.join(args.check_point,"model.pt") 253 | #print(checkpoint) 254 | ssd300 = model(args,onnx_save=args.onnx_save) 255 | summary(ssd300, input_size=(3, 300, 300), device='cpu') 256 | ssd300.cuda() 257 | '''qat''' 258 | register_quantization_hook(ssd300) 259 | ssd300 = merge_freeze_bn(ssd300) 260 | '''qat''' 261 | if os.path.exists(checkpoint): 262 | print("loadmodel from checkpoint...") 263 | checkpoint_load = torch.load(checkpoint,map_location='cpu') 264 | #ssd300.module.load_state_dict(checkpoint_load['model_state_dict']) #donot know ssd300 is distributed 265 | ssd300.load_state_dict({k.replace('module.',''):v for k,v in checkpoint_load['model_state_dict'].items()}) 266 | print("loadmodel from checkpoint end...") 267 | ssd300.eval() 268 | input_names = [ "input" ] 269 | #output_names = [ "SSD300-184" ] 270 | output_names = [ "Conv2d-93" ] 271 | dummy_input = torch.ones([1, 3, 300, 300]).cuda() 272 | #dummy_input = torch.randn(1, 3, 300, 300, device='cuda') 273 | torch.onnx.export(ssd300, dummy_input, onnx_path, verbose=False, input_names=input_names, output_names=output_names) 274 | save_table(ssd300,onnx_path=onnx_path,table=table_path) 275 | else: 276 | args.checkpoint = os.path.join(args.check_point,"model.pt") 277 | if args.local_rank == 0: 278 | os.makedirs(args.check_point, exist_ok=True) 279 | 280 | torch.backends.cudnn.benchmark = True 281 | 282 | if args.fp16_mode != 'off': 283 | args.fp16 = True 284 | else: 285 | args.fp16 = False 286 | #print(args) 287 | start_time = time.time() 288 | acc, avg_speed = train(args) 289 | # avg_speed is reported per node, adjust for the global speed 290 | try: 291 | num_shards = torch.distributed.get_world_size() 292 | except RuntimeError: 293 | num_shards = 1 294 | avg_speed = num_shards * avg_speed 295 | training_time = time.time() - start_time 296 | 297 | if args.local_rank == 0: 298 | print("Training end: Average speed: {:3f} img/sec, Total time: {:3f} sec, Final accuracy: {:3f} mAP" 299 | .format(avg_speed, training_time, acc)) 300 | 301 | if args.target is not None: 302 | if args.target > acc: 303 | print('Target mAP of {} not met. Possible regression'.format(args.target)) 304 | sys.exit(1) 305 | ''' 306 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.253 307 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.429 308 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.262 309 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.075 310 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.273 311 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.397 312 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.240 313 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.349 314 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.367 315 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.122 316 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.406 317 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.551 318 | 214img/sec 319 | 320 | dali 218img/sec 321 | 322 | 323 | warmup 200 + dali + fp16 324 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.256 325 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.434 326 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.263 327 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.078 328 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.274 329 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.408 330 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.240 331 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.352 332 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.368 333 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.126 334 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.403 335 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.563 336 | Current AP: 0.25628 337 | Epoch 64, Accuracy: 0.256285 mAP 338 | DONE (t=9.45s). 339 | Training end: Average speed: 232.580538 img/sec, Total time: 35018.003625 sec, Final accuracy: 0.256285 mAP 340 | 341 | 342 | 343 | not qat 344 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.193 345 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.344 346 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.191 347 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.042 348 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.195 349 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.328 350 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.199 351 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.293 352 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.309 353 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.084 354 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.326 355 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.501 356 | Current AP: 0.19269 357 | 358 | qat resnet18 359 | Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.192 360 | Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.342 361 | Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.194 362 | Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.041 363 | Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.194 364 | Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.327 365 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.197 366 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.291 367 | Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.307 368 | Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.082 369 | Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.325 370 | Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.497 371 | Current AP: 0.19202 372 | ''' -------------------------------------------------------------------------------- /tests/ssd300/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenShisen/ncnnqat/253a413264507cf90089d1aa0e30c0ef30087cfe/tests/ssd300/src/__init__.py -------------------------------------------------------------------------------- /tests/ssd300/src/coco.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | __version__ = '2.0' 3 | # Interface for accessing the Microsoft COCO dataset. 4 | 5 | # Microsoft COCO is a large image dataset designed for object detection, 6 | # segmentation, and caption generation. pycocotools is a Python API that 7 | # assists in loading, parsing and visualizing the annotations in COCO. 8 | # Please visit http://mscoco.org/ for more information on COCO, including 9 | # for the data, paper, and tutorials. The exact format of the annotations 10 | # is also described on the COCO website. For example usage of the pycocotools 11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 12 | # the COCO images and annotations in order to run the demo. 13 | 14 | # An alternative to using the API is to load the annotations directly 15 | # into Python dictionary 16 | # Using the API provides additional utility functions. Note that this API 17 | # supports both *instance* and *caption* annotations. In the case of 18 | # captions not all functions are defined (e.g. categories are undefined). 19 | 20 | # The following API functions are defined: 21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 22 | # decodeMask - Decode binary mask M encoded via run-length encoding. 23 | # encodeMask - Encode binary mask M using run-length encoding. 24 | # getAnnIds - Get ann ids that satisfy given filter conditions. 25 | # getCatIds - Get cat ids that satisfy given filter conditions. 26 | # getImgIds - Get img ids that satisfy given filter conditions. 27 | # loadAnns - Load anns with the specified ids. 28 | # loadCats - Load cats with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # annToMask - Convert segmentation in an annotation to binary mask. 31 | # showAnns - Display the specified annotations. 32 | # loadRes - Load algorithm results and create API for accessing them. 33 | # download - Download COCO images from mscoco.org server. 34 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 35 | # Help on each functions can be accessed by: "help COCO>function". 36 | 37 | # See also COCO>decodeMask, 38 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 39 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 40 | # COCO>loadImgs, COCO>annToMask, COCO>showAnns 41 | 42 | # Microsoft COCO Toolbox. version 2.0 43 | # Data, paper, and tutorials available at: http://mscoco.org/ 44 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 45 | # Licensed under the Simplified BSD License [see bsd.txt] 46 | 47 | import json 48 | import time 49 | import matplotlib.pyplot as plt 50 | from matplotlib.collections import PatchCollection 51 | from matplotlib.patches import Polygon 52 | import numpy as np 53 | import copy 54 | import itertools 55 | from pycocotools import mask as maskUtils 56 | import os 57 | from collections import defaultdict 58 | import sys 59 | PYTHON_VERSION = sys.version_info[0] 60 | if PYTHON_VERSION == 2: 61 | from urllib import urlretrieve 62 | elif PYTHON_VERSION == 3: 63 | from urllib.request import urlretrieve 64 | 65 | 66 | def _isArrayLike(obj): 67 | return hasattr(obj, '__iter__') and hasattr(obj, '__len__') 68 | 69 | 70 | class COCO: 71 | def __init__(self, annotation_file=None): 72 | """ 73 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 74 | :param annotation_file (str): location of annotation file 75 | :param image_folder (str): location to the folder that hosts images. 76 | :return: 77 | """ 78 | # load dataset 79 | self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict() 80 | self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) 81 | if not annotation_file == None: 82 | print('loading annotations into memory...') 83 | tic = time.time() 84 | dataset = json.load(open(annotation_file, 'r')) 85 | assert type(dataset)==dict, 'annotation file format {} not supported'.format(type(dataset)) 86 | print('Done (t={:0.2f}s)'.format(time.time()- tic)) 87 | self.dataset = dataset 88 | self.createIndex() 89 | 90 | def createIndex(self): 91 | # create index 92 | print('creating index...') 93 | anns, cats, imgs = {}, {}, {} 94 | imgToAnns,catToImgs = defaultdict(list),defaultdict(list) 95 | if 'annotations' in self.dataset: 96 | for ann in self.dataset['annotations']: 97 | imgToAnns[ann['image_id']].append(ann) 98 | anns[ann['id']] = ann 99 | 100 | if 'images' in self.dataset: 101 | for img in self.dataset['images']: 102 | imgs[img['id']] = img 103 | 104 | if 'categories' in self.dataset: 105 | for cat in self.dataset['categories']: 106 | cats[cat['id']] = cat 107 | 108 | if 'annotations' in self.dataset and 'categories' in self.dataset: 109 | for ann in self.dataset['annotations']: 110 | catToImgs[ann['category_id']].append(ann['image_id']) 111 | 112 | print('index created!') 113 | 114 | # create class members 115 | self.anns = anns 116 | self.imgToAnns = imgToAnns 117 | self.catToImgs = catToImgs 118 | self.imgs = imgs 119 | self.cats = cats 120 | 121 | def info(self): 122 | """ 123 | Print information about the annotation file. 124 | :return: 125 | """ 126 | for key, value in self.dataset['info'].items(): 127 | print('{}: {}'.format(key, value)) 128 | 129 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 130 | """ 131 | Get ann ids that satisfy given filter conditions. default skips that filter 132 | :param imgIds (int array) : get anns for given imgs 133 | catIds (int array) : get anns for given cats 134 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 135 | iscrowd (boolean) : get anns for given crowd label (False or True) 136 | :return: ids (int array) : integer array of ann ids 137 | """ 138 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 139 | catIds = catIds if _isArrayLike(catIds) else [catIds] 140 | 141 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 142 | anns = self.dataset['annotations'] 143 | else: 144 | if not len(imgIds) == 0: 145 | lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] 146 | anns = list(itertools.chain.from_iterable(lists)) 147 | else: 148 | anns = self.dataset['annotations'] 149 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 150 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 151 | if not iscrowd == None: 152 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 153 | else: 154 | ids = [ann['id'] for ann in anns] 155 | return ids 156 | 157 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 158 | """ 159 | filtering parameters. default skips that filter. 160 | :param catNms (str array) : get cats for given cat names 161 | :param supNms (str array) : get cats for given supercategory names 162 | :param catIds (int array) : get cats for given cat ids 163 | :return: ids (int array) : integer array of cat ids 164 | """ 165 | catNms = catNms if _isArrayLike(catNms) else [catNms] 166 | supNms = supNms if _isArrayLike(supNms) else [supNms] 167 | catIds = catIds if _isArrayLike(catIds) else [catIds] 168 | 169 | if len(catNms) == len(supNms) == len(catIds) == 0: 170 | cats = self.dataset['categories'] 171 | else: 172 | cats = self.dataset['categories'] 173 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 174 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 175 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 176 | ids = [cat['id'] for cat in cats] 177 | return ids 178 | 179 | def getImgIds(self, imgIds=[], catIds=[]): 180 | ''' 181 | Get img ids that satisfy given filter conditions. 182 | :param imgIds (int array) : get imgs for given ids 183 | :param catIds (int array) : get imgs with all given cats 184 | :return: ids (int array) : integer array of img ids 185 | ''' 186 | imgIds = imgIds if _isArrayLike(imgIds) else [imgIds] 187 | catIds = catIds if _isArrayLike(catIds) else [catIds] 188 | 189 | if len(imgIds) == len(catIds) == 0: 190 | ids = self.imgs.keys() 191 | else: 192 | ids = set(imgIds) 193 | for i, catId in enumerate(catIds): 194 | if i == 0 and len(ids) == 0: 195 | ids = set(self.catToImgs[catId]) 196 | else: 197 | ids &= set(self.catToImgs[catId]) 198 | return list(ids) 199 | 200 | def loadAnns(self, ids=[]): 201 | """ 202 | Load anns with the specified ids. 203 | :param ids (int array) : integer ids specifying anns 204 | :return: anns (object array) : loaded ann objects 205 | """ 206 | if _isArrayLike(ids): 207 | return [self.anns[id] for id in ids] 208 | elif type(ids) == int: 209 | return [self.anns[ids]] 210 | 211 | def loadCats(self, ids=[]): 212 | """ 213 | Load cats with the specified ids. 214 | :param ids (int array) : integer ids specifying cats 215 | :return: cats (object array) : loaded cat objects 216 | """ 217 | if _isArrayLike(ids): 218 | return [self.cats[id] for id in ids] 219 | elif type(ids) == int: 220 | return [self.cats[ids]] 221 | 222 | def loadImgs(self, ids=[]): 223 | """ 224 | Load anns with the specified ids. 225 | :param ids (int array) : integer ids specifying img 226 | :return: imgs (object array) : loaded img objects 227 | """ 228 | if _isArrayLike(ids): 229 | return [self.imgs[id] for id in ids] 230 | elif type(ids) == int: 231 | return [self.imgs[ids]] 232 | 233 | def showAnns(self, anns): 234 | """ 235 | Display the specified annotations. 236 | :param anns (array of object): annotations to display 237 | :return: None 238 | """ 239 | if len(anns) == 0: 240 | return 0 241 | if 'segmentation' in anns[0] or 'keypoints' in anns[0]: 242 | datasetType = 'instances' 243 | elif 'caption' in anns[0]: 244 | datasetType = 'captions' 245 | else: 246 | raise Exception('datasetType not supported') 247 | if datasetType == 'instances': 248 | ax = plt.gca() 249 | ax.set_autoscale_on(False) 250 | polygons = [] 251 | color = [] 252 | for ann in anns: 253 | c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] 254 | if 'segmentation' in ann: 255 | if type(ann['segmentation']) == list: 256 | # polygon 257 | for seg in ann['segmentation']: 258 | poly = np.array(seg).reshape((int(len(seg)/2), 2)) 259 | polygons.append(Polygon(poly)) 260 | color.append(c) 261 | else: 262 | # mask 263 | t = self.imgs[ann['image_id']] 264 | if type(ann['segmentation']['counts']) == list: 265 | rle = maskUtils.frPyObjects([ann['segmentation']], t['height'], t['width']) 266 | else: 267 | rle = [ann['segmentation']] 268 | m = maskUtils.decode(rle) 269 | img = np.ones( (m.shape[0], m.shape[1], 3) ) 270 | if ann['iscrowd'] == 1: 271 | color_mask = np.array([2.0,166.0,101.0])/255 272 | if ann['iscrowd'] == 0: 273 | color_mask = np.random.random((1, 3)).tolist()[0] 274 | for i in range(3): 275 | img[:,:,i] = color_mask[i] 276 | ax.imshow(np.dstack( (img, m*0.5) )) 277 | if 'keypoints' in ann and type(ann['keypoints']) == list: 278 | # turn skeleton into zero-based index 279 | sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1 280 | kp = np.array(ann['keypoints']) 281 | x = kp[0::3] 282 | y = kp[1::3] 283 | v = kp[2::3] 284 | for sk in sks: 285 | if np.all(v[sk]>0): 286 | plt.plot(x[sk],y[sk], linewidth=3, color=c) 287 | plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2) 288 | plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2) 289 | p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) 290 | ax.add_collection(p) 291 | p = PatchCollection(polygons, facecolor='none', edgecolors=color, linewidths=2) 292 | ax.add_collection(p) 293 | elif datasetType == 'captions': 294 | for ann in anns: 295 | print(ann['caption']) 296 | 297 | def loadRes(self, resFile): 298 | """ 299 | Load result file and return a result api object. 300 | :param resFile (str) : file name of result file 301 | :return: res (obj) : result api object 302 | """ 303 | res = COCO() 304 | res.dataset['images'] = [img for img in self.dataset['images']] 305 | 306 | print('Loading and preparing results...') 307 | tic = time.time() 308 | if type(resFile) == str: #or type(resFile) == unicode: 309 | anns = json.load(open(resFile)) 310 | elif type(resFile) == np.ndarray: 311 | anns = self.loadNumpyAnnotations(resFile) 312 | else: 313 | anns = resFile 314 | assert type(anns) == list, 'results in not an array of objects' 315 | annsImgIds = [ann['image_id'] for ann in anns] 316 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 317 | 'Results do not correspond to current coco set' 318 | if 'caption' in anns[0]: 319 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 320 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 321 | for id, ann in enumerate(anns): 322 | ann['id'] = id+1 323 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 324 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 325 | for id, ann in enumerate(anns): 326 | bb = ann['bbox'] 327 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 328 | if not 'segmentation' in ann: 329 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 330 | ann['area'] = bb[2]*bb[3] 331 | ann['id'] = id+1 332 | ann['iscrowd'] = 0 333 | elif 'segmentation' in anns[0]: 334 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 335 | for id, ann in enumerate(anns): 336 | # now only support compressed RLE format as segmentation results 337 | ann['area'] = maskUtils.area(ann['segmentation']) 338 | if not 'bbox' in ann: 339 | ann['bbox'] = maskUtils.toBbox(ann['segmentation']) 340 | ann['id'] = id+1 341 | ann['iscrowd'] = 0 342 | elif 'keypoints' in anns[0]: 343 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 344 | for id, ann in enumerate(anns): 345 | s = ann['keypoints'] 346 | x = s[0::3] 347 | y = s[1::3] 348 | x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y) 349 | ann['area'] = (x1-x0)*(y1-y0) 350 | ann['id'] = id + 1 351 | ann['bbox'] = [x0,y0,x1-x0,y1-y0] 352 | print('DONE (t={:0.2f}s)'.format(time.time()- tic)) 353 | 354 | res.dataset['annotations'] = anns 355 | res.createIndex() 356 | return res 357 | 358 | def download(self, tarDir = None, imgIds = [] ): 359 | ''' 360 | Download COCO images from mscoco.org server. 361 | :param tarDir (str): COCO results directory name 362 | imgIds (list): images to be downloaded 363 | :return: 364 | ''' 365 | if tarDir is None: 366 | print('Please specify target directory') 367 | return -1 368 | if len(imgIds) == 0: 369 | imgs = self.imgs.values() 370 | else: 371 | imgs = self.loadImgs(imgIds) 372 | N = len(imgs) 373 | if not os.path.exists(tarDir): 374 | os.makedirs(tarDir) 375 | for i, img in enumerate(imgs): 376 | tic = time.time() 377 | fname = os.path.join(tarDir, img['file_name']) 378 | if not os.path.exists(fname): 379 | urlretrieve(img['coco_url'], fname) 380 | print('downloaded {}/{} images (t={:0.1f}s)'.format(i, N, time.time()- tic)) 381 | 382 | def loadNumpyAnnotations(self, data): 383 | """ 384 | Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class} 385 | :param data (numpy.ndarray) 386 | :return: annotations (python nested list) 387 | """ 388 | print('Converting ndarray to lists...') 389 | assert(type(data) == np.ndarray) 390 | print(data.shape) 391 | assert(data.shape[1] == 7) 392 | N = data.shape[0] 393 | ann = [] 394 | for i in range(N): 395 | if i % 1000000 == 0: 396 | print('{}/{}'.format(i,N)) 397 | ann += [{ 398 | 'image_id' : int(data[i, 0]), 399 | 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ], 400 | 'score' : data[i, 5], 401 | 'category_id': int(data[i, 6]), 402 | }] 403 | return ann 404 | 405 | def annToRLE(self, ann): 406 | """ 407 | Convert annotation which can be polygons, uncompressed RLE to RLE. 408 | :return: binary mask (numpy 2D array) 409 | """ 410 | t = self.imgs[ann['image_id']] 411 | h, w = t['height'], t['width'] 412 | segm = ann['segmentation'] 413 | if type(segm) == list: 414 | # polygon -- a single object might consist of multiple parts 415 | # we merge all parts into one mask rle code 416 | rles = maskUtils.frPyObjects(segm, h, w) 417 | rle = maskUtils.merge(rles) 418 | elif type(segm['counts']) == list: 419 | # uncompressed RLE 420 | rle = maskUtils.frPyObjects(segm, h, w) 421 | else: 422 | # rle 423 | rle = ann['segmentation'] 424 | return rle 425 | 426 | def annToMask(self, ann): 427 | """ 428 | Convert annotation which can be polygons, uncompressed RLE, or RLE to binary mask. 429 | :return: binary mask (numpy 2D array) 430 | """ 431 | rle = self.annToRLE(ann) 432 | m = maskUtils.decode(rle) 433 | return m 434 | -------------------------------------------------------------------------------- /tests/ssd300/src/coco_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from nvidia.dali.pipeline import pipeline_def 18 | import nvidia.dali.types as types 19 | import nvidia.dali.fn as fn 20 | 21 | 22 | @pipeline_def 23 | def create_coco_pipeline(default_boxes, args): 24 | try: 25 | shard_id = torch.distributed.get_rank() 26 | num_shards = torch.distributed.get_world_size() 27 | except RuntimeError: 28 | shard_id = 0 29 | num_shards = 1 30 | 31 | images, bboxes, labels = fn.readers.coco(file_root=args.train_coco_root, 32 | annotations_file=args.train_annotate, 33 | skip_empty=True, 34 | shard_id=shard_id, 35 | num_shards=num_shards, 36 | ratio=True, 37 | ltrb=True, 38 | random_shuffle=False, 39 | shuffle_after_epoch=True, 40 | name="Reader") 41 | 42 | crop_begin, crop_size, bboxes, labels = fn.random_bbox_crop(bboxes, labels, 43 | device="cpu", 44 | aspect_ratio=[0.5, 2.0], 45 | thresholds=[0, 0.1, 0.3, 0.5, 0.7, 0.9], 46 | scaling=[0.3, 1.0], 47 | bbox_layout="xyXY", 48 | allow_no_crop=True, 49 | num_attempts=50) 50 | images = fn.decoders.image_slice(images, crop_begin, crop_size, device="mixed", output_type=types.RGB) 51 | flip_coin = fn.random.coin_flip(probability=0.5) 52 | images = fn.resize(images, 53 | resize_x=300, 54 | resize_y=300, 55 | min_filter=types.DALIInterpType.INTERP_TRIANGULAR) 56 | 57 | #saturation = fn.uniform(range=[0.5, 1.5]) 58 | #contrast = fn.uniform(range=[0.5, 1.5]) 59 | #brightness = fn.uniform(range=[0.875, 1.125]) 60 | #hue = fn.uniform(range=[-0.5, 0.5]) 61 | 62 | saturation = fn.random.uniform(range=[0.5, 1.5]) 63 | contrast = fn.random.uniform(range=[0.5, 1.5]) 64 | brightness = fn.random.uniform(range=[0.875, 1.125]) 65 | hue = fn.random.uniform(range=[-0.5, 0.5]) 66 | 67 | images = fn.hsv(images, dtype=types.FLOAT, hue=hue, saturation=saturation) # use float to avoid clipping and 68 | # quantizing the intermediate result 69 | images = fn.brightness_contrast(images, 70 | contrast_center = 128, # input is in float, but in 0..255 range 71 | dtype = types.UINT8, 72 | brightness = brightness, 73 | contrast = contrast) 74 | 75 | dtype = types.FLOAT16 if args.fp16 else types.FLOAT 76 | 77 | bboxes = fn.bb_flip(bboxes, ltrb=True, horizontal=flip_coin) 78 | images = fn.crop_mirror_normalize(images, 79 | crop=(300, 300), 80 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 81 | std=[0.229 * 255, 0.224 * 255, 0.225 * 255], 82 | mirror=flip_coin, 83 | dtype=dtype, 84 | output_layout="CHW", 85 | pad_output=False) 86 | 87 | bboxes, labels = fn.box_encoder(bboxes, labels, 88 | criteria=0.5, 89 | anchors=default_boxes.as_ltrb_list()) 90 | 91 | labels=labels.gpu() 92 | bboxes=bboxes.gpu() 93 | 94 | return images, bboxes, labels 95 | -------------------------------------------------------------------------------- /tests/ssd300/src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from src.utils import dboxes300_coco, COCODetection, SSDTransformer 8 | from src.coco import COCO 9 | from src.coco_pipeline import create_coco_pipeline 10 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 11 | 12 | 13 | def set_seeds(args): 14 | torch.cuda.set_device(args.local_rank) 15 | device = torch.device('cuda') 16 | 17 | if args.distributed: 18 | args.seed = broadcast_seeds(args.seed, device) 19 | local_seed = (args.seed + torch.distributed.get_rank()) % 2**32 20 | local_rank = torch.distributed.get_rank() 21 | 22 | #local_seed = args.seed % 2**32 23 | #local_rank = 0 24 | else: 25 | local_seed = args.seed % 2**32 26 | local_rank = 0 27 | 28 | print("Rank", local_rank, "using seed = {}".format(local_seed)) 29 | 30 | torch.manual_seed(local_seed) 31 | np.random.seed(seed=local_seed) 32 | 33 | return local_seed 34 | 35 | 36 | def broadcast_seeds(seed, device): 37 | if torch.distributed.is_initialized(): 38 | seeds_tensor = torch.LongTensor([seed]).to(device) 39 | torch.distributed.broadcast(seeds_tensor, 0) 40 | seed = seeds_tensor.item() 41 | return seed 42 | 43 | 44 | def get_train_pytorch_loader(args, num_workers, default_boxes): 45 | dataset = COCODetection( 46 | args.train_coco_root, 47 | args.train_annotate, 48 | SSDTransformer(default_boxes, args, (300, 300), val=False)) 49 | 50 | if args.distributed: 51 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 52 | #train_sampler = None 53 | else: 54 | train_sampler = None 55 | 56 | train_dataloader = DataLoader( 57 | dataset, 58 | batch_size=args.batch_size, 59 | shuffle=(train_sampler is None), 60 | sampler=train_sampler, 61 | drop_last=True, 62 | num_workers=num_workers) 63 | 64 | return train_dataloader 65 | 66 | 67 | def get_train_dali_loader(args, default_boxes, local_seed): 68 | train_pipe = create_coco_pipeline( 69 | default_boxes, 70 | args, 71 | batch_size=args.batch_size, 72 | num_threads=args.num_workers, 73 | device_id=args.local_rank, 74 | seed=local_seed) 75 | 76 | train_loader = DALIGenericIterator( 77 | train_pipe, 78 | ["images", "boxes", "labels"], 79 | reader_name="Reader", 80 | last_batch_policy=LastBatchPolicy.FILL) 81 | 82 | return train_loader 83 | 84 | 85 | def get_train_loader(args, dboxes): 86 | args.train_annotate = os.path.join( 87 | args.data, "annotations/instances_train2017.json") 88 | args.train_coco_root = os.path.join(args.data, "train2017") 89 | 90 | local_seed = set_seeds(args) 91 | 92 | if args.data_pipeline == 'no_dali': 93 | return get_train_pytorch_loader(args, args.num_workers, dboxes) 94 | elif args.data_pipeline == 'dali': 95 | return get_train_dali_loader(args, dboxes, local_seed) 96 | 97 | 98 | def get_val_dataset(args): 99 | dboxes = dboxes300_coco() 100 | val_trans = SSDTransformer(dboxes, args,(300, 300), val=True) 101 | 102 | val_annotate = os.path.join(args.data, "annotations/instances_val2017.json") 103 | val_coco_root = os.path.join(args.data, "val2017") 104 | 105 | val_coco = COCODetection(val_coco_root, val_annotate, val_trans) 106 | return val_coco 107 | 108 | 109 | def get_val_dataloader(args): 110 | dataset = get_val_dataset(args) 111 | inv_map = {v: k for k, v in dataset.label_map.items()} 112 | print(dataset.label_info) 113 | 114 | if args.distributed: 115 | val_sampler = torch.utils.data.distributed.DistributedSampler(dataset) 116 | #val_sampler = None 117 | else: 118 | val_sampler = None 119 | 120 | val_dataloader = DataLoader( 121 | dataset, 122 | batch_size=args.eval_batch_size, 123 | shuffle=False, # Note: distributed sampler is shuffled :( 124 | sampler=val_sampler, 125 | num_workers=args.num_workers) 126 | 127 | return val_dataloader, inv_map 128 | 129 | 130 | def get_coco_ground_truth(args): 131 | val_annotate = os.path.join(args.data, "annotations/instances_val2017.json") 132 | cocoGt = COCO(annotation_file=val_annotate) 133 | return cocoGt 134 | -------------------------------------------------------------------------------- /tests/ssd300/src/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 3 | import torch.distributed as dist 4 | from torch.nn.modules import Module 5 | 6 | ''' 7 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 8 | launcher included with this example. It assumes that your run is using multiprocess with 1 9 | GPU/process, that the model is on the correct device, and that torch.set_device has been 10 | used to set the device. 11 | 12 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 13 | and will be allreduced at the finish of the backward pass. 14 | ''' 15 | class DistributedDataParallel(Module): 16 | 17 | def __init__(self, module): 18 | super(DistributedDataParallel, self).__init__() 19 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 20 | 21 | self.module = module 22 | 23 | for p in self.module.state_dict().values(): 24 | if not torch.is_tensor(p): 25 | continue 26 | if dist._backend == dist.dist_backend.NCCL: 27 | assert p.is_cuda, "NCCL backend only supports model parameters to be on GPU." 28 | dist.broadcast(p, 0) 29 | 30 | def allreduce_params(): 31 | if(self.needs_reduction): 32 | self.needs_reduction = False 33 | buckets = {} 34 | for param in self.module.parameters(): 35 | if param.requires_grad and param.grad is not None: 36 | tp = param.data.type() 37 | if tp not in buckets: 38 | buckets[tp] = [] 39 | buckets[tp].append(param) 40 | if self.warn_on_half: 41 | if torch.cuda.HalfTensor in buckets: 42 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 43 | " It is recommended to use the NCCL backend in this case.") 44 | self.warn_on_half = False 45 | 46 | for tp in buckets: 47 | bucket = buckets[tp] 48 | grads = [param.grad.data for param in bucket] 49 | coalesced = _flatten_dense_tensors(grads) 50 | dist.all_reduce(coalesced) 51 | coalesced /= dist.get_world_size() 52 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 53 | buf.copy_(synced) 54 | 55 | for param in list(self.module.parameters()): 56 | def allreduce_hook(*unused): 57 | param._execution_engine.queue_callback(allreduce_params) 58 | if param.requires_grad: 59 | param.register_hook(allreduce_hook) 60 | 61 | def forward(self, *inputs, **kwargs): 62 | self.needs_reduction = True 63 | return self.module(*inputs, **kwargs) 64 | 65 | ''' 66 | def _sync_buffers(self): 67 | buffers = list(self.module._all_buffers()) 68 | if len(buffers) > 0: 69 | # cross-node buffer sync 70 | flat_buffers = _flatten_dense_tensors(buffers) 71 | dist.broadcast(flat_buffers, 0) 72 | for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): 73 | buf.copy_(synced) 74 | def train(self, mode=True): 75 | # Clear NCCL communicator and CUDA event cache of the default group ID, 76 | # These cache will be recreated at the later call. This is currently a 77 | # work-around for a potential NCCL deadlock. 78 | if dist._backend == dist.dist_backend.NCCL: 79 | dist._clear_group_cache() 80 | super(DistributedDataParallel, self).train(mode) 81 | self.module.train(mode) 82 | ''' 83 | -------------------------------------------------------------------------------- /tests/ssd300/src/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import numpy as np 4 | from contextlib import redirect_stdout 5 | import io 6 | 7 | from pycocotools.cocoeval import COCOeval 8 | 9 | 10 | def evaluate(model, coco, cocoGt, encoder, inv_map, args): 11 | if args.distributed: 12 | N_gpu = torch.distributed.get_world_size() 13 | else: 14 | N_gpu = 1 15 | 16 | model.eval() 17 | model.cuda() 18 | 19 | ret = [] 20 | start = time.time() 21 | 22 | # for idx, image_id in enumerate(coco.img_keys): 23 | for nbatch, (img, img_id, img_size, _, _) in enumerate(coco): 24 | print("Parsing batch: {}/{}".format(nbatch, len(coco)), end='\r') 25 | with torch.no_grad(): 26 | inp = img.cuda() 27 | if args.fp16: 28 | #inp = inp.half() 29 | pass 30 | 31 | # Get predictions 32 | ploc, plabel = model(inp) 33 | ploc, plabel = ploc.float(), plabel.float() 34 | 35 | # Handle the batch of predictions produced 36 | # This is slow, but consistent with old implementation. 37 | for idx in range(ploc.shape[0]): 38 | # ease-of-use for specific predictions 39 | ploc_i = ploc[idx, :, :].unsqueeze(0) 40 | plabel_i = plabel[idx, :, :].unsqueeze(0) 41 | 42 | try: 43 | result = encoder.decode_batch(ploc_i, plabel_i, 0.50, 200)[0] 44 | except: 45 | # raise 46 | print("") 47 | print("No object detected in idx: {}".format(idx)) 48 | continue 49 | 50 | htot, wtot = img_size[0][idx].item(), img_size[1][idx].item() 51 | loc, label, prob = [r.cpu().numpy() for r in result] 52 | for loc_, label_, prob_ in zip(loc, label, prob): 53 | ret.append([img_id[idx], loc_[0] * wtot, \ 54 | loc_[1] * htot, 55 | (loc_[2] - loc_[0]) * wtot, 56 | (loc_[3] - loc_[1]) * htot, 57 | prob_, 58 | inv_map[label_]]) 59 | 60 | # Now we have all predictions from this rank, gather them all together 61 | # if necessary 62 | ret = np.array(ret).astype(np.float32) 63 | 64 | # Multi-GPU eval 65 | if args.distributed: 66 | # NCCL backend means we can only operate on GPU tensors 67 | ret_copy = torch.tensor(ret).cuda() 68 | # Everyone exchanges the size of their results 69 | ret_sizes = [torch.tensor(0).cuda() for _ in range(N_gpu)] 70 | 71 | torch.cuda.synchronize() 72 | torch.distributed.all_gather(ret_sizes, torch.tensor(ret_copy.shape[0]).cuda()) 73 | torch.cuda.synchronize() 74 | 75 | # Get the maximum results size, as all tensors must be the same shape for 76 | # the all_gather call we need to make 77 | max_size = 0 78 | sizes = [] 79 | for s in ret_sizes: 80 | max_size = max(max_size, s.item()) 81 | sizes.append(s.item()) 82 | 83 | # Need to pad my output to max_size in order to use in all_gather 84 | ret_pad = torch.cat([ret_copy, torch.zeros(max_size - ret_copy.shape[0], 7, dtype=torch.float32).cuda()]) 85 | 86 | # allocate storage for results from all other processes 87 | other_ret = [torch.zeros(max_size, 7, dtype=torch.float32).cuda() for i in range(N_gpu)] 88 | # Everyone exchanges (padded) results 89 | 90 | torch.cuda.synchronize() 91 | torch.distributed.all_gather(other_ret, ret_pad) 92 | torch.cuda.synchronize() 93 | 94 | # Now need to reconstruct the _actual_ results from the padded set using slices. 95 | cat_tensors = [] 96 | for i in range(N_gpu): 97 | cat_tensors.append(other_ret[i][:sizes[i]][:]) 98 | 99 | final_results = torch.cat(cat_tensors).cpu().numpy() 100 | else: 101 | # Otherwise full results are just our results 102 | final_results = ret 103 | 104 | if args.local_rank == 0: 105 | print("") 106 | print("Predicting Ended, total time: {:.2f} s".format(time.time() - start)) 107 | 108 | cocoDt = cocoGt.loadRes(final_results) 109 | 110 | E = COCOeval(cocoGt, cocoDt, iouType='bbox') 111 | E.evaluate() 112 | E.accumulate() 113 | if args.local_rank == 0: 114 | E.summarize() 115 | print("Current AP: {:.5f}".format(E.stats[0])) 116 | else: 117 | # fix for cocoeval indiscriminate prints 118 | with redirect_stdout(io.StringIO()): 119 | E.summarize() 120 | 121 | # put your model in training mode back on 122 | model.train() 123 | 124 | return E.stats[0] # Average Precision (AP) @[ IoU=050:0.95 | area= all | maxDets=100 ] 125 | 126 | -------------------------------------------------------------------------------- /tests/ssd300/src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 4 | from torch.nn.parallel import DistributedDataParallel as DDP 5 | 6 | 7 | 8 | class ResNet(nn.Module): 9 | def __init__(self, backbone='resnet50'): 10 | super().__init__() 11 | if backbone == 'resnet18': 12 | backbone = resnet18(pretrained=True) 13 | self.out_channels = [256, 512, 512, 256, 256, 128] 14 | elif backbone == 'resnet34': 15 | backbone = resnet34(pretrained=True) 16 | self.out_channels = [256, 512, 512, 256, 256, 256] 17 | elif backbone == 'resnet50': 18 | backbone = resnet50(pretrained=True) 19 | self.out_channels = [1024, 512, 512, 256, 256, 256] 20 | elif backbone == 'resnet101': 21 | backbone = resnet101(pretrained=True) 22 | self.out_channels = [1024, 512, 512, 256, 256, 256] 23 | else: # backbone == 'resnet152': 24 | backbone = resnet152(pretrained=True) 25 | self.out_channels = [1024, 512, 512, 256, 256, 256] 26 | 27 | 28 | self.feature_extractor = nn.Sequential(*list(backbone.children())[:7]) 29 | 30 | conv4_block1 = self.feature_extractor[-1][0] 31 | 32 | conv4_block1.conv1.stride = (1, 1) 33 | conv4_block1.conv2.stride = (1, 1) 34 | conv4_block1.downsample[0].stride = (1, 1) 35 | 36 | def forward(self, x): 37 | x = self.feature_extractor(x) 38 | return x 39 | 40 | 41 | class SSD300(nn.Module): 42 | def __init__(self, backbone='resnet50'): 43 | super().__init__() 44 | 45 | self.feature_extractor = ResNet(backbone=backbone) 46 | 47 | self.label_num = 81 # number of COCO classes 48 | self._build_additional_features(self.feature_extractor.out_channels) 49 | self.num_defaults = [4, 6, 6, 6, 4, 4] 50 | self.loc = [] 51 | self.conf = [] 52 | 53 | for nd, oc in zip(self.num_defaults, self.feature_extractor.out_channels): 54 | self.loc.append(nn.Conv2d(oc, nd * 4, kernel_size=3, padding=1)) 55 | self.conf.append(nn.Conv2d(oc, nd * self.label_num, kernel_size=3, padding=1)) 56 | 57 | self.loc = nn.ModuleList(self.loc) 58 | self.conf = nn.ModuleList(self.conf) 59 | self._init_weights() 60 | 61 | def _build_additional_features(self, input_size): 62 | self.additional_blocks = [] 63 | for i, (input_size, output_size, channels) in enumerate(zip(input_size[:-1], input_size[1:], [256, 256, 128, 128, 128])): 64 | if i < 3: 65 | layer = nn.Sequential( 66 | nn.Conv2d(input_size, channels, kernel_size=1, bias=False), 67 | nn.BatchNorm2d(channels), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(channels, output_size, kernel_size=3, padding=1, stride=2, bias=False), 70 | nn.BatchNorm2d(output_size), 71 | nn.ReLU(inplace=True), 72 | ) 73 | else: 74 | layer = nn.Sequential( 75 | nn.Conv2d(input_size, channels, kernel_size=1, bias=False), 76 | nn.BatchNorm2d(channels), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(channels, output_size, kernel_size=3, bias=False), 79 | nn.BatchNorm2d(output_size), 80 | nn.ReLU(inplace=True), 81 | ) 82 | 83 | self.additional_blocks.append(layer) 84 | 85 | self.additional_blocks = nn.ModuleList(self.additional_blocks) 86 | 87 | def _init_weights(self): 88 | layers = [*self.additional_blocks, *self.loc, *self.conf] 89 | for layer in layers: 90 | for param in layer.parameters(): 91 | if param.dim() > 1: nn.init.xavier_uniform_(param) 92 | 93 | # Shape the classifier to the view of bboxes 94 | def bbox_view(self, src, loc, conf): 95 | ret = [] 96 | for s, l, c in zip(src, loc, conf): 97 | ret.append((l(s).view(s.size(0), 4, -1), c(s).view(s.size(0), self.label_num, -1))) 98 | 99 | locs, confs = list(zip(*ret)) 100 | locs, confs = torch.cat(locs, 2).contiguous(), torch.cat(confs, 2).contiguous() 101 | return locs, confs 102 | 103 | def forward(self, x): 104 | x = self.feature_extractor(x) 105 | 106 | detection_feed = [x] 107 | for l in self.additional_blocks: 108 | x = l(x) 109 | detection_feed.append(x) 110 | 111 | # Feature Map 38x38x4, 19x19x6, 10x10x6, 5x5x6, 3x3x4, 1x1x4 112 | locs, confs = self.bbox_view(detection_feed, self.loc, self.conf) 113 | 114 | # For SSD 300, shall return nbatch x 8732 x {nlabels, nlocs} results 115 | return locs, confs 116 | 117 | def model(args,onnx_save=False): 118 | ssd300 = SSD300(backbone=args.backbone) 119 | 120 | if onnx_save: 121 | return ssd300 122 | 123 | ssd300.cuda() 124 | 125 | if args.distributed: 126 | ssd300 = DDP(ssd300,device_ids=[args.local_rank],output_device=args.local_rank,find_unused_parameters=True) 127 | 128 | return ssd300 129 | 130 | 131 | class Loss(nn.Module): 132 | """ 133 | Implements the loss as the sum of the followings: 134 | 1. Confidence Loss: All labels, with hard negative mining 135 | 2. Localization Loss: Only on positive labels 136 | Suppose input dboxes has the shape 8732x4 137 | """ 138 | def __init__(self, dboxes): 139 | super(Loss, self).__init__() 140 | self.scale_xy = 1.0/dboxes.scale_xy 141 | self.scale_wh = 1.0/dboxes.scale_wh 142 | 143 | #self.sl1_loss = nn.SmoothL1Loss(reduce=False) 144 | #self.sl1_loss = nn.SmoothL1Loss(reduce=None) 145 | self.sl1_loss = nn.SmoothL1Loss(reduction='none',reduce=None) 146 | self.dboxes = nn.Parameter(dboxes(order="xywh").transpose(0, 1).unsqueeze(dim = 0), 147 | requires_grad=False) 148 | # Two factor are from following links 149 | # http://jany.st/post/2017-11-05-single-shot-detector-ssd-from-scratch-in-tensorflow.html 150 | #self.con_loss = nn.CrossEntropyLoss(reduce=False) 151 | self.con_loss = nn.CrossEntropyLoss(reduce=None,reduction='none') 152 | 153 | def _loc_vec(self, loc): 154 | """ 155 | Generate Location Vectors 156 | """ 157 | gxy = self.scale_xy*(loc[:, :2, :] - self.dboxes[:, :2, :])/self.dboxes[:, 2:, ] 158 | gwh = self.scale_wh*(loc[:, 2:, :]/self.dboxes[:, 2:, :]).log() 159 | return torch.cat((gxy, gwh), dim=1).contiguous() 160 | 161 | def forward(self, ploc, plabel, gloc, glabel): 162 | """ 163 | ploc, plabel: Nx4x8732, Nxlabel_numx8732 164 | predicted location and labels 165 | 166 | gloc, glabel: Nx4x8732, Nx8732 167 | ground truth location and labels 168 | """ 169 | mask = glabel > 0 170 | pos_num = mask.sum(dim=1) 171 | 172 | vec_gd = self._loc_vec(gloc) 173 | 174 | # sum on four coordinates, and mask 175 | #print(ploc.shape, vec_gd.shape) 176 | 177 | 178 | #org: 179 | #sl1 = self.sl1_loss(ploc, vec_gd).sum(dim=1) 180 | #change: 181 | sl1 = self.sl1_loss(ploc, vec_gd)#.sum(dim=1) 182 | sl1 = torch.sum(sl1,dim=1) 183 | 184 | #print(sl1.shape,mask.shape) 185 | sl1 = (mask.float()*sl1).sum(dim=1) 186 | 187 | # hard negative mining 188 | con = self.con_loss(plabel, glabel) 189 | 190 | # postive mask will never selected 191 | #print(con.shape) 192 | con_neg = con.clone() 193 | #print(mask.shape) 194 | #print(con_neg.shape) 195 | con_neg[mask] = 0 196 | _, con_idx = con_neg.sort(dim=1, descending=True) 197 | _, con_rank = con_idx.sort(dim=1) 198 | 199 | # number of negative three times positive 200 | neg_num = torch.clamp(3*pos_num, max=mask.size(1)).unsqueeze(-1) 201 | neg_mask = con_rank < neg_num 202 | 203 | #print(con.shape, mask.shape, neg_mask.shape) 204 | closs = (con*(mask.float() + neg_mask.float())).sum(dim=1) 205 | 206 | # avoid no object detected 207 | total_loss = sl1 + closs 208 | num_mask = (pos_num > 0).float() 209 | pos_num = pos_num.float().clamp(min=1e-6) 210 | ret = (total_loss*num_mask/pos_num).mean(dim=0) 211 | return ret 212 | -------------------------------------------------------------------------------- /tests/ssd300/src/train.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | import torch 3 | import time 4 | from torch.cuda.amp import autocast as autocast, GradScaler 5 | from ncnnqat import unquant_weight 6 | scaler = GradScaler() 7 | def train_loop(model, loss_func, epoch, optim, train_loader, iteration, logger, args): 8 | for nbatch, data in enumerate(train_loader): 9 | if args.data_pipeline == 'no_dali': 10 | (img, _, img_size, bbox, label) = data 11 | img = img.cuda() 12 | bbox = bbox.cuda() 13 | label = label.cuda() 14 | else: 15 | img = data[0]["images"] 16 | bbox = data[0]["boxes"] 17 | label = data[0]["labels"] 18 | label = label.type(torch.cuda.LongTensor) 19 | 20 | #print(img.dtype) 21 | #print(bbox.dtype) 22 | #print(label.dtype) 23 | #print("====================================") 24 | #boxes_in_batch = len(label.nonzero()) 25 | boxes_in_batch = len(label.nonzero(as_tuple=True)) 26 | 27 | 28 | if boxes_in_batch != 0: 29 | 30 | 31 | trans_bbox = bbox.transpose(1, 2).contiguous().cuda() 32 | 33 | label = label.cuda() 34 | gloc = Variable(trans_bbox, requires_grad=False) 35 | glabel = Variable(label, requires_grad=False) 36 | 37 | with autocast(): 38 | ploc, plabel = model(img) 39 | ploc, plabel = ploc.float(), plabel.float() 40 | loss = loss_func(ploc, plabel, gloc, glabel) 41 | 42 | 43 | logger.update_iter(epoch, iteration, loss.item()) 44 | 45 | if args.fp16: 46 | scaler.scale(loss).backward() 47 | else: 48 | loss.backward() 49 | 50 | if args.warmup is not None: 51 | warmup(optim, args.warmup, iteration, args.learning_rate) 52 | if args.fp16: 53 | scaler.step(optim) 54 | scaler.update() 55 | else: 56 | '''qat''' 57 | if epoch >= args.qat_epoch: 58 | model.apply(unquant_weight) 59 | '''qat''' 60 | optim.step() 61 | optim.zero_grad() 62 | iteration += 1 63 | 64 | return iteration 65 | 66 | 67 | def warmup(optim, warmup_iters, iteration, base_lr): 68 | if iteration < warmup_iters: 69 | new_lr = 1. * base_lr / warmup_iters * iteration 70 | for param_group in optim.param_groups: 71 | param_group['lr'] = new_lr 72 | 73 | 74 | def tencent_trick(model): 75 | """ 76 | Divide parameters into 2 groups. 77 | First group is BNs and all biases. 78 | Second group is the remaining model's parameters. 79 | Weight decay will be disabled in first group (aka tencent trick). 80 | """ 81 | decay, no_decay = [], [] 82 | for name, param in model.named_parameters(): 83 | if not param.requires_grad: 84 | continue # frozen weights 85 | if len(param.shape) == 1 or name.endswith(".bias"): 86 | no_decay.append(param) 87 | else: 88 | decay.append(param) 89 | return [{'params': no_decay, 'weight_decay': 0.0}, 90 | {'params': decay}] 91 | -------------------------------------------------------------------------------- /tests/ssd300/src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import os 6 | import numpy as np 7 | import random 8 | import itertools 9 | import torch.nn.functional as F 10 | import json 11 | import time 12 | import bz2 13 | import pickle 14 | from math import sqrt 15 | # from src.coco_pipeline import COCOReaderPipeline 16 | 17 | 18 | # This function is from https://github.com/kuangliu/pytorch-ssd. 19 | def calc_iou_tensor(box1, box2): 20 | """ Calculation of IoU based on two boxes tensor, 21 | Reference to https://github.com/kuangliu/pytorch-src 22 | input: 23 | box1 (N, 4) 24 | box2 (M, 4) 25 | output: 26 | IoU (N, M) 27 | """ 28 | N = box1.size(0) 29 | M = box2.size(0) 30 | 31 | be1 = box1.unsqueeze(1).expand(-1, M, -1) 32 | be2 = box2.unsqueeze(0).expand(N, -1, -1) 33 | 34 | # Left Top & Right Bottom 35 | lt = torch.max(be1[:,:,:2], be2[:,:,:2]) 36 | #mask1 = (be1[:,:, 0] < be2[:,:, 0]) ^ (be1[:,:, 1] < be2[:,:, 1]) 37 | #mask1 = ~mask1 38 | rb = torch.min(be1[:,:,2:], be2[:,:,2:]) 39 | #mask2 = (be1[:,:, 2] < be2[:,:, 2]) ^ (be1[:,:, 3] < be2[:,:, 3]) 40 | #mask2 = ~mask2 41 | 42 | delta = rb - lt 43 | delta[delta < 0] = 0 44 | intersect = delta[:,:,0]*delta[:,:,1] 45 | #*mask1.float()*mask2.float() 46 | 47 | delta1 = be1[:,:,2:] - be1[:,:,:2] 48 | area1 = delta1[:,:,0]*delta1[:,:,1] 49 | delta2 = be2[:,:,2:] - be2[:,:,:2] 50 | area2 = delta2[:,:,0]*delta2[:,:,1] 51 | 52 | iou = intersect/(area1 + area2 - intersect) 53 | return iou 54 | 55 | 56 | # This function is from https://github.com/kuangliu/pytorch-ssd. 57 | class Encoder(object): 58 | """ 59 | Inspired by https://github.com/kuangliu/pytorch-src 60 | Transform between (bboxes, lables) <-> SSD output 61 | 62 | dboxes: default boxes in size 8732 x 4, 63 | encoder: input ltrb format, output xywh format 64 | decoder: input xywh format, output ltrb format 65 | 66 | encode: 67 | input : bboxes_in (Tensor nboxes x 4), labels_in (Tensor nboxes) 68 | output : bboxes_out (Tensor 8732 x 4), labels_out (Tensor 8732) 69 | criteria : IoU threshold of bboexes 70 | 71 | decode: 72 | input : bboxes_in (Tensor 8732 x 4), scores_in (Tensor 8732 x nitems) 73 | output : bboxes_out (Tensor nboxes x 4), labels_out (Tensor nboxes) 74 | criteria : IoU threshold of bboexes 75 | max_output : maximum number of output bboxes 76 | """ 77 | 78 | def __init__(self, dboxes): 79 | self.dboxes = dboxes(order="ltrb") 80 | self.dboxes_xywh = dboxes(order="xywh").unsqueeze(dim=0) 81 | self.nboxes = self.dboxes.size(0) 82 | #print("# Bounding boxes: {}".format(self.nboxes)) 83 | self.scale_xy = dboxes.scale_xy 84 | self.scale_wh = dboxes.scale_wh 85 | 86 | def encode(self, bboxes_in, labels_in, criteria = 0.5): 87 | 88 | ious = calc_iou_tensor(bboxes_in, self.dboxes) 89 | best_dbox_ious, best_dbox_idx = ious.max(dim=0) 90 | best_bbox_ious, best_bbox_idx = ious.max(dim=1) 91 | 92 | # set best ious 2.0 93 | best_dbox_ious.index_fill_(0, best_bbox_idx, 2.0) 94 | 95 | idx = torch.arange(0, best_bbox_idx.size(0), dtype=torch.int64) 96 | best_dbox_idx[best_bbox_idx[idx]] = idx 97 | 98 | # filter IoU > 0.5 99 | masks = best_dbox_ious > criteria 100 | labels_out = torch.zeros(self.nboxes, dtype=torch.long) 101 | #print(maxloc.shape, labels_in.shape, labels_out.shape) 102 | labels_out[masks] = labels_in[best_dbox_idx[masks]] 103 | bboxes_out = self.dboxes.clone() 104 | bboxes_out[masks, :] = bboxes_in[best_dbox_idx[masks], :] 105 | # Transform format to xywh format 106 | x, y, w, h = 0.5*(bboxes_out[:, 0] + bboxes_out[:, 2]), \ 107 | 0.5*(bboxes_out[:, 1] + bboxes_out[:, 3]), \ 108 | -bboxes_out[:, 0] + bboxes_out[:, 2], \ 109 | -bboxes_out[:, 1] + bboxes_out[:, 3] 110 | bboxes_out[:, 0] = x 111 | bboxes_out[:, 1] = y 112 | bboxes_out[:, 2] = w 113 | bboxes_out[:, 3] = h 114 | return bboxes_out, labels_out 115 | 116 | def scale_back_batch(self, bboxes_in, scores_in): 117 | """ 118 | Do scale and transform from xywh to ltrb 119 | suppose input Nx4xnum_bbox Nxlabel_numxnum_bbox 120 | """ 121 | if bboxes_in.device == torch.device("cpu"): 122 | self.dboxes = self.dboxes.cpu() 123 | self.dboxes_xywh = self.dboxes_xywh.cpu() 124 | else: 125 | self.dboxes = self.dboxes.cuda() 126 | self.dboxes_xywh = self.dboxes_xywh.cuda() 127 | 128 | bboxes_in = bboxes_in.permute(0, 2, 1) 129 | scores_in = scores_in.permute(0, 2, 1) 130 | #print(bboxes_in.device, scores_in.device, self.dboxes_xywh.device) 131 | 132 | bboxes_in[:, :, :2] = self.scale_xy*bboxes_in[:, :, :2] 133 | bboxes_in[:, :, 2:] = self.scale_wh*bboxes_in[:, :, 2:] 134 | 135 | bboxes_in[:, :, :2] = bboxes_in[:, :, :2]*self.dboxes_xywh[:, :, 2:] + self.dboxes_xywh[:, :, :2] 136 | bboxes_in[:, :, 2:] = bboxes_in[:, :, 2:].exp()*self.dboxes_xywh[:, :, 2:] 137 | 138 | # Transform format to ltrb 139 | l, t, r, b = bboxes_in[:, :, 0] - 0.5*bboxes_in[:, :, 2],\ 140 | bboxes_in[:, :, 1] - 0.5*bboxes_in[:, :, 3],\ 141 | bboxes_in[:, :, 0] + 0.5*bboxes_in[:, :, 2],\ 142 | bboxes_in[:, :, 1] + 0.5*bboxes_in[:, :, 3] 143 | 144 | bboxes_in[:, :, 0] = l 145 | bboxes_in[:, :, 1] = t 146 | bboxes_in[:, :, 2] = r 147 | bboxes_in[:, :, 3] = b 148 | 149 | return bboxes_in, F.softmax(scores_in, dim=-1) 150 | 151 | def decode_batch(self, bboxes_in, scores_in, criteria = 0.45, max_output=200): 152 | bboxes, probs = self.scale_back_batch(bboxes_in, scores_in) 153 | 154 | output = [] 155 | for bbox, prob in zip(bboxes.split(1, 0), probs.split(1, 0)): 156 | bbox = bbox.squeeze(0) 157 | prob = prob.squeeze(0) 158 | output.append(self.decode_single(bbox, prob, criteria, max_output)) 159 | #print(output[-1]) 160 | return output 161 | 162 | # perform non-maximum suppression 163 | def decode_single(self, bboxes_in, scores_in, criteria, max_output, max_num=200): 164 | # Reference to https://github.com/amdegroot/ssd.pytorch 165 | 166 | bboxes_out = [] 167 | scores_out = [] 168 | labels_out = [] 169 | 170 | for i, score in enumerate(scores_in.split(1, 1)): 171 | # skip background 172 | # print(score[score>0.90]) 173 | if i == 0: continue 174 | # print(i) 175 | 176 | score = score.squeeze(1) 177 | mask = score > 0.05 178 | 179 | bboxes, score = bboxes_in[mask, :], score[mask] 180 | if score.size(0) == 0: continue 181 | 182 | score_sorted, score_idx_sorted = score.sort(dim=0) 183 | 184 | # select max_output indices 185 | score_idx_sorted = score_idx_sorted[-max_num:] 186 | candidates = [] 187 | #maxdata, maxloc = scores_in.sort() 188 | 189 | while score_idx_sorted.numel() > 0: 190 | idx = score_idx_sorted[-1].item() 191 | bboxes_sorted = bboxes[score_idx_sorted, :] 192 | bboxes_idx = bboxes[idx, :].unsqueeze(dim=0) 193 | iou_sorted = calc_iou_tensor(bboxes_sorted, bboxes_idx).squeeze() 194 | # we only need iou < criteria 195 | score_idx_sorted = score_idx_sorted[iou_sorted < criteria] 196 | candidates.append(idx) 197 | 198 | bboxes_out.append(bboxes[candidates, :]) 199 | scores_out.append(score[candidates]) 200 | labels_out.extend([i]*len(candidates)) 201 | 202 | bboxes_out, labels_out, scores_out = torch.cat(bboxes_out, dim=0), \ 203 | torch.tensor(labels_out, dtype=torch.long), \ 204 | torch.cat(scores_out, dim=0) 205 | 206 | 207 | _, max_ids = scores_out.sort(dim=0) 208 | max_ids = max_ids[-max_output:] 209 | return bboxes_out[max_ids, :], labels_out[max_ids], scores_out[max_ids] 210 | 211 | 212 | class DefaultBoxes(object): 213 | def __init__(self, fig_size, feat_size, steps, scales, aspect_ratios, \ 214 | scale_xy=0.1, scale_wh=0.2): 215 | 216 | self.feat_size = feat_size 217 | self.fig_size = fig_size 218 | 219 | self.scale_xy_ = scale_xy 220 | self.scale_wh_ = scale_wh 221 | 222 | # According to https://github.com/weiliu89/caffe 223 | # Calculation method slightly different from paper 224 | self.steps = steps 225 | self.scales = scales 226 | 227 | fk = fig_size/np.array(steps) 228 | self.aspect_ratios = aspect_ratios 229 | 230 | self.default_boxes = [] 231 | # size of feature and number of feature 232 | for idx, sfeat in enumerate(self.feat_size): 233 | 234 | sk1 = scales[idx]/fig_size 235 | sk2 = scales[idx+1]/fig_size 236 | sk3 = sqrt(sk1*sk2) 237 | all_sizes = [(sk1, sk1), (sk3, sk3)] 238 | 239 | for alpha in aspect_ratios[idx]: 240 | w, h = sk1*sqrt(alpha), sk1/sqrt(alpha) 241 | all_sizes.append((w, h)) 242 | all_sizes.append((h, w)) 243 | for w, h in all_sizes: 244 | for i, j in itertools.product(range(sfeat), repeat=2): 245 | cx, cy = (j+0.5)/fk[idx], (i+0.5)/fk[idx] 246 | self.default_boxes.append((cx, cy, w, h)) 247 | 248 | self.dboxes = torch.tensor(self.default_boxes, dtype = torch.float) 249 | self.dboxes.clamp_(min=0, max=1) 250 | # For IoU calculation 251 | self.dboxes_ltrb = self.dboxes.clone() 252 | self.dboxes_ltrb[:, 0] = self.dboxes[:, 0] - 0.5 * self.dboxes[:, 2] 253 | self.dboxes_ltrb[:, 1] = self.dboxes[:, 1] - 0.5 * self.dboxes[:, 3] 254 | self.dboxes_ltrb[:, 2] = self.dboxes[:, 0] + 0.5 * self.dboxes[:, 2] 255 | self.dboxes_ltrb[:, 3] = self.dboxes[:, 1] + 0.5 * self.dboxes[:, 3] 256 | 257 | @property 258 | def scale_xy(self): 259 | return self.scale_xy_ 260 | 261 | @property 262 | def scale_wh(self): 263 | return self.scale_wh_ 264 | 265 | def as_ltrb_list(self): 266 | return [x for x in self.dboxes_ltrb.view(-1).numpy()] 267 | 268 | def __call__(self, order="ltrb"): 269 | if order == "ltrb": return self.dboxes_ltrb 270 | if order == "xywh": return self.dboxes 271 | 272 | 273 | def dboxes300_coco(): 274 | figsize = 300 275 | feat_size = [38, 19, 10, 5, 3, 1] 276 | steps = [8, 16, 32, 64, 100, 300] 277 | # use the scales here: https://github.com/amdegroot/ssd.pytorch/blob/master/data/config.py 278 | scales = [21, 45, 99, 153, 207, 261, 315] 279 | aspect_ratios = [[2], [2, 3], [2, 3], [2, 3], [2], [2]] 280 | dboxes = DefaultBoxes(figsize, feat_size, steps, scales, aspect_ratios) 281 | return dboxes 282 | 283 | 284 | # This function is from https://github.com/chauhan-utk/ssd.DomainAdaptation. 285 | class SSDCropping(object): 286 | """ Cropping for SSD, according to original paper 287 | Choose between following 3 conditions: 288 | 1. Preserve the original image 289 | 2. Random crop minimum IoU is among 0.1, 0.3, 0.5, 0.7, 0.9 290 | 3. Random crop 291 | Reference to https://github.com/chauhan-utk/src.DomainAdaptation 292 | """ 293 | def __init__(self): 294 | 295 | self.sample_options = ( 296 | # Do nothing 297 | None, 298 | # min IoU, max IoU 299 | (0.1, None), 300 | (0.3, None), 301 | (0.5, None), 302 | (0.7, None), 303 | (0.9, None), 304 | # no IoU requirements 305 | (None, None), 306 | ) 307 | 308 | def __call__(self, img, img_size, bboxes, labels): 309 | 310 | # Ensure always return cropped image 311 | while True: 312 | mode = random.choice(self.sample_options) 313 | 314 | if mode is None: 315 | return img, img_size, bboxes, labels 316 | 317 | htot, wtot = img_size 318 | 319 | min_iou, max_iou = mode 320 | min_iou = float("-inf") if min_iou is None else min_iou 321 | max_iou = float("+inf") if max_iou is None else max_iou 322 | 323 | # Implementation use 50 iteration to find possible candidate 324 | for _ in range(1): 325 | # suze of each sampled path in [0.1, 1] 0.3*0.3 approx. 0.1 326 | w = random.uniform(0.3 , 1.0) 327 | h = random.uniform(0.3 , 1.0) 328 | 329 | if w/h < 0.5 or w/h > 2: 330 | continue 331 | 332 | # left 0 ~ wtot - w, top 0 ~ htot - h 333 | left = random.uniform(0, 1.0 - w) 334 | top = random.uniform(0, 1.0 - h) 335 | 336 | right = left + w 337 | bottom = top + h 338 | 339 | ious = calc_iou_tensor(bboxes, torch.tensor([[left, top, right, bottom]])) 340 | 341 | # tailor all the bboxes and return 342 | if not ((ious > min_iou) & (ious < max_iou)).all(): 343 | continue 344 | 345 | # discard any bboxes whose center not in the cropped image 346 | xc = 0.5*(bboxes[:, 0] + bboxes[:, 2]) 347 | yc = 0.5*(bboxes[:, 1] + bboxes[:, 3]) 348 | 349 | masks = (xc > left) & (xc < right) & (yc > top) & (yc < bottom) 350 | 351 | # if no such boxes, continue searching again 352 | if not masks.any(): 353 | continue 354 | 355 | bboxes[bboxes[:, 0] < left, 0] = left 356 | bboxes[bboxes[:, 1] < top, 1] = top 357 | bboxes[bboxes[:, 2] > right, 2] = right 358 | bboxes[bboxes[:, 3] > bottom, 3] = bottom 359 | 360 | #print(left, top, right, bottom) 361 | #print(labels, bboxes, masks) 362 | bboxes = bboxes[masks, :] 363 | labels = labels[masks] 364 | 365 | left_idx = int(left*wtot) 366 | top_idx = int(top*htot) 367 | right_idx = int(right*wtot) 368 | bottom_idx = int(bottom*htot) 369 | #print(left_idx,top_idx,right_idx,bottom_idx) 370 | #img = img[:, top_idx:bottom_idx, left_idx:right_idx] 371 | img = img.crop((left_idx, top_idx, right_idx, bottom_idx)) 372 | 373 | bboxes[:, 0] = (bboxes[:, 0] - left)/w 374 | bboxes[:, 1] = (bboxes[:, 1] - top)/h 375 | bboxes[:, 2] = (bboxes[:, 2] - left)/w 376 | bboxes[:, 3] = (bboxes[:, 3] - top)/h 377 | 378 | htot = bottom_idx - top_idx 379 | wtot = right_idx - left_idx 380 | return img, (htot, wtot), bboxes, labels 381 | 382 | 383 | class RandomHorizontalFlip(object): 384 | def __init__(self, p=0.5): 385 | self.p = p 386 | 387 | def __call__(self, image, bboxes): 388 | if random.random() < self.p: 389 | bboxes[:, 0], bboxes[:, 2] = 1.0 - bboxes[:, 2], 1.0 - bboxes[:, 0] 390 | return image.transpose(Image.FLIP_LEFT_RIGHT), bboxes 391 | return image, bboxes 392 | 393 | 394 | # Do data augumentation 395 | class SSDTransformer(object): 396 | """ SSD Data Augumentation, according to original paper 397 | Composed by several steps: 398 | Cropping 399 | Resize 400 | Flipping 401 | Jittering 402 | """ 403 | def __init__(self, dboxes, args, size = (300, 300), val=False): 404 | 405 | self.args = args 406 | self.size = size 407 | self.val = val 408 | 409 | self.dboxes_ = dboxes 410 | self.encoder = Encoder(self.dboxes_) 411 | self.crop = SSDCropping() 412 | 413 | train_trans = [transforms.Resize(self.size)] 414 | train_trans.append(transforms.ColorJitter( 415 | brightness=0.125, 416 | contrast=0.5, 417 | saturation=0.5, 418 | hue=0.05)) 419 | train_trans.append(transforms.ToTensor()) 420 | 421 | self.img_trans = transforms.Compose(train_trans) 422 | self.hflip = RandomHorizontalFlip() 423 | 424 | # All PyTorch Tensor will be normalized 425 | # https://discuss.pytorch.org/t/how-to-preprocess-input-for-pre-trained-networks/683 426 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 427 | std=[0.229, 0.224, 0.225]) 428 | 429 | self.trans_val = transforms.Compose([ 430 | transforms.Resize(self.size), 431 | transforms.ToTensor(), 432 | self.normalize]) 433 | 434 | @property 435 | def dboxes(self): 436 | return self.dboxes_ 437 | 438 | def __call__(self, img, img_size, bbox=None, label=None, max_num=200): 439 | if self.val: 440 | bbox_out = torch.zeros(max_num, 4) 441 | label_out = torch.zeros(max_num, dtype=torch.long) 442 | bbox_out[:bbox.size(0), :] = bbox 443 | label_out[:label.size(0)] = label 444 | return self.trans_val(img), img_size, bbox_out, label_out 445 | 446 | img, img_size, bbox, label = self.crop(img, img_size, bbox, label) 447 | img, bbox = self.hflip(img, bbox) 448 | img = self.img_trans(img).contiguous() 449 | img = self.normalize(img) 450 | bbox, label = self.encoder.encode(bbox, label) 451 | 452 | return img, img_size, bbox, label 453 | 454 | # Implement a datareader for COCO dataset 455 | class COCODetection(data.Dataset): 456 | def __init__(self, img_folder, annotate_file, transform=None): 457 | self.img_folder = img_folder 458 | self.annotate_file = annotate_file 459 | 460 | # Start processing annotation 461 | with open(annotate_file) as fin: 462 | self.data = json.load(fin) 463 | 464 | self.images = {} 465 | 466 | self.label_map = {} 467 | self.label_info = {} 468 | #print("Parsing COCO data...") 469 | start_time = time.time() 470 | # 0 stand for the background 471 | cnt = 0 472 | self.label_info[cnt] = "background" 473 | for cat in self.data["categories"]: 474 | cnt += 1 475 | self.label_map[cat["id"]] = cnt 476 | self.label_info[cnt] = cat["name"] 477 | 478 | # build inference for images 479 | for img in self.data["images"]: 480 | img_id = img["id"] 481 | img_name = img["file_name"] 482 | img_size = (img["height"],img["width"]) 483 | #print(img_name) 484 | if img_id in self.images: raise Exception("dulpicated image record") 485 | self.images[img_id] = (img_name, img_size, []) 486 | 487 | # read bboxes 488 | for bboxes in self.data["annotations"]: 489 | img_id = bboxes["image_id"] 490 | category_id = bboxes["category_id"] 491 | bbox = bboxes["bbox"] 492 | bbox_label = self.label_map[bboxes["category_id"]] 493 | self.images[img_id][2].append((bbox, bbox_label)) 494 | 495 | for k, v in list(self.images.items()): 496 | if len(v[2]) == 0: 497 | #print("empty image: {}".format(k)) 498 | self.images.pop(k) 499 | 500 | self.img_keys = list(self.images.keys()) 501 | self.transform = transform 502 | 503 | @property 504 | def labelnum(self): 505 | return len(self.label_info) 506 | 507 | @staticmethod 508 | def load(pklfile): 509 | #print("Loading from {}".format(pklfile)) 510 | with bz2.open(pklfile, "rb") as fin: 511 | ret = pickle.load(fin) 512 | return ret 513 | 514 | def save(self, pklfile): 515 | #print("Saving to {}".format(pklfile)) 516 | with bz2.open(pklfile, "wb") as fout: 517 | pickle.dump(self, fout) 518 | 519 | 520 | def __len__(self): 521 | return len(self.images) 522 | 523 | def __getitem__(self, idx): 524 | img_id = self.img_keys[idx] 525 | img_data = self.images[img_id] 526 | fn = img_data[0] 527 | img_path = os.path.join(self.img_folder, fn) 528 | img = Image.open(img_path).convert("RGB") 529 | 530 | htot, wtot = img_data[1] 531 | bbox_sizes = [] 532 | bbox_labels = [] 533 | 534 | #for (xc, yc, w, h), bbox_label in img_data[2]: 535 | for (l,t,w,h), bbox_label in img_data[2]: 536 | r = l + w 537 | b = t + h 538 | #l, t, r, b = xc - 0.5*w, yc - 0.5*h, xc + 0.5*w, yc + 0.5*h 539 | bbox_size = (l/wtot, t/htot, r/wtot, b/htot) 540 | bbox_sizes.append(bbox_size) 541 | bbox_labels.append(bbox_label) 542 | 543 | bbox_sizes = torch.tensor(bbox_sizes) 544 | bbox_labels = torch.tensor(bbox_labels) 545 | 546 | 547 | if self.transform != None: 548 | img, (htot, wtot), bbox_sizes, bbox_labels = \ 549 | self.transform(img, (htot, wtot), bbox_sizes, bbox_labels) 550 | else: 551 | pass 552 | 553 | return img, img_id, (htot, wtot), bbox_sizes, bbox_labels 554 | 555 | 556 | def draw_patches(img, bboxes, labels, order="xywh", label_map={}): 557 | 558 | import matplotlib.pyplot as plt 559 | import matplotlib.patches as patches 560 | # Suppose bboxes in fractional coordinate: 561 | # cx, cy, w, h 562 | # img = img.numpy() 563 | img = np.array(img) 564 | labels = np.array(labels) 565 | bboxes = bboxes.numpy() 566 | 567 | if label_map: 568 | labels = [label_map.get(l) for l in labels] 569 | 570 | if order == "ltrb": 571 | xmin, ymin, xmax, ymax = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] 572 | cx, cy, w, h = (xmin + xmax)/2, (ymin + ymax)/2, xmax - xmin, ymax - ymin 573 | else: 574 | cx, cy, w, h = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3] 575 | 576 | htot, wtot,_ = img.shape 577 | cx *= wtot 578 | cy *= htot 579 | w *= wtot 580 | h *= htot 581 | 582 | bboxes = zip(cx, cy, w, h) 583 | 584 | plt.imshow(img) 585 | ax = plt.gca() 586 | for (cx, cy, w, h), label in zip(bboxes, labels): 587 | if label == "background": continue 588 | ax.add_patch(patches.Rectangle((cx-0.5*w, cy-0.5*h), 589 | w, h, fill=False, color="r")) 590 | bbox_props = dict(boxstyle="round", fc="y", ec="0.5", alpha=0.3) 591 | ax.text(cx-0.5*w, cy-0.5*h, label, ha="center", va="center", size=15, bbox=bbox_props) 592 | plt.show() 593 | -------------------------------------------------------------------------------- /tests/test_cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | import os 3 | import copy 4 | from ncnnqat import unquant_weight, merge_freeze_bn, register_quantization_hook,save_table 5 | import unittest 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.optim import lr_scheduler 11 | from torchvision import models 12 | from torchvision import datasets,utils 13 | from torch.autograd import Variable 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | from torchsummary import summary 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3' 25 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5' 26 | #os.environ['CUDA_VISIBLE_DEVICES'] = '0' 27 | 28 | 29 | 30 | 31 | def net_builder(class_num,net_name="mobile_netv2"): 32 | if net_name == "mobile_netv2": 33 | net = models.mobilenet_v2(pretrained=True) 34 | net.classifier = nn.Sequential(nn.Linear(1280, 1000), nn.ReLU(True),nn.Dropout(0.5),nn.Linear(1000, class_num)) 35 | elif net_name == "resnet18": 36 | net = models.resnet18(pretrained=True) 37 | num_ftrs = net.fc.in_features 38 | net.fc = nn.Linear(num_ftrs, class_num) 39 | else: 40 | raise ValueError("net_name not in(mobile_netv2,resnet18)") 41 | return net 42 | 43 | 44 | class Mbnet(unittest.TestCase): 45 | def test(self): 46 | num_workers = 10 47 | 48 | 49 | net_name="resnet18" 50 | net_name="mobile_netv2" 51 | 52 | class_num = 10 53 | 54 | img_size = 224 55 | batch_size = 128 56 | epoch_all = 50 57 | epoch_merge_bn = epoch_all-5 58 | 59 | #maybe cuda out of memery,set test epoch in a small count 60 | epoch_all = 4 61 | epoch_merge_bn = epoch_all-2 62 | 63 | checkpoint = "./model.pt" 64 | pre_list = ["train","val"] 65 | dataloaders = {} 66 | 67 | transform = transforms.Compose([ 68 | transforms.Resize(img_size), 69 | transforms.ToTensor(), 70 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 71 | ]) 72 | trainset = torchvision.datasets.CIFAR10(root='./data', 73 | train=True, 74 | download=True, 75 | transform=transform) 76 | dataloaders['train'] = torch.utils.data.DataLoader(trainset, 77 | batch_size=batch_size, 78 | shuffle=True, 79 | num_workers=2) 80 | testset = torchvision.datasets.CIFAR10(root='./data', 81 | train=False, 82 | download=True, 83 | transform=transform) 84 | dataloaders['val'] = torch.utils.data.DataLoader(testset, 85 | batch_size=batch_size, 86 | shuffle=True, 87 | num_workers=2) 88 | 89 | 90 | 91 | dummy_input = torch.randn(1, 3, img_size, img_size, device='cuda') 92 | input_names = [ "input" ] 93 | output_names = [ "fc" ] #mobilenet 94 | 95 | 96 | 97 | net = net_builder(10,net_name=net_name) 98 | 99 | if torch.cuda.device_count() > 1: 100 | print("Let's use", torch.cuda.device_count(), "GPUs!") 101 | net = nn.DataParallel(net) 102 | net.cuda() 103 | criterion = nn.CrossEntropyLoss() 104 | optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) 105 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 106 | 107 | best_model_wts = copy.deepcopy(net.state_dict()) 108 | best_acc = 0.0 109 | best_acc_org = 0.0 110 | print("training:") 111 | state_dict_merge = False 112 | for epoch in range(epoch_all): 113 | net.train() 114 | if epoch == epoch_merge_bn: 115 | best_acc_org = best_acc 116 | #save not use qat model 117 | if torch.cuda.device_count() > 1: 118 | net_t = net_builder(class_num,net_name=net_name) 119 | net_t.cuda() 120 | net_t.load_state_dict({k.replace('module.',''):v for k,v in best_model_wts.items()}) 121 | torch.onnx.export(net_t, dummy_input, "mobilenet_org.onnx", verbose=False, input_names=input_names, output_names=output_names) 122 | print("export org onnx") 123 | else: 124 | torch.onnx.export(net, dummy_input, "mobilenet_org.onnx", verbose=False, input_names=input_names, output_names=output_names) 125 | print("export org onnx") 126 | register_quantization_hook(net) 127 | net = merge_freeze_bn(net) 128 | 129 | best_acc = 0. 130 | if epoch == epoch_merge_bn+1: 131 | net = merge_freeze_bn(net) 132 | print("merge bn") 133 | best_model_wts = copy.deepcopy(net.state_dict()) #first epoch of qat ,save model as baseline 134 | if epoch > epoch_merge_bn+1: 135 | print("merge bn") 136 | net = merge_freeze_bn(net) 137 | 138 | running_loss = 0.0 139 | bath_term = 20 140 | for index, data in enumerate(dataloaders['train']): 141 | inputs, labels = data 142 | inputs, labels = Variable(inputs.cuda()), Variable( 143 | labels.cuda()) 144 | optimizer.zero_grad() 145 | outputs = net(inputs) 146 | loss = criterion(outputs, labels) 147 | loss.backward() 148 | if epoch >= epoch_merge_bn: 149 | net.apply(unquant_weight) 150 | 151 | optimizer.step() 152 | 153 | running_loss += loss.item() 154 | if index % bath_term == 100: 155 | print(' epoch %3d, Iter %5d, loss: %.3f' % (epoch + 1, index + 1, running_loss / bath_term)) 156 | running_loss = 0.0 157 | exp_lr_scheduler.step() 158 | 159 | net.eval() 160 | correct = total = 0 161 | for data in dataloaders['val']: 162 | images, labels = data 163 | outputs = net(Variable(images.cuda())) 164 | _, predicted = torch.max(outputs.data, 1) 165 | correct += (predicted == labels.cuda()).sum() 166 | total += labels.size(0) 167 | print('Epoch: {} Accuracy: {}'.format(str(epoch),str(100.0 * correct.cpu().numpy() / total))) 168 | epoch_acc = 100.0 * correct / total 169 | if epoch_acc >= best_acc: 170 | best_acc = epoch_acc 171 | best_model_wts = copy.deepcopy(net.state_dict()) 172 | print("get best ....") 173 | net.load_state_dict(best_model_wts) 174 | print('Finished Training.') 175 | 176 | net.eval() 177 | correct = total = 0 178 | for data in dataloaders['val']: 179 | images, labels = data 180 | outputs = net(Variable(images.cuda())) 181 | _, predicted = torch.max(outputs.data, 1) 182 | correct += (predicted == labels.cuda()).sum() 183 | total += labels.size(0) 184 | print('Accuracy: {}'.format(str(100.0 * correct.cpu().numpy() / total))) 185 | 186 | if torch.cuda.device_count() > 1: 187 | net_t = net_builder(class_num,net_name=net_name) 188 | net_t.cuda() 189 | register_quantization_hook(net_t) 190 | net_t = merge_freeze_bn(net_t) 191 | 192 | net_t.load_state_dict({k.replace('module.',''):v for k,v in net.state_dict().items()}) 193 | torch.onnx.export(net_t, dummy_input, "mobilenet.onnx", verbose=False, input_names=input_names, output_names=output_names) #保存模型 194 | save_table(net_t,onnx_path="mobilenet.onnx",table="mobilenet.table") 195 | print("export qat onnx") 196 | else: 197 | torch.onnx.export(net, dummy_input, "mobilenet.onnx", verbose=False, input_names=input_names, output_names=output_names) 198 | save_table(net,onnx_path="mobilenet.onnx",table="mobilenet.table") 199 | print("export qat onnx") 200 | print(best_acc_org,best_acc) 201 | if __name__ == "__main__": 202 | suite = unittest.TestSuite() 203 | suite.addTest(Mbnet("test")) 204 | runner = unittest.TextTestRunner() 205 | runner.run(suite) 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | --------------------------------------------------------------------------------