├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── dataset.py ├── default_config.yaml ├── demo.ipynb ├── eval.py ├── nfnets ├── __init__.py ├── model.py ├── optim.py └── pretrained.py ├── pretrained └── README.md ├── pyproject.toml ├── requirements.txt ├── setup.cfg └── train.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | __pycache__ 3 | .pytest_cache 4 | .vscode 5 | .ipynb_checkpoints 6 | checkpoints 7 | pretrained/*.npz 8 | pretrained/*.pth 9 | runs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NFNet Pytorch Implementation 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/benjs/nfnets_pytorch/blob/master/demo.ipynb) 4 | 5 | This repo contains pretrained NFNet models F0-F6 with high ImageNet accuracy from the paper *High-Performance Large-Scale Image Recognition Without Normalization*. The small models are as accurate as an EfficientNet-B7, but train 8.7 times faster. The large models set a new SOTA top-1 accuracy on ImageNet. 6 | 7 | | NFNet | F0 | F1 | F2 | F3 | F4 | F5 | F6+SAM | 8 | |:---|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 9 | | Top-1 accuracy Brock et al. | 83.6 | 84.7 | 85.1 | 85.7 | 85.9 | 86.0 | 86.5 | 10 | | Top-1 accuracy this implementation | 82.82 | 84.63 | 84.90 | 85.46 | 85.66 | 85.62 | TBD | 11 | 12 | All credits go to the authors of the [original paper](https://arxiv.org/abs/2102.06171). This repo is heavily inspired by their nice JAX implementation in the [official repository](https://github.com/deepmind/deepmind-research/blob/master/nfnets/). Visit their repo for citing. 13 | 14 | ## Get started 15 | ``` 16 | git clone https://github.com/benjs/nfnets_pytorch.git 17 | pip3 install -r requirements.txt 18 | ``` 19 | or if you don't need eval and training script 20 | ``` 21 | pip install git+https://github.com/benjs/nfnets_pytorch 22 | ``` 23 | Download pretrained weights from the [official repository](https://github.com/deepmind/deepmind-research/blob/master/nfnets/) and call 24 | 25 | ```python 26 | from nfnets import pretrained_nfnet 27 | model_F0 = pretrained_nfnet('pretrained/F0_haiku.npz') 28 | model_F1 = pretrained_nfnet('pretrained/F1_haiku.npz') 29 | # ... 30 | ``` 31 | 32 | The model variant is automatically derived from the parameter count in the pretrained weights file. 33 | 34 | ## Validate yourself 35 | ``` 36 | python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset path/to/imagenet/valset/ 37 | ``` 38 | 39 | You can download the ImageNet validation set from the [ILSVRC2012 challenge site](http://www.image-net.org/challenges/LSVRC/2012/downloads.php#images) after asking for access with, for instance, your .edu mail address or from [AcademicTorrents](https://academictorrents.com/) 40 | 41 | ## Scaled weight standardization convolutions in your own model 42 | Simply replace all your `nn.Conv2d` with `WSConv2D` and all your `nn.ReLU` with `VPReLU` or `VPGELU` (variance preserving ReLU/GELU). 43 | 44 | ``` python 45 | import torch.nn as nn 46 | from nfnets import WSConv2D, VPReLU, VPGELU 47 | 48 | # Simply replace your nn.Conv2d layers 49 | class MyNet(nn.Module): 50 | def __init__(self): 51 | super(MyNet, self).__init__() 52 | 53 | self.activation = VPReLU(inplace=True) # or VPGELU 54 | self.conv0 = WSConv2D(in_channels=128, out_channels=256, kernel_size=1, ...) 55 | # ... 56 | 57 | def forward(self, x): 58 | out = self.activation(self.conv0(x)) 59 | # ... 60 | ``` 61 | 62 | ## SGD with adaptive gradient clipping in your own model 63 | Simply replace your `SGD` optimizer with `SGD_AGC`. 64 | ```python 65 | from nfnets import SGD_AGC 66 | 67 | optimizer = SGD_AGC( 68 | named_params=model.named_parameters(), # Pass named parameters 69 | lr=1e-3, 70 | momentum=0.9, 71 | clipping=0.1, # New clipping parameter 72 | weight_decay=2e-5, 73 | nesterov=True) 74 | ``` 75 | 76 | It is important to exclude certain layers from clipping or momentum. The authors recommends to exclude the last fully convolutional from clipping and the bias/gain parameters from weight decay: 77 | ```python 78 | import re 79 | 80 | for group in optimizer.param_groups: 81 | name = group['name'] 82 | 83 | # Exclude from weight decay 84 | if len(re.findall('stem.*(bias|gain)|conv.*(bias|gain)|skip_gain', name)) > 0: 85 | group['weight_decay'] = 0 86 | 87 | # Exclude from clipping 88 | if name.startswith('linear'): 89 | group['clipping'] = None 90 | 91 | ``` 92 | 93 | ## Train your own NFNet 94 | Adjust your desired parameters in [default_config.yaml](default_config.yaml) and start training. 95 | ``` 96 | python3 train.py --dataset /path/to/imagenet/ 97 | ``` 98 | 99 | There is still some parts missing for complete training from scratch: 100 | - Multi-GPU training 101 | - Data augmentations 102 | - FP16 activations and gradients 103 | 104 | ## Contribute 105 | 106 | The implementation is still in an early stage in terms of usability / testing. 107 | If you have an idea to improve this repo open an issue, start a discussion or submit a pull request. 108 | 109 | The current development status can be seen in [this](https://github.com/benjs/nfnets_pytorch/projects/1) project board. 110 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable 3 | from torchvision import transforms 4 | from torch.utils.data.dataset import Dataset 5 | from torchvision.datasets import ImageNet 6 | 7 | def get_dataset(path:Path, transforms:Callable=None) -> Dataset: 8 | return ImageNet(str(path), split='val', transform=transforms) -------------------------------------------------------------------------------- /default_config.yaml: -------------------------------------------------------------------------------- 1 | # This file contains the default train settings 2 | device: 'cuda:0' # or 'cpu' 3 | amp: False # Enable automatic mixed precision 4 | 5 | # Model 6 | variant: 'F0' # F0 - F7 7 | num_classes: 1000 # Number of classes 8 | activation: 'gelu' # or 'relu' 9 | stochdepth_rate: 0.25 # 0-1, the probability that a layer is dropped during one step 10 | alpha: 0.2 # Scaling factor at the end of each block 11 | se_ratio: 0.5 # Squeeze-Excite expansion ratio 12 | use_fp16: False # Use 16bit floats, which lowers memory footprint. This currently sets 13 | # the complete model to FP16 (will be changed to match FP16 ops from paper) 14 | 15 | # Dataset 16 | dataset: '/media/benjs/ext/' # Dataset root directory 17 | num_workers: 8 # Number of workers in dataloader 18 | pin_memory: True # This can fasten or slow down data loading depending on your hardware 19 | 20 | # Training 21 | batch_size: 64 # Batch size 22 | epochs: 360 # Number of epochs 23 | overfit: False # Train on one batch size only 24 | 25 | learning_rate: 0.1 # Learning rate 26 | scale_lr: True # Scale learning rate with batch size. lr = lr*batch_size/256 27 | momentum: 0.9 # Contribution of earlier gradient to gradient update 28 | weight_decay: 0.00002 # Factor with which weights are added to gradient 29 | nesterov: True # Enable nesterov correction 30 | 31 | do_clip: True # Enable adaptive gradient clipping 32 | clipping: 0.1 # Adaptive gradient clipping parameter -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import PIL 4 | from pathlib import Path 5 | from PIL.Image import Image 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision.transforms.functional as tF 10 | import torchvision.transforms.functional_pil as tF_pil 11 | from torch.utils.data.dataloader import DataLoader 12 | from torchvision.transforms.transforms import Compose, Normalize, Resize, ToTensor 13 | 14 | from dataset import get_dataset 15 | from nfnets import NFNet, pretrained_nfnet 16 | 17 | # Evaluation method used in the paper 18 | # This seems to perform slightly worse than a simple resize 19 | class Pad32CenterCrop(nn.Module): 20 | def __init__(self, size:int): 21 | super().__init__() 22 | self.size = size 23 | self.scaled_size = (size+32, size+32) 24 | 25 | def forward(self, img:Image): 26 | img = tF_pil.resize(img=img, size=self.scaled_size, interpolation=PIL.Image.BICUBIC) 27 | return tF.center_crop(img, self.size) 28 | 29 | def evaluate_on_imagenet(model:NFNet, dataset_dir:Path, batch_size=50, device='cuda:0'): 30 | transforms = Compose([ 31 | #Pad32CenterCrop(model.test_imsize), 32 | ToTensor(), 33 | Resize((model.test_imsize, model.test_imsize), PIL.Image.BICUBIC), 34 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 35 | ]) 36 | 37 | print(f"Starting evaluation from {dataset_dir}") 38 | dataset = get_dataset(dataset_dir, transforms=transforms) 39 | 40 | dataloader = DataLoader( 41 | dataset=dataset, 42 | batch_size=batch_size, # F0: 120, F1: 100, F2: 80 43 | shuffle=False, 44 | pin_memory=False, 45 | num_workers=8 46 | ) 47 | 48 | print(f"Validation set contains {len(dataset)} images.") 49 | 50 | model.to(device) 51 | model.eval() 52 | 53 | processed_imgs = 0 54 | correct_labels = 0 55 | for step, data in enumerate(dataloader): 56 | with torch.no_grad(): 57 | inputs = data[0].to(device) 58 | targets = data[1].to(device) 59 | 60 | output = model(inputs).type(torch.float32) 61 | 62 | processed_imgs += targets.size(0) 63 | _, predicted = torch.max(output, 1) 64 | correct_labels += (predicted == targets).sum().item() 65 | 66 | batch_padding = int(math.log10(len(dataloader.dataset)) + 1) 67 | print(f"\rProcessing {processed_imgs:{batch_padding}d}/{len(dataloader.dataset)}. Accuracy: {100.0*correct_labels/processed_imgs:6.4f}", sep=' ', end='', flush=True) 68 | 69 | print(f"\nFinished eval. Accuracy: {100.0*correct_labels/processed_imgs:6.4f}") 70 | 71 | 72 | if __name__=='__main__': 73 | parser = argparse.ArgumentParser(description='Evaluate NFNets.') 74 | parser.add_argument('--dataset', type=Path, help='Path to dataset root directory', required=True) 75 | parser.add_argument('--pretrained', type=Path, help='Path to pre-trained weights in haiku format', required=True) 76 | parser.add_argument('--batch-size', type=int, help='Validation batch size', default=50) 77 | parser.add_argument('--device', type=str, help='Validation device. Either \'cuda:0\' or \'cpu\'', default='cuda:0') 78 | args = parser.parse_args() 79 | 80 | if not args.pretrained.exists(): 81 | raise FileNotFoundError(f"Could not find file {args.pretrained.absolute()}") 82 | 83 | model = pretrained_nfnet(args.pretrained) 84 | 85 | evaluate_on_imagenet(model, dataset_dir=args.dataset, batch_size=args.batch_size, device=args.device) 86 | -------------------------------------------------------------------------------- /nfnets/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * 2 | from .pretrained import * 3 | from .optim import * -------------------------------------------------------------------------------- /nfnets/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import re 5 | 6 | nfnet_params = { 7 | 'F0': { 8 | 'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3], 9 | 'train_imsize': 192, 'test_imsize': 256, 10 | 'RA_level': '405', 'drop_rate': 0.2}, 11 | 'F1': { 12 | 'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6], 13 | 'train_imsize': 224, 'test_imsize': 320, 14 | 'RA_level': '410', 'drop_rate': 0.3}, 15 | 'F2': { 16 | 'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9], 17 | 'train_imsize': 256, 'test_imsize': 352, 18 | 'RA_level': '410', 'drop_rate': 0.4}, 19 | 'F3': { 20 | 'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12], 21 | 'train_imsize': 320, 'test_imsize': 416, 22 | 'RA_level': '415', 'drop_rate': 0.4}, 23 | 'F4': { 24 | 'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15], 25 | 'train_imsize': 384, 'test_imsize': 512, 26 | 'RA_level': '415', 'drop_rate': 0.5}, 27 | 'F5': { 28 | 'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18], 29 | 'train_imsize': 416, 'test_imsize': 544, 30 | 'RA_level': '415', 'drop_rate': 0.5}, 31 | 'F6': { 32 | 'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21], 33 | 'train_imsize': 448, 'test_imsize': 576, 34 | 'RA_level': '415', 'drop_rate': 0.5}, 35 | 'F7': { 36 | 'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24], 37 | 'train_imsize': 480, 'test_imsize': 608, 38 | 'RA_level': '415', 'drop_rate': 0.5}, 39 | } 40 | 41 | # These extra constant values ensure that the activations 42 | # are variance preserving 43 | class VPGELU(nn.Module): 44 | def forward(self, input: torch.Tensor) -> torch.Tensor: 45 | return F.gelu(input) * 1.7015043497085571 46 | 47 | class VPReLU(nn.Module): 48 | __constants__ = ['inplace'] 49 | inplace: bool 50 | 51 | def __init__(self, inplace: bool = False): 52 | super(VPReLU, self).__init__() 53 | self.inplace = inplace 54 | 55 | def forward(self, input: torch.Tensor) -> torch.Tensor: 56 | return F.relu(input, inplace=self.inplace) * 1.7139588594436646 57 | 58 | def extra_repr(self) -> str: 59 | inplace_str = 'inplace=True' if self.inplace else '' 60 | return inplace_str 61 | 62 | activations_dict = { 63 | 'gelu': VPGELU(), 64 | 'relu': VPReLU(inplace=True) 65 | } 66 | 67 | class NFNet(nn.Module): 68 | def __init__(self, num_classes:int, variant:str='F0', stochdepth_rate:float=None, 69 | alpha:float=0.2, se_ratio:float=0.5, activation:str='gelu'): 70 | super(NFNet, self).__init__() 71 | 72 | if not variant in nfnet_params: 73 | raise RuntimeError(f"Variant {variant} does not exist and could not be loaded.") 74 | 75 | block_params = nfnet_params[variant] 76 | 77 | self.train_imsize = block_params['train_imsize'] 78 | self.test_imsize = block_params['test_imsize'] 79 | self.activation = activations_dict[activation] 80 | self.drop_rate = block_params['drop_rate'] 81 | self.num_classes = num_classes 82 | 83 | self.stem = Stem(activation=activation) 84 | 85 | num_blocks, index = sum(block_params['depth']), 0 86 | 87 | blocks = [] 88 | expected_std = 1.0 89 | in_channels = block_params['width'][0] // 2 90 | 91 | block_args = zip( 92 | block_params['width'], 93 | block_params['depth'], 94 | [0.5] * 4, # bottleneck pattern 95 | [128] * 4, # group pattern. Original groups [128] * 4 96 | [1, 2, 2, 2] # stride pattern 97 | ) 98 | 99 | for (block_width, stage_depth, expand_ratio, group_size, stride) in block_args: 100 | for block_index in range(stage_depth): 101 | beta = 1. / expected_std 102 | 103 | block_sd_rate = stochdepth_rate * index / num_blocks 104 | out_channels = block_width 105 | 106 | blocks.append(NFBlock( 107 | in_channels=in_channels, 108 | out_channels=out_channels, 109 | stride=stride if block_index == 0 else 1, 110 | alpha=alpha, 111 | beta=beta, 112 | se_ratio=se_ratio, 113 | group_size=group_size, 114 | stochdepth_rate=block_sd_rate, 115 | activation=activation)) 116 | 117 | in_channels = out_channels 118 | index += 1 119 | 120 | if block_index == 0: 121 | expected_std = 1.0 122 | 123 | expected_std = (expected_std **2 + alpha**2)**0.5 124 | 125 | self.body = nn.Sequential(*blocks) 126 | 127 | final_conv_channels = 2*in_channels 128 | self.final_conv = WSConv2D(in_channels=out_channels, out_channels=final_conv_channels, kernel_size=1) 129 | self.pool = nn.AvgPool2d(1) 130 | 131 | if self.drop_rate > 0.: 132 | self.dropout = nn.Dropout(self.drop_rate) 133 | 134 | self.linear = nn.Linear(final_conv_channels, self.num_classes) 135 | nn.init.normal_(self.linear.weight, 0, 0.01) 136 | 137 | def forward(self, x): 138 | out = self.stem(x) 139 | out = self.body(out) 140 | out = self.activation(self.final_conv(out)) 141 | pool = torch.mean(out, dim=(2,3)) 142 | 143 | if self.training and self.drop_rate > 0.: 144 | pool = self.dropout(pool) 145 | 146 | return self.linear(pool) 147 | 148 | def exclude_from_weight_decay(self, name:str) -> bool: 149 | # Regex to find layer names like 150 | # "stem.6.bias", "stem.6.gain", "body.0.skip_gain", 151 | # "body.0.conv0.bias", "body.0.conv0.gain" 152 | regex = re.compile('stem.*(bias|gain)|conv.*(bias|gain)|skip_gain') 153 | return len(regex.findall(name)) > 0 154 | 155 | def exclude_from_clipping(self, name: str) -> bool: 156 | # Last layer should not be clipped 157 | return name.startswith('linear') 158 | 159 | class Stem(nn.Module): 160 | def __init__(self, activation:str='gelu'): 161 | super(Stem, self).__init__() 162 | 163 | self.activation = activations_dict[activation] 164 | self.conv0 = WSConv2D(in_channels=3, out_channels=16, kernel_size=3, stride=2) 165 | self.conv1 = WSConv2D(in_channels=16, out_channels=32, kernel_size=3, stride=1) 166 | self.conv2 = WSConv2D(in_channels=32, out_channels=64, kernel_size=3, stride=1) 167 | self.conv3 = WSConv2D(in_channels=64, out_channels=128, kernel_size=3, stride=2) 168 | 169 | def forward(self, x): 170 | out = self.activation(self.conv0(x)) 171 | out = self.activation(self.conv1(out)) 172 | out = self.activation(self.conv2(out)) 173 | out = self.conv3(out) 174 | return out 175 | 176 | class NFBlock(nn.Module): 177 | def __init__(self, in_channels:int, out_channels:int, expansion:float=0.5, 178 | se_ratio:float=0.5, stride:int=1, beta:float=1.0, alpha:float=0.2, 179 | group_size:int=1, stochdepth_rate:float=None, activation:str='gelu'): 180 | 181 | super(NFBlock, self).__init__() 182 | 183 | self.in_channels = in_channels 184 | self.out_channels = out_channels 185 | self.expansion = expansion 186 | self.se_ratio = se_ratio 187 | self.activation = activations_dict[activation] 188 | self.beta, self.alpha = beta, alpha 189 | self.group_size = group_size 190 | 191 | width = int(self.out_channels * expansion) 192 | self.groups = width // group_size 193 | self.width = group_size * self.groups 194 | self.stride = stride 195 | 196 | self.conv0 = WSConv2D(in_channels=self.in_channels, out_channels=self.width, kernel_size=1) 197 | self.conv1 = WSConv2D(in_channels=self.width, out_channels=self.width, kernel_size=3, stride=stride, padding=1, groups=self.groups) 198 | self.conv1b = WSConv2D(in_channels=self.width, out_channels=self.width, kernel_size=3, stride=1, padding=1, groups=self.groups) 199 | self.conv2 = WSConv2D(in_channels=self.width, out_channels=self.out_channels, kernel_size=1) 200 | 201 | self.use_projection = self.stride > 1 or self.in_channels != self.out_channels 202 | if self.use_projection: 203 | if stride > 1: 204 | self.shortcut_avg_pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0 if self.in_channels==1536 else 1) 205 | self.conv_shortcut = WSConv2D(self.in_channels, self.out_channels, kernel_size=1) 206 | 207 | self.squeeze_excite = SqueezeExcite(self.out_channels, self.out_channels, se_ratio=self.se_ratio, activation=activation) 208 | self.skip_gain = nn.Parameter(torch.zeros(())) 209 | 210 | self.use_stochdepth = stochdepth_rate is not None and stochdepth_rate > 0. and stochdepth_rate < 1. 211 | if self.use_stochdepth: 212 | self.stoch_depth = StochDepth(stochdepth_rate) 213 | 214 | def forward(self, x): 215 | out = self.activation(x) * self.beta 216 | 217 | if self.stride > 1: 218 | shortcut = self.shortcut_avg_pool(out) 219 | shortcut = self.conv_shortcut(shortcut) 220 | elif self.use_projection: 221 | shortcut = self.conv_shortcut(out) 222 | else: 223 | shortcut = x 224 | 225 | out = self.activation(self.conv0(out)) 226 | out = self.activation(self.conv1(out)) 227 | out = self.activation(self.conv1b(out)) 228 | out = self.conv2(out) 229 | out = (self.squeeze_excite(out)*2) * out 230 | 231 | if self.use_stochdepth: 232 | out = self.stoch_depth(out) 233 | 234 | return out * self.alpha * self.skip_gain + shortcut 235 | 236 | # Implementation mostly from https://arxiv.org/abs/2101.08692 237 | # Implemented changes from https://arxiv.org/abs/2102.06171 and 238 | # https://github.com/deepmind/deepmind-research/tree/master/nfnets 239 | class WSConv2D(nn.Conv2d): 240 | def __init__(self, in_channels: int, out_channels: int, kernel_size, stride = 1, padding = 0, 241 | dilation = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros'): 242 | 243 | super(WSConv2D, self).__init__(in_channels, out_channels, kernel_size, stride, 244 | padding, dilation, groups, bias, padding_mode) 245 | 246 | nn.init.xavier_normal_(self.weight) 247 | self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) 248 | self.register_buffer('eps', torch.tensor(1e-4, requires_grad=False), persistent=False) 249 | self.register_buffer('fan_in', torch.tensor(self.weight.shape[1:].numel(), requires_grad=False).type_as(self.weight), persistent=False) 250 | 251 | def standardized_weights(self): 252 | # Original code: HWCN 253 | mean = torch.mean(self.weight, axis=[1,2,3], keepdims=True) 254 | var = torch.var(self.weight, axis=[1,2,3], keepdims=True) 255 | scale = torch.rsqrt(torch.maximum(var * self.fan_in, self.eps)) 256 | return (self.weight - mean) * scale * self.gain 257 | 258 | def forward(self, x): 259 | return F.conv2d( 260 | input=x, 261 | weight=self.standardized_weights(), 262 | bias=self.bias, 263 | stride=self.stride, 264 | padding=self.padding, 265 | dilation=self.dilation, 266 | groups=self.groups 267 | ) 268 | 269 | class SqueezeExcite(nn.Module): 270 | def __init__(self, in_channels:int, out_channels:int, se_ratio:float=0.5, activation:str='gelu'): 271 | super(SqueezeExcite, self).__init__() 272 | 273 | self.in_channels = in_channels 274 | self.out_channels = out_channels 275 | self.se_ratio = se_ratio 276 | 277 | self.hidden_channels = max(1, int(self.in_channels * self.se_ratio)) 278 | 279 | self.activation = activations_dict[activation] 280 | self.linear = nn.Linear(self.in_channels, self.hidden_channels) 281 | self.linear_1 = nn.Linear(self.hidden_channels, self.out_channels) 282 | self.sigmoid = nn.Sigmoid() 283 | 284 | def forward(self, x): 285 | out = torch.mean(x, (2,3)) 286 | out = self.linear_1(self.activation(self.linear(out))) 287 | out = self.sigmoid(out) 288 | 289 | b,c,_,_ = x.size() 290 | return out.view(b,c,1,1).expand_as(x) 291 | 292 | class StochDepth(nn.Module): 293 | def __init__(self, stochdepth_rate:float): 294 | super(StochDepth, self).__init__() 295 | 296 | self.drop_rate = stochdepth_rate 297 | 298 | def forward(self, x): 299 | if not self.training: 300 | return x 301 | 302 | batch_size = x.shape[0] 303 | rand_tensor = torch.rand(batch_size, 1, 1, 1).type_as(x).to(x.device) 304 | keep_prob = 1 - self.drop_rate 305 | binary_tensor = torch.floor(rand_tensor + keep_prob) 306 | 307 | return x * binary_tensor 308 | -------------------------------------------------------------------------------- /nfnets/optim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | # Compute norm depending on the shape of x 5 | def unitwise_norm(x): 6 | if (len(torch.squeeze(x).shape)) <= 1: # Scalars, vectors 7 | axis = 0 8 | keepdims = False 9 | elif len(x.shape) in [2,3]: # Linear layers 10 | # Original code: IO 11 | # Pytorch: OI 12 | axis = 1 13 | keepdims = True 14 | elif len(x.shape) == 4: # Conv kernels 15 | # Original code: HWIO 16 | # Pytorch: OIHW 17 | axis = [1, 2, 3] 18 | keepdims = True 19 | else: 20 | raise ValueError(f'Got a parameter with len(shape) not in [1, 2, 3, 4]! {x}') 21 | 22 | return torch.sqrt(torch.sum(torch.square(x), axis=axis, keepdim=keepdims)) 23 | 24 | 25 | # This is a copy of the pytorch SGD implementation 26 | # enhanced with gradient clipping 27 | class SGD_AGC(Optimizer): 28 | def __init__(self, named_params, lr:float, momentum=0, dampening=0, 29 | weight_decay=0, nesterov=False, clipping:float=None, eps:float=1e-3): 30 | if lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 36 | 37 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 38 | weight_decay=weight_decay, nesterov=nesterov, 39 | # Extra defaults 40 | clipping=clipping, 41 | eps=eps 42 | ) 43 | 44 | if nesterov and (momentum <= 0 or dampening != 0): 45 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 46 | 47 | # Put params in list so each one gets its own group 48 | params = [] 49 | for name, param in named_params: 50 | params.append({'params': param, 'name': name}) 51 | 52 | super(SGD_AGC, self).__init__(params, defaults) 53 | 54 | def __setstate__(self, state): 55 | super(SGD_AGC, self).__setstate__(state) 56 | for group in self.param_groups: 57 | group.setdefault('nesterov', False) 58 | 59 | @torch.no_grad() 60 | def step(self, closure=None): 61 | loss = None 62 | if closure is not None: 63 | with torch.enable_grad(): 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | weight_decay = group['weight_decay'] 68 | momentum = group['momentum'] 69 | dampening = group['dampening'] 70 | nesterov = group['nesterov'] 71 | 72 | # Extra values for clipping 73 | clipping = group['clipping'] 74 | eps = group['eps'] 75 | 76 | for p in group['params']: 77 | if p.grad is None: 78 | continue 79 | d_p = p.grad 80 | 81 | # ========================= 82 | # Gradient clipping 83 | if clipping is not None: 84 | param_norm = torch.maximum(unitwise_norm(p), torch.tensor(eps).to(p.device)) 85 | grad_norm = unitwise_norm(d_p) 86 | max_norm = param_norm * group['clipping'] 87 | 88 | trigger_mask = grad_norm > max_norm 89 | clipped_grad = p.grad * (max_norm / torch.maximum(grad_norm, torch.tensor(1e-6).to(p.device))) 90 | d_p = torch.where(trigger_mask, clipped_grad, d_p) 91 | # ========================= 92 | 93 | if weight_decay != 0: 94 | d_p = d_p.add(p, alpha=weight_decay) 95 | if momentum != 0: 96 | param_state = self.state[p] 97 | if 'momentum_buffer' not in param_state: 98 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 99 | else: 100 | buf = param_state['momentum_buffer'] 101 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 102 | if nesterov: 103 | d_p = d_p.add(buf, alpha=momentum) 104 | else: 105 | d_p = buf 106 | 107 | p.add_(d_p, alpha=-group['lr']) 108 | 109 | return loss -------------------------------------------------------------------------------- /nfnets/pretrained.py: -------------------------------------------------------------------------------- 1 | import re 2 | import dill 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from pathlib import Path 7 | 8 | from nfnets import NFNet 9 | 10 | def pretrained_nfnet(path, stochdepth_rate:float=0.5, alpha:float=0.2, activation:str='gelu') -> NFNet: 11 | if isinstance(path, str): 12 | path = Path(path) 13 | 14 | with path.open('rb') as f: 15 | params = dill.load(f) 16 | 17 | layers_to_variant = { 18 | 94: 'F0', 19 | 178: 'F1', 20 | 262: 'F2', 21 | 346: 'F3', 22 | 430: 'F4', 23 | 514: 'F5' 24 | } 25 | 26 | if not len(params) in layers_to_variant: 27 | raise RuntimeError(f"Cannot load file {path.absolute()}." 28 | f" File contains invalid parameter count {len(params)}!") 29 | 30 | model = NFNet( 31 | variant=layers_to_variant[len(params)], 32 | num_classes=1000, 33 | alpha=alpha, 34 | stochdepth_rate=stochdepth_rate, 35 | se_ratio=0.5, 36 | activation=activation) 37 | 38 | state_dict = {} 39 | 40 | for layer_name in params: 41 | for param_name in params[layer_name]: 42 | l = layer_name 43 | l = l.replace("NFNet/~/", "") 44 | l = re.sub("(nf_block_(\d*))", r"body.\2", l) 45 | l = re.sub("(nf_block)", r"body.0", l) 46 | l = re.sub("stem_*", "stem.", l) 47 | l = l.replace("/~/", ".") 48 | 49 | p = str(param_name) 50 | p = "weight" if p == "w" else p 51 | p = "bias" if p == "b" else p 52 | 53 | param = params[layer_name][param_name] 54 | 55 | if len(param.shape) == 4: 56 | # Conv layers, HWIO -> OIHW 57 | param = param.swapaxes(0,3).swapaxes(1,2).swapaxes(2,3) 58 | 59 | elif len(param.shape) == 2: 60 | # Linear layers, OI -> IO 61 | param = param.swapaxes(0,1) 62 | 63 | if p == 'gain': 64 | param = np.expand_dims(param, axis=(1,2,3)) 65 | 66 | #if "conv" in l: 67 | # state_dict[f"{l}.eps"] = torch.tensor(1e-4, requires_grad=False) 68 | 69 | with torch.no_grad(): 70 | t = torch.from_numpy(param) 71 | complete_name = f'{l}.{p}' 72 | if not complete_name in model.state_dict(): 73 | raise ValueError( 74 | f"Parameter {complete_name} not found in state dict!" 75 | " Please report an issue.") 76 | 77 | state_dict[complete_name] = t 78 | 79 | model.load_state_dict(state_dict, strict=True) 80 | return model 81 | 82 | if __name__=='__main__': 83 | parser = argparse.ArgumentParser(description='Load haiku weights and convert them to .pth file.') 84 | parser.add_argument('--pretrained', type=Path, help='Path to pre-trained weights in haiku format') 85 | args = parser.parse_args() 86 | 87 | if not args.pretrained.exists(): 88 | raise FileNotFoundError(f"Could not find file {args.pretrained.absolute()}") 89 | 90 | model = from_pretrained_haiku(args.pretrained) 91 | 92 | torch.save({ 93 | 'model': model.state_dict() 94 | }, str(args.pretrained.with_suffix('.pth'))) -------------------------------------------------------------------------------- /pretrained/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained weights 2 | 3 | Download the pretrained weights from the [official repository](https://github.com/deepmind/deepmind-research/tree/master/nfnets#pre-trained-weights) and place them inside this folder. 4 | Then start training with 5 | ``` 6 | python3 train.py --pretrained pretrained/F0_haiku.npz 7 | ``` 8 | 9 | or evaluation with 10 | ``` 11 | python3 eval.py --pretrained pretrained/F0_haiku.npz --dataset /path/to/imagenet/val/ 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel" 5 | ] 6 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --find-links https://download.pytorch.org/whl/torch_stable.html 2 | --find-links https://download.pytorch.org/whl/cu110/torch_stable.html 3 | 4 | dill 5 | git+https://github.com/deepmind/dm-haiku 6 | jax 7 | jaxlib 8 | matplotlib 9 | numpy 10 | pillow-simd 11 | pyyaml 12 | requests 13 | tensorboard 14 | torch>=1.7.1+cu110 15 | torchvision>=0.8.2+cu110 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name=nfnets_pytorch 3 | version=0.0.1 4 | author=Benjamin Schmidt 5 | author_email = webmaster@benjs.de 6 | license=Apache 2.0 7 | license_file=LICENSE 8 | description=Implementation of the paper "High-Performance Large-Scale Image Recognition Without Normalization" by Brock et al. 9 | long_description=file:README.md 10 | long_description_content_type=text/markdown 11 | url=https://github.com/benjs/nfnets_pytorch 12 | project_urls = 13 | Bug Tracker = https://github.com/benjs/nfnets_pytorch/issues 14 | classifiers = 15 | Programming Language :: Python :: 3 16 | License :: OSI Approved :: Apache Software License 17 | Operating System :: OS Independent 18 | Natural Language :: English 19 | 20 | [options] 21 | packages = nfnets 22 | python_requires = >=3.7 23 | install_requires = 24 | dm-haiku 25 | requests 26 | dill 27 | jax 28 | jaxlib 29 | numpy 30 | requests 31 | torch>=1.7 32 | 33 | dependency_links= 34 | git+https://github.com/deepmind/dm-haiku 35 | https://download.pytorch.org/whl/torch_stable.html -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import PIL 4 | import time 5 | import yaml 6 | from pathlib import Path 7 | from PIL.Image import Image 8 | 9 | import matplotlib.pyplot as plt 10 | import torch 11 | import torch.cuda.amp as amp 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data import Subset 15 | from torch.utils.data.dataloader import DataLoader 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torchvision.transforms.transforms import Compose, Normalize, Resize, ToTensor, RandomHorizontalFlip, RandomCrop 18 | 19 | from dataset import get_dataset 20 | from nfnets import NFNet, SGD_AGC, pretrained_nfnet 21 | 22 | def train(config:dict) -> None: 23 | if config['device'].startswith('cuda'): 24 | if torch.cuda.is_available(): 25 | print(f"Using CUDA{torch.version.cuda} with cuDNN{torch.backends.cudnn.version()}") 26 | else: 27 | raise ValueError("You specified to use cuda device, but cuda is not available.") 28 | 29 | if config['pretrained'] is not None: 30 | model = pretrained_nfnet( 31 | path=config['pretrained'], 32 | stochdepth_rate=config['stochdepth_rate'], 33 | alpha=config['alpha'], 34 | activation=config['activation'] 35 | ) 36 | else: 37 | model = NFNet( 38 | num_classes=config['num_classes'], 39 | variant=config['variant'], 40 | stochdepth_rate=config['stochdepth_rate'], 41 | alpha=config['alpha'], 42 | se_ratio=config['se_ratio'], 43 | activation=config['activation'] 44 | ) 45 | 46 | transforms = Compose([ 47 | RandomHorizontalFlip(), 48 | Resize((model.train_imsize, model.train_imsize), PIL.Image.BICUBIC), 49 | ToTensor(), 50 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 51 | ]) 52 | 53 | device = config['device'] 54 | dataset = get_dataset(path=config['dataset'], transforms=transforms) 55 | 56 | if config['overfit']: 57 | dataset = Subset(dataset, [i*50 for i in range(0,1000)] ) 58 | 59 | dataloader = DataLoader( 60 | dataset=dataset, 61 | batch_size=config['batch_size'], 62 | shuffle=True, 63 | num_workers=config['num_workers'], 64 | pin_memory=config['pin_memory']) 65 | 66 | if config['scale_lr']: 67 | learning_rate = config['learning_rate']*config['batch_size']/256 68 | else: 69 | learning_rate = config['learning_rate'] 70 | 71 | if not config['do_clip']: 72 | config['clipping'] = None 73 | 74 | if config['use_fp16']: 75 | model.half() 76 | 77 | model.to(device) # "memory_format=torch.channels_last" TBD 78 | 79 | optimizer = SGD_AGC( 80 | # The optimizer needs all parameter names 81 | # to filter them by hand later 82 | named_params=model.named_parameters(), 83 | lr=learning_rate, 84 | momentum=config['momentum'], 85 | clipping=config['clipping'], 86 | weight_decay=config['weight_decay'], 87 | nesterov=config['nesterov'] 88 | ) 89 | 90 | # Find desired parameters and exclude them 91 | # from weight decay and clipping 92 | for group in optimizer.param_groups: 93 | name = group['name'] 94 | 95 | if model.exclude_from_weight_decay(name): 96 | group['weight_decay'] = 0 97 | 98 | if model.exclude_from_clipping(name): 99 | group['clipping'] = None 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | 103 | runs_dir = Path('runs') 104 | run_index = 0 105 | while (runs_dir / ('run' + str(run_index))).exists(): 106 | run_index += 1 107 | runs_dir = runs_dir / ('run' + str(run_index)) 108 | runs_dir.mkdir(exist_ok=False, parents=True) 109 | checkpoints_dir = runs_dir / 'checkpoints' 110 | checkpoints_dir.mkdir() 111 | 112 | writer = SummaryWriter(str(runs_dir)) 113 | scaler = amp.GradScaler() 114 | 115 | for epoch in range(config['epochs']): 116 | model.train() 117 | running_loss = 0.0 118 | processed_imgs = 0 119 | correct_labels = 0 120 | epoch_time = time.time() 121 | 122 | for step, data in enumerate(dataloader): 123 | inputs = data[0].half().to(device) if config['use_fp16'] else data[0].to(device) 124 | targets = data[1].to(device) 125 | 126 | optimizer.zero_grad() 127 | 128 | with amp.autocast(enabled=config['amp']): 129 | output = model(inputs) 130 | loss = criterion(output, targets) 131 | 132 | # Gradient scaling 133 | # https://www.youtube.com/watch?v=OqCrNkjN_PM 134 | scaler.scale(loss).backward() 135 | scaler.step(optimizer) 136 | scaler.update() 137 | 138 | running_loss += loss.item() 139 | processed_imgs += targets.size(0) 140 | _, predicted = torch.max(output, 1) 141 | correct_labels += (predicted == targets).sum().item() 142 | 143 | epoch_padding = int(math.log10(config['epochs']) + 1) 144 | batch_padding = int(math.log10(len(dataloader.dataset)) + 1) 145 | print(f"\rEpoch {epoch+1:0{epoch_padding}d}/{config['epochs']}" 146 | f"\tImg {processed_imgs:{batch_padding}d}/{len(dataloader.dataset)}" 147 | f"\tLoss {running_loss / (step+1):6.4f}" 148 | f"\tAcc {100.0*correct_labels/processed_imgs:5.3f}%\t", 149 | sep=' ', end='', flush=True) 150 | 151 | elapsed = time.time() - epoch_time 152 | print (f"({elapsed:.3f}s, {elapsed/len(dataloader):.3}s/step, {elapsed/len(dataset):.3}s/img)") 153 | 154 | global_step = epoch*len(dataloader) + step 155 | writer.add_scalar('training/loss', running_loss/(step+1), global_step) 156 | writer.add_scalar('training/accuracy', 100.0*correct_labels/processed_imgs, global_step) 157 | 158 | #if not config['overfit']: 159 | if epoch % 10 == 0 and epoch != 0: 160 | cp_path = checkpoints_dir / ("checkpoint_epoch" + str(epoch+1) + ".pth") 161 | 162 | torch.save({ 163 | 'epoch': epoch, 164 | 'model': model.state_dict(), 165 | 'optim': optimizer.state_dict(), 166 | 'loss': loss 167 | }, str(cp_path)) 168 | 169 | print(f"Saved checkpoint to {str(cp_path)}") 170 | 171 | if __name__=='__main__': 172 | parser = argparse.ArgumentParser(description='Train NFNets.') 173 | parser.add_argument('--config', type=Path, help='Path to config.yaml', default='default_config.yaml') 174 | parser.add_argument('--batch-size', type=int, help='Training batch size', default=None) 175 | parser.add_argument('--overfit', const=True, default=False, nargs='?', help='Crop the dataset to the batch size and force model to (hopefully) overfit') 176 | parser.add_argument('--variant', type=str, help='NFNet variant to train', default=None) 177 | parser.add_argument('--pretrained', type=Path, help='Path to pre-trained weights in haiku format', default=None) 178 | args = parser.parse_args() 179 | 180 | if not args.config.exists(): 181 | print(f"Config file \"{args.config}\" does not exist!\n") 182 | exit() 183 | 184 | with args.config.open() as file: 185 | config = yaml.safe_load(file) 186 | 187 | # Override config.yaml settings with command line settings 188 | for arg in vars(args): 189 | if getattr(args, arg) is not None and arg in config: 190 | config[arg] = getattr(args, arg) 191 | 192 | config['pretrained'] = args.pretrained 193 | 194 | train(config=config) 195 | --------------------------------------------------------------------------------