├── .gitignore ├── LICENSE ├── README.md ├── configs ├── ps_vit_b_14.yaml ├── ps_vit_b_16.yaml ├── ps_vit_b_18.yaml └── ps_vit_ti_14.yaml ├── datasets ├── __init__.py └── list_dataset.py ├── imgs └── overview.png ├── layers ├── __init__.py ├── cinclude │ ├── cpp_helper.hpp │ ├── cuda_helper.hpp │ └── progressive_sampling_cuda_kernel.cuh ├── csrc │ ├── info.cpp │ ├── progressive_sampling.cpp │ ├── progressive_sampling_cuda.cu │ └── pybind.cpp └── progressive_sample.py ├── main.py ├── models ├── __init__.py ├── ps_vit.py └── transformer_block.py ├── scripts ├── train_distributed.sh └── train_slurm.sh ├── setup.py └── utils ├── __init__.py ├── distributed_utils.py ├── ext_loader.py ├── flop_count ├── __init__.py ├── flop_count.py ├── jit_analysis.py └── jit_handles.py ├── loader.py └── sampler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # editors and IDEs 107 | .idea/ 108 | .vscode/ 109 | 110 | # custom 111 | .DS_Store 112 | weights/ 113 | output/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xiaoyu Yue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision Transformer with Progressive Sampling 2 | 3 | This is the official implementation of the paper [Vision Transformer with Progressive Sampling](https://arxiv.org/abs/2108.01684), ICCV 2021. 4 | 5 | 6 | ![Visual Parser](imgs/overview.png) 7 | 8 | 9 | ## Installation Instructions 10 | - Clone this repo: 11 | 12 | ```bash 13 | git clone git@github.com:yuexy/PS-ViT.git 14 | cd PS-ViT 15 | ``` 16 | 17 | - Create a conda virtual environment and activate it: 18 | 19 | ```bash 20 | conda create -n ps_vit python=3.7 -y 21 | conda activate ps_vit 22 | ``` 23 | 24 | - Install `CUDA==10.1` with `cudnn7` following 25 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 26 | - Install `PyTorch==1.7.1` and `torchvision==0.8.2` with `CUDA==10.1`: 27 | 28 | ```bash 29 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch 30 | ``` 31 | 32 | - Install `timm==0.3.4, einops, pyyaml`: 33 | 34 | ```bash 35 | pip3 install timm=0.3.4, einops, pyyaml 36 | ``` 37 | 38 | - Install `Apex`: 39 | 40 | ```bash 41 | git clone https://github.com/NVIDIA/apex 42 | cd apex 43 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 44 | ``` 45 | 46 | - Install `PS-ViT`: 47 | 48 | ```bash 49 | python setup.py build_ext --inplace 50 | ``` 51 | 52 | 53 | ## Results and Models 54 | *All models listed below are evaluated with input size 224x224* 55 | 56 | | Model | Top1 Acc | #params | FLOPS | Download | 57 | | :--- | :---: | :---: | :---: | :---: | 58 | | PS-ViT-Ti/14 | 75.6 | 4.8M | 1.6G | Coming Soon | 59 | | PS-ViT-B/10 | 80.6 | 21.3M | 3.1G | Coming Soon | 60 | | PS-ViT-B/14 | 81.7 | 21.3M | 5.4G | [Google Drive](https://drive.google.com/file/d/1FAAOCbpgPKlSe3dWIzQLg8JK6okvZkC5/view?usp=sharing) | 61 | | PS-ViT-B/18 | 82.3 | 21.3M | 8.8G | [Google Drive](https://drive.google.com/file/d/1KG4TdrfbNNdbNImCPCdSeQ5Y-gkDMlnr/view?usp=sharing) | 62 | 63 | ## Evaluation 64 | 65 | To evaluate a pre-trained `PS-ViT` on ImageNet val, run: 66 | 67 | ```bash 68 | python3 main.py --model -b --eval_checkpoint 69 | ``` 70 | 71 | ## Training from scratch 72 | 73 | To train a `PS-ViT` on ImageNet from scratch, run: 74 | 75 | ```bash 76 | bash ./scripts/train_distributed.sh 77 | ``` 78 | 79 | ## Citing PS-ViT 80 | ``` 81 | @article{psvit, 82 | title={Vision Transformer with Progressive Sampling}, 83 | author={Yue, Xiaoyu and Sun, Shuyang and Kuang, Zhanghui and Wei, Meng and Torr, Philip and Zhang, Wayne and Lin, Dahua}, 84 | journal={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 85 | year={2021} 86 | } 87 | ``` 88 | 89 | ## Contact 90 | If you have any questions, don't hesitate to contact Xiaoyu Yue. 91 | You can easily reach him by sending an email to yuexiaoyu002@gmail.com. 92 | -------------------------------------------------------------------------------- /configs/ps_vit_b_14.yaml: -------------------------------------------------------------------------------- 1 | model: "ps_vit_b_14" 2 | batch_size: 64 3 | lr: 5e-4 4 | weight_decay: 0.05 5 | img_size: 224 6 | -------------------------------------------------------------------------------- /configs/ps_vit_b_16.yaml: -------------------------------------------------------------------------------- 1 | model: "ps_vit_b_16" 2 | batch_size: 64 3 | lr: 5e-4 4 | weight_decay: 0.05 5 | img_size: 224 6 | -------------------------------------------------------------------------------- /configs/ps_vit_b_18.yaml: -------------------------------------------------------------------------------- 1 | model: "ps_vit_b_18" 2 | batch_size: 64 3 | lr: 5e-4 4 | weight_decay: 0.05 5 | img_size: 224 6 | -------------------------------------------------------------------------------- /configs/ps_vit_ti_14.yaml: -------------------------------------------------------------------------------- 1 | model: "ps_vit_ti_14" 2 | batch_size: 64 3 | lr: 0.001 4 | weight_decay: 0.03 5 | img_size: 224 6 | drop: 0.0 7 | mixup: 0.8 8 | reprob: 0.25 9 | warmup_epochs: 5 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .list_dataset import ListDataset 2 | 3 | __all__ = ['ListDataset'] 4 | -------------------------------------------------------------------------------- /datasets/list_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class ListDataset(Dataset): 8 | def __init__(self, 9 | root_dir, 10 | meta_file, 11 | transform=None): 12 | self.root_dir = root_dir 13 | self.transform = transform 14 | 15 | label_file = open(meta_file, 'r') 16 | lines = label_file.readlines() 17 | label_file.close() 18 | 19 | self.num = len(lines) 20 | self.metas = [] 21 | for line in lines: 22 | img_path, cls_label = line.rstrip().split() 23 | self.metas.append((img_path, int(cls_label))) 24 | 25 | def __len__(self): 26 | return self.num 27 | 28 | def __getitem__(self, idx): 29 | img_name, cls_label = self.metas[idx] 30 | 31 | img_path = os.path.join(self.root_dir, img_name) 32 | 33 | img = Image.open(img_path).convert('RGB') 34 | 35 | # transform 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img, cls_label 40 | 41 | def filename(self, index, basename=False, absolute=False): 42 | filename = self.metas[index][0] 43 | if basename: 44 | filename = os.path.basename(filename) 45 | elif not absolute: 46 | filename = os.path.relpath(filename, self.root_dir) 47 | return filename 48 | 49 | def filenames(self, basename=False, absolute=False): 50 | fn = lambda x: x 51 | if basename: 52 | fn = os.path.basename 53 | elif not absolute: 54 | fn = lambda x: os.path.relpath(x, self.root) 55 | return [fn(x[0]) for x in self.metas] 56 | -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuexy/PS-ViT/a63a397e64d07a238a9fcaf392dce4c4596f6636/imgs/overview.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .progressive_sample import ProgressiveSample 2 | 3 | __all__ = ['ProgressiveSample'] 4 | -------------------------------------------------------------------------------- /layers/cinclude/cpp_helper.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CPP_HELPER 2 | #define CPP_HELPER 3 | 4 | #include 5 | 6 | #include 7 | 8 | using namespace at; 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CPU(x) \ 13 | TORCH_CHECK(!x.device().is_cuda(), #x " must be a CPU tensor") 14 | #define CHECK_CONTIGUOUS(x) \ 15 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_CUDA_INPUT(x) \ 17 | CHECK_CUDA(x); \ 18 | CHECK_CONTIGUOUS(x) 19 | #define CHECK_CPU_INPUT(x) \ 20 | CHECK_CPU(x); \ 21 | CHECK_CONTIGUOUS(x) 22 | 23 | #endif // CPP_HELPER 24 | -------------------------------------------------------------------------------- /layers/cinclude/cuda_helper.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_HELPER 2 | #define CUDA_HELPER 3 | 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include 11 | 12 | using at::Half; 13 | using at::Tensor; 14 | using phalf = at::Half; 15 | 16 | #define __PHALF(x) (x) 17 | 18 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 19 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ 20 | i += blockDim.x * gridDim.x) 21 | 22 | #define THREADS_PER_BLOCK 512 23 | 24 | inline int GET_BLOCKS(const int N) 25 | { 26 | int optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; 27 | int max_block_num = 4096; 28 | return min(optimal_block_num, max_block_num); 29 | } 30 | 31 | template 32 | __device__ T bilinear_interpolate(const T *input, 33 | const int height, 34 | const int width, 35 | T h, 36 | T w) 37 | { 38 | if (h <= -1 || h >= height || w <= -1 || w >= width) 39 | { 40 | return 0; 41 | } 42 | 43 | int h_low = floor(h); 44 | int w_low = floor(w); 45 | int h_high = h_low + 1; 46 | int w_high = w_low + 1; 47 | 48 | T lh = h - h_low; 49 | T lw = w - w_low; 50 | T hh = 1. - lh; 51 | T hw = 1. - lw; 52 | 53 | T v1 = 0; 54 | if (h_low >= 0 && w_low >= 0) 55 | v1 = input[h_low * width + w_low]; 56 | T v2 = 0; 57 | if (h_low >= 0 && w_high <= width - 1) 58 | v2 = input[h_low * width + w_high]; 59 | T v3 = 0; 60 | if (h_high <= height - 1 && w_low >= 0) 61 | v3 = input[h_high * width + w_low]; 62 | T v4 = 0; 63 | if (h_high <= height - 1 && w_high <= width - 1) 64 | v4 = input[h_high * width + w_high]; 65 | 66 | T w1 = hh * hw; 67 | T w2 = hh * lw; 68 | T w3 = lh * hw; 69 | T w4 = lh * lw; 70 | 71 | T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 72 | return val; 73 | } 74 | 75 | template 76 | __device__ void bilinear_interpolate_gradient(const int height, 77 | const int width, 78 | T y, 79 | T x, 80 | T& w1, 81 | T& w2, 82 | T& w3, 83 | T& w4, 84 | int& y_low, 85 | int& y_high, 86 | int& x_low, 87 | int& x_high) 88 | { 89 | if (y <= -1. || y >= height || x <= -1. || x >= width) 90 | { 91 | w1 = w2 = w3 = w4 = 0.; 92 | x_low = x_high = y_low = y_high = -1; 93 | return; 94 | } 95 | 96 | if (y <= 0) y = 0; 97 | if (x <= 0) x = 0; 98 | 99 | y_low = (int) y; 100 | x_low = (int) x; 101 | 102 | if (y_low >= height - 1) 103 | { 104 | y_high = y_low = height - 1; 105 | y = (T) y_low; 106 | } 107 | else 108 | { 109 | y_high = y_low + 1; 110 | } 111 | 112 | if (x_low >= width - 1) 113 | { 114 | x_high = x_low = width - 1; 115 | x = (T) x_low; 116 | } 117 | else 118 | { 119 | x_high = x_low + 1; 120 | } 121 | 122 | T ly = y - y_low; 123 | T lx = x - x_low; 124 | T hy = 1. - ly; 125 | T hx = 1. - lx; 126 | 127 | w1 = hy * hx; 128 | w2 = hy * lx; 129 | w3 = ly * hx; 130 | w4 = ly * lx; 131 | } 132 | 133 | #endif // CUDA_HELPER 134 | -------------------------------------------------------------------------------- /layers/cinclude/progressive_sampling_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #ifndef PROGRESSIVE_SAMPLING_CUDA_KERNEL 2 | #define PROGRESSIVE_SAMPLING_CUDA_KERNEL 3 | 4 | #include "cuda_helper.hpp" 5 | 6 | 7 | template 8 | __global__ void progressive_sampling_forward_cuda_kernel(const int nthreads, 9 | const T* input, 10 | const T* point, 11 | const T* offset, 12 | T* output, 13 | const int channels, 14 | const int point_num, 15 | const int height, 16 | const int width, 17 | const T gamma) 18 | { 19 | CUDA_1D_KERNEL_LOOP(index, nthreads) 20 | { 21 | int c = index % channels; 22 | int p = (index / channels) % point_num; 23 | int n = index / channels / point_num; 24 | 25 | const T* current_point = point + (n * point_num + p) * 2; 26 | const T* current_offset = offset + (n * point_num + p) * 2; 27 | const T* current_input = input + (n * channels + c) * height * width; 28 | 29 | const T y = current_point[0] + current_offset[0] * gamma; 30 | const T x = current_point[1] + current_offset[1] * gamma; 31 | 32 | output[index] = bilinear_interpolate(current_input, height, width, y, x); 33 | } 34 | } 35 | 36 | 37 | template 38 | __global__ void progressive_sampling_backward_cuda_kernel(const int nthreads, 39 | const T* grad_output, 40 | const T* input, 41 | const T* point, 42 | const T* offset, 43 | T* grad_input, 44 | T* grad_offset, 45 | int channels, 46 | int point_num, 47 | int height, 48 | int width, 49 | const T gamma) 50 | { 51 | CUDA_1D_KERNEL_LOOP(index, nthreads) 52 | { 53 | int c = index % channels; 54 | int p = (index / channels) % point_num; 55 | int n = index / channels / point_num; 56 | 57 | const T* current_point = point + (n * point_num + p) * 2; 58 | const T* current_offset = offset + (n * point_num + p) * 2; 59 | const T* current_input = input + (n * channels + c) * height * width; 60 | 61 | const T y = current_point[0] + current_offset[0] * gamma; 62 | const T x = current_point[1] + current_offset[1] * gamma; 63 | 64 | const T grad_current_output = grad_output[index]; 65 | 66 | T* grad_current_input = grad_input + (n * channels + c) * height * width; 67 | T* grad_current_offset = grad_offset + (n * point_num + p) * 2; 68 | 69 | T w1, w2, w3, w4; 70 | int x_low, x_high, y_low, y_high; 71 | 72 | bilinear_interpolate_gradient(height, 73 | width, 74 | y, 75 | x, 76 | w1, w2, w3, w4, 77 | y_low, y_high, 78 | x_low, x_high); 79 | 80 | if (x_low >= 0 && x_high >=0 && y_low >= 0 && y_high >= 0) 81 | { 82 | atomicAdd(grad_current_input + y_low * width + x_low, 83 | grad_current_output * w1); 84 | atomicAdd(grad_current_input + y_low * width + x_high, 85 | grad_current_output * w2); 86 | atomicAdd(grad_current_input + y_high * width + x_low, 87 | grad_current_output * w3); 88 | atomicAdd(grad_current_input + y_high * width + x_high, 89 | grad_current_output * w4); 90 | 91 | T input_00 = current_input[y_low * width + x_low]; 92 | T input_10 = current_input[y_low * width + x_high]; 93 | T input_01 = current_input[y_high * width + x_low]; 94 | T input_11 = current_input[y_high * width + x_high]; 95 | T ogx = gamma * grad_current_output * 96 | (input_11 * (y - y_low) + input_10 * (y_high - y) + 97 | input_01 * (y_low - y) + input_00 * (y - y_high)); 98 | T ogy = gamma * grad_current_output * 99 | (input_11 * (x - x_low) + input_01 * (x_high - x) + 100 | input_10 * (x_low - x) + input_00 * (x - x_high)); 101 | atomicAdd(grad_current_offset, ogy); 102 | atomicAdd(grad_current_offset + 1, ogx); 103 | } 104 | } 105 | } 106 | 107 | #endif // PROGRESSIVE_SAMPLING_CUDA_KERNEL 108 | -------------------------------------------------------------------------------- /layers/csrc/info.cpp: -------------------------------------------------------------------------------- 1 | #include "cpp_helper.hpp" 2 | 3 | #include 4 | 5 | int get_cudart_version() 6 | { 7 | return CUDART_VERSION; 8 | } 9 | 10 | std::string get_compiling_cuda_version() 11 | { 12 | std::ostringstream oss; 13 | // copied from 14 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 15 | auto printCudaStyleVersion = [&](int v) { 16 | oss << (v / 1000) << "." << (v / 10 % 100); 17 | if (v % 10 != 0) { 18 | oss << "." << (v % 10); 19 | } 20 | }; 21 | printCudaStyleVersion(get_cudart_version()); 22 | return oss.str(); 23 | } 24 | 25 | std::string get_compiler_version() 26 | { 27 | std::ostringstream ss; 28 | #if defined(__GNUC__) 29 | #ifndef __clang__ 30 | { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } 31 | #endif 32 | #endif 33 | 34 | #if defined(__clang_major__) 35 | { 36 | ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." 37 | << __clang_patchlevel__; 38 | } 39 | #endif 40 | 41 | #if defined(_MSC_VER) 42 | { ss << "MSVC " << _MSC_FULL_VER; } 43 | #endif 44 | return ss.str(); 45 | } 46 | -------------------------------------------------------------------------------- /layers/csrc/progressive_sampling.cpp: -------------------------------------------------------------------------------- 1 | #include "cpp_helper.hpp" 2 | 3 | 4 | void ProgressiveSamplingForwardCUDAKernelLauncher(Tensor input, 5 | Tensor point, 6 | Tensor offset, 7 | Tensor output, 8 | float gamma); 9 | 10 | void ProgressiveSamplingBackwardCUDAKernelLauncher(Tensor grad_output, 11 | Tensor input, 12 | Tensor point, 13 | Tensor offset, 14 | Tensor grad_input, 15 | Tensor grad_offset, 16 | float gamma); 17 | 18 | 19 | void progressive_sampling_forward(Tensor input, 20 | Tensor point, 21 | Tensor offset, 22 | Tensor output, 23 | float gamma) 24 | { 25 | ProgressiveSamplingForwardCUDAKernelLauncher(input, 26 | point, 27 | offset, 28 | output, 29 | gamma); 30 | } 31 | 32 | void progressive_sampling_backward(Tensor grad_output, 33 | Tensor input, 34 | Tensor point, 35 | Tensor offset, 36 | Tensor grad_input, 37 | Tensor grad_offset, 38 | float gamma) 39 | { 40 | ProgressiveSamplingBackwardCUDAKernelLauncher(grad_output, 41 | input, 42 | point, 43 | offset, 44 | grad_input, 45 | grad_offset, 46 | gamma); 47 | } 48 | -------------------------------------------------------------------------------- /layers/csrc/progressive_sampling_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "progressive_sampling_cuda_kernel.cuh" 2 | #include "cuda_helper.hpp" 3 | 4 | 5 | void ProgressiveSamplingForwardCUDAKernelLauncher(Tensor input, 6 | Tensor point, 7 | Tensor offset, 8 | Tensor output, 9 | float gamma) 10 | { 11 | int output_size = output.numel(); 12 | int channels = input.size(1); 13 | int height = input.size(2); 14 | int width = input.size(3); 15 | int point_num = point.size(1); 16 | 17 | at::cuda::CUDAGuard device_guard(input.device()); 18 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 19 | 20 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 21 | input.scalar_type(), "progressive_sampling_forward_cuda_kernel", [&] { 22 | progressive_sampling_forward_cuda_kernel 23 | <<>>( 24 | output_size, 25 | input.data_ptr(), 26 | point.data_ptr(), 27 | offset.data_ptr(), 28 | output.data_ptr(), 29 | channels, 30 | point_num, 31 | height, 32 | width, 33 | gamma 34 | ); 35 | } 36 | ); 37 | 38 | AT_CUDA_CHECK(cudaGetLastError()); 39 | } 40 | 41 | void ProgressiveSamplingBackwardCUDAKernelLauncher(Tensor grad_output, 42 | Tensor input, 43 | Tensor point, 44 | Tensor offset, 45 | Tensor grad_input, 46 | Tensor grad_offset, 47 | float gamma) 48 | { 49 | int output_size = grad_output.numel(); 50 | int channels = grad_input.size(1); 51 | int height = grad_input.size(2); 52 | int width = grad_input.size(3); 53 | int point_num = grad_offset.size(1); 54 | 55 | at::cuda::CUDAGuard device_guard(grad_output.device()); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 57 | 58 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 59 | grad_output.scalar_type(), "progressive_sampling_backward_cuda_kernel", [&] { 60 | progressive_sampling_backward_cuda_kernel 61 | <<>>( 62 | output_size, 63 | grad_output.data_ptr(), 64 | input.data_ptr(), 65 | point.data_ptr(), 66 | offset.data_ptr(), 67 | grad_input.data_ptr(), 68 | grad_offset.data_ptr(), 69 | channels, 70 | point_num, 71 | height, 72 | width, 73 | gamma 74 | ); 75 | } 76 | ); 77 | 78 | AT_CUDA_CHECK(cudaGetLastError()); 79 | } 80 | -------------------------------------------------------------------------------- /layers/csrc/pybind.cpp: -------------------------------------------------------------------------------- 1 | #include "cpp_helper.hpp" 2 | 3 | std::string get_compiler_version(); 4 | std::string get_compiling_cuda_version(); 5 | 6 | void progressive_sampling_forward(Tensor input, 7 | Tensor point, 8 | Tensor offset, 9 | Tensor output, 10 | float gamma); 11 | 12 | void progressive_sampling_backward(Tensor grad_output, 13 | Tensor input, 14 | Tensor point, 15 | Tensor offset, 16 | Tensor grad_input, 17 | Tensor grad_offset, 18 | float gamma); 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 21 | { 22 | m.def("progressive_sampling_forward", &progressive_sampling_forward, 23 | "progressive sampling forward", 24 | py::arg("input"), 25 | py::arg("point"), 26 | py::arg("offset"), 27 | py::arg("output"), 28 | py::arg("gamma")); 29 | m.def("progressive_sampling_backward", &progressive_sampling_backward, 30 | "progressive sampling backward", 31 | py::arg("grad_output"), 32 | py::arg("input"), 33 | py::arg("point"), 34 | py::arg("offset"), 35 | py::arg("grad_input"), 36 | py::arg("grad_offset"), 37 | py::arg("gamma")); 38 | } 39 | -------------------------------------------------------------------------------- /layers/progressive_sample.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from torch.autograd import Function 4 | from torch.autograd.function import once_differentiable 5 | 6 | from utils import ext_loader 7 | 8 | ext_module = ext_loader.load_ext('_ext', ['progressive_sampling_forward', 9 | 'progressive_sampling_backward']) 10 | 11 | 12 | class ProgressiveSamplingFunction(Function): 13 | @staticmethod 14 | def forward(ctx, 15 | input, 16 | point, 17 | offset=None, 18 | gamma=1.0): 19 | ctx.gamma = float(gamma) 20 | if offset is None: 21 | offset = torch.zeros_like(point) 22 | 23 | output_shape = (point.size(0), 24 | point.size(1), 25 | input.size(1)) 26 | 27 | output = input.new_zeros(output_shape) 28 | 29 | ext_module.progressive_sampling_forward(input, 30 | point, 31 | offset, 32 | output, 33 | ctx.gamma) 34 | 35 | ctx.save_for_backward(input, point, offset) 36 | return output 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, grad_output): 41 | input, point, offset = ctx.saved_tensors 42 | grad_input = grad_output.new_zeros(input.shape) 43 | grad_offset = grad_output.new_zeros(offset.shape) 44 | 45 | ext_module.progressive_sampling_backward(grad_output, 46 | input, 47 | point, 48 | offset, 49 | grad_input, 50 | grad_offset, 51 | ctx.gamma) 52 | return grad_input, None, grad_offset, None 53 | 54 | 55 | progressive_sampling = ProgressiveSamplingFunction.apply 56 | 57 | 58 | class ProgressiveSample(nn.Module): 59 | def __init__(self, 60 | gamma=1.0): 61 | super(ProgressiveSample, self).__init__() 62 | self.gamma = gamma 63 | 64 | def forward(self, input, point, offset): 65 | """ 66 | :param input: [n, c, h, w] 67 | :param point: [n, point_num, 2] (y, x) 68 | :param offset: [n, point_num, 2] (y, x) 69 | :output: [n, point_num, c] 70 | """ 71 | return progressive_sampling(input, point, offset, self.gamma) 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import yaml 4 | import os 5 | import logging 6 | from collections import OrderedDict 7 | from contextlib import suppress 8 | from datetime import datetime 9 | import models 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.utils 14 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 15 | 16 | from timm.data import Dataset, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 17 | from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model 18 | from timm.utils import * 19 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy 20 | from timm.optim import create_optimizer 21 | from timm.scheduler import create_scheduler 22 | from timm.utils import ApexScaler, NativeScaler 23 | 24 | from datasets import ListDataset 25 | from utils import distributed_utils 26 | 27 | from utils.loader import create_loader 28 | 29 | torch.backends.cudnn.benchmark = True 30 | _logger = logging.getLogger('train') 31 | 32 | config_parser = parser = argparse.ArgumentParser( 33 | description='Training Config', add_help=False) 34 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 35 | help='YAML config file specifying default arguments') 36 | 37 | parser = argparse.ArgumentParser(description='Training and Evaluating') 38 | 39 | # Dataset / Model parameters 40 | parser.add_argument('--data', metavar='DIR', 41 | default='data/', help='path to dataset') 42 | parser.add_argument('--data_train_root', default=None) 43 | parser.add_argument('--data_train_label', default=None) 44 | parser.add_argument('--data_val_root', default=None) 45 | parser.add_argument('--data_val_label', default=None) 46 | 47 | parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', 48 | help='Name of model to train (default: "countception"') 49 | parser.add_argument('--pretrained', action='store_true', default=False, 50 | help='Start with pretrained version of specified network (if avail)') 51 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 52 | help='Initialize model from this checkpoint (default: none)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='Resume full model and optimizer state from checkpoint (default: none)') 55 | parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH', 56 | help='path to eval checkpoint (default: none)') 57 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 58 | help='prevent resume of optimizer state when resuming model') 59 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 60 | help='number of label classes (default: 1000)') 61 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 62 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 63 | parser.add_argument('--img-size', type=int, default=224, metavar='N', 64 | help='Image patch size (default: None => model default)') 65 | parser.add_argument('--crop-pct', default=None, type=float, 66 | metavar='N', help='Input image center crop percent (for validation only)') 67 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 68 | help='Override mean pixel value of dataset') 69 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 70 | help='Override std deviation of of dataset') 71 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 72 | help='Image resize interpolation type (overrides model)') 73 | parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', 74 | help='input batch size for training (default: 32)') 75 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 76 | help='ratio of validation batch size to training batch size (default: 1)') 77 | 78 | # Optimizer parameters 79 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 80 | help='Optimizer (default: "sgd"') 81 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 82 | help='Optimizer Epsilon (default: None, use opt default)') 83 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 84 | help='Optimizer Betas (default: None, use opt default)') 85 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 86 | help='Optimizer momentum (default: 0.9)') 87 | parser.add_argument('--weight-decay', type=float, default=0.05, 88 | help='weight decay (default: 0.0001)') 89 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 90 | help='Clip gradient norm (default: None, no clipping)') 91 | 92 | # Learning rate schedule parameters 93 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 94 | help='LR scheduler (default: "step"') 95 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 96 | help='learning rate (default: 0.01)') 97 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 98 | help='learning rate noise on/off epoch percentages') 99 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 100 | help='learning rate noise limit percent (default: 0.67)') 101 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 102 | help='learning rate noise std-dev (default: 1.0)') 103 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 104 | help='learning rate cycle len multiplier (default: 1.0)') 105 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 106 | help='learning rate cycle limit') 107 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 108 | help='warmup learning rate (default: 0.0001)') 109 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 110 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 111 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 112 | help='number of epochs to train (default: 2)') 113 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 114 | help='manual epoch number (useful on restarts)') 115 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 116 | help='epoch interval to decay LR') 117 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', 118 | help='epochs to warmup LR, if scheduler supports') 119 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 120 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 121 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 122 | help='patience epochs for Plateau LR scheduler (default: 10') 123 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 124 | help='LR decay rate (default: 0.1)') 125 | 126 | # Augmentation & regularization parameters 127 | parser.add_argument('--no-aug', action='store_true', default=False, 128 | help='Disable all training augmentation, override other train aug args') 129 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 130 | help='Random resize scale (default: 0.08 1.0)') 131 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 132 | help='Random resize aspect ratio (default: 0.75 1.33)') 133 | parser.add_argument('--hflip', type=float, default=0.5, 134 | help='Horizontal flip training aug probability') 135 | parser.add_argument('--vflip', type=float, default=0., 136 | help='Vertical flip training aug probability') 137 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 138 | help='Color jitter factor (default: 0.4)') 139 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 140 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 141 | parser.add_argument('--aug-splits', type=int, default=0, 142 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 143 | parser.add_argument('--jsd', action='store_true', default=False, 144 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 145 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT', 146 | help='Random erase prob (default: 0.)') 147 | parser.add_argument('--remode', type=str, default='pixel', 148 | help='Random erase mode (default: "const")') 149 | parser.add_argument('--recount', type=int, default=1, 150 | help='Random erase count (default: 1)') 151 | parser.add_argument('--resplit', action='store_true', default=False, 152 | help='Do not random erase first (clean) augmentation split') 153 | parser.add_argument('--mixup', type=float, default=0.2, 154 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 155 | parser.add_argument('--cutmix', type=float, default=1.0, 156 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 157 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 158 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 159 | parser.add_argument('--mixup-prob', type=float, default=1.0, 160 | help='Probability of performing mixup or cutmix when either/both is enabled') 161 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 162 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 163 | parser.add_argument('--mixup-mode', type=str, default='batch', 164 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 165 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 166 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 167 | parser.add_argument('--smoothing', type=float, default=0.1, 168 | help='Label smoothing (default: 0.1)') 169 | parser.add_argument('--train-interpolation', type=str, default='random', 170 | help='Training interpolation (random, bilinear, bicubic default: "random")') 171 | parser.add_argument('--drop', type=float, default=0.1, metavar='PCT', 172 | help='Dropout rate (default: 0.1)') 173 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 174 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 175 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 176 | help='Drop path rate (default: None)') 177 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 178 | help='Drop block rate (default: None)') 179 | 180 | # Batch norm parameters (only works with gen_efficientnet based models currently) 181 | parser.add_argument('--bn-tf', action='store_true', default=False, 182 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 183 | parser.add_argument('--bn-momentum', type=float, default=None, 184 | help='BatchNorm momentum override (if not None)') 185 | parser.add_argument('--bn-eps', type=float, default=None, 186 | help='BatchNorm epsilon override (if not None)') 187 | parser.add_argument('--sync-bn', action='store_true', 188 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 189 | parser.add_argument('--dist-bn', type=str, default='', 190 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 191 | parser.add_argument('--split-bn', action='store_true', 192 | help='Enable separate BN layers per augmentation split.') 193 | 194 | # Model Exponential Moving Average 195 | parser.add_argument('--model-ema', action='store_true', default=True, 196 | help='Enable tracking moving average of model weights') 197 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 198 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 199 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, 200 | help='decay factor for model weights moving average (default: 0.9998)') 201 | 202 | # Misc 203 | parser.add_argument('--seed', type=int, default=233, metavar='S', 204 | help='random seed (default: 233)') 205 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 206 | help='how many batches to wait before logging training status') 207 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 208 | help='how many batches to wait before writing recovery checkpoint') 209 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 210 | help='how many training processes to use (default: 1)') 211 | parser.add_argument('--num-gpu', type=int, default=1, 212 | help='Number of GPUS to use') 213 | parser.add_argument('--save-images', action='store_true', default=False, 214 | help='save images of input bathes every log interval for debugging') 215 | parser.add_argument('--amp', action='store_true', default=False, 216 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 217 | parser.add_argument('--apex-amp', action='store_true', default=False, 218 | help='Use NVIDIA Apex AMP mixed precision') 219 | parser.add_argument('--native-amp', action='store_true', default=False, 220 | help='Use Native Torch AMP mixed precision') 221 | parser.add_argument('--channels-last', action='store_true', default=False, 222 | help='Use channels_last memory layout') 223 | parser.add_argument('--pin-mem', action='store_true', default=False, 224 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 225 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 226 | help='disable fast prefetcher') 227 | parser.add_argument('--output', default='', type=str, metavar='PATH', 228 | help='path to output folder (default: none, current dir)') 229 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 230 | help='Best metric (default: "top1"') 231 | parser.add_argument('--tta', type=int, default=0, metavar='N', 232 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 233 | parser.add_argument("--local_rank", default=0, type=int) 234 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 235 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 236 | # multi-node 237 | parser.add_argument("--distributed", default=True) 238 | parser.add_argument("--port", default='2333') 239 | 240 | # repeated aug 241 | parser.add_argument('--repeated_aug', default=False) 242 | 243 | try: 244 | from apex import amp 245 | from apex.parallel import DistributedDataParallel as ApexDDP 246 | from apex.parallel import convert_syncbn_model 247 | 248 | has_apex = True 249 | except ImportError: 250 | has_apex = False 251 | 252 | has_native_amp = False 253 | try: 254 | if getattr(torch.cuda.amp, 'autocast') is not None: 255 | has_native_amp = True 256 | except AttributeError: 257 | pass 258 | 259 | 260 | def _parse_args(): 261 | args_config, remaining = config_parser.parse_known_args() 262 | if args_config.config: 263 | with open(args_config.config, 'r') as f: 264 | cfg = yaml.safe_load(f) 265 | parser.set_defaults(**cfg) 266 | 267 | args = parser.parse_args(remaining) 268 | 269 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 270 | return args, args_text 271 | 272 | 273 | def main(): 274 | setup_default_logging() 275 | args, args_text = _parse_args() 276 | 277 | args.prefetcher = not args.no_prefetcher 278 | 279 | args.device = 'cuda:0' 280 | args.world_size = 1 281 | args.rank = 0 # global rank 282 | if args.distributed: 283 | if 'SLURM_NTASKS' in os.environ: 284 | args.rank, args.world_size, args.local_rank = distributed_utils.dist_init( 285 | port=args.port) 286 | args.num_gpu = 1 287 | args.device = 'cuda:%d' % args.local_rank 288 | args.world_size = torch.distributed.get_world_size() 289 | args.rank = torch.distributed.get_rank() 290 | _logger.info('[SLURM] Training in distributed mode. Process %d, local %d, total %d.' % ( 291 | args.rank, args.local_rank, args.world_size)) 292 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 293 | args.num_gpu = 1 294 | args.device = 'cuda:%d' % args.local_rank 295 | torch.cuda.set_device(args.local_rank) 296 | torch.distributed.init_process_group(backend='nccl') 297 | args.world_size = torch.distributed.get_world_size() 298 | args.rank = torch.distributed.get_rank() 299 | _logger.info('[TORCH] Training in distributed mode. Process %d, local %d, total %d.' % ( 300 | args.rank, args.local_rank, args.world_size)) 301 | else: 302 | _logger.error('Unsupported!') 303 | return 304 | assert args.rank >= 0 305 | 306 | if args.distributed: 307 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 308 | % (args.rank, args.world_size)) 309 | else: 310 | _logger.info('Training with a single process on %d GPUs.' % 311 | args.num_gpu) 312 | 313 | torch.manual_seed(args.seed) 314 | 315 | model = create_model( 316 | args.model, 317 | pretrained=args.pretrained, 318 | num_classes=args.num_classes, 319 | drop_rate=args.drop, 320 | drop_connect_rate=args.drop_connect, 321 | drop_path_rate=args.drop_path, 322 | drop_block_rate=args.drop_block, 323 | global_pool=args.gp, 324 | bn_tf=args.bn_tf, 325 | bn_momentum=args.bn_momentum, 326 | bn_eps=args.bn_eps) 327 | 328 | if args.initial_checkpoint: 329 | load_checkpoint(model, args.initial_checkpoint, args.model_ema) 330 | 331 | if args.rank == 0: 332 | _logger.info(model) 333 | _logger.info('Model %s created, param count: %d' % 334 | (args.model, sum([m.numel() for m in model.parameters()]))) 335 | 336 | data_config = resolve_data_config( 337 | vars(args), model=model, verbose=args.rank == 0) 338 | 339 | num_aug_splits = 0 340 | if args.aug_splits > 0: 341 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 342 | num_aug_splits = args.aug_splits 343 | 344 | if args.split_bn: 345 | assert num_aug_splits > 1 or args.resplit 346 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 347 | 348 | use_amp = None 349 | if args.amp: 350 | if has_apex: 351 | args.apex_amp = True 352 | elif has_native_amp: 353 | args.native_amp = True 354 | if args.apex_amp and has_apex: 355 | use_amp = 'apex' 356 | elif args.native_amp and has_native_amp: 357 | use_amp = 'native' 358 | elif args.apex_amp or args.native_amp: 359 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 360 | "Install NVIDA apex or upgrade to PyTorch 1.6") 361 | 362 | if args.num_gpu > 1: 363 | if use_amp == 'apex': 364 | _logger.warning( 365 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') 366 | use_amp = None 367 | model = nn.DataParallel( 368 | model, device_ids=list(range(args.num_gpu))).cuda() 369 | assert not args.channels_last, "Channels last not supported with DP, use DDP." 370 | else: 371 | model.cuda() 372 | if args.channels_last: 373 | model = model.to(memory_format=torch.channels_last) 374 | 375 | optimizer = create_optimizer(args, model) 376 | 377 | amp_autocast = suppress 378 | loss_scaler = None 379 | if use_amp == 'apex': 380 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 381 | loss_scaler = ApexScaler() 382 | if args.rank == 0: 383 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 384 | elif use_amp == 'native': 385 | amp_autocast = torch.cuda.amp.autocast 386 | loss_scaler = NativeScaler() 387 | if args.rank == 0: 388 | _logger.info( 389 | 'Using native Torch AMP. Training in mixed precision.') 390 | else: 391 | if args.rank == 0: 392 | _logger.info('AMP not enabled. Training in float32.') 393 | 394 | resume_epoch = None 395 | if args.resume: 396 | resume_epoch = resume_checkpoint( 397 | model, args.resume, 398 | optimizer=None if args.no_resume_opt else optimizer, 399 | loss_scaler=None if args.no_resume_opt else loss_scaler, 400 | log_info=args.rank == 0) 401 | 402 | model_ema = None 403 | if args.model_ema: 404 | model_ema = ModelEma( 405 | model, 406 | decay=args.model_ema_decay, 407 | device='cpu' if args.model_ema_force_cpu else '', 408 | resume=args.resume) 409 | 410 | if args.distributed: 411 | if args.sync_bn: 412 | assert not args.split_bn 413 | try: 414 | if has_apex and use_amp != 'native': 415 | model = convert_syncbn_model(model) 416 | else: 417 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( 418 | model) 419 | if args.rank == 0: 420 | _logger.info( 421 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 422 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 423 | except Exception as e: 424 | _logger.error( 425 | 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') 426 | if has_apex and use_amp != 'native': 427 | if args.rank == 0: 428 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 429 | model = ApexDDP(model, delay_allreduce=True) 430 | else: 431 | if args.rank == 0: 432 | _logger.info("Using native Torch DistributedDataParallel.") 433 | model = NativeDDP(model, device_ids=[args.local_rank]) 434 | 435 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 436 | start_epoch = 0 437 | if args.start_epoch is not None: 438 | start_epoch = args.start_epoch 439 | elif resume_epoch is not None: 440 | start_epoch = resume_epoch 441 | if lr_scheduler is not None and start_epoch > 0: 442 | lr_scheduler.step(start_epoch) 443 | 444 | if args.rank == 0: 445 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 446 | 447 | if args.data is not None: 448 | train_dir = os.path.join(args.data, 'train') 449 | if not os.path.exists(train_dir): 450 | _logger.error( 451 | 'Training folder does not exist at: {}'.format(train_dir)) 452 | exit(1) 453 | dataset_train = Dataset(train_dir) 454 | else: 455 | dataset_train = ListDataset(args.data_train_root, 456 | args.data_train_label) 457 | _logger.info('Loaded %d imgs from %s with ListDataset for train' % 458 | (len(dataset_train), args.data_train_label)) 459 | 460 | collate_fn = None 461 | mixup_fn = None 462 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 463 | if mixup_active: 464 | mixup_args = dict( 465 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 466 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 467 | label_smoothing=args.smoothing, num_classes=args.num_classes) 468 | if args.prefetcher: 469 | assert not num_aug_splits 470 | collate_fn = FastCollateMixup(**mixup_args) 471 | else: 472 | mixup_fn = Mixup(**mixup_args) 473 | 474 | if num_aug_splits > 1: 475 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 476 | 477 | train_interpolation = args.train_interpolation 478 | if args.no_aug or not train_interpolation: 479 | train_interpolation = data_config['interpolation'] 480 | loader_train = create_loader( 481 | dataset_train, 482 | input_size=data_config['input_size'], 483 | batch_size=args.batch_size, 484 | is_training=True, 485 | use_prefetcher=args.prefetcher, 486 | no_aug=args.no_aug, 487 | re_prob=args.reprob, 488 | re_mode=args.remode, 489 | re_count=args.recount, 490 | re_split=args.resplit, 491 | scale=args.scale, 492 | ratio=args.ratio, 493 | hflip=args.hflip, 494 | vflip=args.vflip, 495 | color_jitter=args.color_jitter, 496 | auto_augment=args.aa, 497 | num_aug_splits=num_aug_splits, 498 | interpolation=train_interpolation, 499 | mean=data_config['mean'], 500 | std=data_config['std'], 501 | num_workers=args.workers, 502 | distributed=args.distributed, 503 | collate_fn=collate_fn, 504 | pin_memory=args.pin_mem, 505 | use_multi_epochs_loader=args.use_multi_epochs_loader, 506 | repeated_aug=args.repeated_aug 507 | ) 508 | 509 | if args.data is not None: 510 | eval_dir = os.path.join(args.data, 'val') 511 | if not os.path.isdir(eval_dir): 512 | eval_dir = os.path.join(args.data, 'validation') 513 | if not os.path.isdir(eval_dir): 514 | _logger.error( 515 | 'Validation folder does not exist at: {}'.format(eval_dir)) 516 | exit(1) 517 | dataset_eval = Dataset(eval_dir) 518 | else: 519 | dataset_eval = ListDataset(args.data_val_root, 520 | args.data_val_label) 521 | _logger.info('Loaded %d imgs from %s with ListDataset for eval' % 522 | (len(dataset_eval), args.data_val_label)) 523 | 524 | loader_eval = create_loader( 525 | dataset_eval, 526 | input_size=data_config['input_size'], 527 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 528 | is_training=False, 529 | use_prefetcher=args.prefetcher, 530 | interpolation=data_config['interpolation'], 531 | mean=data_config['mean'], 532 | std=data_config['std'], 533 | num_workers=args.workers, 534 | distributed=args.distributed, 535 | crop_pct=data_config['crop_pct'], 536 | pin_memory=args.pin_mem, 537 | ) 538 | 539 | if args.jsd: 540 | assert num_aug_splits > 1 541 | train_loss_fn = JsdCrossEntropy( 542 | num_splits=num_aug_splits, smoothing=args.smoothing).cuda() 543 | elif mixup_active: 544 | train_loss_fn = SoftTargetCrossEntropy().cuda() 545 | elif args.smoothing: 546 | train_loss_fn = LabelSmoothingCrossEntropy( 547 | smoothing=args.smoothing).cuda() 548 | else: 549 | train_loss_fn = nn.CrossEntropyLoss().cuda() 550 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 551 | 552 | eval_metric = args.eval_metric 553 | best_metric = None 554 | best_epoch = None 555 | 556 | if args.eval_checkpoint: 557 | load_checkpoint(model.module, args.eval_checkpoint, args.model_ema) 558 | val_metrics = validate(model, loader_eval, validate_loss_fn, args) 559 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%") 560 | return 561 | 562 | saver = None 563 | output_dir = '' 564 | if args.rank == 0: 565 | output_base = args.output if args.output else './output' 566 | exp_name = '-'.join([ 567 | datetime.now().strftime("%Y%m%d-%H%M%S"), 568 | args.model, 569 | str(data_config['input_size'][-1]) 570 | ]) 571 | output_dir = get_outdir(output_base, 'train', exp_name) 572 | decreasing = True if eval_metric == 'loss' else False 573 | saver = CheckpointSaver( 574 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 575 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) 576 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 577 | f.write(args_text) 578 | 579 | try: 580 | for epoch in range(start_epoch, num_epochs): 581 | if args.distributed: 582 | loader_train.sampler.set_epoch(epoch) 583 | 584 | train_metrics = train_epoch( 585 | epoch, model, loader_train, optimizer, train_loss_fn, args, 586 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 587 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) 588 | 589 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 590 | if args.rank == 0: 591 | _logger.info( 592 | "Distributing BatchNorm running means and vars") 593 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 594 | 595 | eval_metrics = validate( 596 | model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 597 | 598 | if model_ema is not None and not args.model_ema_force_cpu: 599 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 600 | distribute_bn(model_ema, args.world_size, 601 | args.dist_bn == 'reduce') 602 | ema_eval_metrics = validate( 603 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') 604 | eval_metrics = ema_eval_metrics 605 | 606 | if lr_scheduler is not None: 607 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 608 | 609 | update_summary( 610 | epoch, train_metrics, eval_metrics, os.path.join( 611 | output_dir, 'summary.csv'), 612 | write_header=best_metric is None) 613 | 614 | if saver is not None: 615 | save_metric = eval_metrics[eval_metric] 616 | best_metric, best_epoch = saver.save_checkpoint( 617 | epoch, metric=save_metric) 618 | 619 | except KeyboardInterrupt: 620 | pass 621 | if best_metric is not None: 622 | _logger.info( 623 | '*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 624 | 625 | 626 | def train_epoch( 627 | epoch, model, loader, optimizer, loss_fn, args, 628 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, 629 | loss_scaler=None, model_ema=None, mixup_fn=None): 630 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 631 | if args.prefetcher and loader.mixup_enabled: 632 | loader.mixup_enabled = False 633 | elif mixup_fn is not None: 634 | mixup_fn.mixup_enabled = False 635 | 636 | second_order = hasattr( 637 | optimizer, 'is_second_order') and optimizer.is_second_order 638 | batch_time_m = AverageMeter() 639 | data_time_m = AverageMeter() 640 | losses_m = AverageMeter() 641 | top1_m = AverageMeter() 642 | top5_m = AverageMeter() 643 | 644 | model.train() 645 | 646 | end = time.time() 647 | last_idx = len(loader) - 1 648 | num_updates = epoch * len(loader) 649 | for batch_idx, (input, target) in enumerate(loader): 650 | last_batch = batch_idx == last_idx 651 | data_time_m.update(time.time() - end) 652 | if not args.prefetcher: 653 | input, target = input.cuda(), target.cuda() 654 | if mixup_fn is not None: 655 | input, target = mixup_fn(input, target) 656 | if args.channels_last: 657 | input = input.contiguous(memory_format=torch.channels_last) 658 | 659 | with amp_autocast(): 660 | output = model(input) 661 | loss = loss_fn(output, target) 662 | 663 | if not args.distributed: 664 | losses_m.update(loss.item(), input.size(0)) 665 | 666 | optimizer.zero_grad() 667 | if loss_scaler is not None: 668 | loss_scaler( 669 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) 670 | else: 671 | loss.backward(create_graph=second_order) 672 | if args.clip_grad is not None: 673 | torch.nn.utils.clip_grad_norm_( 674 | model.parameters(), args.clip_grad) 675 | optimizer.step() 676 | 677 | torch.cuda.synchronize() 678 | if model_ema is not None: 679 | model_ema.update(model) 680 | num_updates += 1 681 | 682 | batch_time_m.update(time.time() - end) 683 | if last_batch or batch_idx % args.log_interval == 0: 684 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 685 | lr = sum(lrl) / len(lrl) 686 | 687 | if args.distributed: 688 | reduced_loss = reduce_tensor(loss.data, args.world_size) 689 | losses_m.update(reduced_loss.item(), input.size(0)) 690 | 691 | if args.rank == 0: 692 | _logger.info( 693 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 694 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 695 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 696 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 697 | 'LR: {lr:.3e} ' 698 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 699 | epoch, 700 | batch_idx, len(loader), 701 | 100. * batch_idx / last_idx, 702 | loss=losses_m, 703 | batch_time=batch_time_m, 704 | rate=input.size(0) * args.world_size / 705 | batch_time_m.val, 706 | rate_avg=input.size( 707 | 0) * args.world_size / batch_time_m.avg, 708 | lr=lr, 709 | data_time=data_time_m)) 710 | 711 | if args.save_images and output_dir: 712 | torchvision.utils.save_image( 713 | input, 714 | os.path.join( 715 | output_dir, 'train-batch-%d.jpg' % batch_idx), 716 | padding=0, 717 | normalize=True) 718 | 719 | if saver is not None and args.recovery_interval and ( 720 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 721 | saver.save_recovery(epoch, batch_idx=batch_idx) 722 | 723 | if lr_scheduler is not None: 724 | lr_scheduler.step_update( 725 | num_updates=num_updates, metric=losses_m.avg) 726 | 727 | end = time.time() 728 | 729 | if hasattr(optimizer, 'sync_lookahead'): 730 | optimizer.sync_lookahead() 731 | 732 | return OrderedDict([('loss', losses_m.avg)]) 733 | 734 | 735 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 736 | batch_time_m = AverageMeter() 737 | losses_m = AverageMeter() 738 | top1_m = AverageMeter() 739 | top5_m = AverageMeter() 740 | 741 | model.eval() 742 | 743 | end = time.time() 744 | last_idx = len(loader) - 1 745 | with torch.no_grad(): 746 | for batch_idx, (input, target) in enumerate(loader): 747 | last_batch = batch_idx == last_idx 748 | if not args.prefetcher: 749 | input = input.cuda() 750 | target = target.cuda() 751 | if args.channels_last: 752 | input = input.contiguous(memory_format=torch.channels_last) 753 | 754 | with amp_autocast(): 755 | output = model(input) 756 | if isinstance(output, (tuple, list)): 757 | output = output[0] 758 | 759 | reduce_factor = args.tta 760 | if reduce_factor > 1: 761 | output = output.unfold( 762 | 0, reduce_factor, reduce_factor).mean(dim=2) 763 | target = target[0:target.size(0):reduce_factor] 764 | 765 | loss = loss_fn(output, target) 766 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 767 | 768 | if args.distributed: 769 | reduced_loss = reduce_tensor(loss.data, args.world_size) 770 | acc1 = reduce_tensor(acc1, args.world_size) 771 | acc5 = reduce_tensor(acc5, args.world_size) 772 | else: 773 | reduced_loss = loss.data 774 | 775 | torch.cuda.synchronize() 776 | 777 | losses_m.update(reduced_loss.item(), input.size(0)) 778 | top1_m.update(acc1.item(), output.size(0)) 779 | top5_m.update(acc5.item(), output.size(0)) 780 | 781 | batch_time_m.update(time.time() - end) 782 | end = time.time() 783 | if args.rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 784 | log_name = 'Test' + log_suffix 785 | _logger.info( 786 | '{0}: [{1:>4d}/{2}] ' 787 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 788 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 789 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 790 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 791 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 792 | loss=losses_m, top1=top1_m, top5=top5_m)) 793 | 794 | metrics = OrderedDict( 795 | [('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 796 | if args.rank == 0: 797 | _logger.info('[VAL] Acc@1: %.7f Acc@5: %.7f' % 798 | (top1_m.avg, top5_m.avg)) 799 | 800 | return metrics 801 | 802 | 803 | if __name__ == '__main__': 804 | main() 805 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ps_vit import * 2 | -------------------------------------------------------------------------------- /models/ps_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.models.helpers import load_pretrained 5 | from timm.models.registry import register_model 6 | from timm.models.layers import trunc_normal_ 7 | 8 | from layers import ProgressiveSample 9 | from .transformer_block import TransformerEncoderLayer 10 | 11 | 12 | def conv3x3(in_planes, 13 | out_planes, 14 | stride=1, 15 | groups=1, 16 | dilation=1): 17 | return nn.Conv2d(in_planes, 18 | out_planes, 19 | kernel_size=3, 20 | stride=stride, 21 | padding=dilation, 22 | groups=groups, 23 | bias=False, 24 | dilation=dilation) 25 | 26 | 27 | def conv1x1(in_planes, 28 | out_planes, 29 | stride=1): 30 | return nn.Conv2d(in_planes, 31 | out_planes, 32 | kernel_size=1, 33 | stride=stride, 34 | bias=False) 35 | 36 | 37 | class BottleneckLayer(nn.Module): 38 | def __init__(self, 39 | in_channels, 40 | inter_channels, 41 | out_channels): 42 | super().__init__() 43 | self.conv1 = conv1x1(in_channels, 44 | inter_channels) 45 | self.bn1 = nn.BatchNorm2d(inter_channels) 46 | 47 | self.conv2 = conv3x3(inter_channels, 48 | inter_channels) 49 | self.bn2 = nn.BatchNorm2d(inter_channels) 50 | 51 | self.conv3 = conv1x1(inter_channels, 52 | out_channels) 53 | self.bn3 = nn.BatchNorm2d(out_channels) 54 | 55 | self.relu = nn.ReLU(inplace=True) 56 | 57 | self.downsample = None 58 | if in_channels != out_channels: 59 | self.downsample = nn.Sequential(conv1x1(in_channels, out_channels), 60 | nn.BatchNorm2d(out_channels)) 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class PSViTLayer(nn.Module): 86 | def __init__(self, 87 | feat_size, 88 | dim, 89 | num_heads, 90 | mlp_ratio=4., 91 | qkv_bias=False, 92 | qk_scale=None, 93 | drop=0., 94 | attn_drop=0., 95 | drop_path=0., 96 | act_layer=nn.GELU, 97 | norm_layer=nn.LayerNorm, 98 | position_layer=None, 99 | pred_offset=True, 100 | gamma=0.1, 101 | offset_bias=False): 102 | super().__init__() 103 | 104 | self.feat_size = float(feat_size) 105 | 106 | self.transformer_layer = TransformerEncoderLayer(dim, 107 | num_heads, 108 | mlp_ratio, 109 | qkv_bias, 110 | qk_scale, 111 | drop, 112 | attn_drop, 113 | drop_path, 114 | act_layer, 115 | norm_layer) 116 | self.sampler = ProgressiveSample(gamma) 117 | 118 | self.position_layer = position_layer 119 | if self.position_layer is None: 120 | self.position_layer = nn.Linear(2, dim) 121 | 122 | self.offset_layer = None 123 | if pred_offset: 124 | self.offset_layer = nn.Linear(dim, 2, bias=offset_bias) 125 | 126 | def reset_offset_weight(self): 127 | if self.offset_layer is None: 128 | return 129 | nn.init.constant_(self.offset_layer.weight, 0) 130 | if self.offset_layer.bias is not None: 131 | nn.init.constant_(self.offset_layer.bias, 0) 132 | 133 | def forward(self, 134 | x, 135 | point, 136 | offset=None, 137 | pre_out=None): 138 | """ 139 | :param x: [n, dim, h, w] 140 | :param point: [n, point_num, 2] 141 | :param offset: [n, point_num, 2] 142 | :param pre_out: [n, point_num, dim] 143 | """ 144 | if offset is None: 145 | offset = torch.zeros_like(point) 146 | 147 | sample_feat = self.sampler(x, point, offset) 148 | sample_point = point + offset.detach() 149 | 150 | pos_feat = self.position_layer(sample_point / self.feat_size) 151 | 152 | attn_feat = sample_feat + pos_feat 153 | if pre_out is not None: 154 | attn_feat = attn_feat + pre_out 155 | 156 | attn_feat = self.transformer_layer(attn_feat) 157 | 158 | out_offset = None 159 | if self.offset_layer is not None: 160 | out_offset = self.offset_layer(attn_feat) 161 | 162 | return attn_feat, out_offset, sample_point 163 | 164 | 165 | class PSViT(nn.Module): 166 | def __init__(self, 167 | img_size=224, 168 | num_point_w=14, 169 | num_point_h=14, 170 | in_chans=3, 171 | downsample_ratio=4, 172 | num_classes=1000, 173 | num_iters=4, 174 | depth=14, 175 | embed_dim=384, 176 | num_heads=12, 177 | mlp_ratio=4., 178 | qkv_bias=False, 179 | qk_scale=None, 180 | drop_rate=0., 181 | attn_drop_rate=0., 182 | drop_path_rate=0., 183 | norm_layer=nn.LayerNorm, 184 | stem_layer=None, 185 | offset_gamma=0.1, 186 | offset_bias=False, 187 | with_cls_token=False): 188 | super().__init__() 189 | self.num_classes = num_classes 190 | self.embed_dim = embed_dim 191 | assert num_iters >= 1 192 | 193 | self.img_size = img_size 194 | self.feat_size = img_size // downsample_ratio 195 | 196 | self.num_point_w = num_point_w 197 | self.num_point_h = num_point_h 198 | 199 | self.register_buffer('point_coord', self._get_initial_point()) 200 | 201 | self.pos_layer = nn.Linear(2, self.embed_dim) 202 | 203 | self.stem = stem_layer 204 | if self.stem is None: 205 | self.stem = nn.Sequential(nn.Conv2d(in_chans, 206 | 64, 207 | kernel_size=7, 208 | padding=3, 209 | stride=2, 210 | bias=False), 211 | nn.BatchNorm2d(64), 212 | nn.ReLU(inplace=True), 213 | nn.MaxPool2d(kernel_size=3, 214 | stride=2, 215 | padding=1), 216 | BottleneckLayer(64, 64, self.embed_dim), 217 | BottleneckLayer(self.embed_dim, 64, self.embed_dim)) 218 | 219 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 220 | self.ps_layers = nn.ModuleList() 221 | for i in range(num_iters): 222 | self.ps_layers.append(PSViTLayer(feat_size=self.feat_size, 223 | dim=self.embed_dim, 224 | num_heads=num_heads, 225 | mlp_ratio=mlp_ratio, 226 | qkv_bias=qkv_bias, 227 | qk_scale=qk_scale, 228 | drop=drop_rate, 229 | attn_drop=attn_drop_rate, 230 | drop_path=dpr[i], 231 | norm_layer=norm_layer, 232 | position_layer=self.pos_layer, 233 | pred_offset=i < num_iters - 1, 234 | gamma=offset_gamma, 235 | offset_bias=offset_bias)) 236 | 237 | self.trans_layers = nn.ModuleList() 238 | trans_depth = depth - num_iters 239 | for i in range(trans_depth): 240 | self.trans_layers.append(TransformerEncoderLayer(dim=self.embed_dim, 241 | num_heads=num_heads, 242 | mlp_ratio=mlp_ratio, 243 | qkv_bias=qkv_bias, 244 | qk_scale=qk_scale, 245 | drop=drop_rate, 246 | attn_drop=attn_drop_rate, 247 | drop_path=dpr[i + 248 | num_iters], 249 | norm_layer=norm_layer)) 250 | 251 | self.norm = norm_layer(embed_dim) 252 | 253 | self.head = nn.Linear(embed_dim, num_classes) 254 | 255 | self.cls_token = None 256 | if with_cls_token: 257 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 258 | trunc_normal_(self.cls_token, std=.02) 259 | else: 260 | self.avgpool = nn.AdaptiveAvgPool1d(1) 261 | 262 | self.apply(self._init_weights) 263 | for layer in self.ps_layers: 264 | layer.reset_offset_weight() 265 | 266 | def _init_weights(self, m): 267 | if isinstance(m, nn.Linear): 268 | trunc_normal_(m.weight, std=.02) 269 | if isinstance(m, nn.Linear) and m.bias is not None: 270 | nn.init.constant_(m.bias, 0) 271 | elif isinstance(m, nn.LayerNorm): 272 | nn.init.constant_(m.bias, 0) 273 | nn.init.constant_(m.weight, 1.0) 274 | 275 | def _get_initial_point(self): 276 | patch_size_w = self.feat_size / self.num_point_w 277 | patch_size_h = self.feat_size / self.num_point_h 278 | coord_w = torch.Tensor( 279 | [i * patch_size_w for i in range(self.num_point_w)]) 280 | coord_w += patch_size_w / 2 281 | coord_h = torch.Tensor( 282 | [i * patch_size_h for i in range(self.num_point_h)]) 283 | coord_h += patch_size_h / 2 284 | 285 | grid_x, grid_y = torch.meshgrid(coord_w, coord_h) 286 | grid_x = grid_x.unsqueeze(0) 287 | grid_y = grid_y.unsqueeze(0) 288 | point_coord = torch.cat([grid_y, grid_x], dim=0) 289 | point_coord = point_coord.view(2, -1) 290 | point_coord = point_coord.permute(1, 0).contiguous().unsqueeze(0) 291 | 292 | return point_coord 293 | 294 | @torch.jit.ignore 295 | def no_weight_decay(self): 296 | if self.cls_token is not None: 297 | return {'cls_token'} 298 | else: 299 | return {} 300 | 301 | def get_classifier(self): 302 | return self.head 303 | 304 | def reset_classifier(self, num_classes, global_pool=''): 305 | self.num_classes = num_classes 306 | self.head = nn.Linear(self.embed_dim, num_classes) 307 | 308 | def forward_feature(self, x): 309 | batch_size = x.size(0) 310 | point = self.point_coord.repeat(batch_size, 1, 1) 311 | 312 | x = self.stem(x) 313 | 314 | ps_out = None 315 | offset = None 316 | 317 | for layer in self.ps_layers: 318 | ps_out, offset, point = layer(x, 319 | point, 320 | offset, 321 | ps_out) 322 | 323 | if self.cls_token is not None: 324 | cls_token = self.cls_token.expand(batch_size, -1, -1) 325 | trans_out = torch.cat((cls_token, ps_out), dim=1) 326 | else: 327 | trans_out = ps_out 328 | 329 | for layer in self.trans_layers: 330 | trans_out = layer(trans_out) 331 | 332 | trans_out = self.norm(trans_out) 333 | 334 | if self.cls_token is not None: 335 | out_feat = trans_out[:, 0] 336 | else: 337 | trans_out = trans_out.permute(0, 2, 1) 338 | out_feat = self.avgpool(trans_out).view(batch_size, self.embed_dim) 339 | 340 | return out_feat 341 | 342 | def forward(self, x): 343 | assert x.shape[-1] == self.img_size and x.shape[-2] == self.img_size 344 | x = self.forward_feature(x) 345 | 346 | out = self.head(x) 347 | 348 | return out 349 | 350 | 351 | def _default_cfg(url='', **kwargs): 352 | return { 353 | 'url': url, 354 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 355 | 'crop_pct': .9, 'interpolation': 'bicubic', 356 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 357 | 'classifier': 'head', 358 | **kwargs 359 | } 360 | 361 | 362 | @register_model 363 | def ps_vit_b_14(pretrained=False, **kwargs): 364 | if pretrained: 365 | kwargs.setdefault('qk_scale', 384 ** -0.5) 366 | 367 | stem = nn.Sequential(nn.Conv2d(kwargs.get('in_chans', 3), 368 | 64, 369 | kernel_size=7, 370 | padding=3, 371 | stride=2, 372 | bias=False), 373 | nn.BatchNorm2d(64), 374 | nn.ReLU(inplace=True), 375 | nn.MaxPool2d(kernel_size=3, 376 | stride=2, 377 | padding=1), 378 | BottleneckLayer(64, 64, 256), 379 | BottleneckLayer(256, 64, 384)) 380 | 381 | model = PSViT(embed_dim=384, 382 | num_iters=4, 383 | depth=14, 384 | num_heads=6, 385 | mlp_ratio=3., 386 | stem_layer=stem, 387 | downsample_ratio=4, 388 | offset_gamma=1.0, 389 | offset_bias=True, 390 | with_cls_token=True, 391 | **kwargs) 392 | model.default_cfg = _default_cfg() 393 | if pretrained: 394 | load_pretrained(model, num_classes=model.num_classes, 395 | in_chans=kwargs.get('in_chans', 3)) 396 | return model 397 | 398 | 399 | @register_model 400 | def ps_vit_b_16(pretrained=False, **kwargs): 401 | if pretrained: 402 | kwargs.setdefault('qk_scale', 384 ** -0.5) 403 | 404 | stem = nn.Sequential(nn.Conv2d(kwargs.get('in_chans', 3), 405 | 64, 406 | kernel_size=7, 407 | padding=3, 408 | stride=2, 409 | bias=False), 410 | nn.BatchNorm2d(64), 411 | nn.ReLU(inplace=True), 412 | nn.MaxPool2d(kernel_size=3, 413 | stride=2, 414 | padding=1), 415 | BottleneckLayer(64, 64, 256), 416 | BottleneckLayer(256, 64, 384)) 417 | 418 | model = PSViT(embed_dim=384, 419 | num_iters=4, 420 | num_point_h=16, 421 | num_point_w=16, 422 | depth=14, 423 | num_heads=6, 424 | mlp_ratio=3., 425 | stem_layer=stem, 426 | downsample_ratio=4, 427 | offset_gamma=1.0, 428 | offset_bias=True, 429 | with_cls_token=True, 430 | **kwargs) 431 | model.default_cfg = _default_cfg() 432 | if pretrained: 433 | load_pretrained(model, num_classes=model.num_classes, 434 | in_chans=kwargs.get('in_chans', 3)) 435 | return model 436 | 437 | 438 | @register_model 439 | def ps_vit_b_18(pretrained=False, **kwargs): 440 | if pretrained: 441 | kwargs.setdefault('qk_scale', 384 ** -0.5) 442 | 443 | stem = nn.Sequential(nn.Conv2d(kwargs.get('in_chans', 3), 444 | 64, 445 | kernel_size=7, 446 | padding=3, 447 | stride=2, 448 | bias=False), 449 | nn.BatchNorm2d(64), 450 | nn.ReLU(inplace=True), 451 | nn.MaxPool2d(kernel_size=3, 452 | stride=2, 453 | padding=1), 454 | BottleneckLayer(64, 64, 256), 455 | BottleneckLayer(256, 64, 384)) 456 | 457 | model = PSViT(embed_dim=384, 458 | num_iters=4, 459 | num_point_h=18, 460 | num_point_w=18, 461 | depth=14, 462 | num_heads=6, 463 | mlp_ratio=3., 464 | stem_layer=stem, 465 | downsample_ratio=4, 466 | offset_gamma=1.0, 467 | offset_bias=True, 468 | with_cls_token=True, 469 | **kwargs) 470 | model.default_cfg = _default_cfg() 471 | if pretrained: 472 | load_pretrained(model, num_classes=model.num_classes, 473 | in_chans=kwargs.get('in_chans', 3)) 474 | return model 475 | 476 | 477 | @register_model 478 | def ps_vit_ti_14(pretrained=False, **kwargs): 479 | if pretrained: 480 | kwargs.setdefault('qk_scale', 192 ** -0.5) 481 | 482 | stem = nn.Sequential(nn.Conv2d(kwargs.get('in_chans', 3), 483 | 64, 484 | kernel_size=7, 485 | padding=3, 486 | stride=2, 487 | bias=False), 488 | nn.BatchNorm2d(64), 489 | nn.ReLU(inplace=True), 490 | nn.MaxPool2d(kernel_size=3, 491 | stride=2, 492 | padding=1), 493 | BottleneckLayer(64, 64, 192), 494 | BottleneckLayer(192, 64, 192)) 495 | 496 | model = PSViT(embed_dim=192, 497 | num_iters=4, 498 | depth=12, 499 | num_heads=3, 500 | mlp_ratio=3., 501 | stem_layer=stem, 502 | downsample_ratio=4, 503 | offset_gamma=1.0, 504 | offset_bias=True, 505 | with_cls_token=True, 506 | **kwargs) 507 | model.default_cfg = _default_cfg() 508 | if pretrained: 509 | load_pretrained(model, num_classes=model.num_classes, 510 | in_chans=kwargs.get('in_chans', 3)) 511 | return model 512 | -------------------------------------------------------------------------------- /models/transformer_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | 7 | class Mlp(nn.Module): 8 | def __init__(self, 9 | in_features, 10 | hidden_features=None, 11 | out_features=None, 12 | act_layer=nn.GELU, 13 | drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.act = act_layer() 19 | self.fc2 = nn.Linear(hidden_features, out_features) 20 | self.drop = nn.Dropout(drop) 21 | 22 | def forward(self, x): 23 | x = self.fc1(x) 24 | x = self.act(x) 25 | x = self.drop(x) 26 | x = self.fc2(x) 27 | x = self.drop(x) 28 | return x 29 | 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, 33 | dim, 34 | num_heads=8, 35 | qkv_bias=False, 36 | qk_scale=None, 37 | attn_drop=0., 38 | proj_drop=0.): 39 | super().__init__() 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | 43 | self.scale = qk_scale or head_dim ** -0.5 44 | 45 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 46 | self.attn_drop = nn.Dropout(attn_drop) 47 | self.proj = nn.Linear(dim, dim) 48 | self.proj_drop = nn.Dropout(proj_drop) 49 | 50 | def forward(self, x): 51 | B, N, C = x.shape 52 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 53 | q, k, v = qkv[0], qkv[1], qkv[2] 54 | 55 | attn = (q @ k.transpose(-2, -1)) * self.scale 56 | attn = attn.softmax(dim=-1) 57 | attn = self.attn_drop(attn) 58 | 59 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 60 | x = self.proj(x) 61 | x = self.proj_drop(x) 62 | return x 63 | 64 | 65 | class TransformerEncoderLayer(nn.Module): 66 | def __init__(self, 67 | dim, 68 | num_heads, 69 | mlp_ratio=4., 70 | qkv_bias=False, 71 | qk_scale=None, 72 | drop=0., 73 | attn_drop=0., 74 | drop_path=0., 75 | act_layer=nn.GELU, 76 | norm_layer=nn.LayerNorm): 77 | super().__init__() 78 | self.norm1 = norm_layer(dim) 79 | self.attn = Attention(dim, 80 | num_heads=num_heads, 81 | qkv_bias=qkv_bias, 82 | qk_scale=qk_scale, 83 | attn_drop=attn_drop, 84 | proj_drop=drop) 85 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 86 | self.norm2 = norm_layer(dim) 87 | mlp_hidden_dim = int(dim * mlp_ratio) 88 | self.mlp = Mlp(in_features=dim, 89 | hidden_features=mlp_hidden_dim, 90 | act_layer=act_layer, 91 | drop=drop) 92 | 93 | def forward(self, x): 94 | x = x + self.drop_path(self.attn(self.norm1(x))) 95 | x = x + self.drop_path(self.mlp(self.norm2(x))) 96 | return x 97 | -------------------------------------------------------------------------------- /scripts/train_distributed.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | NOW="`date +%Y%m%d%H%M%S`" 4 | JOB_NAME=$1 5 | CONFIG=$2 6 | NUM_PROC=$3 7 | MASTER_PORT=2333 8 | 9 | if [ ! -d "output" ];then 10 | mkdir output 11 | fi 12 | 13 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC --master_port=${MASTER_PORT} main.py --config=${CONFIG} --distributed=True 2>&1 | tee output/${JOB_NAME}_${NOW}.log 14 | -------------------------------------------------------------------------------- /scripts/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | NOW="`date +%Y%m%d%H%M%S`" 4 | PARTITION=$1 5 | JOB_NAME=$2 6 | CONFIG=$3 7 | NUM_PROC=$4 8 | 9 | if [ ! -d "output" ];then 10 | mkdir output 11 | fi 12 | 13 | srun --mpi=pmi2 -n${NUM_PROC} -p ${PARTITION} --gres=gpu:8 \ 14 | --ntasks-per-node=8 --cpus-per-task=5 --job-name=${JOB_NAME} \ 15 | python main.py --config ${CONFIG} \ 16 | 2>&1 | tee output/${JOB_NAME}_${NOW}.log 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from setuptools import find_packages, setup 4 | 5 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 6 | 7 | 8 | def get_extensions(): 9 | extensions = [] 10 | ext_name = '_ext' 11 | op_files = glob.glob('./layers/csrc/*') 12 | print(op_files) 13 | include_path = os.path.abspath('./layers/cinclude') 14 | 15 | extensions.append(CUDAExtension( 16 | name=ext_name, 17 | sources=op_files, 18 | include_dirs=[include_path] 19 | )) 20 | 21 | return extensions 22 | 23 | 24 | if __name__ == "__main__": 25 | setup( 26 | name='ps_vit', 27 | version='0.0.1', 28 | description='vision transformer with progressive sampling', 29 | packages=find_packages(), 30 | ext_modules=get_extensions(), 31 | cmdclass={'build_ext': BuildExtension}, 32 | zip_safe=False 33 | ) 34 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuexy/PS-ViT/a63a397e64d07a238a9fcaf392dce4c4596f6636/utils/__init__.py -------------------------------------------------------------------------------- /utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import multiprocessing 4 | 5 | import torch.distributed as t_dist 6 | 7 | 8 | def dist_init(port=2333): 9 | if multiprocessing.get_start_method(allow_none=True) != 'spawn': 10 | multiprocessing.set_start_method('spawn', force=True) 11 | 12 | rank = int(os.environ['SLURM_PROCID']) 13 | world_size = os.environ['SLURM_NTASKS'] 14 | node_list = os.environ['SLURM_NODELIST'] 15 | num_gpus = torch.cuda.device_count() 16 | gpu_id = rank % num_gpus 17 | torch.cuda.set_device(gpu_id) 18 | 19 | if '[' in node_list: 20 | beg = node_list.find('[') 21 | pos1 = node_list.find('-', beg) 22 | if pos1 < 0: 23 | pos1 = 1000 24 | pos2 = node_list.find(',', beg) 25 | if pos2 < 0: 26 | pos2 = 1000 27 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 28 | addr = node_list[8:].replace('-', '.') 29 | 30 | os.environ['MASTER_PORT'] = port 31 | os.environ['MASTER_ADDR'] = addr 32 | os.environ['WORLD_SIZE'] = world_size 33 | os.environ['RANK'] = str(rank) 34 | 35 | t_dist.init_process_group(backend='nccl') 36 | 37 | return rank, int(world_size), gpu_id 38 | -------------------------------------------------------------------------------- /utils/ext_loader.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | def load_ext(name, funcs): 5 | ext = importlib.import_module(name) 6 | for fun in funcs: 7 | assert hasattr(ext, fun), f'{fun} miss in module {name}' 8 | return ext 9 | -------------------------------------------------------------------------------- /utils/flop_count/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuexy/PS-ViT/a63a397e64d07a238a9fcaf392dce4c4596f6636/utils/flop_count/__init__.py -------------------------------------------------------------------------------- /utils/flop_count/flop_count.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union 3 | 4 | import torch.nn as nn 5 | from torch import Tensor 6 | 7 | from .jit_analysis import JitModelAnalysis 8 | from .jit_handles import ( 9 | Handle, 10 | addmm_flop_jit, 11 | bmm_flop_jit, 12 | conv_flop_jit, 13 | einsum_flop_jit, 14 | elementwise_flop_counter, 15 | linear_flop_jit, 16 | matmul_flop_jit, 17 | norm_flop_counter, 18 | ) 19 | 20 | 21 | # A dictionary that maps supported operations to their flop count jit handles. 22 | _DEFAULT_SUPPORTED_OPS: Dict[str, Handle] = { 23 | "aten::addmm": addmm_flop_jit, 24 | "aten::bmm": bmm_flop_jit, 25 | "aten::_convolution": conv_flop_jit, 26 | "aten::einsum": einsum_flop_jit, 27 | "aten::matmul": matmul_flop_jit, 28 | "aten::linear": linear_flop_jit, 29 | # You might want to ignore BN flops due to inference-time fusion. 30 | # Use `set_op_handle("aten::batch_norm", None) 31 | "aten::batch_norm": norm_flop_counter("batch_norm", 1), 32 | "aten::group_norm": norm_flop_counter("group_norm", 2), 33 | "aten::layer_norm": norm_flop_counter("layer_norm", 2), 34 | "aten::instance_norm": norm_flop_counter("instance_norm", 1), 35 | "aten::upsample_nearest2d": elementwise_flop_counter("upsample_nearest2d", 0, 1), 36 | "aten::upsample_bilinear2d": elementwise_flop_counter("upsample_bilinear2d", 0, 4), 37 | "aten::adaptive_avg_pool2d": elementwise_flop_counter("adaptive_avg_pool2d", 1, 0), 38 | "aten::grid_sampler": elementwise_flop_counter( 39 | "grid_sampler", 0, 4 40 | ), # assume bilinear 41 | } 42 | 43 | 44 | class FlopCountAnalysis(JitModelAnalysis): 45 | """ 46 | Provides access to per-submodule model flop count obtained by 47 | tracing a model with pytorch's jit tracing functionality. By default, 48 | comes with standard flop counters for a few common operators. 49 | Note that: 50 | 1. Flop is not a well-defined concept. We just produce our best estimate. 51 | 2. We count one fused multiply-add as one flop. 52 | Handles for additional operators may be added, or the default ones 53 | overwritten, using the ``.set_op_handle(name, func)`` method. 54 | See the method documentation for details. 55 | Flop counts can be obtained as: 56 | * ``.total(module_name="")``: total flop count for the module 57 | * ``.by_operator(module_name="")``: flop counts for the module, as a Counter 58 | over different operator types 59 | * ``.by_module()``: Counter of flop counts for all submodules 60 | * ``.by_module_and_operator()``: dictionary indexed by descendant of Counters 61 | over different operator types 62 | An operator is treated as within a module if it is executed inside the 63 | module's ``__call__`` method. Note that this does not include calls to 64 | other methods of the module or explicit calls to ``module.forward(...)``. 65 | Example usage: 66 | >>> import torch.nn as nn 67 | >>> import torch 68 | >>> class TestModel(nn.Module): 69 | ... def __init__(self): 70 | ... super().__init__() 71 | ... self.fc = nn.Linear(in_features=1000, out_features=10) 72 | ... self.conv = nn.Conv2d( 73 | ... in_channels=3, out_channels=10, kernel_size=1 74 | ... ) 75 | ... self.act = nn.ReLU() 76 | ... def forward(self, x): 77 | ... return self.fc(self.act(self.conv(x)).flatten(1)) 78 | >>> model = TestModel() 79 | >>> inputs = (torch.randn((1,3,10,10)),) 80 | >>> flops = FlopCountAnalysis(model, inputs) 81 | >>> flops.total() 82 | 13000 83 | >>> flops.total("fc") 84 | 10000 85 | >>> flops.by_operator() 86 | Counter({"addmm" : 10000, "conv" : 3000}) 87 | >>> flops.by_module() 88 | Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0}) 89 | >>> flops.by_module_and_operator() 90 | {"" : Counter({"addmm" : 10000, "conv" : 3000}), 91 | "fc" : Counter({"addmm" : 10000}), 92 | "conv" : Counter({"conv" : 3000}), 93 | "act" : Counter() 94 | } 95 | """ 96 | 97 | def __init__( 98 | self, 99 | model: nn.Module, 100 | inputs: Union[Tensor, Tuple[Tensor, ...]], 101 | ) -> None: 102 | super().__init__(model=model, inputs=inputs) 103 | self.set_op_handle(**_DEFAULT_SUPPORTED_OPS) 104 | 105 | __init__.__doc__ = JitModelAnalysis.__init__.__doc__ 106 | 107 | 108 | def flop_count( 109 | model: nn.Module, 110 | inputs: Tuple[Any, ...], 111 | supported_ops: Optional[Dict[str, Handle]] = None, 112 | ) -> Tuple[DefaultDict[str, float], Counter[str]]: 113 | """ 114 | Given a model and an input to the model, compute the Gflops of the given 115 | model. 116 | Args: 117 | model (nn.Module): The model to compute flop counts. 118 | inputs (tuple): Inputs that are passed to `model` to count flops. 119 | Inputs need to be in a tuple. 120 | supported_ops (dict(str,Callable) or None) : provide additional 121 | handlers for extra ops, or overwrite the existing handlers for 122 | convolution and matmul and einsum. The key is operator name and the value 123 | is a function that takes (inputs, outputs) of the op. We count 124 | one Multiply-Add as one FLOP. 125 | Returns: 126 | tuple[defaultdict, Counter]: A dictionary that records the number of 127 | gflops for each operation and a Counter that records the number of 128 | unsupported operations. 129 | """ 130 | if supported_ops is None: 131 | supported_ops = {} 132 | flop_counter = FlopCountAnalysis(model, inputs).set_op_handle(**supported_ops) 133 | giga_flops = defaultdict(float) 134 | for op, flop in flop_counter.by_operator().items(): 135 | giga_flops[op] = flop / 1e9 136 | return giga_flops, flop_counter.unsupported_ops() 137 | -------------------------------------------------------------------------------- /utils/flop_count/jit_analysis.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | import warnings 4 | from collections import Counter 5 | from copy import copy 6 | from dataclasses import dataclass 7 | from typing import Any, Dict, List, Optional, Set, Tuple, Union, Iterable 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from torch import Tensor 13 | from torch.jit import TracerWarning, _get_trace_graph 14 | 15 | from .jit_handles import Handle 16 | 17 | 18 | def _named_modules_with_dup( 19 | model: nn.Module, prefix: str = "" 20 | ) -> Iterable[Tuple[str, nn.Module]]: 21 | """ 22 | The same as `model.named_modules()`, except that it includes 23 | duplicated modules that have more than one name. 24 | """ 25 | yield prefix, model 26 | for name, module in model._modules.items(): # pyre-ignore 27 | if module is None: 28 | continue 29 | submodule_prefix = prefix + ("." if prefix else "") + name 30 | yield from _named_modules_with_dup(module, submodule_prefix) 31 | 32 | 33 | # Only ignore ops that are technically truly 0 flops: 34 | # shape-manipulation ops, integer ops, memory copy ops 35 | _IGNORED_OPS: Set[str] = { 36 | "aten::Int", 37 | "aten::ScalarImplicit", 38 | "aten::__and__", 39 | "aten::arange", 40 | "aten::cat", 41 | "aten::chunk", 42 | "aten::clamp", 43 | "aten::clamp_", 44 | "aten::constant_pad_nd", 45 | "aten::contiguous", 46 | "aten::copy_", 47 | "aten::detach", 48 | "aten::dropout", 49 | "aten::empty", 50 | "aten::eq", 51 | "aten::expand", 52 | "aten::flatten", 53 | "aten::floor", 54 | "aten::floor_divide", 55 | "aten::full", 56 | "aten::ge", 57 | "aten::gt", 58 | "aten::index", 59 | "aten::index_put_", 60 | "aten::max", 61 | "aten::nonzero", 62 | "aten::permute", 63 | "aten::relu", 64 | "aten::relu_", 65 | "aten::remainder", 66 | "aten::reshape", 67 | "aten::select", 68 | "aten::size", 69 | "aten::slice", 70 | "aten::split", 71 | "aten::split_with_sizes", 72 | "aten::squeeze", 73 | "aten::narrow", 74 | "aten::unbind", 75 | "aten::full_like", 76 | "aten::stack", 77 | "aten::t", 78 | "aten::to", 79 | "aten::transpose", 80 | "aten::unsqueeze", 81 | "aten::unsqueeze_", 82 | "aten::view", 83 | "aten::zeros", 84 | "aten::zeros_like", 85 | } 86 | 87 | 88 | @dataclass 89 | class Statistics: 90 | """ 91 | For keeping track of the various model statistics recorded during 92 | analysis. 93 | """ 94 | 95 | counts: "Dict[str, Counter[str]]" 96 | unsupported_ops: "Dict[str, Counter[str]]" 97 | uncalled_mods: "Set[str]" 98 | 99 | 100 | def _get_scoped_trace_graph( 101 | module: nn.Module, 102 | inputs: Union[Tensor, Tuple[Tensor, ...]], 103 | aliases: Dict[Union[str, nn.Module], str], 104 | ) -> torch._C.Graph: # pyre-ignore[11] 105 | """ 106 | Traces the provided module using torch.jit._get_trace_graph, but adds 107 | submodule scope information to each graph node. The resulting graph 108 | is in-lined and has all model parameters treated as inputs. The input 109 | model has the scope name '', while its descendants have names of the 110 | form 'child.grandchild.grandgrandchild...'. 111 | Args: 112 | model (nn.Module) : The module to trace 113 | inputs (tuple) : Inputs used during the trace of the model 114 | aliases (dict(str or nn.Module, str) : maps modules and module 115 | names to the canonical name to be used as the scope for 116 | that module. 117 | Returns: 118 | graph (torch._C.Graph) : The pytorch JIT trace of the model 119 | """ 120 | 121 | class ScopePushHook: 122 | def __init__(self, name: str) -> None: 123 | self.name = name 124 | 125 | def __call__(self, module: nn.Module, inputs: Any) -> Any: 126 | tracing_state = torch._C._get_tracing_state() 127 | if tracing_state: 128 | tracing_state.push_scope(self.name) 129 | return inputs 130 | 131 | class ScopePopHook: 132 | def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any: 133 | tracing_state = torch._C._get_tracing_state() 134 | if tracing_state: 135 | tracing_state.pop_scope() 136 | return outputs 137 | 138 | seen = set() 139 | hook_handles: List[Any] = [] 140 | 141 | def register_hooks(mod: nn.Module, name: str) -> None: 142 | prehook = mod.register_forward_pre_hook(ScopePushHook(name)) # pyre-ignore[16] 143 | posthook = mod.register_forward_hook(ScopePopHook()) # pyre-ignore[16] 144 | hook_handles.append(prehook) 145 | hook_handles.append(posthook) 146 | 147 | # Torch script does not support parallel torch models, but we still 148 | # want the scope names to be correct for the complete module. 149 | if isinstance( 150 | module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel) 151 | ): 152 | 153 | # Since DataParallel just wraps the model, add an extra set of hooks 154 | # to the model it wraps to account for the wrapper. Then trace it. 155 | root_name = aliases[module] 156 | module = module.module 157 | register_hooks(module, root_name) 158 | 159 | # We don't need the duplication here, but self._model.named_modules() 160 | # gives slightly different results for some wrapped models. 161 | for name, mod in _named_modules_with_dup(module): 162 | if mod not in seen: 163 | name = aliases[mod] 164 | register_hooks(mod, name) 165 | seen.add(mod) 166 | 167 | if hasattr(torch.jit, "get_trace_graph"): 168 | trace, _ = torch.jit.get_trace_graph(module, inputs) 169 | graph = trace.graph() 170 | else: 171 | graph, _ = _get_trace_graph(module, inputs) 172 | 173 | for handle in hook_handles: 174 | handle.remove() 175 | 176 | return graph 177 | 178 | 179 | class JitModelAnalysis: 180 | """ 181 | Provides access to per-submodule model statistics obtained by 182 | tracing a model with pytorch's jit tracing functionality. Calculates 183 | a statistic on a per-operator basis using the provided set of functions 184 | that acts on the inputs and outputs to the operator, then aggregates 185 | this over modules in the model. Can return the aggregate statistic for 186 | any submodule in the model. Is lazily evaluated, and will perform the 187 | trace when a statistic is first requested. Changing the operator handles 188 | will cause the trace to be rerun on the next request. 189 | Submodules may be referred to using the module's name. The input model has 190 | name "", while its descendants have names of the form 191 | "child.grandchild.grandgrandchild...". 192 | An operator is treated as within the scope of a module if calling that 193 | module directly resulted in that operator being run. In particular, 194 | this means that calls to other functions owned by a module or explicit 195 | calls to module.forward(...) will not register resulting operators as 196 | contributing statistics to that module. 197 | """ 198 | 199 | def __init__( 200 | self, 201 | model: nn.Module, 202 | inputs: Union[Tensor, Tuple[Tensor, ...]], 203 | ) -> None: 204 | """ 205 | Args: 206 | model: The model to analyze 207 | inputs: The inputs to the model for analysis. 208 | """ 209 | self._model = model 210 | self._inputs = inputs 211 | self._op_handles: Dict[str, Handle] = {} 212 | # Mapping from names to submodules 213 | self._named_modules: Dict[str, nn.Module] = dict(_named_modules_with_dup(model)) 214 | # Mapping from submodules and their aliases to the canonical name of each submodule 215 | self._aliases: Dict[Union[nn.Module, str], str] = self._get_aliases(model) 216 | self._stats: Optional[Statistics] = None 217 | self._enable_warn_unsupported_ops = True 218 | self._enable_warn_uncalled_mods = True 219 | self._warn_trace = "no_tracer_warning" 220 | self._ignored_ops: Set[str] = copy(_IGNORED_OPS) 221 | 222 | def total(self, module_name: str = "") -> int: 223 | """ 224 | Returns the total aggregated statistic across all operators 225 | for the requested module. 226 | Args: 227 | module_name (str) : The submodule to get data for. Defaults to 228 | the entire model. 229 | Returns: 230 | int : The aggregated statistic. 231 | """ 232 | stats = self._analyze() 233 | module_name = self.canonical_module_name(module_name) 234 | total_count = sum(stats.counts[module_name].values()) 235 | return total_count 236 | 237 | def by_operator(self, module_name: str = "") -> typing.Counter[str]: 238 | """ 239 | Returns the statistics for a requested module, grouped by operator 240 | type. The operator handle determines the name associated with each 241 | operator type. 242 | Args: 243 | module_name (str) : The submodule to get data for. Defaults 244 | to the entire model. 245 | Returns: 246 | Counter(str) : The statistics for each operator. 247 | """ 248 | stats = self._analyze() 249 | module_name = self.canonical_module_name(module_name) 250 | return stats.counts[module_name] 251 | 252 | def by_module_and_operator(self) -> Dict[str, typing.Counter[str]]: 253 | """ 254 | Returns the statistics for all submodules, separated out by 255 | operator type for each submodule. The operator handle determines 256 | the name associated with each operator type. 257 | Returns: 258 | dict(str, Counter(str)): 259 | The statistics for each submodule and each operator. 260 | Grouped by submodule names, then by operator name. 261 | """ 262 | stats = self._analyze() 263 | return stats.counts 264 | 265 | def by_module(self) -> typing.Counter[str]: 266 | """ 267 | Returns the statistics for all submodules, aggregated over 268 | all operators. 269 | Returns: 270 | Counter(str): statistics counter grouped by submodule names 271 | """ 272 | stats = self._analyze() 273 | summed_counts = Counter() 274 | for mod, results in stats.counts.items(): 275 | summed_counts[mod] = sum(results.values()) 276 | return summed_counts 277 | 278 | def unsupported_ops(self, module_name: str = "") -> typing.Counter[str]: 279 | """ 280 | Lists the number of operators that were encountered but unsupported 281 | because no operator handle is available for them. Does not include 282 | operators that are explicitly ignored. 283 | Args: 284 | module_name (str) : The submodule to list unsupported ops. 285 | Defaults to the entire model. 286 | Returns: 287 | Counter(str) : The number of occurences each unsupported operator. 288 | """ 289 | if self._stats is None: 290 | raise RuntimeError( 291 | "Analysis results should be computed " 292 | "before calling unsupported_ops()" 293 | ) 294 | module_name = self.canonical_module_name(module_name) 295 | return self._stats.unsupported_ops[module_name] # pyre-fixme 296 | 297 | def uncalled_modules(self) -> Set[str]: 298 | """ 299 | Returns a set of submodules that were never called during the 300 | trace of the graph. This may be because they were unused, or 301 | because they were accessed via direct calls .forward() or with 302 | other python methods. In the latter case, statistics will not be 303 | attributed to the submodule, though the statistics will be included 304 | in the parent module. 305 | Returns: 306 | set(str) : The set of submodule names that were never called 307 | during the trace of the model. 308 | """ 309 | stats = self._analyze() 310 | return stats.uncalled_mods 311 | 312 | def set_op_handle(self, *args, **kwargs: Optional[Handle]) -> "JitModelAnalysis": 313 | """ 314 | Sets additional operator handles, or replaces existing ones. 315 | Args: 316 | args: (str, Handle) pairs of operator names and handles. 317 | kwargs: mapping from operator names to handles. 318 | If a handle is ``None``, the op will be explicitly ignored. Otherwise, 319 | handle should be a function that calculates the desirable statistic 320 | from an operator. The function must take two arguments, which are the 321 | inputs and outputs of the operator, in the form of ``list(torch._C.Value)``. 322 | The function should return a counter object with per-operator statistics. 323 | Examples 324 | :: 325 | handlers = {"aten::linear": my_handler} 326 | counter.set_op_handle("aten::matmul", None, "aten::bmm", my_handler2) 327 | .set_op_handle(**handlers) 328 | """ 329 | self._stats = None 330 | if len(args) % 2 != 0: 331 | raise TypeError( 332 | "set_op_handle should be called with pairs of names and handles!" 333 | ) 334 | for name, handle in zip(args[::2], args[1::2]): 335 | kwargs[name] = handle 336 | for name, handle in kwargs.items(): 337 | if handle is None: 338 | self._ignored_ops.add(name) 339 | else: 340 | self._op_handles[name] = handle 341 | return self 342 | 343 | def clear_op_handles(self) -> "JitModelAnalysis": 344 | """ 345 | Clears all operator handles currently set. 346 | """ 347 | self._op_handles = {} 348 | self._ignored_ops = copy(_IGNORED_OPS) 349 | self._stats = None 350 | return self 351 | 352 | def canonical_module_name(self, name: str) -> str: 353 | """ 354 | Returns the canonical module name of the given ``name``, which might be 355 | different from the given ``name`` if the module is shared. 356 | This is the name that will be used as a key when statistics are 357 | output using .by_module() and .by_module_and_operator(). 358 | Args: 359 | name (str) : The name of the module to find the canonical name for. 360 | Returns: 361 | str : The canonical name of the module. 362 | """ 363 | # Blocks access by a direct module reference 364 | assert isinstance(name, str), "Module name must be a string." 365 | if name in self._aliases: 366 | return self._aliases[name] 367 | else: 368 | raise KeyError( 369 | "Requested module name is not among " 370 | "the descendants of the analyzed model." 371 | ) 372 | 373 | def copy( 374 | self, 375 | new_model: Optional[nn.Module] = None, 376 | new_inputs: Union[None, Tensor, Tuple[Tensor, ...]] = None, 377 | ) -> "JitModelAnalysis": 378 | """ 379 | Returns a copy of the :class:`JitModelAnalysis` object, keeping all 380 | settings, but on a new model or new inputs. 381 | Args: 382 | new_model (nn.Module or None) : a new model for the new 383 | JitModelAnalysis. If None, uses the original model. 384 | new_inputs (typing.Tuple[object, ...] or None) : new inputs 385 | for the new JitModelAnalysis. If None, uses the original 386 | inputs. 387 | Returns: 388 | JitModelAnalysis : the new model analysis object 389 | """ 390 | model = self._model if new_model is None else new_model 391 | inputs = self._inputs if new_inputs is None else new_inputs 392 | return ( 393 | JitModelAnalysis(model=model, inputs=inputs) 394 | .set_op_handle(**self._op_handles) 395 | .unsupported_ops_warnings(self._enable_warn_unsupported_ops) 396 | .uncalled_modules_warnings(self._enable_warn_uncalled_mods) 397 | .tracer_warnings(self._warn_trace) 398 | ) 399 | 400 | def tracer_warnings(self, mode: str) -> "JitModelAnalysis": 401 | """ 402 | Sets which warnings to print when tracing the graph to calculate 403 | statistics. There are three modes. Defaults to 'no_tracer_warning'. 404 | Allowed values are: 405 | * 'all' : keeps all warnings raised while tracing 406 | * 'no_tracer_warning' : suppress torch.jit.TracerWarning only 407 | * 'none' : suppress all warnings raised while tracing 408 | Args: 409 | mode (str) : warning mode in one of the above values. 410 | """ 411 | assert mode in [ 412 | "all", 413 | "no_tracer_warning", 414 | "none", 415 | ], "Unrecognized trace warning mode." 416 | self._warn_trace = mode 417 | return self 418 | 419 | def unsupported_ops_warnings(self, enabled: bool) -> "JitModelAnalysis": 420 | """ 421 | Sets if warnings for unsupported operators are shown. Defaults 422 | to True. Counts of unsupported operators may be obtained from 423 | :meth:`unsupported_ops` regardless of this setting. 424 | Args: 425 | enabled (bool) : Set to 'True' to show unsupported operator 426 | warnings. 427 | """ 428 | self._enable_warn_unsupported_ops = enabled 429 | return self 430 | 431 | def uncalled_modules_warnings(self, enabled: bool) -> "JitModelAnalysis": 432 | """ 433 | Sets if warnings from uncalled submodules are shown. Defaults to true. 434 | A submodule is considered "uncalled" if it is never called during 435 | tracing. This may be because it is actually unused, or because it is 436 | accessed via calls to ``.forward()`` or other methods of the module. 437 | The set of uncalled modules may be obtained from 438 | :meth:`uncalled_modules` regardless of this setting. 439 | Args: 440 | enabled (bool) : Set to 'True' to show warnings. 441 | """ 442 | self._enable_warn_uncalled_mods = enabled 443 | return self 444 | 445 | def _warn_unsupported_ops(self, ops: typing.Counter[str]) -> None: 446 | if not self._enable_warn_unsupported_ops: 447 | return 448 | logger = logging.getLogger(__name__) 449 | for op, freq in ops.items(): 450 | logger.warning( 451 | "Unsupported operator {} encountered {} time(s)".format(op, freq) 452 | ) 453 | 454 | def _warn_uncalled_mods(self, uncalled_mods: Set[str]) -> None: 455 | if not self._enable_warn_uncalled_mods or not uncalled_mods: 456 | return 457 | 458 | logger = logging.getLogger(__name__) 459 | logger.warning( 460 | "The following submodules of the model were never " 461 | "called during the trace of the graph. They may be " 462 | "unused, or they were accessed by direct calls to " 463 | ".forward() or via other python methods. In the latter " 464 | "case they will have zeros for statistics, though their " 465 | "statistics will still contribute to their parent calling " 466 | "module." 467 | ) 468 | for mod in uncalled_mods: 469 | logger.warning("Module never called: {}".format(mod)) 470 | 471 | def _get_aliases(self, model: nn.Module) -> Dict[Union[str, nn.Module], str]: 472 | aliases = {} 473 | for name, module in _named_modules_with_dup(model): 474 | if module not in aliases: 475 | aliases[module] = name 476 | aliases[name] = aliases[module] 477 | return aliases 478 | 479 | def _analyze(self) -> "Statistics": 480 | # Don't calculate if results are already stored. 481 | stats = self._stats 482 | if stats is not None: 483 | return stats 484 | 485 | with warnings.catch_warnings(): 486 | if self._warn_trace == "none": 487 | warnings.simplefilter("ignore") 488 | elif self._warn_trace == "no_tracer_warning": 489 | warnings.filterwarnings("ignore", category=TracerWarning) 490 | graph = _get_scoped_trace_graph(self._model, self._inputs, self._aliases) 491 | 492 | # Assures even modules not in the trace graph are initialized to zero count 493 | counts = {} 494 | unsupported_ops = {} 495 | # We don't need the duplication here, but self._model.named_modules() 496 | # gives slightly different results for some wrapped models. 497 | for _, mod in _named_modules_with_dup(self._model): 498 | name = self._aliases[mod] 499 | counts[name] = Counter() 500 | unsupported_ops[name] = Counter() 501 | 502 | all_seen = set() 503 | for node in graph.nodes(): 504 | kind = node.kind() 505 | scope_names = node.scopeName().split("/") 506 | all_seen.update(scope_names) 507 | if kind not in self._op_handles: 508 | # ignore all prim:: operators 509 | if kind in self._ignored_ops or kind.startswith("prim::"): 510 | continue 511 | 512 | for name in set(scope_names): 513 | unsupported_ops[name][kind] += 1 514 | else: 515 | inputs, outputs = list(node.inputs()), list(node.outputs()) 516 | op_counts = self._op_handles[kind](inputs, outputs) 517 | 518 | # Assures an op contributes at most once to a module 519 | for name in set(scope_names): 520 | counts[name] += op_counts 521 | 522 | uncalled_mods = set(self._aliases.values()) - all_seen 523 | 524 | def has_forward(module_type) -> bool: 525 | # Containers are not meant to be called anyway (they don't have forward) 526 | no_forward_mods = {nn.ModuleList, nn.ModuleDict, nn.Module} 527 | for mod in no_forward_mods: 528 | if module_type.forward is mod.forward: 529 | return False 530 | return True 531 | 532 | uncalled_mods = { 533 | m for m in uncalled_mods if has_forward(type(self._named_modules.get(m))) 534 | } 535 | 536 | stats = Statistics( 537 | counts=counts, unsupported_ops=unsupported_ops, uncalled_mods=uncalled_mods 538 | ) 539 | self._stats = stats 540 | self._warn_unsupported_ops(unsupported_ops[""]) 541 | self._warn_uncalled_mods(uncalled_mods) 542 | return stats -------------------------------------------------------------------------------- /utils/flop_count/jit_handles.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from collections import Counter, OrderedDict 3 | from typing import Any, Callable, List, Optional 4 | 5 | from numpy import prod 6 | 7 | 8 | Handle = Callable[[List[Any], List[Any]], typing.Counter[str]] 9 | 10 | 11 | def generic_activation_jit( 12 | op_name: str, 13 | ) -> Callable[[List[Any], List[Any]], typing.Counter[str]]: 14 | """ 15 | This method return a handle that counts the number of activation from the 16 | output shape for the specified operation. 17 | Args: 18 | op_name (str): The name of the operation. 19 | Returns: 20 | Callable: An activation handle for the given operation. 21 | """ 22 | 23 | def _generic_activation_jit(outputs: List[Any]) -> int: 24 | """ 25 | This is a generic jit handle that counts the number of activations for any 26 | operation given the output shape. 27 | Args: 28 | outputs (list(torch._C.Value)): The output shape in the form of a list 29 | of jit object. 30 | Returns: 31 | int: Total number of activations for each operation. 32 | """ 33 | out_shape = get_shape(outputs[0]) 34 | ac_count = prod(out_shape) 35 | return ac_count 36 | 37 | return lambda inputs, outputs: Counter({op_name: _generic_activation_jit(outputs)}) 38 | 39 | 40 | def get_shape(val: Any) -> Optional[List[int]]: 41 | """ 42 | Get the shapes from a jit value object. 43 | Args: 44 | val (torch._C.Value): jit value object. 45 | Returns: 46 | list(int): return a list of ints. 47 | """ 48 | if val.isCompleteTensor(): 49 | return val.type().sizes() 50 | else: 51 | return None 52 | 53 | 54 | """ 55 | Below are flop counters for various ops. Every counter has the following signature: 56 | Args: 57 | inputs (list(torch._C.Value)): The inputs of the op in the form of a list of jit object. 58 | outputs (list(torch._C.Value)): The outputs of the op in the form of a list of jit object. 59 | Returns: 60 | Counter: A Counter dictionary that records the number of flops for each operation. 61 | """ 62 | 63 | 64 | def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 65 | """ 66 | Count flops for fully connected layers. 67 | """ 68 | # Count flop for nn.Linear 69 | # inputs is a list of length 3. 70 | input_shapes = [get_shape(v) for v in inputs[1:3]] 71 | # input_shapes[0]: [batch size, input feature dimension] 72 | # input_shapes[1]: [batch size, output feature dimension] 73 | assert len(input_shapes[0]) == 2, input_shapes[0] 74 | assert len(input_shapes[1]) == 2, input_shapes[1] 75 | batch_size, input_dim = input_shapes[0] 76 | output_dim = input_shapes[1][1] 77 | flop = batch_size * input_dim * output_dim 78 | flop_counter = Counter({"addmm": flop}) 79 | return flop_counter 80 | 81 | 82 | def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 83 | """ 84 | Count flops for the aten::linear operator. 85 | """ 86 | # Inputs is a list of length 3; unlike aten::addmm, it is the first 87 | # two elements that are relevant. 88 | input_shapes = [get_shape(v) for v in inputs[0:2]] 89 | # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] 90 | # input_shapes[1]: [output_feature_dim, input_feature_dim] 91 | assert input_shapes[0][-1] == input_shapes[1][-1] 92 | flops = prod(input_shapes[0]) * input_shapes[1][0] 93 | flop_counter = Counter({"linear": flops}) 94 | return flop_counter 95 | 96 | 97 | def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 98 | """ 99 | Count flops for the bmm operation. 100 | """ 101 | # Inputs should be a list of length 2. 102 | # Inputs contains the shapes of two tensor. 103 | assert len(inputs) == 2, len(inputs) 104 | input_shapes = [get_shape(v) for v in inputs] 105 | n, c, t = input_shapes[0] 106 | d = input_shapes[-1][-1] 107 | flop = n * c * t * d 108 | flop_counter = Counter({"bmm": flop}) 109 | return flop_counter 110 | 111 | 112 | def conv_flop_count( 113 | x_shape: List[int], w_shape: List[int], out_shape: List[int] 114 | ) -> typing.Counter[str]: 115 | """ 116 | Count flops for convolution. Note only multiplication is 117 | counted. Computation for addition and bias is ignored. 118 | Args: 119 | x_shape (list(int)): The input shape before convolution. 120 | w_shape (list(int)): The filter shape. 121 | out_shape (list(int)): The output shape after convolution. 122 | Returns: 123 | Counter: A Counter dictionary that records the number of flops for each 124 | operation. 125 | """ 126 | batch_size, Cin_dim, Cout_dim = x_shape[0], w_shape[1], out_shape[1] 127 | out_size = prod(out_shape[2:]) 128 | kernel_size = prod(w_shape[2:]) 129 | flop = batch_size * out_size * Cout_dim * Cin_dim * kernel_size 130 | flop_counter = Counter({"conv": flop}) 131 | return flop_counter 132 | 133 | 134 | def conv_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 135 | """ 136 | Count flops for convolution. 137 | """ 138 | # Inputs of Convolution should be a list of length 12 or 13. They represent: 139 | # 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding, 140 | # 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn, 141 | # 10) deterministic_cudnn and 11) user_enabled_cudnn. 142 | # starting with #40737 it will be 12) user_enabled_tf32 143 | assert len(inputs) == 12 or len(inputs) == 13, len(inputs) 144 | x, w = inputs[:2] 145 | x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) 146 | return conv_flop_count(x_shape, w_shape, out_shape) 147 | 148 | 149 | def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 150 | """ 151 | Count flops for the einsum operation. We currently support 152 | two einsum operations: "nct,ncp->ntp" and "ntg,ncg->nct". 153 | """ 154 | # Inputs of einsum should be a list of length 2. 155 | # Inputs[0] stores the equation used for einsum. 156 | # Inputs[1] stores the list of input shapes. 157 | assert len(inputs) == 2, len(inputs) 158 | equation = inputs[0].toIValue() 159 | # Get rid of white space in the equation string. 160 | equation = equation.replace(" ", "") 161 | # Re-map equation so that same equation with different alphabet 162 | # representations will look the same. 163 | letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() 164 | mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} 165 | equation = equation.translate(mapping) 166 | input_shapes_jit = inputs[1].node().inputs() 167 | input_shapes = [get_shape(v) for v in input_shapes_jit] 168 | 169 | if equation == "abc,abd->acd": 170 | n, c, t = input_shapes[0] 171 | p = input_shapes[-1][-1] 172 | flop = n * c * t * p 173 | flop_counter = Counter({"einsum": flop}) 174 | return flop_counter 175 | 176 | elif equation == "abc,adc->adb": 177 | n, t, g = input_shapes[0] 178 | c = input_shapes[-1][1] 179 | flop = n * t * g * c 180 | flop_counter = Counter({"einsum": flop}) 181 | return flop_counter 182 | 183 | else: 184 | raise NotImplementedError("Unsupported einsum operation.") 185 | 186 | 187 | def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 188 | """ 189 | Count flops for matmul. 190 | """ 191 | # Inputs should be a list of length 2. 192 | # Inputs contains the shapes of two matrices. 193 | input_shapes = [get_shape(v) for v in inputs] 194 | assert len(input_shapes) == 2, input_shapes 195 | assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes 196 | flop = prod(input_shapes[0]) * input_shapes[-1][-1] 197 | flop_counter = Counter({"matmul": flop}) 198 | return flop_counter 199 | 200 | 201 | def norm_flop_counter(name: str, affine_arg_index: int) -> Handle: 202 | """ 203 | Args: 204 | name: name to return in the counter 205 | affine_arg_index: index of the affine argument in inputs 206 | """ 207 | 208 | def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 209 | """ 210 | Count flops for norm layers. 211 | """ 212 | # Inputs[0] contains the shape of the input. 213 | input_shape = get_shape(inputs[0]) 214 | has_affine = get_shape(inputs[affine_arg_index]) is not None 215 | assert 2 <= len(input_shape) <= 5, input_shape 216 | # 5 is just a rough estimate 217 | flop = prod(input_shape) * (5 if has_affine else 4) 218 | return Counter({name: flop}) 219 | 220 | return norm_flop_jit 221 | 222 | 223 | def elementwise_flop_counter( 224 | name: str, input_scale: float = 1, output_scale: float = 0 225 | ) -> Handle: 226 | """ 227 | Count flops by 228 | input_tensor.numel() * input_scale + output_tensor.numel() * output_scale 229 | Args: 230 | name: name to return in the counter 231 | input_scale: scale of the input tensor (first argument) 232 | output_scale: scale of the output tensor (first element in outputs) 233 | """ 234 | 235 | def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> typing.Counter[str]: 236 | ret = 0 237 | if input_scale != 0: 238 | shape = get_shape(inputs[0]) 239 | ret += input_scale * prod(shape) 240 | if output_scale != 0: 241 | shape = get_shape(outputs[0]) 242 | ret += output_scale * prod(shape) 243 | return Counter({name: ret}) 244 | 245 | return -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.data.transforms_factory import create_transform 3 | from timm.data.loader import (fast_collate, 4 | MultiEpochsDataLoader, 5 | PrefetchLoader) 6 | from timm.data.distributed_sampler import OrderedDistributedSampler 7 | from timm.data.constants import (IMAGENET_DEFAULT_MEAN, 8 | IMAGENET_DEFAULT_STD) 9 | from .sampler import RepeatAugmentSampler 10 | 11 | def create_loader(dataset, 12 | input_size, 13 | batch_size, 14 | is_training=False, 15 | use_prefetcher=True, 16 | no_aug=False, 17 | re_prob=0., 18 | re_mode='const', 19 | re_count=1, 20 | re_split=False, 21 | scale=None, 22 | ratio=None, 23 | hflip=0.5, 24 | vflip=0., 25 | color_jitter=0.4, 26 | auto_augment=None, 27 | num_aug_splits=0, 28 | interpolation='bilinear', 29 | mean=IMAGENET_DEFAULT_MEAN, 30 | std=IMAGENET_DEFAULT_STD, 31 | num_workers=1, 32 | distributed=False, 33 | crop_pct=None, 34 | collate_fn=None, 35 | pin_memory=False, 36 | fp16=False, 37 | tf_preprocessing=False, 38 | use_multi_epochs_loader=False, 39 | repeated_aug=False): 40 | re_num_splits = 0 41 | if re_split: 42 | re_num_splits = num_aug_splits or 2 43 | 44 | dataset.transform = create_transform( 45 | input_size, 46 | is_training=is_training, 47 | use_prefetcher=use_prefetcher, 48 | no_aug=no_aug, 49 | scale=scale, 50 | ratio=ratio, 51 | hflip=hflip, 52 | vflip=vflip, 53 | color_jitter=color_jitter, 54 | auto_augment=auto_augment, 55 | interpolation=interpolation, 56 | mean=mean, 57 | std=std, 58 | crop_pct=crop_pct, 59 | tf_preprocessing=tf_preprocessing, 60 | re_prob=re_prob, 61 | re_mode=re_mode, 62 | re_count=re_count, 63 | re_num_splits=re_num_splits, 64 | separate=num_aug_splits > 0, 65 | ) 66 | 67 | sampler = None 68 | 69 | if distributed: 70 | if is_training: 71 | if repeated_aug: 72 | sampler = RepeatAugmentSampler(dataset) 73 | print('=use repeated augmentation=') 74 | else: 75 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 76 | else: 77 | sampler = OrderedDistributedSampler(dataset) 78 | 79 | if collate_fn is None: 80 | collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate 81 | 82 | loader_class = torch.utils.data.DataLoader 83 | 84 | if use_multi_epochs_loader: 85 | loader_class = MultiEpochsDataLoader 86 | 87 | loader = loader_class( 88 | dataset, 89 | batch_size=batch_size, 90 | shuffle=sampler is None and is_training, 91 | num_workers=num_workers, 92 | sampler=sampler, 93 | collate_fn=collate_fn, 94 | pin_memory=pin_memory, 95 | drop_last=is_training, 96 | ) 97 | 98 | if use_prefetcher: 99 | prefetch_re_prob = re_prob if is_training and not no_aug else 0. 100 | loader = PrefetchLoader( 101 | loader, 102 | mean=mean, 103 | std=std, 104 | fp16=fp16, 105 | re_prob=prefetch_re_prob, 106 | re_mode=re_mode, 107 | re_count=re_count, 108 | re_num_splits=re_num_splits 109 | ) 110 | 111 | return loader 112 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RepeatAugmentSampler(torch.utils.data.Sampler): 7 | def __init__(self, 8 | dataset, 9 | num_replicas=None, 10 | rank=None, 11 | shuffle=True): 12 | if num_replicas is None: 13 | if not dist.is_available(): 14 | raise RuntimeError("Requires distributed package to be available") 15 | num_replicas = dist.get_world_size() 16 | if rank is None: 17 | if not dist.is_available(): 18 | raise RuntimeError("Requires distributed package to be available") 19 | rank = dist.get_rank() 20 | self.dataset = dataset 21 | self.num_replicas = num_replicas 22 | self.rank = rank 23 | self.epoch = 0 24 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 25 | self.total_size = self.num_samples * self.num_replicas 26 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 27 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 28 | self.shuffle = shuffle 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | if self.shuffle: 35 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 36 | else: 37 | indices = list(range(len(self.dataset))) 38 | 39 | # add extra samples to make it evenly divisible 40 | indices = [ele for ele in indices for i in range(3)] 41 | indices += indices[:(self.total_size - len(indices))] 42 | assert len(indices) == self.total_size 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices[:self.num_selected_samples]) 49 | 50 | def __len__(self): 51 | return self.num_selected_samples 52 | 53 | def set_epoch(self, epoch): 54 | self.epoch = epoch 55 | 56 | --------------------------------------------------------------------------------