├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── TRAINING.md ├── datasets.py ├── engine.py ├── figures └── fig_1.png ├── fourier_analysis.py ├── get_flops.py ├── main.py ├── models ├── __init__.py ├── rest.py └── rest_v2.py ├── object_detection ├── GETTING_STARTED.md ├── README.md ├── analyze_model.py ├── configs │ ├── Base-RCNN-FPN.yaml │ ├── Base-RetinaNet.yaml │ ├── ResTv1 │ │ ├── mask_rcnn_rest_base_FPN_1x.yaml │ │ ├── mask_rcnn_rest_small_FPN_1x.yaml │ │ ├── retinanet_rest_base_FPN_1x.yaml │ │ └── retinanet_rest_small_FPN_1x.yaml │ └── ResTv2 │ │ ├── mask_rcnn_rest_base_FPN_1x.yaml │ │ ├── mask_rcnn_rest_small_FPN_1x.yaml │ │ ├── mask_rcnn_restv2_base_FPN_3x.yaml │ │ ├── mask_rcnn_restv2_small_FPN_3x.yaml │ │ ├── mask_rcnn_restv2_tiny_FPN_3x.yaml │ │ └── retinanet_restv2_tiny_FPN_3x.yaml ├── convert_to_d2.py ├── datasets │ ├── README.md │ ├── prepare_ade20k_sem_seg.py │ ├── prepare_cocofied_lvis.py │ ├── prepare_for_tests.sh │ └── prepare_panoptic_fpn.py ├── restv2 │ ├── __init__.py │ ├── config.py │ ├── rest.py │ └── restv2.py ├── train_net.py └── utils.py ├── optim_factory.py ├── requirements.txt ├── run_with_submitit.py ├── semantic_segmentation ├── README.md ├── backbone │ ├── __init__.py │ └── rest_v2.py ├── configs │ ├── ResTv2 │ │ ├── upernet_restv2_base_512_160k_ade20k.py │ │ ├── upernet_restv2_small_512_160k_ade20k.py │ │ └── upernet_restv2_tiny_512_160k_ade20k.py │ └── _base_ │ │ └── models │ │ └── upernet_restv2.py └── tools │ ├── align_resize.py │ ├── dist_test.sh │ ├── dist_train.sh │ └── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | outputs 3 | instant_test_output 4 | inference_test_output 5 | 6 | *.png 7 | *.json 8 | *.diff 9 | *.jpg 10 | !/projects/DensePose/doc/images/*.jpg 11 | 12 | # compilation and distribution 13 | __pycache__ 14 | _ext 15 | *.pyc 16 | *.pyd 17 | *.so 18 | *.dll 19 | *.egg-info/ 20 | build/ 21 | dist/ 22 | wheels/ 23 | 24 | # pytorch/python/numpy formats 25 | *.pth 26 | *.pkl 27 | *.npy 28 | *.ts 29 | model_ts*.txt 30 | 31 | # ipython/jupyter notebooks 32 | *.ipynb 33 | **/.ipynb_checkpoints/ 34 | 35 | # Editor temporaries 36 | *.swn 37 | *.swo 38 | *.swp 39 | *~ 40 | 41 | # editor settings 42 | .idea 43 | .vscode 44 | _darcs 45 | 46 | # project dirs 47 | /detectron2/model_zoo/configs 48 | /datasets/* 49 | !/datasets/*.* 50 | /projects/*/datasets 51 | /models 52 | /snippet -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | We provide installation instructions for ImageNet classification experiments here. 4 | 5 | ## Dependency Setup 6 | Create a new conda virtual environment 7 | ``` 8 | conda create -n rest python=3.9 -y 9 | conda activate rest 10 | ``` 11 | 12 | Install [PyTorch](https://pytorch.org/) >= 1.8.0, [torchvision](https://pytorch.org/vision/stable/index.html) >=0.9.0 following official instructions. For example: 13 | ``` 14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 15 | ``` 16 | 17 | Clone this repo and install required packages: 18 | ``` 19 | pip install timm==0.5.4 tensorboardX six 20 | ``` 21 | 22 | The results in the paper are generated with `torch==1.8.0+cu111 torchvision==0.9.0+cu111 timm==0.5.4`. 23 | 24 | ## Dataset Preparation 25 | 26 | Download the [ImageNet-1K](http://image-net.org/) classification dataset and structure the data as follows: 27 | ``` 28 | /path/to/imagenet-1k/ 29 | train/ 30 | class1/ 31 | img1.jpeg 32 | class2/ 33 | img2.jpeg 34 | val/ 35 | class1/ 36 | img3.jpeg 37 | class2/ 38 | img4.jpeg 39 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Updates 2 | - (2022/05/10) Code of [ResTV2](https://arxiv.org/abs/2204.07366) is released! ResTv2 simplifies the EMSA structure in 3 | [ResTv1](https://arxiv.org/abs/2105.13677) (i.e., eliminating the multi-head interaction part) and employs an upsample 4 | operation to reconstruct the lost medium- and high-frequency information caused by the downsampling operation. 5 | 6 | # [ResT: An Efficient Transformer for Visual Recognition](https://arxiv.org/abs/2105.13677) 7 | 8 | Official PyTorch implementation of **ResTv1** and **ResTv2**, from the following paper: 9 | 10 | [ResT: An Efficient Transformer for Visual Recognition](https://arxiv.org/abs/2105.13677). NeurIPS 2021.\ 11 | [ResT V2: Simpler, Faster and Stronger](https://arxiv.org/abs/2204.07366). NeurIPS 2022.\ 12 | By Qing-Long Zhang and Yu-Bin Yang \ 13 | State Key Laboratory for Novel Software Technology at Nanjing University 14 | 15 | --- 16 | 17 |

18 | 20 |

21 | 22 | **ResTv1** is initially described in [arxiv](https://arxiv.org/abs/2105.13677), which capably serves as a 23 | general-purpose backbone for computer vision. It can tackle input images with arbitrary size. Besides, 24 | ResT compressed the memory of standard MSA and model the interaction between multi-heads while keeping 25 | the diversity ability. 26 | 27 | ## Catalog 28 | - [x] ImageNet-1K Training Code 29 | - [x] ImageNet-1K Fine-tuning Code 30 | - [x] Downstream Transfer (Detection, Segmentation) Code 31 | 32 | 33 | 34 | ## Results and Pre-trained Models 35 | ### ImageNet-1K trained models 36 | 37 | | name | resolution |acc@1 | #params | FLOPs | Throughput | model | 38 | |:-----------:|:---:|:---:|:-------:|:-----:|:----------:|:---:| 39 | | ResTv1-Lite | 224x224 | 77.2 | 11M | 1.4G | 1246 | [baidu](https://pan.baidu.com/s/1VVzrzZi_tD3yTp_lw9tU9A) 40 | | ResTv1-S | 224x224 | 79.6 | 14M | 1.9G | 1043 | [baidu](https://pan.baidu.com/s/1Y-MIzzzcQnmrbHfGGR0mrw) 41 | | ResTv1-B | 224x224 | 81.6 | 30M | 4.3G | 673 | [baidu](https://pan.baidu.com/s/1HhR9YxtGIhouZ0GEA4LYlw) 42 | | ResTv1-L | 224x224 | 83.6 | 52M | 7.9G | 429 | [baidu](https://pan.baidu.com/s/14c4u_oRoBcKOt1aTlsBBpw) 43 | | ResTv2-T | 224x224 | 82.3 | 30M | 4.1G | 826 | [baidu](https://pan.baidu.com/s/1LHAbsrXnGsjvAE3d5zhaHQ) | 44 | | ResTv2-T | 384x384 | 83.7 | 30M | 12.7G | 319 | [baidu](https://pan.baidu.com/s/1fEMs_OrDa_xF7Cw1DiBU9w) | 45 | | ResTv2-S | 224x224 | 83.2 | 41M | 6.0G | 687 | [baidu](https://pan.baidu.com/s/1nysV5MTtwsDLChrRa7vmZQ) | 46 | | ResTv2-S | 384x384 | 84.5 | 41M | 18.4G | 256 | [baidu](https://pan.baidu.com/s/1S1GERP-lYEJANYr17xk3dA) | 47 | | ResTv2-B | 224x224 | 83.7 | 56M | 7.9G | 582 | [baidu](https://pan.baidu.com/s/1GH3N2_rbZx816mN87UzYgQ) | 48 | | ResTv2-B | 384x384 | 85.1 | 56M | 24.3G | 210 | [baidu](https://pan.baidu.com/s/12RBMZmf6IlJIB3lIkeBH9Q) | 49 | | ResTv2-L | 224x224 | 84.2 | 87M | 13.8G | 415 | [baidu](https://pan.baidu.com/s/1A2huwk_Ii4ZzQllg1iHrEw) | 50 | | ResTv2-L | 384x384 | 85.4 | 87M | 42.4G | 141 | [baidu](https://pan.baidu.com/s/1dlxiWexb9mho63WdWS8nXg) | 51 | 52 | 53 | Note: Access code for `baidu` is `rest`. Pretrained models of ResTv1 is now available in [google drive](https://drive.google.com/drive/folders/1H6QUZsKYbU6LECtxzGHKqEeGbx1E8uQ9). 54 | 55 | ## Installation 56 | Please check [INSTALL.md](INSTALL.md) for installation instructions. 57 | 58 | ## Evaluation 59 | We give an example evaluation command for a ImageNet-1K pre-trained, then ImageNet-1K fine-tuned ResTv2-T: 60 | 61 | Single-GPU 62 | ``` 63 | python main.py --model restv2_tiny --eval true \ 64 | --resume restv2_tiny_384.pth \ 65 | --input_size 384 --drop_path 0.1 \ 66 | --data_path /path/to/imagenet-1k 67 | ``` 68 | 69 | This should give 70 | ``` 71 | * Acc@1 83.708 Acc@5 96.524 loss 0.777 72 | ``` 73 | 74 | - For evaluating other model variants, change `--model`, `--resume`, `--input_size` accordingly. You can get the url to pre-trained models from the tables above. 75 | - Setting model-specific `--drop_path` is not strictly required in evaluation, as the `DropPath` module in timm behaves the same during evaluation; but it is required in training. See [TRAINING.md](TRAINING.md) or our paper for the values used for different models. 76 | 77 | ## Training 78 | See [TRAINING.md](TRAINING.md) for training and fine-tuning instructions. 79 | 80 | ## Acknowledgement 81 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library. 82 | 83 | ## License 84 | This project is released under the Apache License 2.0. Please see the [LICENSE](LICENSE) file for more information. 85 | 86 | ## Citation 87 | If you find this repository helpful, please consider citing: 88 | 89 | **ResTv1** 90 | ``` 91 | @inproceedings{zhang2021rest, 92 | title={ResT: An Efficient Transformer for Visual Recognition}, 93 | author={Qinglong Zhang and Yu-bin Yang}, 94 | booktitle={Advances in Neural Information Processing Systems}, 95 | year={2021}, 96 | url={https://openreview.net/forum?id=6Ab68Ip4Mu} 97 | } 98 | ``` 99 | 100 | **ResTv2** 101 | ``` 102 | @article{zhang2022rest, 103 | title={ResT V2: Simpler, Faster and Stronger}, 104 | author={Zhang, Qing-Long and Yang, Yu-Bin}, 105 | journal={arXiv preprint arXiv:2204.07366}, 106 | year={2022} 107 | ``` 108 | 109 | ## Third-party Implementation 110 | [2022/05/26] ResT and ResT v2 have been integrated into [PaddleViT](https://github.com/BR-IDL/PaddleViT), checkout [here](https://github.com/BR-IDL/PaddleViT/tree/develop/image_classification/ResT) for the 3rd party implementation on Paddle framework! 111 | -------------------------------------------------------------------------------- /TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training 2 | 3 | We provide ImageNet-1K training, and fine-tuning commands here. 4 | Please check [INSTALL.md](INSTALL.md) for installation instructions first. 5 | 6 | ## ImageNet-1K Training 7 | Training on ImageNet-1K on a single machine: 8 | ``` 9 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 10 | --model restv2_tiny --drop_path 0.1 \ 11 | --clip_grad 1.0 --warmup_epochs 50 --epochs 300 \ 12 | --batch_size 256 --lr 1.5e-4 --update_freq 1 \ 13 | --model_ema true --model_ema_eval true \ 14 | --data_path /path/to/imagenet-1k 15 | --output_dir /path/to/save_results 16 | ``` 17 | 18 | ## ImageNet-1K Fine-tuning 19 | ### Finetune from ImageNet-1K pre-training 20 | The training commands given above for ImageNet-1K use the default resolution (224). We also fine-tune these trained models with a larger resolution (384). Please specify the path or url to the checkpoint in `--finetune`. 21 | 22 | Fine-tuning on ImageNet-1K (384x384): 23 | 24 | Single-machine 25 | ``` 26 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 27 | --model restv2_tiny --drop_path 0.1 --input_size 384 \ 28 | --batch_size 64 --lr 1.5e-5 --update_freq 1 \ 29 | --warmup_epochs 0 --epochs 30 --weight_decay 1e-8 \ 30 | --cutmix 0 --mixup 0 --clip_grad 1.0 \ 31 | --finetune /path/to/checkpoint.pth \ 32 | --data_path /path/to/imagenet-1k \ 33 | --output_dir /path/to/save_results 34 | ``` 35 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import os 8 | from torchvision import datasets, transforms 9 | 10 | from timm.data.constants import \ 11 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 12 | from timm.data import create_transform 13 | 14 | 15 | def build_dataset(is_train, args): 16 | transform = build_transform(is_train, args) 17 | 18 | print("Transform = ") 19 | if isinstance(transform, tuple): 20 | for trans in transform: 21 | print(" - - - - - - - - - - ") 22 | for t in trans.transforms: 23 | print(t) 24 | else: 25 | for t in transform.transforms: 26 | print(t) 27 | print("---------------------------") 28 | 29 | if args.data_set == 'CIFAR': 30 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 31 | nb_classes = 100 32 | elif args.data_set == 'IMNET': 33 | print("reading from datapath", args.data_path) 34 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 35 | dataset = datasets.ImageFolder(root, transform=transform) 36 | nb_classes = 1000 37 | elif args.data_set == "image_folder": 38 | root = args.data_path if is_train else args.eval_data_path 39 | dataset = datasets.ImageFolder(root, transform=transform) 40 | nb_classes = args.nb_classes 41 | assert len(dataset.class_to_idx) == nb_classes 42 | else: 43 | raise NotImplementedError() 44 | print("Number of the class = %d" % nb_classes) 45 | 46 | return dataset, nb_classes 47 | 48 | 49 | def build_transform(is_train, args): 50 | resize_im = args.input_size > 32 51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 54 | 55 | if is_train: 56 | # this should always dispatch to transforms_imagenet_train 57 | transform = create_transform( 58 | input_size=args.input_size, 59 | is_training=True, 60 | color_jitter=args.color_jitter, 61 | auto_augment=args.aa, 62 | interpolation=args.train_interpolation, 63 | re_prob=args.reprob, 64 | re_mode=args.remode, 65 | re_count=args.recount, 66 | mean=mean, 67 | std=std, 68 | ) 69 | if not resize_im: 70 | transform.transforms[0] = transforms.RandomCrop( 71 | args.input_size, padding=4) 72 | return transform 73 | 74 | t = [] 75 | if resize_im: 76 | # warping (no cropping) when evaluated at 384 or larger 77 | if args.input_size >= 384: 78 | t.append( 79 | transforms.Resize( 80 | (args.input_size, args.input_size), 81 | interpolation=transforms.InterpolationMode.BICUBIC), 82 | ) 83 | print(f"Warping {args.input_size} size input images...") 84 | else: 85 | if args.crop_pct is None: 86 | args.crop_pct = 224 / 256 87 | size = int(args.input_size / args.crop_pct) 88 | t.append( 89 | # to maintain same ratio w.r.t. 224 images 90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 91 | ) 92 | t.append(transforms.CenterCrop(args.input_size)) 93 | 94 | t.append(transforms.ToTensor()) 95 | t.append(transforms.Normalize(mean, std)) 96 | return transforms.Compose(t) 97 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import math 8 | from typing import Iterable, Optional 9 | import torch 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | 13 | import utils 14 | 15 | 16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 17 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 18 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 19 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 20 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 21 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 22 | model.train(True) 23 | metric_logger = utils.MetricLogger(delimiter=" ") 24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 25 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 50 28 | 29 | optimizer.zero_grad() 30 | 31 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 32 | # if data_iter_step > 20: 33 | # break 34 | step = data_iter_step // update_freq 35 | if step >= num_training_steps_per_epoch: 36 | continue 37 | it = start_steps + step # global training iteration 38 | # Update LR & WD for the first acc 39 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 40 | for i, param_group in enumerate(optimizer.param_groups): 41 | if lr_schedule_values is not None: 42 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 43 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 44 | param_group["weight_decay"] = wd_schedule_values[it] 45 | 46 | samples = samples.to(device, non_blocking=True) 47 | targets = targets.to(device, non_blocking=True) 48 | 49 | if mixup_fn is not None: 50 | samples, targets = mixup_fn(samples, targets) 51 | 52 | if use_amp: 53 | with torch.cuda.amp.autocast(): 54 | output = model(samples) 55 | loss = criterion(output, targets) 56 | else: # full precision 57 | output = model(samples) 58 | loss = criterion(output, targets) 59 | 60 | loss_value = loss.item() 61 | 62 | if not math.isfinite(loss_value): # this could trigger if using AMP 63 | print("Loss is {}, stopping training".format(loss_value)) 64 | assert math.isfinite(loss_value) 65 | 66 | if use_amp: 67 | # this attribute is added by timm on one optimizer (adahessian) 68 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 69 | loss /= update_freq 70 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 71 | parameters=model.parameters(), create_graph=is_second_order, 72 | update_grad=(data_iter_step + 1) % update_freq == 0) 73 | if (data_iter_step + 1) % update_freq == 0: 74 | optimizer.zero_grad() 75 | if model_ema is not None: 76 | model_ema.update(model) 77 | else: # full precision 78 | loss /= update_freq 79 | loss.backward() 80 | if (data_iter_step + 1) % update_freq == 0: 81 | optimizer.step() 82 | optimizer.zero_grad() 83 | if model_ema is not None: 84 | model_ema.update(model) 85 | 86 | torch.cuda.synchronize() 87 | 88 | if mixup_fn is None: 89 | class_acc = (output.max(-1)[-1] == targets).float().mean() 90 | else: 91 | class_acc = None 92 | metric_logger.update(loss=loss_value) 93 | metric_logger.update(class_acc=class_acc) 94 | min_lr = 10. 95 | max_lr = 0. 96 | for group in optimizer.param_groups: 97 | min_lr = min(min_lr, group["lr"]) 98 | max_lr = max(max_lr, group["lr"]) 99 | 100 | metric_logger.update(lr=max_lr) 101 | metric_logger.update(min_lr=min_lr) 102 | weight_decay_value = None 103 | for group in optimizer.param_groups: 104 | if group["weight_decay"] > 0: 105 | weight_decay_value = group["weight_decay"] 106 | metric_logger.update(weight_decay=weight_decay_value) 107 | if use_amp: 108 | metric_logger.update(grad_norm=grad_norm) 109 | 110 | if log_writer is not None: 111 | log_writer.update(loss=loss_value, head="loss") 112 | log_writer.update(class_acc=class_acc, head="loss") 113 | log_writer.update(lr=max_lr, head="opt") 114 | log_writer.update(min_lr=min_lr, head="opt") 115 | log_writer.update(weight_decay=weight_decay_value, head="opt") 116 | if use_amp: 117 | log_writer.update(grad_norm=grad_norm, head="opt") 118 | log_writer.set_step() 119 | 120 | # gather the stats from all processes 121 | metric_logger.synchronize_between_processes() 122 | print("Averaged stats:", metric_logger) 123 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 124 | 125 | 126 | @torch.no_grad() 127 | def evaluate(data_loader, model, device, use_amp=False): 128 | criterion = torch.nn.CrossEntropyLoss() 129 | 130 | metric_logger = utils.MetricLogger(delimiter=" ") 131 | header = 'Test:' 132 | 133 | # switch to evaluation mode 134 | model.eval() 135 | i = 0 136 | for batch in metric_logger.log_every(data_loader, 10, header): 137 | i += 1 138 | images = batch[0] 139 | target = batch[-1] 140 | 141 | images = images.to(device, non_blocking=True) 142 | target = target.to(device, non_blocking=True) 143 | 144 | # compute output 145 | if use_amp: 146 | with torch.cuda.amp.autocast(): 147 | output = model(images) 148 | loss = criterion(output, target) 149 | else: 150 | output = model(images) 151 | loss = criterion(output, target) 152 | 153 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 154 | 155 | batch_size = images.shape[0] 156 | metric_logger.update(loss=loss.item()) 157 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 158 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 159 | # gather the stats from all processes 160 | metric_logger.synchronize_between_processes() 161 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 162 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 163 | 164 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 165 | -------------------------------------------------------------------------------- /figures/fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wofmanaf/ResT/b9245d3f335e6c311798ce17b31c64990ee6c0e7/figures/fig_1.png -------------------------------------------------------------------------------- /fourier_analysis.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import matplotlib.colors 4 | import matplotlib.pyplot as plt 5 | from matplotlib import rcParams 6 | import matplotlib.cm as cm 7 | from matplotlib.collections import LineCollection 8 | from matplotlib.ticker import FormatStrFormatter 9 | from matplotlib.pyplot import MultipleLocator 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data.sampler import Sampler 14 | import torchvision.datasets as datasets 15 | import torchvision.transforms as transforms 16 | from models.fft_rest_v2 import restv2_tiny 17 | from tqdm import tqdm 18 | 19 | config = {"font.family": "sans-serif", "font.size": 16, "mathtext.fontset": 'stix', 20 | 'font.sans-serif': ['Times New Roman'], } 21 | rcParams.update(config) 22 | 23 | 24 | # Sampler for pytorch loader. Given range r loader will only 25 | # return dataset[r] instead of whole dataset. 26 | class RangeSampler(Sampler): 27 | def __init__(self, r): 28 | self.r = r 29 | 30 | def __iter__(self): 31 | return iter(self.r) 32 | 33 | def __len__(self): 34 | return len(self.r) 35 | 36 | 37 | def fourier(x): # 2D Fourier transform 38 | f = torch.fft.fft2(x) 39 | f = f.abs() + 1e-6 40 | f = f.log() 41 | return f 42 | 43 | 44 | def shift(x): # shift Fourier transformed feature map 45 | b, c, h, w = x.shape 46 | return torch.roll(x, shifts=(int(h / 2), int(w / 2)), dims=(2, 3)) 47 | 48 | 49 | def make_segments(x, y): # make segment for `plot_segment` 50 | points = np.array([x, y]).T.reshape(-1, 1, 2) 51 | segments = np.concatenate([points[:-1], points[1:]], axis=1) 52 | return segments 53 | 54 | 55 | def plot_segment(fig, ax, xs, ys, cmap_name="plasma", marker="o"): # plot with cmap segments 56 | z = np.linspace(0.0, 1.0, len(ys)) 57 | z = np.asarray(z) 58 | 59 | cmap = cm.get_cmap(cmap_name) 60 | norm = plt.Normalize(0.0, 1.0) 61 | segments = make_segments(xs, ys) 62 | lc = LineCollection(segments, array=z, cmap=cmap_name, norm=norm, 63 | linewidth=2.5, alpha=1.0) 64 | ax.add_collection(lc) 65 | 66 | colors = [cmap(x) for x in xs] 67 | sc = ax.scatter(xs, ys, color=colors, marker=marker, zorder=100) 68 | fig.colorbar(sc, ticks=[0, 0.1, 1.]) 69 | 70 | 71 | def plot(latents, name=''): 72 | # latents: list of hidden feature maps in the latent space 73 | fig, ax = plt.subplots(1, 1, figsize=(3.6, 4), dpi=300) 74 | fig.set_tight_layout(True) 75 | 76 | # Fourier transform feature maps 77 | fourier_latents = [] 78 | for latent in latents: 79 | if len(latent.shape) == 3: # for vit 80 | b, n, c = latent.shape 81 | h, w = int(math.sqrt(n)), int(math.sqrt(n)) 82 | latent = latent.permute(0, 2, 1).reshape(b, c, h, w) 83 | elif len(latent.shape) == 4: # for cnn 84 | b, c, h, w = latent.shape 85 | else: 86 | raise Exception("shape: %s" % str(latent.shape)) 87 | latent = fourier(latent) 88 | latent = shift(latent).mean(dim=(0, 1)) 89 | latent = latent.diag()[int(h / 2):] # only use the half-diagonal components 90 | latent = latent - latent[0] # visualize 'relative' log amplitudes (i.e., low_freq - high_freq) 91 | fourier_latents.append(latent) 92 | # plot fourier transformed relative log amplitudes 93 | for i, latent in enumerate(reversed(fourier_latents)): 94 | freq = np.linspace(0, 1, len(latent)) 95 | ax.plot(freq, latent, color=cm.plasma_r(i / len(fourier_latents))) 96 | 97 | ax.set_xlim(left=0, right=1) 98 | x_major_locator = MultipleLocator(0.5) 99 | y_major_locator = MultipleLocator(2.0) 100 | 101 | ax.set_xlabel("Frequency") 102 | ax.set_ylabel(r"$\Delta$ Log amplitude") 103 | 104 | ax.xaxis.set_major_locator(x_major_locator) 105 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.1fπ')) 106 | 107 | ax.yaxis.set_major_locator(y_major_locator) 108 | plt.savefig(name) 109 | # plt.show() 110 | 111 | 112 | def main(): 113 | model = restv2_tiny().cuda() 114 | checkpoint = torch.load("output_dir/restv2_tiny_224.pth", map_location='cpu') 115 | model.load_state_dict(checkpoint['model']) 116 | model = model.eval() 117 | 118 | val_dir = '/data/ilsvrc2012/val/' 119 | batch_size = 1 120 | num_workers = 4 121 | batch = 0 122 | 123 | sample_range = range(10000 * batch, 10000 * (batch + 1)) 124 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 125 | std=[0.229, 0.224, 0.225]) 126 | val_loader = DataLoader( 127 | datasets.ImageFolder(val_dir, transforms.Compose([ 128 | transforms.Resize(256), 129 | transforms.CenterCrop(224), 130 | transforms.ToTensor(), 131 | normalize, 132 | ])), 133 | batch_size=batch_size, 134 | shuffle=False, 135 | num_workers=num_workers, 136 | pin_memory=True, 137 | sampler=RangeSampler(sample_range) 138 | ) 139 | 140 | latents_attn = [] 141 | latents_up = [] 142 | latents_com = [] 143 | 144 | for i, (samples, targets) in enumerate(tqdm(val_loader, total=len(val_loader), desc='Loading Images')): 145 | samples = samples.cuda() 146 | with torch.no_grad(): 147 | feats = model(samples) 148 | 149 | for j in range(11): 150 | if i == 0: 151 | latents_attn.append(feats[j]["attn"].cpu()) 152 | latents_up.append(feats[j]["up"].cpu()) 153 | latents_com.append(feats[j]["com"].cpu()) 154 | else: 155 | latents_attn[j] = torch.cat([latents_attn[j], feats[j]["attn"].cpu()], dim=0) 156 | latents_up[j] = torch.cat([latents_up[j], feats[j]["up"].cpu()], dim=0) 157 | latents_com[j] = torch.cat([latents_com[j], feats[j]["com"].cpu()], dim=0) 158 | plot(latents_attn, name="attn.png") 159 | plot(latents_up, name="up.png") 160 | plot(latents_com, name="combine.png") 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /get_flops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import argparse 8 | import torch 9 | from timm.models import create_model 10 | import models 11 | 12 | try: 13 | from mmcv.cnn import get_model_complexity_info 14 | from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string 15 | except ImportError: 16 | raise ImportError('Please upgrade mmcv to >0.6.2') 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Get FLOPS of a classification model') 21 | parser.add_argument('--model', default='restv2_tiny', type=str, metavar='MODEL', 22 | help='train config file path') 23 | parser.add_argument( 24 | '--shape', 25 | type=int, 26 | nargs='+', 27 | default=[224, 224], 28 | help='input image size') 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def attn_flops(h, w, r, dim): 34 | return 2 * h * w * (h // r) * (w // r) * dim 35 | 36 | 37 | def get_flops(model, input_shape): 38 | flops, params = get_model_complexity_info(model, input_shape, as_strings=False) 39 | if 'rest' in model.name: 40 | _, H, W = input_shape 41 | # calculate flops of ResTv2 42 | stage1 = attn_flops(H // 4, W // 4, 43 | model.stage1[0].attn.sr_ratio, 44 | model.stage1[0].attn.dim) * model.depths[0] 45 | stage2 = attn_flops(H // 8, W // 8, 46 | model.stage2[0].attn.sr_ratio, 47 | model.stage2[0].attn.dim) * model.depths[1] 48 | stage3 = attn_flops(H // 16, W // 16, 49 | model.stage3[0].attn.sr_ratio, 50 | model.stage3[0].attn.dim) * model.depths[2] 51 | stage4 = attn_flops(H // 32, W // 32, 52 | model.stage4[0].attn.sr_ratio, 53 | model.stage4[0].attn.dim) * model.depths[3] 54 | flops += stage1 + stage2 + stage3 + stage4 55 | return flops_to_string(flops), params_to_string(params) 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | 61 | if len(args.shape) == 1: 62 | input_shape = (3, args.shape[0], args.shape[0]) 63 | elif len(args.shape) == 2: 64 | input_shape = (3,) + tuple(args.shape) 65 | else: 66 | raise ValueError('invalid input shape') 67 | 68 | model = create_model( 69 | args.model, 70 | pretrained=False, 71 | num_classes=1000 72 | ) 73 | model.name = args.model 74 | if torch.cuda.is_available(): 75 | model.cuda() 76 | model.eval() 77 | 78 | flops, params = get_flops(model, input_shape) 79 | 80 | split_line = '=' * 30 81 | print(f'{split_line}\nInput shape: {input_shape}\n' 82 | f'Flops: {flops}\nParams: {params}\n{split_line}') 83 | print('!!!Please be cautious if you use the results in papers. ' 84 | 'You may need to check if all ops are supported and verify that the ' 85 | 'flops computation is correct.') 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rest import * 2 | from .rest_v2 import * 3 | -------------------------------------------------------------------------------- /models/rest.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from timm.models.registry import register_model 12 | 13 | __all__ = ['rest_lite', 'rest_small', 'rest_base', 'rest_large'] 14 | 15 | 16 | def _cfg(url='', **kwargs): 17 | return { 18 | 'url': url, 19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 20 | 'crop_pct': .9, 'interpolation': 'bicubic', 21 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 22 | 'classifier': 'head', 23 | **kwargs 24 | } 25 | 26 | 27 | default_cfgs = { 28 | 'rest_lite': _cfg(), 29 | 'rest_small': _cfg(), 30 | 'rest_base': _cfg(), 31 | 'rest_large': _cfg(), 32 | } 33 | 34 | 35 | class Mlp(nn.Module): 36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 37 | super().__init__() 38 | out_features = out_features or in_features 39 | hidden_features = hidden_features or in_features 40 | self.fc1 = nn.Linear(in_features, hidden_features) 41 | self.act = act_layer() 42 | self.fc2 = nn.Linear(hidden_features, out_features) 43 | self.drop = nn.Dropout(drop) 44 | 45 | def forward(self, x): 46 | x = self.fc1(x) 47 | x = self.act(x) 48 | x = self.drop(x) 49 | x = self.fc2(x) 50 | x = self.drop(x) 51 | return x 52 | 53 | 54 | class Attention(nn.Module): 55 | def __init__(self, 56 | dim, 57 | num_heads=8, 58 | qkv_bias=False, 59 | qk_scale=None, 60 | attn_drop=0., 61 | proj_drop=0., 62 | sr_ratio=1, 63 | apply_transform=False): 64 | super().__init__() 65 | self.num_heads = num_heads 66 | head_dim = dim // num_heads 67 | self.scale = qk_scale or head_dim ** -0.5 68 | 69 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 70 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 71 | self.attn_drop = nn.Dropout(attn_drop) 72 | self.proj = nn.Linear(dim, dim) 73 | self.proj_drop = nn.Dropout(proj_drop) 74 | 75 | self.sr_ratio = sr_ratio 76 | if sr_ratio > 1: 77 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 78 | self.sr_norm = nn.LayerNorm(dim) 79 | 80 | self.apply_transform = apply_transform and num_heads > 1 81 | if self.apply_transform: 82 | self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1) 83 | self.transform_norm = nn.InstanceNorm2d(self.num_heads) 84 | 85 | def forward(self, x, H, W): 86 | B, N, C = x.shape 87 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 88 | if self.sr_ratio > 1: 89 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 90 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 91 | x_ = self.sr_norm(x_) 92 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 93 | else: 94 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 95 | k, v = kv[0], kv[1] 96 | 97 | attn = (q @ k.transpose(-2, -1)) * self.scale 98 | if self.apply_transform: 99 | attn = self.transform_conv(attn) 100 | attn = attn.softmax(dim=-1) 101 | attn = self.transform_norm(attn) 102 | else: 103 | attn = attn.softmax(dim=-1) 104 | 105 | attn = self.attn_drop(attn) 106 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 107 | x = self.proj(x) 108 | x = self.proj_drop(x) 109 | return x 110 | 111 | 112 | class Block(nn.Module): 113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 114 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, apply_transform=False): 115 | super().__init__() 116 | self.norm1 = norm_layer(dim) 117 | self.attn = Attention( 118 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 119 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, apply_transform=apply_transform) 120 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 121 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 122 | self.norm2 = norm_layer(dim) 123 | mlp_hidden_dim = int(dim * mlp_ratio) 124 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 125 | 126 | def forward(self, x, H, W): 127 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 128 | x = x + self.drop_path(self.mlp(self.norm2(x))) 129 | return x 130 | 131 | 132 | class PA(nn.Module): 133 | def __init__(self, dim): 134 | super().__init__() 135 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 136 | self.sigmoid = nn.Sigmoid() 137 | 138 | def forward(self, x): 139 | return x * self.sigmoid(self.pa_conv(x)) 140 | 141 | 142 | class GL(nn.Module): 143 | def __init__(self, dim): 144 | super().__init__() 145 | self.gl_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 146 | 147 | def forward(self, x): 148 | return x + self.gl_conv(x) 149 | 150 | 151 | class PatchEmbed(nn.Module): 152 | """ Image to Patch Embedding""" 153 | 154 | def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False): 155 | super().__init__() 156 | self.patch_size = to_2tuple(patch_size) 157 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2) 158 | self.norm = nn.BatchNorm2d(out_ch) 159 | 160 | self.with_pos = with_pos 161 | if self.with_pos: 162 | self.pos = PA(out_ch) 163 | 164 | def forward(self, x): 165 | B, C, H, W = x.shape 166 | x = self.conv(x) 167 | x = self.norm(x) 168 | if self.with_pos: 169 | x = self.pos(x) 170 | x = x.flatten(2).transpose(1, 2) 171 | H, W = H // self.patch_size[0], W // self.patch_size[1] 172 | return x, (H, W) 173 | 174 | 175 | class BasicStem(nn.Module): 176 | def __init__(self, in_ch=3, out_ch=64, with_pos=False): 177 | super(BasicStem, self).__init__() 178 | hidden_ch = out_ch // 2 179 | self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, stride=2, padding=1, bias=False) 180 | self.norm1 = nn.BatchNorm2d(hidden_ch) 181 | self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, stride=1, padding=1, bias=False) 182 | self.norm2 = nn.BatchNorm2d(hidden_ch) 183 | self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False) 184 | 185 | self.act = nn.ReLU(inplace=True) 186 | self.with_pos = with_pos 187 | if self.with_pos: 188 | self.pos = PA(out_ch) 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.norm1(x) 193 | x = self.act(x) 194 | 195 | x = self.conv2(x) 196 | x = self.norm2(x) 197 | x = self.act(x) 198 | 199 | x = self.conv3(x) 200 | if self.with_pos: 201 | x = self.pos(x) 202 | return x 203 | 204 | 205 | class Stem(nn.Module): 206 | def __init__(self, in_ch=3, out_ch=64, with_pos=False): 207 | super(Stem, self).__init__() 208 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=7, stride=2, padding=3, bias=False) 209 | self.norm = nn.BatchNorm2d(out_ch) 210 | self.act = nn.ReLU(inplace=True) 211 | 212 | self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 213 | self.with_pos = with_pos 214 | if self.with_pos: 215 | self.pos = PA(out_ch) 216 | 217 | def forward(self, x): 218 | x = self.conv(x) 219 | x = self.norm(x) 220 | x = self.act(x) 221 | x = self.max_pool(x) 222 | 223 | if self.with_pos: 224 | x = self.pos(x) 225 | return x 226 | 227 | 228 | class ResT(nn.Module): 229 | def __init__(self, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 230 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, 231 | qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 232 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], 233 | norm_layer=nn.LayerNorm, apply_transform=False): 234 | super().__init__() 235 | self.num_classes = num_classes 236 | self.depths = depths 237 | self.apply_transform = apply_transform 238 | 239 | self.stem = BasicStem(in_ch=in_chans, out_ch=embed_dims[0], with_pos=True) 240 | 241 | self.patch_embed_2 = PatchEmbed(patch_size=2, in_ch=embed_dims[0], out_ch=embed_dims[1], with_pos=True) 242 | self.patch_embed_3 = PatchEmbed(patch_size=2, in_ch=embed_dims[1], out_ch=embed_dims[2], with_pos=True) 243 | self.patch_embed_4 = PatchEmbed(patch_size=2, in_ch=embed_dims[2], out_ch=embed_dims[3], with_pos=True) 244 | 245 | # transformer encoder 246 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 247 | cur = 0 248 | 249 | self.stage1 = nn.ModuleList([ 250 | Block(embed_dims[0], num_heads[0], mlp_ratios[0], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 251 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0], apply_transform=apply_transform) 252 | for i in range(self.depths[0])]) 253 | 254 | cur += depths[0] 255 | self.stage2 = nn.ModuleList([ 256 | Block(embed_dims[1], num_heads[1], mlp_ratios[1], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 257 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1], apply_transform=apply_transform) 258 | for i in range(self.depths[1])]) 259 | 260 | cur += depths[1] 261 | self.stage3 = nn.ModuleList([ 262 | Block(embed_dims[2], num_heads[2], mlp_ratios[2], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 263 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2], apply_transform=apply_transform) 264 | for i in range(self.depths[2])]) 265 | 266 | cur += depths[2] 267 | self.stage4 = nn.ModuleList([ 268 | Block(embed_dims[3], num_heads[3], mlp_ratios[3], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 269 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3], apply_transform=apply_transform) 270 | for i in range(self.depths[3])]) 271 | 272 | self.norm = norm_layer(embed_dims[3]) 273 | 274 | # classification head 275 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 276 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 277 | 278 | # init weights 279 | self.apply(self._init_weights) 280 | 281 | def _init_weights(self, m): 282 | if isinstance(m, nn.Conv2d): 283 | trunc_normal_(m.weight, std=0.02) 284 | elif isinstance(m, nn.Linear): 285 | trunc_normal_(m.weight, std=0.02) 286 | if m.bias is not None: 287 | nn.init.constant_(m.bias, 0) 288 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 289 | nn.init.constant_(m.weight, 1.0) 290 | nn.init.constant_(m.bias, 0) 291 | 292 | def forward(self, x): 293 | x = self.stem(x) 294 | B, _, H, W = x.shape 295 | x = x.flatten(2).permute(0, 2, 1) 296 | 297 | # stage 1 298 | for blk in self.stage1: 299 | x = blk(x, H, W) 300 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 301 | 302 | # stage 2 303 | x, (H, W) = self.patch_embed_2(x) 304 | for blk in self.stage2: 305 | x = blk(x, H, W) 306 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 307 | 308 | # stage 3 309 | x, (H, W) = self.patch_embed_3(x) 310 | for blk in self.stage3: 311 | x = blk(x, H, W) 312 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 313 | 314 | # stage 4 315 | x, (H, W) = self.patch_embed_4(x) 316 | for blk in self.stage4: 317 | x = blk(x, H, W) 318 | x = self.norm(x) 319 | 320 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 321 | x = self.avg_pool(x).flatten(1) 322 | x = self.head(x) 323 | return x 324 | 325 | 326 | @register_model 327 | def rest_lite(pretrained=False, **kwargs): 328 | model = ResT(embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 329 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs) 330 | model.default_cfg = _cfg() 331 | return model 332 | 333 | 334 | @register_model 335 | def rest_small(pretrained=False, **kwargs): 336 | model = ResT(embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 337 | depths=[2, 2, 6, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs) 338 | model.default_cfg = _cfg() 339 | return model 340 | 341 | 342 | @register_model 343 | def rest_base(pretrained=False, **kwargs): 344 | model = ResT(embed_dims=[96, 192, 384, 768], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 345 | depths=[2, 2, 6, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs) 346 | model.default_cfg = _cfg() 347 | return model 348 | 349 | 350 | @register_model 351 | def rest_large(pretrained=False, **kwargs): 352 | model = ResT(embed_dims=[96, 192, 384, 768], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True, 353 | depths=[2, 2, 18, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs) 354 | model.default_cfg = _cfg() 355 | return model 356 | -------------------------------------------------------------------------------- /models/rest_v2.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from timm.models.registry import register_model 12 | 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.fc1 = nn.Linear(dim, 4 * dim) 18 | self.act = nn.GELU() 19 | self.fc2 = nn.Linear(4 * dim, dim) 20 | 21 | def forward(self, x): 22 | x = self.fc1(x) 23 | x = self.act(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, 30 | dim, 31 | num_heads=8, 32 | sr_ratio=1): 33 | super().__init__() 34 | self.num_heads = num_heads 35 | head_dim = dim // num_heads 36 | self.scale = head_dim ** -0.5 37 | self.dim = dim 38 | 39 | self.q = nn.Linear(dim, dim, bias=True) 40 | self.kv = nn.Linear(dim, dim * 2, bias=True) 41 | 42 | self.sr_ratio = sr_ratio 43 | if sr_ratio > 1: 44 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 45 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6) 46 | 47 | self.up = nn.Sequential( 48 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim), 49 | nn.PixelShuffle(upscale_factor=sr_ratio) 50 | ) 51 | self.up_norm = nn.LayerNorm(dim, eps=1e-6) 52 | 53 | self.proj = nn.Linear(dim, dim) 54 | 55 | def forward(self, x, H, W): 56 | B, N, C = x.shape 57 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 58 | if self.sr_ratio > 1: 59 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 60 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 61 | x = self.sr_norm(x) 62 | 63 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 64 | k, v = kv[0], kv[1] 65 | attn = (q @ k.transpose(-2, -1)) * self.scale 66 | attn = attn.softmax(dim=-1) 67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 68 | 69 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio) 70 | identity = self.up(identity).flatten(2).transpose(1, 2) 71 | x = self.proj(x + self.up_norm(identity)) 72 | return x 73 | 74 | 75 | class Block(nn.Module): 76 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.): 77 | super().__init__() 78 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 79 | self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio) 80 | 81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 82 | self.mlp = Mlp(dim) 83 | 84 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 85 | 86 | def forward(self, x, H, W): 87 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm 88 | x = x + self.drop_path(self.mlp(self.norm2(x))) 89 | return x 90 | 91 | 92 | class PA(nn.Module): 93 | def __init__(self, dim): 94 | super().__init__() 95 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True) 96 | self.sigmoid = nn.Sigmoid() 97 | 98 | def forward(self, x): 99 | return x * self.sigmoid(self.pa_conv(x)) 100 | 101 | 102 | class Stem(nn.Module): 103 | def __init__(self, in_dim=3, out_dim=96, patch_size=2): 104 | super().__init__() 105 | self.patch_size = to_2tuple(patch_size) 106 | self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=patch_size, stride=patch_size) 107 | self.norm = nn.LayerNorm(out_dim, eps=1e-6) 108 | 109 | def forward(self, x): 110 | B, C, H, W = x.shape 111 | x = self.proj(x) 112 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 113 | x = self.norm(x) 114 | H, W = H // self.patch_size[0], W // self.patch_size[1] 115 | return x, (H, W) 116 | 117 | 118 | class ConvStem(nn.Module): 119 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 120 | super().__init__() 121 | self.patch_size = to_2tuple(patch_size) 122 | stem = [] 123 | in_dim, out_dim = in_ch, out_ch // 2 124 | for i in range(2): 125 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False)) 126 | stem.append(nn.BatchNorm2d(out_dim)) 127 | stem.append(nn.ReLU(inplace=True)) 128 | in_dim, out_dim = out_dim, out_dim * 2 129 | 130 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1)) 131 | self.proj = nn.Sequential(*stem) 132 | 133 | self.with_pos = with_pos 134 | if self.with_pos: 135 | self.pos = PA(out_ch) 136 | 137 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 138 | 139 | def forward(self, x): 140 | B, C, H, W = x.shape 141 | x = self.proj(x) 142 | if self.with_pos: 143 | x = self.pos(x) 144 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 145 | x = self.norm(x) 146 | H, W = H // self.patch_size[0], W // self.patch_size[1] 147 | return x, (H, W) 148 | 149 | 150 | class PatchEmbed(nn.Module): 151 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 152 | super().__init__() 153 | self.patch_size = to_2tuple(patch_size) 154 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2) 155 | 156 | self.with_pos = with_pos 157 | if self.with_pos: 158 | self.pos = PA(out_ch) 159 | 160 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 161 | 162 | def forward(self, x): 163 | B, C, H, W = x.shape 164 | x = self.proj(x) 165 | if self.with_pos: 166 | x = self.pos(x) 167 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 168 | x = self.norm(x) 169 | H, W = H // self.patch_size[0], W // self.patch_size[1] 170 | return x, (H, W) 171 | 172 | 173 | class ResTV2(nn.Module): 174 | def __init__(self, in_chans=3, num_classes=1000, embed_dims=[96, 192, 384, 768], 175 | num_heads=[1, 2, 4, 8], drop_path_rate=0., 176 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1]): 177 | super().__init__() 178 | self.num_classes = num_classes 179 | self.depths = depths 180 | 181 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4) 182 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2) 183 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2) 184 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2) 185 | 186 | # transformer encoder 187 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 188 | cur = 0 189 | self.stage1 = nn.ModuleList([ 190 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i]) 191 | for i in range(depths[0]) 192 | ]) 193 | 194 | cur += depths[0] 195 | self.stage2 = nn.ModuleList([ 196 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i]) 197 | for i in range(depths[1]) 198 | ]) 199 | 200 | cur += depths[1] 201 | self.stage3 = nn.ModuleList([ 202 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i]) 203 | for i in range(depths[2]) 204 | ]) 205 | 206 | cur += depths[2] 207 | self.stage4 = nn.ModuleList([ 208 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i]) 209 | for i in range(depths[3]) 210 | ]) 211 | 212 | self.norm = nn.LayerNorm(embed_dims[-1], eps=1e-6) # final norm layer 213 | # classification head 214 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 215 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 216 | 217 | # init weights 218 | self.apply(self._init_weights) 219 | 220 | def _init_weights(self, m): 221 | if isinstance(m, (nn.Conv2d, nn.Linear)): 222 | trunc_normal_(m.weight, std=0.02) 223 | if m.bias is not None: 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 226 | nn.init.constant_(m.weight, 1.) 227 | nn.init.constant_(m.bias, 0.) 228 | 229 | def forward(self, x): 230 | B, _, H, W = x.shape 231 | x, (H, W) = self.stem(x) 232 | 233 | # stage 1 234 | for blk in self.stage1: 235 | x = blk(x, H, W) 236 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 237 | 238 | # stage 2 239 | x, (H, W) = self.patch_2(x) 240 | for blk in self.stage2: 241 | x = blk(x, H, W) 242 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 243 | 244 | # stage 3 245 | x, (H, W) = self.patch_3(x) 246 | for blk in self.stage3: 247 | x = blk(x, H, W) 248 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 249 | 250 | # stage 4 251 | x, (H, W) = self.patch_4(x) 252 | for blk in self.stage4: 253 | x = blk(x, H, W) 254 | x = self.norm(x) 255 | 256 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 257 | x = self.avg_pool(x).flatten(1) 258 | x = self.head(x) 259 | return x 260 | 261 | 262 | @register_model 263 | def restv2_tiny(pretrained=False, **kwargs): # 82.3|4.7G|24M -> |3.92G|30.37M 4.5G|30.33M 264 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 2, 6, 2], **kwargs) 265 | return model 266 | 267 | 268 | @register_model 269 | def restv2_small(pretrained=False, **kwargs): # 83.6|7.0G|35M -> |5.78G|40.94M 270 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 2, 12, 2], **kwargs) 271 | return model 272 | 273 | 274 | @register_model 275 | def restv2_base(pretrained=False, **kwargs): # 84.4|10.2G|52M -> |7.25G|55.75M 276 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 3, 16, 3], **kwargs) 277 | return model 278 | 279 | 280 | @register_model 281 | def restv2_large(pretrained=False, **kwargs): # 85.3|39.6|218M -> |14.09G|98.61M 282 | model = ResTV2(num_heads=[2, 4, 8, 16], embed_dims=[128, 256, 512, 1024], depths=[2, 3, 16, 2], **kwargs) 283 | return model -------------------------------------------------------------------------------- /object_detection/GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with Detectron2 2 | 3 | This document provides a brief intro of the usage of builtin command-line tools in detectron2. 4 | 5 | For a tutorial that involves actual coding with the API, 6 | see our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) 7 | which covers how to run inference with an 8 | existing model, and how to train a builtin model on a custom dataset. 9 | 10 | 11 | ### Inference Demo with Pre-trained Models 12 | 13 | 1. Pick a model and its config file from 14 | [model zoo](MODEL_ZOO.md), 15 | for example, `mask_rcnn_R_50_FPN_3x.yaml`. 16 | 2. We provide `demo.py` that is able to demo builtin configs. Run it with: 17 | ``` 18 | cd demo/ 19 | python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \ 20 | --input input1.jpg input2.jpg \ 21 | [--other-options] 22 | --opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl 23 | ``` 24 | The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation. 25 | This command will run the inference and show visualizations in an OpenCV window. 26 | 27 | For details of the command line arguments, see `demo.py -h` or look at its source code 28 | to understand its behavior. Some common arguments are: 29 | * To run __on your webcam__, replace `--input files` with `--webcam`. 30 | * To run __on a video__, replace `--input files` with `--video-input video.mp4`. 31 | * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`. 32 | * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`. 33 | 34 | 35 | ### Training & Evaluation in Command Line 36 | 37 | We provide two scripts in "tools/plain_train_net.py" and "tools/train_net.py", 38 | that are made to train all the configs provided in detectron2. You may want to 39 | use it as a reference to write your own training script. 40 | 41 | Compared to "train_net.py", "plain_train_net.py" supports fewer default 42 | features. It also includes fewer abstraction, therefore is easier to add custom 43 | logic. 44 | 45 | To train a model with "train_net.py", first 46 | setup the corresponding datasets following 47 | [datasets/README.md](./datasets/README.md), 48 | then run: 49 | ``` 50 | cd tools/ 51 | ./train_net.py --num-gpus 8 \ 52 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml 53 | ``` 54 | 55 | The configs are made for 8-GPU training. 56 | To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.: 57 | ``` 58 | ./train_net.py \ 59 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \ 60 | --num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025 61 | ``` 62 | 63 | To evaluate a model's performance, use 64 | ``` 65 | ./train_net.py \ 66 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \ 67 | --eval-only MODEL.WEIGHTS /path/to/checkpoint_file 68 | ``` 69 | For more options, see `./train_net.py -h`. 70 | 71 | ### Use Detectron2 APIs in Your Code 72 | 73 | See our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5) 74 | to learn how to use detectron2 APIs to: 75 | 1. run inference with an existing model 76 | 2. train a builtin model on a custom dataset 77 | 78 | See [detectron2/projects](https://github.com/facebookresearch/detectron2/tree/master/projects) 79 | for more ways to build your project on detectron2. 80 | -------------------------------------------------------------------------------- /object_detection/README.md: -------------------------------------------------------------------------------- 1 | # Object detection 2 | ResTv1 and ResTv2 for Object Detection by detectron2 3 | 4 | This repo contains the supported code and configuration files to reproduce object detection results of ResTv1 and ResTv2. It is based on [detectron2](https://github.com/facebookresearch/detectron2). 5 | 6 | 7 | ## Results and Models 8 | 9 | ### RetinaNet 10 | 11 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FPS | config | model | 12 | |:------------:| :---: |:-------:|:-------:|:--------:|:-------:| :---: |:-----------------------------------------------------------:|:--------------------------------------------------------:| 13 | | ResTv1-S-FPN | ImageNet-1K | 1x | 40.3 | - | 23.4 | - | [config](configs/ResTv1/retinanet_rest_small_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/13YXVRQeNcF_3Txns8eJzZw) | 14 | | ResTv1-B-FPN | ImageNet-1K | 1x | 42.0 | - | 40.5 | - | [config](configs/ResTv1/retinanet_rest_base_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1hMRM5YEIGsfWfvqbuC7JWA) | 15 | 16 | ### Mask R-CNN 17 | 18 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FPS | config | model | 19 | |:------------:| :---: |:-------:|:-------:|:--------:|:-------:|:----:|:-----------------------------------------------------------:|:--------------------------------------------------------:| 20 | | ResTv1-S-FPN | ImageNet-1K | 1x | 39.6 | 37.2 | 31.2 | - | [config](configs/ResTv1/mask_rcnn_rest_small_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1UfDsRGwgZcydXtj56ZFoDg) | 21 | | ResTv1-B-FPN | ImageNet-1K | 1x | 41.6 | 38.7 | 49.8 | - | [config](configs/ResTv1/mask_rcnn_rest_base_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1oSdMGTSBK_JDcLEq3XjY8w) | 22 | | ResTv2-T-FPN | ImageNet-1K | 3x | 47.6 | 43.2 | 49.9 | 25.0 | [config](configs/ResTv2/mask_rcnn_restv2_tiny_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/16fDcEupHBZ1zHyzFFZvM3g) | 23 | | ResTv2-S-FPN | ImageNet-1K | 3x | 48.1 | 43.3 | 60.7 | 21.3 | [config](configs/ResTv2/mask_rcnn_restv2_small_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/1UfDsRGwgZcydXtj56ZFoDg) | 24 | | ResTv2-B-FPN | ImageNet-1K | 3x | 48.7 | 43.9 | 75.5 | 18.3 | [config](configs/ResTv2/mask_rcnn_restv2_base_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/1zHQM0KqgtqQzg0-mdtx-Jg) | 25 | 26 | 27 | ## Usage 28 | Please refer to [get_started.md](https://detectron2.readthedocs.io/en/latest/tutorials/getting_started.html) for installation and dataset preparation. 29 | 30 | note: you need convert the original pretrained weights to d2 format by [convert_to_d2.py](convert_to_d2.py) 31 | 32 | ## Citation 33 | If you find this repository helpful, please consider citing: 34 | 35 | **ResTv1** 36 | ``` 37 | @inproceedings{zhang2021rest, 38 | title={ResT: An Efficient Transformer for Visual Recognition}, 39 | author={Qinglong Zhang and Yu-bin Yang}, 40 | booktitle={Advances in Neural Information Processing Systems}, 41 | year={2021}, 42 | url={https://openreview.net/forum?id=6Ab68Ip4Mu} 43 | } 44 | ``` 45 | 46 | **ResTv2** 47 | ``` 48 | @article{zhang2022rest, 49 | title={ResT V2: Simpler, Faster and Stronger}, 50 | author={Zhang, Qing-Long and Yang, Yu-Bin}, 51 | journal={arXiv preprint arXiv:2204.07366}, 52 | year={2022} 53 | ``` -------------------------------------------------------------------------------- /object_detection/analyze_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import logging 5 | import numpy as np 6 | from collections import Counter 7 | import tqdm 8 | from fvcore.nn import flop_count_table # can also try flop_count_str 9 | 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | from detectron2.config import get_cfg 12 | from detectron2.data import build_detection_test_loader 13 | from detectron2.engine import default_argument_parser, default_setup 14 | from detectron2.modeling import build_model 15 | from detectron2.utils.analysis import ( 16 | FlopCountAnalysis, 17 | activation_count_operators, 18 | parameter_count_table, 19 | ) 20 | 21 | 22 | from detectron2.utils.logger import setup_logger 23 | from restv2 import add_restv2_config 24 | 25 | logger = logging.getLogger("detectron2") 26 | 27 | 28 | def setup(args): 29 | """ 30 | Create configs and perform basic setups. 31 | """ 32 | cfg = get_cfg() 33 | add_restv2_config(cfg) 34 | cfg.merge_from_file(args.config_file) 35 | cfg.merge_from_list(args.opts) 36 | cfg.freeze() 37 | default_setup(cfg, args) 38 | 39 | return cfg 40 | 41 | def do_flop(cfg): 42 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) 43 | model = build_model(cfg) 44 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) 45 | model.eval() 46 | 47 | counts = Counter() 48 | total_flops = [] 49 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa 50 | flops = FlopCountAnalysis(model, data) 51 | if idx > 0: 52 | flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False) 53 | counts += flops.by_operator() 54 | total_flops.append(flops.total()) 55 | 56 | logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops)) 57 | logger.info( 58 | "Average GFlops for each type of operators:\n" 59 | + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) 60 | ) 61 | logger.info( 62 | "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) 63 | ) 64 | 65 | 66 | def do_activation(cfg): 67 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) 68 | model = build_model(cfg) 69 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) 70 | model.eval() 71 | 72 | counts = Counter() 73 | total_activations = [] 74 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa 75 | count = activation_count_operators(model, data) 76 | counts += count 77 | total_activations.append(sum(count.values())) 78 | logger.info( 79 | "(Million) Activations for Each Type of Operators:\n" 80 | + str([(k, v / idx) for k, v in counts.items()]) 81 | ) 82 | logger.info( 83 | "Total (Million) Activations: {}±{}".format( 84 | np.mean(total_activations), np.std(total_activations) 85 | ) 86 | ) 87 | 88 | 89 | def do_parameter(cfg): 90 | model = build_model(cfg) 91 | logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=10)) 92 | 93 | 94 | def do_structure(cfg): 95 | model = build_model(cfg) 96 | logger.info("Model Structure:\n" + str(model)) 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = default_argument_parser( 101 | epilog=""" 102 | Examples: 103 | 104 | To show parameters of a model: 105 | $ ./analyze_model.py --tasks parameter \\ 106 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml 107 | 108 | Flops and activations are data-dependent, therefore inputs and model weights 109 | are needed to count them: 110 | 111 | $ ./analyze_model.py --num-inputs 100 --tasks flop \\ 112 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\ 113 | MODEL.WEIGHTS /path/to/model.pkl 114 | """ 115 | ) 116 | parser.add_argument( 117 | "--tasks", 118 | choices=["flop", "activation", "parameter", "structure"], 119 | required=True, 120 | nargs="+", 121 | ) 122 | parser.add_argument( 123 | "-n", 124 | "--num-inputs", 125 | default=100, 126 | type=int, 127 | help="number of inputs used to compute statistics for flops/activations, " 128 | "both are data dependent.", 129 | ) 130 | args = parser.parse_args() 131 | assert not args.eval_only 132 | assert args.num_gpus == 1 133 | 134 | cfg = setup(args) 135 | 136 | for task in args.tasks: 137 | { 138 | "flop": do_flop, 139 | "activation": do_activation, 140 | "parameter": do_parameter, 141 | "structure": do_structure, 142 | }[task](cfg) 143 | -------------------------------------------------------------------------------- /object_detection/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.0001 38 | WEIGHT_DECAY: 0.05 39 | OPTIMIZER: "AdamW" 40 | CLIP_GRADIENTS: 41 | ENABLED: True 42 | CLIP_TYPE: "full_model" 43 | CLIP_VALUE: 0.01 44 | NORM_TYPE: 2.0 45 | STEPS: (60000, 80000) 46 | MAX_ITER: 90000 47 | INPUT: 48 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 49 | CROP: 50 | ENABLED: True 51 | TYPE: "relative" 52 | SIZE: [0.9, 0.9] 53 | VERSION: 2 -------------------------------------------------------------------------------- /object_detection/configs/Base-RetinaNet.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RetinaNet" 3 | BACKBONE: 4 | NAME: "build_retinanet_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res3", "res4", "res5"] 7 | ANCHOR_GENERATOR: 8 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"] 9 | FPN: 10 | IN_FEATURES: ["res3", "res4", "res5"] 11 | RETINANET: 12 | IOU_THRESHOLDS: [0.4, 0.5] 13 | IOU_LABELS: [0, -1, 1] 14 | SMOOTH_L1_LOSS_BETA: 0.0 15 | DATASETS: 16 | TRAIN: ("coco_2017_train",) 17 | TEST: ("coco_2017_val",) 18 | SOLVER: 19 | IMS_PER_BATCH: 16 20 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate 21 | STEPS: (60000, 80000) 22 | MAX_ITER: 90000 23 | INPUT: 24 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 25 | VERSION: 2 26 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv1/mask_rcnn_rest_base_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_rest_fpn_backbone" 6 | MASK_ON: True 7 | REST: 8 | NAME : "rest_base" 9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 10 | FPN: 11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 12 | SOLVER: 13 | OPTIMIZER: "AdamW" 14 | BASE_LR: 0.0001 15 | WEIGHT_DECAY: 0.05 16 | OUTPUT_DIR: "output/mask_rcnn/rest_base_ms_1x" 17 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv1/mask_rcnn_rest_small_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_rest_fpn_backbone" 6 | MASK_ON: True 7 | REST: 8 | NAME : "rest_small" 9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 10 | FPN: 11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 12 | SOLVER: 13 | OPTIMIZER: "AdamW" 14 | BASE_LR: 0.0001 15 | WEIGHT_DECAY: 0.05 16 | OUTPUT_DIR: "output/mask_rcnn/rest_small_ms_1x" 17 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv1/retinanet_rest_base_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_retinanet_rest_fpn_backbone" 6 | REST: 7 | NAME : "rest_base" 8 | OUT_FEATURES: ["stage2", "stage3", "stage4"] 9 | FPN: 10 | IN_FEATURES: ["stage2", "stage3", "stage4"] 11 | SOLVER: 12 | OPTIMIZER: "AdamW" 13 | BASE_LR: 0.0001 14 | WEIGHT_DECAY: 0.05 15 | OUTPUT_DIR: "output/retinanet/rest_base_ms_1x" 16 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv1/retinanet_rest_small_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_retinanet_rest_fpn_backbone" 6 | REST: 7 | NAME : "rest_small" 8 | OUT_FEATURES: ["stage2", "stage3", "stage4"] 9 | FPN: 10 | IN_FEATURES: ["stage2", "stage3", "stage4"] 11 | SOLVER: 12 | OPTIMIZER: "AdamW" 13 | BASE_LR: 0.0001 14 | WEIGHT_DECAY: 0.05 15 | OUTPUT_DIR: "output/retinanet/rest_small_ms_1x" 16 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/mask_rcnn_rest_base_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_rest_fpn_backbone" 6 | MASK_ON: True 7 | REST: 8 | NAME : "rest_base" 9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 10 | FPN: 11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 12 | SOLVER: 13 | OPTIMIZER: "AdamW" 14 | BASE_LR: 0.0001 15 | WEIGHT_DECAY: 0.05 16 | OUTPUT_DIR: "output/mask_rcnn/rest_base_ms_1x" 17 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/mask_rcnn_rest_small_FPN_1x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "" 4 | BACKBONE: 5 | NAME: "build_rest_fpn_backbone" 6 | MASK_ON: True 7 | REST: 8 | NAME : "rest_small" 9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 10 | FPN: 11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"] 12 | SOLVER: 13 | OPTIMIZER: "AdamW" 14 | BASE_LR: 0.0001 15 | WEIGHT_DECAY: 0.05 16 | OUTPUT_DIR: "output/mask_rcnn/rest_small_ms_1x" 17 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/mask_rcnn_restv2_base_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_base_224.pth" 4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ] 5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ] 6 | MASK_ON: True 7 | BACKBONE: 8 | NAME: "build_restv2_fpn_backbone" 9 | RESTV2: 10 | NAME: "restv2_base" 11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 12 | FPN: 13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 14 | INPUT: 15 | FORMAT: "RGB" 16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800) 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 20000 26 | 27 | DATASETS: 28 | TRAIN: ("coco_2017_train",) 29 | TEST: ("coco_2017_val",) 30 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/mask_rcnn_restv2_small_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_small_224.pth" 4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ] 5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ] 6 | MASK_ON: True 7 | BACKBONE: 8 | NAME: "build_restv2_fpn_backbone" 9 | RESTV2: 10 | NAME: "restv2_small" 11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 12 | FPN: 13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 14 | INPUT: 15 | FORMAT: "RGB" 16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800) 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 20000 26 | 27 | DATASETS: 28 | TRAIN: ("coco_2017_train",) 29 | TEST: ("coco_2017_val",) 30 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/mask_rcnn_restv2_tiny_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN-FPN.yaml" 2 | MODEL: 3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_tiny_224.pth" 4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ] 5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ] 6 | MASK_ON: True 7 | BACKBONE: 8 | NAME: "build_restv2_fpn_backbone" 9 | RESTV2: 10 | NAME: "restv2_tiny" 11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 12 | FPN: 13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ] 14 | INPUT: 15 | FORMAT: "RGB" 16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800) 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 20000 26 | 27 | DATASETS: 28 | TRAIN: ("coco_2017_train",) 29 | TEST: ("coco_2017_val",) 30 | -------------------------------------------------------------------------------- /object_detection/configs/ResTv2/retinanet_restv2_tiny_FPN_3x.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RetinaNet.yaml" 2 | MODEL: 3 | WEIGHTS: "restv2_tiny.pth" 4 | PIXEL_MEAN: [123.675, 116.28, 103.53] # use RGB [103.530, 116.280, 123.675] 5 | PIXEL_STD: [58.395, 57.12, 57.375] #[57.375, 57.120, 58.395] # I use the dafault config [1.0, 1.0, 1.0] and BGR format, that is a mistake 6 | RESNETS: 7 | DEPTH: 50 8 | BACKBONE: 9 | NAME: "build_retinanet_swint_fpn_backbone" 10 | SWINT: 11 | OUT_FEATURES: ["stage3", "stage4", "stage5"] 12 | FPN: 13 | IN_FEATURES: ["stage3", "stage4", "stage5"] 14 | INPUT: 15 | FORMAT: "RGB" 16 | MIN_SIZE_TRAIN: (384, 512, 640, 768, 896, 1024) 17 | SOLVER: 18 | STEPS: (210000, 250000) 19 | MAX_ITER: 270000 20 | WEIGHT_DECAY: 0.05 21 | BASE_LR: 0.0001 22 | AMP: 23 | ENABLED: True 24 | TEST: 25 | EVAL_PERIOD: 30000 26 | 27 | DATASETS: 28 | TRAIN: ("coco_2017_train",) 29 | TEST: ("coco_2017_val",) 30 | -------------------------------------------------------------------------------- /object_detection/convert_to_d2.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import os 8 | import argparse 9 | 10 | import torch 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser("D2 model converter") 15 | 16 | parser.add_argument("--source_model", default="", type=str, help="Path or url to the model to convert") 17 | parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model") 18 | return parser.parse_args() 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | 24 | source_weights = torch.load(args.source_model)["model"] 25 | converted_weights = {} 26 | keys = list(source_weights.keys()) 27 | 28 | prefix = 'backbone.bottom_up.' 29 | for key in keys: 30 | converted_weights[prefix + key] = source_weights[key] 31 | 32 | torch.save(converted_weights, args.output_model) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /object_detection/datasets/README.md: -------------------------------------------------------------------------------- 1 | # Use Builtin Datasets 2 | 3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog) 4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc). 5 | This document explains how to setup the builtin datasets so they can be used by the above APIs. 6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`, 7 | and how to add new datasets to them. 8 | 9 | Detectron2 has builtin support for a few datasets. 10 | The datasets are assumed to exist in a directory specified by the environment variable 11 | `DETECTRON2_DATASETS`. 12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed. 13 | ``` 14 | $DETECTRON2_DATASETS/ 15 | coco/ 16 | lvis/ 17 | cityscapes/ 18 | VOC20{07,12}/ 19 | ``` 20 | 21 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`. 22 | If left unset, the default is `./datasets` relative to your current working directory. 23 | 24 | The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md) 25 | contains configs and models that use these builtin datasets. 26 | 27 | ## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download): 28 | 29 | ``` 30 | coco/ 31 | annotations/ 32 | instances_{train,val}2017.json 33 | person_keypoints_{train,val}2017.json 34 | {train,val}2017/ 35 | # image files that are mentioned in the corresponding json 36 | ``` 37 | 38 | You can use the 2014 version of the dataset as well. 39 | 40 | Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset, 41 | which you can download with `./datasets/prepare_for_tests.sh`. 42 | 43 | ## Expected dataset structure for PanopticFPN: 44 | 45 | Extract panoptic annotations from [COCO website](https://cocodataset.org/#download) 46 | into the following structure: 47 | ``` 48 | coco/ 49 | annotations/ 50 | panoptic_{train,val}2017.json 51 | panoptic_{train,val}2017/ # png annotations 52 | panoptic_stuff_{train,val}2017/ # generated by the script mentioned below 53 | ``` 54 | 55 | Install panopticapi by: 56 | ``` 57 | pip install git+https://github.com/cocodataset/panopticapi.git 58 | ``` 59 | Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations. 60 | 61 | ## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset): 62 | ``` 63 | coco/ 64 | {train,val,test}2017/ 65 | lvis/ 66 | lvis_v0.5_{train,val}.json 67 | lvis_v0.5_image_info_test.json 68 | lvis_v1_{train,val}.json 69 | lvis_v1_image_info_test{,_challenge}.json 70 | ``` 71 | 72 | Install lvis-api by: 73 | ``` 74 | pip install git+https://github.com/lvis-dataset/lvis-api.git 75 | ``` 76 | 77 | To evaluate models trained on the COCO dataset using LVIS annotations, 78 | run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations. 79 | 80 | ## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/): 81 | ``` 82 | cityscapes/ 83 | gtFine/ 84 | train/ 85 | aachen/ 86 | color.png, instanceIds.png, labelIds.png, polygons.json, 87 | labelTrainIds.png 88 | ... 89 | val/ 90 | test/ 91 | # below are generated Cityscapes panoptic annotation 92 | cityscapes_panoptic_train.json 93 | cityscapes_panoptic_train/ 94 | cityscapes_panoptic_val.json 95 | cityscapes_panoptic_val/ 96 | cityscapes_panoptic_test.json 97 | cityscapes_panoptic_test/ 98 | leftImg8bit/ 99 | train/ 100 | val/ 101 | test/ 102 | ``` 103 | Install cityscapes scripts by: 104 | ``` 105 | pip install git+https://github.com/mcordts/cityscapesScripts.git 106 | ``` 107 | 108 | Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with: 109 | ``` 110 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py 111 | ``` 112 | These files are not needed for instance segmentation. 113 | 114 | Note: to generate Cityscapes panoptic dataset, run cityscapesescript with: 115 | ``` 116 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py 117 | ``` 118 | These files are not needed for semantic and instance segmentation. 119 | 120 | ## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html): 121 | ``` 122 | VOC20{07,12}/ 123 | Annotations/ 124 | ImageSets/ 125 | Main/ 126 | trainval.txt 127 | test.txt 128 | # train.txt or val.txt, if you use these splits 129 | JPEGImages/ 130 | ``` 131 | 132 | ## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/): 133 | ``` 134 | ADEChallengeData2016/ 135 | annotations/ 136 | annotations_detectron2/ 137 | images/ 138 | objectInfo150.txt 139 | ``` 140 | The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`. 141 | -------------------------------------------------------------------------------- /object_detection/datasets/prepare_ade20k_sem_seg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | import numpy as np 5 | import os 6 | from pathlib import Path 7 | import tqdm 8 | from PIL import Image 9 | 10 | 11 | def convert(input, output): 12 | img = np.asarray(Image.open(input)) 13 | assert img.dtype == np.uint8 14 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1 15 | Image.fromarray(img).save(output) 16 | 17 | 18 | if __name__ == "__main__": 19 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016" 20 | for name in ["training", "validation"]: 21 | annotation_dir = dataset_dir / "annotations" / name 22 | output_dir = dataset_dir / "annotations_detectron2" / name 23 | output_dir.mkdir(parents=True, exist_ok=True) 24 | for file in tqdm.tqdm(list(annotation_dir.iterdir())): 25 | output_file = output_dir / file.name 26 | convert(file, output_file) 27 | -------------------------------------------------------------------------------- /object_detection/datasets/prepare_cocofied_lvis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | import copy 6 | import json 7 | import os 8 | from collections import defaultdict 9 | 10 | # This mapping is extracted from the official LVIS mapping: 11 | # https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json 12 | COCO_SYNSET_CATEGORIES = [ 13 | {"synset": "person.n.01", "coco_cat_id": 1}, 14 | {"synset": "bicycle.n.01", "coco_cat_id": 2}, 15 | {"synset": "car.n.01", "coco_cat_id": 3}, 16 | {"synset": "motorcycle.n.01", "coco_cat_id": 4}, 17 | {"synset": "airplane.n.01", "coco_cat_id": 5}, 18 | {"synset": "bus.n.01", "coco_cat_id": 6}, 19 | {"synset": "train.n.01", "coco_cat_id": 7}, 20 | {"synset": "truck.n.01", "coco_cat_id": 8}, 21 | {"synset": "boat.n.01", "coco_cat_id": 9}, 22 | {"synset": "traffic_light.n.01", "coco_cat_id": 10}, 23 | {"synset": "fireplug.n.01", "coco_cat_id": 11}, 24 | {"synset": "stop_sign.n.01", "coco_cat_id": 13}, 25 | {"synset": "parking_meter.n.01", "coco_cat_id": 14}, 26 | {"synset": "bench.n.01", "coco_cat_id": 15}, 27 | {"synset": "bird.n.01", "coco_cat_id": 16}, 28 | {"synset": "cat.n.01", "coco_cat_id": 17}, 29 | {"synset": "dog.n.01", "coco_cat_id": 18}, 30 | {"synset": "horse.n.01", "coco_cat_id": 19}, 31 | {"synset": "sheep.n.01", "coco_cat_id": 20}, 32 | {"synset": "beef.n.01", "coco_cat_id": 21}, 33 | {"synset": "elephant.n.01", "coco_cat_id": 22}, 34 | {"synset": "bear.n.01", "coco_cat_id": 23}, 35 | {"synset": "zebra.n.01", "coco_cat_id": 24}, 36 | {"synset": "giraffe.n.01", "coco_cat_id": 25}, 37 | {"synset": "backpack.n.01", "coco_cat_id": 27}, 38 | {"synset": "umbrella.n.01", "coco_cat_id": 28}, 39 | {"synset": "bag.n.04", "coco_cat_id": 31}, 40 | {"synset": "necktie.n.01", "coco_cat_id": 32}, 41 | {"synset": "bag.n.06", "coco_cat_id": 33}, 42 | {"synset": "frisbee.n.01", "coco_cat_id": 34}, 43 | {"synset": "ski.n.01", "coco_cat_id": 35}, 44 | {"synset": "snowboard.n.01", "coco_cat_id": 36}, 45 | {"synset": "ball.n.06", "coco_cat_id": 37}, 46 | {"synset": "kite.n.03", "coco_cat_id": 38}, 47 | {"synset": "baseball_bat.n.01", "coco_cat_id": 39}, 48 | {"synset": "baseball_glove.n.01", "coco_cat_id": 40}, 49 | {"synset": "skateboard.n.01", "coco_cat_id": 41}, 50 | {"synset": "surfboard.n.01", "coco_cat_id": 42}, 51 | {"synset": "tennis_racket.n.01", "coco_cat_id": 43}, 52 | {"synset": "bottle.n.01", "coco_cat_id": 44}, 53 | {"synset": "wineglass.n.01", "coco_cat_id": 46}, 54 | {"synset": "cup.n.01", "coco_cat_id": 47}, 55 | {"synset": "fork.n.01", "coco_cat_id": 48}, 56 | {"synset": "knife.n.01", "coco_cat_id": 49}, 57 | {"synset": "spoon.n.01", "coco_cat_id": 50}, 58 | {"synset": "bowl.n.03", "coco_cat_id": 51}, 59 | {"synset": "banana.n.02", "coco_cat_id": 52}, 60 | {"synset": "apple.n.01", "coco_cat_id": 53}, 61 | {"synset": "sandwich.n.01", "coco_cat_id": 54}, 62 | {"synset": "orange.n.01", "coco_cat_id": 55}, 63 | {"synset": "broccoli.n.01", "coco_cat_id": 56}, 64 | {"synset": "carrot.n.01", "coco_cat_id": 57}, 65 | {"synset": "frank.n.02", "coco_cat_id": 58}, 66 | {"synset": "pizza.n.01", "coco_cat_id": 59}, 67 | {"synset": "doughnut.n.02", "coco_cat_id": 60}, 68 | {"synset": "cake.n.03", "coco_cat_id": 61}, 69 | {"synset": "chair.n.01", "coco_cat_id": 62}, 70 | {"synset": "sofa.n.01", "coco_cat_id": 63}, 71 | {"synset": "pot.n.04", "coco_cat_id": 64}, 72 | {"synset": "bed.n.01", "coco_cat_id": 65}, 73 | {"synset": "dining_table.n.01", "coco_cat_id": 67}, 74 | {"synset": "toilet.n.02", "coco_cat_id": 70}, 75 | {"synset": "television_receiver.n.01", "coco_cat_id": 72}, 76 | {"synset": "laptop.n.01", "coco_cat_id": 73}, 77 | {"synset": "mouse.n.04", "coco_cat_id": 74}, 78 | {"synset": "remote_control.n.01", "coco_cat_id": 75}, 79 | {"synset": "computer_keyboard.n.01", "coco_cat_id": 76}, 80 | {"synset": "cellular_telephone.n.01", "coco_cat_id": 77}, 81 | {"synset": "microwave.n.02", "coco_cat_id": 78}, 82 | {"synset": "oven.n.01", "coco_cat_id": 79}, 83 | {"synset": "toaster.n.02", "coco_cat_id": 80}, 84 | {"synset": "sink.n.01", "coco_cat_id": 81}, 85 | {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82}, 86 | {"synset": "book.n.01", "coco_cat_id": 84}, 87 | {"synset": "clock.n.01", "coco_cat_id": 85}, 88 | {"synset": "vase.n.01", "coco_cat_id": 86}, 89 | {"synset": "scissors.n.01", "coco_cat_id": 87}, 90 | {"synset": "teddy.n.01", "coco_cat_id": 88}, 91 | {"synset": "hand_blower.n.01", "coco_cat_id": 89}, 92 | {"synset": "toothbrush.n.01", "coco_cat_id": 90}, 93 | ] 94 | 95 | 96 | def cocofy_lvis(input_filename, output_filename): 97 | """ 98 | Filter LVIS instance segmentation annotations to remove all categories that are not included in 99 | COCO. The new json files can be used to evaluate COCO AP using `lvis-api`. The category ids in 100 | the output json are the incontiguous COCO dataset ids. 101 | 102 | Args: 103 | input_filename (str): path to the LVIS json file. 104 | output_filename (str): path to the COCOfied json file. 105 | """ 106 | 107 | with open(input_filename, "r") as f: 108 | lvis_json = json.load(f) 109 | 110 | lvis_annos = lvis_json.pop("annotations") 111 | cocofied_lvis = copy.deepcopy(lvis_json) 112 | lvis_json["annotations"] = lvis_annos 113 | 114 | # Mapping from lvis cat id to coco cat id via synset 115 | lvis_cat_id_to_synset = {cat["id"]: cat["synset"] for cat in lvis_json["categories"]} 116 | synset_to_coco_cat_id = {x["synset"]: x["coco_cat_id"] for x in COCO_SYNSET_CATEGORIES} 117 | # Synsets that we will keep in the dataset 118 | synsets_to_keep = set(synset_to_coco_cat_id.keys()) 119 | coco_cat_id_with_instances = defaultdict(int) 120 | 121 | new_annos = [] 122 | ann_id = 1 123 | for ann in lvis_annos: 124 | lvis_cat_id = ann["category_id"] 125 | synset = lvis_cat_id_to_synset[lvis_cat_id] 126 | if synset not in synsets_to_keep: 127 | continue 128 | coco_cat_id = synset_to_coco_cat_id[synset] 129 | new_ann = copy.deepcopy(ann) 130 | new_ann["category_id"] = coco_cat_id 131 | new_ann["id"] = ann_id 132 | ann_id += 1 133 | new_annos.append(new_ann) 134 | coco_cat_id_with_instances[coco_cat_id] += 1 135 | cocofied_lvis["annotations"] = new_annos 136 | 137 | for image in cocofied_lvis["images"]: 138 | for key in ["not_exhaustive_category_ids", "neg_category_ids"]: 139 | new_category_list = [] 140 | for lvis_cat_id in image[key]: 141 | synset = lvis_cat_id_to_synset[lvis_cat_id] 142 | if synset not in synsets_to_keep: 143 | continue 144 | coco_cat_id = synset_to_coco_cat_id[synset] 145 | new_category_list.append(coco_cat_id) 146 | coco_cat_id_with_instances[coco_cat_id] += 1 147 | image[key] = new_category_list 148 | 149 | coco_cat_id_with_instances = set(coco_cat_id_with_instances.keys()) 150 | 151 | new_categories = [] 152 | for cat in lvis_json["categories"]: 153 | synset = cat["synset"] 154 | if synset not in synsets_to_keep: 155 | continue 156 | coco_cat_id = synset_to_coco_cat_id[synset] 157 | if coco_cat_id not in coco_cat_id_with_instances: 158 | continue 159 | new_cat = copy.deepcopy(cat) 160 | new_cat["id"] = coco_cat_id 161 | new_categories.append(new_cat) 162 | cocofied_lvis["categories"] = new_categories 163 | 164 | with open(output_filename, "w") as f: 165 | json.dump(cocofied_lvis, f) 166 | print("{} is COCOfied and stored in {}.".format(input_filename, output_filename)) 167 | 168 | 169 | if __name__ == "__main__": 170 | dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "lvis") 171 | for s in ["lvis_v0.5_train", "lvis_v0.5_val"]: 172 | print("Start COCOfing {}.".format(s)) 173 | cocofy_lvis( 174 | os.path.join(dataset_dir, "{}.json".format(s)), 175 | os.path.join(dataset_dir, "{}_cocofied.json".format(s)), 176 | ) 177 | -------------------------------------------------------------------------------- /object_detection/datasets/prepare_for_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | # Download some files needed for running tests. 5 | 6 | cd "${0%/*}" 7 | 8 | BASE=https://dl.fbaipublicfiles.com/detectron2 9 | mkdir -p coco/annotations 10 | 11 | for anno in instances_val2017_100 \ 12 | person_keypoints_val2017_100 \ 13 | instances_minival2014_100 \ 14 | person_keypoints_minival2014_100; do 15 | 16 | dest=coco/annotations/$anno.json 17 | [[ -s $dest ]] && { 18 | echo "$dest exists. Skipping ..." 19 | } || { 20 | wget $BASE/annotations/coco/$anno.json -O $dest 21 | } 22 | done 23 | -------------------------------------------------------------------------------- /object_detection/datasets/prepare_panoptic_fpn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | import functools 6 | import json 7 | import multiprocessing as mp 8 | import numpy as np 9 | import os 10 | import time 11 | from fvcore.common.download import download 12 | from panopticapi.utils import rgb2id 13 | from PIL import Image 14 | 15 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 16 | 17 | 18 | def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map): 19 | panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32) 20 | panoptic = rgb2id(panoptic) 21 | output = np.zeros_like(panoptic, dtype=np.uint8) + 255 22 | for seg in segments: 23 | cat_id = seg["category_id"] 24 | new_cat_id = id_map[cat_id] 25 | output[panoptic == seg["id"]] = new_cat_id 26 | Image.fromarray(output).save(output_semantic) 27 | 28 | 29 | def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, sem_seg_root, categories): 30 | """ 31 | Create semantic segmentation annotations from panoptic segmentation 32 | annotations, to be used by PanopticFPN. 33 | 34 | It maps all thing categories to class 0, and maps all unlabeled pixels to class 255. 35 | It maps all stuff categories to contiguous ids starting from 1. 36 | 37 | Args: 38 | panoptic_json (str): path to the panoptic json file, in COCO's format. 39 | panoptic_root (str): a directory with panoptic annotation files, in COCO's format. 40 | sem_seg_root (str): a directory to output semantic annotation files 41 | categories (list[dict]): category metadata. Each dict needs to have: 42 | "id": corresponds to the "category_id" in the json annotations 43 | "isthing": 0 or 1 44 | """ 45 | os.makedirs(sem_seg_root, exist_ok=True) 46 | 47 | stuff_ids = [k["id"] for k in categories if k["isthing"] == 0] 48 | thing_ids = [k["id"] for k in categories if k["isthing"] == 1] 49 | id_map = {} # map from category id to id in the output semantic annotation 50 | assert len(stuff_ids) <= 254 51 | for i, stuff_id in enumerate(stuff_ids): 52 | id_map[stuff_id] = i + 1 53 | for thing_id in thing_ids: 54 | id_map[thing_id] = 0 55 | id_map[0] = 255 56 | 57 | with open(panoptic_json) as f: 58 | obj = json.load(f) 59 | 60 | pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4)) 61 | 62 | def iter_annotations(): 63 | for anno in obj["annotations"]: 64 | file_name = anno["file_name"] 65 | segments = anno["segments_info"] 66 | input = os.path.join(panoptic_root, file_name) 67 | output = os.path.join(sem_seg_root, file_name) 68 | yield input, output, segments 69 | 70 | print("Start writing to {} ...".format(sem_seg_root)) 71 | start = time.time() 72 | pool.starmap( 73 | functools.partial(_process_panoptic_to_semantic, id_map=id_map), 74 | iter_annotations(), 75 | chunksize=100, 76 | ) 77 | print("Finished. time: {:.2f}s".format(time.time() - start)) 78 | 79 | 80 | if __name__ == "__main__": 81 | dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco") 82 | for s in ["val2017", "train2017"]: 83 | separate_coco_semantic_from_panoptic( 84 | os.path.join(dataset_dir, "annotations/panoptic_{}.json".format(s)), 85 | os.path.join(dataset_dir, "panoptic_{}".format(s)), 86 | os.path.join(dataset_dir, "panoptic_stuff_{}".format(s)), 87 | COCO_CATEGORIES, 88 | ) 89 | 90 | # Prepare val2017_100 for quick testing: 91 | 92 | dest_dir = os.path.join(dataset_dir, "annotations/") 93 | URL_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" 94 | download(URL_PREFIX + "annotations/coco/panoptic_val2017_100.json", dest_dir) 95 | with open(os.path.join(dest_dir, "panoptic_val2017_100.json")) as f: 96 | obj = json.load(f) 97 | 98 | def link_val100(dir_full, dir_100): 99 | print("Creating " + dir_100 + " ...") 100 | os.makedirs(dir_100, exist_ok=True) 101 | for img in obj["images"]: 102 | basename = os.path.splitext(img["file_name"])[0] 103 | src = os.path.join(dir_full, basename + ".png") 104 | dst = os.path.join(dir_100, basename + ".png") 105 | src = os.path.relpath(src, start=dir_100) 106 | os.symlink(src, dst) 107 | 108 | link_val100( 109 | os.path.join(dataset_dir, "panoptic_val2017"), 110 | os.path.join(dataset_dir, "panoptic_val2017_100"), 111 | ) 112 | 113 | link_val100( 114 | os.path.join(dataset_dir, "panoptic_stuff_val2017"), 115 | os.path.join(dataset_dir, "panoptic_stuff_val2017_100"), 116 | ) 117 | -------------------------------------------------------------------------------- /object_detection/restv2/__init__.py: -------------------------------------------------------------------------------- 1 | from .rest import * 2 | from .restv2 import * 3 | from .config import * 4 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /object_detection/restv2/config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | from detectron2.config import CfgNode as CN 8 | 9 | 10 | def add_restv2_config(cfg): 11 | # restv2 backbone 12 | cfg.MODEL.RESTV2 = CN() 13 | cfg.MODEL.RESTV2.NAME = "restv2_tiny" 14 | cfg.MODEL.RESTV2.OUT_FEATURES = ["stage1", "stage2", "stage3", "stage4"] 15 | cfg.MODEL.RESTV2.WEIGHTS = None 16 | cfg.MODEL.BACKBONE.FREEZE_AT = 2 17 | 18 | # addition 19 | cfg.MODEL.FPN.TOP_LEVELS = 2 20 | cfg.SOLVER.OPTIMIZER = "AdamW" 21 | 22 | 23 | def add_restv1_config(cfg): 24 | # restv1 backbone 25 | cfg.MODEL.REST = CN() 26 | cfg.MODEL.REST.NAME = "rest_base" 27 | cfg.MODEL.REST.OUT_FEATURES = ["stage1", "stage2", "stage3", "stage4"] 28 | cfg.MODEL.REST.WEIGHTS = None 29 | cfg.MODEL.BACKBONE.FREEZE_AT = 2 30 | 31 | # addition 32 | cfg.MODEL.FPN.TOP_LEVELS = 2 33 | cfg.SOLVER.OPTIMIZER = "AdamW" 34 | -------------------------------------------------------------------------------- /object_detection/restv2/rest.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # ResT V1 3 | # Copyright (c) VCU, Nanjing University 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Qing-Long Zhang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 12 | from detectron2.modeling.backbone import Backbone 13 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 14 | from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7, LastLevelMaxPool 15 | from detectron2.layers import ShapeSpec 16 | 17 | 18 | __all__ = [ 19 | "ResT", 20 | "build_rest_backbone", 21 | "build_rest_fpn_backbone", 22 | "build_retinanet_rest_fpn_backbone"] 23 | 24 | 25 | class Mlp(nn.Module): 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, 46 | dim, 47 | num_heads=8, 48 | qkv_bias=False, 49 | qk_scale=None, 50 | attn_drop=0., 51 | proj_drop=0., 52 | sr_ratio=1, 53 | apply_transform=False): 54 | super().__init__() 55 | self.num_heads = num_heads 56 | head_dim = dim // num_heads 57 | self.scale = qk_scale or head_dim ** -0.5 58 | 59 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 60 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 61 | self.attn_drop = nn.Dropout(attn_drop) 62 | self.proj = nn.Linear(dim, dim) 63 | self.proj_drop = nn.Dropout(proj_drop) 64 | 65 | self.sr_ratio = sr_ratio 66 | if sr_ratio > 1: 67 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio+1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 68 | self.sr_norm = nn.LayerNorm(dim) 69 | 70 | self.apply_transform = apply_transform and num_heads > 1 71 | if self.apply_transform: 72 | self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1) 73 | self.transform_norm = nn.InstanceNorm2d(self.num_heads) 74 | 75 | def forward(self, x, H, W): 76 | B, N, C = x.shape 77 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 78 | if self.sr_ratio > 1: 79 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 80 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 81 | x_ = self.sr_norm(x_) 82 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 83 | else: 84 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 85 | k, v = kv[0], kv[1] 86 | 87 | attn = (q @ k.transpose(-2, -1)) * self.scale 88 | if self.apply_transform: 89 | attn = self.transform_conv(attn) 90 | attn = attn.softmax(dim=-1) 91 | attn = self.transform_norm(attn) 92 | else: 93 | attn = attn.softmax(dim=-1) 94 | 95 | attn = self.attn_drop(attn) 96 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 97 | x = self.proj(x) 98 | x = self.proj_drop(x) 99 | return x 100 | 101 | 102 | class Block(nn.Module): 103 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 104 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, apply_transform=False): 105 | super().__init__() 106 | self.norm1 = norm_layer(dim) 107 | self.attn = Attention( 108 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 109 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, apply_transform=apply_transform) 110 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 111 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 112 | self.norm2 = norm_layer(dim) 113 | mlp_hidden_dim = int(dim * mlp_ratio) 114 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 115 | 116 | def forward(self, x, H, W): 117 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 118 | x = x + self.drop_path(self.mlp(self.norm2(x))) 119 | return x 120 | 121 | 122 | class PA(nn.Module): 123 | def __init__(self, dim): 124 | super().__init__() 125 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim) 126 | self.sigmoid = nn.Sigmoid() 127 | 128 | def forward(self, x): 129 | return x * self.sigmoid(self.pa_conv(x)) 130 | 131 | 132 | class PatchEmbed(nn.Module): 133 | """ Image to Patch Embedding""" 134 | def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False): 135 | super().__init__() 136 | self.patch_size = to_2tuple(patch_size) 137 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size+1, stride=patch_size, padding=patch_size // 2) 138 | self.norm = nn.BatchNorm2d(out_ch) 139 | 140 | self.with_pos = with_pos 141 | if self.with_pos: 142 | self.pos = PA(out_ch) 143 | 144 | def forward(self, x): 145 | B, C, H, W = x.shape 146 | x = self.conv(x) 147 | x = self.norm(x) 148 | if self.with_pos: 149 | x = self.pos(x) 150 | x = x.flatten(2).transpose(1, 2) 151 | H, W = H // self.patch_size[0], W // self.patch_size[1] 152 | return x, (H, W) 153 | 154 | 155 | class BasicStem(nn.Module): 156 | def __init__(self, in_ch=3, out_ch=64, with_pos=False): 157 | super(BasicStem, self).__init__() 158 | hidden_ch = out_ch // 2 159 | self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, stride=2, padding=1, bias=False) 160 | self.norm1 = nn.BatchNorm2d(hidden_ch) 161 | self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, stride=1, padding=1, bias=False) 162 | self.norm2 = nn.BatchNorm2d(hidden_ch) 163 | self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False) 164 | 165 | self.act = nn.ReLU(inplace=True) 166 | self.with_pos = with_pos 167 | if self.with_pos: 168 | self.pos = PA(out_ch) 169 | 170 | def forward(self, x): 171 | x = self.conv1(x) 172 | x = self.norm1(x) 173 | x = self.act(x) 174 | 175 | x = self.conv2(x) 176 | x = self.norm2(x) 177 | x = self.act(x) 178 | 179 | x = self.conv3(x) 180 | if self.with_pos: 181 | x = self.pos(x) 182 | return x 183 | 184 | 185 | class ResT(Backbone): 186 | def __init__(self, cfg, in_ch=3, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], 187 | qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., depths=[2, 2, 2, 2], 188 | sr_ratios=[8, 4, 2, 1], norm_layer=nn.LayerNorm, apply_transform=True, out_features=None): 189 | super().__init__() 190 | self.depths = depths 191 | self.num_layers = len(depths) 192 | self.apply_transform = apply_transform 193 | self._out_features = out_features 194 | 195 | self.stem = BasicStem(in_ch=in_ch, out_ch=embed_dims[0], with_pos=True) 196 | 197 | self.patch_embed_2 = PatchEmbed(patch_size=2, in_ch=embed_dims[0], out_ch=embed_dims[1], with_pos=True) 198 | self.patch_embed_3 = PatchEmbed(patch_size=2, in_ch=embed_dims[1], out_ch=embed_dims[2], with_pos=True) 199 | self.patch_embed_4 = PatchEmbed(patch_size=2, in_ch=embed_dims[2], out_ch=embed_dims[3], with_pos=True) 200 | 201 | # transformer encoder 202 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 203 | cur = 0 204 | 205 | self.stage1 = nn.ModuleList([ 206 | Block(embed_dims[0], num_heads[0], mlp_ratios[0], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 207 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[0], apply_transform=apply_transform) 208 | for i in range(self.depths[0])]) 209 | 210 | cur += depths[0] 211 | self.stage2 = nn.ModuleList([ 212 | Block(embed_dims[1], num_heads[1], mlp_ratios[1], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 213 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[1], apply_transform=apply_transform) 214 | for i in range(self.depths[1])]) 215 | 216 | cur += depths[1] 217 | self.stage3 = nn.ModuleList([ 218 | Block(embed_dims[2], num_heads[2], mlp_ratios[2], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 219 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[2], apply_transform=apply_transform) 220 | for i in range(self.depths[2])]) 221 | 222 | cur += depths[2] 223 | self.stage4 = nn.ModuleList([ 224 | Block(embed_dims[3], num_heads[3], mlp_ratios[3], qkv_bias, qk_scale, drop_rate, attn_drop_rate, 225 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[3], apply_transform=apply_transform) 226 | for i in range(self.depths[3])]) 227 | 228 | # add a norm layer for each output 229 | for i in range(self.num_layers-1): 230 | stage = f'stage{i+1}' 231 | if stage in self._out_features: 232 | layer = norm_layer(embed_dims[i]) 233 | layer_name = f'norm{i+1}' 234 | self.add_module(layer_name, layer) 235 | 236 | self.norm = norm_layer(embed_dims[3]) 237 | 238 | # init weights 239 | self.apply(self._init_weights) 240 | self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_AT) 241 | 242 | def _init_weights(self, m): 243 | if isinstance(m, nn.Conv2d): 244 | trunc_normal_(m.weight, std=0.02) 245 | elif isinstance(m, nn.Linear): 246 | trunc_normal_(m.weight, std=0.02) 247 | if m.bias is not None: 248 | nn.init.constant_(m.bias, 0) 249 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 250 | nn.init.constant_(m.weight, 1.0) 251 | nn.init.constant_(m.bias, 0) 252 | 253 | def _freeze_backbone(self, freeze_at): 254 | if freeze_at < 0: 255 | return 256 | for stage_index in range(freeze_at): 257 | if stage_index >= 0: 258 | # stage 0 is the stem 259 | self.stem.eval() 260 | for p in self.stem.parameters(): 261 | p.requires_grad = False 262 | if stage_index == 1: 263 | for p in self.stage1.parameters(): 264 | p.requires_grad = False 265 | else: 266 | patch = getattr(self, "patch_embed_" + str(stage_index)) 267 | stage = getattr(self, "stage" + str(stage_index)) 268 | for p in patch.parameters(): 269 | p.requires_grad = False 270 | for p in stage.parameters(): 271 | p.requires_grad = False 272 | 273 | def forward(self, x): 274 | outputs = {} 275 | x = self.stem(x) 276 | B, _, H, W = x.shape 277 | x = x.flatten(2).permute(0, 2, 1) 278 | 279 | # stage 1 280 | for blk in self.stage1: 281 | x = blk(x, H, W) 282 | if "stage1" in self._out_features: 283 | x = self.norm1(x) 284 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 285 | outputs["stage1"] = x 286 | else: 287 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 288 | 289 | # stage 2 290 | x, (H, W) = self.patch_embed_2(x) 291 | for blk in self.stage2: 292 | x = blk(x, H, W) 293 | if "stage2" in self._out_features: 294 | x = self.norm2(x) 295 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 296 | outputs["stage2"] = x 297 | else: 298 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 299 | 300 | # stage 3 301 | x, (H, W) = self.patch_embed_3(x) 302 | for blk in self.stage3: 303 | x = blk(x, H, W) 304 | if "stage3" in self._out_features: 305 | x = self.norm3(x) 306 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 307 | outputs["stage3"] = x 308 | else: 309 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 310 | 311 | # stage 4 312 | x, (H, W) = self.patch_embed_4(x) 313 | for blk in self.stage4: 314 | x = blk(x, H, W) 315 | x = self.norm(x) 316 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 317 | if "stage4" in self._out_features: 318 | outputs["stage4"] = x 319 | 320 | return outputs 321 | 322 | 323 | @BACKBONE_REGISTRY.register() 324 | def build_rest_backbone(cfg, input_shape): 325 | """ 326 | Create a ResT instance from config. 327 | Returns: 328 | ResT: a :class:`ResT` instance. 329 | """ 330 | name = cfg.MODEL.REST.NAME 331 | out_features = cfg.MODEL.REST.OUT_FEATURES 332 | 333 | depths = {"rest_lite": [2, 2, 2, 2], "rest_small": [2, 2, 6, 2], 334 | "rest_base": [2, 2, 6, 2], "rest_large": [2, 2, 18, 2]}[name] 335 | 336 | embed_dims = {"rest_lite": [64, 128, 256, 512], "rest_small": [64, 128, 256, 512], 337 | "rest_base": [96, 192, 384, 768], "rest_large": [96, 192, 384, 768]}[name] 338 | 339 | drop_path_rate = {"rest_lite": 0.1, "rest_small": 0.1, "rest_base": 0.2, "rest_large": 0.2}[name] 340 | 341 | feature_names = ['stage1', 'stage2', 'stage3', 'stage4'] 342 | out_feature_channels = dict(zip(feature_names, embed_dims)) 343 | out_feature_strides = {"stage1": 4, "stage2": 8, "stage3": 16, "stage4": 32} 344 | 345 | model = ResT(cfg, in_ch=3, embed_dims=embed_dims, qkv_bias=True, drop_path_rate=drop_path_rate, 346 | depths=depths, apply_transform=True, out_features=out_features) 347 | model._out_feature_channels = out_feature_channels 348 | model._out_feature_strides = out_feature_strides 349 | return model 350 | 351 | 352 | @BACKBONE_REGISTRY.register() 353 | def build_rest_fpn_backbone(cfg, input_shape: ShapeSpec): 354 | """ 355 | Args: 356 | cfg: a detectron2 CfgNode 357 | Returns: 358 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 359 | """ 360 | bottom_up = build_rest_backbone(cfg, input_shape) 361 | in_features = cfg.MODEL.FPN.IN_FEATURES 362 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 363 | backbone = FPN( 364 | bottom_up=bottom_up, 365 | in_features=in_features, 366 | out_channels=out_channels, 367 | norm=cfg.MODEL.FPN.NORM, 368 | top_block=LastLevelMaxPool(), 369 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 370 | ) 371 | return backbone 372 | 373 | 374 | @BACKBONE_REGISTRY.register() 375 | def build_retinanet_rest_fpn_backbone(cfg, input_shape: ShapeSpec): 376 | """ 377 | Args: 378 | cfg: a detectron2 CfgNode 379 | 380 | Returns: 381 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 382 | """ 383 | bottom_up = build_rest_backbone(cfg, input_shape) 384 | in_features = cfg.MODEL.FPN.IN_FEATURES 385 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 386 | in_channels_p6p7 = bottom_up.output_shape()["stage4"].channels 387 | backbone = FPN( 388 | bottom_up=bottom_up, 389 | in_features=in_features, 390 | out_channels=out_channels, 391 | norm=cfg.MODEL.FPN.NORM, 392 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels, in_feature="stage4"), 393 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 394 | ) 395 | return backbone -------------------------------------------------------------------------------- /object_detection/restv2/restv2.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # ResT V2 3 | # Copyright (c) VCU, Nanjing University 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Qing-Long Zhang 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | import torch.nn as nn 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec 12 | from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7, LastLevelMaxPool 13 | from utils import load_state_dict 14 | from utils import LayerNorm 15 | 16 | __all__ = [ 17 | "Attention", 18 | "ResTV2", 19 | "build_restv2_backbone", 20 | "build_restv2_fpn_backbone", 21 | "build_retinanet_restv2_fpn_backbone", 22 | ] 23 | 24 | 25 | class Mlp(nn.Module): 26 | def __init__(self, dim): 27 | super().__init__() 28 | self.fc1 = nn.Linear(dim, 4 * dim) 29 | self.act = nn.GELU() 30 | self.fc2 = nn.Linear(4 * dim, dim) 31 | 32 | def forward(self, x): 33 | x = self.fc1(x) 34 | x = self.act(x) 35 | x = self.fc2(x) 36 | return x 37 | 38 | 39 | class Attention(nn.Module): 40 | def __init__(self, 41 | dim, 42 | num_heads=8, 43 | sr_ratio=1): 44 | super().__init__() 45 | self.num_heads = num_heads 46 | head_dim = dim // num_heads 47 | self.scale = head_dim ** -0.5 48 | 49 | self.q = nn.Linear(dim, dim, bias=True) 50 | self.kv = nn.Linear(dim, dim * 2, bias=True) 51 | 52 | self.sr_ratio = sr_ratio 53 | if sr_ratio > 1: 54 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 55 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6) 56 | 57 | self.up = nn.Sequential( 58 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim), 59 | nn.PixelShuffle(upscale_factor=sr_ratio) 60 | ) 61 | self.up_norm = nn.LayerNorm(dim, eps=1e-6) 62 | 63 | self.proj = nn.Linear(dim, dim) 64 | 65 | def forward(self, x, H, W): 66 | B, N, C = x.shape 67 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 68 | if self.sr_ratio > 1: 69 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 70 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 71 | x = self.sr_norm(x) 72 | 73 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 74 | k, v = kv[0], kv[1] 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 78 | 79 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio) 80 | identity = self.up(identity).flatten(2).transpose(1, 2) 81 | x = self.proj(x + self.up_norm(identity)) 82 | return x 83 | 84 | 85 | class Block(nn.Module): 86 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.): 87 | super().__init__() 88 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 89 | self.attn = Attention(dim, num_heads, sr_ratio) 90 | 91 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 92 | self.mlp = Mlp(dim) 93 | 94 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 95 | 96 | def forward(self, x, H, W): 97 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm 98 | x = x + self.drop_path(self.mlp(self.norm2(x))) 99 | return x 100 | 101 | 102 | class PA(nn.Module): 103 | def __init__(self, dim): 104 | super().__init__() 105 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True) 106 | self.sigmoid = nn.Sigmoid() 107 | 108 | def forward(self, x): 109 | return x * self.sigmoid(self.pa_conv(x)) 110 | 111 | 112 | class ConvStem(nn.Module): 113 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 114 | super().__init__() 115 | self.patch_size = to_2tuple(patch_size) 116 | stem = [] 117 | in_dim, out_dim = in_ch, out_ch // 2 118 | for i in range(2): 119 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False)) 120 | stem.append(nn.BatchNorm2d(out_dim)) 121 | stem.append(nn.ReLU(inplace=True)) 122 | in_dim, out_dim = out_dim, out_dim * 2 123 | 124 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1)) 125 | self.proj = nn.Sequential(*stem) 126 | 127 | self.with_pos = with_pos 128 | if self.with_pos: 129 | self.pos = PA(out_ch) 130 | 131 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 132 | 133 | def forward(self, x): 134 | B, C, H, W = x.shape 135 | x = self.proj(x) 136 | if self.with_pos: 137 | x = self.pos(x) 138 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 139 | x = self.norm(x) 140 | H, W = H // self.patch_size[0], W // self.patch_size[1] 141 | return x, (H, W) 142 | 143 | 144 | class PatchEmbed(nn.Module): 145 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 146 | super().__init__() 147 | self.patch_size = to_2tuple(patch_size) 148 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2) 149 | 150 | self.with_pos = with_pos 151 | if self.with_pos: 152 | self.pos = PA(out_ch) 153 | 154 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 155 | 156 | def forward(self, x): 157 | B, C, H, W = x.shape 158 | x = self.proj(x) 159 | if self.with_pos: 160 | x = self.pos(x) 161 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 162 | x = self.norm(x) 163 | H, W = H // self.patch_size[0], W // self.patch_size[1] 164 | return x, (H, W) 165 | 166 | 167 | class ResTV2(Backbone): 168 | def __init__(self, in_chans=3, embed_dims=[96, 192, 384, 768], 169 | num_heads=[1, 2, 4, 8], drop_path_rate=0., 170 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_features=None, 171 | frozen_stages=-1, pretrained=None): 172 | super().__init__() 173 | 174 | self.depths = depths 175 | self.frozen_stages = frozen_stages 176 | self.pretrained = pretrained 177 | self._out_features = out_features 178 | 179 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4) 180 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2) 181 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2) 182 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2) 183 | 184 | # transformer encoder 185 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 186 | cur = 0 187 | self.stage1 = nn.ModuleList([ 188 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i]) 189 | for i in range(depths[0]) 190 | ]) 191 | 192 | cur += depths[0] 193 | self.stage2 = nn.ModuleList([ 194 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i]) 195 | for i in range(depths[1]) 196 | ]) 197 | 198 | cur += depths[1] 199 | self.stage3 = nn.ModuleList([ 200 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i]) 201 | for i in range(depths[2]) 202 | ]) 203 | 204 | cur += depths[2] 205 | self.stage4 = nn.ModuleList([ 206 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i]) 207 | for i in range(depths[3]) 208 | ]) 209 | 210 | # add a norm layer for each output 211 | for i_layer in out_features: 212 | idx = int(i_layer[-1]) 213 | out_ch = embed_dims[idx - 1] 214 | layer = nn.Sequential( 215 | LayerNorm(out_ch, eps=1e-6, data_format="channels_first"), 216 | ) 217 | layer_name = f"norm_{idx}" 218 | self.add_module(layer_name, layer) 219 | 220 | self._freeze_stages() 221 | self.init_weights() 222 | 223 | def _freeze_stages(self): 224 | if self.frozen_stages >= 0: 225 | self.stem.eval() 226 | for param in self.stem.parameters(): 227 | param.requires_grad = False 228 | 229 | if self.frozen_stages >= 1: 230 | self.stage1.eval() 231 | for param in self.stage1.parameters(): 232 | param.requires_grad = False 233 | 234 | if self.frozen_stages >= 2: 235 | for i in range(0, self.frozen_stages): 236 | patch = getattr(self, "patch_embed_" + str(i)) 237 | stage = getattr(self, "stage" + str(i)) 238 | patch.eval() 239 | stage.eval() 240 | for param in patch.parameters(): 241 | param.requires_grad = False 242 | for param in stage.parameters(): 243 | param.requires_grad = False 244 | 245 | def init_weights(self): 246 | if self.pretrained is None: 247 | for m in self.modules(): 248 | if isinstance(m, (nn.Conv2d, nn.Linear)): 249 | trunc_normal_(m.weight, std=0.02) 250 | if m.bias is not None: 251 | nn.init.constant_(m.bias, 0) 252 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): 253 | nn.init.constant_(m.weight, 1.) 254 | nn.init.constant_(m.bias, 0.) 255 | else: 256 | checkpoint = torch.load(self.pretrained, map_location='cpu') 257 | if 'state_dict' in checkpoint: 258 | state_dict = checkpoint['state_dict'] 259 | elif 'model' in checkpoint: 260 | state_dict = checkpoint['model'] 261 | else: 262 | state_dict = checkpoint 263 | # strip prefix of state_dict 264 | if list(state_dict.keys())[0].startswith('module.'): 265 | state_dict = {k[7:]: v for k, v in state_dict.items()} 266 | load_state_dict(self, state_dict) 267 | 268 | def forward(self, x): 269 | outs = {} 270 | B, _, H, W = x.shape 271 | x, (H, W) = self.stem(x) 272 | # stage 1 273 | for blk in self.stage1: 274 | x = blk(x, H, W) 275 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 276 | if "stage1" in self._out_features: 277 | outs["stage1"] = self.norm_1(x) 278 | 279 | # stage 2 280 | x, (H, W) = self.patch_2(x) 281 | for blk in self.stage2: 282 | x = blk(x, H, W) 283 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 284 | if "stage2" in self._out_features: 285 | outs["stage2"] = self.norm_2(x) 286 | 287 | # stage 3 288 | x, (H, W) = self.patch_3(x) 289 | for blk in self.stage3: 290 | x = blk(x, H, W) 291 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 292 | if "stage3" in self._out_features: 293 | outs["stage3"] = self.norm_3(x) 294 | 295 | # stage 4 296 | x, (H, W) = self.patch_4(x) 297 | for blk in self.stage4: 298 | x = blk(x, H, W) 299 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 300 | if "stage4" in self._out_features: 301 | outs["stage4"] = self.norm_4(x) 302 | 303 | return outs 304 | 305 | @property 306 | def size_divisibility(self) -> int: 307 | return 32 308 | 309 | 310 | @BACKBONE_REGISTRY.register() 311 | def build_restv2_backbone(cfg, input_shape: ShapeSpec): 312 | name = cfg.MODEL.RESTV2.NAME 313 | out_features = cfg.MODEL.RESTV2.OUT_FEATURES 314 | settings = { 315 | "restv2_tiny": {"depths": [1, 2, 6, 2], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.1}, 316 | "restv2_small": {"depths": [1, 2, 12, 2], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.2}, 317 | "restv2_base": {"depths": [1, 3, 16, 3], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.3}, 318 | "restv2_large": {"depths": [2, 3, 16, 3], "embed_dims": [128, 256, 512, 1024], "drop_path_rate": 0.5}, 319 | } 320 | feature_names = ['stage1', 'stage2', 'stage3', 'stage4'] 321 | out_feature_channels = dict(zip(feature_names, settings[name]["embed_dims"])) 322 | out_feature_strides = {"stage1": 4, "stage2": 8, "stage3": 16, "stage4": 32} 323 | model = ResTV2(embed_dims=settings[name]["embed_dims"], 324 | depths=settings[name]["depths"], drop_path_rate=settings[name]["drop_path_rate"], 325 | out_features=out_features) 326 | model._out_feature_channels = out_feature_channels 327 | model._out_feature_strides = out_feature_strides 328 | return model 329 | 330 | 331 | @BACKBONE_REGISTRY.register() 332 | def build_restv2_fpn_backbone(cfg, input_shape: ShapeSpec): 333 | """ 334 | Args: 335 | cfg: a detectron2 CfgNode 336 | input_shape: ShapeSpec 337 | Returns: 338 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 339 | """ 340 | bottom_up = build_restv2_backbone(cfg, input_shape) 341 | in_features = cfg.MODEL.FPN.IN_FEATURES 342 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 343 | backbone = FPN( 344 | bottom_up=bottom_up, 345 | in_features=in_features, 346 | out_channels=out_channels, 347 | norm=cfg.MODEL.FPN.NORM, 348 | top_block=LastLevelMaxPool(), 349 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 350 | ) 351 | return backbone 352 | 353 | 354 | @BACKBONE_REGISTRY.register() 355 | def build_retinanet_restv2_fpn_backbone(cfg, input_shape: ShapeSpec): 356 | """ 357 | Args: 358 | cfg: a detectron2 CfgNode 359 | input_shape: ShapeSpec 360 | Returns: 361 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 362 | """ 363 | bottom_up = build_restv2_backbone(cfg, input_shape) 364 | in_features = cfg.MODEL.FPN.IN_FEATURES 365 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 366 | in_channels_p6p7 = bottom_up.output_shape()["stage4"].channels 367 | backbone = FPN( 368 | bottom_up=bottom_up, 369 | in_features=in_features, 370 | out_channels=out_channels, 371 | norm=cfg.MODEL.FPN.NORM, 372 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels, in_feature="stage4"), 373 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 374 | ) 375 | return backbone 376 | -------------------------------------------------------------------------------- /object_detection/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ------------------------------------------------------------ 3 | # Copyright (c) VCU, Nanjing University. 4 | # Licensed under the Apache License 2.0 [see LICENSE for details] 5 | # Written by Qing-Long Zhang 6 | # ------------------------------------------------------------ 7 | 8 | """ 9 | Detection Training Script. 10 | This scripts reads a given config file and runs the training or evaluation. 11 | It is an entry point that is made to train standard models in detectron2. 12 | In order to let one script support training of many models, 13 | this script contains logic that are specific to these built-in models and therefore 14 | may not be suitable for your own project. 15 | For example, your research project perhaps only needs a single "evaluator". 16 | Therefore, we recommend you to use detectron2 as an library and take 17 | this file as an example of how to use the library. 18 | You may want to write your own script with your datasets and other customizations. 19 | """ 20 | import copy 21 | from typing import Any, Dict, List, Set 22 | import itertools 23 | import logging 24 | import os 25 | from collections import OrderedDict 26 | import torch 27 | 28 | import detectron2.utils.comm as comm 29 | from detectron2.checkpoint import DetectionCheckpointer 30 | from detectron2.config import get_cfg 31 | from detectron2.data import MetadataCatalog 32 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch 33 | from detectron2.evaluation import ( 34 | CityscapesInstanceEvaluator, 35 | CityscapesSemSegEvaluator, 36 | COCOEvaluator, 37 | COCOPanopticEvaluator, 38 | DatasetEvaluators, 39 | LVISEvaluator, 40 | PascalVOCDetectionEvaluator, 41 | SemSegEvaluator, 42 | verify_results, 43 | ) 44 | 45 | from detectron2.modeling import GeneralizedRCNNWithTTA 46 | from detectron2.solver.build import maybe_add_gradient_clipping, get_default_optimizer_params 47 | from restv2 import add_restv2_config, add_restv1_config 48 | 49 | 50 | def build_evaluator(cfg, dataset_name, output_folder=None): 51 | """ 52 | Create evaluator(s) for a given dataset. 53 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 54 | For your own dataset, you can simply create an evaluator manually in your 55 | script and do not have to worry about the hacky if-else logic here. 56 | """ 57 | if output_folder is None: 58 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 59 | evaluator_list = [] 60 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 61 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]: 62 | evaluator_list.append( 63 | SemSegEvaluator( 64 | dataset_name, 65 | distributed=True, 66 | output_dir=output_folder, 67 | ) 68 | ) 69 | if evaluator_type in ["coco", "coco_panoptic_seg"]: 70 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) 71 | if evaluator_type == "coco_panoptic_seg": 72 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder)) 73 | if evaluator_type == "cityscapes_instance": 74 | return CityscapesInstanceEvaluator(dataset_name) 75 | if evaluator_type == "cityscapes_sem_seg": 76 | return CityscapesSemSegEvaluator(dataset_name) 77 | elif evaluator_type == "pascal_voc": 78 | return PascalVOCDetectionEvaluator(dataset_name) 79 | elif evaluator_type == "lvis": 80 | return LVISEvaluator(dataset_name, output_dir=output_folder) 81 | if len(evaluator_list) == 0: 82 | raise NotImplementedError( 83 | "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) 84 | ) 85 | elif len(evaluator_list) == 1: 86 | return evaluator_list[0] 87 | return DatasetEvaluators(evaluator_list) 88 | 89 | 90 | class Trainer(DefaultTrainer): 91 | """ 92 | We use the "DefaultTrainer" which contains pre-defined default logic for 93 | standard training workflow. They may not work for you, especially if you 94 | are working on a new research project. In that case you can write your 95 | own training loop. You can use "tools/plain_train_net.py" as an example. 96 | """ 97 | 98 | @classmethod 99 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 100 | return build_evaluator(cfg, dataset_name, output_folder) 101 | 102 | @classmethod 103 | def test_with_TTA(cls, cfg, model): 104 | logger = logging.getLogger("detectron2.trainer") 105 | # In the end of training, run an evaluation with TTA 106 | # Only support some R-CNN models. 107 | logger.info("Running inference with test-time augmentation ...") 108 | model = GeneralizedRCNNWithTTA(cfg, model) 109 | evaluators = [ 110 | cls.build_evaluator( 111 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 112 | ) 113 | for name in cfg.DATASETS.TEST 114 | ] 115 | res = cls.test(cfg, model, evaluators) 116 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 117 | return res 118 | 119 | @classmethod 120 | def build_optimizer(cls, cfg, model): 121 | params = get_default_optimizer_params( 122 | model, 123 | base_lr=cfg.SOLVER.BASE_LR, 124 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 125 | weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, 126 | bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, 127 | weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, 128 | ) 129 | 130 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 131 | # detectron2 doesn't have full model gradient clipping now 132 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 133 | enable = ( 134 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 135 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 136 | and clip_norm_val > 0.0 137 | ) 138 | 139 | class FullModelGradientClippingOptimizer(optim): 140 | def step(self, closure=None): 141 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 142 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 143 | super().step(closure=closure) 144 | 145 | return FullModelGradientClippingOptimizer if enable else optim 146 | 147 | optimizer_type = cfg.SOLVER.OPTIMIZER 148 | if optimizer_type == "SGD": 149 | optimizer = maybe_add_gradient_clipping(torch.optim.SGD)( 150 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, 151 | nesterov=cfg.SOLVER.NESTEROV, 152 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 153 | ) 154 | elif optimizer_type == "AdamW": 155 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 156 | params, cfg.SOLVER.BASE_LR, betas=(0.9, 0.999), 157 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 158 | ) 159 | else: 160 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 161 | return optimizer 162 | 163 | 164 | def setup(args): 165 | """ 166 | Create configs and perform basic setups. 167 | """ 168 | cfg = get_cfg() 169 | add_restv2_config(cfg) 170 | add_restv1_config(cfg) 171 | cfg.merge_from_file(args.config_file) 172 | cfg.merge_from_list(args.opts) 173 | cfg.freeze() 174 | default_setup(cfg, args) 175 | 176 | return cfg 177 | 178 | 179 | def main(args): 180 | cfg = setup(args) 181 | 182 | if args.eval_only: 183 | model = Trainer.build_model(cfg) 184 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 185 | cfg.MODEL.WEIGHTS, resume=args.resume 186 | ) 187 | res = Trainer.test(cfg, model) 188 | if cfg.TEST.AUG.ENABLED: 189 | res.update(Trainer.test_with_TTA(cfg, model)) 190 | if comm.is_main_process(): 191 | verify_results(cfg, res) 192 | return res 193 | 194 | """ 195 | If you'd like to do anything fancier than the standard training logic, 196 | consider writing your own training loop (see plain_train_net.py) or 197 | subclassing the trainer. 198 | """ 199 | trainer = Trainer(cfg) 200 | trainer.resume_or_load(resume=args.resume) 201 | if cfg.TEST.AUG.ENABLED: 202 | trainer.register_hooks( 203 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))] 204 | ) 205 | return trainer.train() 206 | 207 | 208 | if __name__ == "__main__": 209 | args = default_argument_parser().parse_args() 210 | print("Command Line Args:", args) 211 | launch( 212 | main, 213 | args.num_gpus, 214 | num_machines=args.num_machines, 215 | machine_rank=args.machine_rank, 216 | dist_url=args.dist_url, 217 | args=(args,), 218 | ) 219 | -------------------------------------------------------------------------------- /object_detection/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.distributed as dist 13 | import torchvision 14 | from torch import Tensor 15 | 16 | 17 | def _max_by_axis(the_list): 18 | # type: (List[List[int]]) -> List[int] 19 | maxes = the_list[0] 20 | for sublist in the_list[1:]: 21 | for index, item in enumerate(sublist): 22 | maxes[index] = max(maxes[index], item) 23 | return maxes 24 | 25 | 26 | class NestedTensor(object): 27 | def __init__(self, tensors, mask: Optional[Tensor]): 28 | self.tensors = tensors 29 | self.mask = mask 30 | 31 | def to(self, device): 32 | # type: (Device) -> NestedTensor # noqa 33 | cast_tensor = self.tensors.to(device) 34 | mask = self.mask 35 | if mask is not None: 36 | assert mask is not None 37 | cast_mask = mask.to(device) 38 | else: 39 | cast_mask = None 40 | return NestedTensor(cast_tensor, cast_mask) 41 | 42 | def decompose(self): 43 | return self.tensors, self.mask 44 | 45 | def __repr__(self): 46 | return str(self.tensors) 47 | 48 | 49 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 50 | # TODO make this more general 51 | if tensor_list[0].ndim == 3: 52 | if torchvision._is_tracing(): 53 | # nested_tensor_from_tensor_list() does not export well to ONNX 54 | # call _onnx_nested_tensor_from_tensor_list() instead 55 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 56 | 57 | # TODO make it support different-sized images 58 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 59 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 60 | batch_shape = [len(tensor_list)] + max_size 61 | b, c, h, w = batch_shape 62 | dtype = tensor_list[0].dtype 63 | device = tensor_list[0].device 64 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 65 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 66 | for img, pad_img, m in zip(tensor_list, tensor, mask): 67 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 68 | m[: img.shape[1], : img.shape[2]] = False 69 | else: 70 | raise ValueError("not supported") 71 | return NestedTensor(tensor, mask) 72 | 73 | 74 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 75 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 76 | @torch.jit.unused 77 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 78 | max_size = [] 79 | for i in range(tensor_list[0].dim()): 80 | max_size_i = torch.max( 81 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 82 | ).to(torch.int64) 83 | max_size.append(max_size_i) 84 | max_size = tuple(max_size) 85 | 86 | # work around for 87 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 88 | # m[: img.shape[1], :img.shape[2]] = False 89 | # which is not yet supported in onnx 90 | padded_imgs = [] 91 | padded_masks = [] 92 | for img in tensor_list: 93 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 94 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 95 | padded_imgs.append(padded_img) 96 | 97 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 98 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 99 | padded_masks.append(padded_mask.to(torch.bool)) 100 | 101 | tensor = torch.stack(padded_imgs) 102 | mask = torch.stack(padded_masks) 103 | 104 | return NestedTensor(tensor, mask=mask) 105 | 106 | 107 | def is_dist_avail_and_initialized(): 108 | if not dist.is_available(): 109 | return False 110 | if not dist.is_initialized(): 111 | return False 112 | return True 113 | 114 | 115 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 116 | missing_keys = [] 117 | unexpected_keys = [] 118 | error_msgs = [] 119 | # copy state_dict so _load_from_state_dict can modify it 120 | metadata = getattr(state_dict, '_metadata', None) 121 | state_dict = state_dict.copy() 122 | if metadata is not None: 123 | state_dict._metadata = metadata 124 | 125 | def load(module, prefix=''): 126 | local_metadata = {} if metadata is None else metadata.get( 127 | prefix[:-1], {}) 128 | module._load_from_state_dict( 129 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 130 | for name, child in module._modules.items(): 131 | if child is not None: 132 | load(child, prefix + name + '.') 133 | 134 | load(model, prefix=prefix) 135 | 136 | warn_missing_keys = [] 137 | ignore_missing_keys = [] 138 | for key in missing_keys: 139 | keep_flag = True 140 | for ignore_key in ignore_missing.split('|'): 141 | if ignore_key in key: 142 | keep_flag = False 143 | break 144 | if keep_flag: 145 | warn_missing_keys.append(key) 146 | else: 147 | ignore_missing_keys.append(key) 148 | 149 | missing_keys = warn_missing_keys 150 | 151 | if len(missing_keys) > 0: 152 | print("Weights of {} not initialized from pretrained model: {}".format( 153 | model.__class__.__name__, missing_keys)) 154 | if len(unexpected_keys) > 0: 155 | print("Weights from pretrained model not used in {}: {}".format( 156 | model.__class__.__name__, unexpected_keys)) 157 | if len(ignore_missing_keys) > 0: 158 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 159 | model.__class__.__name__, ignore_missing_keys)) 160 | if len(error_msgs) > 0: 161 | print('\n'.join(error_msgs)) 162 | 163 | 164 | class LayerNorm(nn.Module): 165 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 166 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 167 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 168 | with shape (batch_size, channels, height, width). 169 | """ 170 | 171 | def __init__(self, dim, eps=1e-6, data_format="channels_last"): 172 | super().__init__() 173 | self.weight = nn.Parameter(torch.ones(dim)) 174 | self.bias = nn.Parameter(torch.zeros(dim)) 175 | self.eps = eps 176 | self.data_format = data_format 177 | if self.data_format not in ["channels_last", "channels_first"]: 178 | raise NotImplementedError 179 | self.dim = (dim,) 180 | 181 | def forward(self, x): 182 | if self.data_format == "channels_last": 183 | return F.layer_norm(x, self.dim, self.weight, self.bias, self.eps) 184 | elif self.data_format == "channels_first": 185 | u = x.mean(1, keepdim=True) 186 | s = (x - u).pow(2).mean(1, keepdim=True) 187 | x = (x - u) / torch.sqrt(s + self.eps) 188 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 189 | return x 190 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import torch 8 | from torch import optim as optim 9 | 10 | from timm.optim.adafactor import Adafactor 11 | from timm.optim.adahessian import Adahessian 12 | from timm.optim.adamp import AdamP 13 | from timm.optim.lookahead import Lookahead 14 | from timm.optim.nadam import Nadam 15 | from timm.optim.nvnovograd import NvNovoGrad 16 | from timm.optim.radam import RAdam 17 | from timm.optim.rmsprop_tf import RMSpropTF 18 | from timm.optim.sgdp import SGDP 19 | from timm.optim.lamb import Lamb 20 | from timm.optim.lars import Lars 21 | 22 | import json 23 | 24 | try: 25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 26 | 27 | has_apex = True 28 | except ImportError: 29 | has_apex = False 30 | 31 | 32 | def get_layer_id_for_rest(model, name): 33 | """ 34 | Assign a parameter with its layer id 35 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 36 | """ 37 | depths = model.depths 38 | if name in ['cls_token', 'pos_embed']: 39 | return 0 40 | elif name.startswith('stem'): 41 | return 0 42 | elif name.startswith('patch'): 43 | layer_id = int(name.split('.')[0].split('_')[1]) 44 | if layer_id == 2: 45 | layer_id = depths[0] + 1 46 | elif layer_id == 3: 47 | layer_id = sum(depths[:2]) + 2 48 | else: 49 | layer_id = sum(depths[:3]) + 3 50 | return int(layer_id) 51 | elif name.startswith('stage'): 52 | stage_id = int(name.split('.')[0][-1]) 53 | layer_id = int(name.split('.')[1]) 54 | if stage_id == 1: 55 | layer_id = layer_id + 1 56 | elif stage_id == 2: 57 | layer_id = depths[0] + layer_id + 2 58 | elif stage_id == 3: 59 | layer_id = sum(depths[:2]) + layer_id + 3 60 | else: 61 | layer_id = sum(depths[:3]) + layer_id + 4 62 | return int(layer_id) 63 | else: 64 | return sum(depths) + 4 65 | 66 | 67 | class LayerDecayValueAssigner(object): 68 | def __init__(self, values): 69 | self.values = values 70 | 71 | def get_scale(self, layer_id): 72 | return self.values[layer_id] 73 | 74 | def get_layer_id(self, model, var_name): 75 | return get_layer_id_for_rest(model, var_name) 76 | 77 | 78 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 79 | parameter_group_names = {} 80 | parameter_group_vars = {} 81 | 82 | for name, param in model.named_parameters(): 83 | if not param.requires_grad: 84 | continue # frozen weights 85 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 86 | group_name = "no_decay" 87 | this_weight_decay = 0. 88 | else: 89 | group_name = "decay" 90 | this_weight_decay = weight_decay 91 | if get_num_layer is not None: 92 | layer_id = get_num_layer(model, name) 93 | print("layer_id: {}", layer_id) 94 | group_name = "layer_%d_%s" % (layer_id, group_name) 95 | else: 96 | layer_id = None 97 | 98 | if group_name not in parameter_group_names: 99 | if get_layer_scale is not None: 100 | scale = get_layer_scale(layer_id) 101 | else: 102 | scale = 1. 103 | 104 | parameter_group_names[group_name] = { 105 | "weight_decay": this_weight_decay, 106 | "params": [], 107 | "lr_scale": scale 108 | } 109 | parameter_group_vars[group_name] = { 110 | "weight_decay": this_weight_decay, 111 | "params": [], 112 | "lr_scale": scale 113 | } 114 | 115 | parameter_group_vars[group_name]["params"].append(param) 116 | parameter_group_names[group_name]["params"].append(name) 117 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 118 | return list(parameter_group_vars.values()) 119 | 120 | 121 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 122 | opt_lower = args.opt.lower() 123 | weight_decay = args.weight_decay 124 | # if weight_decay and filter_bias_and_bn: 125 | if filter_bias_and_bn: 126 | skip = {} 127 | if skip_list is not None: 128 | skip = skip_list 129 | elif hasattr(model, 'no_weight_decay'): 130 | skip = model.no_weight_decay() 131 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 132 | weight_decay = 0. 133 | else: 134 | parameters = model.parameters() 135 | 136 | if 'fused' in opt_lower: 137 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 138 | 139 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 140 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 141 | opt_args['eps'] = args.opt_eps 142 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 143 | opt_args['betas'] = args.opt_betas 144 | 145 | opt_split = opt_lower.split('_') 146 | opt_lower = opt_split[-1] 147 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 148 | opt_args.pop('eps', None) 149 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 150 | elif opt_lower == 'momentum': 151 | opt_args.pop('eps', None) 152 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 153 | elif opt_lower == 'adam': 154 | optimizer = optim.Adam(parameters, **opt_args) 155 | elif opt_lower == 'adamw': 156 | optimizer = optim.AdamW(parameters, **opt_args) 157 | elif opt_lower == 'nadam': 158 | optimizer = Nadam(parameters, **opt_args) 159 | elif opt_lower == 'radam': 160 | optimizer = RAdam(parameters, **opt_args) 161 | elif opt_lower == 'adamp': 162 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 163 | elif opt_lower == 'sgdp': 164 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 165 | elif opt_lower == 'adadelta': 166 | optimizer = optim.Adadelta(parameters, **opt_args) 167 | elif opt_lower == 'adafactor': 168 | if not args.lr: 169 | opt_args['lr'] = None 170 | optimizer = Adafactor(parameters, **opt_args) 171 | elif opt_lower == 'lamb': 172 | optimizer = Lamb(parameters, **opt_args) 173 | elif opt_lower == 'lambc': 174 | optimizer = Lamb(parameters, trust_clip=True, **opt_args) 175 | elif opt_lower == 'lars': 176 | optimizer = Lars(parameters, momentum=args.momentum, **opt_args) 177 | elif opt_lower == 'adahessian': 178 | optimizer = Adahessian(parameters, **opt_args) 179 | elif opt_lower == 'rmsprop': 180 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 181 | elif opt_lower == 'rmsproptf': 182 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 183 | elif opt_lower == 'nvnovograd': 184 | optimizer = NvNovoGrad(parameters, **opt_args) 185 | elif opt_lower == 'fusedsgd': 186 | opt_args.pop('eps', None) 187 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 188 | elif opt_lower == 'fusedmomentum': 189 | opt_args.pop('eps', None) 190 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 191 | elif opt_lower == 'fusedadam': 192 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 193 | elif opt_lower == 'fusedadamw': 194 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 195 | elif opt_lower == 'fusedlamb': 196 | optimizer = FusedLAMB(parameters, **opt_args) 197 | elif opt_lower == 'fusednovograd': 198 | opt_args.setdefault('betas', (0.95, 0.98)) 199 | optimizer = FusedNovoGrad(parameters, **opt_args) 200 | else: 201 | assert False and "Invalid optimizer" 202 | 203 | if len(opt_split) > 1: 204 | if opt_split[0] == 'lookahead': 205 | optimizer = Lookahead(optimizer) 206 | 207 | return optimizer 208 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.0 2 | torchvision>=0.9.0 3 | timm==0.5.4 4 | tensorboardX 5 | six 6 | -------------------------------------------------------------------------------- /run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | import uuid 10 | from pathlib import Path 11 | 12 | import main as classification 13 | import submitit 14 | 15 | 16 | def parse_args(): 17 | classification_parser = classification.get_args_parser() 18 | parser = argparse.ArgumentParser("Submitit for ResTv2", parents=[classification_parser]) 19 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 20 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 21 | parser.add_argument("--timeout", default=72, type=int, help="Duration of the job, in hours") 22 | parser.add_argument("--job_name", default="restv2", type=str, help="Job name") 23 | parser.add_argument("--job_dir", default="", type=str, help="Job directory; leave empty for default") 24 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 25 | parser.add_argument("--use_volta32", action='store_true', default=True, help="Big models? Use this") 26 | parser.add_argument('--comment', default="", type=str, 27 | help='Comment to pass to scheduler, e.g. priority message') 28 | return parser.parse_args() 29 | 30 | 31 | def get_shared_folder() -> Path: 32 | user = os.getenv("USER") 33 | if Path("/checkpoint/").is_dir(): 34 | p = Path(f"/checkpoint/{user}/restv2") 35 | p.mkdir(exist_ok=True) 36 | return p 37 | raise RuntimeError("No shared folder available") 38 | 39 | 40 | def get_init_file(): 41 | # Init file must not exist, but it's parent dir must exist. 42 | os.makedirs(str(get_shared_folder()), exist_ok=True) 43 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 44 | if init_file.exists(): 45 | os.remove(str(init_file)) 46 | return init_file 47 | 48 | 49 | class Trainer(object): 50 | def __init__(self, args): 51 | self.args = args 52 | 53 | def __call__(self): 54 | import main as classification 55 | 56 | self._setup_gpu_args() 57 | classification.main(self.args) 58 | 59 | def checkpoint(self): 60 | import os 61 | import submitit 62 | 63 | self.args.dist_url = get_init_file().as_uri() 64 | self.args.auto_resume = True 65 | print("Requeuing ", self.args) 66 | empty_trainer = type(self)(self.args) 67 | return submitit.helpers.DelayedSubmission(empty_trainer) 68 | 69 | def _setup_gpu_args(self): 70 | import submitit 71 | from pathlib import Path 72 | 73 | job_env = submitit.JobEnvironment() 74 | self.args.output_dir = Path(self.args.job_dir) 75 | self.args.gpu = job_env.local_rank 76 | self.args.rank = job_env.global_rank 77 | self.args.world_size = job_env.num_tasks 78 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 79 | 80 | 81 | def main(): 82 | args = parse_args() 83 | 84 | if args.job_dir == "": 85 | args.job_dir = get_shared_folder() / "%j" 86 | 87 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 88 | 89 | num_gpus_per_node = args.ngpus 90 | nodes = args.nodes 91 | timeout_min = args.timeout * 60 92 | 93 | partition = args.partition 94 | kwargs = {} 95 | if args.use_volta32: 96 | kwargs['slurm_constraint'] = 'volta32gb' 97 | if args.comment: 98 | kwargs['slurm_comment'] = args.comment 99 | 100 | executor.update_parameters( 101 | mem_gb=40 * num_gpus_per_node, 102 | gpus_per_node=num_gpus_per_node, 103 | tasks_per_node=num_gpus_per_node, # one task per GPU 104 | cpus_per_task=10, 105 | nodes=nodes, 106 | timeout_min=timeout_min, # max is 60 * 72 107 | # Below are cluster dependent parameters 108 | slurm_partition=partition, 109 | slurm_signal_delay_s=120, 110 | **kwargs 111 | ) 112 | 113 | executor.update_parameters(name=args.job_name) 114 | 115 | args.dist_url = get_init_file().as_uri() 116 | args.output_dir = args.job_dir 117 | 118 | trainer = Trainer(args) 119 | job = executor.submit(trainer) 120 | 121 | print("Submitted job_id:", job.job_id) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /semantic_segmentation/README.md: -------------------------------------------------------------------------------- 1 | # ADE20k Semantic segmentation with ResTv2 2 | 3 | ## Getting started 4 | 5 | We add ResTv2 model and config files to the semantic_segmentation. 6 | 7 | ## Results and Fine-tuned Models 8 | 9 | | name | Pretrained Model | Method | Crop Size | Lr Schd | mIoU (ms+flip) | #params | FLOPs | FPS | Fine-tuned Model | 10 | |:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:| :---:| :---:| 11 | | ResTv2-T | ImageNet-1K | UPerNet | 512x512 | 160K | 47.3 | 62.1M | 977G | 22.4 | [baidu](https://pan.baidu.com/s/1X-hAafTLFnwJPQSI2BNOKw) | 12 | | ResTv2-S | ImageNet-1K | UPerNet | 512x512 | 160K | 49.2 | 72.9M | 1035G | 20.0 | [baidu](https://pan.baidu.com/s/1WHiL0Rf9JeOB76yh6WOvLQ) | 13 | | ResTv2-B | ImageNet-1K | UPerNet | 512x512 | 160K | 49.6 | 87.6M | 1095G | 19.2 | [baidu](https://pan.baidu.com/s/1dtkg68j3vCU-dxJxl8VFdg) | 14 | 15 | ### Training 16 | 17 | Command format: 18 | ``` 19 | tools/dist_train.sh --work-dir --seed 0 --deterministic --options model.pretrained= 20 | ``` 21 | 22 | For example, using a `ResTv2-T` backbone with UperNet: 23 | ```bash 24 | bash tools/dist_train.sh \ 25 | configs/ResTv2/upernet_restv2_tiny_512_160k_ade20k.py 8 \ 26 | --work-dir /path/to/save --seed 0 --deterministic \ 27 | --options model.pretrained=ResTv2_tiny_224.pth 28 | ``` 29 | 30 | More config files can be found at [`configs/ResTv2`](configs/ResTv2). 31 | 32 | 33 | ## Evaluation 34 | 35 | Command format: 36 | ``` 37 | tools/dist_test.sh --eval mIoU --aug-test 38 | ``` 39 | 40 | For example, evaluate a `ResTv2-T` backbone with UperNet: 41 | ```bash 42 | bash tools/dist_test.sh configs/ResTv2/upernet_ResTv2_tiny_512_160k_ade20k.py \ 43 | upernet_restv2_tiny_512_160k_ade20k.pth 4 --eval mIoU --aug-test 44 | ``` 45 | 46 | ## Acknowledgment 47 | 48 | This code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library. -------------------------------------------------------------------------------- /semantic_segmentation/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .rest_v2 import ResTV2 2 | 3 | __all__ = ['ResTV2'] 4 | -------------------------------------------------------------------------------- /semantic_segmentation/backbone/rest_v2.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import warnings 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | from mmcv.runner import BaseModule, _load_checkpoint 15 | from mmseg.utils import get_root_logger 16 | from mmseg.models.builder import BACKBONES 17 | 18 | 19 | class Mlp(nn.Module): 20 | def __init__(self, dim): 21 | super().__init__() 22 | self.fc1 = nn.Linear(dim, 4 * dim) 23 | self.act = nn.GELU() 24 | self.fc2 = nn.Linear(4 * dim, dim) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.fc2(x) 30 | return x 31 | 32 | 33 | class Attention(nn.Module): 34 | def __init__(self, 35 | dim, 36 | num_heads=8, 37 | sr_ratio=1): 38 | super().__init__() 39 | self.dim = dim 40 | self.num_heads = num_heads 41 | head_dim = dim // num_heads 42 | self.scale = head_dim ** -0.5 43 | 44 | self.q = nn.Linear(dim, dim, bias=True) 45 | self.kv = nn.Linear(dim, dim * 2, bias=True) 46 | 47 | self.sr_ratio = sr_ratio 48 | if sr_ratio > 1: 49 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim) 50 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6) 51 | 52 | self.up = nn.Sequential( 53 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim), 54 | nn.PixelShuffle(upscale_factor=sr_ratio) 55 | ) 56 | self.up_norm = nn.LayerNorm(dim, eps=1e-6) 57 | 58 | self.proj = nn.Linear(dim, dim) 59 | 60 | def forward(self, x, H, W): 61 | B, N, C = x.shape 62 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 63 | if self.sr_ratio > 1: 64 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 65 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) 66 | x = self.sr_norm(x) 67 | 68 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 69 | k, v = kv[0], kv[1] 70 | attn = (q @ k.transpose(-2, -1)) * self.scale 71 | attn = attn.softmax(dim=-1) 72 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 73 | 74 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio) 75 | identity = self.up(identity).flatten(2).transpose(1, 2) 76 | x = self.proj(x + self.up_norm(identity)) 77 | return x 78 | 79 | 80 | class Block(nn.Module): 81 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.): 82 | super().__init__() 83 | self.norm1 = nn.LayerNorm(dim, eps=1e-6) 84 | self.attn = Attention(dim, num_heads, sr_ratio) 85 | 86 | self.norm2 = nn.LayerNorm(dim, eps=1e-6) 87 | self.mlp = Mlp(dim) 88 | 89 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 90 | 91 | def forward(self, x, H, W): 92 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm 93 | x = x + self.drop_path(self.mlp(self.norm2(x))) 94 | return x 95 | 96 | 97 | class PA(nn.Module): 98 | def __init__(self, dim): 99 | super().__init__() 100 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True) 101 | self.sigmoid = nn.Sigmoid() 102 | 103 | def forward(self, x): 104 | return x * self.sigmoid(self.pa_conv(x)) 105 | 106 | 107 | class ConvStem(nn.Module): 108 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 109 | super().__init__() 110 | self.patch_size = to_2tuple(patch_size) 111 | stem = [] 112 | in_dim, out_dim = in_ch, out_ch // 2 113 | for i in range(2): 114 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False)) 115 | stem.append(nn.BatchNorm2d(out_dim)) 116 | stem.append(nn.ReLU(inplace=True)) 117 | in_dim, out_dim = out_dim, out_dim * 2 118 | 119 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1)) 120 | self.proj = nn.Sequential(*stem) 121 | 122 | self.with_pos = with_pos 123 | if self.with_pos: 124 | self.pos = PA(out_ch) 125 | 126 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 127 | 128 | def forward(self, x): 129 | B, C, H, W = x.shape 130 | x = self.proj(x) 131 | if self.with_pos: 132 | x = self.pos(x) 133 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 134 | x = self.norm(x) 135 | H, W = H // self.patch_size[0], W // self.patch_size[1] 136 | return x, (H, W) 137 | 138 | 139 | class PatchEmbed(nn.Module): 140 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True): 141 | super().__init__() 142 | self.patch_size = to_2tuple(patch_size) 143 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2) 144 | 145 | self.with_pos = with_pos 146 | if self.with_pos: 147 | self.pos = PA(out_ch) 148 | 149 | self.norm = nn.LayerNorm(out_ch, eps=1e-6) 150 | 151 | def forward(self, x): 152 | B, C, H, W = x.shape 153 | x = self.proj(x) 154 | if self.with_pos: 155 | x = self.pos(x) 156 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 157 | x = self.norm(x) 158 | H, W = H // self.patch_size[0], W // self.patch_size[1] 159 | return x, (H, W) 160 | 161 | 162 | @BACKBONES.register_module() 163 | class ResTV2(BaseModule): 164 | def __init__(self, in_chans=3, embed_dims=[96, 192, 384, 768], 165 | num_heads=[1, 2, 4, 8], drop_path_rate=0., 166 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3), 167 | pretrained=None, init_cfg=None): 168 | super().__init__() 169 | if isinstance(pretrained, str) or pretrained is None: 170 | warnings.warn('DeprecationWarning: pretrained is a deprecated, please use "init_cfg" instead') 171 | else: 172 | raise TypeError('pretrained must be a str or None') 173 | self.pretrained = pretrained 174 | self.init_cfg = init_cfg 175 | 176 | self.depths = depths 177 | self.out_indices = out_indices 178 | 179 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4) 180 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2) 181 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2) 182 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2) 183 | 184 | # transformer encoder 185 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 186 | cur = 0 187 | self.stage1 = nn.ModuleList([ 188 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i]) 189 | for i in range(depths[0]) 190 | ]) 191 | 192 | cur += depths[0] 193 | self.stage2 = nn.ModuleList([ 194 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i]) 195 | for i in range(depths[1]) 196 | ]) 197 | 198 | cur += depths[1] 199 | self.stage3 = nn.ModuleList([ 200 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i]) 201 | for i in range(depths[2]) 202 | ]) 203 | 204 | cur += depths[2] 205 | self.stage4 = nn.ModuleList([ 206 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i]) 207 | for i in range(depths[3]) 208 | ]) 209 | 210 | for idx in out_indices: 211 | out_ch = embed_dims[idx] 212 | layer = LayerNorm(out_ch, eps=1e-6, data_format="channels_first") 213 | layer_name = f"norm_{idx + 1}" 214 | self.add_module(layer_name, layer) 215 | 216 | def init_weights(self): 217 | if self.pretrained is None: 218 | super().init_weights() 219 | for m in self.modules(): 220 | if isinstance(m, (nn.Conv2d, nn.Linear)): 221 | trunc_normal_(m.weight, std=0.02) 222 | if m.bias is not None: 223 | nn.init.constant_(m.bias, 0) 224 | 225 | elif isinstance(self.pretrained, str): 226 | logger = get_root_logger() 227 | ckpt = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu') 228 | if 'state_dict' in ckpt: 229 | state_dict = ckpt['state_dict'] 230 | elif 'model' in ckpt: 231 | state_dict = ckpt['model'] 232 | else: 233 | state_dict = ckpt 234 | 235 | # strip prefix of state_dict 236 | if list(state_dict.keys())[0].startswith('module.'): 237 | state_dict = {k[7:]: v for k, v in state_dict.items()} 238 | self.load_state_dict(state_dict, False) 239 | 240 | def forward(self, x): 241 | outs = [] 242 | B, _, H, W = x.shape 243 | x, (H, W) = self.stem(x) 244 | # stage 1 245 | for blk in self.stage1: 246 | x = blk(x, H, W) 247 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 248 | if 0 in self.out_indices: 249 | outs.append(self.norm_1(x)) 250 | 251 | # stage 2 252 | x, (H, W) = self.patch_2(x) 253 | for blk in self.stage2: 254 | x = blk(x, H, W) 255 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 256 | if 1 in self.out_indices: 257 | outs.append(self.norm_2(x)) 258 | 259 | # stage 3 260 | x, (H, W) = self.patch_3(x) 261 | for blk in self.stage3: 262 | x = blk(x, H, W) 263 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 264 | if 2 in self.out_indices: 265 | outs.append(self.norm_3(x)) 266 | 267 | # stage 4 268 | x, (H, W) = self.patch_4(x) 269 | for blk in self.stage4: 270 | x = blk(x, H, W) 271 | x = x.permute(0, 2, 1).reshape(B, -1, H, W) 272 | if 3 in self.out_indices: 273 | outs.append(self.norm_4(x)) 274 | 275 | return tuple(outs) 276 | 277 | 278 | class LayerNorm(nn.Module): 279 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 280 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 281 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 282 | with shape (batch_size, channels, height, width). 283 | """ 284 | 285 | def __init__(self, dim, eps=1e-6, data_format="channels_last"): 286 | super().__init__() 287 | self.weight = nn.Parameter(torch.ones(dim)) 288 | self.bias = nn.Parameter(torch.zeros(dim)) 289 | self.eps = eps 290 | self.data_format = data_format 291 | if self.data_format not in ["channels_last", "channels_first"]: 292 | raise NotImplementedError 293 | self.dim = (dim,) 294 | 295 | def forward(self, x): 296 | if self.data_format == "channels_last": 297 | return F.layer_norm(x, self.dim, self.weight, self.bias, self.eps) 298 | elif self.data_format == "channels_first": 299 | u = x.mean(1, keepdim=True) 300 | s = (x - u).pow(2).mean(1, keepdim=True) 301 | x = (x - u) / torch.sqrt(s + self.eps) 302 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 303 | return x 304 | -------------------------------------------------------------------------------- /semantic_segmentation/configs/ResTv2/upernet_restv2_base_512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | 8 | _base_ = [ 9 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py', 10 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 11 | ] 12 | crop_size = (512, 512) 13 | 14 | model = dict( 15 | pretrained='./restv2_base.pth', 16 | backbone=dict( 17 | type='ResTV2', 18 | in_chans=3, 19 | num_heads=[1, 2, 4, 8], 20 | embed_dims=[96, 192, 384, 768], 21 | depths=[1, 3, 16, 3], 22 | drop_path_rate=0.3, 23 | out_indices=[0, 1, 2, 3], 24 | ), 25 | decode_head=dict( 26 | in_channels=[96, 192, 384, 768], 27 | num_classes=150, 28 | ), 29 | auxiliary_head=dict( 30 | in_channels=384, 31 | num_classes=150 32 | ), 33 | ) 34 | 35 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW', 36 | lr=0.00015, betas=(0.9, 0.999), weight_decay=0.05, 37 | paramwise_cfg={'decay_rate': 0.9, 38 | 'decay_type': 'stage_wise', 39 | 'num_layers': (1, 3, 16, 3)}) 40 | 41 | lr_config = dict(_delete_=True, policy='poly', 42 | warmup='linear', 43 | warmup_iters=1500, 44 | warmup_ratio=1e-6, 45 | power=1.0, min_lr=0.0, by_epoch=False) 46 | 47 | # By default, models are trained on 8 GPUs with 2 images per GPU 48 | data = dict(samples_per_gpu=2) 49 | -------------------------------------------------------------------------------- /semantic_segmentation/configs/ResTv2/upernet_restv2_small_512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | _base_ = [ 8 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py', 9 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 10 | ] 11 | crop_size = (512, 512) 12 | 13 | model = dict( 14 | pretrained='./restv2_small.pth', 15 | backbone=dict( 16 | type='ResTV2', 17 | in_chans=3, 18 | num_heads=[1, 2, 4, 8], 19 | embed_dims=[96, 192, 384, 768], 20 | depths=[1, 2, 12, 2], 21 | drop_path_rate=0.2, 22 | out_indices=[0, 1, 2, 3], 23 | ), 24 | decode_head=dict( 25 | in_channels=[96, 192, 384, 768], 26 | num_classes=150, 27 | ), 28 | auxiliary_head=dict( 29 | in_channels=384, 30 | num_classes=150 31 | ), 32 | ) 33 | 34 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW', 35 | lr=0.00015, betas=(0.9, 0.999), weight_decay=0.05, 36 | paramwise_cfg={'decay_rate': 0.9, 37 | 'decay_type': 'stage_wise', 38 | 'num_layers': (1, 2, 12, 2)}) 39 | 40 | lr_config = dict(_delete_=True, policy='poly', 41 | warmup='linear', 42 | warmup_iters=1500, 43 | warmup_ratio=1e-6, 44 | power=1.0, min_lr=0.0, by_epoch=False) 45 | 46 | # By default, models are trained on 8 GPUs with 2 images per GPU 47 | data = dict(samples_per_gpu=2) 48 | -------------------------------------------------------------------------------- /semantic_segmentation/configs/ResTv2/upernet_restv2_tiny_512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | 8 | _base_ = [ 9 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py', 10 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' 11 | ] 12 | crop_size = (512, 512) 13 | 14 | model = dict( 15 | pretrained="", 16 | backbone=dict( 17 | type='ResTV2', 18 | in_chans=3, 19 | num_heads=[1, 2, 4, 8], 20 | embed_dims=[96, 192, 384, 768], 21 | depths=[1, 2, 6, 2], 22 | drop_path_rate=0.1, 23 | out_indices=(0, 1, 2, 3), 24 | ), 25 | decode_head=dict( 26 | in_channels=[96, 192, 384, 768], 27 | num_classes=150, 28 | ), 29 | auxiliary_head=dict( 30 | in_channels=384, 31 | num_classes=150 32 | ), 33 | ) 34 | 35 | optimizer = dict( 36 | _delete_=True, 37 | type='AdamW', 38 | lr=0.00006, 39 | betas=(0.9, 0.999), 40 | weight_decay=0.01, 41 | paramwise_cfg=dict( 42 | custom_keys={ 43 | 'absolute_pos_embed': dict(decay_mult=0.), 44 | 'relative_position_bias_table': dict(decay_mult=0.), 45 | 'norm': dict(decay_mult=0.) 46 | })) 47 | 48 | lr_config = dict( 49 | _delete_=True, 50 | policy='poly', 51 | warmup='linear', 52 | warmup_iters=1500, 53 | warmup_ratio=1e-6, 54 | power=1.0, 55 | min_lr=0.0, 56 | by_epoch=False) 57 | 58 | # By default, models are trained on 8 GPUs with 2 images per GPU 59 | data = dict(samples_per_gpu=2) 60 | -------------------------------------------------------------------------------- /semantic_segmentation/configs/_base_/models/upernet_restv2.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | backbone_norm_cfg = dict(type='LN', requires_grad=True) 4 | model = dict( 5 | type='EncoderDecoder', 6 | pretrained=None, 7 | backbone=dict( 8 | type='ResTV2', 9 | in_chans=3, 10 | num_heads=[1, 2, 4, 8], 11 | embed_dims=[96, 192, 384, 768], 12 | depths=[1, 2, 6, 2], 13 | drop_path_rate=0.2, 14 | out_indices=[0, 1, 2, 3], 15 | ), 16 | decode_head=dict( 17 | type='UPerHead', 18 | in_channels=[96, 192, 384, 768], 19 | in_index=[0, 1, 2, 3], 20 | pool_scales=(1, 2, 3, 6), 21 | channels=512, 22 | dropout_ratio=0.1, 23 | num_classes=19, 24 | norm_cfg=norm_cfg, 25 | align_corners=False, 26 | loss_decode=dict( 27 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 28 | auxiliary_head=dict( 29 | type='FCNHead', 30 | in_channels=384, 31 | in_index=2, 32 | channels=256, 33 | num_convs=1, 34 | concat_input=False, 35 | dropout_ratio=0.1, 36 | num_classes=19, 37 | norm_cfg=norm_cfg, 38 | align_corners=False, 39 | loss_decode=dict( 40 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 41 | # model training and testing settings 42 | train_cfg=dict(), 43 | test_cfg=dict(mode='whole')) 44 | -------------------------------------------------------------------------------- /semantic_segmentation/tools/align_resize.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | import mmcv 8 | import numpy as np 9 | from mmcv.utils import deprecated_api_warning, is_tuple_of 10 | from numpy import random 11 | 12 | from mmseg.datasets.builder import PIPELINES 13 | 14 | @PIPELINES.register_module() 15 | class AlignResize(object): 16 | """Resize images & seg. Align 17 | """ 18 | 19 | def __init__(self, 20 | img_scale=None, 21 | multiscale_mode='range', 22 | ratio_range=None, 23 | keep_ratio=True, 24 | size_divisor=32): 25 | if img_scale is None: 26 | self.img_scale = None 27 | else: 28 | if isinstance(img_scale, list): 29 | self.img_scale = img_scale 30 | else: 31 | self.img_scale = [img_scale] 32 | assert mmcv.is_list_of(self.img_scale, tuple) 33 | 34 | if ratio_range is not None: 35 | # mode 1: given img_scale=None and a range of image ratio 36 | # mode 2: given a scale and a range of image ratio 37 | assert self.img_scale is None or len(self.img_scale) == 1 38 | else: 39 | # mode 3 and 4: given multiple scales or a range of scales 40 | assert multiscale_mode in ['value', 'range'] 41 | 42 | self.multiscale_mode = multiscale_mode 43 | self.ratio_range = ratio_range 44 | self.keep_ratio = keep_ratio 45 | self.size_divisor = size_divisor 46 | 47 | @staticmethod 48 | def random_select(img_scales): 49 | """Randomly select an img_scale from given candidates. 50 | 51 | Args: 52 | img_scales (list[tuple]): Images scales for selection. 53 | 54 | Returns: 55 | (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, 56 | where ``img_scale`` is the selected image scale and 57 | ``scale_idx`` is the selected index in the given candidates. 58 | """ 59 | 60 | assert mmcv.is_list_of(img_scales, tuple) 61 | scale_idx = np.random.randint(len(img_scales)) 62 | img_scale = img_scales[scale_idx] 63 | return img_scale, scale_idx 64 | 65 | @staticmethod 66 | def random_sample(img_scales): 67 | """Randomly sample an img_scale when ``multiscale_mode=='range'``. 68 | 69 | Args: 70 | img_scales (list[tuple]): Images scale range for sampling. 71 | There must be two tuples in img_scales, which specify the lower 72 | and uper bound of image scales. 73 | 74 | Returns: 75 | (tuple, None): Returns a tuple ``(img_scale, None)``, where 76 | ``img_scale`` is sampled scale and None is just a placeholder 77 | to be consistent with :func:`random_select`. 78 | """ 79 | 80 | assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2 81 | img_scale_long = [max(s) for s in img_scales] 82 | img_scale_short = [min(s) for s in img_scales] 83 | long_edge = np.random.randint( 84 | min(img_scale_long), 85 | max(img_scale_long) + 1) 86 | short_edge = np.random.randint( 87 | min(img_scale_short), 88 | max(img_scale_short) + 1) 89 | img_scale = (long_edge, short_edge) 90 | return img_scale, None 91 | 92 | @staticmethod 93 | def random_sample_ratio(img_scale, ratio_range): 94 | """Randomly sample an img_scale when ``ratio_range`` is specified. 95 | 96 | A ratio will be randomly sampled from the range specified by 97 | ``ratio_range``. Then it would be multiplied with ``img_scale`` to 98 | generate sampled scale. 99 | 100 | Args: 101 | img_scale (tuple): Images scale base to multiply with ratio. 102 | ratio_range (tuple[float]): The minimum and maximum ratio to scale 103 | the ``img_scale``. 104 | 105 | Returns: 106 | (tuple, None): Returns a tuple ``(scale, None)``, where 107 | ``scale`` is sampled ratio multiplied with ``img_scale`` and 108 | None is just a placeholder to be consistent with 109 | :func:`random_select`. 110 | """ 111 | 112 | assert isinstance(img_scale, tuple) and len(img_scale) == 2 113 | min_ratio, max_ratio = ratio_range 114 | assert min_ratio <= max_ratio 115 | ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio 116 | scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) 117 | return scale, None 118 | 119 | def _random_scale(self, results): 120 | """Randomly sample an img_scale according to ``ratio_range`` and 121 | ``multiscale_mode``. 122 | 123 | If ``ratio_range`` is specified, a ratio will be sampled and be 124 | multiplied with ``img_scale``. 125 | If multiple scales are specified by ``img_scale``, a scale will be 126 | sampled according to ``multiscale_mode``. 127 | Otherwise, single scale will be used. 128 | 129 | Args: 130 | results (dict): Result dict from :obj:`dataset`. 131 | 132 | Returns: 133 | dict: Two new keys 'scale` and 'scale_idx` are added into 134 | ``results``, which would be used by subsequent pipelines. 135 | """ 136 | 137 | if self.ratio_range is not None: 138 | if self.img_scale is None: 139 | h, w = results['img'].shape[:2] 140 | scale, scale_idx = self.random_sample_ratio((w, h), 141 | self.ratio_range) 142 | else: 143 | scale, scale_idx = self.random_sample_ratio( 144 | self.img_scale[0], self.ratio_range) 145 | elif len(self.img_scale) == 1: 146 | scale, scale_idx = self.img_scale[0], 0 147 | elif self.multiscale_mode == 'range': 148 | scale, scale_idx = self.random_sample(self.img_scale) 149 | elif self.multiscale_mode == 'value': 150 | scale, scale_idx = self.random_select(self.img_scale) 151 | else: 152 | raise NotImplementedError 153 | 154 | results['scale'] = scale 155 | results['scale_idx'] = scale_idx 156 | 157 | def _align(self, img, size_divisor, interpolation=None): 158 | align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor 159 | align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor 160 | if interpolation == None: 161 | img = mmcv.imresize(img, (align_w, align_h)) 162 | else: 163 | img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation) 164 | return img 165 | 166 | def _resize_img(self, results): 167 | """Resize images with ``results['scale']``.""" 168 | if self.keep_ratio: 169 | img, scale_factor = mmcv.imrescale( 170 | results['img'], results['scale'], return_scale=True) 171 | #### align #### 172 | img = self._align(img, self.size_divisor) 173 | # the w_scale and h_scale has minor difference 174 | # a real fix should be done in the mmcv.imrescale in the future 175 | new_h, new_w = img.shape[:2] 176 | h, w = results['img'].shape[:2] 177 | w_scale = new_w / w 178 | h_scale = new_h / h 179 | else: 180 | img, w_scale, h_scale = mmcv.imresize( 181 | results['img'], results['scale'], return_scale=True) 182 | 183 | h, w = img.shape[:2] 184 | assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \ 185 | int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \ 186 | "img size not align. h:{} w:{}".format(h,w) 187 | scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], 188 | dtype=np.float32) 189 | results['img'] = img 190 | results['img_shape'] = img.shape 191 | results['pad_shape'] = img.shape # in case that there is no padding 192 | results['scale_factor'] = scale_factor 193 | results['keep_ratio'] = self.keep_ratio 194 | 195 | def _resize_seg(self, results): 196 | """Resize semantic segmentation map with ``results['scale']``.""" 197 | for key in results.get('seg_fields', []): 198 | if self.keep_ratio: 199 | gt_seg = mmcv.imrescale( 200 | results[key], results['scale'], interpolation='nearest') 201 | gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest') 202 | else: 203 | gt_seg = mmcv.imresize( 204 | results[key], results['scale'], interpolation='nearest') 205 | h, w = gt_seg.shape[:2] 206 | assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \ 207 | int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \ 208 | "gt_seg size not align. h:{} w:{}".format(h, w) 209 | results[key] = gt_seg 210 | 211 | def __call__(self, results): 212 | """Call function to resize images, bounding boxes, masks, semantic 213 | segmentation map. 214 | 215 | Args: 216 | results (dict): Result dict from loading pipeline. 217 | 218 | Returns: 219 | dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', 220 | 'keep_ratio' keys are added into result dict. 221 | """ 222 | 223 | if 'scale' not in results: 224 | self._random_scale(results) 225 | self._resize_img(results) 226 | self._resize_seg(results) 227 | return results 228 | 229 | def __repr__(self): 230 | repr_str = self.__class__.__name__ 231 | repr_str += (f'(img_scale={self.img_scale}, ' 232 | f'multiscale_mode={self.multiscale_mode}, ' 233 | f'ratio_range={self.ratio_range}, ' 234 | f'keep_ratio={self.keep_ratio})') 235 | return repr_str 236 | -------------------------------------------------------------------------------- /semantic_segmentation/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} -------------------------------------------------------------------------------- /semantic_segmentation/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} -------------------------------------------------------------------------------- /semantic_segmentation/tools/train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------ 2 | # Copyright (c) VCU, Nanjing University. 3 | # Licensed under the Apache License 2.0 [see LICENSE for details] 4 | # Written by Qing-Long Zhang 5 | # ------------------------------------------------------------ 6 | 7 | 8 | import argparse 9 | import copy 10 | import os 11 | import os.path as osp 12 | import time 13 | import warnings 14 | 15 | import mmcv 16 | import torch 17 | import torch.distributed as dist 18 | from mmcv.cnn.utils import revert_sync_batchnorm 19 | from mmcv.runner import get_dist_info, init_dist 20 | from mmcv.utils import Config, DictAction, get_git_hash 21 | 22 | from mmseg import __version__ 23 | from mmseg.apis import init_random_seed, set_random_seed, train_segmentor 24 | from mmseg.datasets import build_dataset 25 | from mmseg.models import build_segmentor 26 | from mmseg.utils import collect_env, get_root_logger, setup_multi_processes 27 | 28 | from align_resize import AlignResize 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser(description='Train a segmentor') 32 | parser.add_argument('config', help='train config file path') 33 | parser.add_argument('--work-dir', help='the dir to save logs and models') 34 | parser.add_argument( 35 | '--load-from', help='the checkpoint file to load weights from') 36 | parser.add_argument( 37 | '--resume-from', help='the checkpoint file to resume from') 38 | parser.add_argument( 39 | '--no-validate', 40 | action='store_true', 41 | help='whether not to evaluate the checkpoint during training') 42 | group_gpus = parser.add_mutually_exclusive_group() 43 | group_gpus.add_argument( 44 | '--gpus', 45 | type=int, 46 | help='(Deprecated, please use --gpu-id) number of gpus to use ' 47 | '(only applicable to non-distributed training)') 48 | group_gpus.add_argument( 49 | '--gpu-ids', 50 | type=int, 51 | nargs='+', 52 | help='(Deprecated, please use --gpu-id) ids of gpus to use ' 53 | '(only applicable to non-distributed training)') 54 | group_gpus.add_argument( 55 | '--gpu-id', 56 | type=int, 57 | default=0, 58 | help='id of gpu to use ' 59 | '(only applicable to non-distributed training)') 60 | parser.add_argument('--seed', type=int, default=None, help='random seed') 61 | parser.add_argument( 62 | '--diff_seed', 63 | action='store_true', 64 | help='Whether or not set different seeds for different ranks') 65 | parser.add_argument( 66 | '--deterministic', 67 | action='store_true', 68 | help='whether to set deterministic options for CUDNN backend.') 69 | parser.add_argument( 70 | '--options', 71 | nargs='+', 72 | action=DictAction, 73 | help="--options is deprecated in favor of --cfg_options' and it will " 74 | 'not be supported in version v0.22.0. Override some settings in the ' 75 | 'used config, the key-value pair in xxx=yyy format will be merged ' 76 | 'into config file. If the value to be overwritten is a list, it ' 77 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 78 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 79 | 'marks are necessary and that no white space is allowed.') 80 | parser.add_argument( 81 | '--cfg-options', 82 | nargs='+', 83 | action=DictAction, 84 | help='override some settings in the used config, the key-value pair ' 85 | 'in xxx=yyy format will be merged into config file. If the value to ' 86 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 87 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 88 | 'Note that the quotation marks are necessary and that no white space ' 89 | 'is allowed.') 90 | parser.add_argument( 91 | '--launcher', 92 | choices=['none', 'pytorch', 'slurm', 'mpi'], 93 | default='none', 94 | help='job launcher') 95 | parser.add_argument('--local_rank', type=int, default=0) 96 | parser.add_argument( 97 | '--auto-resume', 98 | action='store_true', 99 | help='resume from the latest checkpoint automatically.') 100 | args = parser.parse_args() 101 | if 'LOCAL_RANK' not in os.environ: 102 | os.environ['LOCAL_RANK'] = str(args.local_rank) 103 | 104 | if args.options and args.cfg_options: 105 | raise ValueError( 106 | '--options and --cfg-options cannot be both ' 107 | 'specified, --options is deprecated in favor of --cfg-options. ' 108 | '--options will not be supported in version v0.22.0.') 109 | if args.options: 110 | warnings.warn('--options is deprecated in favor of --cfg-options. ' 111 | '--options will not be supported in version v0.22.0.') 112 | args.cfg_options = args.options 113 | 114 | return args 115 | 116 | 117 | def main(): 118 | args = parse_args() 119 | 120 | cfg = Config.fromfile(args.config) 121 | if args.cfg_options is not None: 122 | cfg.merge_from_dict(args.cfg_options) 123 | 124 | # set cudnn_benchmark 125 | if cfg.get('cudnn_benchmark', False): 126 | torch.backends.cudnn.benchmark = True 127 | 128 | # work_dir is determined in this priority: CLI > segment in file > filename 129 | if args.work_dir is not None: 130 | # update configs according to CLI args if args.work_dir is not None 131 | cfg.work_dir = args.work_dir 132 | elif cfg.get('work_dir', None) is None: 133 | # use config filename as default work_dir if cfg.work_dir is None 134 | cfg.work_dir = osp.join('./work_dirs', 135 | osp.splitext(osp.basename(args.config))[0]) 136 | if args.load_from is not None: 137 | cfg.load_from = args.load_from 138 | if args.resume_from is not None: 139 | cfg.resume_from = args.resume_from 140 | if args.gpus is not None: 141 | cfg.gpu_ids = range(1) 142 | warnings.warn('`--gpus` is deprecated because we only support ' 143 | 'single GPU mode in non-distributed training. ' 144 | 'Use `gpus=1` now.') 145 | if args.gpu_ids is not None: 146 | cfg.gpu_ids = args.gpu_ids[0:1] 147 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 148 | 'Because we only support single GPU mode in ' 149 | 'non-distributed training. Use the first GPU ' 150 | 'in `gpu_ids` now.') 151 | if args.gpus is None and args.gpu_ids is None: 152 | cfg.gpu_ids = [args.gpu_id] 153 | 154 | cfg.auto_resume = args.auto_resume 155 | 156 | # init distributed env first, since logger depends on the dist info. 157 | if args.launcher == 'none': 158 | distributed = False 159 | else: 160 | distributed = True 161 | init_dist(args.launcher, **cfg.dist_params) 162 | # gpu_ids is used to calculate iter when resuming checkpoint 163 | _, world_size = get_dist_info() 164 | cfg.gpu_ids = range(world_size) 165 | 166 | # create work_dir 167 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 168 | # dump config 169 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 170 | # init the logger before other steps 171 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 172 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 173 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 174 | 175 | # set multi-process settings 176 | setup_multi_processes(cfg) 177 | 178 | # init the meta dict to record some important information such as 179 | # environment info and seed, which will be logged 180 | meta = dict() 181 | # log env info 182 | env_info_dict = collect_env() 183 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 184 | dash_line = '-' * 60 + '\n' 185 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 186 | dash_line) 187 | meta['env_info'] = env_info 188 | 189 | # log some basic info 190 | logger.info(f'Distributed training: {distributed}') 191 | logger.info(f'Config:\n{cfg.pretty_text}') 192 | 193 | # set random seeds 194 | seed = init_random_seed(args.seed) 195 | seed = seed + dist.get_rank() if args.diff_seed else seed 196 | logger.info(f'Set random seed to {seed}, ' 197 | f'deterministic: {args.deterministic}') 198 | set_random_seed(seed, deterministic=args.deterministic) 199 | cfg.seed = seed 200 | meta['seed'] = seed 201 | meta['exp_name'] = osp.basename(args.config) 202 | 203 | model = build_segmentor( 204 | cfg.model, 205 | train_cfg=cfg.get('train_cfg'), 206 | test_cfg=cfg.get('test_cfg')) 207 | model.init_weights() 208 | 209 | # SyncBN is not support for DP 210 | if not distributed: 211 | warnings.warn( 212 | 'SyncBN is only supported with DDP. To be compatible with DP, ' 213 | 'we convert SyncBN to BN. Please use dist_train.sh which can ' 214 | 'avoid this error.') 215 | model = revert_sync_batchnorm(model) 216 | 217 | logger.info(model) 218 | 219 | datasets = [build_dataset(cfg.data.train)] 220 | if len(cfg.workflow) == 2: 221 | val_dataset = copy.deepcopy(cfg.data.val) 222 | val_dataset.pipeline = cfg.data.train.pipeline 223 | datasets.append(build_dataset(val_dataset)) 224 | if cfg.checkpoint_config is not None: 225 | # save mmseg version, config file content and class names in 226 | # checkpoints as meta data 227 | cfg.checkpoint_config.meta = dict( 228 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 229 | config=cfg.pretty_text, 230 | CLASSES=datasets[0].CLASSES, 231 | PALETTE=datasets[0].PALETTE) 232 | # add an attribute for visualization convenience 233 | model.CLASSES = datasets[0].CLASSES 234 | # passing checkpoint meta for saving best checkpoint 235 | meta.update(cfg.checkpoint_config.meta) 236 | train_segmentor( 237 | model, 238 | datasets, 239 | cfg, 240 | distributed=distributed, 241 | validate=(not args.no_validate), 242 | timestamp=timestamp, 243 | meta=meta) 244 | 245 | 246 | if __name__ == '__main__': 247 | main() --------------------------------------------------------------------------------