├── .gitignore
├── INSTALL.md
├── LICENSE
├── README.md
├── TRAINING.md
├── datasets.py
├── engine.py
├── figures
└── fig_1.png
├── fourier_analysis.py
├── get_flops.py
├── main.py
├── models
├── __init__.py
├── rest.py
└── rest_v2.py
├── object_detection
├── GETTING_STARTED.md
├── README.md
├── analyze_model.py
├── configs
│ ├── Base-RCNN-FPN.yaml
│ ├── Base-RetinaNet.yaml
│ ├── ResTv1
│ │ ├── mask_rcnn_rest_base_FPN_1x.yaml
│ │ ├── mask_rcnn_rest_small_FPN_1x.yaml
│ │ ├── retinanet_rest_base_FPN_1x.yaml
│ │ └── retinanet_rest_small_FPN_1x.yaml
│ └── ResTv2
│ │ ├── mask_rcnn_rest_base_FPN_1x.yaml
│ │ ├── mask_rcnn_rest_small_FPN_1x.yaml
│ │ ├── mask_rcnn_restv2_base_FPN_3x.yaml
│ │ ├── mask_rcnn_restv2_small_FPN_3x.yaml
│ │ ├── mask_rcnn_restv2_tiny_FPN_3x.yaml
│ │ └── retinanet_restv2_tiny_FPN_3x.yaml
├── convert_to_d2.py
├── datasets
│ ├── README.md
│ ├── prepare_ade20k_sem_seg.py
│ ├── prepare_cocofied_lvis.py
│ ├── prepare_for_tests.sh
│ └── prepare_panoptic_fpn.py
├── restv2
│ ├── __init__.py
│ ├── config.py
│ ├── rest.py
│ └── restv2.py
├── train_net.py
└── utils.py
├── optim_factory.py
├── requirements.txt
├── run_with_submitit.py
├── semantic_segmentation
├── README.md
├── backbone
│ ├── __init__.py
│ └── rest_v2.py
├── configs
│ ├── ResTv2
│ │ ├── upernet_restv2_base_512_160k_ade20k.py
│ │ ├── upernet_restv2_small_512_160k_ade20k.py
│ │ └── upernet_restv2_tiny_512_160k_ade20k.py
│ └── _base_
│ │ └── models
│ │ └── upernet_restv2.py
└── tools
│ ├── align_resize.py
│ ├── dist_test.sh
│ ├── dist_train.sh
│ └── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # output dir
2 | outputs
3 | instant_test_output
4 | inference_test_output
5 |
6 | *.png
7 | *.json
8 | *.diff
9 | *.jpg
10 | !/projects/DensePose/doc/images/*.jpg
11 |
12 | # compilation and distribution
13 | __pycache__
14 | _ext
15 | *.pyc
16 | *.pyd
17 | *.so
18 | *.dll
19 | *.egg-info/
20 | build/
21 | dist/
22 | wheels/
23 |
24 | # pytorch/python/numpy formats
25 | *.pth
26 | *.pkl
27 | *.npy
28 | *.ts
29 | model_ts*.txt
30 |
31 | # ipython/jupyter notebooks
32 | *.ipynb
33 | **/.ipynb_checkpoints/
34 |
35 | # Editor temporaries
36 | *.swn
37 | *.swo
38 | *.swp
39 | *~
40 |
41 | # editor settings
42 | .idea
43 | .vscode
44 | _darcs
45 |
46 | # project dirs
47 | /detectron2/model_zoo/configs
48 | /datasets/*
49 | !/datasets/*.*
50 | /projects/*/datasets
51 | /models
52 | /snippet
--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | We provide installation instructions for ImageNet classification experiments here.
4 |
5 | ## Dependency Setup
6 | Create a new conda virtual environment
7 | ```
8 | conda create -n rest python=3.9 -y
9 | conda activate rest
10 | ```
11 |
12 | Install [PyTorch](https://pytorch.org/) >= 1.8.0, [torchvision](https://pytorch.org/vision/stable/index.html) >=0.9.0 following official instructions. For example:
13 | ```
14 | pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
15 | ```
16 |
17 | Clone this repo and install required packages:
18 | ```
19 | pip install timm==0.5.4 tensorboardX six
20 | ```
21 |
22 | The results in the paper are generated with `torch==1.8.0+cu111 torchvision==0.9.0+cu111 timm==0.5.4`.
23 |
24 | ## Dataset Preparation
25 |
26 | Download the [ImageNet-1K](http://image-net.org/) classification dataset and structure the data as follows:
27 | ```
28 | /path/to/imagenet-1k/
29 | train/
30 | class1/
31 | img1.jpeg
32 | class2/
33 | img2.jpeg
34 | val/
35 | class1/
36 | img3.jpeg
37 | class2/
38 | img4.jpeg
39 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2020 - present, Facebook, Inc
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Updates
2 | - (2022/05/10) Code of [ResTV2](https://arxiv.org/abs/2204.07366) is released! ResTv2 simplifies the EMSA structure in
3 | [ResTv1](https://arxiv.org/abs/2105.13677) (i.e., eliminating the multi-head interaction part) and employs an upsample
4 | operation to reconstruct the lost medium- and high-frequency information caused by the downsampling operation.
5 |
6 | # [ResT: An Efficient Transformer for Visual Recognition](https://arxiv.org/abs/2105.13677)
7 |
8 | Official PyTorch implementation of **ResTv1** and **ResTv2**, from the following paper:
9 |
10 | [ResT: An Efficient Transformer for Visual Recognition](https://arxiv.org/abs/2105.13677). NeurIPS 2021.\
11 | [ResT V2: Simpler, Faster and Stronger](https://arxiv.org/abs/2204.07366). NeurIPS 2022.\
12 | By Qing-Long Zhang and Yu-Bin Yang \
13 | State Key Laboratory for Novel Software Technology at Nanjing University
14 |
15 | ---
16 |
17 |
18 |
20 |
21 |
22 | **ResTv1** is initially described in [arxiv](https://arxiv.org/abs/2105.13677), which capably serves as a
23 | general-purpose backbone for computer vision. It can tackle input images with arbitrary size. Besides,
24 | ResT compressed the memory of standard MSA and model the interaction between multi-heads while keeping
25 | the diversity ability.
26 |
27 | ## Catalog
28 | - [x] ImageNet-1K Training Code
29 | - [x] ImageNet-1K Fine-tuning Code
30 | - [x] Downstream Transfer (Detection, Segmentation) Code
31 |
32 |
33 |
34 | ## Results and Pre-trained Models
35 | ### ImageNet-1K trained models
36 |
37 | | name | resolution |acc@1 | #params | FLOPs | Throughput | model |
38 | |:-----------:|:---:|:---:|:-------:|:-----:|:----------:|:---:|
39 | | ResTv1-Lite | 224x224 | 77.2 | 11M | 1.4G | 1246 | [baidu](https://pan.baidu.com/s/1VVzrzZi_tD3yTp_lw9tU9A)
40 | | ResTv1-S | 224x224 | 79.6 | 14M | 1.9G | 1043 | [baidu](https://pan.baidu.com/s/1Y-MIzzzcQnmrbHfGGR0mrw)
41 | | ResTv1-B | 224x224 | 81.6 | 30M | 4.3G | 673 | [baidu](https://pan.baidu.com/s/1HhR9YxtGIhouZ0GEA4LYlw)
42 | | ResTv1-L | 224x224 | 83.6 | 52M | 7.9G | 429 | [baidu](https://pan.baidu.com/s/14c4u_oRoBcKOt1aTlsBBpw)
43 | | ResTv2-T | 224x224 | 82.3 | 30M | 4.1G | 826 | [baidu](https://pan.baidu.com/s/1LHAbsrXnGsjvAE3d5zhaHQ) |
44 | | ResTv2-T | 384x384 | 83.7 | 30M | 12.7G | 319 | [baidu](https://pan.baidu.com/s/1fEMs_OrDa_xF7Cw1DiBU9w) |
45 | | ResTv2-S | 224x224 | 83.2 | 41M | 6.0G | 687 | [baidu](https://pan.baidu.com/s/1nysV5MTtwsDLChrRa7vmZQ) |
46 | | ResTv2-S | 384x384 | 84.5 | 41M | 18.4G | 256 | [baidu](https://pan.baidu.com/s/1S1GERP-lYEJANYr17xk3dA) |
47 | | ResTv2-B | 224x224 | 83.7 | 56M | 7.9G | 582 | [baidu](https://pan.baidu.com/s/1GH3N2_rbZx816mN87UzYgQ) |
48 | | ResTv2-B | 384x384 | 85.1 | 56M | 24.3G | 210 | [baidu](https://pan.baidu.com/s/12RBMZmf6IlJIB3lIkeBH9Q) |
49 | | ResTv2-L | 224x224 | 84.2 | 87M | 13.8G | 415 | [baidu](https://pan.baidu.com/s/1A2huwk_Ii4ZzQllg1iHrEw) |
50 | | ResTv2-L | 384x384 | 85.4 | 87M | 42.4G | 141 | [baidu](https://pan.baidu.com/s/1dlxiWexb9mho63WdWS8nXg) |
51 |
52 |
53 | Note: Access code for `baidu` is `rest`. Pretrained models of ResTv1 is now available in [google drive](https://drive.google.com/drive/folders/1H6QUZsKYbU6LECtxzGHKqEeGbx1E8uQ9).
54 |
55 | ## Installation
56 | Please check [INSTALL.md](INSTALL.md) for installation instructions.
57 |
58 | ## Evaluation
59 | We give an example evaluation command for a ImageNet-1K pre-trained, then ImageNet-1K fine-tuned ResTv2-T:
60 |
61 | Single-GPU
62 | ```
63 | python main.py --model restv2_tiny --eval true \
64 | --resume restv2_tiny_384.pth \
65 | --input_size 384 --drop_path 0.1 \
66 | --data_path /path/to/imagenet-1k
67 | ```
68 |
69 | This should give
70 | ```
71 | * Acc@1 83.708 Acc@5 96.524 loss 0.777
72 | ```
73 |
74 | - For evaluating other model variants, change `--model`, `--resume`, `--input_size` accordingly. You can get the url to pre-trained models from the tables above.
75 | - Setting model-specific `--drop_path` is not strictly required in evaluation, as the `DropPath` module in timm behaves the same during evaluation; but it is required in training. See [TRAINING.md](TRAINING.md) or our paper for the values used for different models.
76 |
77 | ## Training
78 | See [TRAINING.md](TRAINING.md) for training and fine-tuning instructions.
79 |
80 | ## Acknowledgement
81 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library.
82 |
83 | ## License
84 | This project is released under the Apache License 2.0. Please see the [LICENSE](LICENSE) file for more information.
85 |
86 | ## Citation
87 | If you find this repository helpful, please consider citing:
88 |
89 | **ResTv1**
90 | ```
91 | @inproceedings{zhang2021rest,
92 | title={ResT: An Efficient Transformer for Visual Recognition},
93 | author={Qinglong Zhang and Yu-bin Yang},
94 | booktitle={Advances in Neural Information Processing Systems},
95 | year={2021},
96 | url={https://openreview.net/forum?id=6Ab68Ip4Mu}
97 | }
98 | ```
99 |
100 | **ResTv2**
101 | ```
102 | @article{zhang2022rest,
103 | title={ResT V2: Simpler, Faster and Stronger},
104 | author={Zhang, Qing-Long and Yang, Yu-Bin},
105 | journal={arXiv preprint arXiv:2204.07366},
106 | year={2022}
107 | ```
108 |
109 | ## Third-party Implementation
110 | [2022/05/26] ResT and ResT v2 have been integrated into [PaddleViT](https://github.com/BR-IDL/PaddleViT), checkout [here](https://github.com/BR-IDL/PaddleViT/tree/develop/image_classification/ResT) for the 3rd party implementation on Paddle framework!
111 |
--------------------------------------------------------------------------------
/TRAINING.md:
--------------------------------------------------------------------------------
1 | # Training
2 |
3 | We provide ImageNet-1K training, and fine-tuning commands here.
4 | Please check [INSTALL.md](INSTALL.md) for installation instructions first.
5 |
6 | ## ImageNet-1K Training
7 | Training on ImageNet-1K on a single machine:
8 | ```
9 | python -m torch.distributed.launch --nproc_per_node=8 main.py \
10 | --model restv2_tiny --drop_path 0.1 \
11 | --clip_grad 1.0 --warmup_epochs 50 --epochs 300 \
12 | --batch_size 256 --lr 1.5e-4 --update_freq 1 \
13 | --model_ema true --model_ema_eval true \
14 | --data_path /path/to/imagenet-1k
15 | --output_dir /path/to/save_results
16 | ```
17 |
18 | ## ImageNet-1K Fine-tuning
19 | ### Finetune from ImageNet-1K pre-training
20 | The training commands given above for ImageNet-1K use the default resolution (224). We also fine-tune these trained models with a larger resolution (384). Please specify the path or url to the checkpoint in `--finetune`.
21 |
22 | Fine-tuning on ImageNet-1K (384x384):
23 |
24 | Single-machine
25 | ```
26 | python -m torch.distributed.launch --nproc_per_node=8 main.py \
27 | --model restv2_tiny --drop_path 0.1 --input_size 384 \
28 | --batch_size 64 --lr 1.5e-5 --update_freq 1 \
29 | --warmup_epochs 0 --epochs 30 --weight_decay 1e-8 \
30 | --cutmix 0 --mixup 0 --clip_grad 1.0 \
31 | --finetune /path/to/checkpoint.pth \
32 | --data_path /path/to/imagenet-1k \
33 | --output_dir /path/to/save_results
34 | ```
35 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import os
8 | from torchvision import datasets, transforms
9 |
10 | from timm.data.constants import \
11 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
12 | from timm.data import create_transform
13 |
14 |
15 | def build_dataset(is_train, args):
16 | transform = build_transform(is_train, args)
17 |
18 | print("Transform = ")
19 | if isinstance(transform, tuple):
20 | for trans in transform:
21 | print(" - - - - - - - - - - ")
22 | for t in trans.transforms:
23 | print(t)
24 | else:
25 | for t in transform.transforms:
26 | print(t)
27 | print("---------------------------")
28 |
29 | if args.data_set == 'CIFAR':
30 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
31 | nb_classes = 100
32 | elif args.data_set == 'IMNET':
33 | print("reading from datapath", args.data_path)
34 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
35 | dataset = datasets.ImageFolder(root, transform=transform)
36 | nb_classes = 1000
37 | elif args.data_set == "image_folder":
38 | root = args.data_path if is_train else args.eval_data_path
39 | dataset = datasets.ImageFolder(root, transform=transform)
40 | nb_classes = args.nb_classes
41 | assert len(dataset.class_to_idx) == nb_classes
42 | else:
43 | raise NotImplementedError()
44 | print("Number of the class = %d" % nb_classes)
45 |
46 | return dataset, nb_classes
47 |
48 |
49 | def build_transform(is_train, args):
50 | resize_im = args.input_size > 32
51 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
52 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
53 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
54 |
55 | if is_train:
56 | # this should always dispatch to transforms_imagenet_train
57 | transform = create_transform(
58 | input_size=args.input_size,
59 | is_training=True,
60 | color_jitter=args.color_jitter,
61 | auto_augment=args.aa,
62 | interpolation=args.train_interpolation,
63 | re_prob=args.reprob,
64 | re_mode=args.remode,
65 | re_count=args.recount,
66 | mean=mean,
67 | std=std,
68 | )
69 | if not resize_im:
70 | transform.transforms[0] = transforms.RandomCrop(
71 | args.input_size, padding=4)
72 | return transform
73 |
74 | t = []
75 | if resize_im:
76 | # warping (no cropping) when evaluated at 384 or larger
77 | if args.input_size >= 384:
78 | t.append(
79 | transforms.Resize(
80 | (args.input_size, args.input_size),
81 | interpolation=transforms.InterpolationMode.BICUBIC),
82 | )
83 | print(f"Warping {args.input_size} size input images...")
84 | else:
85 | if args.crop_pct is None:
86 | args.crop_pct = 224 / 256
87 | size = int(args.input_size / args.crop_pct)
88 | t.append(
89 | # to maintain same ratio w.r.t. 224 images
90 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
91 | )
92 | t.append(transforms.CenterCrop(args.input_size))
93 |
94 | t.append(transforms.ToTensor())
95 | t.append(transforms.Normalize(mean, std))
96 | return transforms.Compose(t)
97 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import math
8 | from typing import Iterable, Optional
9 | import torch
10 | from timm.data import Mixup
11 | from timm.utils import accuracy, ModelEma
12 |
13 | import utils
14 |
15 |
16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
17 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
18 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
19 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
20 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
21 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False):
22 | model.train(True)
23 | metric_logger = utils.MetricLogger(delimiter=" ")
24 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
25 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
26 | header = 'Epoch: [{}]'.format(epoch)
27 | print_freq = 50
28 |
29 | optimizer.zero_grad()
30 |
31 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
32 | # if data_iter_step > 20:
33 | # break
34 | step = data_iter_step // update_freq
35 | if step >= num_training_steps_per_epoch:
36 | continue
37 | it = start_steps + step # global training iteration
38 | # Update LR & WD for the first acc
39 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
40 | for i, param_group in enumerate(optimizer.param_groups):
41 | if lr_schedule_values is not None:
42 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
43 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
44 | param_group["weight_decay"] = wd_schedule_values[it]
45 |
46 | samples = samples.to(device, non_blocking=True)
47 | targets = targets.to(device, non_blocking=True)
48 |
49 | if mixup_fn is not None:
50 | samples, targets = mixup_fn(samples, targets)
51 |
52 | if use_amp:
53 | with torch.cuda.amp.autocast():
54 | output = model(samples)
55 | loss = criterion(output, targets)
56 | else: # full precision
57 | output = model(samples)
58 | loss = criterion(output, targets)
59 |
60 | loss_value = loss.item()
61 |
62 | if not math.isfinite(loss_value): # this could trigger if using AMP
63 | print("Loss is {}, stopping training".format(loss_value))
64 | assert math.isfinite(loss_value)
65 |
66 | if use_amp:
67 | # this attribute is added by timm on one optimizer (adahessian)
68 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
69 | loss /= update_freq
70 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
71 | parameters=model.parameters(), create_graph=is_second_order,
72 | update_grad=(data_iter_step + 1) % update_freq == 0)
73 | if (data_iter_step + 1) % update_freq == 0:
74 | optimizer.zero_grad()
75 | if model_ema is not None:
76 | model_ema.update(model)
77 | else: # full precision
78 | loss /= update_freq
79 | loss.backward()
80 | if (data_iter_step + 1) % update_freq == 0:
81 | optimizer.step()
82 | optimizer.zero_grad()
83 | if model_ema is not None:
84 | model_ema.update(model)
85 |
86 | torch.cuda.synchronize()
87 |
88 | if mixup_fn is None:
89 | class_acc = (output.max(-1)[-1] == targets).float().mean()
90 | else:
91 | class_acc = None
92 | metric_logger.update(loss=loss_value)
93 | metric_logger.update(class_acc=class_acc)
94 | min_lr = 10.
95 | max_lr = 0.
96 | for group in optimizer.param_groups:
97 | min_lr = min(min_lr, group["lr"])
98 | max_lr = max(max_lr, group["lr"])
99 |
100 | metric_logger.update(lr=max_lr)
101 | metric_logger.update(min_lr=min_lr)
102 | weight_decay_value = None
103 | for group in optimizer.param_groups:
104 | if group["weight_decay"] > 0:
105 | weight_decay_value = group["weight_decay"]
106 | metric_logger.update(weight_decay=weight_decay_value)
107 | if use_amp:
108 | metric_logger.update(grad_norm=grad_norm)
109 |
110 | if log_writer is not None:
111 | log_writer.update(loss=loss_value, head="loss")
112 | log_writer.update(class_acc=class_acc, head="loss")
113 | log_writer.update(lr=max_lr, head="opt")
114 | log_writer.update(min_lr=min_lr, head="opt")
115 | log_writer.update(weight_decay=weight_decay_value, head="opt")
116 | if use_amp:
117 | log_writer.update(grad_norm=grad_norm, head="opt")
118 | log_writer.set_step()
119 |
120 | # gather the stats from all processes
121 | metric_logger.synchronize_between_processes()
122 | print("Averaged stats:", metric_logger)
123 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
124 |
125 |
126 | @torch.no_grad()
127 | def evaluate(data_loader, model, device, use_amp=False):
128 | criterion = torch.nn.CrossEntropyLoss()
129 |
130 | metric_logger = utils.MetricLogger(delimiter=" ")
131 | header = 'Test:'
132 |
133 | # switch to evaluation mode
134 | model.eval()
135 | i = 0
136 | for batch in metric_logger.log_every(data_loader, 10, header):
137 | i += 1
138 | images = batch[0]
139 | target = batch[-1]
140 |
141 | images = images.to(device, non_blocking=True)
142 | target = target.to(device, non_blocking=True)
143 |
144 | # compute output
145 | if use_amp:
146 | with torch.cuda.amp.autocast():
147 | output = model(images)
148 | loss = criterion(output, target)
149 | else:
150 | output = model(images)
151 | loss = criterion(output, target)
152 |
153 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
154 |
155 | batch_size = images.shape[0]
156 | metric_logger.update(loss=loss.item())
157 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
158 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
159 | # gather the stats from all processes
160 | metric_logger.synchronize_between_processes()
161 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
162 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
163 |
164 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
165 |
--------------------------------------------------------------------------------
/figures/fig_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wofmanaf/ResT/b9245d3f335e6c311798ce17b31c64990ee6c0e7/figures/fig_1.png
--------------------------------------------------------------------------------
/fourier_analysis.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import matplotlib.colors
4 | import matplotlib.pyplot as plt
5 | from matplotlib import rcParams
6 | import matplotlib.cm as cm
7 | from matplotlib.collections import LineCollection
8 | from matplotlib.ticker import FormatStrFormatter
9 | from matplotlib.pyplot import MultipleLocator
10 | import numpy as np
11 | import torch
12 | from torch.utils.data import DataLoader
13 | from torch.utils.data.sampler import Sampler
14 | import torchvision.datasets as datasets
15 | import torchvision.transforms as transforms
16 | from models.fft_rest_v2 import restv2_tiny
17 | from tqdm import tqdm
18 |
19 | config = {"font.family": "sans-serif", "font.size": 16, "mathtext.fontset": 'stix',
20 | 'font.sans-serif': ['Times New Roman'], }
21 | rcParams.update(config)
22 |
23 |
24 | # Sampler for pytorch loader. Given range r loader will only
25 | # return dataset[r] instead of whole dataset.
26 | class RangeSampler(Sampler):
27 | def __init__(self, r):
28 | self.r = r
29 |
30 | def __iter__(self):
31 | return iter(self.r)
32 |
33 | def __len__(self):
34 | return len(self.r)
35 |
36 |
37 | def fourier(x): # 2D Fourier transform
38 | f = torch.fft.fft2(x)
39 | f = f.abs() + 1e-6
40 | f = f.log()
41 | return f
42 |
43 |
44 | def shift(x): # shift Fourier transformed feature map
45 | b, c, h, w = x.shape
46 | return torch.roll(x, shifts=(int(h / 2), int(w / 2)), dims=(2, 3))
47 |
48 |
49 | def make_segments(x, y): # make segment for `plot_segment`
50 | points = np.array([x, y]).T.reshape(-1, 1, 2)
51 | segments = np.concatenate([points[:-1], points[1:]], axis=1)
52 | return segments
53 |
54 |
55 | def plot_segment(fig, ax, xs, ys, cmap_name="plasma", marker="o"): # plot with cmap segments
56 | z = np.linspace(0.0, 1.0, len(ys))
57 | z = np.asarray(z)
58 |
59 | cmap = cm.get_cmap(cmap_name)
60 | norm = plt.Normalize(0.0, 1.0)
61 | segments = make_segments(xs, ys)
62 | lc = LineCollection(segments, array=z, cmap=cmap_name, norm=norm,
63 | linewidth=2.5, alpha=1.0)
64 | ax.add_collection(lc)
65 |
66 | colors = [cmap(x) for x in xs]
67 | sc = ax.scatter(xs, ys, color=colors, marker=marker, zorder=100)
68 | fig.colorbar(sc, ticks=[0, 0.1, 1.])
69 |
70 |
71 | def plot(latents, name=''):
72 | # latents: list of hidden feature maps in the latent space
73 | fig, ax = plt.subplots(1, 1, figsize=(3.6, 4), dpi=300)
74 | fig.set_tight_layout(True)
75 |
76 | # Fourier transform feature maps
77 | fourier_latents = []
78 | for latent in latents:
79 | if len(latent.shape) == 3: # for vit
80 | b, n, c = latent.shape
81 | h, w = int(math.sqrt(n)), int(math.sqrt(n))
82 | latent = latent.permute(0, 2, 1).reshape(b, c, h, w)
83 | elif len(latent.shape) == 4: # for cnn
84 | b, c, h, w = latent.shape
85 | else:
86 | raise Exception("shape: %s" % str(latent.shape))
87 | latent = fourier(latent)
88 | latent = shift(latent).mean(dim=(0, 1))
89 | latent = latent.diag()[int(h / 2):] # only use the half-diagonal components
90 | latent = latent - latent[0] # visualize 'relative' log amplitudes (i.e., low_freq - high_freq)
91 | fourier_latents.append(latent)
92 | # plot fourier transformed relative log amplitudes
93 | for i, latent in enumerate(reversed(fourier_latents)):
94 | freq = np.linspace(0, 1, len(latent))
95 | ax.plot(freq, latent, color=cm.plasma_r(i / len(fourier_latents)))
96 |
97 | ax.set_xlim(left=0, right=1)
98 | x_major_locator = MultipleLocator(0.5)
99 | y_major_locator = MultipleLocator(2.0)
100 |
101 | ax.set_xlabel("Frequency")
102 | ax.set_ylabel(r"$\Delta$ Log amplitude")
103 |
104 | ax.xaxis.set_major_locator(x_major_locator)
105 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.1fπ'))
106 |
107 | ax.yaxis.set_major_locator(y_major_locator)
108 | plt.savefig(name)
109 | # plt.show()
110 |
111 |
112 | def main():
113 | model = restv2_tiny().cuda()
114 | checkpoint = torch.load("output_dir/restv2_tiny_224.pth", map_location='cpu')
115 | model.load_state_dict(checkpoint['model'])
116 | model = model.eval()
117 |
118 | val_dir = '/data/ilsvrc2012/val/'
119 | batch_size = 1
120 | num_workers = 4
121 | batch = 0
122 |
123 | sample_range = range(10000 * batch, 10000 * (batch + 1))
124 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
125 | std=[0.229, 0.224, 0.225])
126 | val_loader = DataLoader(
127 | datasets.ImageFolder(val_dir, transforms.Compose([
128 | transforms.Resize(256),
129 | transforms.CenterCrop(224),
130 | transforms.ToTensor(),
131 | normalize,
132 | ])),
133 | batch_size=batch_size,
134 | shuffle=False,
135 | num_workers=num_workers,
136 | pin_memory=True,
137 | sampler=RangeSampler(sample_range)
138 | )
139 |
140 | latents_attn = []
141 | latents_up = []
142 | latents_com = []
143 |
144 | for i, (samples, targets) in enumerate(tqdm(val_loader, total=len(val_loader), desc='Loading Images')):
145 | samples = samples.cuda()
146 | with torch.no_grad():
147 | feats = model(samples)
148 |
149 | for j in range(11):
150 | if i == 0:
151 | latents_attn.append(feats[j]["attn"].cpu())
152 | latents_up.append(feats[j]["up"].cpu())
153 | latents_com.append(feats[j]["com"].cpu())
154 | else:
155 | latents_attn[j] = torch.cat([latents_attn[j], feats[j]["attn"].cpu()], dim=0)
156 | latents_up[j] = torch.cat([latents_up[j], feats[j]["up"].cpu()], dim=0)
157 | latents_com[j] = torch.cat([latents_com[j], feats[j]["com"].cpu()], dim=0)
158 | plot(latents_attn, name="attn.png")
159 | plot(latents_up, name="up.png")
160 | plot(latents_com, name="combine.png")
161 |
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/get_flops.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import argparse
8 | import torch
9 | from timm.models import create_model
10 | import models
11 |
12 | try:
13 | from mmcv.cnn import get_model_complexity_info
14 | from mmcv.cnn.utils.flops_counter import get_model_complexity_info, flops_to_string, params_to_string
15 | except ImportError:
16 | raise ImportError('Please upgrade mmcv to >0.6.2')
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Get FLOPS of a classification model')
21 | parser.add_argument('--model', default='restv2_tiny', type=str, metavar='MODEL',
22 | help='train config file path')
23 | parser.add_argument(
24 | '--shape',
25 | type=int,
26 | nargs='+',
27 | default=[224, 224],
28 | help='input image size')
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def attn_flops(h, w, r, dim):
34 | return 2 * h * w * (h // r) * (w // r) * dim
35 |
36 |
37 | def get_flops(model, input_shape):
38 | flops, params = get_model_complexity_info(model, input_shape, as_strings=False)
39 | if 'rest' in model.name:
40 | _, H, W = input_shape
41 | # calculate flops of ResTv2
42 | stage1 = attn_flops(H // 4, W // 4,
43 | model.stage1[0].attn.sr_ratio,
44 | model.stage1[0].attn.dim) * model.depths[0]
45 | stage2 = attn_flops(H // 8, W // 8,
46 | model.stage2[0].attn.sr_ratio,
47 | model.stage2[0].attn.dim) * model.depths[1]
48 | stage3 = attn_flops(H // 16, W // 16,
49 | model.stage3[0].attn.sr_ratio,
50 | model.stage3[0].attn.dim) * model.depths[2]
51 | stage4 = attn_flops(H // 32, W // 32,
52 | model.stage4[0].attn.sr_ratio,
53 | model.stage4[0].attn.dim) * model.depths[3]
54 | flops += stage1 + stage2 + stage3 + stage4
55 | return flops_to_string(flops), params_to_string(params)
56 |
57 |
58 | def main():
59 | args = parse_args()
60 |
61 | if len(args.shape) == 1:
62 | input_shape = (3, args.shape[0], args.shape[0])
63 | elif len(args.shape) == 2:
64 | input_shape = (3,) + tuple(args.shape)
65 | else:
66 | raise ValueError('invalid input shape')
67 |
68 | model = create_model(
69 | args.model,
70 | pretrained=False,
71 | num_classes=1000
72 | )
73 | model.name = args.model
74 | if torch.cuda.is_available():
75 | model.cuda()
76 | model.eval()
77 |
78 | flops, params = get_flops(model, input_shape)
79 |
80 | split_line = '=' * 30
81 | print(f'{split_line}\nInput shape: {input_shape}\n'
82 | f'Flops: {flops}\nParams: {params}\n{split_line}')
83 | print('!!!Please be cautious if you use the results in papers. '
84 | 'You may need to check if all ops are supported and verify that the '
85 | 'flops computation is correct.')
86 |
87 |
88 | if __name__ == '__main__':
89 | main()
90 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .rest import *
2 | from .rest_v2 import *
3 |
--------------------------------------------------------------------------------
/models/rest.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from timm.models.registry import register_model
12 |
13 | __all__ = ['rest_lite', 'rest_small', 'rest_base', 'rest_large']
14 |
15 |
16 | def _cfg(url='', **kwargs):
17 | return {
18 | 'url': url,
19 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
20 | 'crop_pct': .9, 'interpolation': 'bicubic',
21 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
22 | 'classifier': 'head',
23 | **kwargs
24 | }
25 |
26 |
27 | default_cfgs = {
28 | 'rest_lite': _cfg(),
29 | 'rest_small': _cfg(),
30 | 'rest_base': _cfg(),
31 | 'rest_large': _cfg(),
32 | }
33 |
34 |
35 | class Mlp(nn.Module):
36 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
37 | super().__init__()
38 | out_features = out_features or in_features
39 | hidden_features = hidden_features or in_features
40 | self.fc1 = nn.Linear(in_features, hidden_features)
41 | self.act = act_layer()
42 | self.fc2 = nn.Linear(hidden_features, out_features)
43 | self.drop = nn.Dropout(drop)
44 |
45 | def forward(self, x):
46 | x = self.fc1(x)
47 | x = self.act(x)
48 | x = self.drop(x)
49 | x = self.fc2(x)
50 | x = self.drop(x)
51 | return x
52 |
53 |
54 | class Attention(nn.Module):
55 | def __init__(self,
56 | dim,
57 | num_heads=8,
58 | qkv_bias=False,
59 | qk_scale=None,
60 | attn_drop=0.,
61 | proj_drop=0.,
62 | sr_ratio=1,
63 | apply_transform=False):
64 | super().__init__()
65 | self.num_heads = num_heads
66 | head_dim = dim // num_heads
67 | self.scale = qk_scale or head_dim ** -0.5
68 |
69 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
70 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
71 | self.attn_drop = nn.Dropout(attn_drop)
72 | self.proj = nn.Linear(dim, dim)
73 | self.proj_drop = nn.Dropout(proj_drop)
74 |
75 | self.sr_ratio = sr_ratio
76 | if sr_ratio > 1:
77 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
78 | self.sr_norm = nn.LayerNorm(dim)
79 |
80 | self.apply_transform = apply_transform and num_heads > 1
81 | if self.apply_transform:
82 | self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1)
83 | self.transform_norm = nn.InstanceNorm2d(self.num_heads)
84 |
85 | def forward(self, x, H, W):
86 | B, N, C = x.shape
87 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
88 | if self.sr_ratio > 1:
89 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
90 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
91 | x_ = self.sr_norm(x_)
92 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
93 | else:
94 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
95 | k, v = kv[0], kv[1]
96 |
97 | attn = (q @ k.transpose(-2, -1)) * self.scale
98 | if self.apply_transform:
99 | attn = self.transform_conv(attn)
100 | attn = attn.softmax(dim=-1)
101 | attn = self.transform_norm(attn)
102 | else:
103 | attn = attn.softmax(dim=-1)
104 |
105 | attn = self.attn_drop(attn)
106 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
107 | x = self.proj(x)
108 | x = self.proj_drop(x)
109 | return x
110 |
111 |
112 | class Block(nn.Module):
113 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
114 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, apply_transform=False):
115 | super().__init__()
116 | self.norm1 = norm_layer(dim)
117 | self.attn = Attention(
118 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
119 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, apply_transform=apply_transform)
120 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
121 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
122 | self.norm2 = norm_layer(dim)
123 | mlp_hidden_dim = int(dim * mlp_ratio)
124 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
125 |
126 | def forward(self, x, H, W):
127 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
128 | x = x + self.drop_path(self.mlp(self.norm2(x)))
129 | return x
130 |
131 |
132 | class PA(nn.Module):
133 | def __init__(self, dim):
134 | super().__init__()
135 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
136 | self.sigmoid = nn.Sigmoid()
137 |
138 | def forward(self, x):
139 | return x * self.sigmoid(self.pa_conv(x))
140 |
141 |
142 | class GL(nn.Module):
143 | def __init__(self, dim):
144 | super().__init__()
145 | self.gl_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
146 |
147 | def forward(self, x):
148 | return x + self.gl_conv(x)
149 |
150 |
151 | class PatchEmbed(nn.Module):
152 | """ Image to Patch Embedding"""
153 |
154 | def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False):
155 | super().__init__()
156 | self.patch_size = to_2tuple(patch_size)
157 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2)
158 | self.norm = nn.BatchNorm2d(out_ch)
159 |
160 | self.with_pos = with_pos
161 | if self.with_pos:
162 | self.pos = PA(out_ch)
163 |
164 | def forward(self, x):
165 | B, C, H, W = x.shape
166 | x = self.conv(x)
167 | x = self.norm(x)
168 | if self.with_pos:
169 | x = self.pos(x)
170 | x = x.flatten(2).transpose(1, 2)
171 | H, W = H // self.patch_size[0], W // self.patch_size[1]
172 | return x, (H, W)
173 |
174 |
175 | class BasicStem(nn.Module):
176 | def __init__(self, in_ch=3, out_ch=64, with_pos=False):
177 | super(BasicStem, self).__init__()
178 | hidden_ch = out_ch // 2
179 | self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, stride=2, padding=1, bias=False)
180 | self.norm1 = nn.BatchNorm2d(hidden_ch)
181 | self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, stride=1, padding=1, bias=False)
182 | self.norm2 = nn.BatchNorm2d(hidden_ch)
183 | self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False)
184 |
185 | self.act = nn.ReLU(inplace=True)
186 | self.with_pos = with_pos
187 | if self.with_pos:
188 | self.pos = PA(out_ch)
189 |
190 | def forward(self, x):
191 | x = self.conv1(x)
192 | x = self.norm1(x)
193 | x = self.act(x)
194 |
195 | x = self.conv2(x)
196 | x = self.norm2(x)
197 | x = self.act(x)
198 |
199 | x = self.conv3(x)
200 | if self.with_pos:
201 | x = self.pos(x)
202 | return x
203 |
204 |
205 | class Stem(nn.Module):
206 | def __init__(self, in_ch=3, out_ch=64, with_pos=False):
207 | super(Stem, self).__init__()
208 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=7, stride=2, padding=3, bias=False)
209 | self.norm = nn.BatchNorm2d(out_ch)
210 | self.act = nn.ReLU(inplace=True)
211 |
212 | self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
213 | self.with_pos = with_pos
214 | if self.with_pos:
215 | self.pos = PA(out_ch)
216 |
217 | def forward(self, x):
218 | x = self.conv(x)
219 | x = self.norm(x)
220 | x = self.act(x)
221 | x = self.max_pool(x)
222 |
223 | if self.with_pos:
224 | x = self.pos(x)
225 | return x
226 |
227 |
228 | class ResT(nn.Module):
229 | def __init__(self, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
230 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False,
231 | qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
232 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
233 | norm_layer=nn.LayerNorm, apply_transform=False):
234 | super().__init__()
235 | self.num_classes = num_classes
236 | self.depths = depths
237 | self.apply_transform = apply_transform
238 |
239 | self.stem = BasicStem(in_ch=in_chans, out_ch=embed_dims[0], with_pos=True)
240 |
241 | self.patch_embed_2 = PatchEmbed(patch_size=2, in_ch=embed_dims[0], out_ch=embed_dims[1], with_pos=True)
242 | self.patch_embed_3 = PatchEmbed(patch_size=2, in_ch=embed_dims[1], out_ch=embed_dims[2], with_pos=True)
243 | self.patch_embed_4 = PatchEmbed(patch_size=2, in_ch=embed_dims[2], out_ch=embed_dims[3], with_pos=True)
244 |
245 | # transformer encoder
246 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
247 | cur = 0
248 |
249 | self.stage1 = nn.ModuleList([
250 | Block(embed_dims[0], num_heads[0], mlp_ratios[0], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
251 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[0], apply_transform=apply_transform)
252 | for i in range(self.depths[0])])
253 |
254 | cur += depths[0]
255 | self.stage2 = nn.ModuleList([
256 | Block(embed_dims[1], num_heads[1], mlp_ratios[1], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
257 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[1], apply_transform=apply_transform)
258 | for i in range(self.depths[1])])
259 |
260 | cur += depths[1]
261 | self.stage3 = nn.ModuleList([
262 | Block(embed_dims[2], num_heads[2], mlp_ratios[2], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
263 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[2], apply_transform=apply_transform)
264 | for i in range(self.depths[2])])
265 |
266 | cur += depths[2]
267 | self.stage4 = nn.ModuleList([
268 | Block(embed_dims[3], num_heads[3], mlp_ratios[3], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
269 | drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[3], apply_transform=apply_transform)
270 | for i in range(self.depths[3])])
271 |
272 | self.norm = norm_layer(embed_dims[3])
273 |
274 | # classification head
275 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
276 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
277 |
278 | # init weights
279 | self.apply(self._init_weights)
280 |
281 | def _init_weights(self, m):
282 | if isinstance(m, nn.Conv2d):
283 | trunc_normal_(m.weight, std=0.02)
284 | elif isinstance(m, nn.Linear):
285 | trunc_normal_(m.weight, std=0.02)
286 | if m.bias is not None:
287 | nn.init.constant_(m.bias, 0)
288 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
289 | nn.init.constant_(m.weight, 1.0)
290 | nn.init.constant_(m.bias, 0)
291 |
292 | def forward(self, x):
293 | x = self.stem(x)
294 | B, _, H, W = x.shape
295 | x = x.flatten(2).permute(0, 2, 1)
296 |
297 | # stage 1
298 | for blk in self.stage1:
299 | x = blk(x, H, W)
300 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
301 |
302 | # stage 2
303 | x, (H, W) = self.patch_embed_2(x)
304 | for blk in self.stage2:
305 | x = blk(x, H, W)
306 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
307 |
308 | # stage 3
309 | x, (H, W) = self.patch_embed_3(x)
310 | for blk in self.stage3:
311 | x = blk(x, H, W)
312 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
313 |
314 | # stage 4
315 | x, (H, W) = self.patch_embed_4(x)
316 | for blk in self.stage4:
317 | x = blk(x, H, W)
318 | x = self.norm(x)
319 |
320 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
321 | x = self.avg_pool(x).flatten(1)
322 | x = self.head(x)
323 | return x
324 |
325 |
326 | @register_model
327 | def rest_lite(pretrained=False, **kwargs):
328 | model = ResT(embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
329 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs)
330 | model.default_cfg = _cfg()
331 | return model
332 |
333 |
334 | @register_model
335 | def rest_small(pretrained=False, **kwargs):
336 | model = ResT(embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
337 | depths=[2, 2, 6, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs)
338 | model.default_cfg = _cfg()
339 | return model
340 |
341 |
342 | @register_model
343 | def rest_base(pretrained=False, **kwargs):
344 | model = ResT(embed_dims=[96, 192, 384, 768], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
345 | depths=[2, 2, 6, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs)
346 | model.default_cfg = _cfg()
347 | return model
348 |
349 |
350 | @register_model
351 | def rest_large(pretrained=False, **kwargs):
352 | model = ResT(embed_dims=[96, 192, 384, 768], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=True,
353 | depths=[2, 2, 18, 2], sr_ratios=[8, 4, 2, 1], apply_transform=True, **kwargs)
354 | model.default_cfg = _cfg()
355 | return model
356 |
--------------------------------------------------------------------------------
/models/rest_v2.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from timm.models.registry import register_model
12 |
13 |
14 | class Mlp(nn.Module):
15 | def __init__(self, dim):
16 | super().__init__()
17 | self.fc1 = nn.Linear(dim, 4 * dim)
18 | self.act = nn.GELU()
19 | self.fc2 = nn.Linear(4 * dim, dim)
20 |
21 | def forward(self, x):
22 | x = self.fc1(x)
23 | x = self.act(x)
24 | x = self.fc2(x)
25 | return x
26 |
27 |
28 | class Attention(nn.Module):
29 | def __init__(self,
30 | dim,
31 | num_heads=8,
32 | sr_ratio=1):
33 | super().__init__()
34 | self.num_heads = num_heads
35 | head_dim = dim // num_heads
36 | self.scale = head_dim ** -0.5
37 | self.dim = dim
38 |
39 | self.q = nn.Linear(dim, dim, bias=True)
40 | self.kv = nn.Linear(dim, dim * 2, bias=True)
41 |
42 | self.sr_ratio = sr_ratio
43 | if sr_ratio > 1:
44 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
45 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6)
46 |
47 | self.up = nn.Sequential(
48 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim),
49 | nn.PixelShuffle(upscale_factor=sr_ratio)
50 | )
51 | self.up_norm = nn.LayerNorm(dim, eps=1e-6)
52 |
53 | self.proj = nn.Linear(dim, dim)
54 |
55 | def forward(self, x, H, W):
56 | B, N, C = x.shape
57 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
58 | if self.sr_ratio > 1:
59 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
60 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
61 | x = self.sr_norm(x)
62 |
63 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
64 | k, v = kv[0], kv[1]
65 | attn = (q @ k.transpose(-2, -1)) * self.scale
66 | attn = attn.softmax(dim=-1)
67 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
68 |
69 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio)
70 | identity = self.up(identity).flatten(2).transpose(1, 2)
71 | x = self.proj(x + self.up_norm(identity))
72 | return x
73 |
74 |
75 | class Block(nn.Module):
76 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.):
77 | super().__init__()
78 | self.norm1 = nn.LayerNorm(dim, eps=1e-6)
79 | self.attn = Attention(dim, num_heads=num_heads, sr_ratio=sr_ratio)
80 |
81 | self.norm2 = nn.LayerNorm(dim, eps=1e-6)
82 | self.mlp = Mlp(dim)
83 |
84 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
85 |
86 | def forward(self, x, H, W):
87 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm
88 | x = x + self.drop_path(self.mlp(self.norm2(x)))
89 | return x
90 |
91 |
92 | class PA(nn.Module):
93 | def __init__(self, dim):
94 | super().__init__()
95 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True)
96 | self.sigmoid = nn.Sigmoid()
97 |
98 | def forward(self, x):
99 | return x * self.sigmoid(self.pa_conv(x))
100 |
101 |
102 | class Stem(nn.Module):
103 | def __init__(self, in_dim=3, out_dim=96, patch_size=2):
104 | super().__init__()
105 | self.patch_size = to_2tuple(patch_size)
106 | self.proj = nn.Conv2d(in_dim, out_dim, kernel_size=patch_size, stride=patch_size)
107 | self.norm = nn.LayerNorm(out_dim, eps=1e-6)
108 |
109 | def forward(self, x):
110 | B, C, H, W = x.shape
111 | x = self.proj(x)
112 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
113 | x = self.norm(x)
114 | H, W = H // self.patch_size[0], W // self.patch_size[1]
115 | return x, (H, W)
116 |
117 |
118 | class ConvStem(nn.Module):
119 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
120 | super().__init__()
121 | self.patch_size = to_2tuple(patch_size)
122 | stem = []
123 | in_dim, out_dim = in_ch, out_ch // 2
124 | for i in range(2):
125 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False))
126 | stem.append(nn.BatchNorm2d(out_dim))
127 | stem.append(nn.ReLU(inplace=True))
128 | in_dim, out_dim = out_dim, out_dim * 2
129 |
130 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1))
131 | self.proj = nn.Sequential(*stem)
132 |
133 | self.with_pos = with_pos
134 | if self.with_pos:
135 | self.pos = PA(out_ch)
136 |
137 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
138 |
139 | def forward(self, x):
140 | B, C, H, W = x.shape
141 | x = self.proj(x)
142 | if self.with_pos:
143 | x = self.pos(x)
144 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
145 | x = self.norm(x)
146 | H, W = H // self.patch_size[0], W // self.patch_size[1]
147 | return x, (H, W)
148 |
149 |
150 | class PatchEmbed(nn.Module):
151 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
152 | super().__init__()
153 | self.patch_size = to_2tuple(patch_size)
154 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2)
155 |
156 | self.with_pos = with_pos
157 | if self.with_pos:
158 | self.pos = PA(out_ch)
159 |
160 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
161 |
162 | def forward(self, x):
163 | B, C, H, W = x.shape
164 | x = self.proj(x)
165 | if self.with_pos:
166 | x = self.pos(x)
167 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
168 | x = self.norm(x)
169 | H, W = H // self.patch_size[0], W // self.patch_size[1]
170 | return x, (H, W)
171 |
172 |
173 | class ResTV2(nn.Module):
174 | def __init__(self, in_chans=3, num_classes=1000, embed_dims=[96, 192, 384, 768],
175 | num_heads=[1, 2, 4, 8], drop_path_rate=0.,
176 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1]):
177 | super().__init__()
178 | self.num_classes = num_classes
179 | self.depths = depths
180 |
181 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4)
182 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2)
183 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2)
184 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2)
185 |
186 | # transformer encoder
187 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
188 | cur = 0
189 | self.stage1 = nn.ModuleList([
190 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i])
191 | for i in range(depths[0])
192 | ])
193 |
194 | cur += depths[0]
195 | self.stage2 = nn.ModuleList([
196 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i])
197 | for i in range(depths[1])
198 | ])
199 |
200 | cur += depths[1]
201 | self.stage3 = nn.ModuleList([
202 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i])
203 | for i in range(depths[2])
204 | ])
205 |
206 | cur += depths[2]
207 | self.stage4 = nn.ModuleList([
208 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i])
209 | for i in range(depths[3])
210 | ])
211 |
212 | self.norm = nn.LayerNorm(embed_dims[-1], eps=1e-6) # final norm layer
213 | # classification head
214 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
215 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
216 |
217 | # init weights
218 | self.apply(self._init_weights)
219 |
220 | def _init_weights(self, m):
221 | if isinstance(m, (nn.Conv2d, nn.Linear)):
222 | trunc_normal_(m.weight, std=0.02)
223 | if m.bias is not None:
224 | nn.init.constant_(m.bias, 0)
225 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
226 | nn.init.constant_(m.weight, 1.)
227 | nn.init.constant_(m.bias, 0.)
228 |
229 | def forward(self, x):
230 | B, _, H, W = x.shape
231 | x, (H, W) = self.stem(x)
232 |
233 | # stage 1
234 | for blk in self.stage1:
235 | x = blk(x, H, W)
236 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
237 |
238 | # stage 2
239 | x, (H, W) = self.patch_2(x)
240 | for blk in self.stage2:
241 | x = blk(x, H, W)
242 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
243 |
244 | # stage 3
245 | x, (H, W) = self.patch_3(x)
246 | for blk in self.stage3:
247 | x = blk(x, H, W)
248 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
249 |
250 | # stage 4
251 | x, (H, W) = self.patch_4(x)
252 | for blk in self.stage4:
253 | x = blk(x, H, W)
254 | x = self.norm(x)
255 |
256 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
257 | x = self.avg_pool(x).flatten(1)
258 | x = self.head(x)
259 | return x
260 |
261 |
262 | @register_model
263 | def restv2_tiny(pretrained=False, **kwargs): # 82.3|4.7G|24M -> |3.92G|30.37M 4.5G|30.33M
264 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 2, 6, 2], **kwargs)
265 | return model
266 |
267 |
268 | @register_model
269 | def restv2_small(pretrained=False, **kwargs): # 83.6|7.0G|35M -> |5.78G|40.94M
270 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 2, 12, 2], **kwargs)
271 | return model
272 |
273 |
274 | @register_model
275 | def restv2_base(pretrained=False, **kwargs): # 84.4|10.2G|52M -> |7.25G|55.75M
276 | model = ResTV2(embed_dims=[96, 192, 384, 768], depths=[1, 3, 16, 3], **kwargs)
277 | return model
278 |
279 |
280 | @register_model
281 | def restv2_large(pretrained=False, **kwargs): # 85.3|39.6|218M -> |14.09G|98.61M
282 | model = ResTV2(num_heads=[2, 4, 8, 16], embed_dims=[128, 256, 512, 1024], depths=[2, 3, 16, 2], **kwargs)
283 | return model
--------------------------------------------------------------------------------
/object_detection/GETTING_STARTED.md:
--------------------------------------------------------------------------------
1 | ## Getting Started with Detectron2
2 |
3 | This document provides a brief intro of the usage of builtin command-line tools in detectron2.
4 |
5 | For a tutorial that involves actual coding with the API,
6 | see our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
7 | which covers how to run inference with an
8 | existing model, and how to train a builtin model on a custom dataset.
9 |
10 |
11 | ### Inference Demo with Pre-trained Models
12 |
13 | 1. Pick a model and its config file from
14 | [model zoo](MODEL_ZOO.md),
15 | for example, `mask_rcnn_R_50_FPN_3x.yaml`.
16 | 2. We provide `demo.py` that is able to demo builtin configs. Run it with:
17 | ```
18 | cd demo/
19 | python demo.py --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml \
20 | --input input1.jpg input2.jpg \
21 | [--other-options]
22 | --opts MODEL.WEIGHTS detectron2://COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x/137849600/model_final_f10217.pkl
23 | ```
24 | The configs are made for training, therefore we need to specify `MODEL.WEIGHTS` to a model from model zoo for evaluation.
25 | This command will run the inference and show visualizations in an OpenCV window.
26 |
27 | For details of the command line arguments, see `demo.py -h` or look at its source code
28 | to understand its behavior. Some common arguments are:
29 | * To run __on your webcam__, replace `--input files` with `--webcam`.
30 | * To run __on a video__, replace `--input files` with `--video-input video.mp4`.
31 | * To run __on cpu__, add `MODEL.DEVICE cpu` after `--opts`.
32 | * To save outputs to a directory (for images) or a file (for webcam or video), use `--output`.
33 |
34 |
35 | ### Training & Evaluation in Command Line
36 |
37 | We provide two scripts in "tools/plain_train_net.py" and "tools/train_net.py",
38 | that are made to train all the configs provided in detectron2. You may want to
39 | use it as a reference to write your own training script.
40 |
41 | Compared to "train_net.py", "plain_train_net.py" supports fewer default
42 | features. It also includes fewer abstraction, therefore is easier to add custom
43 | logic.
44 |
45 | To train a model with "train_net.py", first
46 | setup the corresponding datasets following
47 | [datasets/README.md](./datasets/README.md),
48 | then run:
49 | ```
50 | cd tools/
51 | ./train_net.py --num-gpus 8 \
52 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
53 | ```
54 |
55 | The configs are made for 8-GPU training.
56 | To train on 1 GPU, you may need to [change some parameters](https://arxiv.org/abs/1706.02677), e.g.:
57 | ```
58 | ./train_net.py \
59 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
60 | --num-gpus 1 SOLVER.IMS_PER_BATCH 2 SOLVER.BASE_LR 0.0025
61 | ```
62 |
63 | To evaluate a model's performance, use
64 | ```
65 | ./train_net.py \
66 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \
67 | --eval-only MODEL.WEIGHTS /path/to/checkpoint_file
68 | ```
69 | For more options, see `./train_net.py -h`.
70 |
71 | ### Use Detectron2 APIs in Your Code
72 |
73 | See our [Colab Notebook](https://colab.research.google.com/drive/16jcaJoc6bCFAQ96jDe2HwtXj7BMD_-m5)
74 | to learn how to use detectron2 APIs to:
75 | 1. run inference with an existing model
76 | 2. train a builtin model on a custom dataset
77 |
78 | See [detectron2/projects](https://github.com/facebookresearch/detectron2/tree/master/projects)
79 | for more ways to build your project on detectron2.
80 |
--------------------------------------------------------------------------------
/object_detection/README.md:
--------------------------------------------------------------------------------
1 | # Object detection
2 | ResTv1 and ResTv2 for Object Detection by detectron2
3 |
4 | This repo contains the supported code and configuration files to reproduce object detection results of ResTv1 and ResTv2. It is based on [detectron2](https://github.com/facebookresearch/detectron2).
5 |
6 |
7 | ## Results and Models
8 |
9 | ### RetinaNet
10 |
11 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FPS | config | model |
12 | |:------------:| :---: |:-------:|:-------:|:--------:|:-------:| :---: |:-----------------------------------------------------------:|:--------------------------------------------------------:|
13 | | ResTv1-S-FPN | ImageNet-1K | 1x | 40.3 | - | 23.4 | - | [config](configs/ResTv1/retinanet_rest_small_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/13YXVRQeNcF_3Txns8eJzZw) |
14 | | ResTv1-B-FPN | ImageNet-1K | 1x | 42.0 | - | 40.5 | - | [config](configs/ResTv1/retinanet_rest_base_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1hMRM5YEIGsfWfvqbuC7JWA) |
15 |
16 | ### Mask R-CNN
17 |
18 | | Backbone | Pretrain | Lr Schd | box mAP | mask mAP | #params | FPS | config | model |
19 | |:------------:| :---: |:-------:|:-------:|:--------:|:-------:|:----:|:-----------------------------------------------------------:|:--------------------------------------------------------:|
20 | | ResTv1-S-FPN | ImageNet-1K | 1x | 39.6 | 37.2 | 31.2 | - | [config](configs/ResTv1/mask_rcnn_rest_small_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1UfDsRGwgZcydXtj56ZFoDg) |
21 | | ResTv1-B-FPN | ImageNet-1K | 1x | 41.6 | 38.7 | 49.8 | - | [config](configs/ResTv1/mask_rcnn_rest_base_FPN_1x.yaml) | [baidu](https://pan.baidu.com/s/1oSdMGTSBK_JDcLEq3XjY8w) |
22 | | ResTv2-T-FPN | ImageNet-1K | 3x | 47.6 | 43.2 | 49.9 | 25.0 | [config](configs/ResTv2/mask_rcnn_restv2_tiny_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/16fDcEupHBZ1zHyzFFZvM3g) |
23 | | ResTv2-S-FPN | ImageNet-1K | 3x | 48.1 | 43.3 | 60.7 | 21.3 | [config](configs/ResTv2/mask_rcnn_restv2_small_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/1UfDsRGwgZcydXtj56ZFoDg) |
24 | | ResTv2-B-FPN | ImageNet-1K | 3x | 48.7 | 43.9 | 75.5 | 18.3 | [config](configs/ResTv2/mask_rcnn_restv2_base_FPN_3x.yaml) | [baidu](https://pan.baidu.com/s/1zHQM0KqgtqQzg0-mdtx-Jg) |
25 |
26 |
27 | ## Usage
28 | Please refer to [get_started.md](https://detectron2.readthedocs.io/en/latest/tutorials/getting_started.html) for installation and dataset preparation.
29 |
30 | note: you need convert the original pretrained weights to d2 format by [convert_to_d2.py](convert_to_d2.py)
31 |
32 | ## Citation
33 | If you find this repository helpful, please consider citing:
34 |
35 | **ResTv1**
36 | ```
37 | @inproceedings{zhang2021rest,
38 | title={ResT: An Efficient Transformer for Visual Recognition},
39 | author={Qinglong Zhang and Yu-bin Yang},
40 | booktitle={Advances in Neural Information Processing Systems},
41 | year={2021},
42 | url={https://openreview.net/forum?id=6Ab68Ip4Mu}
43 | }
44 | ```
45 |
46 | **ResTv2**
47 | ```
48 | @article{zhang2022rest,
49 | title={ResT V2: Simpler, Faster and Stronger},
50 | author={Zhang, Qing-Long and Yang, Yu-Bin},
51 | journal={arXiv preprint arXiv:2204.07366},
52 | year={2022}
53 | ```
--------------------------------------------------------------------------------
/object_detection/analyze_model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import logging
5 | import numpy as np
6 | from collections import Counter
7 | import tqdm
8 | from fvcore.nn import flop_count_table # can also try flop_count_str
9 |
10 | from detectron2.checkpoint import DetectionCheckpointer
11 | from detectron2.config import get_cfg
12 | from detectron2.data import build_detection_test_loader
13 | from detectron2.engine import default_argument_parser, default_setup
14 | from detectron2.modeling import build_model
15 | from detectron2.utils.analysis import (
16 | FlopCountAnalysis,
17 | activation_count_operators,
18 | parameter_count_table,
19 | )
20 |
21 |
22 | from detectron2.utils.logger import setup_logger
23 | from restv2 import add_restv2_config
24 |
25 | logger = logging.getLogger("detectron2")
26 |
27 |
28 | def setup(args):
29 | """
30 | Create configs and perform basic setups.
31 | """
32 | cfg = get_cfg()
33 | add_restv2_config(cfg)
34 | cfg.merge_from_file(args.config_file)
35 | cfg.merge_from_list(args.opts)
36 | cfg.freeze()
37 | default_setup(cfg, args)
38 |
39 | return cfg
40 |
41 | def do_flop(cfg):
42 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
43 | model = build_model(cfg)
44 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
45 | model.eval()
46 |
47 | counts = Counter()
48 | total_flops = []
49 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa
50 | flops = FlopCountAnalysis(model, data)
51 | if idx > 0:
52 | flops.unsupported_ops_warnings(False).uncalled_modules_warnings(False)
53 | counts += flops.by_operator()
54 | total_flops.append(flops.total())
55 |
56 | logger.info("Flops table computed from only one input sample:\n" + flop_count_table(flops))
57 | logger.info(
58 | "Average GFlops for each type of operators:\n"
59 | + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()])
60 | )
61 | logger.info(
62 | "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9)
63 | )
64 |
65 |
66 | def do_activation(cfg):
67 | data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
68 | model = build_model(cfg)
69 | DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS)
70 | model.eval()
71 |
72 | counts = Counter()
73 | total_activations = []
74 | for idx, data in zip(tqdm.trange(args.num_inputs), data_loader): # noqa
75 | count = activation_count_operators(model, data)
76 | counts += count
77 | total_activations.append(sum(count.values()))
78 | logger.info(
79 | "(Million) Activations for Each Type of Operators:\n"
80 | + str([(k, v / idx) for k, v in counts.items()])
81 | )
82 | logger.info(
83 | "Total (Million) Activations: {}±{}".format(
84 | np.mean(total_activations), np.std(total_activations)
85 | )
86 | )
87 |
88 |
89 | def do_parameter(cfg):
90 | model = build_model(cfg)
91 | logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=10))
92 |
93 |
94 | def do_structure(cfg):
95 | model = build_model(cfg)
96 | logger.info("Model Structure:\n" + str(model))
97 |
98 |
99 | if __name__ == "__main__":
100 | parser = default_argument_parser(
101 | epilog="""
102 | Examples:
103 |
104 | To show parameters of a model:
105 | $ ./analyze_model.py --tasks parameter \\
106 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml
107 |
108 | Flops and activations are data-dependent, therefore inputs and model weights
109 | are needed to count them:
110 |
111 | $ ./analyze_model.py --num-inputs 100 --tasks flop \\
112 | --config-file ../configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml \\
113 | MODEL.WEIGHTS /path/to/model.pkl
114 | """
115 | )
116 | parser.add_argument(
117 | "--tasks",
118 | choices=["flop", "activation", "parameter", "structure"],
119 | required=True,
120 | nargs="+",
121 | )
122 | parser.add_argument(
123 | "-n",
124 | "--num-inputs",
125 | default=100,
126 | type=int,
127 | help="number of inputs used to compute statistics for flops/activations, "
128 | "both are data dependent.",
129 | )
130 | args = parser.parse_args()
131 | assert not args.eval_only
132 | assert args.num_gpus == 1
133 |
134 | cfg = setup(args)
135 |
136 | for task in args.tasks:
137 | {
138 | "flop": do_flop,
139 | "activation": do_activation,
140 | "parameter": do_parameter,
141 | "structure": do_structure,
142 | }[task](cfg)
143 |
--------------------------------------------------------------------------------
/object_detection/configs/Base-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "GeneralizedRCNN"
3 | BACKBONE:
4 | NAME: "build_resnet_fpn_backbone"
5 | RESNETS:
6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"]
7 | FPN:
8 | IN_FEATURES: ["res2", "res3", "res4", "res5"]
9 | ANCHOR_GENERATOR:
10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
12 | RPN:
13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level
16 | # Detectron1 uses 2000 proposals per-batch,
17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
19 | POST_NMS_TOPK_TRAIN: 1000
20 | POST_NMS_TOPK_TEST: 1000
21 | ROI_HEADS:
22 | NAME: "StandardROIHeads"
23 | IN_FEATURES: ["p2", "p3", "p4", "p5"]
24 | ROI_BOX_HEAD:
25 | NAME: "FastRCNNConvFCHead"
26 | NUM_FC: 2
27 | POOLER_RESOLUTION: 7
28 | ROI_MASK_HEAD:
29 | NAME: "MaskRCNNConvUpsampleHead"
30 | NUM_CONV: 4
31 | POOLER_RESOLUTION: 14
32 | DATASETS:
33 | TRAIN: ("coco_2017_train",)
34 | TEST: ("coco_2017_val",)
35 | SOLVER:
36 | IMS_PER_BATCH: 16
37 | BASE_LR: 0.0001
38 | WEIGHT_DECAY: 0.05
39 | OPTIMIZER: "AdamW"
40 | CLIP_GRADIENTS:
41 | ENABLED: True
42 | CLIP_TYPE: "full_model"
43 | CLIP_VALUE: 0.01
44 | NORM_TYPE: 2.0
45 | STEPS: (60000, 80000)
46 | MAX_ITER: 90000
47 | INPUT:
48 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
49 | CROP:
50 | ENABLED: True
51 | TYPE: "relative"
52 | SIZE: [0.9, 0.9]
53 | VERSION: 2
--------------------------------------------------------------------------------
/object_detection/configs/Base-RetinaNet.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "RetinaNet"
3 | BACKBONE:
4 | NAME: "build_retinanet_resnet_fpn_backbone"
5 | RESNETS:
6 | OUT_FEATURES: ["res3", "res4", "res5"]
7 | ANCHOR_GENERATOR:
8 | SIZES: !!python/object/apply:eval ["[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [32, 64, 128, 256, 512 ]]"]
9 | FPN:
10 | IN_FEATURES: ["res3", "res4", "res5"]
11 | RETINANET:
12 | IOU_THRESHOLDS: [0.4, 0.5]
13 | IOU_LABELS: [0, -1, 1]
14 | SMOOTH_L1_LOSS_BETA: 0.0
15 | DATASETS:
16 | TRAIN: ("coco_2017_train",)
17 | TEST: ("coco_2017_val",)
18 | SOLVER:
19 | IMS_PER_BATCH: 16
20 | BASE_LR: 0.01 # Note that RetinaNet uses a different default learning rate
21 | STEPS: (60000, 80000)
22 | MAX_ITER: 90000
23 | INPUT:
24 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
25 | VERSION: 2
26 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv1/mask_rcnn_rest_base_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_rest_fpn_backbone"
6 | MASK_ON: True
7 | REST:
8 | NAME : "rest_base"
9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
10 | FPN:
11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
12 | SOLVER:
13 | OPTIMIZER: "AdamW"
14 | BASE_LR: 0.0001
15 | WEIGHT_DECAY: 0.05
16 | OUTPUT_DIR: "output/mask_rcnn/rest_base_ms_1x"
17 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv1/mask_rcnn_rest_small_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_rest_fpn_backbone"
6 | MASK_ON: True
7 | REST:
8 | NAME : "rest_small"
9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
10 | FPN:
11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
12 | SOLVER:
13 | OPTIMIZER: "AdamW"
14 | BASE_LR: 0.0001
15 | WEIGHT_DECAY: 0.05
16 | OUTPUT_DIR: "output/mask_rcnn/rest_small_ms_1x"
17 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv1/retinanet_rest_base_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RetinaNet.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_retinanet_rest_fpn_backbone"
6 | REST:
7 | NAME : "rest_base"
8 | OUT_FEATURES: ["stage2", "stage3", "stage4"]
9 | FPN:
10 | IN_FEATURES: ["stage2", "stage3", "stage4"]
11 | SOLVER:
12 | OPTIMIZER: "AdamW"
13 | BASE_LR: 0.0001
14 | WEIGHT_DECAY: 0.05
15 | OUTPUT_DIR: "output/retinanet/rest_base_ms_1x"
16 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv1/retinanet_rest_small_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RetinaNet.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_retinanet_rest_fpn_backbone"
6 | REST:
7 | NAME : "rest_small"
8 | OUT_FEATURES: ["stage2", "stage3", "stage4"]
9 | FPN:
10 | IN_FEATURES: ["stage2", "stage3", "stage4"]
11 | SOLVER:
12 | OPTIMIZER: "AdamW"
13 | BASE_LR: 0.0001
14 | WEIGHT_DECAY: 0.05
15 | OUTPUT_DIR: "output/retinanet/rest_small_ms_1x"
16 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/mask_rcnn_rest_base_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_rest_fpn_backbone"
6 | MASK_ON: True
7 | REST:
8 | NAME : "rest_base"
9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
10 | FPN:
11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
12 | SOLVER:
13 | OPTIMIZER: "AdamW"
14 | BASE_LR: 0.0001
15 | WEIGHT_DECAY: 0.05
16 | OUTPUT_DIR: "output/mask_rcnn/rest_base_ms_1x"
17 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/mask_rcnn_rest_small_FPN_1x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: ""
4 | BACKBONE:
5 | NAME: "build_rest_fpn_backbone"
6 | MASK_ON: True
7 | REST:
8 | NAME : "rest_small"
9 | OUT_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
10 | FPN:
11 | IN_FEATURES: ["stage1", "stage2", "stage3", "stage4"]
12 | SOLVER:
13 | OPTIMIZER: "AdamW"
14 | BASE_LR: 0.0001
15 | WEIGHT_DECAY: 0.05
16 | OUTPUT_DIR: "output/mask_rcnn/rest_small_ms_1x"
17 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/mask_rcnn_restv2_base_FPN_3x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_base_224.pth"
4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ]
5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ]
6 | MASK_ON: True
7 | BACKBONE:
8 | NAME: "build_restv2_fpn_backbone"
9 | RESTV2:
10 | NAME: "restv2_base"
11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
12 | FPN:
13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
14 | INPUT:
15 | FORMAT: "RGB"
16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800)
17 | SOLVER:
18 | STEPS: (210000, 250000)
19 | MAX_ITER: 270000
20 | WEIGHT_DECAY: 0.05
21 | BASE_LR: 0.0001
22 | AMP:
23 | ENABLED: True
24 | TEST:
25 | EVAL_PERIOD: 20000
26 |
27 | DATASETS:
28 | TRAIN: ("coco_2017_train",)
29 | TEST: ("coco_2017_val",)
30 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/mask_rcnn_restv2_small_FPN_3x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_small_224.pth"
4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ]
5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ]
6 | MASK_ON: True
7 | BACKBONE:
8 | NAME: "build_restv2_fpn_backbone"
9 | RESTV2:
10 | NAME: "restv2_small"
11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
12 | FPN:
13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
14 | INPUT:
15 | FORMAT: "RGB"
16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800)
17 | SOLVER:
18 | STEPS: (210000, 250000)
19 | MAX_ITER: 270000
20 | WEIGHT_DECAY: 0.05
21 | BASE_LR: 0.0001
22 | AMP:
23 | ENABLED: True
24 | TEST:
25 | EVAL_PERIOD: 20000
26 |
27 | DATASETS:
28 | TRAIN: ("coco_2017_train",)
29 | TEST: ("coco_2017_val",)
30 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/mask_rcnn_restv2_tiny_FPN_3x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RCNN-FPN.yaml"
2 | MODEL:
3 | WEIGHTS: "/home/zhangql/ResTV2/workdirs/restv2_tiny_224.pth"
4 | PIXEL_MEAN: [ 123.675, 116.28, 103.53 ]
5 | PIXEL_STD: [ 58.395, 57.12, 57.375 ]
6 | MASK_ON: True
7 | BACKBONE:
8 | NAME: "build_restv2_fpn_backbone"
9 | RESTV2:
10 | NAME: "restv2_tiny"
11 | OUT_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
12 | FPN:
13 | IN_FEATURES: [ "stage1", "stage2", "stage3", "stage4" ]
14 | INPUT:
15 | FORMAT: "RGB"
16 | MIN_SIZE_TRAIN: (480, 512, 534, 576, 640, 672, 704, 736, 768, 800)
17 | SOLVER:
18 | STEPS: (210000, 250000)
19 | MAX_ITER: 270000
20 | WEIGHT_DECAY: 0.05
21 | BASE_LR: 0.0001
22 | AMP:
23 | ENABLED: True
24 | TEST:
25 | EVAL_PERIOD: 20000
26 |
27 | DATASETS:
28 | TRAIN: ("coco_2017_train",)
29 | TEST: ("coco_2017_val",)
30 |
--------------------------------------------------------------------------------
/object_detection/configs/ResTv2/retinanet_restv2_tiny_FPN_3x.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "../Base-RetinaNet.yaml"
2 | MODEL:
3 | WEIGHTS: "restv2_tiny.pth"
4 | PIXEL_MEAN: [123.675, 116.28, 103.53] # use RGB [103.530, 116.280, 123.675]
5 | PIXEL_STD: [58.395, 57.12, 57.375] #[57.375, 57.120, 58.395] # I use the dafault config [1.0, 1.0, 1.0] and BGR format, that is a mistake
6 | RESNETS:
7 | DEPTH: 50
8 | BACKBONE:
9 | NAME: "build_retinanet_swint_fpn_backbone"
10 | SWINT:
11 | OUT_FEATURES: ["stage3", "stage4", "stage5"]
12 | FPN:
13 | IN_FEATURES: ["stage3", "stage4", "stage5"]
14 | INPUT:
15 | FORMAT: "RGB"
16 | MIN_SIZE_TRAIN: (384, 512, 640, 768, 896, 1024)
17 | SOLVER:
18 | STEPS: (210000, 250000)
19 | MAX_ITER: 270000
20 | WEIGHT_DECAY: 0.05
21 | BASE_LR: 0.0001
22 | AMP:
23 | ENABLED: True
24 | TEST:
25 | EVAL_PERIOD: 30000
26 |
27 | DATASETS:
28 | TRAIN: ("coco_2017_train",)
29 | TEST: ("coco_2017_val",)
30 |
--------------------------------------------------------------------------------
/object_detection/convert_to_d2.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import os
8 | import argparse
9 |
10 | import torch
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser("D2 model converter")
15 |
16 | parser.add_argument("--source_model", default="", type=str, help="Path or url to the model to convert")
17 | parser.add_argument("--output_model", default="", type=str, help="Path where to save the converted model")
18 | return parser.parse_args()
19 |
20 |
21 | def main():
22 | args = parse_args()
23 |
24 | source_weights = torch.load(args.source_model)["model"]
25 | converted_weights = {}
26 | keys = list(source_weights.keys())
27 |
28 | prefix = 'backbone.bottom_up.'
29 | for key in keys:
30 | converted_weights[prefix + key] = source_weights[key]
31 |
32 | torch.save(converted_weights, args.output_model)
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------
/object_detection/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Use Builtin Datasets
2 |
3 | A dataset can be used by accessing [DatasetCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.DatasetCatalog)
4 | for its data, or [MetadataCatalog](https://detectron2.readthedocs.io/modules/data.html#detectron2.data.MetadataCatalog) for its metadata (class names, etc).
5 | This document explains how to setup the builtin datasets so they can be used by the above APIs.
6 | [Use Custom Datasets](https://detectron2.readthedocs.io/tutorials/datasets.html) gives a deeper dive on how to use `DatasetCatalog` and `MetadataCatalog`,
7 | and how to add new datasets to them.
8 |
9 | Detectron2 has builtin support for a few datasets.
10 | The datasets are assumed to exist in a directory specified by the environment variable
11 | `DETECTRON2_DATASETS`.
12 | Under this directory, detectron2 will look for datasets in the structure described below, if needed.
13 | ```
14 | $DETECTRON2_DATASETS/
15 | coco/
16 | lvis/
17 | cityscapes/
18 | VOC20{07,12}/
19 | ```
20 |
21 | You can set the location for builtin datasets by `export DETECTRON2_DATASETS=/path/to/datasets`.
22 | If left unset, the default is `./datasets` relative to your current working directory.
23 |
24 | The [model zoo](https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md)
25 | contains configs and models that use these builtin datasets.
26 |
27 | ## Expected dataset structure for [COCO instance/keypoint detection](https://cocodataset.org/#download):
28 |
29 | ```
30 | coco/
31 | annotations/
32 | instances_{train,val}2017.json
33 | person_keypoints_{train,val}2017.json
34 | {train,val}2017/
35 | # image files that are mentioned in the corresponding json
36 | ```
37 |
38 | You can use the 2014 version of the dataset as well.
39 |
40 | Some of the builtin tests (`dev/run_*_tests.sh`) uses a tiny version of the COCO dataset,
41 | which you can download with `./datasets/prepare_for_tests.sh`.
42 |
43 | ## Expected dataset structure for PanopticFPN:
44 |
45 | Extract panoptic annotations from [COCO website](https://cocodataset.org/#download)
46 | into the following structure:
47 | ```
48 | coco/
49 | annotations/
50 | panoptic_{train,val}2017.json
51 | panoptic_{train,val}2017/ # png annotations
52 | panoptic_stuff_{train,val}2017/ # generated by the script mentioned below
53 | ```
54 |
55 | Install panopticapi by:
56 | ```
57 | pip install git+https://github.com/cocodataset/panopticapi.git
58 | ```
59 | Then, run `python datasets/prepare_panoptic_fpn.py`, to extract semantic annotations from panoptic annotations.
60 |
61 | ## Expected dataset structure for [LVIS instance segmentation](https://www.lvisdataset.org/dataset):
62 | ```
63 | coco/
64 | {train,val,test}2017/
65 | lvis/
66 | lvis_v0.5_{train,val}.json
67 | lvis_v0.5_image_info_test.json
68 | lvis_v1_{train,val}.json
69 | lvis_v1_image_info_test{,_challenge}.json
70 | ```
71 |
72 | Install lvis-api by:
73 | ```
74 | pip install git+https://github.com/lvis-dataset/lvis-api.git
75 | ```
76 |
77 | To evaluate models trained on the COCO dataset using LVIS annotations,
78 | run `python datasets/prepare_cocofied_lvis.py` to prepare "cocofied" LVIS annotations.
79 |
80 | ## Expected dataset structure for [cityscapes](https://www.cityscapes-dataset.com/downloads/):
81 | ```
82 | cityscapes/
83 | gtFine/
84 | train/
85 | aachen/
86 | color.png, instanceIds.png, labelIds.png, polygons.json,
87 | labelTrainIds.png
88 | ...
89 | val/
90 | test/
91 | # below are generated Cityscapes panoptic annotation
92 | cityscapes_panoptic_train.json
93 | cityscapes_panoptic_train/
94 | cityscapes_panoptic_val.json
95 | cityscapes_panoptic_val/
96 | cityscapes_panoptic_test.json
97 | cityscapes_panoptic_test/
98 | leftImg8bit/
99 | train/
100 | val/
101 | test/
102 | ```
103 | Install cityscapes scripts by:
104 | ```
105 | pip install git+https://github.com/mcordts/cityscapesScripts.git
106 | ```
107 |
108 | Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with:
109 | ```
110 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py
111 | ```
112 | These files are not needed for instance segmentation.
113 |
114 | Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
115 | ```
116 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
117 | ```
118 | These files are not needed for semantic and instance segmentation.
119 |
120 | ## Expected dataset structure for [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/index.html):
121 | ```
122 | VOC20{07,12}/
123 | Annotations/
124 | ImageSets/
125 | Main/
126 | trainval.txt
127 | test.txt
128 | # train.txt or val.txt, if you use these splits
129 | JPEGImages/
130 | ```
131 |
132 | ## Expected dataset structure for [ADE20k Scene Parsing](http://sceneparsing.csail.mit.edu/):
133 | ```
134 | ADEChallengeData2016/
135 | annotations/
136 | annotations_detectron2/
137 | images/
138 | objectInfo150.txt
139 | ```
140 | The directory `annotations_detectron2` is generated by running `python datasets/prepare_ade20k_sem_seg.py`.
141 |
--------------------------------------------------------------------------------
/object_detection/datasets/prepare_ade20k_sem_seg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | import numpy as np
5 | import os
6 | from pathlib import Path
7 | import tqdm
8 | from PIL import Image
9 |
10 |
11 | def convert(input, output):
12 | img = np.asarray(Image.open(input))
13 | assert img.dtype == np.uint8
14 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
15 | Image.fromarray(img).save(output)
16 |
17 |
18 | if __name__ == "__main__":
19 | dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ADEChallengeData2016"
20 | for name in ["training", "validation"]:
21 | annotation_dir = dataset_dir / "annotations" / name
22 | output_dir = dataset_dir / "annotations_detectron2" / name
23 | output_dir.mkdir(parents=True, exist_ok=True)
24 | for file in tqdm.tqdm(list(annotation_dir.iterdir())):
25 | output_file = output_dir / file.name
26 | convert(file, output_file)
27 |
--------------------------------------------------------------------------------
/object_detection/datasets/prepare_cocofied_lvis.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | import copy
6 | import json
7 | import os
8 | from collections import defaultdict
9 |
10 | # This mapping is extracted from the official LVIS mapping:
11 | # https://github.com/lvis-dataset/lvis-api/blob/master/data/coco_to_synset.json
12 | COCO_SYNSET_CATEGORIES = [
13 | {"synset": "person.n.01", "coco_cat_id": 1},
14 | {"synset": "bicycle.n.01", "coco_cat_id": 2},
15 | {"synset": "car.n.01", "coco_cat_id": 3},
16 | {"synset": "motorcycle.n.01", "coco_cat_id": 4},
17 | {"synset": "airplane.n.01", "coco_cat_id": 5},
18 | {"synset": "bus.n.01", "coco_cat_id": 6},
19 | {"synset": "train.n.01", "coco_cat_id": 7},
20 | {"synset": "truck.n.01", "coco_cat_id": 8},
21 | {"synset": "boat.n.01", "coco_cat_id": 9},
22 | {"synset": "traffic_light.n.01", "coco_cat_id": 10},
23 | {"synset": "fireplug.n.01", "coco_cat_id": 11},
24 | {"synset": "stop_sign.n.01", "coco_cat_id": 13},
25 | {"synset": "parking_meter.n.01", "coco_cat_id": 14},
26 | {"synset": "bench.n.01", "coco_cat_id": 15},
27 | {"synset": "bird.n.01", "coco_cat_id": 16},
28 | {"synset": "cat.n.01", "coco_cat_id": 17},
29 | {"synset": "dog.n.01", "coco_cat_id": 18},
30 | {"synset": "horse.n.01", "coco_cat_id": 19},
31 | {"synset": "sheep.n.01", "coco_cat_id": 20},
32 | {"synset": "beef.n.01", "coco_cat_id": 21},
33 | {"synset": "elephant.n.01", "coco_cat_id": 22},
34 | {"synset": "bear.n.01", "coco_cat_id": 23},
35 | {"synset": "zebra.n.01", "coco_cat_id": 24},
36 | {"synset": "giraffe.n.01", "coco_cat_id": 25},
37 | {"synset": "backpack.n.01", "coco_cat_id": 27},
38 | {"synset": "umbrella.n.01", "coco_cat_id": 28},
39 | {"synset": "bag.n.04", "coco_cat_id": 31},
40 | {"synset": "necktie.n.01", "coco_cat_id": 32},
41 | {"synset": "bag.n.06", "coco_cat_id": 33},
42 | {"synset": "frisbee.n.01", "coco_cat_id": 34},
43 | {"synset": "ski.n.01", "coco_cat_id": 35},
44 | {"synset": "snowboard.n.01", "coco_cat_id": 36},
45 | {"synset": "ball.n.06", "coco_cat_id": 37},
46 | {"synset": "kite.n.03", "coco_cat_id": 38},
47 | {"synset": "baseball_bat.n.01", "coco_cat_id": 39},
48 | {"synset": "baseball_glove.n.01", "coco_cat_id": 40},
49 | {"synset": "skateboard.n.01", "coco_cat_id": 41},
50 | {"synset": "surfboard.n.01", "coco_cat_id": 42},
51 | {"synset": "tennis_racket.n.01", "coco_cat_id": 43},
52 | {"synset": "bottle.n.01", "coco_cat_id": 44},
53 | {"synset": "wineglass.n.01", "coco_cat_id": 46},
54 | {"synset": "cup.n.01", "coco_cat_id": 47},
55 | {"synset": "fork.n.01", "coco_cat_id": 48},
56 | {"synset": "knife.n.01", "coco_cat_id": 49},
57 | {"synset": "spoon.n.01", "coco_cat_id": 50},
58 | {"synset": "bowl.n.03", "coco_cat_id": 51},
59 | {"synset": "banana.n.02", "coco_cat_id": 52},
60 | {"synset": "apple.n.01", "coco_cat_id": 53},
61 | {"synset": "sandwich.n.01", "coco_cat_id": 54},
62 | {"synset": "orange.n.01", "coco_cat_id": 55},
63 | {"synset": "broccoli.n.01", "coco_cat_id": 56},
64 | {"synset": "carrot.n.01", "coco_cat_id": 57},
65 | {"synset": "frank.n.02", "coco_cat_id": 58},
66 | {"synset": "pizza.n.01", "coco_cat_id": 59},
67 | {"synset": "doughnut.n.02", "coco_cat_id": 60},
68 | {"synset": "cake.n.03", "coco_cat_id": 61},
69 | {"synset": "chair.n.01", "coco_cat_id": 62},
70 | {"synset": "sofa.n.01", "coco_cat_id": 63},
71 | {"synset": "pot.n.04", "coco_cat_id": 64},
72 | {"synset": "bed.n.01", "coco_cat_id": 65},
73 | {"synset": "dining_table.n.01", "coco_cat_id": 67},
74 | {"synset": "toilet.n.02", "coco_cat_id": 70},
75 | {"synset": "television_receiver.n.01", "coco_cat_id": 72},
76 | {"synset": "laptop.n.01", "coco_cat_id": 73},
77 | {"synset": "mouse.n.04", "coco_cat_id": 74},
78 | {"synset": "remote_control.n.01", "coco_cat_id": 75},
79 | {"synset": "computer_keyboard.n.01", "coco_cat_id": 76},
80 | {"synset": "cellular_telephone.n.01", "coco_cat_id": 77},
81 | {"synset": "microwave.n.02", "coco_cat_id": 78},
82 | {"synset": "oven.n.01", "coco_cat_id": 79},
83 | {"synset": "toaster.n.02", "coco_cat_id": 80},
84 | {"synset": "sink.n.01", "coco_cat_id": 81},
85 | {"synset": "electric_refrigerator.n.01", "coco_cat_id": 82},
86 | {"synset": "book.n.01", "coco_cat_id": 84},
87 | {"synset": "clock.n.01", "coco_cat_id": 85},
88 | {"synset": "vase.n.01", "coco_cat_id": 86},
89 | {"synset": "scissors.n.01", "coco_cat_id": 87},
90 | {"synset": "teddy.n.01", "coco_cat_id": 88},
91 | {"synset": "hand_blower.n.01", "coco_cat_id": 89},
92 | {"synset": "toothbrush.n.01", "coco_cat_id": 90},
93 | ]
94 |
95 |
96 | def cocofy_lvis(input_filename, output_filename):
97 | """
98 | Filter LVIS instance segmentation annotations to remove all categories that are not included in
99 | COCO. The new json files can be used to evaluate COCO AP using `lvis-api`. The category ids in
100 | the output json are the incontiguous COCO dataset ids.
101 |
102 | Args:
103 | input_filename (str): path to the LVIS json file.
104 | output_filename (str): path to the COCOfied json file.
105 | """
106 |
107 | with open(input_filename, "r") as f:
108 | lvis_json = json.load(f)
109 |
110 | lvis_annos = lvis_json.pop("annotations")
111 | cocofied_lvis = copy.deepcopy(lvis_json)
112 | lvis_json["annotations"] = lvis_annos
113 |
114 | # Mapping from lvis cat id to coco cat id via synset
115 | lvis_cat_id_to_synset = {cat["id"]: cat["synset"] for cat in lvis_json["categories"]}
116 | synset_to_coco_cat_id = {x["synset"]: x["coco_cat_id"] for x in COCO_SYNSET_CATEGORIES}
117 | # Synsets that we will keep in the dataset
118 | synsets_to_keep = set(synset_to_coco_cat_id.keys())
119 | coco_cat_id_with_instances = defaultdict(int)
120 |
121 | new_annos = []
122 | ann_id = 1
123 | for ann in lvis_annos:
124 | lvis_cat_id = ann["category_id"]
125 | synset = lvis_cat_id_to_synset[lvis_cat_id]
126 | if synset not in synsets_to_keep:
127 | continue
128 | coco_cat_id = synset_to_coco_cat_id[synset]
129 | new_ann = copy.deepcopy(ann)
130 | new_ann["category_id"] = coco_cat_id
131 | new_ann["id"] = ann_id
132 | ann_id += 1
133 | new_annos.append(new_ann)
134 | coco_cat_id_with_instances[coco_cat_id] += 1
135 | cocofied_lvis["annotations"] = new_annos
136 |
137 | for image in cocofied_lvis["images"]:
138 | for key in ["not_exhaustive_category_ids", "neg_category_ids"]:
139 | new_category_list = []
140 | for lvis_cat_id in image[key]:
141 | synset = lvis_cat_id_to_synset[lvis_cat_id]
142 | if synset not in synsets_to_keep:
143 | continue
144 | coco_cat_id = synset_to_coco_cat_id[synset]
145 | new_category_list.append(coco_cat_id)
146 | coco_cat_id_with_instances[coco_cat_id] += 1
147 | image[key] = new_category_list
148 |
149 | coco_cat_id_with_instances = set(coco_cat_id_with_instances.keys())
150 |
151 | new_categories = []
152 | for cat in lvis_json["categories"]:
153 | synset = cat["synset"]
154 | if synset not in synsets_to_keep:
155 | continue
156 | coco_cat_id = synset_to_coco_cat_id[synset]
157 | if coco_cat_id not in coco_cat_id_with_instances:
158 | continue
159 | new_cat = copy.deepcopy(cat)
160 | new_cat["id"] = coco_cat_id
161 | new_categories.append(new_cat)
162 | cocofied_lvis["categories"] = new_categories
163 |
164 | with open(output_filename, "w") as f:
165 | json.dump(cocofied_lvis, f)
166 | print("{} is COCOfied and stored in {}.".format(input_filename, output_filename))
167 |
168 |
169 | if __name__ == "__main__":
170 | dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "lvis")
171 | for s in ["lvis_v0.5_train", "lvis_v0.5_val"]:
172 | print("Start COCOfing {}.".format(s))
173 | cocofy_lvis(
174 | os.path.join(dataset_dir, "{}.json".format(s)),
175 | os.path.join(dataset_dir, "{}_cocofied.json".format(s)),
176 | )
177 |
--------------------------------------------------------------------------------
/object_detection/datasets/prepare_for_tests.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | # Download some files needed for running tests.
5 |
6 | cd "${0%/*}"
7 |
8 | BASE=https://dl.fbaipublicfiles.com/detectron2
9 | mkdir -p coco/annotations
10 |
11 | for anno in instances_val2017_100 \
12 | person_keypoints_val2017_100 \
13 | instances_minival2014_100 \
14 | person_keypoints_minival2014_100; do
15 |
16 | dest=coco/annotations/$anno.json
17 | [[ -s $dest ]] && {
18 | echo "$dest exists. Skipping ..."
19 | } || {
20 | wget $BASE/annotations/coco/$anno.json -O $dest
21 | }
22 | done
23 |
--------------------------------------------------------------------------------
/object_detection/datasets/prepare_panoptic_fpn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 |
5 | import functools
6 | import json
7 | import multiprocessing as mp
8 | import numpy as np
9 | import os
10 | import time
11 | from fvcore.common.download import download
12 | from panopticapi.utils import rgb2id
13 | from PIL import Image
14 |
15 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
16 |
17 |
18 | def _process_panoptic_to_semantic(input_panoptic, output_semantic, segments, id_map):
19 | panoptic = np.asarray(Image.open(input_panoptic), dtype=np.uint32)
20 | panoptic = rgb2id(panoptic)
21 | output = np.zeros_like(panoptic, dtype=np.uint8) + 255
22 | for seg in segments:
23 | cat_id = seg["category_id"]
24 | new_cat_id = id_map[cat_id]
25 | output[panoptic == seg["id"]] = new_cat_id
26 | Image.fromarray(output).save(output_semantic)
27 |
28 |
29 | def separate_coco_semantic_from_panoptic(panoptic_json, panoptic_root, sem_seg_root, categories):
30 | """
31 | Create semantic segmentation annotations from panoptic segmentation
32 | annotations, to be used by PanopticFPN.
33 |
34 | It maps all thing categories to class 0, and maps all unlabeled pixels to class 255.
35 | It maps all stuff categories to contiguous ids starting from 1.
36 |
37 | Args:
38 | panoptic_json (str): path to the panoptic json file, in COCO's format.
39 | panoptic_root (str): a directory with panoptic annotation files, in COCO's format.
40 | sem_seg_root (str): a directory to output semantic annotation files
41 | categories (list[dict]): category metadata. Each dict needs to have:
42 | "id": corresponds to the "category_id" in the json annotations
43 | "isthing": 0 or 1
44 | """
45 | os.makedirs(sem_seg_root, exist_ok=True)
46 |
47 | stuff_ids = [k["id"] for k in categories if k["isthing"] == 0]
48 | thing_ids = [k["id"] for k in categories if k["isthing"] == 1]
49 | id_map = {} # map from category id to id in the output semantic annotation
50 | assert len(stuff_ids) <= 254
51 | for i, stuff_id in enumerate(stuff_ids):
52 | id_map[stuff_id] = i + 1
53 | for thing_id in thing_ids:
54 | id_map[thing_id] = 0
55 | id_map[0] = 255
56 |
57 | with open(panoptic_json) as f:
58 | obj = json.load(f)
59 |
60 | pool = mp.Pool(processes=max(mp.cpu_count() // 2, 4))
61 |
62 | def iter_annotations():
63 | for anno in obj["annotations"]:
64 | file_name = anno["file_name"]
65 | segments = anno["segments_info"]
66 | input = os.path.join(panoptic_root, file_name)
67 | output = os.path.join(sem_seg_root, file_name)
68 | yield input, output, segments
69 |
70 | print("Start writing to {} ...".format(sem_seg_root))
71 | start = time.time()
72 | pool.starmap(
73 | functools.partial(_process_panoptic_to_semantic, id_map=id_map),
74 | iter_annotations(),
75 | chunksize=100,
76 | )
77 | print("Finished. time: {:.2f}s".format(time.time() - start))
78 |
79 |
80 | if __name__ == "__main__":
81 | dataset_dir = os.path.join(os.getenv("DETECTRON2_DATASETS", "datasets"), "coco")
82 | for s in ["val2017", "train2017"]:
83 | separate_coco_semantic_from_panoptic(
84 | os.path.join(dataset_dir, "annotations/panoptic_{}.json".format(s)),
85 | os.path.join(dataset_dir, "panoptic_{}".format(s)),
86 | os.path.join(dataset_dir, "panoptic_stuff_{}".format(s)),
87 | COCO_CATEGORIES,
88 | )
89 |
90 | # Prepare val2017_100 for quick testing:
91 |
92 | dest_dir = os.path.join(dataset_dir, "annotations/")
93 | URL_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
94 | download(URL_PREFIX + "annotations/coco/panoptic_val2017_100.json", dest_dir)
95 | with open(os.path.join(dest_dir, "panoptic_val2017_100.json")) as f:
96 | obj = json.load(f)
97 |
98 | def link_val100(dir_full, dir_100):
99 | print("Creating " + dir_100 + " ...")
100 | os.makedirs(dir_100, exist_ok=True)
101 | for img in obj["images"]:
102 | basename = os.path.splitext(img["file_name"])[0]
103 | src = os.path.join(dir_full, basename + ".png")
104 | dst = os.path.join(dir_100, basename + ".png")
105 | src = os.path.relpath(src, start=dir_100)
106 | os.symlink(src, dst)
107 |
108 | link_val100(
109 | os.path.join(dataset_dir, "panoptic_val2017"),
110 | os.path.join(dataset_dir, "panoptic_val2017_100"),
111 | )
112 |
113 | link_val100(
114 | os.path.join(dataset_dir, "panoptic_stuff_val2017"),
115 | os.path.join(dataset_dir, "panoptic_stuff_val2017_100"),
116 | )
117 |
--------------------------------------------------------------------------------
/object_detection/restv2/__init__.py:
--------------------------------------------------------------------------------
1 | from .rest import *
2 | from .restv2 import *
3 | from .config import *
4 | __all__ = [k for k in globals().keys() if not k.startswith("_")]
--------------------------------------------------------------------------------
/object_detection/restv2/config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | from detectron2.config import CfgNode as CN
8 |
9 |
10 | def add_restv2_config(cfg):
11 | # restv2 backbone
12 | cfg.MODEL.RESTV2 = CN()
13 | cfg.MODEL.RESTV2.NAME = "restv2_tiny"
14 | cfg.MODEL.RESTV2.OUT_FEATURES = ["stage1", "stage2", "stage3", "stage4"]
15 | cfg.MODEL.RESTV2.WEIGHTS = None
16 | cfg.MODEL.BACKBONE.FREEZE_AT = 2
17 |
18 | # addition
19 | cfg.MODEL.FPN.TOP_LEVELS = 2
20 | cfg.SOLVER.OPTIMIZER = "AdamW"
21 |
22 |
23 | def add_restv1_config(cfg):
24 | # restv1 backbone
25 | cfg.MODEL.REST = CN()
26 | cfg.MODEL.REST.NAME = "rest_base"
27 | cfg.MODEL.REST.OUT_FEATURES = ["stage1", "stage2", "stage3", "stage4"]
28 | cfg.MODEL.REST.WEIGHTS = None
29 | cfg.MODEL.BACKBONE.FREEZE_AT = 2
30 |
31 | # addition
32 | cfg.MODEL.FPN.TOP_LEVELS = 2
33 | cfg.SOLVER.OPTIMIZER = "AdamW"
34 |
--------------------------------------------------------------------------------
/object_detection/restv2/rest.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # ResT V1
3 | # Copyright (c) VCU, Nanjing University
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Qing-Long Zhang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
12 | from detectron2.modeling.backbone import Backbone
13 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
14 | from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7, LastLevelMaxPool
15 | from detectron2.layers import ShapeSpec
16 |
17 |
18 | __all__ = [
19 | "ResT",
20 | "build_rest_backbone",
21 | "build_rest_fpn_backbone",
22 | "build_retinanet_rest_fpn_backbone"]
23 |
24 |
25 | class Mlp(nn.Module):
26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
27 | super().__init__()
28 | out_features = out_features or in_features
29 | hidden_features = hidden_features or in_features
30 | self.fc1 = nn.Linear(in_features, hidden_features)
31 | self.act = act_layer()
32 | self.fc2 = nn.Linear(hidden_features, out_features)
33 | self.drop = nn.Dropout(drop)
34 |
35 | def forward(self, x):
36 | x = self.fc1(x)
37 | x = self.act(x)
38 | x = self.drop(x)
39 | x = self.fc2(x)
40 | x = self.drop(x)
41 | return x
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self,
46 | dim,
47 | num_heads=8,
48 | qkv_bias=False,
49 | qk_scale=None,
50 | attn_drop=0.,
51 | proj_drop=0.,
52 | sr_ratio=1,
53 | apply_transform=False):
54 | super().__init__()
55 | self.num_heads = num_heads
56 | head_dim = dim // num_heads
57 | self.scale = qk_scale or head_dim ** -0.5
58 |
59 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
60 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
61 | self.attn_drop = nn.Dropout(attn_drop)
62 | self.proj = nn.Linear(dim, dim)
63 | self.proj_drop = nn.Dropout(proj_drop)
64 |
65 | self.sr_ratio = sr_ratio
66 | if sr_ratio > 1:
67 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio+1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
68 | self.sr_norm = nn.LayerNorm(dim)
69 |
70 | self.apply_transform = apply_transform and num_heads > 1
71 | if self.apply_transform:
72 | self.transform_conv = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, stride=1)
73 | self.transform_norm = nn.InstanceNorm2d(self.num_heads)
74 |
75 | def forward(self, x, H, W):
76 | B, N, C = x.shape
77 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
78 | if self.sr_ratio > 1:
79 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
80 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
81 | x_ = self.sr_norm(x_)
82 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83 | else:
84 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
85 | k, v = kv[0], kv[1]
86 |
87 | attn = (q @ k.transpose(-2, -1)) * self.scale
88 | if self.apply_transform:
89 | attn = self.transform_conv(attn)
90 | attn = attn.softmax(dim=-1)
91 | attn = self.transform_norm(attn)
92 | else:
93 | attn = attn.softmax(dim=-1)
94 |
95 | attn = self.attn_drop(attn)
96 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
97 | x = self.proj(x)
98 | x = self.proj_drop(x)
99 | return x
100 |
101 |
102 | class Block(nn.Module):
103 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
104 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, apply_transform=False):
105 | super().__init__()
106 | self.norm1 = norm_layer(dim)
107 | self.attn = Attention(
108 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
109 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, apply_transform=apply_transform)
110 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
111 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
112 | self.norm2 = norm_layer(dim)
113 | mlp_hidden_dim = int(dim * mlp_ratio)
114 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
115 |
116 | def forward(self, x, H, W):
117 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
118 | x = x + self.drop_path(self.mlp(self.norm2(x)))
119 | return x
120 |
121 |
122 | class PA(nn.Module):
123 | def __init__(self, dim):
124 | super().__init__()
125 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim)
126 | self.sigmoid = nn.Sigmoid()
127 |
128 | def forward(self, x):
129 | return x * self.sigmoid(self.pa_conv(x))
130 |
131 |
132 | class PatchEmbed(nn.Module):
133 | """ Image to Patch Embedding"""
134 | def __init__(self, patch_size=16, in_ch=3, out_ch=768, with_pos=False):
135 | super().__init__()
136 | self.patch_size = to_2tuple(patch_size)
137 | self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size+1, stride=patch_size, padding=patch_size // 2)
138 | self.norm = nn.BatchNorm2d(out_ch)
139 |
140 | self.with_pos = with_pos
141 | if self.with_pos:
142 | self.pos = PA(out_ch)
143 |
144 | def forward(self, x):
145 | B, C, H, W = x.shape
146 | x = self.conv(x)
147 | x = self.norm(x)
148 | if self.with_pos:
149 | x = self.pos(x)
150 | x = x.flatten(2).transpose(1, 2)
151 | H, W = H // self.patch_size[0], W // self.patch_size[1]
152 | return x, (H, W)
153 |
154 |
155 | class BasicStem(nn.Module):
156 | def __init__(self, in_ch=3, out_ch=64, with_pos=False):
157 | super(BasicStem, self).__init__()
158 | hidden_ch = out_ch // 2
159 | self.conv1 = nn.Conv2d(in_ch, hidden_ch, kernel_size=3, stride=2, padding=1, bias=False)
160 | self.norm1 = nn.BatchNorm2d(hidden_ch)
161 | self.conv2 = nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, stride=1, padding=1, bias=False)
162 | self.norm2 = nn.BatchNorm2d(hidden_ch)
163 | self.conv3 = nn.Conv2d(hidden_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False)
164 |
165 | self.act = nn.ReLU(inplace=True)
166 | self.with_pos = with_pos
167 | if self.with_pos:
168 | self.pos = PA(out_ch)
169 |
170 | def forward(self, x):
171 | x = self.conv1(x)
172 | x = self.norm1(x)
173 | x = self.act(x)
174 |
175 | x = self.conv2(x)
176 | x = self.norm2(x)
177 | x = self.act(x)
178 |
179 | x = self.conv3(x)
180 | if self.with_pos:
181 | x = self.pos(x)
182 | return x
183 |
184 |
185 | class ResT(Backbone):
186 | def __init__(self, cfg, in_ch=3, embed_dims=[64, 128, 256, 512], num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4],
187 | qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., depths=[2, 2, 2, 2],
188 | sr_ratios=[8, 4, 2, 1], norm_layer=nn.LayerNorm, apply_transform=True, out_features=None):
189 | super().__init__()
190 | self.depths = depths
191 | self.num_layers = len(depths)
192 | self.apply_transform = apply_transform
193 | self._out_features = out_features
194 |
195 | self.stem = BasicStem(in_ch=in_ch, out_ch=embed_dims[0], with_pos=True)
196 |
197 | self.patch_embed_2 = PatchEmbed(patch_size=2, in_ch=embed_dims[0], out_ch=embed_dims[1], with_pos=True)
198 | self.patch_embed_3 = PatchEmbed(patch_size=2, in_ch=embed_dims[1], out_ch=embed_dims[2], with_pos=True)
199 | self.patch_embed_4 = PatchEmbed(patch_size=2, in_ch=embed_dims[2], out_ch=embed_dims[3], with_pos=True)
200 |
201 | # transformer encoder
202 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
203 | cur = 0
204 |
205 | self.stage1 = nn.ModuleList([
206 | Block(embed_dims[0], num_heads[0], mlp_ratios[0], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
207 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[0], apply_transform=apply_transform)
208 | for i in range(self.depths[0])])
209 |
210 | cur += depths[0]
211 | self.stage2 = nn.ModuleList([
212 | Block(embed_dims[1], num_heads[1], mlp_ratios[1], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
213 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[1], apply_transform=apply_transform)
214 | for i in range(self.depths[1])])
215 |
216 | cur += depths[1]
217 | self.stage3 = nn.ModuleList([
218 | Block(embed_dims[2], num_heads[2], mlp_ratios[2], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
219 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[2], apply_transform=apply_transform)
220 | for i in range(self.depths[2])])
221 |
222 | cur += depths[2]
223 | self.stage4 = nn.ModuleList([
224 | Block(embed_dims[3], num_heads[3], mlp_ratios[3], qkv_bias, qk_scale, drop_rate, attn_drop_rate,
225 | drop_path=dpr[cur+i], norm_layer=norm_layer, sr_ratio=sr_ratios[3], apply_transform=apply_transform)
226 | for i in range(self.depths[3])])
227 |
228 | # add a norm layer for each output
229 | for i in range(self.num_layers-1):
230 | stage = f'stage{i+1}'
231 | if stage in self._out_features:
232 | layer = norm_layer(embed_dims[i])
233 | layer_name = f'norm{i+1}'
234 | self.add_module(layer_name, layer)
235 |
236 | self.norm = norm_layer(embed_dims[3])
237 |
238 | # init weights
239 | self.apply(self._init_weights)
240 | self._freeze_backbone(cfg.MODEL.BACKBONE.FREEZE_AT)
241 |
242 | def _init_weights(self, m):
243 | if isinstance(m, nn.Conv2d):
244 | trunc_normal_(m.weight, std=0.02)
245 | elif isinstance(m, nn.Linear):
246 | trunc_normal_(m.weight, std=0.02)
247 | if m.bias is not None:
248 | nn.init.constant_(m.bias, 0)
249 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
250 | nn.init.constant_(m.weight, 1.0)
251 | nn.init.constant_(m.bias, 0)
252 |
253 | def _freeze_backbone(self, freeze_at):
254 | if freeze_at < 0:
255 | return
256 | for stage_index in range(freeze_at):
257 | if stage_index >= 0:
258 | # stage 0 is the stem
259 | self.stem.eval()
260 | for p in self.stem.parameters():
261 | p.requires_grad = False
262 | if stage_index == 1:
263 | for p in self.stage1.parameters():
264 | p.requires_grad = False
265 | else:
266 | patch = getattr(self, "patch_embed_" + str(stage_index))
267 | stage = getattr(self, "stage" + str(stage_index))
268 | for p in patch.parameters():
269 | p.requires_grad = False
270 | for p in stage.parameters():
271 | p.requires_grad = False
272 |
273 | def forward(self, x):
274 | outputs = {}
275 | x = self.stem(x)
276 | B, _, H, W = x.shape
277 | x = x.flatten(2).permute(0, 2, 1)
278 |
279 | # stage 1
280 | for blk in self.stage1:
281 | x = blk(x, H, W)
282 | if "stage1" in self._out_features:
283 | x = self.norm1(x)
284 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
285 | outputs["stage1"] = x
286 | else:
287 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
288 |
289 | # stage 2
290 | x, (H, W) = self.patch_embed_2(x)
291 | for blk in self.stage2:
292 | x = blk(x, H, W)
293 | if "stage2" in self._out_features:
294 | x = self.norm2(x)
295 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
296 | outputs["stage2"] = x
297 | else:
298 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
299 |
300 | # stage 3
301 | x, (H, W) = self.patch_embed_3(x)
302 | for blk in self.stage3:
303 | x = blk(x, H, W)
304 | if "stage3" in self._out_features:
305 | x = self.norm3(x)
306 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
307 | outputs["stage3"] = x
308 | else:
309 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
310 |
311 | # stage 4
312 | x, (H, W) = self.patch_embed_4(x)
313 | for blk in self.stage4:
314 | x = blk(x, H, W)
315 | x = self.norm(x)
316 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
317 | if "stage4" in self._out_features:
318 | outputs["stage4"] = x
319 |
320 | return outputs
321 |
322 |
323 | @BACKBONE_REGISTRY.register()
324 | def build_rest_backbone(cfg, input_shape):
325 | """
326 | Create a ResT instance from config.
327 | Returns:
328 | ResT: a :class:`ResT` instance.
329 | """
330 | name = cfg.MODEL.REST.NAME
331 | out_features = cfg.MODEL.REST.OUT_FEATURES
332 |
333 | depths = {"rest_lite": [2, 2, 2, 2], "rest_small": [2, 2, 6, 2],
334 | "rest_base": [2, 2, 6, 2], "rest_large": [2, 2, 18, 2]}[name]
335 |
336 | embed_dims = {"rest_lite": [64, 128, 256, 512], "rest_small": [64, 128, 256, 512],
337 | "rest_base": [96, 192, 384, 768], "rest_large": [96, 192, 384, 768]}[name]
338 |
339 | drop_path_rate = {"rest_lite": 0.1, "rest_small": 0.1, "rest_base": 0.2, "rest_large": 0.2}[name]
340 |
341 | feature_names = ['stage1', 'stage2', 'stage3', 'stage4']
342 | out_feature_channels = dict(zip(feature_names, embed_dims))
343 | out_feature_strides = {"stage1": 4, "stage2": 8, "stage3": 16, "stage4": 32}
344 |
345 | model = ResT(cfg, in_ch=3, embed_dims=embed_dims, qkv_bias=True, drop_path_rate=drop_path_rate,
346 | depths=depths, apply_transform=True, out_features=out_features)
347 | model._out_feature_channels = out_feature_channels
348 | model._out_feature_strides = out_feature_strides
349 | return model
350 |
351 |
352 | @BACKBONE_REGISTRY.register()
353 | def build_rest_fpn_backbone(cfg, input_shape: ShapeSpec):
354 | """
355 | Args:
356 | cfg: a detectron2 CfgNode
357 | Returns:
358 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
359 | """
360 | bottom_up = build_rest_backbone(cfg, input_shape)
361 | in_features = cfg.MODEL.FPN.IN_FEATURES
362 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
363 | backbone = FPN(
364 | bottom_up=bottom_up,
365 | in_features=in_features,
366 | out_channels=out_channels,
367 | norm=cfg.MODEL.FPN.NORM,
368 | top_block=LastLevelMaxPool(),
369 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
370 | )
371 | return backbone
372 |
373 |
374 | @BACKBONE_REGISTRY.register()
375 | def build_retinanet_rest_fpn_backbone(cfg, input_shape: ShapeSpec):
376 | """
377 | Args:
378 | cfg: a detectron2 CfgNode
379 |
380 | Returns:
381 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
382 | """
383 | bottom_up = build_rest_backbone(cfg, input_shape)
384 | in_features = cfg.MODEL.FPN.IN_FEATURES
385 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
386 | in_channels_p6p7 = bottom_up.output_shape()["stage4"].channels
387 | backbone = FPN(
388 | bottom_up=bottom_up,
389 | in_features=in_features,
390 | out_channels=out_channels,
391 | norm=cfg.MODEL.FPN.NORM,
392 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels, in_feature="stage4"),
393 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
394 | )
395 | return backbone
--------------------------------------------------------------------------------
/object_detection/restv2/restv2.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # ResT V2
3 | # Copyright (c) VCU, Nanjing University
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Qing-Long Zhang
6 | # --------------------------------------------------------
7 |
8 | import torch
9 | import torch.nn as nn
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
12 | from detectron2.modeling.backbone.fpn import FPN, LastLevelP6P7, LastLevelMaxPool
13 | from utils import load_state_dict
14 | from utils import LayerNorm
15 |
16 | __all__ = [
17 | "Attention",
18 | "ResTV2",
19 | "build_restv2_backbone",
20 | "build_restv2_fpn_backbone",
21 | "build_retinanet_restv2_fpn_backbone",
22 | ]
23 |
24 |
25 | class Mlp(nn.Module):
26 | def __init__(self, dim):
27 | super().__init__()
28 | self.fc1 = nn.Linear(dim, 4 * dim)
29 | self.act = nn.GELU()
30 | self.fc2 = nn.Linear(4 * dim, dim)
31 |
32 | def forward(self, x):
33 | x = self.fc1(x)
34 | x = self.act(x)
35 | x = self.fc2(x)
36 | return x
37 |
38 |
39 | class Attention(nn.Module):
40 | def __init__(self,
41 | dim,
42 | num_heads=8,
43 | sr_ratio=1):
44 | super().__init__()
45 | self.num_heads = num_heads
46 | head_dim = dim // num_heads
47 | self.scale = head_dim ** -0.5
48 |
49 | self.q = nn.Linear(dim, dim, bias=True)
50 | self.kv = nn.Linear(dim, dim * 2, bias=True)
51 |
52 | self.sr_ratio = sr_ratio
53 | if sr_ratio > 1:
54 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
55 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6)
56 |
57 | self.up = nn.Sequential(
58 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim),
59 | nn.PixelShuffle(upscale_factor=sr_ratio)
60 | )
61 | self.up_norm = nn.LayerNorm(dim, eps=1e-6)
62 |
63 | self.proj = nn.Linear(dim, dim)
64 |
65 | def forward(self, x, H, W):
66 | B, N, C = x.shape
67 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
68 | if self.sr_ratio > 1:
69 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
70 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
71 | x = self.sr_norm(x)
72 |
73 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
74 | k, v = kv[0], kv[1]
75 | attn = (q @ k.transpose(-2, -1)) * self.scale
76 | attn = attn.softmax(dim=-1)
77 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
78 |
79 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio)
80 | identity = self.up(identity).flatten(2).transpose(1, 2)
81 | x = self.proj(x + self.up_norm(identity))
82 | return x
83 |
84 |
85 | class Block(nn.Module):
86 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.):
87 | super().__init__()
88 | self.norm1 = nn.LayerNorm(dim, eps=1e-6)
89 | self.attn = Attention(dim, num_heads, sr_ratio)
90 |
91 | self.norm2 = nn.LayerNorm(dim, eps=1e-6)
92 | self.mlp = Mlp(dim)
93 |
94 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
95 |
96 | def forward(self, x, H, W):
97 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm
98 | x = x + self.drop_path(self.mlp(self.norm2(x)))
99 | return x
100 |
101 |
102 | class PA(nn.Module):
103 | def __init__(self, dim):
104 | super().__init__()
105 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True)
106 | self.sigmoid = nn.Sigmoid()
107 |
108 | def forward(self, x):
109 | return x * self.sigmoid(self.pa_conv(x))
110 |
111 |
112 | class ConvStem(nn.Module):
113 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
114 | super().__init__()
115 | self.patch_size = to_2tuple(patch_size)
116 | stem = []
117 | in_dim, out_dim = in_ch, out_ch // 2
118 | for i in range(2):
119 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False))
120 | stem.append(nn.BatchNorm2d(out_dim))
121 | stem.append(nn.ReLU(inplace=True))
122 | in_dim, out_dim = out_dim, out_dim * 2
123 |
124 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1))
125 | self.proj = nn.Sequential(*stem)
126 |
127 | self.with_pos = with_pos
128 | if self.with_pos:
129 | self.pos = PA(out_ch)
130 |
131 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
132 |
133 | def forward(self, x):
134 | B, C, H, W = x.shape
135 | x = self.proj(x)
136 | if self.with_pos:
137 | x = self.pos(x)
138 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
139 | x = self.norm(x)
140 | H, W = H // self.patch_size[0], W // self.patch_size[1]
141 | return x, (H, W)
142 |
143 |
144 | class PatchEmbed(nn.Module):
145 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
146 | super().__init__()
147 | self.patch_size = to_2tuple(patch_size)
148 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2)
149 |
150 | self.with_pos = with_pos
151 | if self.with_pos:
152 | self.pos = PA(out_ch)
153 |
154 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
155 |
156 | def forward(self, x):
157 | B, C, H, W = x.shape
158 | x = self.proj(x)
159 | if self.with_pos:
160 | x = self.pos(x)
161 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
162 | x = self.norm(x)
163 | H, W = H // self.patch_size[0], W // self.patch_size[1]
164 | return x, (H, W)
165 |
166 |
167 | class ResTV2(Backbone):
168 | def __init__(self, in_chans=3, embed_dims=[96, 192, 384, 768],
169 | num_heads=[1, 2, 4, 8], drop_path_rate=0.,
170 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_features=None,
171 | frozen_stages=-1, pretrained=None):
172 | super().__init__()
173 |
174 | self.depths = depths
175 | self.frozen_stages = frozen_stages
176 | self.pretrained = pretrained
177 | self._out_features = out_features
178 |
179 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4)
180 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2)
181 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2)
182 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2)
183 |
184 | # transformer encoder
185 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
186 | cur = 0
187 | self.stage1 = nn.ModuleList([
188 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i])
189 | for i in range(depths[0])
190 | ])
191 |
192 | cur += depths[0]
193 | self.stage2 = nn.ModuleList([
194 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i])
195 | for i in range(depths[1])
196 | ])
197 |
198 | cur += depths[1]
199 | self.stage3 = nn.ModuleList([
200 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i])
201 | for i in range(depths[2])
202 | ])
203 |
204 | cur += depths[2]
205 | self.stage4 = nn.ModuleList([
206 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i])
207 | for i in range(depths[3])
208 | ])
209 |
210 | # add a norm layer for each output
211 | for i_layer in out_features:
212 | idx = int(i_layer[-1])
213 | out_ch = embed_dims[idx - 1]
214 | layer = nn.Sequential(
215 | LayerNorm(out_ch, eps=1e-6, data_format="channels_first"),
216 | )
217 | layer_name = f"norm_{idx}"
218 | self.add_module(layer_name, layer)
219 |
220 | self._freeze_stages()
221 | self.init_weights()
222 |
223 | def _freeze_stages(self):
224 | if self.frozen_stages >= 0:
225 | self.stem.eval()
226 | for param in self.stem.parameters():
227 | param.requires_grad = False
228 |
229 | if self.frozen_stages >= 1:
230 | self.stage1.eval()
231 | for param in self.stage1.parameters():
232 | param.requires_grad = False
233 |
234 | if self.frozen_stages >= 2:
235 | for i in range(0, self.frozen_stages):
236 | patch = getattr(self, "patch_embed_" + str(i))
237 | stage = getattr(self, "stage" + str(i))
238 | patch.eval()
239 | stage.eval()
240 | for param in patch.parameters():
241 | param.requires_grad = False
242 | for param in stage.parameters():
243 | param.requires_grad = False
244 |
245 | def init_weights(self):
246 | if self.pretrained is None:
247 | for m in self.modules():
248 | if isinstance(m, (nn.Conv2d, nn.Linear)):
249 | trunc_normal_(m.weight, std=0.02)
250 | if m.bias is not None:
251 | nn.init.constant_(m.bias, 0)
252 | elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
253 | nn.init.constant_(m.weight, 1.)
254 | nn.init.constant_(m.bias, 0.)
255 | else:
256 | checkpoint = torch.load(self.pretrained, map_location='cpu')
257 | if 'state_dict' in checkpoint:
258 | state_dict = checkpoint['state_dict']
259 | elif 'model' in checkpoint:
260 | state_dict = checkpoint['model']
261 | else:
262 | state_dict = checkpoint
263 | # strip prefix of state_dict
264 | if list(state_dict.keys())[0].startswith('module.'):
265 | state_dict = {k[7:]: v for k, v in state_dict.items()}
266 | load_state_dict(self, state_dict)
267 |
268 | def forward(self, x):
269 | outs = {}
270 | B, _, H, W = x.shape
271 | x, (H, W) = self.stem(x)
272 | # stage 1
273 | for blk in self.stage1:
274 | x = blk(x, H, W)
275 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
276 | if "stage1" in self._out_features:
277 | outs["stage1"] = self.norm_1(x)
278 |
279 | # stage 2
280 | x, (H, W) = self.patch_2(x)
281 | for blk in self.stage2:
282 | x = blk(x, H, W)
283 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
284 | if "stage2" in self._out_features:
285 | outs["stage2"] = self.norm_2(x)
286 |
287 | # stage 3
288 | x, (H, W) = self.patch_3(x)
289 | for blk in self.stage3:
290 | x = blk(x, H, W)
291 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
292 | if "stage3" in self._out_features:
293 | outs["stage3"] = self.norm_3(x)
294 |
295 | # stage 4
296 | x, (H, W) = self.patch_4(x)
297 | for blk in self.stage4:
298 | x = blk(x, H, W)
299 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
300 | if "stage4" in self._out_features:
301 | outs["stage4"] = self.norm_4(x)
302 |
303 | return outs
304 |
305 | @property
306 | def size_divisibility(self) -> int:
307 | return 32
308 |
309 |
310 | @BACKBONE_REGISTRY.register()
311 | def build_restv2_backbone(cfg, input_shape: ShapeSpec):
312 | name = cfg.MODEL.RESTV2.NAME
313 | out_features = cfg.MODEL.RESTV2.OUT_FEATURES
314 | settings = {
315 | "restv2_tiny": {"depths": [1, 2, 6, 2], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.1},
316 | "restv2_small": {"depths": [1, 2, 12, 2], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.2},
317 | "restv2_base": {"depths": [1, 3, 16, 3], "embed_dims": [96, 192, 384, 768], "drop_path_rate": 0.3},
318 | "restv2_large": {"depths": [2, 3, 16, 3], "embed_dims": [128, 256, 512, 1024], "drop_path_rate": 0.5},
319 | }
320 | feature_names = ['stage1', 'stage2', 'stage3', 'stage4']
321 | out_feature_channels = dict(zip(feature_names, settings[name]["embed_dims"]))
322 | out_feature_strides = {"stage1": 4, "stage2": 8, "stage3": 16, "stage4": 32}
323 | model = ResTV2(embed_dims=settings[name]["embed_dims"],
324 | depths=settings[name]["depths"], drop_path_rate=settings[name]["drop_path_rate"],
325 | out_features=out_features)
326 | model._out_feature_channels = out_feature_channels
327 | model._out_feature_strides = out_feature_strides
328 | return model
329 |
330 |
331 | @BACKBONE_REGISTRY.register()
332 | def build_restv2_fpn_backbone(cfg, input_shape: ShapeSpec):
333 | """
334 | Args:
335 | cfg: a detectron2 CfgNode
336 | input_shape: ShapeSpec
337 | Returns:
338 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
339 | """
340 | bottom_up = build_restv2_backbone(cfg, input_shape)
341 | in_features = cfg.MODEL.FPN.IN_FEATURES
342 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
343 | backbone = FPN(
344 | bottom_up=bottom_up,
345 | in_features=in_features,
346 | out_channels=out_channels,
347 | norm=cfg.MODEL.FPN.NORM,
348 | top_block=LastLevelMaxPool(),
349 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
350 | )
351 | return backbone
352 |
353 |
354 | @BACKBONE_REGISTRY.register()
355 | def build_retinanet_restv2_fpn_backbone(cfg, input_shape: ShapeSpec):
356 | """
357 | Args:
358 | cfg: a detectron2 CfgNode
359 | input_shape: ShapeSpec
360 | Returns:
361 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
362 | """
363 | bottom_up = build_restv2_backbone(cfg, input_shape)
364 | in_features = cfg.MODEL.FPN.IN_FEATURES
365 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
366 | in_channels_p6p7 = bottom_up.output_shape()["stage4"].channels
367 | backbone = FPN(
368 | bottom_up=bottom_up,
369 | in_features=in_features,
370 | out_channels=out_channels,
371 | norm=cfg.MODEL.FPN.NORM,
372 | top_block=LastLevelP6P7(in_channels_p6p7, out_channels, in_feature="stage4"),
373 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
374 | )
375 | return backbone
376 |
--------------------------------------------------------------------------------
/object_detection/train_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # ------------------------------------------------------------
3 | # Copyright (c) VCU, Nanjing University.
4 | # Licensed under the Apache License 2.0 [see LICENSE for details]
5 | # Written by Qing-Long Zhang
6 | # ------------------------------------------------------------
7 |
8 | """
9 | Detection Training Script.
10 | This scripts reads a given config file and runs the training or evaluation.
11 | It is an entry point that is made to train standard models in detectron2.
12 | In order to let one script support training of many models,
13 | this script contains logic that are specific to these built-in models and therefore
14 | may not be suitable for your own project.
15 | For example, your research project perhaps only needs a single "evaluator".
16 | Therefore, we recommend you to use detectron2 as an library and take
17 | this file as an example of how to use the library.
18 | You may want to write your own script with your datasets and other customizations.
19 | """
20 | import copy
21 | from typing import Any, Dict, List, Set
22 | import itertools
23 | import logging
24 | import os
25 | from collections import OrderedDict
26 | import torch
27 |
28 | import detectron2.utils.comm as comm
29 | from detectron2.checkpoint import DetectionCheckpointer
30 | from detectron2.config import get_cfg
31 | from detectron2.data import MetadataCatalog
32 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
33 | from detectron2.evaluation import (
34 | CityscapesInstanceEvaluator,
35 | CityscapesSemSegEvaluator,
36 | COCOEvaluator,
37 | COCOPanopticEvaluator,
38 | DatasetEvaluators,
39 | LVISEvaluator,
40 | PascalVOCDetectionEvaluator,
41 | SemSegEvaluator,
42 | verify_results,
43 | )
44 |
45 | from detectron2.modeling import GeneralizedRCNNWithTTA
46 | from detectron2.solver.build import maybe_add_gradient_clipping, get_default_optimizer_params
47 | from restv2 import add_restv2_config, add_restv1_config
48 |
49 |
50 | def build_evaluator(cfg, dataset_name, output_folder=None):
51 | """
52 | Create evaluator(s) for a given dataset.
53 | This uses the special metadata "evaluator_type" associated with each builtin dataset.
54 | For your own dataset, you can simply create an evaluator manually in your
55 | script and do not have to worry about the hacky if-else logic here.
56 | """
57 | if output_folder is None:
58 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
59 | evaluator_list = []
60 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
61 | if evaluator_type in ["sem_seg", "coco_panoptic_seg"]:
62 | evaluator_list.append(
63 | SemSegEvaluator(
64 | dataset_name,
65 | distributed=True,
66 | output_dir=output_folder,
67 | )
68 | )
69 | if evaluator_type in ["coco", "coco_panoptic_seg"]:
70 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder))
71 | if evaluator_type == "coco_panoptic_seg":
72 | evaluator_list.append(COCOPanopticEvaluator(dataset_name, output_folder))
73 | if evaluator_type == "cityscapes_instance":
74 | return CityscapesInstanceEvaluator(dataset_name)
75 | if evaluator_type == "cityscapes_sem_seg":
76 | return CityscapesSemSegEvaluator(dataset_name)
77 | elif evaluator_type == "pascal_voc":
78 | return PascalVOCDetectionEvaluator(dataset_name)
79 | elif evaluator_type == "lvis":
80 | return LVISEvaluator(dataset_name, output_dir=output_folder)
81 | if len(evaluator_list) == 0:
82 | raise NotImplementedError(
83 | "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type)
84 | )
85 | elif len(evaluator_list) == 1:
86 | return evaluator_list[0]
87 | return DatasetEvaluators(evaluator_list)
88 |
89 |
90 | class Trainer(DefaultTrainer):
91 | """
92 | We use the "DefaultTrainer" which contains pre-defined default logic for
93 | standard training workflow. They may not work for you, especially if you
94 | are working on a new research project. In that case you can write your
95 | own training loop. You can use "tools/plain_train_net.py" as an example.
96 | """
97 |
98 | @classmethod
99 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
100 | return build_evaluator(cfg, dataset_name, output_folder)
101 |
102 | @classmethod
103 | def test_with_TTA(cls, cfg, model):
104 | logger = logging.getLogger("detectron2.trainer")
105 | # In the end of training, run an evaluation with TTA
106 | # Only support some R-CNN models.
107 | logger.info("Running inference with test-time augmentation ...")
108 | model = GeneralizedRCNNWithTTA(cfg, model)
109 | evaluators = [
110 | cls.build_evaluator(
111 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
112 | )
113 | for name in cfg.DATASETS.TEST
114 | ]
115 | res = cls.test(cfg, model, evaluators)
116 | res = OrderedDict({k + "_TTA": v for k, v in res.items()})
117 | return res
118 |
119 | @classmethod
120 | def build_optimizer(cls, cfg, model):
121 | params = get_default_optimizer_params(
122 | model,
123 | base_lr=cfg.SOLVER.BASE_LR,
124 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
125 | weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
126 | bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
127 | weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
128 | )
129 |
130 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
131 | # detectron2 doesn't have full model gradient clipping now
132 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
133 | enable = (
134 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED
135 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
136 | and clip_norm_val > 0.0
137 | )
138 |
139 | class FullModelGradientClippingOptimizer(optim):
140 | def step(self, closure=None):
141 | all_params = itertools.chain(*[x["params"] for x in self.param_groups])
142 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
143 | super().step(closure=closure)
144 |
145 | return FullModelGradientClippingOptimizer if enable else optim
146 |
147 | optimizer_type = cfg.SOLVER.OPTIMIZER
148 | if optimizer_type == "SGD":
149 | optimizer = maybe_add_gradient_clipping(torch.optim.SGD)(
150 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
151 | nesterov=cfg.SOLVER.NESTEROV,
152 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
153 | )
154 | elif optimizer_type == "AdamW":
155 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
156 | params, cfg.SOLVER.BASE_LR, betas=(0.9, 0.999),
157 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
158 | )
159 | else:
160 | raise NotImplementedError(f"no optimizer type {optimizer_type}")
161 | return optimizer
162 |
163 |
164 | def setup(args):
165 | """
166 | Create configs and perform basic setups.
167 | """
168 | cfg = get_cfg()
169 | add_restv2_config(cfg)
170 | add_restv1_config(cfg)
171 | cfg.merge_from_file(args.config_file)
172 | cfg.merge_from_list(args.opts)
173 | cfg.freeze()
174 | default_setup(cfg, args)
175 |
176 | return cfg
177 |
178 |
179 | def main(args):
180 | cfg = setup(args)
181 |
182 | if args.eval_only:
183 | model = Trainer.build_model(cfg)
184 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
185 | cfg.MODEL.WEIGHTS, resume=args.resume
186 | )
187 | res = Trainer.test(cfg, model)
188 | if cfg.TEST.AUG.ENABLED:
189 | res.update(Trainer.test_with_TTA(cfg, model))
190 | if comm.is_main_process():
191 | verify_results(cfg, res)
192 | return res
193 |
194 | """
195 | If you'd like to do anything fancier than the standard training logic,
196 | consider writing your own training loop (see plain_train_net.py) or
197 | subclassing the trainer.
198 | """
199 | trainer = Trainer(cfg)
200 | trainer.resume_or_load(resume=args.resume)
201 | if cfg.TEST.AUG.ENABLED:
202 | trainer.register_hooks(
203 | [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
204 | )
205 | return trainer.train()
206 |
207 |
208 | if __name__ == "__main__":
209 | args = default_argument_parser().parse_args()
210 | print("Command Line Args:", args)
211 | launch(
212 | main,
213 | args.num_gpus,
214 | num_machines=args.num_machines,
215 | machine_rank=args.machine_rank,
216 | dist_url=args.dist_url,
217 | args=(args,),
218 | )
219 |
--------------------------------------------------------------------------------
/object_detection/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.distributed as dist
13 | import torchvision
14 | from torch import Tensor
15 |
16 |
17 | def _max_by_axis(the_list):
18 | # type: (List[List[int]]) -> List[int]
19 | maxes = the_list[0]
20 | for sublist in the_list[1:]:
21 | for index, item in enumerate(sublist):
22 | maxes[index] = max(maxes[index], item)
23 | return maxes
24 |
25 |
26 | class NestedTensor(object):
27 | def __init__(self, tensors, mask: Optional[Tensor]):
28 | self.tensors = tensors
29 | self.mask = mask
30 |
31 | def to(self, device):
32 | # type: (Device) -> NestedTensor # noqa
33 | cast_tensor = self.tensors.to(device)
34 | mask = self.mask
35 | if mask is not None:
36 | assert mask is not None
37 | cast_mask = mask.to(device)
38 | else:
39 | cast_mask = None
40 | return NestedTensor(cast_tensor, cast_mask)
41 |
42 | def decompose(self):
43 | return self.tensors, self.mask
44 |
45 | def __repr__(self):
46 | return str(self.tensors)
47 |
48 |
49 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
50 | # TODO make this more general
51 | if tensor_list[0].ndim == 3:
52 | if torchvision._is_tracing():
53 | # nested_tensor_from_tensor_list() does not export well to ONNX
54 | # call _onnx_nested_tensor_from_tensor_list() instead
55 | return _onnx_nested_tensor_from_tensor_list(tensor_list)
56 |
57 | # TODO make it support different-sized images
58 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
59 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
60 | batch_shape = [len(tensor_list)] + max_size
61 | b, c, h, w = batch_shape
62 | dtype = tensor_list[0].dtype
63 | device = tensor_list[0].device
64 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
65 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
66 | for img, pad_img, m in zip(tensor_list, tensor, mask):
67 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
68 | m[: img.shape[1], : img.shape[2]] = False
69 | else:
70 | raise ValueError("not supported")
71 | return NestedTensor(tensor, mask)
72 |
73 |
74 | # _onnx_nested_tensor_from_tensor_list() is an implementation of
75 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
76 | @torch.jit.unused
77 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
78 | max_size = []
79 | for i in range(tensor_list[0].dim()):
80 | max_size_i = torch.max(
81 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
82 | ).to(torch.int64)
83 | max_size.append(max_size_i)
84 | max_size = tuple(max_size)
85 |
86 | # work around for
87 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
88 | # m[: img.shape[1], :img.shape[2]] = False
89 | # which is not yet supported in onnx
90 | padded_imgs = []
91 | padded_masks = []
92 | for img in tensor_list:
93 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
94 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
95 | padded_imgs.append(padded_img)
96 |
97 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
98 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
99 | padded_masks.append(padded_mask.to(torch.bool))
100 |
101 | tensor = torch.stack(padded_imgs)
102 | mask = torch.stack(padded_masks)
103 |
104 | return NestedTensor(tensor, mask=mask)
105 |
106 |
107 | def is_dist_avail_and_initialized():
108 | if not dist.is_available():
109 | return False
110 | if not dist.is_initialized():
111 | return False
112 | return True
113 |
114 |
115 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
116 | missing_keys = []
117 | unexpected_keys = []
118 | error_msgs = []
119 | # copy state_dict so _load_from_state_dict can modify it
120 | metadata = getattr(state_dict, '_metadata', None)
121 | state_dict = state_dict.copy()
122 | if metadata is not None:
123 | state_dict._metadata = metadata
124 |
125 | def load(module, prefix=''):
126 | local_metadata = {} if metadata is None else metadata.get(
127 | prefix[:-1], {})
128 | module._load_from_state_dict(
129 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
130 | for name, child in module._modules.items():
131 | if child is not None:
132 | load(child, prefix + name + '.')
133 |
134 | load(model, prefix=prefix)
135 |
136 | warn_missing_keys = []
137 | ignore_missing_keys = []
138 | for key in missing_keys:
139 | keep_flag = True
140 | for ignore_key in ignore_missing.split('|'):
141 | if ignore_key in key:
142 | keep_flag = False
143 | break
144 | if keep_flag:
145 | warn_missing_keys.append(key)
146 | else:
147 | ignore_missing_keys.append(key)
148 |
149 | missing_keys = warn_missing_keys
150 |
151 | if len(missing_keys) > 0:
152 | print("Weights of {} not initialized from pretrained model: {}".format(
153 | model.__class__.__name__, missing_keys))
154 | if len(unexpected_keys) > 0:
155 | print("Weights from pretrained model not used in {}: {}".format(
156 | model.__class__.__name__, unexpected_keys))
157 | if len(ignore_missing_keys) > 0:
158 | print("Ignored weights of {} not initialized from pretrained model: {}".format(
159 | model.__class__.__name__, ignore_missing_keys))
160 | if len(error_msgs) > 0:
161 | print('\n'.join(error_msgs))
162 |
163 |
164 | class LayerNorm(nn.Module):
165 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
166 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
167 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
168 | with shape (batch_size, channels, height, width).
169 | """
170 |
171 | def __init__(self, dim, eps=1e-6, data_format="channels_last"):
172 | super().__init__()
173 | self.weight = nn.Parameter(torch.ones(dim))
174 | self.bias = nn.Parameter(torch.zeros(dim))
175 | self.eps = eps
176 | self.data_format = data_format
177 | if self.data_format not in ["channels_last", "channels_first"]:
178 | raise NotImplementedError
179 | self.dim = (dim,)
180 |
181 | def forward(self, x):
182 | if self.data_format == "channels_last":
183 | return F.layer_norm(x, self.dim, self.weight, self.bias, self.eps)
184 | elif self.data_format == "channels_first":
185 | u = x.mean(1, keepdim=True)
186 | s = (x - u).pow(2).mean(1, keepdim=True)
187 | x = (x - u) / torch.sqrt(s + self.eps)
188 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
189 | return x
190 |
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import torch
8 | from torch import optim as optim
9 |
10 | from timm.optim.adafactor import Adafactor
11 | from timm.optim.adahessian import Adahessian
12 | from timm.optim.adamp import AdamP
13 | from timm.optim.lookahead import Lookahead
14 | from timm.optim.nadam import Nadam
15 | from timm.optim.nvnovograd import NvNovoGrad
16 | from timm.optim.radam import RAdam
17 | from timm.optim.rmsprop_tf import RMSpropTF
18 | from timm.optim.sgdp import SGDP
19 | from timm.optim.lamb import Lamb
20 | from timm.optim.lars import Lars
21 |
22 | import json
23 |
24 | try:
25 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
26 |
27 | has_apex = True
28 | except ImportError:
29 | has_apex = False
30 |
31 |
32 | def get_layer_id_for_rest(model, name):
33 | """
34 | Assign a parameter with its layer id
35 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
36 | """
37 | depths = model.depths
38 | if name in ['cls_token', 'pos_embed']:
39 | return 0
40 | elif name.startswith('stem'):
41 | return 0
42 | elif name.startswith('patch'):
43 | layer_id = int(name.split('.')[0].split('_')[1])
44 | if layer_id == 2:
45 | layer_id = depths[0] + 1
46 | elif layer_id == 3:
47 | layer_id = sum(depths[:2]) + 2
48 | else:
49 | layer_id = sum(depths[:3]) + 3
50 | return int(layer_id)
51 | elif name.startswith('stage'):
52 | stage_id = int(name.split('.')[0][-1])
53 | layer_id = int(name.split('.')[1])
54 | if stage_id == 1:
55 | layer_id = layer_id + 1
56 | elif stage_id == 2:
57 | layer_id = depths[0] + layer_id + 2
58 | elif stage_id == 3:
59 | layer_id = sum(depths[:2]) + layer_id + 3
60 | else:
61 | layer_id = sum(depths[:3]) + layer_id + 4
62 | return int(layer_id)
63 | else:
64 | return sum(depths) + 4
65 |
66 |
67 | class LayerDecayValueAssigner(object):
68 | def __init__(self, values):
69 | self.values = values
70 |
71 | def get_scale(self, layer_id):
72 | return self.values[layer_id]
73 |
74 | def get_layer_id(self, model, var_name):
75 | return get_layer_id_for_rest(model, var_name)
76 |
77 |
78 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
79 | parameter_group_names = {}
80 | parameter_group_vars = {}
81 |
82 | for name, param in model.named_parameters():
83 | if not param.requires_grad:
84 | continue # frozen weights
85 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
86 | group_name = "no_decay"
87 | this_weight_decay = 0.
88 | else:
89 | group_name = "decay"
90 | this_weight_decay = weight_decay
91 | if get_num_layer is not None:
92 | layer_id = get_num_layer(model, name)
93 | print("layer_id: {}", layer_id)
94 | group_name = "layer_%d_%s" % (layer_id, group_name)
95 | else:
96 | layer_id = None
97 |
98 | if group_name not in parameter_group_names:
99 | if get_layer_scale is not None:
100 | scale = get_layer_scale(layer_id)
101 | else:
102 | scale = 1.
103 |
104 | parameter_group_names[group_name] = {
105 | "weight_decay": this_weight_decay,
106 | "params": [],
107 | "lr_scale": scale
108 | }
109 | parameter_group_vars[group_name] = {
110 | "weight_decay": this_weight_decay,
111 | "params": [],
112 | "lr_scale": scale
113 | }
114 |
115 | parameter_group_vars[group_name]["params"].append(param)
116 | parameter_group_names[group_name]["params"].append(name)
117 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
118 | return list(parameter_group_vars.values())
119 |
120 |
121 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
122 | opt_lower = args.opt.lower()
123 | weight_decay = args.weight_decay
124 | # if weight_decay and filter_bias_and_bn:
125 | if filter_bias_and_bn:
126 | skip = {}
127 | if skip_list is not None:
128 | skip = skip_list
129 | elif hasattr(model, 'no_weight_decay'):
130 | skip = model.no_weight_decay()
131 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
132 | weight_decay = 0.
133 | else:
134 | parameters = model.parameters()
135 |
136 | if 'fused' in opt_lower:
137 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
138 |
139 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
140 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
141 | opt_args['eps'] = args.opt_eps
142 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
143 | opt_args['betas'] = args.opt_betas
144 |
145 | opt_split = opt_lower.split('_')
146 | opt_lower = opt_split[-1]
147 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
148 | opt_args.pop('eps', None)
149 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
150 | elif opt_lower == 'momentum':
151 | opt_args.pop('eps', None)
152 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
153 | elif opt_lower == 'adam':
154 | optimizer = optim.Adam(parameters, **opt_args)
155 | elif opt_lower == 'adamw':
156 | optimizer = optim.AdamW(parameters, **opt_args)
157 | elif opt_lower == 'nadam':
158 | optimizer = Nadam(parameters, **opt_args)
159 | elif opt_lower == 'radam':
160 | optimizer = RAdam(parameters, **opt_args)
161 | elif opt_lower == 'adamp':
162 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
163 | elif opt_lower == 'sgdp':
164 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
165 | elif opt_lower == 'adadelta':
166 | optimizer = optim.Adadelta(parameters, **opt_args)
167 | elif opt_lower == 'adafactor':
168 | if not args.lr:
169 | opt_args['lr'] = None
170 | optimizer = Adafactor(parameters, **opt_args)
171 | elif opt_lower == 'lamb':
172 | optimizer = Lamb(parameters, **opt_args)
173 | elif opt_lower == 'lambc':
174 | optimizer = Lamb(parameters, trust_clip=True, **opt_args)
175 | elif opt_lower == 'lars':
176 | optimizer = Lars(parameters, momentum=args.momentum, **opt_args)
177 | elif opt_lower == 'adahessian':
178 | optimizer = Adahessian(parameters, **opt_args)
179 | elif opt_lower == 'rmsprop':
180 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
181 | elif opt_lower == 'rmsproptf':
182 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
183 | elif opt_lower == 'nvnovograd':
184 | optimizer = NvNovoGrad(parameters, **opt_args)
185 | elif opt_lower == 'fusedsgd':
186 | opt_args.pop('eps', None)
187 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
188 | elif opt_lower == 'fusedmomentum':
189 | opt_args.pop('eps', None)
190 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
191 | elif opt_lower == 'fusedadam':
192 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
193 | elif opt_lower == 'fusedadamw':
194 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
195 | elif opt_lower == 'fusedlamb':
196 | optimizer = FusedLAMB(parameters, **opt_args)
197 | elif opt_lower == 'fusednovograd':
198 | opt_args.setdefault('betas', (0.95, 0.98))
199 | optimizer = FusedNovoGrad(parameters, **opt_args)
200 | else:
201 | assert False and "Invalid optimizer"
202 |
203 | if len(opt_split) > 1:
204 | if opt_split[0] == 'lookahead':
205 | optimizer = Lookahead(optimizer)
206 |
207 | return optimizer
208 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.8.0
2 | torchvision>=0.9.0
3 | timm==0.5.4
4 | tensorboardX
5 | six
6 |
--------------------------------------------------------------------------------
/run_with_submitit.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import uuid
10 | from pathlib import Path
11 |
12 | import main as classification
13 | import submitit
14 |
15 |
16 | def parse_args():
17 | classification_parser = classification.get_args_parser()
18 | parser = argparse.ArgumentParser("Submitit for ResTv2", parents=[classification_parser])
19 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
20 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
21 | parser.add_argument("--timeout", default=72, type=int, help="Duration of the job, in hours")
22 | parser.add_argument("--job_name", default="restv2", type=str, help="Job name")
23 | parser.add_argument("--job_dir", default="", type=str, help="Job directory; leave empty for default")
24 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit")
25 | parser.add_argument("--use_volta32", action='store_true', default=True, help="Big models? Use this")
26 | parser.add_argument('--comment', default="", type=str,
27 | help='Comment to pass to scheduler, e.g. priority message')
28 | return parser.parse_args()
29 |
30 |
31 | def get_shared_folder() -> Path:
32 | user = os.getenv("USER")
33 | if Path("/checkpoint/").is_dir():
34 | p = Path(f"/checkpoint/{user}/restv2")
35 | p.mkdir(exist_ok=True)
36 | return p
37 | raise RuntimeError("No shared folder available")
38 |
39 |
40 | def get_init_file():
41 | # Init file must not exist, but it's parent dir must exist.
42 | os.makedirs(str(get_shared_folder()), exist_ok=True)
43 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
44 | if init_file.exists():
45 | os.remove(str(init_file))
46 | return init_file
47 |
48 |
49 | class Trainer(object):
50 | def __init__(self, args):
51 | self.args = args
52 |
53 | def __call__(self):
54 | import main as classification
55 |
56 | self._setup_gpu_args()
57 | classification.main(self.args)
58 |
59 | def checkpoint(self):
60 | import os
61 | import submitit
62 |
63 | self.args.dist_url = get_init_file().as_uri()
64 | self.args.auto_resume = True
65 | print("Requeuing ", self.args)
66 | empty_trainer = type(self)(self.args)
67 | return submitit.helpers.DelayedSubmission(empty_trainer)
68 |
69 | def _setup_gpu_args(self):
70 | import submitit
71 | from pathlib import Path
72 |
73 | job_env = submitit.JobEnvironment()
74 | self.args.output_dir = Path(self.args.job_dir)
75 | self.args.gpu = job_env.local_rank
76 | self.args.rank = job_env.global_rank
77 | self.args.world_size = job_env.num_tasks
78 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
79 |
80 |
81 | def main():
82 | args = parse_args()
83 |
84 | if args.job_dir == "":
85 | args.job_dir = get_shared_folder() / "%j"
86 |
87 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
88 |
89 | num_gpus_per_node = args.ngpus
90 | nodes = args.nodes
91 | timeout_min = args.timeout * 60
92 |
93 | partition = args.partition
94 | kwargs = {}
95 | if args.use_volta32:
96 | kwargs['slurm_constraint'] = 'volta32gb'
97 | if args.comment:
98 | kwargs['slurm_comment'] = args.comment
99 |
100 | executor.update_parameters(
101 | mem_gb=40 * num_gpus_per_node,
102 | gpus_per_node=num_gpus_per_node,
103 | tasks_per_node=num_gpus_per_node, # one task per GPU
104 | cpus_per_task=10,
105 | nodes=nodes,
106 | timeout_min=timeout_min, # max is 60 * 72
107 | # Below are cluster dependent parameters
108 | slurm_partition=partition,
109 | slurm_signal_delay_s=120,
110 | **kwargs
111 | )
112 |
113 | executor.update_parameters(name=args.job_name)
114 |
115 | args.dist_url = get_init_file().as_uri()
116 | args.output_dir = args.job_dir
117 |
118 | trainer = Trainer(args)
119 | job = executor.submit(trainer)
120 |
121 | print("Submitted job_id:", job.job_id)
122 |
123 |
124 | if __name__ == "__main__":
125 | main()
126 |
--------------------------------------------------------------------------------
/semantic_segmentation/README.md:
--------------------------------------------------------------------------------
1 | # ADE20k Semantic segmentation with ResTv2
2 |
3 | ## Getting started
4 |
5 | We add ResTv2 model and config files to the semantic_segmentation.
6 |
7 | ## Results and Fine-tuned Models
8 |
9 | | name | Pretrained Model | Method | Crop Size | Lr Schd | mIoU (ms+flip) | #params | FLOPs | FPS | Fine-tuned Model |
10 | |:---:|:---:|:---:|:---:| :---:|:---:|:---:|:---:| :---:| :---:|
11 | | ResTv2-T | ImageNet-1K | UPerNet | 512x512 | 160K | 47.3 | 62.1M | 977G | 22.4 | [baidu](https://pan.baidu.com/s/1X-hAafTLFnwJPQSI2BNOKw) |
12 | | ResTv2-S | ImageNet-1K | UPerNet | 512x512 | 160K | 49.2 | 72.9M | 1035G | 20.0 | [baidu](https://pan.baidu.com/s/1WHiL0Rf9JeOB76yh6WOvLQ) |
13 | | ResTv2-B | ImageNet-1K | UPerNet | 512x512 | 160K | 49.6 | 87.6M | 1095G | 19.2 | [baidu](https://pan.baidu.com/s/1dtkg68j3vCU-dxJxl8VFdg) |
14 |
15 | ### Training
16 |
17 | Command format:
18 | ```
19 | tools/dist_train.sh --work-dir --seed 0 --deterministic --options model.pretrained=
20 | ```
21 |
22 | For example, using a `ResTv2-T` backbone with UperNet:
23 | ```bash
24 | bash tools/dist_train.sh \
25 | configs/ResTv2/upernet_restv2_tiny_512_160k_ade20k.py 8 \
26 | --work-dir /path/to/save --seed 0 --deterministic \
27 | --options model.pretrained=ResTv2_tiny_224.pth
28 | ```
29 |
30 | More config files can be found at [`configs/ResTv2`](configs/ResTv2).
31 |
32 |
33 | ## Evaluation
34 |
35 | Command format:
36 | ```
37 | tools/dist_test.sh --eval mIoU --aug-test
38 | ```
39 |
40 | For example, evaluate a `ResTv2-T` backbone with UperNet:
41 | ```bash
42 | bash tools/dist_test.sh configs/ResTv2/upernet_ResTv2_tiny_512_160k_ade20k.py \
43 | upernet_restv2_tiny_512_160k_ade20k.pth 4 --eval mIoU --aug-test
44 | ```
45 |
46 | ## Acknowledgment
47 |
48 | This code is built using the [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) library, [Timm](https://github.com/rwightman/pytorch-image-models) library.
--------------------------------------------------------------------------------
/semantic_segmentation/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .rest_v2 import ResTV2
2 |
3 | __all__ = ['ResTV2']
4 |
--------------------------------------------------------------------------------
/semantic_segmentation/backbone/rest_v2.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import warnings
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
13 |
14 | from mmcv.runner import BaseModule, _load_checkpoint
15 | from mmseg.utils import get_root_logger
16 | from mmseg.models.builder import BACKBONES
17 |
18 |
19 | class Mlp(nn.Module):
20 | def __init__(self, dim):
21 | super().__init__()
22 | self.fc1 = nn.Linear(dim, 4 * dim)
23 | self.act = nn.GELU()
24 | self.fc2 = nn.Linear(4 * dim, dim)
25 |
26 | def forward(self, x):
27 | x = self.fc1(x)
28 | x = self.act(x)
29 | x = self.fc2(x)
30 | return x
31 |
32 |
33 | class Attention(nn.Module):
34 | def __init__(self,
35 | dim,
36 | num_heads=8,
37 | sr_ratio=1):
38 | super().__init__()
39 | self.dim = dim
40 | self.num_heads = num_heads
41 | head_dim = dim // num_heads
42 | self.scale = head_dim ** -0.5
43 |
44 | self.q = nn.Linear(dim, dim, bias=True)
45 | self.kv = nn.Linear(dim, dim * 2, bias=True)
46 |
47 | self.sr_ratio = sr_ratio
48 | if sr_ratio > 1:
49 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio + 1, stride=sr_ratio, padding=sr_ratio // 2, groups=dim)
50 | self.sr_norm = nn.LayerNorm(dim, eps=1e-6)
51 |
52 | self.up = nn.Sequential(
53 | nn.Conv2d(dim, sr_ratio * sr_ratio * dim, kernel_size=3, stride=1, padding=1, groups=dim),
54 | nn.PixelShuffle(upscale_factor=sr_ratio)
55 | )
56 | self.up_norm = nn.LayerNorm(dim, eps=1e-6)
57 |
58 | self.proj = nn.Linear(dim, dim)
59 |
60 | def forward(self, x, H, W):
61 | B, N, C = x.shape
62 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
63 | if self.sr_ratio > 1:
64 | x = x.permute(0, 2, 1).reshape(B, C, H, W)
65 | x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
66 | x = self.sr_norm(x)
67 |
68 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69 | k, v = kv[0], kv[1]
70 | attn = (q @ k.transpose(-2, -1)) * self.scale
71 | attn = attn.softmax(dim=-1)
72 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
73 |
74 | identity = v.transpose(-1, -2).reshape(B, C, H // self.sr_ratio, W // self.sr_ratio)
75 | identity = self.up(identity).flatten(2).transpose(1, 2)
76 | x = self.proj(x + self.up_norm(identity))
77 | return x
78 |
79 |
80 | class Block(nn.Module):
81 | def __init__(self, dim, num_heads, sr_ratio=1, drop_path=0.):
82 | super().__init__()
83 | self.norm1 = nn.LayerNorm(dim, eps=1e-6)
84 | self.attn = Attention(dim, num_heads, sr_ratio)
85 |
86 | self.norm2 = nn.LayerNorm(dim, eps=1e-6)
87 | self.mlp = Mlp(dim)
88 |
89 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
90 |
91 | def forward(self, x, H, W):
92 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) # pre_norm
93 | x = x + self.drop_path(self.mlp(self.norm2(x)))
94 | return x
95 |
96 |
97 | class PA(nn.Module):
98 | def __init__(self, dim):
99 | super().__init__()
100 | self.pa_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=True)
101 | self.sigmoid = nn.Sigmoid()
102 |
103 | def forward(self, x):
104 | return x * self.sigmoid(self.pa_conv(x))
105 |
106 |
107 | class ConvStem(nn.Module):
108 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
109 | super().__init__()
110 | self.patch_size = to_2tuple(patch_size)
111 | stem = []
112 | in_dim, out_dim = in_ch, out_ch // 2
113 | for i in range(2):
114 | stem.append(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=2, padding=1, bias=False))
115 | stem.append(nn.BatchNorm2d(out_dim))
116 | stem.append(nn.ReLU(inplace=True))
117 | in_dim, out_dim = out_dim, out_dim * 2
118 |
119 | stem.append(nn.Conv2d(in_dim, out_ch, kernel_size=1, stride=1))
120 | self.proj = nn.Sequential(*stem)
121 |
122 | self.with_pos = with_pos
123 | if self.with_pos:
124 | self.pos = PA(out_ch)
125 |
126 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
127 |
128 | def forward(self, x):
129 | B, C, H, W = x.shape
130 | x = self.proj(x)
131 | if self.with_pos:
132 | x = self.pos(x)
133 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
134 | x = self.norm(x)
135 | H, W = H // self.patch_size[0], W // self.patch_size[1]
136 | return x, (H, W)
137 |
138 |
139 | class PatchEmbed(nn.Module):
140 | def __init__(self, in_ch=3, out_ch=96, patch_size=2, with_pos=True):
141 | super().__init__()
142 | self.patch_size = to_2tuple(patch_size)
143 | self.proj = nn.Conv2d(in_ch, out_ch, kernel_size=patch_size + 1, stride=patch_size, padding=patch_size // 2)
144 |
145 | self.with_pos = with_pos
146 | if self.with_pos:
147 | self.pos = PA(out_ch)
148 |
149 | self.norm = nn.LayerNorm(out_ch, eps=1e-6)
150 |
151 | def forward(self, x):
152 | B, C, H, W = x.shape
153 | x = self.proj(x)
154 | if self.with_pos:
155 | x = self.pos(x)
156 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
157 | x = self.norm(x)
158 | H, W = H // self.patch_size[0], W // self.patch_size[1]
159 | return x, (H, W)
160 |
161 |
162 | @BACKBONES.register_module()
163 | class ResTV2(BaseModule):
164 | def __init__(self, in_chans=3, embed_dims=[96, 192, 384, 768],
165 | num_heads=[1, 2, 4, 8], drop_path_rate=0.,
166 | depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1], out_indices=(0, 1, 2, 3),
167 | pretrained=None, init_cfg=None):
168 | super().__init__()
169 | if isinstance(pretrained, str) or pretrained is None:
170 | warnings.warn('DeprecationWarning: pretrained is a deprecated, please use "init_cfg" instead')
171 | else:
172 | raise TypeError('pretrained must be a str or None')
173 | self.pretrained = pretrained
174 | self.init_cfg = init_cfg
175 |
176 | self.depths = depths
177 | self.out_indices = out_indices
178 |
179 | self.stem = ConvStem(in_chans, embed_dims[0], patch_size=4)
180 | self.patch_2 = PatchEmbed(embed_dims[0], embed_dims[1], patch_size=2)
181 | self.patch_3 = PatchEmbed(embed_dims[1], embed_dims[2], patch_size=2)
182 | self.patch_4 = PatchEmbed(embed_dims[2], embed_dims[3], patch_size=2)
183 |
184 | # transformer encoder
185 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
186 | cur = 0
187 | self.stage1 = nn.ModuleList([
188 | Block(embed_dims[0], num_heads[0], sr_ratios[0], dpr[cur + i])
189 | for i in range(depths[0])
190 | ])
191 |
192 | cur += depths[0]
193 | self.stage2 = nn.ModuleList([
194 | Block(embed_dims[1], num_heads[1], sr_ratios[1], dpr[cur + i])
195 | for i in range(depths[1])
196 | ])
197 |
198 | cur += depths[1]
199 | self.stage3 = nn.ModuleList([
200 | Block(embed_dims[2], num_heads[2], sr_ratios[2], dpr[cur + i])
201 | for i in range(depths[2])
202 | ])
203 |
204 | cur += depths[2]
205 | self.stage4 = nn.ModuleList([
206 | Block(embed_dims[3], num_heads[3], sr_ratios[3], dpr[cur + i])
207 | for i in range(depths[3])
208 | ])
209 |
210 | for idx in out_indices:
211 | out_ch = embed_dims[idx]
212 | layer = LayerNorm(out_ch, eps=1e-6, data_format="channels_first")
213 | layer_name = f"norm_{idx + 1}"
214 | self.add_module(layer_name, layer)
215 |
216 | def init_weights(self):
217 | if self.pretrained is None:
218 | super().init_weights()
219 | for m in self.modules():
220 | if isinstance(m, (nn.Conv2d, nn.Linear)):
221 | trunc_normal_(m.weight, std=0.02)
222 | if m.bias is not None:
223 | nn.init.constant_(m.bias, 0)
224 |
225 | elif isinstance(self.pretrained, str):
226 | logger = get_root_logger()
227 | ckpt = _load_checkpoint(self.pretrained, logger=logger, map_location='cpu')
228 | if 'state_dict' in ckpt:
229 | state_dict = ckpt['state_dict']
230 | elif 'model' in ckpt:
231 | state_dict = ckpt['model']
232 | else:
233 | state_dict = ckpt
234 |
235 | # strip prefix of state_dict
236 | if list(state_dict.keys())[0].startswith('module.'):
237 | state_dict = {k[7:]: v for k, v in state_dict.items()}
238 | self.load_state_dict(state_dict, False)
239 |
240 | def forward(self, x):
241 | outs = []
242 | B, _, H, W = x.shape
243 | x, (H, W) = self.stem(x)
244 | # stage 1
245 | for blk in self.stage1:
246 | x = blk(x, H, W)
247 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
248 | if 0 in self.out_indices:
249 | outs.append(self.norm_1(x))
250 |
251 | # stage 2
252 | x, (H, W) = self.patch_2(x)
253 | for blk in self.stage2:
254 | x = blk(x, H, W)
255 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
256 | if 1 in self.out_indices:
257 | outs.append(self.norm_2(x))
258 |
259 | # stage 3
260 | x, (H, W) = self.patch_3(x)
261 | for blk in self.stage3:
262 | x = blk(x, H, W)
263 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
264 | if 2 in self.out_indices:
265 | outs.append(self.norm_3(x))
266 |
267 | # stage 4
268 | x, (H, W) = self.patch_4(x)
269 | for blk in self.stage4:
270 | x = blk(x, H, W)
271 | x = x.permute(0, 2, 1).reshape(B, -1, H, W)
272 | if 3 in self.out_indices:
273 | outs.append(self.norm_4(x))
274 |
275 | return tuple(outs)
276 |
277 |
278 | class LayerNorm(nn.Module):
279 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
280 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
281 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
282 | with shape (batch_size, channels, height, width).
283 | """
284 |
285 | def __init__(self, dim, eps=1e-6, data_format="channels_last"):
286 | super().__init__()
287 | self.weight = nn.Parameter(torch.ones(dim))
288 | self.bias = nn.Parameter(torch.zeros(dim))
289 | self.eps = eps
290 | self.data_format = data_format
291 | if self.data_format not in ["channels_last", "channels_first"]:
292 | raise NotImplementedError
293 | self.dim = (dim,)
294 |
295 | def forward(self, x):
296 | if self.data_format == "channels_last":
297 | return F.layer_norm(x, self.dim, self.weight, self.bias, self.eps)
298 | elif self.data_format == "channels_first":
299 | u = x.mean(1, keepdim=True)
300 | s = (x - u).pow(2).mean(1, keepdim=True)
301 | x = (x - u) / torch.sqrt(s + self.eps)
302 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
303 | return x
304 |
--------------------------------------------------------------------------------
/semantic_segmentation/configs/ResTv2/upernet_restv2_base_512_160k_ade20k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 |
8 | _base_ = [
9 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py',
10 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
11 | ]
12 | crop_size = (512, 512)
13 |
14 | model = dict(
15 | pretrained='./restv2_base.pth',
16 | backbone=dict(
17 | type='ResTV2',
18 | in_chans=3,
19 | num_heads=[1, 2, 4, 8],
20 | embed_dims=[96, 192, 384, 768],
21 | depths=[1, 3, 16, 3],
22 | drop_path_rate=0.3,
23 | out_indices=[0, 1, 2, 3],
24 | ),
25 | decode_head=dict(
26 | in_channels=[96, 192, 384, 768],
27 | num_classes=150,
28 | ),
29 | auxiliary_head=dict(
30 | in_channels=384,
31 | num_classes=150
32 | ),
33 | )
34 |
35 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
36 | lr=0.00015, betas=(0.9, 0.999), weight_decay=0.05,
37 | paramwise_cfg={'decay_rate': 0.9,
38 | 'decay_type': 'stage_wise',
39 | 'num_layers': (1, 3, 16, 3)})
40 |
41 | lr_config = dict(_delete_=True, policy='poly',
42 | warmup='linear',
43 | warmup_iters=1500,
44 | warmup_ratio=1e-6,
45 | power=1.0, min_lr=0.0, by_epoch=False)
46 |
47 | # By default, models are trained on 8 GPUs with 2 images per GPU
48 | data = dict(samples_per_gpu=2)
49 |
--------------------------------------------------------------------------------
/semantic_segmentation/configs/ResTv2/upernet_restv2_small_512_160k_ade20k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | _base_ = [
8 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py',
9 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
10 | ]
11 | crop_size = (512, 512)
12 |
13 | model = dict(
14 | pretrained='./restv2_small.pth',
15 | backbone=dict(
16 | type='ResTV2',
17 | in_chans=3,
18 | num_heads=[1, 2, 4, 8],
19 | embed_dims=[96, 192, 384, 768],
20 | depths=[1, 2, 12, 2],
21 | drop_path_rate=0.2,
22 | out_indices=[0, 1, 2, 3],
23 | ),
24 | decode_head=dict(
25 | in_channels=[96, 192, 384, 768],
26 | num_classes=150,
27 | ),
28 | auxiliary_head=dict(
29 | in_channels=384,
30 | num_classes=150
31 | ),
32 | )
33 |
34 | optimizer = dict(constructor='LearningRateDecayOptimizerConstructor', _delete_=True, type='AdamW',
35 | lr=0.00015, betas=(0.9, 0.999), weight_decay=0.05,
36 | paramwise_cfg={'decay_rate': 0.9,
37 | 'decay_type': 'stage_wise',
38 | 'num_layers': (1, 2, 12, 2)})
39 |
40 | lr_config = dict(_delete_=True, policy='poly',
41 | warmup='linear',
42 | warmup_iters=1500,
43 | warmup_ratio=1e-6,
44 | power=1.0, min_lr=0.0, by_epoch=False)
45 |
46 | # By default, models are trained on 8 GPUs with 2 images per GPU
47 | data = dict(samples_per_gpu=2)
48 |
--------------------------------------------------------------------------------
/semantic_segmentation/configs/ResTv2/upernet_restv2_tiny_512_160k_ade20k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 |
8 | _base_ = [
9 | '../_base_/models/upernet_restv2.py', '../_base_/datasets/ade20k.py',
10 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
11 | ]
12 | crop_size = (512, 512)
13 |
14 | model = dict(
15 | pretrained="",
16 | backbone=dict(
17 | type='ResTV2',
18 | in_chans=3,
19 | num_heads=[1, 2, 4, 8],
20 | embed_dims=[96, 192, 384, 768],
21 | depths=[1, 2, 6, 2],
22 | drop_path_rate=0.1,
23 | out_indices=(0, 1, 2, 3),
24 | ),
25 | decode_head=dict(
26 | in_channels=[96, 192, 384, 768],
27 | num_classes=150,
28 | ),
29 | auxiliary_head=dict(
30 | in_channels=384,
31 | num_classes=150
32 | ),
33 | )
34 |
35 | optimizer = dict(
36 | _delete_=True,
37 | type='AdamW',
38 | lr=0.00006,
39 | betas=(0.9, 0.999),
40 | weight_decay=0.01,
41 | paramwise_cfg=dict(
42 | custom_keys={
43 | 'absolute_pos_embed': dict(decay_mult=0.),
44 | 'relative_position_bias_table': dict(decay_mult=0.),
45 | 'norm': dict(decay_mult=0.)
46 | }))
47 |
48 | lr_config = dict(
49 | _delete_=True,
50 | policy='poly',
51 | warmup='linear',
52 | warmup_iters=1500,
53 | warmup_ratio=1e-6,
54 | power=1.0,
55 | min_lr=0.0,
56 | by_epoch=False)
57 |
58 | # By default, models are trained on 8 GPUs with 2 images per GPU
59 | data = dict(samples_per_gpu=2)
60 |
--------------------------------------------------------------------------------
/semantic_segmentation/configs/_base_/models/upernet_restv2.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | norm_cfg = dict(type='SyncBN', requires_grad=True)
3 | backbone_norm_cfg = dict(type='LN', requires_grad=True)
4 | model = dict(
5 | type='EncoderDecoder',
6 | pretrained=None,
7 | backbone=dict(
8 | type='ResTV2',
9 | in_chans=3,
10 | num_heads=[1, 2, 4, 8],
11 | embed_dims=[96, 192, 384, 768],
12 | depths=[1, 2, 6, 2],
13 | drop_path_rate=0.2,
14 | out_indices=[0, 1, 2, 3],
15 | ),
16 | decode_head=dict(
17 | type='UPerHead',
18 | in_channels=[96, 192, 384, 768],
19 | in_index=[0, 1, 2, 3],
20 | pool_scales=(1, 2, 3, 6),
21 | channels=512,
22 | dropout_ratio=0.1,
23 | num_classes=19,
24 | norm_cfg=norm_cfg,
25 | align_corners=False,
26 | loss_decode=dict(
27 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
28 | auxiliary_head=dict(
29 | type='FCNHead',
30 | in_channels=384,
31 | in_index=2,
32 | channels=256,
33 | num_convs=1,
34 | concat_input=False,
35 | dropout_ratio=0.1,
36 | num_classes=19,
37 | norm_cfg=norm_cfg,
38 | align_corners=False,
39 | loss_decode=dict(
40 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
41 | # model training and testing settings
42 | train_cfg=dict(),
43 | test_cfg=dict(mode='whole'))
44 |
--------------------------------------------------------------------------------
/semantic_segmentation/tools/align_resize.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 | import mmcv
8 | import numpy as np
9 | from mmcv.utils import deprecated_api_warning, is_tuple_of
10 | from numpy import random
11 |
12 | from mmseg.datasets.builder import PIPELINES
13 |
14 | @PIPELINES.register_module()
15 | class AlignResize(object):
16 | """Resize images & seg. Align
17 | """
18 |
19 | def __init__(self,
20 | img_scale=None,
21 | multiscale_mode='range',
22 | ratio_range=None,
23 | keep_ratio=True,
24 | size_divisor=32):
25 | if img_scale is None:
26 | self.img_scale = None
27 | else:
28 | if isinstance(img_scale, list):
29 | self.img_scale = img_scale
30 | else:
31 | self.img_scale = [img_scale]
32 | assert mmcv.is_list_of(self.img_scale, tuple)
33 |
34 | if ratio_range is not None:
35 | # mode 1: given img_scale=None and a range of image ratio
36 | # mode 2: given a scale and a range of image ratio
37 | assert self.img_scale is None or len(self.img_scale) == 1
38 | else:
39 | # mode 3 and 4: given multiple scales or a range of scales
40 | assert multiscale_mode in ['value', 'range']
41 |
42 | self.multiscale_mode = multiscale_mode
43 | self.ratio_range = ratio_range
44 | self.keep_ratio = keep_ratio
45 | self.size_divisor = size_divisor
46 |
47 | @staticmethod
48 | def random_select(img_scales):
49 | """Randomly select an img_scale from given candidates.
50 |
51 | Args:
52 | img_scales (list[tuple]): Images scales for selection.
53 |
54 | Returns:
55 | (tuple, int): Returns a tuple ``(img_scale, scale_dix)``,
56 | where ``img_scale`` is the selected image scale and
57 | ``scale_idx`` is the selected index in the given candidates.
58 | """
59 |
60 | assert mmcv.is_list_of(img_scales, tuple)
61 | scale_idx = np.random.randint(len(img_scales))
62 | img_scale = img_scales[scale_idx]
63 | return img_scale, scale_idx
64 |
65 | @staticmethod
66 | def random_sample(img_scales):
67 | """Randomly sample an img_scale when ``multiscale_mode=='range'``.
68 |
69 | Args:
70 | img_scales (list[tuple]): Images scale range for sampling.
71 | There must be two tuples in img_scales, which specify the lower
72 | and uper bound of image scales.
73 |
74 | Returns:
75 | (tuple, None): Returns a tuple ``(img_scale, None)``, where
76 | ``img_scale`` is sampled scale and None is just a placeholder
77 | to be consistent with :func:`random_select`.
78 | """
79 |
80 | assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
81 | img_scale_long = [max(s) for s in img_scales]
82 | img_scale_short = [min(s) for s in img_scales]
83 | long_edge = np.random.randint(
84 | min(img_scale_long),
85 | max(img_scale_long) + 1)
86 | short_edge = np.random.randint(
87 | min(img_scale_short),
88 | max(img_scale_short) + 1)
89 | img_scale = (long_edge, short_edge)
90 | return img_scale, None
91 |
92 | @staticmethod
93 | def random_sample_ratio(img_scale, ratio_range):
94 | """Randomly sample an img_scale when ``ratio_range`` is specified.
95 |
96 | A ratio will be randomly sampled from the range specified by
97 | ``ratio_range``. Then it would be multiplied with ``img_scale`` to
98 | generate sampled scale.
99 |
100 | Args:
101 | img_scale (tuple): Images scale base to multiply with ratio.
102 | ratio_range (tuple[float]): The minimum and maximum ratio to scale
103 | the ``img_scale``.
104 |
105 | Returns:
106 | (tuple, None): Returns a tuple ``(scale, None)``, where
107 | ``scale`` is sampled ratio multiplied with ``img_scale`` and
108 | None is just a placeholder to be consistent with
109 | :func:`random_select`.
110 | """
111 |
112 | assert isinstance(img_scale, tuple) and len(img_scale) == 2
113 | min_ratio, max_ratio = ratio_range
114 | assert min_ratio <= max_ratio
115 | ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
116 | scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
117 | return scale, None
118 |
119 | def _random_scale(self, results):
120 | """Randomly sample an img_scale according to ``ratio_range`` and
121 | ``multiscale_mode``.
122 |
123 | If ``ratio_range`` is specified, a ratio will be sampled and be
124 | multiplied with ``img_scale``.
125 | If multiple scales are specified by ``img_scale``, a scale will be
126 | sampled according to ``multiscale_mode``.
127 | Otherwise, single scale will be used.
128 |
129 | Args:
130 | results (dict): Result dict from :obj:`dataset`.
131 |
132 | Returns:
133 | dict: Two new keys 'scale` and 'scale_idx` are added into
134 | ``results``, which would be used by subsequent pipelines.
135 | """
136 |
137 | if self.ratio_range is not None:
138 | if self.img_scale is None:
139 | h, w = results['img'].shape[:2]
140 | scale, scale_idx = self.random_sample_ratio((w, h),
141 | self.ratio_range)
142 | else:
143 | scale, scale_idx = self.random_sample_ratio(
144 | self.img_scale[0], self.ratio_range)
145 | elif len(self.img_scale) == 1:
146 | scale, scale_idx = self.img_scale[0], 0
147 | elif self.multiscale_mode == 'range':
148 | scale, scale_idx = self.random_sample(self.img_scale)
149 | elif self.multiscale_mode == 'value':
150 | scale, scale_idx = self.random_select(self.img_scale)
151 | else:
152 | raise NotImplementedError
153 |
154 | results['scale'] = scale
155 | results['scale_idx'] = scale_idx
156 |
157 | def _align(self, img, size_divisor, interpolation=None):
158 | align_h = int(np.ceil(img.shape[0] / size_divisor)) * size_divisor
159 | align_w = int(np.ceil(img.shape[1] / size_divisor)) * size_divisor
160 | if interpolation == None:
161 | img = mmcv.imresize(img, (align_w, align_h))
162 | else:
163 | img = mmcv.imresize(img, (align_w, align_h), interpolation=interpolation)
164 | return img
165 |
166 | def _resize_img(self, results):
167 | """Resize images with ``results['scale']``."""
168 | if self.keep_ratio:
169 | img, scale_factor = mmcv.imrescale(
170 | results['img'], results['scale'], return_scale=True)
171 | #### align ####
172 | img = self._align(img, self.size_divisor)
173 | # the w_scale and h_scale has minor difference
174 | # a real fix should be done in the mmcv.imrescale in the future
175 | new_h, new_w = img.shape[:2]
176 | h, w = results['img'].shape[:2]
177 | w_scale = new_w / w
178 | h_scale = new_h / h
179 | else:
180 | img, w_scale, h_scale = mmcv.imresize(
181 | results['img'], results['scale'], return_scale=True)
182 |
183 | h, w = img.shape[:2]
184 | assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
185 | int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
186 | "img size not align. h:{} w:{}".format(h,w)
187 | scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
188 | dtype=np.float32)
189 | results['img'] = img
190 | results['img_shape'] = img.shape
191 | results['pad_shape'] = img.shape # in case that there is no padding
192 | results['scale_factor'] = scale_factor
193 | results['keep_ratio'] = self.keep_ratio
194 |
195 | def _resize_seg(self, results):
196 | """Resize semantic segmentation map with ``results['scale']``."""
197 | for key in results.get('seg_fields', []):
198 | if self.keep_ratio:
199 | gt_seg = mmcv.imrescale(
200 | results[key], results['scale'], interpolation='nearest')
201 | gt_seg = self._align(gt_seg, self.size_divisor, interpolation='nearest')
202 | else:
203 | gt_seg = mmcv.imresize(
204 | results[key], results['scale'], interpolation='nearest')
205 | h, w = gt_seg.shape[:2]
206 | assert int(np.ceil(h / self.size_divisor)) * self.size_divisor == h and \
207 | int(np.ceil(w / self.size_divisor)) * self.size_divisor == w, \
208 | "gt_seg size not align. h:{} w:{}".format(h, w)
209 | results[key] = gt_seg
210 |
211 | def __call__(self, results):
212 | """Call function to resize images, bounding boxes, masks, semantic
213 | segmentation map.
214 |
215 | Args:
216 | results (dict): Result dict from loading pipeline.
217 |
218 | Returns:
219 | dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor',
220 | 'keep_ratio' keys are added into result dict.
221 | """
222 |
223 | if 'scale' not in results:
224 | self._random_scale(results)
225 | self._resize_img(results)
226 | self._resize_seg(results)
227 | return results
228 |
229 | def __repr__(self):
230 | repr_str = self.__class__.__name__
231 | repr_str += (f'(img_scale={self.img_scale}, '
232 | f'multiscale_mode={self.multiscale_mode}, '
233 | f'ratio_range={self.ratio_range}, '
234 | f'keep_ratio={self.keep_ratio})')
235 | return repr_str
236 |
--------------------------------------------------------------------------------
/semantic_segmentation/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
--------------------------------------------------------------------------------
/semantic_segmentation/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
--------------------------------------------------------------------------------
/semantic_segmentation/tools/train.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------
2 | # Copyright (c) VCU, Nanjing University.
3 | # Licensed under the Apache License 2.0 [see LICENSE for details]
4 | # Written by Qing-Long Zhang
5 | # ------------------------------------------------------------
6 |
7 |
8 | import argparse
9 | import copy
10 | import os
11 | import os.path as osp
12 | import time
13 | import warnings
14 |
15 | import mmcv
16 | import torch
17 | import torch.distributed as dist
18 | from mmcv.cnn.utils import revert_sync_batchnorm
19 | from mmcv.runner import get_dist_info, init_dist
20 | from mmcv.utils import Config, DictAction, get_git_hash
21 |
22 | from mmseg import __version__
23 | from mmseg.apis import init_random_seed, set_random_seed, train_segmentor
24 | from mmseg.datasets import build_dataset
25 | from mmseg.models import build_segmentor
26 | from mmseg.utils import collect_env, get_root_logger, setup_multi_processes
27 |
28 | from align_resize import AlignResize
29 |
30 | def parse_args():
31 | parser = argparse.ArgumentParser(description='Train a segmentor')
32 | parser.add_argument('config', help='train config file path')
33 | parser.add_argument('--work-dir', help='the dir to save logs and models')
34 | parser.add_argument(
35 | '--load-from', help='the checkpoint file to load weights from')
36 | parser.add_argument(
37 | '--resume-from', help='the checkpoint file to resume from')
38 | parser.add_argument(
39 | '--no-validate',
40 | action='store_true',
41 | help='whether not to evaluate the checkpoint during training')
42 | group_gpus = parser.add_mutually_exclusive_group()
43 | group_gpus.add_argument(
44 | '--gpus',
45 | type=int,
46 | help='(Deprecated, please use --gpu-id) number of gpus to use '
47 | '(only applicable to non-distributed training)')
48 | group_gpus.add_argument(
49 | '--gpu-ids',
50 | type=int,
51 | nargs='+',
52 | help='(Deprecated, please use --gpu-id) ids of gpus to use '
53 | '(only applicable to non-distributed training)')
54 | group_gpus.add_argument(
55 | '--gpu-id',
56 | type=int,
57 | default=0,
58 | help='id of gpu to use '
59 | '(only applicable to non-distributed training)')
60 | parser.add_argument('--seed', type=int, default=None, help='random seed')
61 | parser.add_argument(
62 | '--diff_seed',
63 | action='store_true',
64 | help='Whether or not set different seeds for different ranks')
65 | parser.add_argument(
66 | '--deterministic',
67 | action='store_true',
68 | help='whether to set deterministic options for CUDNN backend.')
69 | parser.add_argument(
70 | '--options',
71 | nargs='+',
72 | action=DictAction,
73 | help="--options is deprecated in favor of --cfg_options' and it will "
74 | 'not be supported in version v0.22.0. Override some settings in the '
75 | 'used config, the key-value pair in xxx=yyy format will be merged '
76 | 'into config file. If the value to be overwritten is a list, it '
77 | 'should be like key="[a,b]" or key=a,b It also allows nested '
78 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
79 | 'marks are necessary and that no white space is allowed.')
80 | parser.add_argument(
81 | '--cfg-options',
82 | nargs='+',
83 | action=DictAction,
84 | help='override some settings in the used config, the key-value pair '
85 | 'in xxx=yyy format will be merged into config file. If the value to '
86 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
87 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
88 | 'Note that the quotation marks are necessary and that no white space '
89 | 'is allowed.')
90 | parser.add_argument(
91 | '--launcher',
92 | choices=['none', 'pytorch', 'slurm', 'mpi'],
93 | default='none',
94 | help='job launcher')
95 | parser.add_argument('--local_rank', type=int, default=0)
96 | parser.add_argument(
97 | '--auto-resume',
98 | action='store_true',
99 | help='resume from the latest checkpoint automatically.')
100 | args = parser.parse_args()
101 | if 'LOCAL_RANK' not in os.environ:
102 | os.environ['LOCAL_RANK'] = str(args.local_rank)
103 |
104 | if args.options and args.cfg_options:
105 | raise ValueError(
106 | '--options and --cfg-options cannot be both '
107 | 'specified, --options is deprecated in favor of --cfg-options. '
108 | '--options will not be supported in version v0.22.0.')
109 | if args.options:
110 | warnings.warn('--options is deprecated in favor of --cfg-options. '
111 | '--options will not be supported in version v0.22.0.')
112 | args.cfg_options = args.options
113 |
114 | return args
115 |
116 |
117 | def main():
118 | args = parse_args()
119 |
120 | cfg = Config.fromfile(args.config)
121 | if args.cfg_options is not None:
122 | cfg.merge_from_dict(args.cfg_options)
123 |
124 | # set cudnn_benchmark
125 | if cfg.get('cudnn_benchmark', False):
126 | torch.backends.cudnn.benchmark = True
127 |
128 | # work_dir is determined in this priority: CLI > segment in file > filename
129 | if args.work_dir is not None:
130 | # update configs according to CLI args if args.work_dir is not None
131 | cfg.work_dir = args.work_dir
132 | elif cfg.get('work_dir', None) is None:
133 | # use config filename as default work_dir if cfg.work_dir is None
134 | cfg.work_dir = osp.join('./work_dirs',
135 | osp.splitext(osp.basename(args.config))[0])
136 | if args.load_from is not None:
137 | cfg.load_from = args.load_from
138 | if args.resume_from is not None:
139 | cfg.resume_from = args.resume_from
140 | if args.gpus is not None:
141 | cfg.gpu_ids = range(1)
142 | warnings.warn('`--gpus` is deprecated because we only support '
143 | 'single GPU mode in non-distributed training. '
144 | 'Use `gpus=1` now.')
145 | if args.gpu_ids is not None:
146 | cfg.gpu_ids = args.gpu_ids[0:1]
147 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
148 | 'Because we only support single GPU mode in '
149 | 'non-distributed training. Use the first GPU '
150 | 'in `gpu_ids` now.')
151 | if args.gpus is None and args.gpu_ids is None:
152 | cfg.gpu_ids = [args.gpu_id]
153 |
154 | cfg.auto_resume = args.auto_resume
155 |
156 | # init distributed env first, since logger depends on the dist info.
157 | if args.launcher == 'none':
158 | distributed = False
159 | else:
160 | distributed = True
161 | init_dist(args.launcher, **cfg.dist_params)
162 | # gpu_ids is used to calculate iter when resuming checkpoint
163 | _, world_size = get_dist_info()
164 | cfg.gpu_ids = range(world_size)
165 |
166 | # create work_dir
167 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
168 | # dump config
169 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
170 | # init the logger before other steps
171 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
172 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
173 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
174 |
175 | # set multi-process settings
176 | setup_multi_processes(cfg)
177 |
178 | # init the meta dict to record some important information such as
179 | # environment info and seed, which will be logged
180 | meta = dict()
181 | # log env info
182 | env_info_dict = collect_env()
183 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
184 | dash_line = '-' * 60 + '\n'
185 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
186 | dash_line)
187 | meta['env_info'] = env_info
188 |
189 | # log some basic info
190 | logger.info(f'Distributed training: {distributed}')
191 | logger.info(f'Config:\n{cfg.pretty_text}')
192 |
193 | # set random seeds
194 | seed = init_random_seed(args.seed)
195 | seed = seed + dist.get_rank() if args.diff_seed else seed
196 | logger.info(f'Set random seed to {seed}, '
197 | f'deterministic: {args.deterministic}')
198 | set_random_seed(seed, deterministic=args.deterministic)
199 | cfg.seed = seed
200 | meta['seed'] = seed
201 | meta['exp_name'] = osp.basename(args.config)
202 |
203 | model = build_segmentor(
204 | cfg.model,
205 | train_cfg=cfg.get('train_cfg'),
206 | test_cfg=cfg.get('test_cfg'))
207 | model.init_weights()
208 |
209 | # SyncBN is not support for DP
210 | if not distributed:
211 | warnings.warn(
212 | 'SyncBN is only supported with DDP. To be compatible with DP, '
213 | 'we convert SyncBN to BN. Please use dist_train.sh which can '
214 | 'avoid this error.')
215 | model = revert_sync_batchnorm(model)
216 |
217 | logger.info(model)
218 |
219 | datasets = [build_dataset(cfg.data.train)]
220 | if len(cfg.workflow) == 2:
221 | val_dataset = copy.deepcopy(cfg.data.val)
222 | val_dataset.pipeline = cfg.data.train.pipeline
223 | datasets.append(build_dataset(val_dataset))
224 | if cfg.checkpoint_config is not None:
225 | # save mmseg version, config file content and class names in
226 | # checkpoints as meta data
227 | cfg.checkpoint_config.meta = dict(
228 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
229 | config=cfg.pretty_text,
230 | CLASSES=datasets[0].CLASSES,
231 | PALETTE=datasets[0].PALETTE)
232 | # add an attribute for visualization convenience
233 | model.CLASSES = datasets[0].CLASSES
234 | # passing checkpoint meta for saving best checkpoint
235 | meta.update(cfg.checkpoint_config.meta)
236 | train_segmentor(
237 | model,
238 | datasets,
239 | cfg,
240 | distributed=distributed,
241 | validate=(not args.no_validate),
242 | timestamp=timestamp,
243 | meta=meta)
244 |
245 |
246 | if __name__ == '__main__':
247 | main()
--------------------------------------------------------------------------------