├── .gitignore ├── LICENSE ├── README.md ├── datasets.py ├── engine.py ├── main.py ├── models ├── __init__.py ├── crossvit.py └── t2t │ ├── __init__.py │ └── t2t.py ├── requirements.txt ├── run_with_submitit.py ├── samplers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CrossViT 2 | 3 | This repository is the official implementation of CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification. [ArXiv](https://arxiv.org/abs/2103.14899) 4 | 5 | If you use the codes and models from this repo, please cite our work. Thanks! 6 | 7 | ``` 8 | @inproceedings{ 9 | chen2021crossvit, 10 | title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}}, 11 | author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda}, 12 | booktitle={International Conference on Computer Vision (ICCV)}, 13 | year={2021} 14 | } 15 | ``` 16 | 17 | 18 | ## Installation 19 | 20 | To install requirements: 21 | 22 | ```setup 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | With conda: 27 | 28 | ``` 29 | conda create -n crossvit python=3.8 30 | conda activate crossvit 31 | conda install pytorch=1.7.1 torchvision cudatoolkit=11.0 -c pytorch -c nvidia 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Data preparation 36 | 37 | Download and extract ImageNet train and val images from http://image-net.org/. 38 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 39 | 40 | ``` 41 | /path/to/imagenet/ 42 | train/ 43 | class1/ 44 | img1.jpeg 45 | class2/ 46 | img2.jpeg 47 | val/ 48 | class1/ 49 | img3.jpeg 50 | class/2 51 | img4.jpeg 52 | ``` 53 | 54 | ## Pretrained models 55 | 56 | We provide models trained on ImageNet1K. You can find models [here](https://github.com/IBM/CrossViT/releases/tag/weights-0.1). 57 | And you can load pretrained weights into models by add `--pretrained` flag. 58 | 59 | 60 | ## Training 61 | 62 | To train `crossvit_9_dagger_224` on ImageNet on a single node with 8 gpus for 300 epochs run: 63 | 64 | ```shell script 65 | 66 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet 67 | ``` 68 | 69 | Other model names can be found at [models/crossvit.py](models/crossvit.py). 70 | 71 | ## Multinode training 72 | 73 | Distributed training is available via Slurm and `submitit`: 74 | 75 | To train a `crossvit_9_dagger_224` model on ImageNet on 4 nodes with 8 gpus each for 300 epochs: 76 | 77 | ``` 78 | python run_with_submitit.py --nodes 4 --model crossvit_9_dagger_224 --data-path /path/to/imagenet --batch-size 128 --warmup-epochs 30 79 | ``` 80 | 81 | Or you can start process on each machine maunally. E.g. 2 nodes, each with 8 gpus. 82 | 83 | Machine 0: 84 | ```shell script 85 | 86 | python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=0 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet 87 | ``` 88 | 89 | Machine 1: 90 | ```shell script 91 | 92 | python -m torch.distributed.launch --nproc_per_node=8 --master_addr=MACHINE_0_IP --master_port=AVAILABLE_PORT --nnodes=2 --node_rank=1 main.py --model crossvit_9_dagger_224 --batch-size 256 --data-path /path/to/imagenet 93 | ``` 94 | 95 | 96 | Note that: some slurm configurations might need to be changed based on your cluster. 97 | 98 | 99 | ## Evaluation 100 | 101 | To evaluate a pretrained model on `crossvit_9_dagger_224`: 102 | 103 | ``` 104 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model crossvit_9_dagger_224 --batch-size 128 --data-path /path/to/imagenet --eval --pretrained 105 | ``` 106 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Mostly copy-paste from https://github.com/facebookresearch/deit/blob/main/datasets.py 6 | """ 7 | 8 | import os 9 | import json 10 | 11 | from torchvision import datasets, transforms 12 | from torchvision.datasets.folder import ImageFolder, default_loader 13 | 14 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 15 | from timm.data import create_transform 16 | 17 | 18 | class INatDataset(ImageFolder): 19 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 20 | category='name', loader=default_loader): 21 | self.transform = transform 22 | self.loader = loader 23 | self.target_transform = target_transform 24 | self.year = year 25 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 26 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 27 | with open(path_json) as json_file: 28 | data = json.load(json_file) 29 | 30 | with open(os.path.join(root, 'categories.json')) as json_file: 31 | data_catg = json.load(json_file) 32 | 33 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 34 | 35 | with open(path_json_for_targeter) as json_file: 36 | data_for_targeter = json.load(json_file) 37 | 38 | targeter = {} 39 | indexer = 0 40 | for elem in data_for_targeter['annotations']: 41 | king = [] 42 | king.append(data_catg[int(elem['category_id'])][category]) 43 | if king[0] not in targeter.keys(): 44 | targeter[king[0]] = indexer 45 | indexer += 1 46 | self.nb_classes = len(targeter) 47 | 48 | self.samples = [] 49 | for elem in data['images']: 50 | cut = elem['file_name'].split('/') 51 | target_current = int(cut[2]) 52 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 53 | 54 | categors = data_catg[target_current] 55 | target_current_true = targeter[categors[category]] 56 | self.samples.append((path_current, target_current_true)) 57 | 58 | # __getitem__ and __len__ inherited from ImageFolder 59 | 60 | 61 | def build_dataset(is_train, args): 62 | transform = build_transform(is_train, args) 63 | 64 | if args.data_set == 'CIFAR10': 65 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform) 66 | nb_classes = 10 67 | elif args.data_set == 'CIFAR100': 68 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 69 | nb_classes = 100 70 | elif args.data_set == 'IMNET': 71 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 72 | dataset = datasets.ImageFolder(root, transform=transform) 73 | nb_classes = 1000 74 | elif args.data_set == 'INAT': 75 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 76 | category=args.inat_category, transform=transform) 77 | nb_classes = dataset.nb_classes 78 | elif args.data_set == 'INAT19': 79 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 80 | category=args.inat_category, transform=transform) 81 | nb_classes = dataset.nb_classes 82 | 83 | return dataset, nb_classes 84 | 85 | 86 | def build_transform(is_train, args): 87 | resize_im = args.input_size > 32 88 | if is_train: 89 | # this should always dispatch to transforms_imagenet_train 90 | transform = create_transform( 91 | input_size=args.input_size, 92 | is_training=True, 93 | color_jitter=args.color_jitter, 94 | auto_augment=args.aa, 95 | interpolation=args.train_interpolation, 96 | re_prob=args.reprob, 97 | re_mode=args.remode, 98 | re_count=args.recount, 99 | ) 100 | if not resize_im: 101 | # replace RandomResizedCropAndInterpolation with 102 | # RandomCrop 103 | transform.transforms[0] = transforms.RandomCrop( 104 | args.input_size, padding=4) 105 | return transform 106 | 107 | t = [] 108 | if resize_im: 109 | size = int(args.crop_ratio * args.input_size) 110 | t.append( 111 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 112 | ) 113 | t.append(transforms.CenterCrop(args.input_size)) 114 | 115 | t.append(transforms.ToTensor()) 116 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 117 | return transforms.Compose(t) 118 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Train and eval functions used in main.py 6 | 7 | Mostly copy-paste from https://github.com/facebookresearch/deit/blob/main/engine.py 8 | """ 9 | 10 | import math 11 | from typing import Iterable, Optional 12 | import torch 13 | 14 | from timm.data import Mixup 15 | from timm.utils import accuracy 16 | from einops import rearrange 17 | 18 | import utils 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 22 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 23 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 24 | mixup_fn: Optional[Mixup] = None, 25 | world_size: int = 1, distributed: bool = True, amp=True, 26 | finetune=False 27 | ): 28 | if finetune: 29 | model.train(not finetune) 30 | else: 31 | model.train() 32 | metric_logger = utils.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 50 36 | 37 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 38 | batch_size = targets.size(0) 39 | 40 | samples = samples.to(device, non_blocking=True) 41 | targets = targets.to(device, non_blocking=True) 42 | 43 | if mixup_fn is not None: 44 | samples, targets = mixup_fn(samples, targets) 45 | 46 | with torch.cuda.amp.autocast(enabled=amp): 47 | outputs = model(samples) 48 | loss = criterion(outputs, targets) 49 | loss_value = loss.item() 50 | 51 | if not math.isfinite(loss_value): 52 | print("Loss is {}, stopping training".format(loss_value)) 53 | raise ValueError("Loss is {}, stopping training".format(loss_value)) 54 | 55 | optimizer.zero_grad() 56 | 57 | # this attribute is added by timm on one optimizer (adahessian) 58 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 59 | 60 | if amp: 61 | loss_scaler(loss, optimizer, clip_grad=max_norm, 62 | parameters=model.parameters(), create_graph=is_second_order) 63 | else: 64 | loss.backward(create_graph=is_second_order) 65 | if max_norm is not None and max_norm != 0.0: 66 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 67 | optimizer.step() 68 | 69 | torch.cuda.synchronize() 70 | 71 | metric_logger.update(loss=loss_value) 72 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 73 | # gather the stats from all processes 74 | metric_logger.synchronize_between_processes() 75 | print("Averaged stats:", metric_logger) 76 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 77 | 78 | 79 | @torch.no_grad() 80 | def evaluate(data_loader, model, device, world_size, distributed=True, amp=False): 81 | criterion = torch.nn.CrossEntropyLoss() 82 | 83 | metric_logger = utils.MetricLogger(delimiter=" ") 84 | header = 'Test:' 85 | 86 | # switch to evaluation mode 87 | model.eval() 88 | 89 | outputs = [] 90 | targets = [] 91 | 92 | for images, target in metric_logger.log_every(data_loader, 10, header): 93 | images = images.to(device, non_blocking=True) 94 | target = target.to(device, non_blocking=True) 95 | # compute output 96 | with torch.cuda.amp.autocast(enabled=amp): 97 | output = model(images) 98 | 99 | if distributed: 100 | outputs.append(concat_all_gather(output)) 101 | targets.append(concat_all_gather(target)) 102 | else: 103 | outputs.append(output) 104 | targets.append(target) 105 | 106 | num_data = len(data_loader.dataset) 107 | outputs = torch.cat(outputs, dim=0) 108 | targets = torch.cat(targets, dim=0) 109 | real_acc1, real_acc5 = accuracy(outputs[:num_data], targets[:num_data], topk=(1, 5)) 110 | real_loss = criterion(outputs, targets) 111 | metric_logger.update(loss=real_loss.item()) 112 | metric_logger.meters['acc1'].update(real_acc1.item()) 113 | metric_logger.meters['acc5'].update(real_acc5.item()) 114 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 115 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 116 | 117 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 118 | 119 | 120 | @torch.no_grad() 121 | def concat_all_gather(tensor): 122 | """ 123 | Performs all_gather operation on the provided tensors. 124 | """ 125 | tensors_gather = [torch.ones_like(tensor) 126 | for _ in range(torch.distributed.get_world_size())] 127 | torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False) 128 | 129 | if tensor.dim() == 1: 130 | output = rearrange(tensors_gather, 'n b -> (b n)') 131 | else: 132 | output = rearrange(tensors_gather, 'n b c -> (b n) c') 133 | return output 134 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Main training and evaluation script 6 | 7 | Mostly copy-paste from https://github.com/facebookresearch/deit/blob/main/main.py 8 | """ 9 | 10 | import argparse 11 | import datetime 12 | import numpy as np 13 | import time 14 | import torch 15 | import torch.backends.cudnn as cudnn 16 | import json 17 | import os 18 | import warnings 19 | 20 | from pathlib import Path 21 | 22 | from timm.data import Mixup 23 | from timm.models import create_model 24 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 25 | from timm.scheduler import create_scheduler 26 | from timm.optim import create_optimizer 27 | from timm.utils import NativeScaler, get_state_dict 28 | 29 | from datasets import build_dataset 30 | from engine import train_one_epoch, evaluate 31 | from samplers import RASampler 32 | import models 33 | import utils 34 | 35 | 36 | warnings.filterwarnings("ignore", category=UserWarning) 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('CrossViT training and evaluation script', add_help=False) 41 | parser.add_argument('--batch-size', default=64, type=int) 42 | parser.add_argument('--epochs', default=300, type=int) 43 | 44 | # Model parameters 45 | parser.add_argument('--model', default='crossvit_small_224', type=str, metavar='MODEL', 46 | help='Name of model to train') 47 | parser.add_argument('--input-size', default=240, type=int, help='images input size') 48 | 49 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 50 | help='Dropout rate (default: 0.)') 51 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 52 | help='Drop path rate (default: 0.1)') 53 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 54 | help='Drop block rate (default: None)') 55 | parser.add_argument('--pretrained', action='store_true', help='load imagenet1k pretrained model') 56 | 57 | # Optimizer parameters 58 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 59 | help='Optimizer (default: "adamw"') 60 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 61 | help='Optimizer Epsilon (default: 1e-8)') 62 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 63 | help='Optimizer Betas (default: None, use opt default)') 64 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 65 | help='Clip gradient norm (default: None, no clipping)') 66 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 67 | help='SGD momentum (default: 0.9)') 68 | parser.add_argument('--weight-decay', type=float, default=0.05, 69 | help='weight decay (default: 0.05)') 70 | # Learning rate schedule parameters 71 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 72 | help='LR scheduler (default: "cosine"') 73 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 74 | help='learning rate (default: 5e-4)') 75 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 76 | help='learning rate noise on/off epoch percentages') 77 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 78 | help='learning rate noise limit percent (default: 0.67)') 79 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 80 | help='learning rate noise std-dev (default: 1.0)') 81 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 82 | help='warmup learning rate (default: 1e-6)') 83 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 84 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 85 | 86 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 87 | help='epoch interval to decay LR') 88 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 89 | help='epochs to warmup LR, if scheduler supports') 90 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 91 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 92 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 93 | help='patience epochs for Plateau LR scheduler (default: 10') 94 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 95 | help='LR decay rate (default: 0.1)') 96 | 97 | # Augmentation parameters 98 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 99 | help='Color jitter factor (default: 0.4)') 100 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 101 | help='Use AutoAugment policy. "v0" or "original". " + \ 102 | "(default: rand-m9-mstd0.5-inc1)'), 103 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 104 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 105 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 106 | 107 | parser.add_argument('--repeated-aug', action='store_true') 108 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 109 | parser.set_defaults(repeated_aug=True) 110 | parser.add_argument('--crop-ratio', type=float, default=256/224, help='crop ratio for evaluation') 111 | 112 | # * Random Erase params 113 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 114 | help='Random erase prob (default: 0.25)') 115 | parser.add_argument('--remode', type=str, default='pixel', 116 | help='Random erase mode (default: "pixel")') 117 | parser.add_argument('--recount', type=int, default=1, 118 | help='Random erase count (default: 1)') 119 | parser.add_argument('--resplit', action='store_true', default=False, 120 | help='Do not random erase first (clean) augmentation split') 121 | 122 | # * Mixup params 123 | parser.add_argument('--mixup', type=float, default=0.8, 124 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 125 | parser.add_argument('--cutmix', type=float, default=1.0, 126 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 127 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 128 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 129 | parser.add_argument('--mixup-prob', type=float, default=1.0, 130 | help='Probability of performing mixup or cutmix when either/both is enabled') 131 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 132 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 133 | parser.add_argument('--mixup-mode', type=str, default='batch', 134 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 135 | 136 | # Dataset parameters 137 | parser.add_argument('--data-path', default=os.path.join(os.path.expanduser("~"), 'datasets/image_cls/imagenet1k/'), type=str, 138 | help='dataset path') 139 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19'], 140 | type=str, help='Image Net dataset path') 141 | parser.add_argument('--inat-category', default='name', 142 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 143 | type=str, help='semantic granularity') 144 | 145 | parser.add_argument('--output_dir', default='', 146 | help='path where to save, empty for no saving') 147 | parser.add_argument('--device', default='cuda', 148 | help='device to use for training / testing') 149 | parser.add_argument('--seed', default=0, type=int) 150 | parser.add_argument('--resume', default='', help='resume from checkpoint') 151 | parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler') 152 | parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp') 153 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 154 | help='start epoch') 155 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 156 | parser.add_argument('--num_workers', default=10, type=int) 157 | parser.add_argument('--pin-mem', action='store_true', 158 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 159 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 160 | help='') 161 | parser.set_defaults(pin_mem=True) 162 | 163 | # distributed training parameters 164 | parser.add_argument('--world_size', default=1, type=int, 165 | help='number of distributed processes') 166 | parser.add_argument("--local_rank", type=int) 167 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 168 | 169 | parser.add_argument('--auto-resume', action='store_true', help='auto resume') 170 | parser.add_argument('--finetune', action='store_true', help='finetune model') 171 | parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model') 172 | 173 | return parser 174 | 175 | 176 | def main(args): 177 | utils.init_distributed_mode(args) 178 | print(args) 179 | device = torch.device(args.device) 180 | 181 | # fix the seed for reproducibility 182 | seed = args.seed + utils.get_rank() 183 | torch.manual_seed(seed) 184 | np.random.seed(seed) 185 | # random.seed(seed) 186 | 187 | cudnn.benchmark = True 188 | 189 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 190 | 191 | if True: # args.distributed: 192 | num_tasks = utils.get_world_size() 193 | global_rank = utils.get_rank() 194 | if args.repeated_aug: 195 | sampler_train = RASampler( 196 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 197 | ) 198 | else: 199 | sampler_train = torch.utils.data.DistributedSampler( 200 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 201 | ) 202 | else: 203 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 204 | 205 | data_loader_train = torch.utils.data.DataLoader( 206 | dataset_train, sampler=sampler_train, 207 | batch_size=args.batch_size, 208 | num_workers=args.num_workers, 209 | pin_memory=args.pin_mem, 210 | drop_last=True, 211 | ) 212 | 213 | dataset_val, _ = build_dataset(is_train=False, args=args) 214 | val_sampler = torch.utils.data.DistributedSampler(dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 215 | data_loader_val = torch.utils.data.DataLoader( 216 | dataset_val, sampler=val_sampler, batch_size=int(1.5 * args.batch_size), 217 | shuffle=False, num_workers=args.num_workers, 218 | pin_memory=args.pin_mem, drop_last=False 219 | ) 220 | 221 | mixup_fn = None 222 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 223 | if mixup_active: 224 | mixup_fn = Mixup( 225 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 226 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 227 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 228 | 229 | print(f"Creating model: {args.model}") 230 | model = create_model( 231 | args.model, 232 | pretrained=args.pretrained, 233 | num_classes=args.nb_classes, 234 | drop_rate=args.drop, 235 | drop_path_rate=args.drop_path, 236 | drop_block_rate=args.drop_block, 237 | ) 238 | 239 | # TODO: finetuning 240 | 241 | model.to(device) 242 | model_without_ddp = model 243 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 244 | print('number of params:', n_parameters) 245 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 246 | args.lr = linear_scaled_lr 247 | print(f"Scaled learning rate (batch size: {args.batch_size * utils.get_world_size()}): {linear_scaled_lr}") 248 | optimizer = create_optimizer(args, model) 249 | 250 | if args.distributed: 251 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 252 | model_without_ddp = model.module 253 | #optimizer = create_optimizer(args, model) 254 | 255 | loss_scaler = NativeScaler() 256 | lr_scheduler, _ = create_scheduler(args, optimizer) 257 | criterion = LabelSmoothingCrossEntropy() 258 | 259 | if args.mixup > 0.: 260 | # smoothing is handled with mixup label transform 261 | criterion = SoftTargetCrossEntropy() 262 | elif args.smoothing: 263 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 264 | else: 265 | criterion = torch.nn.CrossEntropyLoss() 266 | 267 | max_accuracy = 0.0 268 | output_dir = Path(args.output_dir) 269 | 270 | if args.initial_checkpoint: 271 | print("Loading pretrained model") 272 | checkpoint = torch.load(args.initial_checkpoint, map_location='cpu') 273 | utils.load_checkpoint(model, checkpoint['model']) 274 | 275 | if args.auto_resume: 276 | if args.resume == '': 277 | args.resume = str(output_dir / "checkpoint.pth") 278 | if not os.path.exists(args.resume): 279 | args.resume = '' 280 | 281 | if args.resume: 282 | if args.resume.startswith('https'): 283 | checkpoint = torch.hub.load_state_dict_from_url( 284 | args.resume, map_location='cpu', check_hash=True) 285 | else: 286 | checkpoint = torch.load(args.resume, map_location='cpu') 287 | utils.load_checkpoint(model, checkpoint['model']) 288 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 289 | optimizer.load_state_dict(checkpoint['optimizer']) 290 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 291 | args.start_epoch = checkpoint['epoch'] + 1 292 | if 'scaler' in checkpoint and args.resume_loss_scaler: 293 | print("Resume with previous loss scaler state") 294 | loss_scaler.load_state_dict(checkpoint['scaler']) 295 | max_accuracy = checkpoint['max_accuracy'] 296 | 297 | if args.eval: 298 | test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) 299 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") 300 | return 301 | 302 | print(f"Start training, currnet max acc is {max_accuracy:.2f}") 303 | start_time = time.time() 304 | for epoch in range(args.start_epoch, args.epochs): 305 | 306 | if args.distributed: 307 | data_loader_train.sampler.set_epoch(epoch) 308 | 309 | train_stats = train_one_epoch( 310 | model, criterion, data_loader_train, 311 | optimizer, device, epoch, loss_scaler, 312 | args.clip_grad, mixup_fn, num_tasks, True, 313 | amp=args.amp, 314 | finetune=args.finetune 315 | ) 316 | 317 | lr_scheduler.step(epoch) 318 | 319 | test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=True, amp=args.amp) 320 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") 321 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 322 | print(f'Max accuracy: {max_accuracy:.2f}%') 323 | 324 | if args.output_dir: 325 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 326 | if test_stats["acc1"] == max_accuracy: 327 | checkpoint_paths.append(output_dir / 'model_best.pth') 328 | for checkpoint_path in checkpoint_paths: 329 | state_dict = { 330 | 'model': model_without_ddp.state_dict(), 331 | 'optimizer': optimizer.state_dict(), 332 | 'lr_scheduler': lr_scheduler.state_dict(), 333 | 'epoch': epoch, 334 | 'args': args, 335 | 'scaler': loss_scaler.state_dict(), 336 | 'max_accuracy': max_accuracy 337 | } 338 | utils.save_on_master(state_dict, checkpoint_path) 339 | 340 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 341 | **{f'test_{k}': v for k, v in test_stats.items()}, 342 | 'epoch': epoch, 343 | 'n_parameters': n_parameters} 344 | 345 | if args.output_dir and utils.is_main_process(): 346 | with (output_dir / "log.txt").open("a") as f: 347 | f.write(json.dumps(log_stats) + "\n") 348 | 349 | total_time = time.time() - start_time 350 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 351 | print('Training time {}'.format(total_time_str)) 352 | 353 | 354 | if __name__ == '__main__': 355 | parser = argparse.ArgumentParser('CrossViT training and evaluation script', parents=[get_args_parser()]) 356 | args = parser.parse_args() 357 | if args.output_dir: 358 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 359 | main(args) 360 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .crossvit import * -------------------------------------------------------------------------------- /models/crossvit.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | """ 6 | Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 7 | 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.hub 14 | from functools import partial 15 | 16 | 17 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 18 | from timm.models.registry import register_model 19 | from timm.models.vision_transformer import _cfg, Mlp, Block 20 | 21 | _model_urls = { 22 | 'crossvit_15_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth', 23 | 'crossvit_15_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth', 24 | 'crossvit_15_dagger_384': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth', 25 | 'crossvit_18_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth', 26 | 'crossvit_18_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth', 27 | 'crossvit_18_dagger_384': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth', 28 | 'crossvit_9_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth', 29 | 'crossvit_9_dagger_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth', 30 | 'crossvit_base_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth', 31 | 'crossvit_small_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth', 32 | 'crossvit_tiny_224': 'https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth', 33 | } 34 | 35 | 36 | class PatchEmbed(nn.Module): 37 | """ Image to Patch Embedding 38 | """ 39 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): 40 | super().__init__() 41 | img_size = to_2tuple(img_size) 42 | patch_size = to_2tuple(patch_size) 43 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 44 | self.img_size = img_size 45 | self.patch_size = patch_size 46 | self.num_patches = num_patches 47 | if multi_conv: 48 | if patch_size[0] == 12: 49 | self.proj = nn.Sequential( 50 | nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1), 55 | ) 56 | elif patch_size[0] == 16: 57 | self.proj = nn.Sequential( 58 | nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1), 61 | nn.ReLU(inplace=True), 62 | nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), 63 | ) 64 | else: 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 66 | 67 | def forward(self, x): 68 | B, C, H, W = x.shape 69 | # FIXME look at relaxing size constraints 70 | assert H == self.img_size[0] and W == self.img_size[1], \ 71 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 72 | x = self.proj(x).flatten(2).transpose(1, 2) 73 | return x 74 | 75 | 76 | class CrossAttention(nn.Module): 77 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 78 | super().__init__() 79 | self.num_heads = num_heads 80 | head_dim = dim // num_heads 81 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 82 | self.scale = qk_scale or head_dim ** -0.5 83 | 84 | self.wq = nn.Linear(dim, dim, bias=qkv_bias) 85 | self.wk = nn.Linear(dim, dim, bias=qkv_bias) 86 | self.wv = nn.Linear(dim, dim, bias=qkv_bias) 87 | self.attn_drop = nn.Dropout(attn_drop) 88 | self.proj = nn.Linear(dim, dim) 89 | self.proj_drop = nn.Dropout(proj_drop) 90 | 91 | def forward(self, x): 92 | 93 | B, N, C = x.shape 94 | q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) 95 | k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 96 | v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) 97 | 98 | attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N 99 | attn = attn.softmax(dim=-1) 100 | attn = self.attn_drop(attn) 101 | 102 | x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C 103 | x = self.proj(x) 104 | x = self.proj_drop(x) 105 | return x 106 | 107 | 108 | class CrossAttentionBlock(nn.Module): 109 | 110 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 111 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=True): 112 | super().__init__() 113 | self.norm1 = norm_layer(dim) 114 | self.attn = CrossAttention( 115 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 116 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 117 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 118 | self.has_mlp = has_mlp 119 | if has_mlp: 120 | self.norm2 = norm_layer(dim) 121 | mlp_hidden_dim = int(dim * mlp_ratio) 122 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 123 | 124 | def forward(self, x): 125 | x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x))) 126 | if self.has_mlp: 127 | x = x + self.drop_path(self.mlp(self.norm2(x))) 128 | 129 | return x 130 | 131 | 132 | class MultiScaleBlock(nn.Module): 133 | 134 | def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 135 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 136 | super().__init__() 137 | 138 | num_branches = len(dim) 139 | self.num_branches = num_branches 140 | # different branch could have different embedding size, the first one is the base 141 | self.blocks = nn.ModuleList() 142 | for d in range(num_branches): 143 | tmp = [] 144 | for i in range(depth[d]): 145 | tmp.append( 146 | Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, 147 | drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) 148 | if len(tmp) != 0: 149 | self.blocks.append(nn.Sequential(*tmp)) 150 | 151 | if len(self.blocks) == 0: 152 | self.blocks = None 153 | 154 | self.projs = nn.ModuleList() 155 | for d in range(num_branches): 156 | if dim[d] == dim[(d+1) % num_branches] and False: 157 | tmp = [nn.Identity()] 158 | else: 159 | tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])] 160 | self.projs.append(nn.Sequential(*tmp)) 161 | 162 | self.fusion = nn.ModuleList() 163 | for d in range(num_branches): 164 | d_ = (d+1) % num_branches 165 | nh = num_heads[d_] 166 | if depth[-1] == 0: # backward capability: 167 | self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, 168 | drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, 169 | has_mlp=False)) 170 | else: 171 | tmp = [] 172 | for _ in range(depth[-1]): 173 | tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, 174 | drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer, 175 | has_mlp=False)) 176 | self.fusion.append(nn.Sequential(*tmp)) 177 | 178 | self.revert_projs = nn.ModuleList() 179 | for d in range(num_branches): 180 | if dim[(d+1) % num_branches] == dim[d] and False: 181 | tmp = [nn.Identity()] 182 | else: 183 | tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] 184 | self.revert_projs.append(nn.Sequential(*tmp)) 185 | 186 | def forward(self, x): 187 | outs_b = [block(x_) for x_, block in zip(x, self.blocks)] 188 | # only take the cls token out 189 | proj_cls_token = [proj(x[:, 0:1]) for x, proj in zip(outs_b, self.projs)] 190 | # cross attention 191 | outs = [] 192 | for i in range(self.num_branches): 193 | tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1) 194 | tmp = self.fusion[i](tmp) 195 | reverted_proj_cls_token = self.revert_projs[i](tmp[:, 0:1, ...]) 196 | tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1) 197 | outs.append(tmp) 198 | return outs 199 | 200 | 201 | def _compute_num_patches(img_size, patches): 202 | return [i // p * i // p for i, p in zip(img_size,patches)] 203 | 204 | 205 | class VisionTransformer(nn.Module): 206 | """ Vision Transformer with support for patch or hybrid CNN input stage 207 | """ 208 | def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=([1, 3, 1], [1, 3, 1], [1, 3, 1]), 209 | num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 210 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, multi_conv=False): 211 | super().__init__() 212 | 213 | self.num_classes = num_classes 214 | if not isinstance(img_size, list): 215 | img_size = to_2tuple(img_size) 216 | self.img_size = img_size 217 | 218 | num_patches = _compute_num_patches(img_size, patch_size) 219 | self.num_branches = len(patch_size) 220 | 221 | self.patch_embed = nn.ModuleList() 222 | if hybrid_backbone is None: 223 | self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) 224 | for im_s, p, d in zip(img_size, patch_size, embed_dim): 225 | self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) 226 | else: 227 | self.pos_embed = nn.ParameterList() 228 | from .t2t import T2T, get_sinusoid_encoding 229 | tokens_type = 'transformer' if hybrid_backbone == 't2t' else 'performer' 230 | for idx, (im_s, p, d) in enumerate(zip(img_size, patch_size, embed_dim)): 231 | self.patch_embed.append(T2T(im_s, tokens_type=tokens_type, patch_size=p, embed_dim=d)) 232 | self.pos_embed.append(nn.Parameter(data=get_sinusoid_encoding(n_position=1 + num_patches[idx], d_hid=embed_dim[idx]), requires_grad=False)) 233 | 234 | del self.pos_embed 235 | self.pos_embed = nn.ParameterList([nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])) for i in range(self.num_branches)]) 236 | 237 | self.cls_token = nn.ParameterList([nn.Parameter(torch.zeros(1, 1, embed_dim[i])) for i in range(self.num_branches)]) 238 | self.pos_drop = nn.Dropout(p=drop_rate) 239 | 240 | total_depth = sum([sum(x[-2:]) for x in depth]) 241 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule 242 | dpr_ptr = 0 243 | self.blocks = nn.ModuleList() 244 | for idx, block_cfg in enumerate(depth): 245 | curr_depth = max(block_cfg[:-1]) + block_cfg[-1] 246 | dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] 247 | blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, 248 | qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, 249 | norm_layer=norm_layer) 250 | dpr_ptr += curr_depth 251 | self.blocks.append(blk) 252 | 253 | self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) 254 | self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) 255 | 256 | for i in range(self.num_branches): 257 | if self.pos_embed[i].requires_grad: 258 | trunc_normal_(self.pos_embed[i], std=.02) 259 | trunc_normal_(self.cls_token[i], std=.02) 260 | 261 | self.apply(self._init_weights) 262 | 263 | def _init_weights(self, m): 264 | if isinstance(m, nn.Linear): 265 | trunc_normal_(m.weight, std=.02) 266 | if isinstance(m, nn.Linear) and m.bias is not None: 267 | nn.init.constant_(m.bias, 0) 268 | elif isinstance(m, nn.LayerNorm): 269 | nn.init.constant_(m.bias, 0) 270 | nn.init.constant_(m.weight, 1.0) 271 | 272 | @torch.jit.ignore 273 | def no_weight_decay(self): 274 | out = {'cls_token'} 275 | if self.pos_embed[0].requires_grad: 276 | out.add('pos_embed') 277 | return out 278 | 279 | def get_classifier(self): 280 | return self.head 281 | 282 | def reset_classifier(self, num_classes, global_pool=''): 283 | self.num_classes = num_classes 284 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 285 | 286 | def forward_features(self, x): 287 | B, C, H, W = x.shape 288 | xs = [] 289 | for i in range(self.num_branches): 290 | x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x 291 | tmp = self.patch_embed[i](x_) 292 | cls_tokens = self.cls_token[i].expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 293 | tmp = torch.cat((cls_tokens, tmp), dim=1) 294 | tmp = tmp + self.pos_embed[i] 295 | tmp = self.pos_drop(tmp) 296 | xs.append(tmp) 297 | 298 | for blk in self.blocks: 299 | xs = blk(xs) 300 | 301 | # NOTE: was before branch token section, move to here to assure all branch token are before layer norm 302 | xs = [self.norm[i](x) for i, x in enumerate(xs)] 303 | out = [x[:, 0] for x in xs] 304 | 305 | return out 306 | 307 | def forward(self, x): 308 | xs = self.forward_features(x) 309 | ce_logits = [self.head[i](x) for i, x in enumerate(xs)] 310 | ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) 311 | return ce_logits 312 | 313 | 314 | 315 | 316 | @register_model 317 | def crossvit_tiny_224(pretrained=False, **kwargs): 318 | model = VisionTransformer(img_size=[240, 224], 319 | patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], 320 | num_heads=[3, 3], mlp_ratio=[4, 4, 1], qkv_bias=True, 321 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 322 | model.default_cfg = _cfg() 323 | if pretrained: 324 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_tiny_224'], map_location='cpu') 325 | model.load_state_dict(state_dict) 326 | return model 327 | 328 | 329 | @register_model 330 | def crossvit_small_224(pretrained=False, **kwargs): 331 | model = VisionTransformer(img_size=[240, 224], 332 | patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], 333 | num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True, 334 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 335 | model.default_cfg = _cfg() 336 | if pretrained: 337 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_small_224'], map_location='cpu') 338 | model.load_state_dict(state_dict) 339 | return model 340 | 341 | 342 | @register_model 343 | def crossvit_base_224(pretrained=False, **kwargs): 344 | model = VisionTransformer(img_size=[240, 224], 345 | patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], 346 | num_heads=[12, 12], mlp_ratio=[4, 4, 1], qkv_bias=True, 347 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 348 | model.default_cfg = _cfg() 349 | if pretrained: 350 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_base_224'], map_location='cpu') 351 | model.load_state_dict(state_dict) 352 | return model 353 | 354 | 355 | @register_model 356 | def crossvit_9_224(pretrained=False, **kwargs): 357 | model = VisionTransformer(img_size=[240, 224], 358 | patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], 359 | num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, 360 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 361 | model.default_cfg = _cfg() 362 | if pretrained: 363 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_9_224'], map_location='cpu') 364 | model.load_state_dict(state_dict) 365 | return model 366 | 367 | 368 | @register_model 369 | def crossvit_15_224(pretrained=False, **kwargs): 370 | model = VisionTransformer(img_size=[240, 224], 371 | patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], 372 | num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, 373 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 374 | model.default_cfg = _cfg() 375 | if pretrained: 376 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_224'], map_location='cpu') 377 | model.load_state_dict(state_dict) 378 | return model 379 | 380 | 381 | @register_model 382 | def crossvit_18_224(pretrained=False, **kwargs): 383 | model = VisionTransformer(img_size=[240, 224], 384 | patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], 385 | num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, 386 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 387 | model.default_cfg = _cfg() 388 | if pretrained: 389 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_224'], map_location='cpu') 390 | model.load_state_dict(state_dict) 391 | return model 392 | 393 | 394 | @register_model 395 | def crossvit_9_dagger_224(pretrained=False, **kwargs): 396 | model = VisionTransformer(img_size=[240, 224], 397 | patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], 398 | num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, 399 | norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) 400 | model.default_cfg = _cfg() 401 | if pretrained: 402 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_9_dagger_224'], map_location='cpu') 403 | model.load_state_dict(state_dict) 404 | return model 405 | 406 | @register_model 407 | def crossvit_15_dagger_224(pretrained=False, **kwargs): 408 | model = VisionTransformer(img_size=[240, 224], 409 | patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], 410 | num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, 411 | norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) 412 | model.default_cfg = _cfg() 413 | if pretrained: 414 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_dagger_224'], map_location='cpu') 415 | model.load_state_dict(state_dict) 416 | return model 417 | 418 | @register_model 419 | def crossvit_15_dagger_384(pretrained=False, **kwargs): 420 | model = VisionTransformer(img_size=[408, 384], 421 | patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], 422 | num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, 423 | norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) 424 | model.default_cfg = _cfg() 425 | if pretrained: 426 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_15_dagger_384'], map_location='cpu') 427 | model.load_state_dict(state_dict) 428 | return model 429 | 430 | @register_model 431 | def crossvit_18_dagger_224(pretrained=False, **kwargs): 432 | model = VisionTransformer(img_size=[240, 224], 433 | patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], 434 | num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, 435 | norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) 436 | model.default_cfg = _cfg() 437 | if pretrained: 438 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_dagger_224'], map_location='cpu') 439 | model.load_state_dict(state_dict) 440 | return model 441 | 442 | @register_model 443 | def crossvit_18_dagger_384(pretrained=False, **kwargs): 444 | model = VisionTransformer(img_size=[408, 384], 445 | patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], 446 | num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, 447 | norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) 448 | model.default_cfg = _cfg() 449 | if pretrained: 450 | state_dict = torch.hub.load_state_dict_from_url(_model_urls['crossvit_18_dagger_384'], map_location='cpu') 451 | model.load_state_dict(state_dict) 452 | return model 453 | -------------------------------------------------------------------------------- /models/t2t/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2t import T2T, get_sinusoid_encoding 2 | -------------------------------------------------------------------------------- /models/t2t/t2t.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Mostly copy-paste from https://github.com/yitu-opensource/T2T-ViT/blob/main/models/token_transformer.py 6 | """ 7 | 8 | 9 | import math 10 | 11 | import numpy as np 12 | from timm.models.layers import DropPath, to_2tuple 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def get_sinusoid_encoding(n_position, d_hid): 18 | ''' Sinusoid position encoding table ''' 19 | 20 | def get_position_angle_vec(position): 21 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 22 | 23 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 24 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 25 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 26 | 27 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 28 | 29 | 30 | class Token_performer(nn.Module): 31 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2=0.1): 32 | # def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.0, dp2=0.0): 33 | super().__init__() 34 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 35 | self.kqv = nn.Linear(dim, 3 * self.emb) 36 | self.dp = nn.Dropout(dp1) 37 | self.proj = nn.Linear(self.emb, self.emb) 38 | self.head_cnt = head_cnt 39 | self.norm1 = nn.LayerNorm(dim) 40 | self.norm2 = nn.LayerNorm(self.emb) 41 | self.epsilon = 1e-8 # for stable in division 42 | 43 | self.mlp = nn.Sequential( 44 | nn.Linear(self.emb, 1 * self.emb), 45 | nn.GELU(), 46 | nn.Linear(1 * self.emb, self.emb), 47 | nn.Dropout(dp2), 48 | ) 49 | 50 | self.m = int(self.emb * kernel_ratio) 51 | self.w = torch.randn(self.m, self.emb) 52 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 53 | 54 | def prm_exp(self, x): 55 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 56 | # and Simo Ryu (https://github.com/cloneofsimo) 57 | # ==== positive random features for gaussian kernels ==== 58 | # x = (B, T, hs) 59 | # w = (m, hs) 60 | # return : x : B, T, m 61 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 62 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 63 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 64 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 65 | 66 | return torch.exp(wtx - xd) / math.sqrt(self.m) 67 | 68 | def single_attn(self, x): 69 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) 70 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) 71 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 72 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) 73 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 74 | # skip connection 75 | y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 76 | 77 | return y 78 | 79 | def forward(self, x): 80 | x = self.single_attn(self.norm1(x)) 81 | x = x + self.mlp(self.norm2(x)) 82 | return x 83 | 84 | 85 | class Mlp(nn.Module): 86 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 87 | super().__init__() 88 | out_features = out_features or in_features 89 | hidden_features = hidden_features or in_features 90 | self.fc1 = nn.Linear(in_features, hidden_features) 91 | self.act = act_layer() 92 | self.fc2 = nn.Linear(hidden_features, out_features) 93 | self.drop = nn.Dropout(drop) 94 | 95 | def forward(self, x): 96 | x = self.fc1(x) 97 | x = self.act(x) 98 | x = self.drop(x) 99 | x = self.fc2(x) 100 | x = self.drop(x) 101 | return x 102 | 103 | 104 | class Attention(nn.Module): 105 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 106 | super().__init__() 107 | self.num_heads = num_heads 108 | self.in_dim = in_dim 109 | head_dim = dim // num_heads 110 | self.scale = qk_scale or head_dim ** -0.5 111 | 112 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) 113 | self.attn_drop = nn.Dropout(attn_drop) 114 | self.proj = nn.Linear(in_dim, in_dim) 115 | self.proj_drop = nn.Dropout(proj_drop) 116 | 117 | def forward(self, x): 118 | B, N, C = x.shape 119 | 120 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) 121 | q, k, v = qkv[0], qkv[1], qkv[2] 122 | 123 | attn = (q @ k.transpose(-2, -1)) * self.scale 124 | attn = attn.softmax(dim=-1) 125 | attn = self.attn_drop(attn) 126 | 127 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) 128 | x = self.proj(x) 129 | x = self.proj_drop(x) 130 | 131 | # skip connection 132 | x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection 133 | 134 | return x 135 | 136 | 137 | class Token_transformer(nn.Module): 138 | 139 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 140 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 141 | super().__init__() 142 | self.norm1 = norm_layer(dim) 143 | self.attn = Attention( 144 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 145 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 146 | self.norm2 = norm_layer(in_dim) 147 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) 148 | 149 | def forward(self, x): 150 | x = self.attn(self.norm1(x)) 151 | x = x + self.drop_path(self.mlp(self.norm2(x))) 152 | return x 153 | 154 | 155 | class T2T(nn.Module): 156 | """ 157 | Tokens-to-Token encoding module 158 | """ 159 | def __init__(self, img_size=224, patch_size=16, tokens_type='transformer', in_chans=3, embed_dim=768, token_dim=64): 160 | super().__init__() 161 | 162 | if patch_size == 12: 163 | kernel_size = ((7, 4, 2), (3, 3, 1), (3, 1, 1)) 164 | elif patch_size == 16: 165 | kernel_size = ((7, 4, 2), (3, 2, 1), (3, 2, 1)) 166 | else: 167 | raise ValueError(f"Unknown patch size {patch_size}") 168 | 169 | self.soft_split0 = nn.Unfold(kernel_size=to_2tuple(kernel_size[0][0]), stride=to_2tuple(kernel_size[0][1]), padding=to_2tuple(kernel_size[0][2])) 170 | self.soft_split1 = nn.Unfold(kernel_size=to_2tuple(kernel_size[1][0]), stride=to_2tuple(kernel_size[1][1]), padding=to_2tuple(kernel_size[1][2])) 171 | self.soft_split2 = nn.Unfold(kernel_size=to_2tuple(kernel_size[2][0]), stride=to_2tuple(kernel_size[2][1]), padding=to_2tuple(kernel_size[2][2])) 172 | 173 | if tokens_type == 'transformer': 174 | # print('adopt transformer encoder for tokens-to-token') 175 | 176 | self.attention1 = Token_transformer(dim=in_chans * (kernel_size[0][0] ** 2), in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 177 | self.attention2 = Token_transformer(dim=token_dim * (kernel_size[1][0] ** 2), in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 178 | self.project = nn.Linear(token_dim * (kernel_size[2][0] ** 2), embed_dim) 179 | 180 | elif tokens_type == 'performer': 181 | # print('adopt performer encoder for tokens-to-token') 182 | # self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 183 | # self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 184 | # self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 185 | 186 | #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) 187 | #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) 188 | self.attention1 = Token_performer(dim=in_chans * (kernel_size[0][0] ** 2), in_dim=token_dim, kernel_ratio=0.5) 189 | self.attention2 = Token_performer(dim=token_dim * (kernel_size[1][0] ** 2), in_dim=token_dim, kernel_ratio=0.5) 190 | self.project = nn.Linear(token_dim * (kernel_size[2][0] ** 2), embed_dim) 191 | # 192 | # elif tokens_type == 'convolution': # just for comparison with conolution, not our model 193 | # # for this tokens type, you need change forward as three convolution operation 194 | # print('adopt convolution layers for tokens-to-token') 195 | # self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 196 | # self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 197 | # self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 198 | 199 | self.num_patches = (img_size // (kernel_size[0][1] * kernel_size[1][1] * kernel_size[2][1])) * (img_size // (kernel_size[0][1] * kernel_size[1][1] * kernel_size[2][1])) # there are 3 sfot split, stride are 4,2,2 seperately 200 | 201 | def forward(self, x): 202 | # step0: soft split 203 | x = self.soft_split0(x).transpose(1, 2) 204 | 205 | # iteration1: re-structurization/reconstruction 206 | x = self.attention1(x) 207 | B, new_HW, C = x.shape 208 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 209 | # iteration1: soft split 210 | x = self.soft_split1(x).transpose(1, 2) 211 | 212 | # iteration2: re-structurization/reconstruction 213 | x = self.attention2(x) 214 | B, new_HW, C = x.shape 215 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 216 | # iteration2: soft split 217 | x = self.soft_split2(x).transpose(1, 2) 218 | 219 | # final tokens 220 | x = self.project(x) 221 | 222 | return x 223 | 224 | 225 | class SharedT2T(nn.Module): 226 | """ 227 | Tokens-to-Token encoding module 228 | """ 229 | def __init__(self, img_size=224, patch_size=16, tokens_type='transformer', in_chans=3, embed_dim=768, token_dim=64): 230 | super().__init__() 231 | 232 | if patch_size == 12: 233 | kernel_size = ((7, 4, 2), (3, 3, 1), (3, 1, 1)) 234 | elif patch_size == 16: 235 | kernel_size = ((7, 4, 2), (3, 2, 1), (3, 2, 1)) 236 | else: 237 | raise ValueError(f"Unknown patch size {patch_size}") 238 | 239 | 240 | if tokens_type == 'transformer': 241 | # print('adopt transformer encoder for tokens-to-token') 242 | self.soft_split0 = nn.Unfold(kernel_size=to_2tuple(kernel_size[0][0]), stride=to_2tuple(kernel_size[0][1]), padding=to_2tuple(kernel_size[0][2])) 243 | self.soft_split1 = nn.Unfold(kernel_size=to_2tuple(kernel_size[1][0]), stride=to_2tuple(kernel_size[1][1]), padding=to_2tuple(kernel_size[1][2])) 244 | self.soft_split2 = nn.Unfold(kernel_size=to_2tuple(kernel_size[2][0]), stride=to_2tuple(kernel_size[2][1]), padding=to_2tuple(kernel_size[2][2])) 245 | 246 | self.attention1 = Token_transformer(dim=in_chans * (kernel_size[0][0] ** 2), in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 247 | self.attention2 = Token_transformer(dim=token_dim * (kernel_size[1][0] ** 2), in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 248 | self.project = nn.Linear(token_dim * (kernel_size[2][0] ** 2), embed_dim) 249 | 250 | # elif tokens_type == 'performer': 251 | # print('adopt performer encoder for tokens-to-token') 252 | # self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 253 | # self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 254 | # self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 255 | # 256 | # #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) 257 | # #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) 258 | # self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5) 259 | # self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5) 260 | # self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 261 | # 262 | # elif tokens_type == 'convolution': # just for comparison with conolution, not our model 263 | # # for this tokens type, you need change forward as three convolution operation 264 | # print('adopt convolution layers for tokens-to-token') 265 | # self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 266 | # self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 267 | # self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 268 | 269 | self.num_patches = (img_size // (kernel_size[0][1] * kernel_size[1][1] * kernel_size[2][1])) * (img_size // (kernel_size[0][1] * kernel_size[1][1] * kernel_size[2][1])) # there are 3 sfot split, stride are 4,2,2 seperately 270 | 271 | def forward(self, x): 272 | # step0: soft split 273 | x = self.soft_split0(x).transpose(1, 2) 274 | 275 | # iteration1: re-structurization/reconstruction 276 | x = self.attention1(x) 277 | B, new_HW, C = x.shape 278 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 279 | # iteration1: soft split 280 | x = self.soft_split1(x).transpose(1, 2) 281 | 282 | # iteration2: re-structurization/reconstruction 283 | x = self.attention2(x) 284 | B, new_HW, C = x.shape 285 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 286 | # iteration2: soft split 287 | x = self.soft_split2(x).transpose(1, 2) 288 | 289 | # final tokens 290 | x = self.project(x) 291 | 292 | return x 293 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | timm==0.4.12 4 | fvcore 5 | einops 6 | submitit -------------------------------------------------------------------------------- /run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | A script to run multinode training with submitit. 6 | 7 | Mostly copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py 8 | """ 9 | 10 | import argparse 11 | import os 12 | import uuid 13 | from pathlib import Path 14 | 15 | import main as classification 16 | import submitit 17 | 18 | 19 | def parse_args(): 20 | classification_parser = classification.get_args_parser() 21 | parser = argparse.ArgumentParser("Submitit for CrossViT", parents=[classification_parser]) 22 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 23 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 24 | parser.add_argument("--timeout", default=360, type=int, help="Duration of the job") 25 | parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") 26 | parser.add_argument("--suffix", default="", type=str, help="Job dir. Leave empty for automatic.") 27 | 28 | parser.add_argument("--partition", default="npl", type=str, help="Partition where to submit") 29 | return parser.parse_args() 30 | 31 | 32 | def get_shared_folder() -> Path: 33 | 34 | if not (Path.cwd() / "checkpoint").exists(): 35 | (Path.cwd() / "checkpoint").mkdir(exist_ok=True) 36 | 37 | if (Path.cwd() / "checkpoint").is_dir(): 38 | p = Path.cwd() / "checkpoint" / "experiments" 39 | p.mkdir(exist_ok=True) 40 | return p 41 | raise RuntimeError("No shared folder available") 42 | 43 | 44 | def get_init_file(): 45 | # Init file must not exist, but it's parent dir must exist. 46 | os.makedirs(str(get_shared_folder()), exist_ok=True) 47 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 48 | if init_file.exists(): 49 | os.remove(str(init_file)) 50 | return init_file 51 | 52 | 53 | class Trainer(object): 54 | def __init__(self, args): 55 | self.args = args 56 | 57 | def __call__(self): 58 | import main as classification 59 | 60 | self._setup_gpu_args() 61 | classification.main(self.args) 62 | 63 | def checkpoint(self): 64 | import os 65 | import submitit 66 | 67 | self.args.seed = self.args.seed + 1 68 | self.args.dist_url = get_init_file().as_uri() 69 | checkpoint_file = os.path.join(self.args.output_dir, "checkpoint.pth") 70 | if os.path.exists(checkpoint_file): 71 | self.args.resume = checkpoint_file 72 | print("Requeuing ", self.args) 73 | empty_trainer = type(self)(self.args) 74 | return submitit.helpers.DelayedSubmission(empty_trainer) 75 | 76 | def _setup_gpu_args(self): 77 | import submitit 78 | from pathlib import Path 79 | 80 | job_env = submitit.JobEnvironment() 81 | self.args.output_dir = Path(str(self.args.output_dir).replace("%j", str(job_env.job_id))) 82 | self.args.gpu = job_env.local_rank 83 | self.args.rank = job_env.global_rank 84 | self.args.world_size = job_env.num_tasks 85 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 86 | 87 | 88 | def main(): 89 | args = parse_args() 90 | log_folder = args.model if args.suffix == '' else args.model + "-" + args.suffix 91 | if args.job_dir == "": 92 | args.job_dir = get_shared_folder() / Path(log_folder) 93 | else: 94 | if not os.path.exists(args.job_dir): 95 | os.makedirs(args.job_dir) 96 | args.job_dir = Path(os.path.join(args.job_dir, log_folder)) 97 | 98 | # Note that the folder will depend on the job_id, to easily track experiments 99 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 100 | 101 | num_gpus_per_node = args.ngpus 102 | nodes = args.nodes 103 | timeout_min = args.timeout 104 | 105 | partition = args.partition 106 | executor.update_parameters( 107 | mem_gb=40 * num_gpus_per_node, 108 | gpus_per_node=num_gpus_per_node, 109 | tasks_per_node=num_gpus_per_node, # one task per GPU 110 | cpus_per_task=10, 111 | nodes=nodes, 112 | timeout_min=timeout_min, # max is 60 * 72 113 | # Below are cluster dependent parameters 114 | slurm_partition=partition, 115 | slurm_signal_delay_s=120, 116 | slurm_gres=f'gpu:{args.ngpus}' 117 | ) 118 | 119 | job_name = log_folder 120 | executor.update_parameters(name=job_name) 121 | 122 | args.dist_url = get_init_file().as_uri() 123 | args.output_dir = args.job_dir 124 | 125 | trainer = Trainer(args) 126 | job = executor.submit(trainer) 127 | 128 | print("Submitted job_id:", job.job_id, " name: ", job_name) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Copy-paste from https://github.com/facebookresearch/deit/blob/main/samplers.py 6 | """ 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import math 11 | 12 | 13 | class RASampler(torch.utils.data.Sampler): 14 | """Sampler that restricts data loading to a subset of the dataset for distributed, 15 | with repeated augmentation. 16 | It ensures that different each augmented version of a sample will be visible to a 17 | different process (GPU) 18 | Heavily based on torch.utils.data.DistributedSampler 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 22 | if num_replicas is None: 23 | if not dist.is_available(): 24 | raise RuntimeError("Requires distributed package to be available") 25 | num_replicas = dist.get_world_size() 26 | if rank is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | rank = dist.get_rank() 30 | self.dataset = dataset 31 | self.num_replicas = num_replicas 32 | self.rank = rank 33 | self.epoch = 0 34 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 35 | self.total_size = self.num_samples * self.num_replicas 36 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 37 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 38 | self.shuffle = shuffle 39 | 40 | def __iter__(self): 41 | # deterministically shuffle based on epoch 42 | g = torch.Generator() 43 | g.manual_seed(self.epoch) 44 | if self.shuffle: 45 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 46 | else: 47 | indices = list(range(len(self.dataset))) 48 | 49 | # add extra samples to make it evenly divisible 50 | indices = [ele for ele in indices for i in range(3)] 51 | indices += indices[:(self.total_size - len(indices))] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright IBM All Rights Reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Misc functions, including distributed helpers. 6 | 7 | Mostly copy-paste from https://github.com/facebookresearch/deit/blob/main/utils.py 8 | """ 9 | 10 | import io 11 | import os 12 | import time 13 | from collections import defaultdict, deque 14 | import datetime 15 | 16 | import torch 17 | import torch.distributed as dist 18 | 19 | 20 | import io 21 | import os 22 | import time 23 | from collections import defaultdict, deque 24 | import datetime 25 | import tempfile 26 | import logging 27 | 28 | import torch 29 | import torch.distributed as dist 30 | from fvcore.common.checkpoint import Checkpointer 31 | 32 | 33 | class SmoothedValue(object): 34 | """Track a series of values and provide access to smoothed values over a 35 | window or the global series average. 36 | """ 37 | 38 | def __init__(self, window_size=20, fmt=None): 39 | if fmt is None: 40 | fmt = "{median:.4f} ({global_avg:.4f})" 41 | self.deque = deque(maxlen=window_size) 42 | self.total = 0.0 43 | self.count = 0 44 | self.fmt = fmt 45 | 46 | def update(self, value, n=1): 47 | self.deque.append(value) 48 | self.count += n 49 | self.total += value * n 50 | 51 | def synchronize_between_processes(self): 52 | """ 53 | Warning: does not synchronize the deque! 54 | """ 55 | if not is_dist_avail_and_initialized(): 56 | return 57 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 58 | dist.barrier() 59 | dist.all_reduce(t) 60 | t = t.tolist() 61 | self.count = int(t[0]) 62 | self.total = t[1] 63 | 64 | @property 65 | def median(self): 66 | d = torch.tensor(list(self.deque)) 67 | return d.median().item() 68 | 69 | @property 70 | def avg(self): 71 | d = torch.tensor(list(self.deque), dtype=torch.float32) 72 | return d.mean().item() 73 | 74 | @property 75 | def global_avg(self): 76 | return self.total / self.count 77 | 78 | @property 79 | def max(self): 80 | return max(self.deque) 81 | 82 | @property 83 | def value(self): 84 | return self.deque[-1] 85 | 86 | def __str__(self): 87 | return self.fmt.format( 88 | median=self.median, 89 | avg=self.avg, 90 | global_avg=self.global_avg, 91 | max=self.max, 92 | value=self.value) 93 | 94 | 95 | class MetricLogger(object): 96 | def __init__(self, delimiter="\t"): 97 | self.meters = defaultdict(SmoothedValue) 98 | self.delimiter = delimiter 99 | 100 | def update(self, **kwargs): 101 | for k, v in kwargs.items(): 102 | if isinstance(v, torch.Tensor): 103 | v = v.item() 104 | assert isinstance(v, (float, int)) 105 | self.meters[k].update(v) 106 | 107 | def __getattr__(self, attr): 108 | if attr in self.meters: 109 | return self.meters[attr] 110 | if attr in self.__dict__: 111 | return self.__dict__[attr] 112 | raise AttributeError("'{}' object has no attribute '{}'".format( 113 | type(self).__name__, attr)) 114 | 115 | def __str__(self): 116 | loss_str = [] 117 | for name, meter in self.meters.items(): 118 | loss_str.append( 119 | "{}: {}".format(name, str(meter)) 120 | ) 121 | return self.delimiter.join(loss_str) 122 | 123 | def synchronize_between_processes(self): 124 | for meter in self.meters.values(): 125 | meter.synchronize_between_processes() 126 | 127 | def add_meter(self, name, meter): 128 | self.meters[name] = meter 129 | 130 | def log_every(self, iterable, print_freq, header=None): 131 | i = 0 132 | if not header: 133 | header = '' 134 | start_time = time.time() 135 | end = time.time() 136 | iter_time = SmoothedValue(fmt='{avg:.4f}') 137 | data_time = SmoothedValue(fmt='{avg:.4f}') 138 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 139 | log_msg = [ 140 | header, 141 | '[{0' + space_fmt + '}/{1}]', 142 | 'eta: {eta}', 143 | '{meters}', 144 | 'time: {time}', 145 | 'data: {data}' 146 | ] 147 | if torch.cuda.is_available(): 148 | log_msg.append('max mem: {memory:.0f}') 149 | log_msg = self.delimiter.join(log_msg) 150 | MB = 1024.0 * 1024.0 151 | for obj in iterable: 152 | data_time.update(time.time() - end) 153 | yield obj 154 | iter_time.update(time.time() - end) 155 | if i % print_freq == 0 or i == len(iterable) - 1: 156 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 157 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 158 | if torch.cuda.is_available(): 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time), 163 | memory=torch.cuda.max_memory_allocated() / MB)) 164 | else: 165 | print(log_msg.format( 166 | i, len(iterable), eta=eta_string, 167 | meters=str(self), 168 | time=str(iter_time), data=str(data_time))) 169 | i += 1 170 | end = time.time() 171 | total_time = time.time() - start_time 172 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 173 | print('{} Total time: {} ({:.4f} s / it)'.format( 174 | header, total_time_str, total_time / len(iterable))) 175 | 176 | 177 | def _load_checkpoint_for_ema(model_ema, checkpoint): 178 | """ 179 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 180 | """ 181 | mem_file = io.BytesIO() 182 | torch.save(checkpoint, mem_file) 183 | mem_file.seek(0) 184 | model_ema._load_checkpoint(mem_file) 185 | 186 | 187 | def load_checkpoint(model, state_dict, mode=None): 188 | 189 | # reuse Checkpointer in fvcore to support flexible loading 190 | ckpt = Checkpointer(model, save_to_disk=False) 191 | logging.basicConfig() 192 | ckpt.logger.setLevel(logging.INFO) 193 | # since Checkpointer requires the weight to be put under `model` field, we need to save it to disk 194 | tmp_path = tempfile.NamedTemporaryFile('w+b') 195 | torch.save({'model': state_dict}, tmp_path.name) 196 | ckpt.load(tmp_path.name) 197 | 198 | def setup_for_distributed(is_master): 199 | """ 200 | This function disables printing when not in master process 201 | """ 202 | import builtins as __builtin__ 203 | builtin_print = __builtin__.print 204 | 205 | def print(*args, **kwargs): 206 | force = kwargs.pop('force', False) 207 | if is_master or force: 208 | builtin_print(*args, **kwargs) 209 | 210 | __builtin__.print = print 211 | 212 | 213 | def is_dist_avail_and_initialized(): 214 | if not dist.is_available(): 215 | return False 216 | if not dist.is_initialized(): 217 | return False 218 | return True 219 | 220 | 221 | def get_world_size(): 222 | if not is_dist_avail_and_initialized(): 223 | return 1 224 | return dist.get_world_size() 225 | 226 | 227 | def get_rank(): 228 | if not is_dist_avail_and_initialized(): 229 | return 0 230 | return dist.get_rank() 231 | 232 | 233 | def is_main_process(): 234 | return get_rank() == 0 235 | 236 | 237 | def save_on_master(*args, **kwargs): 238 | if is_main_process(): 239 | torch.save(*args, **kwargs) 240 | 241 | 242 | def init_distributed_mode(args): 243 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 244 | args.rank = int(os.environ["RANK"]) 245 | args.world_size = int(os.environ['WORLD_SIZE']) 246 | args.gpu = int(os.environ['LOCAL_RANK']) 247 | elif 'SLURM_PROCID' in os.environ: 248 | args.rank = int(os.environ['SLURM_PROCID']) 249 | args.gpu = args.rank % torch.cuda.device_count() 250 | else: 251 | print('Not using distributed mode') 252 | args.distributed = False 253 | return 254 | 255 | args.distributed = True 256 | 257 | torch.cuda.set_device(args.gpu) 258 | args.dist_backend = 'nccl' 259 | print('| distributed init (rank {}): {}'.format( 260 | args.rank, args.dist_url), flush=True) 261 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 262 | world_size=args.world_size, rank=args.rank) 263 | torch.distributed.barrier() 264 | setup_for_distributed(args.rank == 0) 265 | --------------------------------------------------------------------------------