├── LICENSE ├── README.md ├── architecture.png ├── cls ├── README.md ├── configs │ ├── _base_ │ │ └── models │ │ │ └── mobilenet_v1_1x.py │ ├── distillers │ │ └── mgd │ │ │ ├── res34_distill_res18_img.py │ │ │ └── res50_distill_mv1_img.py │ └── mobilenet_v1 │ │ └── mobilenet_v1.py ├── mmcls │ ├── apis │ │ └── train.py │ ├── distillation │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── distillers │ │ │ ├── __init__.py │ │ │ └── classification_distiller.py │ │ └── losses │ │ │ ├── __init__.py │ │ │ └── mgd.py │ └── models │ │ └── backbones │ │ ├── __init__.py │ │ └── mobilenet_v1.py └── tools │ └── train.py ├── det ├── README.md ├── configs │ └── distillers │ │ └── mgd │ │ ├── cascade_mask_rcnn_rx101_32x4d_distill_faster_rcnn_r50_fpn_2x_coco.py │ │ ├── cascade_mask_rcnn_rx101_32x4d_distill_mask_rcnn_r50_fpn_2x_coco.py │ │ ├── reppoints_rx101_64x4d_distill_reppoints_r50_fpn_2x_coco.py │ │ ├── retina_rx101_64x4d_distill_retina_r50_fpn_2x_coco.py │ │ └── solo_r101_ms_distill_solo_r50_coco.py ├── mmdet │ ├── apis │ │ └── train.py │ └── distillation │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── distillers │ │ ├── __init__.py │ │ └── detection_distiller.py │ │ └── losses │ │ ├── __init__.py │ │ └── mgd.py └── tools │ └── train.py ├── pth_transfer.py └── seg ├── README.md ├── configs ├── _base_ │ └── datasets │ │ └── cityscapes_512x512.py ├── deeplabv3 │ └── deeplabv3_r18-d8_512x512_40k_cityscapes.py ├── distillers │ └── mgd │ │ ├── psp_r101_distill_deepv3_r18_40k_512x512_city.py │ │ └── psp_r101_distill_psp_r18_40k_512x512_city.py └── pspnet │ └── pspnet_r18-d8_512x512_40k_cityscapes.py ├── mmseg ├── apis │ └── train.py └── distillation │ ├── __init__.py │ ├── builder.py │ ├── distillers │ ├── __init__.py │ └── segmentation_distiller.py │ └── losses │ ├── __init__.py │ └── mgd.py └── tools └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MGD 2 | ECCV 2022 Paper: [Masked Generative Distillation](https://arxiv.org/abs/2205.01529) 3 | 4 | ![architecture](architecture.png) 5 | ## Image Classification 6 | Please refer [image classification](https://github.com/yzd-v/MGD/tree/master/cls) 7 | ## Object Classification 8 | Please refer [object detection](https://github.com/yzd-v/MGD/tree/master/det) 9 | ## Semantic Segmentation 10 | Please refer [semantic segmentation](https://github.com/yzd-v/MGD/tree/master/seg) 11 | 12 | ## Citation 13 | ``` 14 | @article{yang2022masked, 15 | title={Masked Generative Distillation}, 16 | author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, 17 | journal={arXiv preprint arXiv:2205.01529}, 18 | year={2022} 19 | } 20 | ``` -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzd-v/MGD/2c9da0b28625eb948db57afc02c824452c3910fe/architecture.png -------------------------------------------------------------------------------- /cls/README.md: -------------------------------------------------------------------------------- 1 | # Image Classification 2 | ## Install 3 | - Our codes are based on [MMClassification](https://github.com/open-mmlab/mmclassification). Please follow the installation of MMClassification and make sure you can run it successfully. 4 | - This repo uses mmcv-full==1.3.17 and mmcls = 0.19.0 5 | - If you want to use lower mmcv-full version, you may have to change the optimizer in apis/train.py and build_distiller in tools/train.py. 6 | - For lower mmcv-full, you may refer [FGD](https://github.com/yzd-v/FGD) to change model.init_weights() in [train.py](https://github.com/yzd-v/MGD/tree/master/cls/tools/train.py) and self.student.init_weights() in [distiller.py](https://github.com/yzd-v/MGD/tree/master/cls/mmcls/distillation/distillers/classification_distiller.py). 7 | ## Add and Replace the codes 8 | - Add the configs/. in our codes to the configs/ in mmclassification's codes. 9 | - Add the mmcls/. in our codes to the mmcls/ in mmclassification's codes. 10 | - Replace the mmcls/apis/train.py and tools/train.py in mmclassification's codes with mmcls/apis/train.py and tools/train.py in our codes. 11 | - Add pth_transfer.py to mmclassification's codes. 12 | - Unzip ImageNet dataset into data/imagenet/ 13 | 14 | ## Train 15 | 16 | ``` 17 | #single GPU 18 | python tools/train.py configs/distillers/res34_distill_res18_img.py 19 | 20 | #multi GPU 21 | bash tools/dist_train.sh configs/distillers/res34_distill_res18_img.py 8 22 | ``` 23 | 24 | ## Transfer 25 | ``` 26 | # Tansfer the MGD model into mmcls model 27 | python pth_transfer.py --mgd_path $mgd_ckpt --output_path $new_mmcls_ckpt 28 | ``` 29 | ## Test 30 | 31 | ``` 32 | #single GPU 33 | python tools/test.py configs/resnet/resnet18_8xb32_in1k.py $new_mmcls_ckpt --metrics accuracy 34 | 35 | #multi GPU 36 | bash tools/dist_test.sh configs/resnet/resnet18_8xb32_in1k.py $new_mmcls_ckpt 8 --metrics accuracy 37 | ``` 38 | 39 | ## Results 40 | | Model | Teacher | Baseline(Top-1 Acc) | +MGD(Top-1 Acc) | config | log | weight | 41 | | :------: | :-------: | :----------------: | :------------: | :----------------------------------------------------------: | :------------------------------------------------------: | :--: | 42 | | ResNet18 | ResNet34 | 69.90 | 71.69 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | [baidu](https://pan.baidu.com/s/1PoJeqOzlEb6MKBEQMSvYsw?pwd=27wc) | [baidu](https://pan.baidu.com/s/1VtjqCvFHGh-qUR7wOvojYw?pwd=ehnn) || | | 43 | | MobileNet | ResNet50 | 69.21 | 72.49 | [config](https://github.com/yzd-v/MGD/tree/master/cls/configs/mobilenet_v1/mobilenet_v1.py) | [baidu](https://pan.baidu.com/s/1m5yuPATnpnfBB1izZc0I3g?pwd=piu8) | [baidu](https://pan.baidu.com/s/1NdoHf0KA3MiIUKC9_gH3ng?pwd=fnii) | 44 | 45 | ## Citation 46 | ``` 47 | @article{yang2022masked, 48 | title={Masked Generative Distillation}, 49 | author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, 50 | journal={arXiv preprint arXiv:2205.01529}, 51 | year={2022} 52 | } 53 | ``` 54 | 55 | ## Acknowledgement 56 | 57 | Our code is based on the project [MMClassification](https://github.com/open-mmlab/mmclassification). -------------------------------------------------------------------------------- /cls/configs/_base_/models/mobilenet_v1_1x.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifier', 4 | backbone=dict(type='MobileNetV1'), 5 | neck=dict(type='GlobalAveragePooling'), 6 | head=dict( 7 | type='LinearClsHead', 8 | num_classes=1000, 9 | in_channels=1024, 10 | loss=dict(type='CrossEntropyLoss', loss_weight=1.0), 11 | topk=(1, 5), 12 | )) 13 | -------------------------------------------------------------------------------- /cls/configs/distillers/mgd/res34_distill_res18_img.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../resnet/resnet18_b32x8_imagenet.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | distiller = dict( 7 | type='ClassificationDistiller', 8 | teacher_pretrained = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth', 9 | distill_cfg = [ 10 | dict(methods=[dict(type='MGDLoss', 11 | name='loss_mgd', 12 | student_channels = 512, 13 | teacher_channels = 512, 14 | alpha_mgd=0.00007, 15 | lambda_mgd=0.15, 16 | ) 17 | ] 18 | ), 19 | ] 20 | ) 21 | 22 | student_cfg = 'configs/resnet/resnet18_b32x8_imagenet.py' 23 | teacher_cfg = 'configs/resnet/resnet34_b32x8_imagenet.py' 24 | -------------------------------------------------------------------------------- /cls/configs/distillers/mgd/res50_distill_mv1_img.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../mobilenet_v1/mobilenet_v1.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | distiller = dict( 7 | type='ClassificationDistiller', 8 | teacher_pretrained = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth', 9 | distill_cfg = [ 10 | dict(methods=[dict(type='MGDLoss', 11 | name='loss_mgd', 12 | student_channels = 1024, 13 | teacher_channels = 2048, 14 | alpha_mgd=0.00007, 15 | lambda_mgd=0.15, 16 | ) 17 | ] 18 | ), 19 | ] 20 | ) 21 | 22 | student_cfg = 'configs/mobilenet_v1/mobilenet_v1.py' 23 | teacher_cfg = 'configs/resnet/resnet50_b32x8_imagenet.py' 24 | optimizer_config = dict(_delete_=True,grad_clip=dict(max_norm=5.0)) -------------------------------------------------------------------------------- /cls/configs/mobilenet_v1/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/mobilenet_v1_1x.py', 3 | '../_base_/datasets/imagenet_bs32.py', 4 | '../_base_/schedules/imagenet_bs256.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | -------------------------------------------------------------------------------- /cls/mmcls/apis/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import (DistSamplerSeedHook, build_optimizer, build_runner, 10 | get_dist_info) 11 | 12 | from mmcls.core import DistOptimizerHook 13 | from mmcls.datasets import build_dataloader, build_dataset 14 | from mmcls.utils import get_root_logger 15 | 16 | # TODO import eval hooks from mmcv and delete them from mmcls 17 | try: 18 | from mmcv.runner.hooks import EvalHook, DistEvalHook 19 | except ImportError: 20 | warnings.warn('DeprecationWarning: EvalHook and DistEvalHook from mmcls ' 21 | 'will be deprecated.' 22 | 'Please install mmcv through master branch.') 23 | from mmcls.core import EvalHook, DistEvalHook 24 | 25 | # TODO import optimizer hook from mmcv and delete them from mmcls 26 | try: 27 | from mmcv.runner import Fp16OptimizerHook 28 | except ImportError: 29 | warnings.warn('DeprecationWarning: FP16OptimizerHook from mmcls will be ' 30 | 'deprecated. Please install mmcv>=1.1.4.') 31 | from mmcls.core import Fp16OptimizerHook 32 | 33 | 34 | def init_random_seed(seed=None, device='cuda'): 35 | """Initialize random seed. 36 | 37 | If the seed is not set, the seed will be automatically randomized, 38 | and then broadcast to all processes to prevent some potential bugs. 39 | 40 | Args: 41 | seed (int, Optional): The seed. Default to None. 42 | device (str): The device where the seed will be put on. 43 | Default to 'cuda'. 44 | 45 | Returns: 46 | int: Seed to be used. 47 | """ 48 | if seed is not None: 49 | return seed 50 | 51 | # Make sure all ranks share the same random seed to prevent 52 | # some potential bugs. Please refer to 53 | # https://github.com/open-mmlab/mmdetection/issues/6339 54 | rank, world_size = get_dist_info() 55 | seed = np.random.randint(2**31) 56 | if world_size == 1: 57 | return seed 58 | 59 | if rank == 0: 60 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 61 | else: 62 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 63 | dist.broadcast(random_num, src=0) 64 | return random_num.item() 65 | 66 | 67 | def set_random_seed(seed, deterministic=False): 68 | """Set random seed. 69 | 70 | Args: 71 | seed (int): Seed to be used. 72 | deterministic (bool): Whether to set the deterministic option for 73 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 74 | to True and `torch.backends.cudnn.benchmark` to False. 75 | Default: False. 76 | """ 77 | random.seed(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | if deterministic: 82 | torch.backends.cudnn.deterministic = True 83 | torch.backends.cudnn.benchmark = False 84 | 85 | 86 | def train_model(model, 87 | dataset, 88 | cfg, 89 | distributed=False, 90 | validate=False, 91 | timestamp=None, 92 | device='cuda', 93 | meta=None): 94 | logger = get_root_logger() 95 | 96 | # prepare data loaders 97 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 98 | 99 | sampler_cfg = cfg.data.get('sampler', None) 100 | 101 | data_loaders = [ 102 | build_dataloader( 103 | ds, 104 | cfg.data.samples_per_gpu, 105 | cfg.data.workers_per_gpu, 106 | # cfg.gpus will be ignored if distributed 107 | num_gpus=len(cfg.gpu_ids), 108 | dist=distributed, 109 | round_up=True, 110 | seed=cfg.seed, 111 | sampler_cfg=sampler_cfg) for ds in dataset 112 | ] 113 | 114 | # put model on gpus 115 | if distributed: 116 | find_unused_parameters = cfg.get('find_unused_parameters', False) 117 | # Sets the `find_unused_parameters` parameter in 118 | # torch.nn.parallel.DistributedDataParallel 119 | model = MMDistributedDataParallel( 120 | model.cuda(), 121 | device_ids=[torch.cuda.current_device()], 122 | broadcast_buffers=False, 123 | find_unused_parameters=find_unused_parameters) 124 | else: 125 | if device == 'cuda': 126 | model = MMDataParallel( 127 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 128 | elif device == 'cpu': 129 | model = model.cpu() 130 | else: 131 | raise ValueError(F'unsupported device name {device}.') 132 | 133 | # build runner 134 | distiller_cfg = cfg.get('distiller',None) 135 | if distiller_cfg is None: 136 | optimizer = build_optimizer(model, cfg.optimizer) 137 | else: 138 | optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) 139 | 140 | if cfg.get('runner') is None: 141 | cfg.runner = { 142 | 'type': 'EpochBasedRunner', 143 | 'max_epochs': cfg.total_epochs 144 | } 145 | warnings.warn( 146 | 'config is now expected to have a `runner` section, ' 147 | 'please set `runner` in your config.', UserWarning) 148 | 149 | runner = build_runner( 150 | cfg.runner, 151 | default_args=dict( 152 | model=model, 153 | batch_processor=None, 154 | optimizer=optimizer, 155 | work_dir=cfg.work_dir, 156 | logger=logger, 157 | meta=meta)) 158 | 159 | # an ugly walkaround to make the .log and .log.json filenames the same 160 | runner.timestamp = timestamp 161 | 162 | # fp16 setting 163 | fp16_cfg = cfg.get('fp16', None) 164 | if fp16_cfg is not None: 165 | optimizer_config = Fp16OptimizerHook( 166 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed) 167 | elif distributed and 'type' not in cfg.optimizer_config: 168 | optimizer_config = DistOptimizerHook(**cfg.optimizer_config) 169 | else: 170 | optimizer_config = cfg.optimizer_config 171 | 172 | # register hooks 173 | runner.register_training_hooks( 174 | cfg.lr_config, 175 | optimizer_config, 176 | cfg.checkpoint_config, 177 | cfg.log_config, 178 | cfg.get('momentum_config', None), 179 | custom_hooks_config=cfg.get('custom_hooks', None)) 180 | if distributed and cfg.runner['type'] == 'EpochBasedRunner': 181 | runner.register_hook(DistSamplerSeedHook()) 182 | 183 | # register eval hooks 184 | if validate: 185 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 186 | val_dataloader = build_dataloader( 187 | val_dataset, 188 | samples_per_gpu=cfg.data.samples_per_gpu, 189 | workers_per_gpu=cfg.data.workers_per_gpu, 190 | dist=distributed, 191 | shuffle=False, 192 | round_up=True) 193 | eval_cfg = cfg.get('evaluation', {}) 194 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 195 | eval_hook = DistEvalHook if distributed else EvalHook 196 | # `EvalHook` needs to be executed after `IterTimerHook`. 197 | # Otherwise, it will cause a bug if use `IterBasedRunner`. 198 | # Refers to https://github.com/open-mmlab/mmcv/issues/1261 199 | runner.register_hook( 200 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') 201 | 202 | if cfg.resume_from: 203 | runner.resume(cfg.resume_from) 204 | elif cfg.load_from: 205 | runner.load_checkpoint(cfg.load_from) 206 | runner.run(data_loaders, cfg.workflow) 207 | -------------------------------------------------------------------------------- /cls/mmcls/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .builder import ( DISTILLER,DISTILL_LOSSES,build_distill_loss,build_distiller) 3 | from .distillers import * 4 | from .losses import * 5 | 6 | 7 | __all__ = [ 8 | 'DISTILLER', 'DISTILL_LOSSES', 'build_distiller' 9 | ] 10 | 11 | 12 | -------------------------------------------------------------------------------- /cls/mmcls/distillation/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from torch import nn 3 | 4 | DISTILLER = Registry('distiller') 5 | DISTILL_LOSSES = Registry('distill_loss') 6 | DISRUNNERS = Registry('runner') 7 | 8 | def build(cfg, registry, default_args=None): 9 | """Build a module. 10 | 11 | Args: 12 | cfg (dict, list[dict]): The config of modules, is is either a dict 13 | or a list of configs. 14 | registry (:obj:`Registry`): A registry the module belongs to. 15 | default_args (dict, optional): Default arguments to build the module. 16 | Defaults to None. 17 | 18 | Returns: 19 | nn.Module: A built nn module. 20 | """ 21 | 22 | if isinstance(cfg, list): 23 | modules = [ 24 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 25 | ] 26 | return nn.Sequential(*modules) 27 | else: 28 | return build_from_cfg(cfg, registry, default_args) 29 | 30 | def build_distill_loss(cfg): 31 | """Build distill loss.""" 32 | return build(cfg, DISTILL_LOSSES) 33 | 34 | def build_distiller(cfg,teacher_cfg=None,student_cfg=None): 35 | """Build distiller.""" 36 | 37 | return build(cfg, DISTILLER, dict(teacher_cfg=teacher_cfg,student_cfg=student_cfg)) 38 | 39 | 40 | def build_runner(cfg, default_args=None): 41 | return build_from_cfg(cfg, DISRUNNERS, default_args=default_args) 42 | -------------------------------------------------------------------------------- /cls/mmcls/distillation/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification_distiller import ClassificationDistiller 2 | 3 | __all__ = [ 4 | 'ClassificationDistiller' 5 | ] -------------------------------------------------------------------------------- /cls/mmcls/distillation/distillers/classification_distiller.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from mmcls.models.classifiers.base import BaseClassifier 5 | from mmcls.models import build_classifier 6 | from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict 7 | from ..builder import DISTILLER,build_distill_loss 8 | from collections import OrderedDict 9 | 10 | 11 | @DISTILLER.register_module() 12 | class ClassificationDistiller(BaseClassifier): 13 | """Base distiller for detectors. 14 | 15 | It typically consists of teacher_model and student_model. 16 | """ 17 | def __init__(self, 18 | teacher_cfg, 19 | student_cfg, 20 | distill_cfg=None, 21 | teacher_pretrained=None): 22 | 23 | super(ClassificationDistiller, self).__init__() 24 | 25 | self.teacher = build_classifier(teacher_cfg.model) 26 | if teacher_pretrained: 27 | self.init_weights_teacher(teacher_pretrained) 28 | self.teacher.eval() 29 | 30 | self.student= build_classifier(student_cfg.model) 31 | self.student.init_weights() 32 | 33 | self.distill_cfg = distill_cfg 34 | self.distill_losses = nn.ModuleDict() 35 | if self.distill_cfg is not None: 36 | for item_loc in distill_cfg: 37 | for item_loss in item_loc.methods: 38 | loss_name = item_loss.name 39 | self.distill_losses[loss_name] = build_distill_loss(item_loss) 40 | 41 | def base_parameters(self): 42 | return nn.ModuleList([self.student,self.distill_losses]) 43 | 44 | def init_weights_teacher(self, path=None): 45 | """Load the pretrained model in teacher detector. 46 | 47 | Args: 48 | pretrained (str, optional): Path to pre-trained weights. 49 | Defaults to None. 50 | """ 51 | checkpoint = load_checkpoint(self.teacher, path, map_location='cpu') 52 | 53 | def forward_train(self, 54 | img, 55 | gt_label, 56 | **kwargs): 57 | 58 | """ 59 | Args: 60 | img (Tensor): Input images of shape (N, C, H, W). 61 | Typically these should be mean centered and std scaled. 62 | 63 | Returns: 64 | dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses). 65 | """ 66 | 67 | if self.student.augments is not None: 68 | img, gt_label = self.student.augments(img, gt_label) 69 | 70 | fea_s = self.student.extract_feat(img, stage='backbone') 71 | x = fea_s 72 | if self.student.with_neck: 73 | x = self.student.neck(x) 74 | if self.student.with_head and hasattr(self.student.head, 'pre_logits'): 75 | x = self.student.head.pre_logits(x) 76 | 77 | logit_s = self.student.head.fc(x) 78 | loss = self.student.head.loss(logit_s, gt_label) 79 | 80 | student_loss = dict() 81 | for key in loss.keys(): 82 | student_loss['ori_'+key] = loss[key] 83 | 84 | with torch.no_grad(): 85 | fea_t = self.teacher.extract_feat(img, stage='backbone') 86 | 87 | loss_name = 'loss_mgd' 88 | student_loss[loss_name] = self.distill_losses[loss_name](fea_s[-1], fea_t[-1]) 89 | 90 | return student_loss 91 | 92 | def simple_test(self, img, img_metas=None, **kwargs): 93 | return self.student.simple_test(img, img_metas, **kwargs) 94 | 95 | def extract_feat(self, imgs, stage='neck'): 96 | """Extract features from images. 97 | 'backbone', 'neck', 'pre_logits' 98 | """ 99 | return self.student.extract_feat(imgs, stage) 100 | 101 | 102 | -------------------------------------------------------------------------------- /cls/mmcls/distillation/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .mgd import MGDLoss 2 | 3 | __all__ = [ 4 | 'MGDLoss' 5 | ] 6 | -------------------------------------------------------------------------------- /cls/mmcls/distillation/losses/mgd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from ..builder import DISTILL_LOSSES 5 | 6 | @DISTILL_LOSSES.register_module() 7 | class MGDLoss(nn.Module): 8 | 9 | """PyTorch version of `Masked Generative Distillation` 10 | 11 | Args: 12 | student_channels(int): Number of channels in the student's feature map. 13 | teacher_channels(int): Number of channels in the teacher's feature map. 14 | name (str): the loss name of the layer 15 | alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00007 16 | lambda_mgd (float, optional): masked ratio. Defaults to 0.5 17 | """ 18 | def __init__(self, 19 | student_channels, 20 | teacher_channels, 21 | name, 22 | alpha_mgd=0.00007, 23 | lambda_mgd=0.15, 24 | ): 25 | super(MGDLoss, self).__init__() 26 | self.alpha_mgd = alpha_mgd 27 | self.lambda_mgd = lambda_mgd 28 | 29 | if student_channels != teacher_channels: 30 | self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) 31 | else: 32 | self.align = None 33 | 34 | self.generation = nn.Sequential( 35 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)) 38 | 39 | 40 | def forward(self, 41 | preds_S, 42 | preds_T): 43 | """Forward function. 44 | Args: 45 | preds_S(Tensor): Bs*C*H*W, student's feature map 46 | preds_T(Tensor): Bs*C*H*W, teacher's feature map 47 | """ 48 | assert preds_S.shape[-2:] == preds_T.shape[-2:] 49 | 50 | if self.align is not None: 51 | preds_S = self.align(preds_S) 52 | 53 | loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgd 54 | 55 | return loss 56 | 57 | def get_dis_loss(self, preds_S, preds_T): 58 | loss_mse = nn.MSELoss(reduction='sum') 59 | N, C, H, W = preds_T.shape 60 | 61 | device = preds_S.device 62 | mat = torch.rand((N,C,1,1)).to(device) 63 | # mat = torch.rand((N,1,H,W)).to(device) 64 | mat = torch.where(mat < self.lambda_mgd, 0, 1).to(device) 65 | 66 | masked_fea = torch.mul(preds_S, mat) 67 | new_fea = self.generation(masked_fea) 68 | 69 | dis_loss = loss_mse(new_fea, preds_T)/N 70 | 71 | return dis_loss -------------------------------------------------------------------------------- /cls/mmcls/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .alexnet import AlexNet 3 | from .conformer import Conformer 4 | from .deit import DistilledVisionTransformer 5 | from .lenet import LeNet5 6 | from .mlp_mixer import MlpMixer 7 | from .mobilenet_v1 import MobileNetV1 8 | from .mobilenet_v2 import MobileNetV2 9 | from .mobilenet_v3 import MobileNetV3 10 | from .regnet import RegNet 11 | from .repvgg import RepVGG 12 | from .res2net import Res2Net 13 | from .resnest import ResNeSt 14 | from .resnet import ResNet, ResNetV1d 15 | from .resnet_cifar import ResNet_CIFAR 16 | from .resnext import ResNeXt 17 | from .seresnet import SEResNet 18 | from .seresnext import SEResNeXt 19 | from .shufflenet_v1 import ShuffleNetV1 20 | from .shufflenet_v2 import ShuffleNetV2 21 | from .swin_transformer import SwinTransformer 22 | from .t2t_vit import T2T_ViT 23 | from .timm_backbone import TIMMBackbone 24 | from .tnt import TNT 25 | from .vgg import VGG 26 | from .vision_transformer import VisionTransformer 27 | 28 | __all__ = [ 29 | 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 30 | 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 31 | 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 32 | 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', 33 | 'Conformer', 'MlpMixer', 'DistilledVisionTransformer' 34 | ] 35 | -------------------------------------------------------------------------------- /cls/mmcls/models/backbones/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..builder import BACKBONES 4 | from .base_backbone import BaseBackbone 5 | 6 | @BACKBONES.register_module() 7 | class MobileNetV1(BaseBackbone): 8 | def __init__(self, 9 | init_cfg=[ 10 | dict(type='Kaiming', layer=['Conv2d']), 11 | dict( 12 | type='Constant', 13 | val=1, 14 | layer=['_BatchNorm', 'GroupNorm']) 15 | ]): 16 | super(BaseBackbone, self).__init__(init_cfg) 17 | 18 | def conv_bn(inp, oup, stride): 19 | return nn.Sequential( 20 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 21 | nn.BatchNorm2d(oup), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def conv_dw(inp, oup, stride): 26 | return nn.Sequential( 27 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 28 | nn.BatchNorm2d(inp), 29 | nn.ReLU(inplace=True), 30 | 31 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 32 | nn.BatchNorm2d(oup), 33 | nn.ReLU(inplace=True), 34 | ) 35 | 36 | self.model = nn.Sequential( 37 | conv_bn( 3, 32, 2), 38 | conv_dw( 32, 64, 1), 39 | conv_dw( 64, 128, 2), 40 | conv_dw(128, 128, 1), 41 | conv_dw(128, 256, 2), 42 | conv_dw(256, 256, 1), 43 | conv_dw(256, 512, 2), 44 | conv_dw(512, 512, 1), 45 | conv_dw(512, 512, 1), 46 | conv_dw(512, 512, 1), 47 | conv_dw(512, 512, 1), 48 | conv_dw(512, 512, 1), 49 | conv_dw(512, 1024, 2), 50 | conv_dw(1024, 1024, 1), 51 | ) 52 | 53 | def forward(self, x): 54 | outs = [] 55 | x = self.model(x) 56 | outs.append(x) 57 | return tuple(outs) -------------------------------------------------------------------------------- /cls/tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | from mmcv import Config, DictAction 12 | from mmcv.runner import get_dist_info, init_dist 13 | 14 | from mmcls import __version__ 15 | from mmcls.apis import init_random_seed, set_random_seed, train_model 16 | from mmcls.datasets import build_dataset 17 | from mmcls.models import build_classifier 18 | from mmcls.utils import collect_env, get_root_logger 19 | from mmcls.distillation import build_distiller 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a model') 24 | parser.add_argument('config', help='train config file path') 25 | parser.add_argument('--work-dir', help='the dir to save logs and models') 26 | parser.add_argument( 27 | '--resume-from', help='the checkpoint file to resume from') 28 | parser.add_argument( 29 | '--no-validate', 30 | action='store_true', 31 | help='whether not to evaluate the checkpoint during training') 32 | group_gpus = parser.add_mutually_exclusive_group() 33 | group_gpus.add_argument('--device', help='device used for training') 34 | group_gpus.add_argument( 35 | '--gpus', 36 | type=int, 37 | help='number of gpus to use ' 38 | '(only applicable to non-distributed training)') 39 | group_gpus.add_argument( 40 | '--gpu-ids', 41 | type=int, 42 | nargs='+', 43 | help='ids of gpus to use ' 44 | '(only applicable to non-distributed training)') 45 | parser.add_argument('--seed', type=int, default=None, help='random seed') 46 | parser.add_argument( 47 | '--deterministic', 48 | action='store_true', 49 | help='whether to set deterministic options for CUDNN backend.') 50 | parser.add_argument( 51 | '--options', 52 | nargs='+', 53 | action=DictAction, 54 | help='override some settings in the used config, the key-value pair ' 55 | 'in xxx=yyy format will be merged into config file (deprecate), ' 56 | 'change to --cfg-options instead.') 57 | parser.add_argument( 58 | '--cfg-options', 59 | nargs='+', 60 | action=DictAction, 61 | help='override some settings in the used config, the key-value pair ' 62 | 'in xxx=yyy format will be merged into config file. If the value to ' 63 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 64 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 65 | 'Note that the quotation marks are necessary and that no white space ' 66 | 'is allowed.') 67 | parser.add_argument( 68 | '--launcher', 69 | choices=['none', 'pytorch', 'slurm', 'mpi'], 70 | default='none', 71 | help='job launcher') 72 | parser.add_argument('--local_rank', type=int, default=0) 73 | args = parser.parse_args() 74 | if 'LOCAL_RANK' not in os.environ: 75 | os.environ['LOCAL_RANK'] = str(args.local_rank) 76 | 77 | if args.options and args.cfg_options: 78 | raise ValueError( 79 | '--options and --cfg-options cannot be both ' 80 | 'specified, --options is deprecated in favor of --cfg-options') 81 | if args.options: 82 | warnings.warn('--options is deprecated in favor of --cfg-options') 83 | args.cfg_options = args.options 84 | 85 | return args 86 | 87 | 88 | def main(): 89 | args = parse_args() 90 | 91 | cfg = Config.fromfile(args.config) 92 | if args.cfg_options is not None: 93 | cfg.merge_from_dict(args.cfg_options) 94 | # set cudnn_benchmark 95 | if cfg.get('cudnn_benchmark', False): 96 | torch.backends.cudnn.benchmark = True 97 | 98 | # work_dir is determined in this priority: CLI > segment in file > filename 99 | if args.work_dir is not None: 100 | # update configs according to CLI args if args.work_dir is not None 101 | cfg.work_dir = args.work_dir 102 | elif cfg.get('work_dir', None) is None: 103 | # use config filename as default work_dir if cfg.work_dir is None 104 | cfg.work_dir = osp.join('./work_dirs', 105 | osp.splitext(osp.basename(args.config))[0]) 106 | if args.resume_from is not None: 107 | cfg.resume_from = args.resume_from 108 | if args.gpu_ids is not None: 109 | cfg.gpu_ids = args.gpu_ids 110 | else: 111 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 112 | 113 | # init distributed env first, since logger depends on the dist info. 114 | if args.launcher == 'none': 115 | distributed = False 116 | else: 117 | distributed = True 118 | init_dist(args.launcher, **cfg.dist_params) 119 | _, world_size = get_dist_info() 120 | cfg.gpu_ids = range(world_size) 121 | 122 | # create work_dir 123 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 124 | # dump config 125 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 126 | # init the logger before other steps 127 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 128 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 129 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 130 | 131 | # init the meta dict to record some important information such as 132 | # environment info and seed, which will be logged 133 | meta = dict() 134 | # log env info 135 | env_info_dict = collect_env() 136 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 137 | dash_line = '-' * 60 + '\n' 138 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 139 | dash_line) 140 | meta['env_info'] = env_info 141 | 142 | # log some basic info 143 | logger.info(f'Distributed training: {distributed}') 144 | logger.info(f'Config:\n{cfg.pretty_text}') 145 | 146 | # set random seeds 147 | seed = init_random_seed(args.seed) 148 | logger.info(f'Set random seed to {seed}, ' 149 | f'deterministic: {args.deterministic}') 150 | set_random_seed(seed, deterministic=args.deterministic) 151 | cfg.seed = seed 152 | meta['seed'] = seed 153 | 154 | """ 155 | change the code for distillation 156 | """ 157 | distiller_cfg = cfg.get('distiller',None) 158 | if distiller_cfg is None: 159 | model = build_classifier(cfg.model) 160 | model.init_weights() 161 | else: 162 | teacher_cfg = Config.fromfile(cfg.teacher_cfg) 163 | student_cfg = Config.fromfile(cfg.student_cfg) 164 | 165 | model = build_distiller(cfg.distiller,teacher_cfg,student_cfg) 166 | 167 | 168 | datasets = [build_dataset(cfg.data.train)] 169 | if len(cfg.workflow) == 2: 170 | val_dataset = copy.deepcopy(cfg.data.val) 171 | val_dataset.pipeline = cfg.data.train.pipeline 172 | datasets.append(build_dataset(val_dataset)) 173 | 174 | # save mmcls version, config file content and class names in 175 | # runner as meta data 176 | meta.update( 177 | dict( 178 | mmcls_version=__version__, 179 | config=cfg.pretty_text, 180 | CLASSES=datasets[0].CLASSES)) 181 | 182 | # add an attribute for visualization convenience 183 | train_model( 184 | model, 185 | datasets, 186 | cfg, 187 | distributed=distributed, 188 | validate=(not args.no_validate), 189 | timestamp=timestamp, 190 | device='cpu' if args.device == 'cpu' else 'cuda', 191 | meta=meta) 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /det/README.md: -------------------------------------------------------------------------------- 1 | # Object Detection 2 | ## Install 3 | - Our codes are based on [MMDetection](https://github.com/open-mmlab/mmdetection). Please follow the installation of MMDetection and make sure you can run it successfully. 4 | - This repo uses mmcv-full==1.3.17 and mmdet==2.19.0 5 | - If you want to use lower mmcv-full version, you may have to change the optimizer in apis/train.py and build_distiller in tools/train.py. 6 | - For lower mmcv-full, you may refer [FGD](https://github.com/yzd-v/FGD) to change model.init_weights() in [train.py](https://github.com/yzd-v/MGD/tree/master/det/tools/train.py) and self.student.init_weights() in [distiller.py](https://github.com/yzd-v/MGD/tree/master/det/mmdet/distillation/distillers/detection_distiller.py). 7 | ## Add and Replace the codes 8 | - Add the configs/. in our codes to the configs/ in mmdetectin's codes. 9 | - Add the mmdet/distillation/. in our codes to the mmdet/ in mmdetectin's codes. 10 | - Replace the mmdet/apis/train.py and tools/train.py in mmdetection's codes with mmdet/apis/train.py and tools/train.py in our codes. 11 | - Add pth_transfer.py to mmdetection's codes. 12 | - Unzip COCO dataset into data/coco/ 13 | ## Train 14 | 15 | ``` 16 | #single GPU 17 | python tools/train.py configs/distillers/mgd/retina_rx101_64x4d_distill_retina_r50_fpn_2x_coco.py 18 | 19 | #multi GPU 20 | bash tools/dist_train.sh configs/distillers/mgd/retina_rx101_64x4d_distill_retina_r50_fpn_2x_coco.py 8 21 | ``` 22 | 23 | ## Transfer 24 | ``` 25 | # Tansfer the MGD model into mmdet model 26 | python pth_transfer.py --mgd_path $mgd_ckpt --output_path $new_mmdet_ckpt 27 | ``` 28 | ## Test 29 | 30 | ``` 31 | #single GPU 32 | python tools/test.py configs/retinanet/retinanet_r50_fpn_2x_coco.py $new_mmdet_ckpt --eval bbox 33 | 34 | #multi GPU 35 | bash tools/dist_test.sh configs/retinanet/retinanet_r50_fpn_2x_coco.py $new_mmdet_ckpt 8 --eval bbox 36 | ``` 37 | ## Results 38 | | Model | Backbone | Baseline(mAP) | +MGD(mAP) | config | log | weight | 39 | | :---------: | :--------: | :-----------: | :-------: | :----------------------------------------------------------: | :------------------------------------------------------: | :--: | 40 | | RetinaNet | ResNet-50 | 37.4 | 41.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet/retinanet_r50_fpn_2x_coco.py) | [baidu](https://pan.baidu.com/s/1sBxgi110KtZLSB8NDr7G-g?pwd=n83s) | [baidu](https://pan.baidu.com/s/1Bqv2XNa_TAGZJFUd177WWA?pwd=gu2x) | 41 | | Faster RCNN | ResNet-50 | 38.4 | 42.1 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py) | [baidu](https://pan.baidu.com/s/1xrLcE2e9e5qT1nomX4TqTg?pwd=aheq) | [baidu](https://pan.baidu.com/s/1vuZuq06wg3X9SJPNWQSgrw?pwd=2x8z) | 42 | | RepPoints | ResNet-50 | 38.6 | 42.3 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/reppoints/reppoints_moment_r50_fpn_gn-neck+head_2x_coco.py) | [baidu](https://pan.baidu.com/s/103unzbTgqjIBdYzH8zliEg?pwd=aucd) | [baidu](https://pan.baidu.com/s/1HfqvLoMU57y9NXPq5TNhow?pwd=g79p) | 43 | 44 | | Model | Backbone | Baseline(Mask mAP) | +MGD(Mask mAP) | config | log | weight | 45 | | :------: | :-------: | :----------------: | :------------: | :----------------------------------------------------------: | :------------------------------------------------------: | :--: | 46 | | SOLO | ResNet-50 | 33.1 | 36.2 | [config](https://github.com/open-mmlab/mmdetection/blob/master/configs/solo/solo_r50_fpn_1x_coco.py) | [baidu](https://pan.baidu.com/s/1kl7bSSkToN7atGZdWp9Ntw?pwd=wdpt) | [baidu](https://pan.baidu.com/s/1xZmIj_wP7SXsSxfXxa_4Ww?pwd=ksr1) | 47 | | MaskRCNN | ResNet-50 | 35.4 | 38.1 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py) | [baidu](https://pan.baidu.com/s/1uN8Q5Ew57oKUjzh65_TCVw?pwd=nykm) | [baidu](https://pan.baidu.com/s/1B4Bcw6S_sy882SMK2bp7uw?pwd=a7xf) | 48 | 49 | ## Citation 50 | ``` 51 | @article{yang2022masked, 52 | title={Masked Generative Distillation}, 53 | author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, 54 | journal={arXiv preprint arXiv:2205.01529}, 55 | year={2022} 56 | } 57 | ``` 58 | 59 | ## Acknowledgement 60 | 61 | Our code is based on the project [MMDetection](https://github.com/open-mmlab/mmdetection). -------------------------------------------------------------------------------- /det/configs/distillers/mgd/cascade_mask_rcnn_rx101_32x4d_distill_faster_rcnn_r50_fpn_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.0000005 7 | lambda_mgd=0.45 8 | distiller = dict( 9 | type='DetectionDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/dcn/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco-e75f90c8.pth', 11 | distill_cfg = [ dict(student_module = 'neck.fpn_convs.3.conv', 12 | teacher_module = 'neck.fpn_convs.3.conv', 13 | output_hook = True, 14 | methods=[dict(type='FeatureLoss', 15 | name='loss_mgd_fpn_3', 16 | student_channels = 256, 17 | teacher_channels = 256, 18 | alpha_mgd=alpha_mgd, 19 | lambda_mgd=lambda_mgd, 20 | ) 21 | ] 22 | ), 23 | dict(student_module = 'neck.fpn_convs.2.conv', 24 | teacher_module = 'neck.fpn_convs.2.conv', 25 | output_hook = True, 26 | methods=[dict(type='FeatureLoss', 27 | name='loss_mgd_fpn_2', 28 | student_channels = 256, 29 | teacher_channels = 256, 30 | alpha_mgd=alpha_mgd, 31 | lambda_mgd=lambda_mgd, 32 | ) 33 | ] 34 | ), 35 | dict(student_module = 'neck.fpn_convs.1.conv', 36 | teacher_module = 'neck.fpn_convs.1.conv', 37 | output_hook = True, 38 | methods=[dict(type='FeatureLoss', 39 | name='loss_mgd_fpn_1', 40 | student_channels = 256, 41 | teacher_channels = 256, 42 | alpha_mgd=alpha_mgd, 43 | lambda_mgd=lambda_mgd, 44 | ) 45 | ] 46 | ), 47 | dict(student_module = 'neck.fpn_convs.0.conv', 48 | teacher_module = 'neck.fpn_convs.0.conv', 49 | output_hook = True, 50 | methods=[dict(type='FeatureLoss', 51 | name='loss_mgd_fpn_0', 52 | student_channels = 256, 53 | teacher_channels = 256, 54 | alpha_mgd=alpha_mgd, 55 | lambda_mgd=lambda_mgd, 56 | ) 57 | ] 58 | ), 59 | ] 60 | ) 61 | 62 | student_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py' 63 | teacher_cfg = 'configs/dcn/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco.py' 64 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) -------------------------------------------------------------------------------- /det/configs/distillers/mgd/cascade_mask_rcnn_rx101_32x4d_distill_mask_rcnn_r50_fpn_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.0000005 7 | lambda_mgd=0.45 8 | distiller = dict( 9 | type='DetectionDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/dcn/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco-e75f90c8.pth', 11 | distill_cfg = [ dict(student_module = 'neck.fpn_convs.3.conv', 12 | teacher_module = 'neck.fpn_convs.3.conv', 13 | output_hook = True, 14 | methods=[dict(type='FeatureLoss', 15 | name='loss_mgd_fpn_3', 16 | student_channels = 256, 17 | teacher_channels = 256, 18 | alpha_mgd=alpha_mgd, 19 | lambda_mgd=lambda_mgd, 20 | ) 21 | ] 22 | ), 23 | dict(student_module = 'neck.fpn_convs.2.conv', 24 | teacher_module = 'neck.fpn_convs.2.conv', 25 | output_hook = True, 26 | methods=[dict(type='FeatureLoss', 27 | name='loss_mgd_fpn_2', 28 | student_channels = 256, 29 | teacher_channels = 256, 30 | alpha_mgd=alpha_mgd, 31 | lambda_mgd=lambda_mgd, 32 | ) 33 | ] 34 | ), 35 | dict(student_module = 'neck.fpn_convs.1.conv', 36 | teacher_module = 'neck.fpn_convs.1.conv', 37 | output_hook = True, 38 | methods=[dict(type='FeatureLoss', 39 | name='loss_mgd_fpn_1', 40 | student_channels = 256, 41 | teacher_channels = 256, 42 | alpha_mgd=alpha_mgd, 43 | lambda_mgd=lambda_mgd, 44 | ) 45 | ] 46 | ), 47 | dict(student_module = 'neck.fpn_convs.0.conv', 48 | teacher_module = 'neck.fpn_convs.0.conv', 49 | output_hook = True, 50 | methods=[dict(type='FeatureLoss', 51 | name='loss_mgd_fpn_0', 52 | student_channels = 256, 53 | teacher_channels = 256, 54 | alpha_mgd=alpha_mgd, 55 | lambda_mgd=lambda_mgd, 56 | ) 57 | ] 58 | ), 59 | ] 60 | ) 61 | 62 | 63 | student_cfg = 'configs/mask_rcnn/mask_rcnn_r50_fpn_2x_coco.py' 64 | teacher_cfg = 'configs/dcn/cascade_mask_rcnn_x101_32x4d_fpn_dconv_c3-c5_1x_coco.py' 65 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 66 | -------------------------------------------------------------------------------- /det/configs/distillers/mgd/reppoints_rx101_64x4d_distill_reppoints_r50_fpn_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../reppoints/reppoints_moment_r50_fpn_gn-neck+head_2x_coco.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.00002 7 | lambda_mgd=0.65 8 | distiller = dict( 9 | type='DetectionDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck%2Bhead_2x_coco_20200329-f87da1ea.pth', 11 | init_student = True, 12 | distill_cfg = [ 13 | dict(student_module = 'neck.fpn_convs.4.conv', 14 | teacher_module = 'neck.fpn_convs.4.conv', 15 | output_hook = True, 16 | methods=[dict(type='FeatureLoss', 17 | name='loss_mgd_fpn_4', 18 | student_channels = 256, 19 | teacher_channels = 256, 20 | alpha_mgd=alpha_mgd, 21 | lambda_mgd=lambda_mgd, 22 | ) 23 | ] 24 | ), 25 | dict(student_module = 'neck.fpn_convs.3.conv', 26 | teacher_module = 'neck.fpn_convs.3.conv', 27 | output_hook = True, 28 | methods=[dict(type='FeatureLoss', 29 | name='loss_mgd_fpn_3', 30 | student_channels = 256, 31 | teacher_channels = 256, 32 | alpha_mgd=alpha_mgd, 33 | lambda_mgd=lambda_mgd, 34 | ) 35 | ] 36 | ), 37 | dict(student_module = 'neck.fpn_convs.2.conv', 38 | teacher_module = 'neck.fpn_convs.2.conv', 39 | output_hook = True, 40 | methods=[dict(type='FeatureLoss', 41 | name='loss_mgd_fpn_2', 42 | student_channels = 256, 43 | teacher_channels = 256, 44 | alpha_mgd=alpha_mgd, 45 | lambda_mgd=lambda_mgd, 46 | ) 47 | ] 48 | ), 49 | dict(student_module = 'neck.fpn_convs.1.conv', 50 | teacher_module = 'neck.fpn_convs.1.conv', 51 | output_hook = True, 52 | methods=[dict(type='FeatureLoss', 53 | name='loss_mgd_fpn_1', 54 | student_channels = 256, 55 | teacher_channels = 256, 56 | alpha_mgd=alpha_mgd, 57 | lambda_mgd=lambda_mgd, 58 | ) 59 | ] 60 | ), 61 | dict(student_module = 'neck.fpn_convs.0.conv', 62 | teacher_module = 'neck.fpn_convs.0.conv', 63 | output_hook = True, 64 | methods=[dict(type='FeatureLoss', 65 | name='loss_mgd_fpn_0', 66 | student_channels = 256, 67 | teacher_channels = 256, 68 | alpha_mgd=alpha_mgd, 69 | lambda_mgd=lambda_mgd, 70 | ) 71 | ] 72 | ), 73 | 74 | ] 75 | ) 76 | 77 | student_cfg = 'configs/reppoints/reppoints_moment_r50_fpn_gn-neck+head_2x_coco.py' 78 | teacher_cfg = 'configs/reppoints/reppoints_moment_x101_fpn_dconv_c3-c5_gn-neck+head_2x_coco.py' 79 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) -------------------------------------------------------------------------------- /det/configs/distillers/mgd/retina_rx101_64x4d_distill_retina_r50_fpn_2x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../retinanet/retinanet_r50_fpn_2x_coco.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.00002 7 | lambda_mgd=0.65 8 | distiller = dict( 9 | type='DetectionDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/retinanet/retinanet_x101_64x4d_fpn_mstrain_3x_coco/retinanet_x101_64x4d_fpn_mstrain_3x_coco_20210719_051838-022c2187.pth', 11 | init_student = True, 12 | distill_cfg = [ dict(student_module = 'neck.fpn_convs.4.conv', 13 | teacher_module = 'neck.fpn_convs.4.conv', 14 | output_hook = True, 15 | methods=[dict(type='FeatureLoss', 16 | name='loss_mgd_fpn_4', 17 | student_channels = 256, 18 | teacher_channels = 256, 19 | alpha_mgd=alpha_mgd, 20 | lambda_mgd=lambda_mgd, 21 | ) 22 | ] 23 | ), 24 | dict(student_module = 'neck.fpn_convs.3.conv', 25 | teacher_module = 'neck.fpn_convs.3.conv', 26 | output_hook = True, 27 | methods=[dict(type='FeatureLoss', 28 | name='loss_mgd_fpn_3', 29 | student_channels = 256, 30 | teacher_channels = 256, 31 | alpha_mgd=alpha_mgd, 32 | lambda_mgd=lambda_mgd, 33 | ) 34 | ] 35 | ), 36 | dict(student_module = 'neck.fpn_convs.2.conv', 37 | teacher_module = 'neck.fpn_convs.2.conv', 38 | output_hook = True, 39 | methods=[dict(type='FeatureLoss', 40 | name='loss_mgd_fpn_2', 41 | student_channels = 256, 42 | teacher_channels = 256, 43 | alpha_mgd=alpha_mgd, 44 | lambda_mgd=lambda_mgd, 45 | ) 46 | ] 47 | ), 48 | dict(student_module = 'neck.fpn_convs.1.conv', 49 | teacher_module = 'neck.fpn_convs.1.conv', 50 | output_hook = True, 51 | methods=[dict(type='FeatureLoss', 52 | name='loss_mgd_fpn_1', 53 | student_channels = 256, 54 | teacher_channels = 256, 55 | alpha_mgd=alpha_mgd, 56 | lambda_mgd=lambda_mgd, 57 | ) 58 | ] 59 | ), 60 | dict(student_module = 'neck.fpn_convs.0.conv', 61 | teacher_module = 'neck.fpn_convs.0.conv', 62 | output_hook = True, 63 | methods=[dict(type='FeatureLoss', 64 | name='loss_mgd_fpn_0', 65 | student_channels = 256, 66 | teacher_channels = 256, 67 | alpha_mgd=alpha_mgd, 68 | lambda_mgd=lambda_mgd, 69 | ) 70 | ] 71 | ), 72 | 73 | ] 74 | ) 75 | 76 | student_cfg = 'configs/retinanet/retinanet_r50_fpn_2x_coco.py' 77 | teacher_cfg = 'configs/retinanet/retinanet_x101_64x4d_fpn_1x_coco.py' 78 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 79 | -------------------------------------------------------------------------------- /det/configs/distillers/mgd/solo_r101_ms_distill_solo_r50_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../solo/solo_r50_fpn_1x_coco.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.00002 7 | lambda_mgd=0.65 8 | distiller = dict( 9 | type='DetectionDistiller', 10 | teacher_pretrained = 'Solo_r101_3x.pth', 11 | init_student = True, 12 | distill_cfg = [ 13 | dict(student_module = 'neck.fpn_convs.3.conv', 14 | teacher_module = 'neck.fpn_convs.3.conv', 15 | output_hook = True, 16 | methods=[dict(type='FeatureLoss', 17 | name='loss_mgd_fpn_3', 18 | student_channels = 256, 19 | teacher_channels = 256, 20 | alpha_mgd=alpha_mgd, 21 | lambda_mgd=lambda_mgd, 22 | ) 23 | ] 24 | ), 25 | dict(student_module = 'neck.fpn_convs.2.conv', 26 | teacher_module = 'neck.fpn_convs.2.conv', 27 | output_hook = True, 28 | methods=[dict(type='FeatureLoss', 29 | name='loss_mgd_fpn_2', 30 | student_channels = 256, 31 | teacher_channels = 256, 32 | alpha_mgd=alpha_mgd, 33 | lambda_mgd=lambda_mgd, 34 | ) 35 | ] 36 | ), 37 | dict(student_module = 'neck.fpn_convs.1.conv', 38 | teacher_module = 'neck.fpn_convs.1.conv', 39 | output_hook = True, 40 | methods=[dict(type='FeatureLoss', 41 | name='loss_mgd_fpn_1', 42 | student_channels = 256, 43 | teacher_channels = 256, 44 | alpha_mgd=alpha_mgd, 45 | lambda_mgd=lambda_mgd, 46 | ) 47 | ] 48 | ), 49 | dict(student_module = 'neck.fpn_convs.0.conv', 50 | teacher_module = 'neck.fpn_convs.0.conv', 51 | output_hook = True, 52 | methods=[dict(type='FeatureLoss', 53 | name='loss_mgd_fpn_0', 54 | student_channels = 256, 55 | teacher_channels = 256, 56 | alpha_mgd=alpha_mgd, 57 | lambda_mgd=lambda_mgd, 58 | ) 59 | ] 60 | ), 61 | 62 | ] 63 | ) 64 | 65 | student_cfg = 'configs/solo/solo_r50_fpn_1x_coco.py' 66 | teacher_cfg = 'configs/solo/solo_r101_fpn_3x_coco.py' 67 | optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) 68 | -------------------------------------------------------------------------------- /det/mmdet/apis/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, 10 | Fp16OptimizerHook, OptimizerHook, build_optimizer, 11 | build_runner, get_dist_info) 12 | 13 | from mmdet.core import DistEvalHook, EvalHook 14 | from mmdet.datasets import (build_dataloader, build_dataset, 15 | replace_ImageToTensor) 16 | from mmdet.utils import get_root_logger 17 | 18 | 19 | def init_random_seed(seed=None, device='cuda'): 20 | """Initialize random seed. 21 | 22 | If the seed is not set, the seed will be automatically randomized, 23 | and then broadcast to all processes to prevent some potential bugs. 24 | 25 | Args: 26 | seed (int, Optional): The seed. Default to None. 27 | device (str): The device where the seed will be put on. 28 | Default to 'cuda'. 29 | 30 | Returns: 31 | int: Seed to be used. 32 | """ 33 | if seed is not None: 34 | return seed 35 | 36 | # Make sure all ranks share the same random seed to prevent 37 | # some potential bugs. Please refer to 38 | # https://github.com/open-mmlab/mmdetection/issues/6339 39 | rank, world_size = get_dist_info() 40 | seed = np.random.randint(2**31) 41 | if world_size == 1: 42 | return seed 43 | 44 | if rank == 0: 45 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 46 | else: 47 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 48 | dist.broadcast(random_num, src=0) 49 | return random_num.item() 50 | 51 | 52 | def set_random_seed(seed, deterministic=False): 53 | """Set random seed. 54 | 55 | Args: 56 | seed (int): Seed to be used. 57 | deterministic (bool): Whether to set the deterministic option for 58 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 59 | to True and `torch.backends.cudnn.benchmark` to False. 60 | Default: False. 61 | """ 62 | random.seed(seed) 63 | np.random.seed(seed) 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | if deterministic: 67 | torch.backends.cudnn.deterministic = True 68 | torch.backends.cudnn.benchmark = False 69 | 70 | 71 | def train_detector(model, 72 | dataset, 73 | cfg, 74 | distributed=False, 75 | validate=False, 76 | timestamp=None, 77 | meta=None): 78 | logger = get_root_logger(log_level=cfg.log_level) 79 | 80 | # prepare data loaders 81 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 82 | if 'imgs_per_gpu' in cfg.data: 83 | logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. ' 84 | 'Please use "samples_per_gpu" instead') 85 | if 'samples_per_gpu' in cfg.data: 86 | logger.warning( 87 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 88 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 89 | f'={cfg.data.imgs_per_gpu} is used in this experiments') 90 | else: 91 | logger.warning( 92 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 93 | f'{cfg.data.imgs_per_gpu} in this experiments') 94 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 95 | 96 | runner_type = 'EpochBasedRunner' if 'runner' not in cfg else cfg.runner[ 97 | 'type'] 98 | data_loaders = [ 99 | build_dataloader( 100 | ds, 101 | cfg.data.samples_per_gpu, 102 | cfg.data.workers_per_gpu, 103 | # `num_gpus` will be ignored if distributed 104 | num_gpus=len(cfg.gpu_ids), 105 | dist=distributed, 106 | seed=cfg.seed, 107 | runner_type=runner_type, 108 | persistent_workers=cfg.data.get('persistent_workers', False)) 109 | for ds in dataset 110 | ] 111 | 112 | # put model on gpus 113 | if distributed: 114 | find_unused_parameters = cfg.get('find_unused_parameters', False) 115 | # Sets the `find_unused_parameters` parameter in 116 | # torch.nn.parallel.DistributedDataParallel 117 | model = MMDistributedDataParallel( 118 | model.cuda(), 119 | device_ids=[torch.cuda.current_device()], 120 | broadcast_buffers=False, 121 | find_unused_parameters=find_unused_parameters) 122 | else: 123 | model = MMDataParallel( 124 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 125 | 126 | # build runner 127 | distiller_cfg = cfg.get('distiller',None) 128 | if distiller_cfg is None: 129 | optimizer = build_optimizer(model, cfg.optimizer) 130 | else: 131 | optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) 132 | 133 | if 'runner' not in cfg: 134 | cfg.runner = { 135 | 'type': 'EpochBasedRunner', 136 | 'max_epochs': cfg.total_epochs 137 | } 138 | warnings.warn( 139 | 'config is now expected to have a `runner` section, ' 140 | 'please set `runner` in your config.', UserWarning) 141 | else: 142 | if 'total_epochs' in cfg: 143 | assert cfg.total_epochs == cfg.runner.max_epochs 144 | 145 | runner = build_runner( 146 | cfg.runner, 147 | default_args=dict( 148 | model=model, 149 | optimizer=optimizer, 150 | work_dir=cfg.work_dir, 151 | logger=logger, 152 | meta=meta)) 153 | 154 | # an ugly workaround to make .log and .log.json filenames the same 155 | runner.timestamp = timestamp 156 | 157 | # fp16 setting 158 | fp16_cfg = cfg.get('fp16', None) 159 | if fp16_cfg is not None: 160 | optimizer_config = Fp16OptimizerHook( 161 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed) 162 | elif distributed and 'type' not in cfg.optimizer_config: 163 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 164 | else: 165 | optimizer_config = cfg.optimizer_config 166 | 167 | # register hooks 168 | runner.register_training_hooks( 169 | cfg.lr_config, 170 | optimizer_config, 171 | cfg.checkpoint_config, 172 | cfg.log_config, 173 | cfg.get('momentum_config', None), 174 | custom_hooks_config=cfg.get('custom_hooks', None)) 175 | 176 | if distributed: 177 | if isinstance(runner, EpochBasedRunner): 178 | runner.register_hook(DistSamplerSeedHook()) 179 | 180 | # register eval hooks 181 | if validate: 182 | # Support batch_size > 1 in validation 183 | val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1) 184 | if val_samples_per_gpu > 1: 185 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 186 | cfg.data.val.pipeline = replace_ImageToTensor( 187 | cfg.data.val.pipeline) 188 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 189 | val_dataloader = build_dataloader( 190 | val_dataset, 191 | samples_per_gpu=val_samples_per_gpu, 192 | workers_per_gpu=cfg.data.workers_per_gpu, 193 | dist=distributed, 194 | shuffle=False) 195 | eval_cfg = cfg.get('evaluation', {}) 196 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 197 | eval_hook = DistEvalHook if distributed else EvalHook 198 | # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the 199 | # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. 200 | runner.register_hook( 201 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') 202 | 203 | if cfg.resume_from: 204 | runner.resume(cfg.resume_from) 205 | elif cfg.load_from: 206 | runner.load_checkpoint(cfg.load_from) 207 | runner.run(data_loaders, cfg.workflow) 208 | -------------------------------------------------------------------------------- /det/mmdet/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .builder import ( DISTILLER,DISTILL_LOSSES,build_distill_loss,build_distiller) 3 | from .distillers import * 4 | from .losses import * 5 | 6 | 7 | __all__ = [ 8 | 'DISTILLER', 'DISTILL_LOSSES', 'build_distiller' 9 | ] 10 | 11 | 12 | -------------------------------------------------------------------------------- /det/mmdet/distillation/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from torch import nn 3 | 4 | DISTILLER = Registry('distiller') 5 | DISTILL_LOSSES = Registry('distill_loss') 6 | DISRUNNERS = Registry('runner') 7 | 8 | def build(cfg, registry, default_args=None): 9 | """Build a module. 10 | 11 | Args: 12 | cfg (dict, list[dict]): The config of modules, is is either a dict 13 | or a list of configs. 14 | registry (:obj:`Registry`): A registry the module belongs to. 15 | default_args (dict, optional): Default arguments to build the module. 16 | Defaults to None. 17 | 18 | Returns: 19 | nn.Module: A built nn module. 20 | """ 21 | 22 | if isinstance(cfg, list): 23 | modules = [ 24 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 25 | ] 26 | return nn.Sequential(*modules) 27 | else: 28 | return build_from_cfg(cfg, registry, default_args) 29 | 30 | def build_distill_loss(cfg): 31 | """Build distill loss.""" 32 | return build(cfg, DISTILL_LOSSES) 33 | 34 | def build_distiller(cfg,teacher_cfg=None,student_cfg=None, train_cfg=None, test_cfg=None): 35 | """Build distiller.""" 36 | if train_cfg is not None or test_cfg is not None: 37 | warnings.warn( 38 | 'train_cfg and test_cfg is deprecated, ' 39 | 'please specify them in model', UserWarning) 40 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 41 | 'train_cfg specified in both outer field and model field ' 42 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 43 | 'test_cfg specified in both outer field and model field ' 44 | return build(cfg, DISTILLER, dict(teacher_cfg=teacher_cfg,student_cfg=student_cfg)) 45 | 46 | 47 | def build_runner(cfg, default_args=None): 48 | return build_from_cfg(cfg, DISRUNNERS, default_args=default_args) 49 | -------------------------------------------------------------------------------- /det/mmdet/distillation/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection_distiller import DetectionDistiller 2 | 3 | __all__ = [ 4 | 'DetectionDistiller' 5 | ] -------------------------------------------------------------------------------- /det/mmdet/distillation/distillers/detection_distiller.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from mmdet.models.detectors.base import BaseDetector 5 | from mmdet.models import build_detector 6 | from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict 7 | from ..builder import DISTILLER,build_distill_loss 8 | from collections import OrderedDict 9 | import copy 10 | 11 | 12 | @DISTILLER.register_module() 13 | class DetectionDistiller(BaseDetector): 14 | """Base distiller for detectors. 15 | 16 | It typically consists of teacher_model and student_model. 17 | """ 18 | def __init__(self, 19 | teacher_cfg, 20 | student_cfg, 21 | distill_cfg=None, 22 | teacher_pretrained=None, 23 | init_student=False): 24 | 25 | super(DetectionDistiller, self).__init__() 26 | 27 | self.teacher = build_detector(teacher_cfg.model, 28 | train_cfg=teacher_cfg.get('train_cfg'), 29 | test_cfg=teacher_cfg.get('test_cfg')) 30 | self.init_weights_teacher(teacher_pretrained) 31 | self.teacher.eval() 32 | 33 | self.student= build_detector(student_cfg.model, 34 | train_cfg=student_cfg.get('train_cfg'), 35 | test_cfg=student_cfg.get('test_cfg')) 36 | self.student.init_weights() 37 | if init_student: 38 | t_checkpoint = _load_checkpoint(teacher_pretrained) 39 | all_name = [] 40 | for name, v in t_checkpoint["state_dict"].items(): 41 | if name.startswith("backbone."): 42 | continue 43 | else: 44 | all_name.append((name, v)) 45 | 46 | state_dict = OrderedDict(all_name) 47 | load_state_dict(self.student, state_dict) 48 | 49 | self.distill_losses = nn.ModuleDict() 50 | self.distill_cfg = distill_cfg 51 | 52 | student_modules = dict(self.student.named_modules()) 53 | teacher_modules = dict(self.teacher.named_modules()) 54 | def regitster_hooks(student_module,teacher_module): 55 | def hook_teacher_forward(module, input, output): 56 | self.register_buffer(teacher_module,output) 57 | def hook_student_forward(module, input, output): 58 | self.register_buffer( student_module,output ) 59 | return hook_teacher_forward,hook_student_forward 60 | 61 | for item_loc in distill_cfg: 62 | 63 | student_module = 'student_' + item_loc.student_module.replace('.','_') 64 | teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') 65 | 66 | self.register_buffer(student_module,None) 67 | self.register_buffer(teacher_module,None) 68 | 69 | hook_teacher_forward,hook_student_forward = regitster_hooks(student_module ,teacher_module ) 70 | teacher_modules[item_loc.teacher_module].register_forward_hook(hook_teacher_forward) 71 | student_modules[item_loc.student_module].register_forward_hook(hook_student_forward) 72 | 73 | for item_loss in item_loc.methods: 74 | loss_name = item_loss.name 75 | self.distill_losses[loss_name] = build_distill_loss(item_loss) 76 | def base_parameters(self): 77 | return nn.ModuleList([self.student,self.distill_losses]) 78 | 79 | 80 | @property 81 | def with_neck(self): 82 | """bool: whether the detector has a neck""" 83 | return hasattr(self.student, 'neck') and self.student.neck is not None 84 | 85 | @property 86 | def with_shared_head(self): 87 | """bool: whether the detector has a shared head in the RoI Head""" 88 | return hasattr(self.student, 'roi_head') and self.student.roi_head.with_shared_head 89 | 90 | @property 91 | def with_bbox(self): 92 | """bool: whether the detector has a bbox head""" 93 | return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_bbox) 94 | or (hasattr(self.student, 'bbox_head') and self.student.bbox_head is not None)) 95 | 96 | @property 97 | def with_mask(self): 98 | """bool: whether the detector has a mask head""" 99 | return ((hasattr(self.student, 'roi_head') and self.student.roi_head.with_mask) 100 | or (hasattr(self.student, 'mask_head') and self.student.mask_head is not None)) 101 | 102 | def init_weights_teacher(self, path=None): 103 | """Load the pretrained model in teacher detector. 104 | 105 | Args: 106 | pretrained (str, optional): Path to pre-trained weights. 107 | Defaults to None. 108 | """ 109 | checkpoint = load_checkpoint(self.teacher, path, map_location='cpu') 110 | 111 | 112 | 113 | def forward_train(self, 114 | img, 115 | img_metas, 116 | **kwargs): 117 | 118 | """ 119 | Args: 120 | img (Tensor): Input images of shape (N, C, H, W). 121 | Typically these should be mean centered and std scaled. 122 | img_metas (list[dict]): A List of image info dict where each dict 123 | has: 'img_shape', 'scale_factor', 'flip', and may also contain 124 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 125 | For details on the values of these keys see 126 | :class:`mmdet.datasets.pipelines.Collect`. 127 | 128 | Returns: 129 | dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses). 130 | """ 131 | 132 | student_loss = self.student.forward_train(img, img_metas, **kwargs) 133 | 134 | with torch.no_grad(): 135 | fea_t = self.teacher.extract_feat(img) 136 | 137 | buffer_dict = dict(self.named_buffers()) 138 | for item_loc in self.distill_cfg: 139 | student_module = 'student_' + item_loc.student_module.replace('.','_') 140 | teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_') 141 | student_feat = buffer_dict[student_module] 142 | teacher_feat = buffer_dict[teacher_module] 143 | for item_loss in item_loc.methods: 144 | loss_name = item_loss.name 145 | student_loss[loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat) 146 | 147 | return student_loss 148 | 149 | def simple_test(self, img, img_metas, **kwargs): 150 | return self.student.simple_test(img, img_metas, **kwargs) 151 | def aug_test(self, imgs, img_metas, **kwargs): 152 | return self.student.aug_test(imgs, img_metas, **kwargs) 153 | def extract_feat(self, imgs): 154 | """Extract features from images.""" 155 | return self.student.extract_feat(imgs) 156 | 157 | 158 | -------------------------------------------------------------------------------- /det/mmdet/distillation/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .mgd import FeatureLoss 2 | __all__ = [ 3 | 'FeatureLoss' 4 | ] 5 | -------------------------------------------------------------------------------- /det/mmdet/distillation/losses/mgd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from ..builder import DISTILL_LOSSES 5 | 6 | @DISTILL_LOSSES.register_module() 7 | class FeatureLoss(nn.Module): 8 | 9 | """PyTorch version of `Masked Generative Distillation` 10 | 11 | Args: 12 | student_channels(int): Number of channels in the student's feature map. 13 | teacher_channels(int): Number of channels in the teacher's feature map. 14 | name (str): the loss name of the layer 15 | alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002 16 | lambda_mgd (float, optional): masked ratio. Defaults to 0.65 17 | """ 18 | def __init__(self, 19 | student_channels, 20 | teacher_channels, 21 | name, 22 | alpha_mgd=0.00002, 23 | lambda_mgd=0.65, 24 | ): 25 | super(FeatureLoss, self).__init__() 26 | self.alpha_mgd = alpha_mgd 27 | self.lambda_mgd = lambda_mgd 28 | self.name = name 29 | 30 | if student_channels != teacher_channels: 31 | self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) 32 | else: 33 | self.align = None 34 | 35 | self.generation = nn.Sequential( 36 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)) 39 | 40 | 41 | def forward(self, 42 | preds_S, 43 | preds_T): 44 | """Forward function. 45 | Args: 46 | preds_S(Tensor): Bs*C*H*W, student's feature map 47 | preds_T(Tensor): Bs*C*H*W, teacher's feature map 48 | """ 49 | assert preds_S.shape[-2:] == preds_T.shape[-2:] 50 | 51 | if self.align is not None: 52 | preds_S = self.align(preds_S) 53 | 54 | loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgd 55 | 56 | return loss 57 | 58 | def get_dis_loss(self, preds_S, preds_T): 59 | loss_mse = nn.MSELoss(reduction='sum') 60 | N, C, H, W = preds_T.shape 61 | 62 | device = preds_S.device 63 | mat = torch.rand((N,1,H,W)).to(device) 64 | mat = torch.where(mat>1-self.lambda_mgd, 0, 1).to(device) 65 | 66 | masked_fea = torch.mul(preds_S, mat) 67 | new_fea = self.generation(masked_fea) 68 | 69 | dis_loss = loss_mse(new_fea, preds_T)/N 70 | 71 | return dis_loss -------------------------------------------------------------------------------- /det/tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | from mmcv import Config, DictAction 12 | from mmcv.runner import get_dist_info, init_dist 13 | from mmcv.utils import get_git_hash 14 | 15 | from mmdet import __version__ 16 | from mmdet.apis import init_random_seed, set_random_seed, train_detector 17 | from mmdet.datasets import build_dataset 18 | from mmdet.models import build_detector 19 | from mmdet.utils import collect_env, get_root_logger 20 | from mmdet.distillation import build_distiller 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a detector') 24 | parser.add_argument('config', help='train config file path') 25 | parser.add_argument('--work-dir', help='the dir to save logs and models') 26 | parser.add_argument( 27 | '--resume-from', help='the checkpoint file to resume from') 28 | parser.add_argument( 29 | '--no-validate', 30 | action='store_true', 31 | help='whether not to evaluate the checkpoint during training') 32 | group_gpus = parser.add_mutually_exclusive_group() 33 | group_gpus.add_argument( 34 | '--gpus', 35 | type=int, 36 | help='number of gpus to use ' 37 | '(only applicable to non-distributed training)') 38 | group_gpus.add_argument( 39 | '--gpu-ids', 40 | type=int, 41 | nargs='+', 42 | help='ids of gpus to use ' 43 | '(only applicable to non-distributed training)') 44 | parser.add_argument('--seed', type=int, default=None, help='random seed') 45 | parser.add_argument( 46 | '--deterministic', 47 | action='store_true', 48 | help='whether to set deterministic options for CUDNN backend.') 49 | parser.add_argument( 50 | '--options', 51 | nargs='+', 52 | action=DictAction, 53 | help='override some settings in the used config, the key-value pair ' 54 | 'in xxx=yyy format will be merged into config file (deprecate), ' 55 | 'change to --cfg-options instead.') 56 | parser.add_argument( 57 | '--cfg-options', 58 | nargs='+', 59 | action=DictAction, 60 | help='override some settings in the used config, the key-value pair ' 61 | 'in xxx=yyy format will be merged into config file. If the value to ' 62 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 63 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 64 | 'Note that the quotation marks are necessary and that no white space ' 65 | 'is allowed.') 66 | parser.add_argument( 67 | '--launcher', 68 | choices=['none', 'pytorch', 'slurm', 'mpi'], 69 | default='none', 70 | help='job launcher') 71 | parser.add_argument('--local_rank', type=int, default=0) 72 | args = parser.parse_args() 73 | if 'LOCAL_RANK' not in os.environ: 74 | os.environ['LOCAL_RANK'] = str(args.local_rank) 75 | 76 | if args.options and args.cfg_options: 77 | raise ValueError( 78 | '--options and --cfg-options cannot be both ' 79 | 'specified, --options is deprecated in favor of --cfg-options') 80 | if args.options: 81 | warnings.warn('--options is deprecated in favor of --cfg-options') 82 | args.cfg_options = args.options 83 | 84 | return args 85 | 86 | 87 | def main(): 88 | args = parse_args() 89 | 90 | cfg = Config.fromfile(args.config) 91 | if args.cfg_options is not None: 92 | cfg.merge_from_dict(args.cfg_options) 93 | # set cudnn_benchmark 94 | if cfg.get('cudnn_benchmark', False): 95 | torch.backends.cudnn.benchmark = True 96 | 97 | # work_dir is determined in this priority: CLI > segment in file > filename 98 | if args.work_dir is not None: 99 | # update configs according to CLI args if args.work_dir is not None 100 | cfg.work_dir = args.work_dir 101 | elif cfg.get('work_dir', None) is None: 102 | # use config filename as default work_dir if cfg.work_dir is None 103 | cfg.work_dir = osp.join('./work_dirs', 104 | osp.splitext(osp.basename(args.config))[0]) 105 | if args.resume_from is not None: 106 | cfg.resume_from = args.resume_from 107 | if args.gpu_ids is not None: 108 | cfg.gpu_ids = args.gpu_ids 109 | else: 110 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 111 | 112 | # init distributed env first, since logger depends on the dist info. 113 | if args.launcher == 'none': 114 | distributed = False 115 | else: 116 | distributed = True 117 | init_dist(args.launcher, **cfg.dist_params) 118 | # re-set gpu_ids with distributed training mode 119 | _, world_size = get_dist_info() 120 | cfg.gpu_ids = range(world_size) 121 | 122 | # create work_dir 123 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 124 | # dump config 125 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 126 | # init the logger before other steps 127 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 128 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 129 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 130 | 131 | # init the meta dict to record some important information such as 132 | # environment info and seed, which will be logged 133 | meta = dict() 134 | # log env info 135 | env_info_dict = collect_env() 136 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 137 | dash_line = '-' * 60 + '\n' 138 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 139 | dash_line) 140 | meta['env_info'] = env_info 141 | meta['config'] = cfg.pretty_text 142 | # log some basic info 143 | logger.info(f'Distributed training: {distributed}') 144 | logger.info(f'Config:\n{cfg.pretty_text}') 145 | 146 | # set random seeds 147 | seed = init_random_seed(args.seed) 148 | logger.info(f'Set random seed to {seed}, ' 149 | f'deterministic: {args.deterministic}') 150 | set_random_seed(seed, deterministic=args.deterministic) 151 | cfg.seed = seed 152 | meta['seed'] = seed 153 | meta['exp_name'] = osp.basename(args.config) 154 | 155 | distiller_cfg = cfg.get('distiller',None) 156 | if distiller_cfg is None: 157 | model = build_detector( 158 | cfg.model, 159 | train_cfg=cfg.get('train_cfg'), 160 | test_cfg=cfg.get('test_cfg')) 161 | model.init_weights() 162 | else: 163 | teacher_cfg = Config.fromfile(cfg.teacher_cfg) 164 | student_cfg = Config.fromfile(cfg.student_cfg) 165 | 166 | model = build_distiller(cfg.distiller,teacher_cfg,student_cfg, 167 | train_cfg=student_cfg.get('train_cfg'), 168 | test_cfg=student_cfg.get('test_cfg')) 169 | 170 | datasets = [build_dataset(cfg.data.train)] 171 | if len(cfg.workflow) == 2: 172 | val_dataset = copy.deepcopy(cfg.data.val) 173 | val_dataset.pipeline = cfg.data.train.pipeline 174 | datasets.append(build_dataset(val_dataset)) 175 | if cfg.checkpoint_config is not None: 176 | # save mmdet version, config file content and class names in 177 | # checkpoints as meta data 178 | cfg.checkpoint_config.meta = dict( 179 | mmdet_version=__version__ + get_git_hash()[:7], 180 | CLASSES=datasets[0].CLASSES) 181 | # add an attribute for visualization convenience 182 | model.CLASSES = datasets[0].CLASSES 183 | train_detector( 184 | model, 185 | datasets, 186 | cfg, 187 | distributed=distributed, 188 | validate=(not args.no_validate), 189 | timestamp=timestamp, 190 | meta=meta) 191 | 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /pth_transfer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | def change_model(args): 8 | mgd_model = torch.load(args.mgd_path) 9 | all_name = [] 10 | for name, v in mgd_model["state_dict"].items(): 11 | if name.startswith("student."): 12 | all_name.append((name[8:], v)) 13 | else: 14 | continue 15 | state_dict = OrderedDict(all_name) 16 | mgd_model['state_dict'] = state_dict 17 | mgd_model.pop('optimizer') 18 | torch.save(mgd_model, args.output_path) 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description='Transfer CKPT') 23 | parser.add_argument('--mgd_path', type=str, default='work_dirs/mgd_psp_r101_distill_deepv3_r18_40k_512x512_city/latest.pth', 24 | metavar='N',help='mgd_model path') 25 | parser.add_argument('--output_path', type=str, default='deeplabv3_res18_new.pth',metavar='N', 26 | help = 'output path') 27 | args = parser.parse_args() 28 | change_model(args) 29 | -------------------------------------------------------------------------------- /seg/README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation 2 | ## Install 3 | - Our codes are based on [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). Please follow the installation of MMSegmentation and make sure you can run it successfully. 4 | - This repo uses mmcv-full==1.4.2 and mmseg==0.20.2 5 | - If you want to use lower mmcv-full version, you may have to change the optimizer in apis/train.py and build_distiller in tools/train.py. 6 | - For lower mmcv-full, you may refer [FGD](https://github.com/yzd-v/FGD) to change model.init_weights() in [train.py](https://github.com/yzd-v/MGD/tree/master/seg/tools/train.py) and self.student.init_weights() in [distiller.py](https://github.com/yzd-v/MGD/tree/master/seg/mmseg/distillation/distillers/segmentation_distiller.py). 7 | ## Add and Replace the codes 8 | - Add the configs/. in our codes to the configs/ in mmsegmentation's codes. 9 | - Add the mmseg/distillation/. in our codes to the mmseg/ in mmsegmentation's codes. 10 | - Replace the mmseg/apis/train.py and tools/train.py in mmsegmentation's codes with mmseg/apis/train.py and tools/train.py in our codes. 11 | - Add pth_transfer.py to mmsegmentation's codes. 12 | - Unzip CityScapes dataset into data/cityscape/ 13 | ## Train 14 | 15 | ``` 16 | #single GPU 17 | python tools/train.py configs/distillers/mgd/psp_r101_distill_psp_r18_40k_512x512_city.py 18 | 19 | #multi GPU 20 | bash tools/dist_train.sh configs/distillers/mgd/psp_r101_distill_psp_r18_40k_512x512_city.py 8 21 | ``` 22 | 23 | ## Transfer 24 | ``` 25 | # Tansfer the MGD model into mmseg model 26 | python pth_transfer.py --mgd_path $mgd_ckpt --output_path $new_mmseg_ckpt 27 | ``` 28 | ## Test 29 | 30 | ``` 31 | #single GPU 32 | python tools/test.py configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py $new_mmseg_ckpt --eval mIoU 33 | 34 | #multi GPU 35 | bash tools/dist_test.sh configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py $new_seg_ckpt 8 --eval mIoU 36 | ``` 37 | ## Results 38 | | Model | Backbone | Baseline(mIoU) | +MGD(mIoU) | config | log | weight | 39 | | :------: | :-------: | :----------------: | :------------: | :----------------------------------------------------------: | :------------------------------------------------------: | :--: | 40 | | PspNet | ResNet-18 | 69.85 | 73.63 | [config](https://github.com/yzd-v/MGD/tree/master/seg/configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py) | [baidu](https://pan.baidu.com/s/15mLdMez1yf4-lrR0u5XUag?pwd=7vqd) | [baidu](https://pan.baidu.com/s/1a2DgN70r-jxl4bpC07NXQQ?pwd=u5ii) | 41 | | DeepLabV3 | ResNet-18 | 73.20 | 76.31 | [config](https://github.com/yzd-v/MGD/tree/master/seg/configs/deeplabv3/deeplabv3_r18-d8_512x512_40k_cityscapes.py) | [baidu](https://pan.baidu.com/s/1xSXxQuIJ52ZihP0g3-0_pw?pwd=h9aw) | [baidu](https://pan.baidu.com/s/1Q8fOKhJWTHOSaEQVIIg4aw?pwd=1m9s) | 42 | 43 | ## Citation 44 | ``` 45 | @article{yang2022masked, 46 | title={Masked Generative Distillation}, 47 | author={Yang, Zhendong and Li, Zhe and Shao, Mingqi and Shi, Dachuan and Yuan, Zehuan and Yuan, Chun}, 48 | journal={arXiv preprint arXiv:2205.01529}, 49 | year={2022} 50 | } 51 | ``` 52 | 53 | ## Acknowledgement 54 | 55 | Our code is based on the project [MMSegmentation](https://github.com/open-mmlab/mmsegmentation). -------------------------------------------------------------------------------- /seg/configs/_base_/datasets/cityscapes_512x512.py: -------------------------------------------------------------------------------- 1 | _base_ = './cityscapes.py' 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | crop_size = (512, 512) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations'), 8 | dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), 9 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 10 | dict(type='RandomFlip', prob=0.5), 11 | dict(type='PhotoMetricDistortion'), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(2048, 1024), 22 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 23 | flip=False, 24 | transforms=[ 25 | dict(type='Resize', keep_ratio=True), 26 | dict(type='RandomFlip'), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | train=dict(pipeline=train_pipeline), 34 | val=dict(pipeline=test_pipeline), 35 | test=dict(pipeline=test_pipeline)) 36 | -------------------------------------------------------------------------------- /seg/configs/deeplabv3/deeplabv3_r18-d8_512x512_40k_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/deeplabv3_r50-d8.py', '../_base_/datasets/cityscapes_512x512.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' 4 | ] 5 | model = dict( 6 | pretrained='open-mmlab://resnet18_v1c', 7 | backbone=dict(depth=18), 8 | decode_head=dict( 9 | in_channels=512, 10 | channels=128, 11 | ), 12 | auxiliary_head=dict(in_channels=256, channels=64)) 13 | -------------------------------------------------------------------------------- /seg/configs/distillers/mgd/psp_r101_distill_deepv3_r18_40k_512x512_city.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../deeplabv3/deeplabv3_r18-d8_512x512_40k_cityscapes.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.00002 7 | lambda_mgd=0.75 8 | distiller = dict( 9 | type='SegmentationDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_40k_cityscapes/pspnet_r101-d8_512x1024_40k_cityscapes_20200604_232751-467e7cf4.pth', 11 | init_student = False, 12 | use_logit = True, 13 | distill_cfg = [ dict(methods=[dict(type='FeatureLoss', 14 | name='loss_mgd_fea', 15 | student_channels = 512, 16 | teacher_channels = 2048, 17 | alpha_mgd=alpha_mgd, 18 | lambda_mgd=lambda_mgd, 19 | ) 20 | ] 21 | ), 22 | ] 23 | ) 24 | 25 | student_cfg = 'configs/deeplabv3/deeplabv3_r18-d8_512x512_40k_cityscapes.py' 26 | teacher_cfg = 'configs/pspnet/pspnet_r101-d8_512x1024_40k_cityscapes.py' 27 | -------------------------------------------------------------------------------- /seg/configs/distillers/mgd/psp_r101_distill_psp_r18_40k_512x512_city.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py' 3 | ] 4 | # model settings 5 | find_unused_parameters=True 6 | alpha_mgd=0.00002 7 | lambda_mgd=0.75 8 | distiller = dict( 9 | type='SegmentationDistiller', 10 | teacher_pretrained = 'https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r101-d8_512x1024_40k_cityscapes/pspnet_r101-d8_512x1024_40k_cityscapes_20200604_232751-467e7cf4.pth', 11 | init_student = False, 12 | use_logit = True, 13 | distill_cfg = [ dict(methods=[dict(type='FeatureLoss', 14 | name='loss_mgd_fea', 15 | student_channels = 512, 16 | teacher_channels = 2048, 17 | alpha_mgd=alpha_mgd, 18 | lambda_mgd=lambda_mgd, 19 | ) 20 | ] 21 | ), 22 | ] 23 | ) 24 | 25 | student_cfg = 'configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py' 26 | teacher_cfg = 'configs/pspnet/pspnet_r101-d8_512x1024_40k_cityscapes.py' 27 | -------------------------------------------------------------------------------- /seg/configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/pspnet_r50-d8.py', '../_base_/datasets/cityscapes_512x512.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' 4 | ] 5 | model = dict( 6 | pretrained='open-mmlab://resnet18_v1c', 7 | backbone=dict(depth=18), 8 | decode_head=dict( 9 | in_channels=512, 10 | channels=128, 11 | ), 12 | auxiliary_head=dict(in_channels=256, channels=64)) 13 | -------------------------------------------------------------------------------- /seg/mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | import warnings 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 9 | from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info 10 | from mmcv.utils import build_from_cfg 11 | 12 | from mmseg.core import DistEvalHook, EvalHook 13 | from mmseg.datasets import build_dataloader, build_dataset 14 | from mmseg.utils import get_root_logger 15 | 16 | 17 | def init_random_seed(seed=None, device='cuda'): 18 | """Initialize random seed. 19 | 20 | If the seed is not set, the seed will be automatically randomized, 21 | and then broadcast to all processes to prevent some potential bugs. 22 | Args: 23 | seed (int, Optional): The seed. Default to None. 24 | device (str): The device where the seed will be put on. 25 | Default to 'cuda'. 26 | Returns: 27 | int: Seed to be used. 28 | """ 29 | if seed is not None: 30 | return seed 31 | 32 | # Make sure all ranks share the same random seed to prevent 33 | # some potential bugs. Please refer to 34 | # https://github.com/open-mmlab/mmdetection/issues/6339 35 | rank, world_size = get_dist_info() 36 | seed = np.random.randint(2**31) 37 | if world_size == 1: 38 | return seed 39 | 40 | if rank == 0: 41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 42 | else: 43 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 44 | dist.broadcast(random_num, src=0) 45 | return random_num.item() 46 | 47 | 48 | def set_random_seed(seed, deterministic=False): 49 | """Set random seed. 50 | 51 | Args: 52 | seed (int): Seed to be used. 53 | deterministic (bool): Whether to set the deterministic option for 54 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 55 | to True and `torch.backends.cudnn.benchmark` to False. 56 | Default: False. 57 | """ 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | torch.cuda.manual_seed_all(seed) 62 | if deterministic: 63 | torch.backends.cudnn.deterministic = True 64 | torch.backends.cudnn.benchmark = False 65 | 66 | 67 | def train_segmentor(model, 68 | dataset, 69 | cfg, 70 | distributed=False, 71 | validate=False, 72 | timestamp=None, 73 | meta=None): 74 | """Launch segmentor training.""" 75 | logger = get_root_logger(cfg.log_level) 76 | 77 | # prepare data loaders 78 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 79 | data_loaders = [ 80 | build_dataloader( 81 | ds, 82 | cfg.data.samples_per_gpu, 83 | cfg.data.workers_per_gpu, 84 | # cfg.gpus will be ignored if distributed 85 | len(cfg.gpu_ids), 86 | dist=distributed, 87 | seed=cfg.seed, 88 | drop_last=True) for ds in dataset 89 | ] 90 | 91 | # put model on gpus 92 | if distributed: 93 | find_unused_parameters = cfg.get('find_unused_parameters', False) 94 | # Sets the `find_unused_parameters` parameter in 95 | # torch.nn.parallel.DistributedDataParallel 96 | model = MMDistributedDataParallel( 97 | model.cuda(), 98 | device_ids=[torch.cuda.current_device()], 99 | broadcast_buffers=False, 100 | find_unused_parameters=find_unused_parameters) 101 | else: 102 | model = MMDataParallel( 103 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 104 | 105 | # build runner 106 | distiller_cfg = cfg.get('distiller',None) 107 | if distiller_cfg is None: 108 | optimizer = build_optimizer(model, cfg.optimizer) 109 | else: 110 | optimizer = build_optimizer(model.module.base_parameters(), cfg.optimizer) 111 | 112 | if cfg.get('runner') is None: 113 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 114 | warnings.warn( 115 | 'config is now expected to have a `runner` section, ' 116 | 'please set `runner` in your config.', UserWarning) 117 | 118 | runner = build_runner( 119 | cfg.runner, 120 | default_args=dict( 121 | model=model, 122 | batch_processor=None, 123 | optimizer=optimizer, 124 | work_dir=cfg.work_dir, 125 | logger=logger, 126 | meta=meta)) 127 | 128 | # register hooks 129 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 130 | cfg.checkpoint_config, cfg.log_config, 131 | cfg.get('momentum_config', None)) 132 | 133 | # an ugly walkaround to make the .log and .log.json filenames the same 134 | runner.timestamp = timestamp 135 | 136 | # register eval hooks 137 | if validate: 138 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 139 | val_dataloader = build_dataloader( 140 | val_dataset, 141 | samples_per_gpu=1, 142 | workers_per_gpu=cfg.data.workers_per_gpu, 143 | dist=distributed, 144 | shuffle=False) 145 | eval_cfg = cfg.get('evaluation', {}) 146 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 147 | eval_hook = DistEvalHook if distributed else EvalHook 148 | # In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the 149 | # priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'. 150 | runner.register_hook( 151 | eval_hook(val_dataloader, **eval_cfg), priority='LOW') 152 | 153 | # user-defined hooks 154 | if cfg.get('custom_hooks', None): 155 | custom_hooks = cfg.custom_hooks 156 | assert isinstance(custom_hooks, list), \ 157 | f'custom_hooks expect list type, but got {type(custom_hooks)}' 158 | for hook_cfg in cfg.custom_hooks: 159 | assert isinstance(hook_cfg, dict), \ 160 | 'Each item in custom_hooks expects dict type, but got ' \ 161 | f'{type(hook_cfg)}' 162 | hook_cfg = hook_cfg.copy() 163 | priority = hook_cfg.pop('priority', 'NORMAL') 164 | hook = build_from_cfg(hook_cfg, HOOKS) 165 | runner.register_hook(hook, priority=priority) 166 | 167 | if cfg.resume_from: 168 | runner.resume(cfg.resume_from) 169 | elif cfg.load_from: 170 | runner.load_checkpoint(cfg.load_from) 171 | runner.run(data_loaders, cfg.workflow) 172 | -------------------------------------------------------------------------------- /seg/mmseg/distillation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .builder import ( DISTILLER,DISTILL_LOSSES,build_distill_loss,build_distiller) 3 | from .distillers import * 4 | from .losses import * 5 | 6 | 7 | __all__ = [ 8 | 'DISTILLER', 'DISTILL_LOSSES', 'build_distiller' 9 | ] 10 | 11 | 12 | -------------------------------------------------------------------------------- /seg/mmseg/distillation/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from torch import nn 3 | 4 | DISTILLER = Registry('distiller') 5 | DISTILL_LOSSES = Registry('distill_loss') 6 | 7 | def build(cfg, registry, default_args=None): 8 | """Build a module. 9 | 10 | Args: 11 | cfg (dict, list[dict]): The config of modules, is is either a dict 12 | or a list of configs. 13 | registry (:obj:`Registry`): A registry the module belongs to. 14 | default_args (dict, optional): Default arguments to build the module. 15 | Defaults to None. 16 | 17 | Returns: 18 | nn.Module: A built nn module. 19 | """ 20 | 21 | if isinstance(cfg, list): 22 | modules = [ 23 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 24 | ] 25 | return nn.Sequential(*modules) 26 | else: 27 | return build_from_cfg(cfg, registry, default_args) 28 | 29 | def build_distill_loss(cfg): 30 | """Build distill loss.""" 31 | return build(cfg, DISTILL_LOSSES) 32 | 33 | def build_distiller(cfg,teacher_cfg=None,student_cfg=None, train_cfg=None, test_cfg=None): 34 | """Build distiller.""" 35 | if train_cfg is not None or test_cfg is not None: 36 | warnings.warn( 37 | 'train_cfg and test_cfg is deprecated, ' 38 | 'please specify them in model', UserWarning) 39 | assert cfg.get('train_cfg') is None or train_cfg is None, \ 40 | 'train_cfg specified in both outer field and model field ' 41 | assert cfg.get('test_cfg') is None or test_cfg is None, \ 42 | 'test_cfg specified in both outer field and model field ' 43 | return build(cfg, DISTILLER, dict(teacher_cfg=teacher_cfg,student_cfg=student_cfg)) 44 | 45 | -------------------------------------------------------------------------------- /seg/mmseg/distillation/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from .segmentation_distiller import SegmentationDistiller 2 | 3 | __all__ = [ 4 | 'SegmentationDistiller' 5 | ] -------------------------------------------------------------------------------- /seg/mmseg/distillation/distillers/segmentation_distiller.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from mmseg.models.segmentors.base import BaseSegmentor 5 | from mmseg.models import build_segmentor 6 | from mmcv.runner import load_checkpoint, _load_checkpoint, load_state_dict 7 | from ..builder import DISTILLER,build_distill_loss 8 | from collections import OrderedDict 9 | from mmseg.core import add_prefix 10 | 11 | 12 | @DISTILLER.register_module() 13 | class SegmentationDistiller(BaseSegmentor): 14 | """Base distiller for detectors. 15 | 16 | It typically consists of teacher_model and student_model. 17 | """ 18 | def __init__(self, 19 | teacher_cfg, 20 | student_cfg, 21 | distill_cfg=None, 22 | teacher_pretrained=None, 23 | init_student=False, 24 | use_logit=True): 25 | 26 | super(SegmentationDistiller, self).__init__() 27 | 28 | self.teacher = build_segmentor(teacher_cfg.model, 29 | train_cfg=teacher_cfg.get('train_cfg'), 30 | test_cfg=teacher_cfg.get('test_cfg')) 31 | self.init_weights_teacher(teacher_pretrained) 32 | self.teacher.eval() 33 | 34 | self.use_logit = use_logit 35 | self.student= build_segmentor(student_cfg.model, 36 | train_cfg=student_cfg.get('train_cfg'), 37 | test_cfg=student_cfg.get('test_cfg')) 38 | self.student.init_weights() 39 | if init_student: 40 | t_checkpoint = _load_checkpoint(teacher_pretrained) 41 | all_name = [] 42 | for name, v in t_checkpoint["state_dict"].items(): 43 | if name.startswith("backbone."): 44 | continue 45 | else: 46 | all_name.append((name, v)) 47 | 48 | state_dict = OrderedDict(all_name) 49 | load_state_dict(self.student, state_dict) 50 | self.distill_losses = nn.ModuleDict() 51 | self.distill_cfg = distill_cfg 52 | for item_loc in distill_cfg: 53 | for item_loss in item_loc.methods: 54 | loss_name = item_loss.name 55 | self.distill_losses[loss_name] = build_distill_loss(item_loss) 56 | 57 | def base_parameters(self): 58 | return nn.ModuleList([self.student,self.distill_losses]) 59 | 60 | @property 61 | def with_neck(self): 62 | """bool: whether the segmentor has neck""" 63 | return hasattr(self.student, 'neck') and self.student.neck is not None 64 | 65 | @property 66 | def with_auxiliary_head(self): 67 | """bool: whether the segmentor has auxiliary head""" 68 | return hasattr(self.student, 69 | 'auxiliary_head') and self.student.auxiliary_head is not None 70 | 71 | @property 72 | def with_decode_head(self): 73 | """bool: whether the segmentor has decode head""" 74 | return hasattr(self.student, 'decode_head') and self.student.decode_head is not None 75 | 76 | def init_weights_teacher(self, path=None): 77 | """Load the pretrained model in teacher detector. 78 | 79 | Args: 80 | pretrained (str, optional): Path to pre-trained weights. 81 | Defaults to None. 82 | """ 83 | checkpoint = load_checkpoint(self.teacher, path, map_location='cpu') 84 | 85 | 86 | 87 | def forward_train(self, img, img_metas, gt_semantic_seg): 88 | 89 | """ 90 | Args: 91 | img (Tensor): Input images of shape (N, C, H, W). 92 | Typically these should be mean centered and std scaled. 93 | 94 | Returns: 95 | dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses). 96 | """ 97 | with torch.no_grad(): 98 | self.teacher.eval() 99 | fea_t = self.teacher.extract_feat(img) 100 | if self.use_logit: 101 | logit_t = self.teacher._decode_head_forward_test(fea_t, img_metas) 102 | 103 | student_feat = self.student.extract_feat(img) 104 | logit_s = self.student._decode_head_forward_test(student_feat, img_metas) 105 | losses = self.student.decode_head.losses(logit_s, gt_semantic_seg) 106 | loss_decode = dict() 107 | loss_decode.update(add_prefix(losses, 'decode')) 108 | 109 | student_loss = dict() 110 | student_loss.update(loss_decode) 111 | 112 | if self.student.with_auxiliary_head: 113 | loss_aux = self.student._auxiliary_head_forward_train( 114 | student_feat, img_metas, gt_semantic_seg) 115 | student_loss.update(loss_aux) 116 | 117 | loss_name = 'loss_mgd_fea' 118 | student_loss[loss_name] = self.distill_losses[loss_name](student_feat[-1],fea_t[-1].detach()) 119 | 120 | if self.use_logit: 121 | N, C, H, W = logit_s.shape 122 | softmax_pred_T = F.softmax(logit_t.view(-1, W * H) / 4, dim=1) 123 | logsoftmax = torch.nn.LogSoftmax(dim=1) 124 | loss = torch.sum(softmax_pred_T * 125 | logsoftmax(logit_t.view(-1, W * H) / 4) - 126 | softmax_pred_T * 127 | logsoftmax(logit_s.view(-1, W * H) / 4)) * ( 128 | 4**2) 129 | 130 | student_loss['loss_logit'] = 3 * loss / (C * N) 131 | 132 | return student_loss 133 | 134 | def simple_test(self, img, img_metas, **kwargs): 135 | return self.student.simple_test(img, img_metas, **kwargs) 136 | def aug_test(self, imgs, img_metas, **kwargs): 137 | return self.student.aug_test(imgs, img_metas, **kwargs) 138 | def extract_feat(self, imgs): 139 | """Extract features from images.""" 140 | return self.student.extract_feat(imgs) 141 | def encode_decode(self, img, img_metas): 142 | return self.student.encode_decode(img, img_metas) 143 | 144 | 145 | -------------------------------------------------------------------------------- /seg/mmseg/distillation/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .mgd import FeatureLoss 2 | __all__ = [ 3 | 'FeatureLoss', 4 | ] 5 | -------------------------------------------------------------------------------- /seg/mmseg/distillation/losses/mgd.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from ..builder import DISTILL_LOSSES 5 | 6 | @DISTILL_LOSSES.register_module() 7 | class FeatureLoss(nn.Module): 8 | 9 | """PyTorch version of `Masked Generative Distillation` 10 | 11 | Args: 12 | student_channels(int): Number of channels in the student's feature map. 13 | teacher_channels(int): Number of channels in the teacher's feature map. 14 | name (str): the loss name of the layer 15 | alpha_mgd (float, optional): Weight of dis_loss. Defaults to 0.00002 16 | lambda_mgd (float, optional): masked ratio. Defaults to 0.75 17 | """ 18 | def __init__(self, 19 | student_channels, 20 | teacher_channels, 21 | name, 22 | alpha_mgd=0.00002, 23 | lambda_mgd=0.75, 24 | ): 25 | super(FeatureLoss, self).__init__() 26 | self.alpha_mgd = alpha_mgd 27 | self.lambda_mgd = lambda_mgd 28 | self.name = name 29 | 30 | if student_channels != teacher_channels: 31 | self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0) 32 | else: 33 | self.align = None 34 | 35 | self.generation = nn.Sequential( 36 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1), 37 | nn.ReLU(inplace=True), 38 | nn.Conv2d(teacher_channels, teacher_channels, kernel_size=3, padding=1)) 39 | 40 | 41 | def forward(self, 42 | preds_S, 43 | preds_T): 44 | """Forward function. 45 | Args: 46 | preds_S(Tensor): Bs*C*H*W, student's feature map 47 | preds_T(Tensor): Bs*C*H*W, teacher's feature map 48 | """ 49 | assert preds_S.shape[-2:] == preds_T.shape[-2:] 50 | 51 | if self.align is not None: 52 | preds_S = self.align(preds_S) 53 | 54 | loss = self.get_dis_loss(preds_S, preds_T)*self.alpha_mgd 55 | 56 | return loss 57 | 58 | def get_dis_loss(self, preds_S, preds_T): 59 | loss_mse = nn.MSELoss(reduction='sum') 60 | N, C, H, W = preds_T.shape 61 | 62 | device = preds_S.device 63 | mat = torch.rand((N,1,H,W)).to(device) 64 | mat = torch.where(mat>1-self.lambda_mgd, 0, 1).to(device) 65 | 66 | masked_fea = torch.mul(preds_S, mat) 67 | new_fea = self.generation(masked_fea) 68 | 69 | dis_loss = loss_mse(new_fea, preds_T)/N 70 | 71 | return dis_loss -------------------------------------------------------------------------------- /seg/tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | from mmcv.cnn.utils import revert_sync_batchnorm 12 | from mmcv.runner import get_dist_info, init_dist 13 | from mmcv.utils import Config, DictAction, get_git_hash 14 | 15 | from mmseg import __version__ 16 | from mmseg.apis import init_random_seed, set_random_seed, train_segmentor 17 | from mmseg.datasets import build_dataset 18 | from mmseg.models import build_segmentor 19 | from mmseg.utils import collect_env, get_root_logger 20 | from mmseg.distillation import build_distiller 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a segmentor') 24 | parser.add_argument('config', help='train config file path') 25 | parser.add_argument('--work-dir', help='the dir to save logs and models') 26 | parser.add_argument( 27 | '--load-from', help='the checkpoint file to load weights from') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | parser.add_argument( 31 | '--no-validate', 32 | action='store_true', 33 | help='whether not to evaluate the checkpoint during training') 34 | group_gpus = parser.add_mutually_exclusive_group() 35 | group_gpus.add_argument( 36 | '--gpus', 37 | type=int, 38 | help='number of gpus to use ' 39 | '(only applicable to non-distributed training)') 40 | group_gpus.add_argument( 41 | '--gpu-ids', 42 | type=int, 43 | nargs='+', 44 | help='ids of gpus to use ' 45 | '(only applicable to non-distributed training)') 46 | parser.add_argument('--seed', type=int, default=None, help='random seed') 47 | parser.add_argument( 48 | '--deterministic', 49 | action='store_true', 50 | help='whether to set deterministic options for CUDNN backend.') 51 | parser.add_argument( 52 | '--options', 53 | nargs='+', 54 | action=DictAction, 55 | help="--options is deprecated in favor of --cfg_options' and it will " 56 | 'not be supported in version v0.22.0. Override some settings in the ' 57 | 'used config, the key-value pair in xxx=yyy format will be merged ' 58 | 'into config file. If the value to be overwritten is a list, it ' 59 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 60 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 61 | 'marks are necessary and that no white space is allowed.') 62 | parser.add_argument( 63 | '--cfg-options', 64 | nargs='+', 65 | action=DictAction, 66 | help='override some settings in the used config, the key-value pair ' 67 | 'in xxx=yyy format will be merged into config file. If the value to ' 68 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 69 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 70 | 'Note that the quotation marks are necessary and that no white space ' 71 | 'is allowed.') 72 | parser.add_argument( 73 | '--launcher', 74 | choices=['none', 'pytorch', 'slurm', 'mpi'], 75 | default='none', 76 | help='job launcher') 77 | parser.add_argument('--local_rank', type=int, default=0) 78 | args = parser.parse_args() 79 | if 'LOCAL_RANK' not in os.environ: 80 | os.environ['LOCAL_RANK'] = str(args.local_rank) 81 | 82 | if args.options and args.cfg_options: 83 | raise ValueError( 84 | '--options and --cfg-options cannot be both ' 85 | 'specified, --options is deprecated in favor of --cfg-options. ' 86 | '--options will not be supported in version v0.22.0.') 87 | if args.options: 88 | warnings.warn('--options is deprecated in favor of --cfg-options. ' 89 | '--options will not be supported in version v0.22.0.') 90 | args.cfg_options = args.options 91 | 92 | return args 93 | 94 | 95 | def main(): 96 | args = parse_args() 97 | 98 | cfg = Config.fromfile(args.config) 99 | if args.cfg_options is not None: 100 | cfg.merge_from_dict(args.cfg_options) 101 | # set cudnn_benchmark 102 | if cfg.get('cudnn_benchmark', False): 103 | torch.backends.cudnn.benchmark = True 104 | 105 | # work_dir is determined in this priority: CLI > segment in file > filename 106 | if args.work_dir is not None: 107 | # update configs according to CLI args if args.work_dir is not None 108 | cfg.work_dir = args.work_dir 109 | elif cfg.get('work_dir', None) is None: 110 | # use config filename as default work_dir if cfg.work_dir is None 111 | cfg.work_dir = osp.join('./work_dirs', 112 | osp.splitext(osp.basename(args.config))[0]) 113 | if args.load_from is not None: 114 | cfg.load_from = args.load_from 115 | if args.resume_from is not None: 116 | cfg.resume_from = args.resume_from 117 | if args.gpu_ids is not None: 118 | cfg.gpu_ids = args.gpu_ids 119 | else: 120 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 121 | 122 | # init distributed env first, since logger depends on the dist info. 123 | if args.launcher == 'none': 124 | distributed = False 125 | else: 126 | distributed = True 127 | init_dist(args.launcher, **cfg.dist_params) 128 | # gpu_ids is used to calculate iter when resuming checkpoint 129 | _, world_size = get_dist_info() 130 | cfg.gpu_ids = range(world_size) 131 | 132 | # create work_dir 133 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 134 | # dump config 135 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 136 | # init the logger before other steps 137 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 138 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 139 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 140 | 141 | # init the meta dict to record some important information such as 142 | # environment info and seed, which will be logged 143 | meta = dict() 144 | # log env info 145 | env_info_dict = collect_env() 146 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 147 | dash_line = '-' * 60 + '\n' 148 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 149 | dash_line) 150 | meta['env_info'] = env_info 151 | 152 | # log some basic info 153 | logger.info(f'Distributed training: {distributed}') 154 | logger.info(f'Config:\n{cfg.pretty_text}') 155 | 156 | # set random seeds 157 | seed = init_random_seed(args.seed) 158 | logger.info(f'Set random seed to {seed}, ' 159 | f'deterministic: {args.deterministic}') 160 | set_random_seed(seed, deterministic=args.deterministic) 161 | cfg.seed = seed 162 | meta['seed'] = seed 163 | meta['exp_name'] = osp.basename(args.config) 164 | 165 | distiller_cfg = cfg.get('distiller',None) 166 | if distiller_cfg is None: 167 | model = build_segmentor( 168 | cfg.model, 169 | train_cfg=cfg.get('train_cfg'), 170 | test_cfg=cfg.get('test_cfg')) 171 | model.init_weights() 172 | else: 173 | teacher_cfg = Config.fromfile(cfg.teacher_cfg) 174 | student_cfg = Config.fromfile(cfg.student_cfg) 175 | model = build_distiller(cfg.distiller,teacher_cfg,student_cfg, 176 | train_cfg=student_cfg.get('train_cfg'), 177 | test_cfg=student_cfg.get('test_cfg')) 178 | 179 | # SyncBN is not support for DP 180 | if not distributed: 181 | warnings.warn( 182 | 'SyncBN is only supported with DDP. To be compatible with DP, ' 183 | 'we convert SyncBN to BN. Please use dist_train.sh which can ' 184 | 'avoid this error.') 185 | model = revert_sync_batchnorm(model) 186 | 187 | logger.info(model) 188 | 189 | datasets = [build_dataset(cfg.data.train)] 190 | if len(cfg.workflow) == 2: 191 | val_dataset = copy.deepcopy(cfg.data.val) 192 | val_dataset.pipeline = cfg.data.train.pipeline 193 | datasets.append(build_dataset(val_dataset)) 194 | if cfg.checkpoint_config is not None: 195 | # save mmseg version, config file content and class names in 196 | # checkpoints as meta data 197 | cfg.checkpoint_config.meta = dict( 198 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 199 | config=cfg.pretty_text, 200 | CLASSES=datasets[0].CLASSES, 201 | PALETTE=datasets[0].PALETTE) 202 | # add an attribute for visualization convenience 203 | model.CLASSES = datasets[0].CLASSES 204 | # passing checkpoint meta for saving best checkpoint 205 | meta.update(cfg.checkpoint_config.meta) 206 | train_segmentor( 207 | model, 208 | datasets, 209 | cfg, 210 | distributed=distributed, 211 | validate=(not args.no_validate), 212 | timestamp=timestamp, 213 | meta=meta) 214 | 215 | 216 | if __name__ == '__main__': 217 | main() 218 | --------------------------------------------------------------------------------