├── LICENSE ├── README.md ├── configs └── replknet31_base_224_pt1k.py ├── main_benchmark.py ├── main_imagenet_test.py ├── main_imagenet_train.py └── model_replknet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RepLKNet (CVPR 2022) 2 | 3 | This is the official MegEngine implementation of **RepLKNet**, from the following CVPR-2022 paper: 4 | 5 | Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs. 6 | 7 | The paper is released on arXiv: https://arxiv.org/abs/2203.06717. 8 | 9 | If you find the paper or this repository helpful, please consider citing 10 | 11 | @article{replknet, 12 | title={Scaling Up Your Kernels to 31x31: Revisiting Large Kernel Design in CNNs}, 13 | author={Ding, Xiaohan and Zhang, Xiangyu and Zhou, Yizhuang and Han, Jungong and Ding, Guiguang and Sun, Jian}, 14 | journal={arXiv preprint arXiv:2203.06717}, 15 | year={2022} 16 | } 17 | 18 | ## Official PyTorch implementation 19 | 20 | Our official PyTorch repository https://github.com/DingXiaoH/RepLKNet-pytorch contains 21 | 22 | 1. All of the pretrained weights and ImageNet-1K weights. 23 | 2. All of the Cityscapes/ADE20K/COCO weights and code. 24 | 3. An example of using our efficient conv implementation with PyTorch. 25 | 4. Training script and reproducible commands. 26 | 5. A script to visualize the Effective Receptive Field and instructions on obtaining the shape bias. 27 | 28 | ## Other implementations 29 | 30 | | framework | link | 31 | |:---:|:---:| 32 | |Tensorflow|https://github.com/shkarupa-alex/tfreplknet| 33 | | ... | | 34 | 35 | More implementations are welcomed. 36 | 37 | ## Catalog 38 | - [x] Model code 39 | - [x] MegEngine pretrained models 40 | - [x] MegEngine training code 41 | - [ ] MegEngine downstream models 42 | - [ ] MegEngine downstream code 43 | 44 | 45 | 46 | ## Results and Pre-trained Models 47 | 48 | ### ImageNet-1K Models 49 | 50 | | name | resolution |acc | #params | FLOPs | download | 51 | |:---:|:---:|:---:|:---:| :---:|:---:| 52 | | RepLKNet-31B | 224x224 | 83.58 | 79M | 15.3G | [0de394](https://data.megengine.org.cn/research/replknet/replknet31_base_224_pt1k_basecls.pkl) | 53 | 54 | 55 | ### ImageNet-22K Models 56 | 57 | | name | resolution |acc | #params | FLOPs | 22K model | 1K model | 58 | |:---:|:---:|:---:|:---:| :---:| :---:|:---:| 59 | 60 | 61 | 62 | ### MegData-73M Models 63 | | name | resolution |acc@1 | #params | FLOPs | MegData-73M model | 1K model | 64 | |:---:|:---:|:---:|:---:| :---:| :---:|:---:| 65 | 66 | 67 | ## Installation of MegEngine 68 | ```bash 69 | pip3 install megengine -f https://megengine.org.cn/whl/mge.html --user 70 | ``` 71 | For more details, please check the [HomePage](https://github.com/MegEngine/MegEngine). 72 | 73 | ## Installation of BaseCls 74 | 75 | [BaseCls](https://github.com/megvii-research/basecls) is an image classification framework built upon MegEngine. 76 | We ultilize BaseCls for ImageNet pretraining and finetuning. 77 | 78 | ```bash 79 | pip3 install basecls --user 80 | ``` 81 | 82 | Training and evaluation are configured through file. All default configurations are listed [here](https://github.com/megvii-research/basecls/blob/main/basecls/configs/base_cfg.py). 83 | 84 | ## Evaluation 85 | ```bash 86 | ./main_imagenet_test.py -f configs/replknet31_base_224_pt1k.py -w [weights] batch_size 64 data.val_path /path/to/imagenet/val 87 | ``` 88 | 89 | ## Training 90 | ```bash 91 | ./main_imagenet_train.py -f configs/replknet31_base_224_pt1k.py data.train_path /path/to/imagenet/train data.val_path /path/to/imagenet/val 92 | ``` 93 | 94 | ## Benchmark large depth-wise kernels 95 | 96 | We can compare the kernel speed of MegEngine against PyTorch. A minimum version of megengine 1.8.2 is required for 97 | optimized large depth-wise convolutions. 98 | 99 | ```bash 100 | ./main_benchmark.py 101 | ``` 102 | 103 | ## License 104 | This project is released under the MIT license. Please see the [LICENSE](LICENSE) file for more information. 105 | -------------------------------------------------------------------------------- /configs/replknet31_base_224_pt1k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 2 | from basecls.layers import NORM_TYPES 3 | 4 | from basecls.configs.base_cfg import BaseConfig 5 | 6 | 7 | _cfg = dict( 8 | batch_size=128, 9 | output_dir=None, 10 | model=dict( 11 | name="replknet31_base", 12 | drop_path_rate=0.5, 13 | ), 14 | bn=dict( 15 | precise_every_n_epoch=5, 16 | ), 17 | ## MegEngine use BGR as default colorspace 18 | # preprocess=dict( 19 | # img_color_space="RGB", 20 | # # flip from BGR mean & std to RGB 21 | # img_mean=[103.530, 116.280, 123.675][::-1], 22 | # img_std=[57.375, 57.12, 58.395][::-1], 23 | # ), 24 | test=dict( 25 | crop_pct=0.875, 26 | ), 27 | eval_every_n_epoch=5, 28 | loss=dict( 29 | label_smooth=0.1, 30 | ), 31 | augments=dict( 32 | name="RandAugment", 33 | rand_aug=dict( 34 | magnitude=9, 35 | ), 36 | resize=dict( 37 | interpolation="bicubic", 38 | ), 39 | rand_erase=dict( 40 | prob=0.25, 41 | mode="pixel", 42 | ), 43 | mixup=dict( 44 | mixup_alpha=0.8, 45 | cutmix_alpha=1.0, 46 | ), 47 | ), 48 | data=dict( 49 | train_path="/path/to/imagenet/train", 50 | val_path="/path/to/imagenet/val", 51 | num_workers=10, 52 | ), 53 | solver=dict( 54 | optimizer="adamw", 55 | # `basic_lr` is the learning rate for a single GPU 56 | # 4e-3 per 2048 batch size == 2.5e-4 per 128 batch size 57 | basic_lr=2.5e-4, 58 | lr_min_factor=1e-3, 59 | weight_decay=( 60 | (0, "bias"), 61 | (0, NORM_TYPES), 62 | 0.05, 63 | ), 64 | max_epoch=300, 65 | warmup_epochs=10, 66 | warmup_factor=0.1, 67 | lr_schedule="cosine", 68 | ), 69 | model_ema=dict( 70 | enabled=True, 71 | momentum=0.9992, 72 | update_period=8, 73 | ), 74 | fastrun=False, 75 | dtr=False, 76 | amp=dict( 77 | enabled=True, 78 | dynamic_scale=True, 79 | ), 80 | save_every_n_epoch=50, 81 | ) 82 | 83 | 84 | class Cfg(BaseConfig): 85 | def __init__(self, values_or_file=None, **kwargs): 86 | super().__init__(_cfg) 87 | self.merge(values_or_file, **kwargs) 88 | -------------------------------------------------------------------------------- /main_benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 3 | """Benchmark depth-wise large convolution for megengine and pytorch""" 4 | import os 5 | import time 6 | 7 | import megengine 8 | import megengine.functional as F 9 | import torch 10 | 11 | torch.backends.cudnn.enabled = True 12 | torch.backends.cudnn.benchmark = True 13 | 14 | assert megengine.__version__ >= "1.8.2", "please update megengine via " \ 15 | "`pip3 megengine -f https://megengine.org.cn/whl/mge.html -U --user`" 16 | 17 | 18 | def benchmark_megengine(batch_size, resolution, channels, depth, kernel_size, niters=10): 19 | input = F.ones([batch_size, channels, resolution, resolution]) * 1e-3 20 | weight = F.ones([channels, 1, 1, kernel_size, kernel_size]) * 1e-3 21 | diff = [] 22 | for i in range(niters): 23 | x = input 24 | megengine._full_sync() 25 | t = time.perf_counter() 26 | for _ in range(depth): 27 | x = F.conv2d(x, weight, bias=None, padding=kernel_size // 2, groups=channels) 28 | megengine._full_sync() 29 | diff.append((time.perf_counter() - t) * 1000) 30 | diff = sum(sorted(diff)[1:-2]) / (niters - 3) 31 | print(f"benchmark_megeg\tB{batch_size},R{resolution},C{channels},D{depth},K{kernel_size}\t{diff:.3f} ms") 32 | return diff 33 | 34 | 35 | @torch.no_grad() 36 | def benchmark_torch(batch_size, resolution, channels, depth, kernel_size, niters=10): 37 | input = torch.randn(batch_size, channels, resolution, resolution).cuda() * 1e-3 38 | weight = torch.randn(channels, 1, kernel_size, kernel_size).cuda() * 1e-3 39 | diff = [] 40 | for i in range(niters): 41 | x = input.clone() 42 | torch.cuda.synchronize() 43 | t = time.perf_counter() 44 | for _ in range(depth): 45 | x = torch.nn.functional.conv2d(x, weight, bias=None, padding=kernel_size // 2, groups=channels) 46 | torch.cuda.synchronize() 47 | diff.append((time.perf_counter() - t) * 1000) 48 | diff = sum(sorted(diff)[1:-2]) / (niters - 3) 49 | print(f"benchmark_torch\tB{batch_size},R{resolution},C{channels},D{depth},K{kernel_size}\t{diff:.3f} ms") 50 | return diff 51 | 52 | 53 | if __name__ == "__main__": 54 | for resolution in (16, 32): 55 | for kernel_size in (7, 17, 27): 56 | benchmark_torch(64, resolution, 256, 12, kernel_size) 57 | benchmark_megengine(64, resolution, 256, 12, kernel_size) 58 | -------------------------------------------------------------------------------- /main_imagenet_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 3 | """ 4 | ImageNet testing script modifiled from BaseCls 5 | https://github.com/megvii-research/basecls/blob/main/basecls/tools/cls_test.py 6 | """ 7 | import argparse 8 | import importlib 9 | import multiprocessing as mp 10 | import os 11 | import sys 12 | 13 | import megengine as mge 14 | import megengine.distributed as dist 15 | from basecore.config import ConfigDict 16 | from loguru import logger 17 | 18 | from basecls.engine import ClsTester 19 | from basecls.models import build_model, load_model 20 | from basecls.utils import default_logging, registers, set_nccl_env, set_num_threads, setup_logger 21 | 22 | from model_replknet import RepLKNet 23 | 24 | 25 | def make_parser() -> argparse.ArgumentParser: 26 | """Build args parser for testing script. 27 | 28 | Returns: 29 | The args parser. 30 | """ 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("-f", "--file", type=str, help="testing process description file") 33 | parser.add_argument("-w", "--weight_file", default=None, type=str, help="weight file") 34 | parser.add_argument("opts", default=None, nargs=argparse.REMAINDER, help="other options") 35 | return parser 36 | 37 | 38 | @logger.catch 39 | def worker(args: argparse.Namespace): 40 | """Worker function for testing script. 41 | 42 | Args: 43 | args: args for testing script. 44 | """ 45 | logger.info(f"Init process group for gpu{dist.get_rank()} done") 46 | 47 | sys.path.append(os.path.dirname(args.file)) 48 | module_name = os.path.splitext(os.path.basename(args.file))[0] 49 | current_network = importlib.import_module(module_name) 50 | cfg = current_network.Cfg() 51 | if cfg.output_dir is None: 52 | cfg.output_dir = f"./logs_{module_name}" 53 | cfg.output_dir = os.path.abspath(cfg.output_dir) 54 | 55 | if args.weight_file: 56 | cfg.weights = args.weight_file 57 | else: 58 | cfg.weights = os.path.join(cfg.output_dir, "latest.pkl") 59 | 60 | cfg.merge(args.opts) 61 | cfg.set_mode("freeze") 62 | 63 | if dist.get_rank() == 0 and not os.path.exists(cfg.output_dir): 64 | os.makedirs(cfg.output_dir) 65 | dist.group_barrier() 66 | 67 | setup_logger(cfg.output_dir, "test_log.txt", to_loguru=True) 68 | logger.info(f"args: {args}") 69 | 70 | if cfg.fastrun: 71 | logger.info("Using fastrun mode...") 72 | mge.functional.debug_param.set_execution_strategy("PROFILE") 73 | 74 | tester = build(cfg) 75 | tester.test() 76 | 77 | 78 | def build(cfg: ConfigDict): 79 | """Build function for testing script. 80 | 81 | Args: 82 | cfg: config for testing. 83 | 84 | Returns: 85 | A tester. 86 | """ 87 | model = build_model(cfg) 88 | load_model(model, cfg.weights) 89 | 90 | if isinstance(model, RepLKNet): 91 | model = RepLKNet.convert_to_deploy(model) 92 | 93 | default_logging(cfg, model) 94 | 95 | dataloader = registers.dataloaders.get(cfg.data.name).build(cfg, False) 96 | # FIXME: need atomic user_pop, maybe in MegEngine 1.5? 97 | # tester = BaseTester(model, dataloader, AccEvaluator()) 98 | return ClsTester(cfg, model, dataloader) 99 | 100 | 101 | def main(): 102 | """Main function for testing script.""" 103 | parser = make_parser() 104 | args = parser.parse_args() 105 | 106 | mp.set_start_method("spawn") 107 | 108 | set_nccl_env() 109 | set_num_threads() 110 | 111 | if not os.path.exists(args.file): 112 | raise ValueError("Description file does not exist") 113 | 114 | device_count = mge.device.get_device_count("gpu") 115 | 116 | if device_count == 0: 117 | logger.warning("No GPU was found, testing on CPU") 118 | worker(args) 119 | elif device_count > 1: 120 | mp_worker = dist.launcher(worker) 121 | mp_worker(args) 122 | else: 123 | worker(args) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /main_imagenet_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 3 | """ 4 | ImageNet training script modifiled from BaseCls 5 | https://github.com/megvii-research/basecls/blob/main/basecls/tools/cls_train.py 6 | """ 7 | import argparse 8 | import importlib 9 | import os 10 | import sys 11 | 12 | import megengine as mge 13 | import megengine.distributed as dist 14 | from basecore.config import ConfigDict 15 | from loguru import logger 16 | 17 | from basecls.models import build_model, load_model, sync_model 18 | from basecls.utils import registers, set_nccl_env, set_num_threads, setup_logger 19 | 20 | import model_replknet 21 | 22 | 23 | def default_parser() -> argparse.ArgumentParser: 24 | """Build args parser for training script. 25 | 26 | Returns: 27 | The args parser. 28 | """ 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("-f", "--file", type=str, help="training process description file") 31 | parser.add_argument( 32 | "--resume", action="store_true", help="resume training from saved checkpoint or not" 33 | ) 34 | parser.add_argument( 35 | "opts", 36 | default=None, 37 | help="Modify config options using the command-line", 38 | nargs=argparse.REMAINDER, 39 | ) 40 | return parser 41 | 42 | 43 | @logger.catch 44 | def worker(args: argparse.Namespace): 45 | """Worker function for training script. 46 | 47 | Args: 48 | args: args for training script. 49 | """ 50 | logger.info(f"Init process group for gpu{dist.get_rank()} done") 51 | 52 | sys.path.append(os.path.dirname(args.file)) 53 | module_name = os.path.splitext(os.path.basename(args.file))[0] 54 | current_network = importlib.import_module(module_name) 55 | cfg = current_network.Cfg() 56 | cfg.merge(args.opts) 57 | cfg.resume = args.resume 58 | if cfg.output_dir is None: 59 | cfg.output_dir = f"./logs_{module_name}" 60 | cfg.output_dir = os.path.abspath(cfg.output_dir) 61 | 62 | cfg.set_mode("freeze") 63 | 64 | if dist.get_rank() == 0 and not os.path.exists(cfg.output_dir): 65 | os.makedirs(cfg.output_dir) 66 | dist.group_barrier() 67 | 68 | setup_logger(cfg.output_dir, "train_log.txt", to_loguru=True) 69 | logger.info(f"args: {args}") 70 | 71 | if cfg.fastrun: 72 | logger.info("Using fastrun mode...") 73 | mge.functional.debug_param.set_execution_strategy("PROFILE") 74 | 75 | if cfg.dtr: 76 | logger.info("Enabling DTR...") 77 | mge.dtr.enable() 78 | 79 | trainer = build(cfg) 80 | trainer.train() 81 | 82 | 83 | def build(cfg: ConfigDict): 84 | """Build function for training script. 85 | 86 | Args: 87 | cfg: config for training. 88 | 89 | Returns: 90 | A trainer. 91 | """ 92 | model = build_model(cfg) 93 | if getattr(cfg, "weights", None) is not None: 94 | load_model(model, cfg.weights, strict=False) 95 | sync_model(model) 96 | model.train() 97 | 98 | logger.info(f"Using augments named {cfg.augments.name}") 99 | augments = registers.augments.get(cfg.augments.name).build(cfg) 100 | logger.info(f"Using dataloader named {cfg.data.name}") 101 | dataloader = registers.dataloaders.get(cfg.data.name).build(cfg, True, augments) 102 | logger.info(f"Using solver named {cfg.solver.name}") 103 | solver = registers.solvers.get(cfg.solver.name).build(cfg, model) 104 | logger.info(f"Using hooks named {cfg.hooks_name}") 105 | hooks = registers.hooks.get(cfg.hooks_name).build(cfg) 106 | 107 | logger.info(f"Using trainer named {cfg.trainer_name}") 108 | TrainerClass = registers.trainers.get(cfg.trainer_name) 109 | return TrainerClass(cfg, model, dataloader, solver, hooks=hooks) 110 | 111 | 112 | def main(): 113 | """Main function for training script.""" 114 | parser = default_parser() 115 | args = parser.parse_args() 116 | 117 | set_nccl_env() 118 | set_num_threads() 119 | 120 | device_count = mge.device.get_device_count("gpu") 121 | launcher = dist.launcher 122 | 123 | if not os.path.exists(args.file): 124 | raise ValueError("Description file does not exist") 125 | 126 | if device_count == 0: 127 | raise ValueError("Number of devices should be greater than 0") 128 | elif device_count > 1 or os.environ.get("RLAUNCH_REPLICA_TOTAL", 0) > 1: 129 | mp_worker = launcher(worker) 130 | mp_worker(args) 131 | else: 132 | worker(args) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /model_replknet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2014-2022 Megvii Inc. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # 5 | # Unless required by applicable law or agreed to in writing, software 6 | # distributed under the License is distributed on an "AS IS" BASIS, 7 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 8 | import copy 9 | 10 | import megengine 11 | import megengine.functional as F 12 | import megengine.module as nn 13 | import numpy as np 14 | from basecls.layers import DropPath, init_weights 15 | from basecls.utils import registers 16 | 17 | 18 | def _fuse_prebn_conv1x1(bn, conv): 19 | module_output = copy.deepcopy(conv) 20 | module_output.bias = megengine.Parameter(np.zeros(module_output._infer_bias_shape(), dtype=np.float32)) 21 | assert conv.groups == 1 22 | kernel = conv.weight 23 | running_mean = bn.running_mean 24 | running_var = bn.running_var 25 | gamma = bn.weight 26 | beta = bn.bias 27 | eps = bn.eps 28 | std = F.sqrt(running_var + eps) 29 | t = (gamma / std).reshape(1, -1, 1, 1) 30 | module_output.weight[:] = kernel * t 31 | module_output.bias[:] = F.conv2d(beta - running_mean * gamma / std, kernel, conv.bias) 32 | return module_output 33 | 34 | 35 | def _fuse_conv_bn(conv, bn): 36 | module_output = copy.deepcopy(conv) 37 | module_output.bias = megengine.Parameter(np.zeros(module_output._infer_bias_shape(), dtype=np.float32)) 38 | # flatten then reshape in case of group conv 39 | kernel = F.flatten(conv.weight, end_axis=conv.weight.ndim - 4) 40 | running_mean = bn.running_mean 41 | running_var = bn.running_var 42 | gamma = bn.weight 43 | beta = bn.bias 44 | eps = bn.eps 45 | std = F.sqrt(running_var + eps) 46 | t = (gamma / std).reshape(-1, 1, 1, 1) 47 | module_output.weight[:] = (kernel * t).reshape(module_output.weight.shape) 48 | module_output.bias[:] = beta + ((conv.bias if conv.bias is not None else 0) - running_mean) * gamma / std 49 | return module_output 50 | 51 | 52 | class ConvBn2d(nn.ConvBn2d): 53 | def __init__(self, *args, **kwargs): 54 | bias = kwargs.pop("bias", False) and False 55 | super().__init__(*args, bias=bias, **kwargs) 56 | 57 | @classmethod 58 | def fuse_conv_bn(cls, module: nn.Module): 59 | module_output = module 60 | if isinstance(module, ConvBn2d): 61 | return _fuse_conv_bn(module.conv, module.bn) 62 | for name, child in module.named_children(): 63 | setattr(module_output, name, cls.fuse_conv_bn(child)) 64 | del module 65 | return module_output 66 | 67 | 68 | class LargeKernelReparam(nn.Module): 69 | def __init__(self, channels, kernel, small_kernels=()): 70 | super(LargeKernelReparam, self).__init__() 71 | self.dw_large = ConvBn2d(channels, channels, kernel, padding=kernel // 2, groups=channels) 72 | 73 | self.small_kernels = small_kernels 74 | for k in self.small_kernels: 75 | setattr(self, f"dw_small_{k}", ConvBn2d(channels, channels, k, padding=k // 2, groups=channels)) 76 | 77 | def forward(self, inp): 78 | outp = self.dw_large(inp) 79 | for k in self.small_kernels: 80 | outp += getattr(self, f"dw_small_{k}")(inp) 81 | return outp 82 | 83 | @classmethod 84 | def convert_to_deploy(cls, module: nn.Module): 85 | module_output = module 86 | if isinstance(module, LargeKernelReparam): 87 | module = ConvBn2d.fuse_conv_bn(module) 88 | module_output = copy.deepcopy(module.dw_large) 89 | kernel = module_output.kernel_size[0] 90 | for k in module.small_kernels: 91 | dw_small = getattr(module, f"dw_small_{k}") 92 | module_output.weight += F.pad(dw_small.weight, [[0, 0]] * 3 + [[(kernel - k) // 2] * 2] * 2) 93 | module_output.bias += dw_small.bias 94 | return module_output 95 | for name, child in module.named_children(): 96 | setattr(module_output, name, cls.convert_to_deploy(child)) 97 | del module 98 | return module_output 99 | 100 | 101 | class Mlp(nn.Module): 102 | def __init__(self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU, drop=0.,): 103 | super().__init__() 104 | out_features = out_channels or in_channels 105 | hidden_features = hidden_channels or in_channels 106 | self.fc1 = ConvBn2d(in_channels, hidden_features, 1, stride=1, padding=0) 107 | self.act = act_layer() 108 | self.fc2 = ConvBn2d(hidden_features, out_features, 1, stride=1, padding=0) 109 | self.drop = nn.Dropout(drop) 110 | 111 | def forward(self, x): 112 | x = self.fc1(x) 113 | x = self.act(x) 114 | x = self.drop(x) 115 | x = self.fc2(x) 116 | x = self.drop(x) 117 | return x 118 | 119 | 120 | class RepLKBlock(nn.Module): 121 | 122 | def __init__(self, channels, kernel, small_kernels=(), dw_ratio=1.0, mlp_ratio=4.0, drop_path=0., activation=nn.ReLU): 123 | super().__init__() 124 | 125 | self.pre_bn = nn.BatchNorm2d(channels) 126 | self.pw1 = ConvBn2d(channels, int(channels * dw_ratio), 1, 1, 0) 127 | self.pw1_act = activation() 128 | self.dw = LargeKernelReparam(int(channels * dw_ratio), kernel, small_kernels=small_kernels) 129 | self.dw_act = activation() 130 | self.pw2 = ConvBn2d(int(channels * dw_ratio), channels, 1, 1, 0) 131 | 132 | self.premlp_bn = nn.BatchNorm2d(channels) 133 | self.mlp = Mlp(in_channels=channels, hidden_channels=int(channels * mlp_ratio)) 134 | 135 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 136 | 137 | def forward(self, x): 138 | y = self.pre_bn(x) 139 | y = self.pw1_act(self.pw1(y)) 140 | y = self.dw_act(self.dw(y)) 141 | y = self.pw2(y) 142 | x = x + self.drop_path(y) 143 | 144 | y = self.premlp_bn(x) 145 | y = self.mlp(y) 146 | x = x + self.drop_path(y) 147 | 148 | return x 149 | 150 | @classmethod 151 | def convert_to_deploy(cls, module: nn.Module): 152 | module_output = module 153 | if isinstance(module, RepLKBlock): 154 | LargeKernelReparam.convert_to_deploy(module) 155 | ConvBn2d.fuse_conv_bn(module) 156 | 157 | module.pre_bn, module.pw1 = nn.Identity(), _fuse_prebn_conv1x1(module.pre_bn, module.pw1) 158 | module.premlp_bn, module.mlp.fc1 = nn.Identity(), _fuse_prebn_conv1x1(module.premlp_bn, module.mlp.fc1) 159 | return module_output 160 | for name, child in module.named_children(): 161 | setattr(module_output, name, cls.convert_to_deploy(child)) 162 | del module 163 | return module_output 164 | 165 | 166 | class DownSample(nn.Sequential): 167 | def __init__(self, in_channels, out_channels, activation=nn.ReLU): 168 | super().__init__( 169 | ConvBn2d(in_channels, out_channels, 1), 170 | activation(), 171 | ConvBn2d(out_channels, out_channels, 3, stride=2, padding=1, groups=out_channels), 172 | activation(), 173 | ) 174 | 175 | 176 | class Stem(nn.Sequential): 177 | def __init__(self, in_channels, out_channels, activation=nn.ReLU): 178 | super().__init__( 179 | ConvBn2d(in_channels, out_channels, 3, stride=2, padding=1), 180 | activation(), 181 | ConvBn2d(out_channels, out_channels, 3, padding=1, groups=out_channels), 182 | activation(), 183 | ConvBn2d(out_channels, out_channels, 1), 184 | activation(), 185 | ConvBn2d(out_channels, out_channels, 3, stride=2, padding=1, groups=out_channels), 186 | activation(), 187 | ) 188 | 189 | class RepLKNet(nn.Module): 190 | 191 | def __init__( 192 | self, 193 | in_channels=3, 194 | depths=(2, 2, 18, 2), 195 | dims=(128, 256, 512, 1024), 196 | kernel_sizes=(31, 29, 27, 13), 197 | small_kernels=(5,), 198 | dw_ratio=1.0, 199 | mlp_ratio=4.0, 200 | num_classes=1000, 201 | drop_path_rate=0.5, 202 | ): 203 | super().__init__() 204 | 205 | self.stem = Stem(in_channels, dims[0]) 206 | # stochastic depth 207 | dpr = (x for x in np.linspace(0, drop_path_rate, sum(depths))) # stochastic depth decay rule 208 | 209 | self.blocks = [] 210 | 211 | for stage, (depth, dim, ksize) in enumerate(zip(depths, dims, kernel_sizes)): 212 | for _ in range(depth): 213 | self.blocks.append( 214 | RepLKBlock(dim, ksize, small_kernels=small_kernels, 215 | dw_ratio=dw_ratio, mlp_ratio=mlp_ratio, drop_path=next(dpr)) 216 | ) 217 | if stage < len(depths) - 1: 218 | self.blocks.append(DownSample(dim, dims[stage + 1])) 219 | 220 | self.norm = nn.BatchNorm2d(dims[-1]) 221 | self.avgpool = nn.AdaptiveAvgPool2d(1) 222 | self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() 223 | init_weights(self) 224 | 225 | def forward_features(self, x): 226 | x = self.stem(x) 227 | for blk in self.blocks: 228 | x = blk(x) 229 | x = self.norm(x) 230 | x = self.avgpool(x) 231 | x = F.flatten(x, 1) 232 | return x 233 | 234 | def forward(self, x): 235 | x = self.forward_features(x) 236 | x = self.head(x) 237 | return x 238 | 239 | @classmethod 240 | def convert_to_deploy(cls, module: nn.Module): 241 | module_output = module 242 | if isinstance(module, RepLKNet): 243 | RepLKBlock.convert_to_deploy(module) 244 | ConvBn2d.fuse_conv_bn(module) 245 | return module_output 246 | for name, child in module.named_children(): 247 | setattr(module_output, name, cls.convert_to_deploy(child)) 248 | del module 249 | return module_output 250 | 251 | 252 | @registers.models.register() 253 | def replknet31_base(**kwargs): 254 | kwargs.pop("head", None) 255 | return RepLKNet(dims=(128, 256, 512, 1024), dw_ratio=1.0, **kwargs) 256 | 257 | 258 | @registers.models.register() 259 | def replknet31_large(**kwargs): 260 | kwargs.pop("head", None) 261 | return RepLKNet(dims=(192, 384, 768, 1536), dw_ratio=1.0, **kwargs) 262 | 263 | 264 | @registers.models.register() 265 | def replknet_xlarge(**kwargs): 266 | kwargs.pop("head", None) 267 | return RepLKNet(dims=(256, 512, 1024, 2048), kernel_sizes=(27, 27, 27, 13), small_kernels=(), dw_ratio=1.5, **kwargs) 268 | --------------------------------------------------------------------------------