├── .github └── workflows │ └── jekyll-gh-pages.yml ├── README.md ├── hierarchical_architecture ├── README.md ├── checkpoints │ ├── spectformer_b │ │ └── log.txt │ ├── spectformer_l │ │ └── log.txt │ └── spectformer_s │ │ └── log.txt ├── configs │ └── spectformer │ │ ├── spectformer_b.py │ │ ├── spectformer_l.py │ │ └── spectformer_s.py ├── datasets.py ├── engine.py ├── loss │ ├── __init__.py │ └── cross_entropy.py ├── losses.py ├── main.py ├── main.sh ├── mcloader │ ├── __init__.py │ ├── classification.py │ ├── data_prefetcher.py │ ├── image_list.py │ ├── imagenet.py │ └── mcloader.py ├── samplers.py ├── spectformer.py ├── util │ ├── __init__.py │ ├── checkpoint_saver.py │ ├── flops_counter.py │ └── util.py └── utils.py └── vanilla_architecture ├── README.md ├── datasets.py ├── engine.py ├── figs ├── GFNet_filter.jpg ├── SpectFormer.png ├── SpectFormer_filter.jpg ├── SpectFormer_main.png ├── inference.png └── sota.jpg ├── infer.py ├── logs ├── spectformer-b │ └── log.txt ├── spectformer-s │ └── log.txt ├── spectformer-t │ └── log.txt └── spectformer-xs │ └── log.txt ├── losses.py ├── main.sh ├── main_spectformer.py ├── main_spectformer_transfer.py ├── samplers.py ├── spectformer.py └── utils.py /.github/workflows/jekyll-gh-pages.yml: -------------------------------------------------------------------------------- 1 | # Sample workflow for building and deploying a Jekyll site to GitHub Pages 2 | name: Deploy Jekyll with GitHub Pages dependencies preinstalled 3 | 4 | on: 5 | # Runs on pushes targeting the default branch 6 | push: 7 | branches: ["main"] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | # Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages 13 | permissions: 14 | contents: read 15 | pages: write 16 | id-token: write 17 | 18 | # Allow only one concurrent deployment, skipping runs queued between the run in-progress and latest queued. 19 | # However, do NOT cancel in-progress runs as we want to allow these production deployments to complete. 20 | concurrency: 21 | group: "pages" 22 | cancel-in-progress: false 23 | 24 | jobs: 25 | # Build job 26 | build: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - name: Checkout 30 | uses: actions/checkout@v3 31 | - name: Setup Pages 32 | uses: actions/configure-pages@v3 33 | - name: Build with Jekyll 34 | uses: actions/jekyll-build-pages@v1 35 | with: 36 | source: ./ 37 | destination: ./_site 38 | - name: Upload artifact 39 | uses: actions/upload-pages-artifact@v1 40 | 41 | # Deployment job 42 | deploy: 43 | environment: 44 | name: github-pages 45 | url: ${{ steps.deployment.outputs.page_url }} 46 | runs-on: ubuntu-latest 47 | needs: build 48 | steps: 49 | - name: Deploy to GitHub Pages 50 | id: deployment 51 | uses: actions/deploy-pages@v2 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpectFormer: Frequency and Attention is what you need in a Vision Transformer 2 | 3 | 4 | ![Intro](vanilla_architecture/figs/SpectFormer.png) 5 | 6 | [![Paper](http://img.shields.io/badge/Paper-arxiv.2304.06446-B31B1B.svg)](https://arxiv.org/abs/2304.06446) 7 | [![Project Page](https://img.shields.io/badge/Project%20Page-SpectFormer-B31B1B.svg)](https://badripatro.github.io/SpectFormers/) 8 | 9 | ## Abstract 10 | 11 | ''' 12 | Vision transformers have been applied successfully for image recognition tasks. There have been either multi-headed self-attention based (ViT \cite{dosovitskiy2020image}, DeIT, \cite{touvron2021training}) similar to the original work in textual models or more recently based on spectral layers (Fnet\cite{lee2021fnet}, GFNet\cite{rao2021global}, AFNO\cite{guibas2021efficient}). We hypothesize that both spectral and multi-headed attention plays a major role. We investigate this hypothesis through this work and observe that indeed combining spectral and multi-headed attention layers provides a better transformer architecture. We thus propose the novel Spectformer architecture for transformers that combines spectral and multi-headed attention layers. We believe that the resulting representation allows the transformer to capture the feature representation appropriately and it yields improved performance over other transformer representations. For instance, it improves the top-1 accuracy by 2\% on ImageNet compared to both GFNet-H and LiT. SpectFormer-S reaches 84.25\% top-1 accuracy on ImageNet-1K (state of the art for small version). Further, Spectformer-L achieves 85.7\% that is the state of the art for the comparable base version of the transformers. We further ensure that we obtain reasonable results in other scenarios such as transfer learning on standard datasets such as CIFAR-10, CIFAR-100, Oxford-IIIT-flower, and Standford Car datasets. We then investigate its use in downstream tasks such of object detection and instance segmentation on MS-COCO dataset and observe that Spectformer shows consistent performance that is comparable to the best backbones and can be further optimized and improved. Hence, we believe that combined spectral and attention layers are what are needed for vision transformers. 13 | 14 | ''' 15 | 16 | ![Main Model](vanilla_architecture/figs/SpectFormer_main.png) 17 | 18 | 19 | 20 | ## SOTA Performance on the ImageNet-1K dataset for image size 224 x 224 for Image Recognition Task 21 | 22 | ![SOTA](vanilla_architecture/figs/sota.jpg) 23 | 24 | 25 | 26 | ## Training 27 | 28 | ### Train SpectFormer for Vanilla Architecture 29 | ``` 30 | bash vanilla_architecture/main.sh 31 | ``` 32 | 33 | 34 | ### Train SpectFormer for Hierarchical Architecture 35 | ``` 36 | bash hierarchical_architecture/main.sh 37 | ``` 38 | 39 | ## Inference Results 40 | ![Inference](vanilla_architecture/figs/inference.png) 41 | 42 | 43 | ## Citation 44 | 45 | ``` 46 | @article{patro2023spectformer, 47 | title={SpectFormer: Frequency and Attention is what you need in a Vision Transformer}, 48 | author={Patro, Badri N. and Namboodiri, Vinay P. and Agneeswaran, Vijay Srinivas}, 49 | journal={arXiv preprint arXiv:2304.06446}, 50 | year = {2023} 51 | } 52 | ``` 53 | 54 | # Acknowledgements 55 | Thanks the contribution of [DeiT](https://github.com/facebookresearch/deit), [WaveVit](https://github.com/YehLi/ImageNetModel) and [GFNet](https://github.com/raoyongming/GFNet). 56 | -------------------------------------------------------------------------------- /hierarchical_architecture/README.md: -------------------------------------------------------------------------------- 1 | # SpectFormer Hierarchical Model 2 | 3 | ### Requirement: 4 | * PyTorch 1.10.0+ 5 | * Python3.8 6 | * CUDA 10.1+ 7 | * [timm](https://github.com/rwightman/pytorch-image-models)==0.4.5 8 | * [tlt](https://github.com/zihangJiang/TokenLabeling)==0.1.0 9 | * pyyaml 10 | * apex-amp 11 | 12 | 13 | ## Data Preparation 14 | 15 | Download and extract ImageNet images from http://image-net.org/. The directory structure should be 16 | 17 | ``` 18 | 19 | │ILSVRC2012/ 20 | ├──train/ 21 | │ ├── n01440764 22 | │ │ ├── n01440764_10026.JPEG 23 | │ │ ├── n01440764_10027.JPEG 24 | │ │ ├── ...... 25 | │ ├── ...... 26 | ├──val/ 27 | │ ├── n01440764 28 | │ │ ├── ILSVRC2012_val_00000293.JPEG 29 | │ │ ├── ILSVRC2012_val_00002138.JPEG 30 | │ │ ├── ...... 31 | │ ├── ...... 32 | 33 | ``` 34 | 35 | 36 | ### Model Zoo 37 | 38 | We provide baseline SVT Hierarchical models pre-trained on ImageNet1k 2012, using the distilled version of our method: 39 | 40 | | name | resolution | #params | FLOPs | Top-1 Acc. | Top-5 Acc. | 41 | | :---: | :---: | :---: | :---: | :---: | :---: | 42 | | SpectFormer-H-S | 224 | 22.2M | 3.9 | 84.3 | 96.9 | 43 | | SpectFormer-H-B | 224 | 33.1M | 6.3 | 85.0 | 97.1 | 44 | | SpectFormer-H-L | 224 | 54.7M | 12.7 | 85.7 | 97.3 | 45 | 46 | 47 | ### Train SpectFormer small model 48 | ``` 49 | python3 -m torch.distributed.launch \ 50 | --nproc_per_node=8 \ 51 | --nnodes=1 \ 52 | --node_rank=0 \ 53 | --master_addr="localhost" \ 54 | --master_port=12346 \ 55 | --use_env main.py --config configs/spectformer/spectformer_s.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \ 56 | --token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet 57 | ``` 58 | 59 | 60 | ### Train SpectFormer Base model 61 | ``` 62 | python3 -m torch.distributed.launch \ 63 | --nproc_per_node=8 \ 64 | --nnodes=1 \ 65 | --node_rank=0 \ 66 | --master_addr="localhost" \ 67 | --master_port=12346 \ 68 | --use_env main.py --config configs/spectformer/spectformer_b.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \ 69 | --token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet 70 | ``` 71 | 72 | 73 | 74 | ### Train SpectFormer Large model 75 | ``` 76 | python3 -m torch.distributed.launch \ 77 | --nproc_per_node=8 \ 78 | --nnodes=1 \ 79 | --node_rank=0 \ 80 | --master_addr="localhost" \ 81 | --master_port=12346 \ 82 | --use_env main.py --config configs/spectformer/spectformer_l.py --data-path /export/home/dataset/imagenet --epochs 310 --batch-size 128 \ 83 | --token-label --token-label-size 7 --token-label-data /export/home/dataset/imagenet/label_top5_train_nfnet 84 | ``` 85 | 86 | 87 | ## Citation 88 | 89 | If you find this repo helpful, please consider citing us. 90 | 91 | ``` 92 | @article{patro2023spectformer, 93 | title={SpectFormer: Frequency and Attention is what you need in a Vision Transformer}, 94 | author={Patro, Badri N and Namboodiri, Vinay P and Agneeswaran, Vijay Srinivas}, 95 | journal={arXiv preprint arXiv:2304.06446}, 96 | year={2023} 97 | } 98 | 99 | ``` 100 | 101 | 102 | # Acknowledgements 103 | Our code is based on [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [DeiT](https://github.com/facebookresearch/deit), [WaveVit](https://github.com/YehLi/ImageNetModel) and [GFNet](https://github.com/raoyongming/GFNet). 104 | -------------------------------------------------------------------------------- /hierarchical_architecture/configs/spectformer/spectformer_b.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='spectformer_b', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/spectformer_b', 6 | ) -------------------------------------------------------------------------------- /hierarchical_architecture/configs/spectformer/spectformer_l.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='spectformer_l', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/spectformer_l', 6 | ) -------------------------------------------------------------------------------- /hierarchical_architecture/configs/spectformer/spectformer_s.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='spectformer_s', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/spectformer_s', 6 | ) -------------------------------------------------------------------------------- /hierarchical_architecture/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torchvision.datasets.folder import ImageFolder, default_loader 6 | 7 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 8 | from timm.data import create_transform 9 | from mcloader import ClassificationDataset 10 | 11 | 12 | class INatDataset(ImageFolder): 13 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 14 | category='name', loader=default_loader): 15 | self.transform = transform 16 | self.loader = loader 17 | self.target_transform = target_transform 18 | self.year = year 19 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 20 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 21 | with open(path_json) as json_file: 22 | data = json.load(json_file) 23 | 24 | with open(os.path.join(root, 'categories.json')) as json_file: 25 | data_catg = json.load(json_file) 26 | 27 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 28 | 29 | with open(path_json_for_targeter) as json_file: 30 | data_for_targeter = json.load(json_file) 31 | 32 | targeter = {} 33 | indexer = 0 34 | for elem in data_for_targeter['annotations']: 35 | king = [] 36 | king.append(data_catg[int(elem['category_id'])][category]) 37 | if king[0] not in targeter.keys(): 38 | targeter[king[0]] = indexer 39 | indexer += 1 40 | self.nb_classes = len(targeter) 41 | 42 | self.samples = [] 43 | for elem in data['images']: 44 | cut = elem['file_name'].split('/') 45 | target_current = int(cut[2]) 46 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 47 | 48 | categors = data_catg[target_current] 49 | target_current_true = targeter[categors[category]] 50 | self.samples.append((path_current, target_current_true)) 51 | 52 | # __getitem__ and __len__ inherited from ImageFolder 53 | 54 | 55 | def build_dataset(is_train, args): 56 | transform = build_transform(is_train, args) 57 | 58 | if args.data_set == 'CIFAR': 59 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 60 | nb_classes = 100 61 | elif args.data_set == 'IMNET': 62 | if not args.use_mcloader: 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | else: 66 | dataset = ClassificationDataset( 67 | 'train' if is_train else 'val', 68 | pipeline=transform 69 | ) 70 | nb_classes = 1000 71 | elif args.data_set == 'INAT': 72 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 73 | category=args.inat_category, transform=transform) 74 | nb_classes = dataset.nb_classes 75 | elif args.data_set == 'INAT19': 76 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 77 | category=args.inat_category, transform=transform) 78 | nb_classes = dataset.nb_classes 79 | 80 | return dataset, nb_classes 81 | 82 | 83 | def build_transform(is_train, args): 84 | resize_im = args.input_size > 32 85 | if is_train: 86 | # this should always dispatch to transforms_imagenet_train 87 | transform = create_transform( 88 | input_size=args.input_size, 89 | is_training=True, 90 | color_jitter=args.color_jitter, 91 | auto_augment=args.aa, 92 | interpolation=args.train_interpolation, 93 | re_prob=args.reprob, 94 | re_mode=args.remode, 95 | re_count=args.recount, 96 | ) 97 | if not resize_im: 98 | # replace RandomResizedCropAndInterpolation with 99 | # RandomCrop 100 | transform.transforms[0] = transforms.RandomCrop( 101 | args.input_size, padding=4) 102 | return transform 103 | 104 | t = [] 105 | if resize_im: 106 | #size = int((256 / 224) * args.input_size) 107 | size = int((1.0 / 0.96) * args.input_size) 108 | t.append( 109 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 110 | ) 111 | t.append(transforms.CenterCrop(args.input_size)) 112 | 113 | t.append(transforms.ToTensor()) 114 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 115 | return transforms.Compose(t) 116 | -------------------------------------------------------------------------------- /hierarchical_architecture/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | import logging 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | from tlt.data import create_token_label_target 13 | from losses import DistillationLoss 14 | import utils 15 | 16 | 17 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 21 | set_training_mode=True, 22 | fp32=False, args=None): 23 | model.train(set_training_mode) 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 10 28 | 29 | _logger = logging.getLogger('train') 30 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 31 | if not args.prefetcher: 32 | samples = samples.to(device, non_blocking=True) 33 | targets = targets.to(device, non_blocking=True) 34 | 35 | if mixup_fn is not None: 36 | samples, targets = mixup_fn(samples, targets) 37 | else: 38 | if args.token_label and args.token_label_data and not data_loader.mixup_enabled: 39 | targets = create_token_label_target( 40 | targets, 41 | num_classes=args.nb_classes, 42 | smoothing=args.smoothing, 43 | label_size=args.token_label_size) 44 | 45 | with torch.cuda.amp.autocast(enabled=not fp32): 46 | outputs = model(samples) 47 | 48 | if args.token_label: 49 | loss = criterion(outputs, targets) 50 | else: 51 | loss = criterion(samples, outputs, targets) 52 | 53 | loss_value = loss.item() 54 | 55 | if not math.isfinite(loss_value): 56 | _logger.info("Loss is {}, stopping training".format(loss_value)) 57 | sys.exit(1) 58 | 59 | optimizer.zero_grad() 60 | 61 | # this attribute is added by timm on one optimizer (adahessian) 62 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 63 | loss_scaler(loss, optimizer, clip_grad=max_norm, 64 | parameters=model.parameters(), create_graph=is_second_order) 65 | 66 | torch.cuda.synchronize() 67 | if model_ema is not None: 68 | model_ema.update(model) 69 | 70 | metric_logger.update(loss=loss_value) 71 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 72 | # gather the stats from all processes 73 | metric_logger.synchronize_between_processes() 74 | _logger.info("Averaged stats:" + str(metric_logger)) 75 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 76 | 77 | 78 | @torch.no_grad() 79 | def evaluate(data_loader, model, device): 80 | criterion = torch.nn.CrossEntropyLoss() 81 | 82 | metric_logger = utils.MetricLogger(delimiter=" ") 83 | _logger = logging.getLogger('train') 84 | header = 'Test:' 85 | 86 | # switch to evaluation mode 87 | model.eval() 88 | 89 | for images, target in metric_logger.log_every(data_loader, 10, header): 90 | images = images.to(device, non_blocking=True) 91 | target = target.to(device, non_blocking=True) 92 | 93 | # compute output 94 | with torch.cuda.amp.autocast(): 95 | output = model(images) 96 | loss = criterion(output, target) 97 | 98 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 99 | 100 | batch_size = images.shape[0] 101 | metric_logger.update(loss=loss.item()) 102 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 103 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 104 | # gather the stats from all processes 105 | metric_logger.synchronize_between_processes() 106 | _logger.info('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 107 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 108 | 109 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 110 | -------------------------------------------------------------------------------- /hierarchical_architecture/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy import TokenLabelGTCrossEntropy, TokenLabelSoftTargetCrossEntropy, TokenLabelCrossEntropy -------------------------------------------------------------------------------- /hierarchical_architecture/loss/cross_entropy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sea Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Loss functions for VOLO 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | 20 | class SoftTargetCrossEntropy(nn.Module): 21 | """ 22 | The native CE loss with soft target 23 | input: x is output of model, target is ground truth 24 | return: loss 25 | """ 26 | def __init__(self): 27 | super(SoftTargetCrossEntropy, self).__init__() 28 | 29 | def forward(self, x, target): 30 | N_rep = x.shape[0] 31 | N = target.shape[0] 32 | if not N == N_rep: 33 | target = target.repeat(N_rep // N, 1) 34 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 35 | return loss.mean() 36 | 37 | 38 | class TokenLabelGTCrossEntropy(nn.Module): 39 | """ 40 | Token labeling dense loss with ground gruth, see more from token labeling 41 | input: x is output of model, target is ground truth 42 | return: loss 43 | """ 44 | def __init__(self, 45 | dense_weight=1.0, 46 | cls_weight=1.0, 47 | mixup_active=True, 48 | smoothing=0.1, 49 | classes=1000): 50 | super(TokenLabelGTCrossEntropy, self).__init__() 51 | 52 | self.CE = SoftTargetCrossEntropy() 53 | 54 | self.dense_weight = dense_weight 55 | self.smoothing = smoothing 56 | self.mixup_active = mixup_active 57 | self.classes = classes 58 | self.cls_weight = cls_weight 59 | assert dense_weight + cls_weight > 0 60 | 61 | def forward(self, x, target): 62 | 63 | output, aux_output, bb = x 64 | bbx1, bby1, bbx2, bby2 = bb 65 | 66 | B, N, C = aux_output.shape 67 | if len(target.shape) == 2: 68 | target_cls = target 69 | target_aux = target.repeat(1, N).reshape(B * N, C) 70 | else: 71 | ground_truth = target[:, :, 0] 72 | target_cls = target[:, :, 1] 73 | ratio = (0.9 - 0.4 * 74 | (ground_truth.max(-1)[1] == target_cls.max(-1)[1]) 75 | ).unsqueeze(-1) 76 | target_cls = target_cls * ratio + ground_truth * (1 - ratio) 77 | target_aux = target[:, :, 2:] 78 | target_aux = target_aux.transpose(1, 2).reshape(-1, C) 79 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N) 80 | if lam < 1: 81 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0) 82 | 83 | aux_output = aux_output.reshape(-1, C) 84 | 85 | loss_cls = self.CE(output, target_cls) 86 | loss_aux = self.CE(aux_output, target_aux) 87 | 88 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux 89 | 90 | 91 | class TokenLabelSoftTargetCrossEntropy(nn.Module): 92 | """ 93 | Token labeling dense loss with soft target, see more from token labeling 94 | input: x is output of model, target is ground truth 95 | return: loss 96 | """ 97 | def __init__(self): 98 | super(TokenLabelSoftTargetCrossEntropy, self).__init__() 99 | 100 | def forward(self, x, target): 101 | N_rep = x.shape[0] 102 | N = target.shape[0] 103 | if not N == N_rep: 104 | target = target.repeat(N_rep // N, 1) 105 | if len(target.shape) == 3 and target.shape[-1] == 2: 106 | target = target[:, :, 1] 107 | loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) 108 | return loss.mean() 109 | 110 | 111 | class TokenLabelCrossEntropy(nn.Module): 112 | """ 113 | Token labeling loss without ground truth 114 | input: x is output of model, target is ground truth 115 | return: loss 116 | """ 117 | def __init__(self, 118 | dense_weight=1.0, 119 | cls_weight=1.0, 120 | mixup_active=True, 121 | classes=1000): 122 | """ 123 | Constructor Token labeling loss. 124 | """ 125 | super(TokenLabelCrossEntropy, self).__init__() 126 | 127 | self.CE = SoftTargetCrossEntropy() 128 | 129 | self.dense_weight = dense_weight 130 | self.mixup_active = mixup_active 131 | self.classes = classes 132 | self.cls_weight = cls_weight 133 | assert dense_weight + cls_weight > 0 134 | 135 | def forward(self, x, target): 136 | 137 | output, aux_output, bb = x 138 | bbx1, bby1, bbx2, bby2 = bb 139 | 140 | B, N, C = aux_output.shape 141 | if len(target.shape) == 2: 142 | target_cls = target 143 | target_aux = target.repeat(1, N).reshape(B * N, C) 144 | else: 145 | target_cls = target[:, :, 1] 146 | target_aux = target[:, :, 2:] 147 | target_aux = target_aux.transpose(1, 2).reshape(-1, C) 148 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / N) 149 | if lam < 1: 150 | target_cls = lam * target_cls + (1 - lam) * target_cls.flip(0) 151 | 152 | aux_output = aux_output.reshape(-1, C) 153 | loss_cls = self.CE(output, target_cls) 154 | loss_aux = self.CE(aux_output, target_aux) 155 | return self.cls_weight * loss_cls + self.dense_weight * loss_aux 156 | -------------------------------------------------------------------------------- /hierarchical_architecture/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the knowledge distillation loss 3 | """ 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | class DistillationLoss(torch.nn.Module): 9 | """ 10 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 11 | taking a teacher model prediction and using it as additional supervision. 12 | """ 13 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 14 | distillation_type: str, alpha: float, tau: float): 15 | super().__init__() 16 | self.base_criterion = base_criterion 17 | self.teacher_model = teacher_model 18 | assert distillation_type in ['none', 'soft', 'hard'] 19 | self.distillation_type = distillation_type 20 | self.alpha = alpha 21 | self.tau = tau 22 | 23 | def forward(self, inputs, outputs, labels): 24 | """ 25 | Args: 26 | inputs: The original inputs that are feed to the teacher model 27 | outputs: the outputs of the model to be trained. It is expected to be 28 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 29 | in the first position and the distillation predictions as the second output 30 | labels: the labels for the base criterion 31 | """ 32 | outputs_kd = None 33 | if not isinstance(outputs, torch.Tensor): 34 | # assume that the model outputs a tuple of [outputs, outputs_kd] 35 | outputs, outputs_kd = outputs 36 | base_loss = self.base_criterion(outputs, labels) 37 | if self.distillation_type == 'none': 38 | return base_loss 39 | 40 | if outputs_kd is None: 41 | raise ValueError("When knowledge distillation is enabled, the model is " 42 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 43 | "class_token and the dist_token") 44 | # don't backprop throught the teacher 45 | with torch.no_grad(): 46 | teacher_outputs = self.teacher_model(inputs) 47 | 48 | if self.distillation_type == 'soft': 49 | T = self.tau 50 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 51 | # with slight modifications 52 | distillation_loss = F.kl_div( 53 | F.log_softmax(outputs_kd / T, dim=1), 54 | F.log_softmax(teacher_outputs / T, dim=1), 55 | reduction='sum', 56 | log_target=True 57 | ) * (T * T) / outputs_kd.numel() 58 | elif self.distillation_type == 'hard': 59 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 60 | 61 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 62 | return loss 63 | -------------------------------------------------------------------------------- /hierarchical_architecture/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Current path is $PATH" 4 | echo "Running" 5 | nvidia-smi 6 | echo $CUDA_VISIBLE_DEVICES 7 | 8 | # # spectformer_s 9 | # python3 -m torch.distributed.launch \ 10 | # --nproc_per_node=8 \ 11 | # --nnodes=1 \ 12 | # --node_rank=0 \ 13 | # --master_addr="localhost" \ 14 | # --master_port=12346 \ 15 | # --use_env main.py --config configs/spectformer/spectformer_s.py --data-path ../../../../dataset/Image_net/imagenet --epochs 310 --batch-size 128 \ 16 | # --token-label --token-label-size 7 --token-label-data ../../../../dataset/Image_net/imagenet_efficientnet_l2_sz475_top5/ 17 | 18 | 19 | # spectformer_l 20 | python3 -m torch.distributed.launch \ 21 | --nproc_per_node=8 \ 22 | --nnodes=1 \ 23 | --node_rank=0 \ 24 | --master_addr="localhost" \ 25 | --master_port=12346 \ 26 | --use_env main.py --config configs/spectformer/spectformer_l.py --data-path ../../../../dataset/Image_net/imagenet --epochs 310 --batch-size 128 \ 27 | --token-label --token-label-size 7 --token-label-data ../../../../dataset/Image_net/imagenet_efficientnet_l2_sz475_top5/ -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationDataset 2 | from .data_prefetcher import DataPrefetcher -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/classification.py: -------------------------------------------------------------------------------- 1 | # code from PVT(https://github.com/whai362/PVT) 2 | import torch 3 | from torch.utils.data import Dataset 4 | from .imagenet import ImageNet 5 | 6 | 7 | class ClassificationDataset(Dataset): 8 | """Dataset for classification. 9 | """ 10 | 11 | def __init__(self, split='train', pipeline=None): 12 | if split == 'train': 13 | self.data_source = ImageNet(root='data/imagenet/train', 14 | list_file='data/imagenet/meta/train.txt', 15 | memcached=True, 16 | mclient_path='/mnt/lustre/share/memcached_client') 17 | else: 18 | self.data_source = ImageNet(root='data/imagenet/val', 19 | list_file='data/imagenet/meta/val.txt', 20 | memcached=True, 21 | mclient_path='/mnt/lustre/share/memcached_client') 22 | self.pipeline = pipeline 23 | 24 | def __len__(self): 25 | return self.data_source.get_length() 26 | 27 | def __getitem__(self, idx): 28 | img, target = self.data_source.get_sample(idx) 29 | if self.pipeline is not None: 30 | img = self.pipeline(img) 31 | 32 | return img, target 33 | -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | # code from PVT(https://github.com/whai362/PVT) 2 | import torch 3 | 4 | 5 | class DataPrefetcher: 6 | def __init__(self, loader): 7 | self.loader = iter(loader) 8 | self.stream = torch.cuda.Stream() 9 | self.preload() 10 | 11 | def preload(self): 12 | try: 13 | self.next_input, self.next_target = next(self.loader) 14 | except StopIteration: 15 | self.next_input = None 16 | self.next_target = None 17 | return 18 | 19 | with torch.cuda.stream(self.stream): 20 | self.next_input = self.next_input.cuda(non_blocking=True) 21 | self.next_target = self.next_target.cuda(non_blocking=True) 22 | 23 | def next(self): 24 | torch.cuda.current_stream().wait_stream(self.stream) 25 | input = self.next_input 26 | target = self.next_target 27 | if input is not None: 28 | self.preload() 29 | return input, target 30 | -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/image_list.py: -------------------------------------------------------------------------------- 1 | # code from PVT(https://github.com/whai362/PVT) 2 | import os 3 | from PIL import Image 4 | 5 | from .mcloader import McLoader 6 | 7 | 8 | class ImageList(object): 9 | 10 | def __init__(self, root, list_file, memcached=False, mclient_path=None): 11 | with open(list_file, 'r') as f: 12 | lines = f.readlines() 13 | self.has_labels = len(lines[0].split()) == 2 14 | if self.has_labels: 15 | self.fns, self.labels = zip(*[l.strip().split() for l in lines]) 16 | self.labels = [int(l) for l in self.labels] 17 | else: 18 | self.fns = [l.strip() for l in lines] 19 | self.fns = [os.path.join(root, fn) for fn in self.fns] 20 | self.memcached = memcached 21 | self.mclient_path = mclient_path 22 | self.initialized = False 23 | 24 | def _init_memcached(self): 25 | if not self.initialized: 26 | assert self.mclient_path is not None 27 | self.mc_loader = McLoader(self.mclient_path) 28 | self.initialized = True 29 | 30 | def get_length(self): 31 | return len(self.fns) 32 | 33 | def get_sample(self, idx): 34 | if self.memcached: 35 | self._init_memcached() 36 | if self.memcached: 37 | img = self.mc_loader(self.fns[idx]) 38 | else: 39 | img = Image.open(self.fns[idx]) 40 | img = img.convert('RGB') 41 | if self.has_labels: 42 | target = self.labels[idx] 43 | return img, target 44 | else: 45 | return img 46 | -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/imagenet.py: -------------------------------------------------------------------------------- 1 | # code from PVT(https://github.com/whai362/PVT) 2 | from .image_list import ImageList 3 | 4 | 5 | class ImageNet(ImageList): 6 | 7 | def __init__(self, root, list_file, memcached, mclient_path): 8 | super(ImageNet, self).__init__( 9 | root, list_file, memcached, mclient_path) 10 | -------------------------------------------------------------------------------- /hierarchical_architecture/mcloader/mcloader.py: -------------------------------------------------------------------------------- 1 | # code from PVT(https://github.com/whai362/PVT) 2 | 3 | import io 4 | from PIL import Image 5 | try: 6 | import mc 7 | except ImportError as E: 8 | pass 9 | 10 | 11 | def pil_loader(img_str): 12 | buff = io.BytesIO(img_str) 13 | return Image.open(buff) 14 | 15 | 16 | class McLoader(object): 17 | 18 | def __init__(self, mclient_path): 19 | assert mclient_path is not None, \ 20 | "Please specify 'data_mclient_path' in the config." 21 | self.mclient_path = mclient_path 22 | server_list_config_file = "{}/server_list.conf".format( 23 | self.mclient_path) 24 | client_config_file = "{}/client.conf".format(self.mclient_path) 25 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 26 | client_config_file) 27 | 28 | def __call__(self, fn): 29 | try: 30 | img_value = mc.pyvector() 31 | self.mclient.Get(fn, img_value) 32 | img_value_str = mc.ConvertBuffer(img_value) 33 | img = pil_loader(img_value_str) 34 | except: 35 | print('Read image failed ({})'.format(fn)) 36 | return None 37 | else: 38 | return img -------------------------------------------------------------------------------- /hierarchical_architecture/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | -------------------------------------------------------------------------------- /hierarchical_architecture/spectformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | import numpy as np 11 | from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT) 12 | 13 | 14 | class SpectralGatingNetwork(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | # this weights are valid for h=14 and w=8 18 | if dim == 64: #96 for large model, 64 for small and base model 19 | self.h = 56 #H 20 | self.w = 29 #(W/2)+1 21 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 22 | if dim ==128: 23 | self.h = 28 #H 24 | self.w = 15 #(W/2)+1, this is due to rfft2 25 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 26 | if dim == 96: #96 for large model, 64 for small and base model 27 | self.h = 56 #H 28 | self.w = 29 #(W/2)+1 29 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 30 | if dim ==192: 31 | self.h = 28 #H 32 | self.w = 15 #(W/2)+1, this is due to rfft2 33 | self.complex_weight = nn.Parameter(torch.randn(self.h, self.w, dim, 2, dtype=torch.float32) * 0.02) 34 | 35 | def forward(self, x, H, W): 36 | # print('wno',x.shape) #CIFAR100 image :[128, 196, 384] 37 | B, N, C = x.shape 38 | # print('wno B, N, C',B, N, C) #CIFAR100 image : 128 196 384 39 | x = x.view(B, H, W, C) 40 | # B, H, W, C=x.shape 41 | x = x.to(torch.float32) 42 | # print(x.dtype) 43 | # Add above for this error, RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same 44 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 45 | # print('wno',x.shape) 46 | weight = torch.view_as_complex(self.complex_weight) 47 | # print('weight',weight.shape) 48 | x = x * weight 49 | x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho') 50 | # print('wno',x.shape) 51 | x = x.reshape(B, N, C)# permute is not same as reshape or view 52 | return x 53 | 54 | 55 | def rand_bbox(size, lam, scale=1): 56 | W = size[1] // scale 57 | H = size[2] // scale 58 | cut_rat = np.sqrt(1. - lam) 59 | cut_w = np.int(W * cut_rat) 60 | cut_h = np.int(H * cut_rat) 61 | 62 | # uniform 63 | cx = np.random.randint(W) 64 | cy = np.random.randint(H) 65 | 66 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 67 | bby1 = np.clip(cy - cut_h // 2, 0, H) 68 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 69 | bby2 = np.clip(cy + cut_h // 2, 0, H) 70 | 71 | return bbx1, bby1, bbx2, bby2 72 | 73 | class ClassAttention(nn.Module): 74 | def __init__(self, dim, num_heads): 75 | super().__init__() 76 | self.num_heads = num_heads 77 | head_dim = dim // num_heads 78 | self.head_dim = head_dim 79 | self.scale = head_dim**-0.5 80 | self.kv = nn.Linear(dim, dim * 2) 81 | self.q = nn.Linear(dim, dim) 82 | self.proj = nn.Linear(dim, dim) 83 | self.apply(self._init_weights) 84 | 85 | def _init_weights(self, m): 86 | if isinstance(m, nn.Linear): 87 | trunc_normal_(m.weight, std=.02) 88 | if isinstance(m, nn.Linear) and m.bias is not None: 89 | nn.init.constant_(m.bias, 0) 90 | elif isinstance(m, nn.LayerNorm): 91 | nn.init.constant_(m.bias, 0) 92 | nn.init.constant_(m.weight, 1.0) 93 | elif isinstance(m, nn.Conv2d): 94 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 95 | fan_out //= m.groups 96 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 97 | if m.bias is not None: 98 | m.bias.data.zero_() 99 | 100 | def forward(self, x): 101 | B, N, C = x.shape 102 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 103 | k, v = kv[0], kv[1] 104 | q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) 105 | attn = ((q * self.scale) @ k.transpose(-2, -1)) 106 | attn = attn.softmax(dim=-1) 107 | cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads) 108 | cls_embed = self.proj(cls_embed) 109 | return cls_embed 110 | 111 | class FFN(nn.Module): 112 | def __init__(self, in_features, hidden_features): 113 | super().__init__() 114 | self.fc1 = nn.Linear(in_features, hidden_features) 115 | self.act = nn.GELU() 116 | self.fc2 = nn.Linear(hidden_features, in_features) 117 | self.apply(self._init_weights) 118 | 119 | def _init_weights(self, m): 120 | if isinstance(m, nn.Linear): 121 | trunc_normal_(m.weight, std=.02) 122 | if isinstance(m, nn.Linear) and m.bias is not None: 123 | nn.init.constant_(m.bias, 0) 124 | elif isinstance(m, nn.LayerNorm): 125 | nn.init.constant_(m.bias, 0) 126 | nn.init.constant_(m.weight, 1.0) 127 | elif isinstance(m, nn.Conv2d): 128 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | fan_out //= m.groups 130 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | x = self.fc1(x) 136 | x = self.act(x) 137 | x = self.fc2(x) 138 | return x 139 | 140 | class ClassBlock(nn.Module): 141 | def __init__(self, dim, num_heads, mlp_ratio, norm_layer=nn.LayerNorm): 142 | super().__init__() 143 | self.norm1 = norm_layer(dim) 144 | self.norm2 = norm_layer(dim) 145 | self.attn = ClassAttention(dim, num_heads) 146 | self.mlp = FFN(dim, int(dim * mlp_ratio)) 147 | self.apply(self._init_weights) 148 | 149 | def _init_weights(self, m): 150 | if isinstance(m, nn.Linear): 151 | trunc_normal_(m.weight, std=.02) 152 | if isinstance(m, nn.Linear) and m.bias is not None: 153 | nn.init.constant_(m.bias, 0) 154 | elif isinstance(m, nn.LayerNorm): 155 | nn.init.constant_(m.bias, 0) 156 | nn.init.constant_(m.weight, 1.0) 157 | elif isinstance(m, nn.Conv2d): 158 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 159 | fan_out //= m.groups 160 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 161 | if m.bias is not None: 162 | m.bias.data.zero_() 163 | 164 | def forward(self, x): 165 | cls_embed = x[:, :1] 166 | cls_embed = cls_embed + self.attn(self.norm1(x)) 167 | cls_embed = cls_embed + self.mlp(self.norm2(cls_embed)) 168 | return torch.cat([cls_embed, x[:, 1:]], dim=1) 169 | 170 | class PVT2FFN(nn.Module): 171 | def __init__(self, in_features, hidden_features): 172 | super().__init__() 173 | self.fc1 = nn.Linear(in_features, hidden_features) 174 | self.dwconv = DWConv(hidden_features) 175 | self.act = nn.GELU() 176 | self.fc2 = nn.Linear(hidden_features, in_features) 177 | self.apply(self._init_weights) 178 | 179 | def _init_weights(self, m): 180 | if isinstance(m, nn.Linear): 181 | trunc_normal_(m.weight, std=.02) 182 | if isinstance(m, nn.Linear) and m.bias is not None: 183 | nn.init.constant_(m.bias, 0) 184 | elif isinstance(m, nn.LayerNorm): 185 | nn.init.constant_(m.bias, 0) 186 | nn.init.constant_(m.weight, 1.0) 187 | elif isinstance(m, nn.Conv2d): 188 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 189 | fan_out //= m.groups 190 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 191 | if m.bias is not None: 192 | m.bias.data.zero_() 193 | 194 | def forward(self, x, H, W): 195 | x = self.fc1(x) 196 | x = self.dwconv(x, H, W) 197 | x = self.act(x) 198 | x = self.fc2(x) 199 | return x 200 | 201 | class Attention(nn.Module): 202 | def __init__(self, dim, num_heads): 203 | super().__init__() 204 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 205 | 206 | self.dim = dim 207 | self.num_heads = num_heads 208 | head_dim = dim // num_heads 209 | self.scale = head_dim ** -0.5 210 | 211 | self.q = nn.Linear(dim, dim) 212 | self.kv = nn.Linear(dim, dim * 2) 213 | self.proj = nn.Linear(dim, dim) 214 | self.apply(self._init_weights) 215 | 216 | def _init_weights(self, m): 217 | if isinstance(m, nn.Linear): 218 | trunc_normal_(m.weight, std=.02) 219 | if isinstance(m, nn.Linear) and m.bias is not None: 220 | nn.init.constant_(m.bias, 0) 221 | elif isinstance(m, nn.LayerNorm): 222 | nn.init.constant_(m.bias, 0) 223 | nn.init.constant_(m.weight, 1.0) 224 | elif isinstance(m, nn.Conv2d): 225 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 226 | fan_out //= m.groups 227 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 228 | if m.bias is not None: 229 | m.bias.data.zero_() 230 | 231 | def forward(self, x, H, W): 232 | B, N, C = x.shape 233 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 234 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 235 | k, v = kv[0], kv[1] 236 | attn = (q @ k.transpose(-2, -1)) * self.scale 237 | attn = attn.softmax(dim=-1) 238 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 239 | x = self.proj(x) 240 | return x 241 | 242 | class Block(nn.Module): 243 | def __init__(self, 244 | dim, 245 | num_heads, 246 | mlp_ratio, 247 | drop_path=0., 248 | norm_layer=nn.LayerNorm, 249 | sr_ratio=1, 250 | block_type = 'wave' 251 | ): 252 | super().__init__() 253 | self.norm1 = norm_layer(dim) 254 | self.norm2 = norm_layer(dim) 255 | 256 | if block_type == 'std_att': 257 | self.attn = Attention(dim, num_heads) 258 | else: 259 | self.attn = SpectralGatingNetwork (dim) 260 | self.mlp = PVT2FFN(in_features=dim, hidden_features=int(dim * mlp_ratio)) 261 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 262 | self.apply(self._init_weights) 263 | 264 | def _init_weights(self, m): 265 | if isinstance(m, nn.Linear): 266 | trunc_normal_(m.weight, std=.02) 267 | if isinstance(m, nn.Linear) and m.bias is not None: 268 | nn.init.constant_(m.bias, 0) 269 | elif isinstance(m, nn.LayerNorm): 270 | nn.init.constant_(m.bias, 0) 271 | nn.init.constant_(m.weight, 1.0) 272 | elif isinstance(m, nn.Conv2d): 273 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 274 | fan_out //= m.groups 275 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 276 | if m.bias is not None: 277 | m.bias.data.zero_() 278 | 279 | def forward(self, x, H, W): 280 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 281 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 282 | return x 283 | 284 | class DownSamples(nn.Module): 285 | def __init__(self, in_channels, out_channels): 286 | super().__init__() 287 | self.proj = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1) 288 | self.norm = nn.LayerNorm(out_channels) 289 | self.apply(self._init_weights) 290 | 291 | def _init_weights(self, m): 292 | if isinstance(m, nn.Linear): 293 | trunc_normal_(m.weight, std=.02) 294 | if isinstance(m, nn.Linear) and m.bias is not None: 295 | nn.init.constant_(m.bias, 0) 296 | elif isinstance(m, nn.LayerNorm): 297 | nn.init.constant_(m.bias, 0) 298 | nn.init.constant_(m.weight, 1.0) 299 | elif isinstance(m, nn.Conv2d): 300 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 301 | fan_out //= m.groups 302 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 303 | if m.bias is not None: 304 | m.bias.data.zero_() 305 | 306 | def forward(self, x): 307 | x = self.proj(x) 308 | _, _, H, W = x.shape 309 | x = x.flatten(2).transpose(1, 2) 310 | x = self.norm(x) 311 | return x, H, W 312 | 313 | class Stem(nn.Module): 314 | def __init__(self, in_channels, stem_hidden_dim, out_channels): 315 | super().__init__() 316 | hidden_dim = stem_hidden_dim 317 | self.conv = nn.Sequential( 318 | nn.Conv2d(in_channels, hidden_dim, kernel_size=7, stride=2, 319 | padding=3, bias=False), # 112x112 320 | nn.BatchNorm2d(hidden_dim), 321 | nn.ReLU(inplace=True), 322 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 323 | padding=1, bias=False), # 112x112 324 | nn.BatchNorm2d(hidden_dim), 325 | nn.ReLU(inplace=True), 326 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, 327 | padding=1, bias=False), # 112x112 328 | nn.BatchNorm2d(hidden_dim), 329 | nn.ReLU(inplace=True), 330 | ) 331 | self.proj = nn.Conv2d(hidden_dim, 332 | out_channels, 333 | kernel_size=3, 334 | stride=2, 335 | padding=1) 336 | self.norm = nn.LayerNorm(out_channels) 337 | 338 | self.apply(self._init_weights) 339 | 340 | def _init_weights(self, m): 341 | if isinstance(m, nn.Linear): 342 | trunc_normal_(m.weight, std=.02) 343 | if isinstance(m, nn.Linear) and m.bias is not None: 344 | nn.init.constant_(m.bias, 0) 345 | elif isinstance(m, nn.LayerNorm): 346 | nn.init.constant_(m.bias, 0) 347 | nn.init.constant_(m.weight, 1.0) 348 | elif isinstance(m, nn.Conv2d): 349 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 350 | fan_out //= m.groups 351 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 352 | if m.bias is not None: 353 | m.bias.data.zero_() 354 | 355 | def forward(self, x): 356 | x = self.conv(x) 357 | x = self.proj(x) 358 | _, _, H, W = x.shape 359 | x = x.flatten(2).transpose(1, 2) 360 | x = self.norm(x) 361 | return x, H, W 362 | 363 | class SpectFormer(nn.Module): 364 | def __init__(self, 365 | in_chans=3, 366 | num_classes=1000, 367 | stem_hidden_dim = 32, 368 | embed_dims=[64, 128, 320, 448], 369 | num_heads=[2, 4, 10, 14], 370 | mlp_ratios=[8, 8, 4, 4], 371 | drop_path_rate=0., 372 | norm_layer=nn.LayerNorm, 373 | depths=[3, 4, 6, 3], 374 | sr_ratios=[4, 2, 1, 1], 375 | num_stages=4, 376 | token_label=True, 377 | **kwargs 378 | ): 379 | super().__init__() 380 | self.num_classes = num_classes 381 | self.depths = depths 382 | self.num_stages = num_stages 383 | 384 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 385 | cur = 0 386 | 387 | for i in range(num_stages): 388 | if i == 0: 389 | patch_embed = Stem(in_chans, stem_hidden_dim, embed_dims[i]) 390 | else: 391 | patch_embed = DownSamples(embed_dims[i - 1], embed_dims[i]) 392 | 393 | block = nn.ModuleList([Block( 394 | dim = embed_dims[i], 395 | num_heads = num_heads[i], 396 | mlp_ratio = mlp_ratios[i], 397 | drop_path=dpr[cur + j], 398 | norm_layer=norm_layer, 399 | sr_ratio = sr_ratios[i], 400 | block_type='wave' if i < 2 else 'std_att') 401 | for j in range(depths[i])]) 402 | 403 | norm = norm_layer(embed_dims[i]) 404 | cur += depths[i] 405 | 406 | setattr(self, f"patch_embed{i + 1}", patch_embed) 407 | setattr(self, f"block{i + 1}", block) 408 | setattr(self, f"norm{i + 1}", norm) 409 | 410 | post_layers = ['ca'] 411 | self.post_network = nn.ModuleList([ 412 | ClassBlock( 413 | dim = embed_dims[-1], 414 | num_heads = num_heads[-1], 415 | mlp_ratio = mlp_ratios[-1], 416 | norm_layer=norm_layer) 417 | for _ in range(len(post_layers)) 418 | ]) 419 | 420 | # classification head 421 | self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 422 | ##################################### token_label ##################################### 423 | self.return_dense = token_label 424 | self.mix_token = token_label 425 | self.beta = 1.0 426 | self.pooling_scale = 8 427 | if self.return_dense: 428 | self.aux_head = nn.Linear( 429 | embed_dims[-1], 430 | num_classes) if num_classes > 0 else nn.Identity() 431 | ##################################### token_label ##################################### 432 | 433 | self.apply(self._init_weights) 434 | 435 | def _init_weights(self, m): 436 | if isinstance(m, nn.Linear): 437 | trunc_normal_(m.weight, std=.02) 438 | if isinstance(m, nn.Linear) and m.bias is not None: 439 | nn.init.constant_(m.bias, 0) 440 | elif isinstance(m, nn.LayerNorm): 441 | nn.init.constant_(m.bias, 0) 442 | nn.init.constant_(m.weight, 1.0) 443 | elif isinstance(m, nn.Conv2d): 444 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 445 | fan_out //= m.groups 446 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 447 | if m.bias is not None: 448 | m.bias.data.zero_() 449 | 450 | def forward_cls(self, x): 451 | B, N, C = x.shape 452 | cls_tokens = x.mean(dim=1, keepdim=True) 453 | x = torch.cat((cls_tokens, x), dim=1) 454 | for block in self.post_network: 455 | x = block(x) 456 | return x 457 | 458 | def forward_features(self, x): 459 | B = x.shape[0] 460 | for i in range(self.num_stages): 461 | patch_embed = getattr(self, f"patch_embed{i + 1}") 462 | block = getattr(self, f"block{i + 1}") 463 | x, H, W = patch_embed(x) 464 | for blk in block: 465 | x = blk(x, H, W) 466 | 467 | if i != self.num_stages - 1: 468 | norm = getattr(self, f"norm{i + 1}") 469 | x = norm(x) 470 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 471 | 472 | x = self.forward_cls(x)[:, 0] 473 | norm = getattr(self, f"norm{self.num_stages}") 474 | x = norm(x) 475 | return x 476 | 477 | def forward(self, x): 478 | if not self.return_dense: 479 | x = self.forward_features(x) 480 | x = self.head(x) 481 | return x 482 | else: 483 | x, H, W = self.forward_embeddings(x) 484 | # mix token, see token labeling for details. 485 | if self.mix_token and self.training: 486 | lam = np.random.beta(self.beta, self.beta) 487 | patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[ 488 | 2] // self.pooling_scale 489 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale) 490 | temp_x = x.clone() 491 | sbbx1,sbby1,sbbx2,sbby2=self.pooling_scale*bbx1,self.pooling_scale*bby1,\ 492 | self.pooling_scale*bbx2,self.pooling_scale*bby2 493 | temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :] 494 | x = temp_x 495 | else: 496 | bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0 497 | 498 | x = self.forward_tokens(x, H, W) 499 | x_cls = self.head(x[:, 0]) 500 | x_aux = self.aux_head( 501 | x[:, 1:] 502 | ) # generate classes in all feature tokens, see token labeling 503 | 504 | if not self.training: 505 | return x_cls + 0.5 * x_aux.max(1)[0] 506 | 507 | if self.mix_token and self.training: # reverse "mix token", see token labeling for details. 508 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1]) 509 | 510 | temp_x = x_aux.clone() 511 | temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :] 512 | x_aux = temp_x 513 | 514 | x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1]) 515 | 516 | return x_cls, x_aux, (bbx1, bby1, bbx2, bby2) 517 | 518 | def forward_tokens(self, x, H, W): 519 | B = x.shape[0] 520 | x = x.view(B, -1, x.size(-1)) 521 | 522 | for i in range(self.num_stages): 523 | if i != 0: 524 | patch_embed = getattr(self, f"patch_embed{i + 1}") 525 | x, H, W = patch_embed(x) 526 | 527 | block = getattr(self, f"block{i + 1}") 528 | for blk in block: 529 | x = blk(x, H, W) 530 | 531 | if i != self.num_stages - 1: 532 | norm = getattr(self, f"norm{i + 1}") 533 | x = norm(x) 534 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 535 | 536 | x = self.forward_cls(x) 537 | norm = getattr(self, f"norm{self.num_stages}") 538 | x = norm(x) 539 | return x 540 | 541 | def forward_embeddings(self, x): 542 | patch_embed = getattr(self, f"patch_embed{0 + 1}") 543 | x, H, W = patch_embed(x) 544 | x = x.view(x.size(0), H, W, -1) 545 | return x, H, W 546 | 547 | 548 | class DWConv(nn.Module): 549 | def __init__(self, dim=768): 550 | super(DWConv, self).__init__() 551 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 552 | 553 | def forward(self, x, H, W): 554 | B, N, C = x.shape 555 | x = x.transpose(1, 2).view(B, C, H, W) 556 | x = self.dwconv(x) 557 | x = x.flatten(2).transpose(1, 2) 558 | return x 559 | 560 | @register_model 561 | def spectformer_s(pretrained=False, **kwargs): 562 | model = SpectFormer( 563 | stem_hidden_dim = 32, 564 | embed_dims = [64, 128, 320, 448], 565 | num_heads = [2, 4, 10, 14], 566 | mlp_ratios = [8, 8, 4, 4], 567 | norm_layer = partial(nn.LayerNorm, eps=1e-6), 568 | depths = [3, 4, 6, 3], 569 | sr_ratios = [4, 2, 1, 1], 570 | **kwargs) 571 | model.default_cfg = _cfg() 572 | return model 573 | 574 | @register_model 575 | def spectformer_b(pretrained=False, **kwargs): 576 | model = SpectFormer( 577 | stem_hidden_dim = 64, 578 | embed_dims = [64, 128, 320, 512], 579 | num_heads = [2, 4, 10, 16], 580 | mlp_ratios = [8, 8, 4, 4], 581 | norm_layer = partial(nn.LayerNorm, eps=1e-6), 582 | depths = [3, 4, 12, 3], 583 | sr_ratios = [4, 2, 1, 1], 584 | **kwargs) 585 | model.default_cfg = _cfg() 586 | return model 587 | 588 | @register_model 589 | def spectformer_l(pretrained=False, **kwargs): 590 | model = SpectFormer( 591 | stem_hidden_dim = 64, 592 | embed_dims = [96, 192, 384, 512], 593 | num_heads = [3, 6, 12, 16], 594 | mlp_ratios = [8, 8, 4, 4], 595 | norm_layer = partial(nn.LayerNorm, eps=1e-6), 596 | depths = [3, 6, 18, 3], 597 | sr_ratios = [4, 2, 1, 1], 598 | **kwargs) 599 | model.default_cfg = _cfg() 600 | return model -------------------------------------------------------------------------------- /hierarchical_architecture/util/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import load_pretrained_weights -------------------------------------------------------------------------------- /hierarchical_architecture/util/checkpoint_saver.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import operator 4 | import os 5 | import logging 6 | 7 | import torch 8 | 9 | from timm.utils.model import unwrap_model, get_state_dict 10 | 11 | 12 | _logger = logging.getLogger(__name__) 13 | 14 | 15 | class CheckpointSaver2: 16 | def __init__( 17 | self, 18 | model, 19 | optimizer, 20 | args=None, 21 | model_ema=None, 22 | amp_scaler=None, 23 | checkpoint_prefix='checkpoint', 24 | recovery_prefix='recovery', 25 | checkpoint_dir='', 26 | recovery_dir='', 27 | decreasing=False, 28 | max_history=10, 29 | unwrap_fn=unwrap_model): 30 | 31 | # objects to save state_dicts of 32 | self.model = model 33 | self.optimizer = optimizer 34 | self.args = args 35 | self.model_ema = model_ema 36 | self.amp_scaler = amp_scaler 37 | 38 | # state 39 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness 40 | self.best_epoch = None 41 | self.best_metric = None 42 | self.curr_recovery_file = '' 43 | self.last_recovery_file = '' 44 | 45 | # config 46 | self.checkpoint_dir = checkpoint_dir 47 | self.recovery_dir = recovery_dir 48 | self.save_prefix = checkpoint_prefix 49 | self.recovery_prefix = recovery_prefix 50 | self.extension = '.pth.tar' 51 | self.decreasing = decreasing # a lower metric is better if True 52 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs 53 | self.max_history = max_history 54 | self.unwrap_fn = unwrap_fn 55 | assert self.max_history >= 1 56 | 57 | def save_checkpoint(self, epoch, metric=None): 58 | assert epoch >= 0 59 | tmp_save_path = os.path.join(self.checkpoint_dir, str(epoch) + self.extension) 60 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 61 | self._save(tmp_save_path, epoch, metric) 62 | if os.path.exists(last_save_path): 63 | #os.unlink(last_save_path) # required for Windows support. 64 | os.remove(last_save_path) ########################################################### 65 | os.rename(tmp_save_path, last_save_path) 66 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 67 | if (len(self.checkpoint_files) < self.max_history 68 | or metric is None or self.cmp(metric, worst_file[1])): 69 | if len(self.checkpoint_files) >= self.max_history: 70 | self._cleanup_checkpoints(1) 71 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 72 | save_path = os.path.join(self.checkpoint_dir, filename) 73 | 74 | if os.path.exists(save_path): 75 | os.remove(save_path) 76 | 77 | #os.link(last_save_path, save_path) 78 | os.rename(last_save_path, save_path) 79 | self.checkpoint_files.append((save_path, metric)) 80 | self.checkpoint_files = sorted( 81 | self.checkpoint_files, key=lambda x: x[1], 82 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 83 | 84 | checkpoints_str = "Current checkpoints:\n" 85 | for c in self.checkpoint_files: 86 | checkpoints_str += ' {}\n'.format(c) 87 | _logger.info(checkpoints_str) 88 | 89 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 90 | self.best_epoch = epoch 91 | self.best_metric = metric 92 | #best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 93 | #if os.path.exists(best_save_path): 94 | # os.unlink(best_save_path) 95 | #os.link(last_save_path, best_save_path) 96 | 97 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 98 | 99 | def _save(self, save_path, epoch, metric=None): 100 | save_state = { 101 | 'epoch': epoch, 102 | 'arch': type(self.model).__name__.lower(), 103 | 'state_dict': get_state_dict(self.model, self.unwrap_fn), 104 | 'optimizer': self.optimizer.state_dict(), 105 | 'version': 2, # version < 2 increments epoch before save 106 | } 107 | if self.args is not None: 108 | save_state['arch'] = self.args.model 109 | save_state['args'] = self.args 110 | if self.amp_scaler is not None: 111 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() 112 | if self.model_ema is not None: 113 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) 114 | if metric is not None: 115 | save_state['metric'] = metric 116 | torch.save(save_state, save_path) 117 | 118 | def _cleanup_checkpoints(self, trim=0): 119 | trim = min(len(self.checkpoint_files), trim) 120 | delete_index = self.max_history - trim 121 | if delete_index < 0 or len(self.checkpoint_files) <= delete_index: 122 | return 123 | to_delete = self.checkpoint_files[delete_index:] 124 | for d in to_delete: 125 | try: 126 | _logger.debug("Cleaning checkpoint: {}".format(d)) 127 | os.remove(d[0]) 128 | except Exception as e: 129 | _logger.error("Exception '{}' while deleting checkpoint".format(e)) 130 | self.checkpoint_files = self.checkpoint_files[:delete_index] 131 | 132 | def save_recovery(self, epoch, batch_idx=0): 133 | assert epoch >= 0 134 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 135 | save_path = os.path.join(self.recovery_dir, filename) 136 | self._save(save_path, epoch) 137 | if os.path.exists(self.last_recovery_file): 138 | try: 139 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 140 | os.remove(self.last_recovery_file) 141 | except Exception as e: 142 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 143 | self.last_recovery_file = self.curr_recovery_file 144 | self.curr_recovery_file = save_path 145 | 146 | def find_recovery(self): 147 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) 148 | files = glob.glob(recovery_path + '*' + self.extension) 149 | files = sorted(files) 150 | return files[0] if len(files) else '' -------------------------------------------------------------------------------- /hierarchical_architecture/util/flops_counter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2019 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import sys 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | 16 | def get_model_complexity_info(model, input_res, 17 | print_per_layer_stat=True, 18 | as_strings=True, 19 | input_constructor=None, ost=sys.stdout, 20 | verbose=False, ignore_modules=[], 21 | custom_modules_hooks={}): 22 | assert type(input_res) is tuple 23 | assert len(input_res) >= 1 24 | assert isinstance(model, nn.Module) 25 | global CUSTOM_MODULES_MAPPING 26 | CUSTOM_MODULES_MAPPING = custom_modules_hooks 27 | flops_model = add_flops_counting_methods(model) 28 | flops_model.eval() 29 | flops_model.start_flops_count(ost=ost, verbose=verbose, ignore_list=ignore_modules) 30 | if input_constructor: 31 | input = input_constructor(input_res) 32 | _ = flops_model(**input) 33 | else: 34 | try: 35 | batch = torch.ones(()).new_empty((1, *input_res), 36 | dtype=next(flops_model.parameters()).dtype, 37 | device=next(flops_model.parameters()).device) 38 | except StopIteration: 39 | batch = torch.ones(()).new_empty((1, *input_res)) 40 | 41 | _ = flops_model(batch) 42 | 43 | flops_count, params_count = flops_model.compute_average_flops_cost() 44 | if print_per_layer_stat: 45 | print_model_with_flops(flops_model, flops_count, params_count, ost=ost) 46 | flops_model.stop_flops_count() 47 | CUSTOM_MODULES_MAPPING = {} 48 | 49 | if as_strings: 50 | return flops_to_string(flops_count), params_to_string(params_count) 51 | 52 | return flops_count, params_count 53 | 54 | 55 | def flops_to_string(flops, units='GMac', precision=2): 56 | if units is None: 57 | if flops // 10**9 > 0: 58 | return str(round(flops / 10.**9, precision)) + ' GMac' 59 | elif flops // 10**6 > 0: 60 | return str(round(flops / 10.**6, precision)) + ' MMac' 61 | elif flops // 10**3 > 0: 62 | return str(round(flops / 10.**3, precision)) + ' KMac' 63 | else: 64 | return str(flops) + ' Mac' 65 | else: 66 | if units == 'GMac': 67 | return str(round(flops / 10.**9, precision)) + ' ' + units 68 | elif units == 'MMac': 69 | return str(round(flops / 10.**6, precision)) + ' ' + units 70 | elif units == 'KMac': 71 | return str(round(flops / 10.**3, precision)) + ' ' + units 72 | else: 73 | return str(flops) + ' Mac' 74 | 75 | 76 | def params_to_string(params_num, units=None, precision=2): 77 | if units is None: 78 | if params_num // 10 ** 6 > 0: 79 | return str(round(params_num / 10 ** 6, 2)) + ' M' 80 | elif params_num // 10 ** 3: 81 | return str(round(params_num / 10 ** 3, 2)) + ' k' 82 | else: 83 | return str(params_num) 84 | else: 85 | if units == 'M': 86 | return str(round(params_num / 10.**6, precision)) + ' ' + units 87 | elif units == 'K': 88 | return str(round(params_num / 10.**3, precision)) + ' ' + units 89 | else: 90 | return str(params_num) 91 | 92 | 93 | def print_model_with_flops(model, total_flops, total_params, units='GMac', 94 | precision=3, ost=sys.stdout): 95 | 96 | def accumulate_params(self): 97 | if is_supported_instance(self): 98 | return self.__params__ 99 | else: 100 | sum = 0 101 | for m in self.children(): 102 | sum += m.accumulate_params() 103 | return sum 104 | 105 | def accumulate_flops(self): 106 | if is_supported_instance(self): 107 | return self.__flops__ / model.__batch_counter__ 108 | else: 109 | sum = 0 110 | for m in self.children(): 111 | sum += m.accumulate_flops() 112 | return sum 113 | 114 | def flops_repr(self): 115 | accumulated_params_num = self.accumulate_params() 116 | accumulated_flops_cost = self.accumulate_flops() 117 | return ', '.join([params_to_string(accumulated_params_num, units='M', precision=precision), 118 | '{:.3%} Params'.format(accumulated_params_num / total_params), 119 | flops_to_string(accumulated_flops_cost, units=units, precision=precision), 120 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 121 | self.original_extra_repr()]) 122 | 123 | def add_extra_repr(m): 124 | m.accumulate_flops = accumulate_flops.__get__(m) 125 | m.accumulate_params = accumulate_params.__get__(m) 126 | flops_extra_repr = flops_repr.__get__(m) 127 | if m.extra_repr != flops_extra_repr: 128 | m.original_extra_repr = m.extra_repr 129 | m.extra_repr = flops_extra_repr 130 | assert m.extra_repr != m.original_extra_repr 131 | 132 | def del_extra_repr(m): 133 | if hasattr(m, 'original_extra_repr'): 134 | m.extra_repr = m.original_extra_repr 135 | del m.original_extra_repr 136 | if hasattr(m, 'accumulate_flops'): 137 | del m.accumulate_flops 138 | 139 | model.apply(add_extra_repr) 140 | print(model, file=ost) 141 | model.apply(del_extra_repr) 142 | 143 | 144 | def get_model_parameters_number(model): 145 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 146 | return params_num 147 | 148 | 149 | def add_flops_counting_methods(net_main_module): 150 | # adding additional methods to the existing module object, 151 | # this is done this way so that each function has access to self object 152 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 153 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 154 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 155 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 156 | 157 | net_main_module.reset_flops_count() 158 | 159 | return net_main_module 160 | 161 | 162 | def compute_average_flops_cost(self): 163 | """ 164 | A method that will be available after add_flops_counting_methods() is called 165 | on a desired net object. 166 | 167 | Returns current mean flops consumption per image. 168 | 169 | """ 170 | 171 | batches_count = self.__batch_counter__ 172 | flops_sum = 0 173 | params_sum = 0 174 | for module in self.modules(): 175 | if is_supported_instance(module): 176 | flops_sum += module.__flops__ 177 | params_sum = get_model_parameters_number(self) 178 | return flops_sum / batches_count, params_sum 179 | 180 | 181 | def start_flops_count(self, **kwargs): 182 | """ 183 | A method that will be available after add_flops_counting_methods() is called 184 | on a desired net object. 185 | 186 | Activates the computation of mean flops consumption per image. 187 | Call it before you run the network. 188 | 189 | """ 190 | add_batch_counter_hook_function(self) 191 | 192 | seen_types = set() 193 | def add_flops_counter_hook_function(module, ost, verbose, ignore_list): 194 | if type(module) in ignore_list: 195 | seen_types.add(type(module)) 196 | if is_supported_instance(module): 197 | module.__params__ = 0 198 | elif is_supported_instance(module): 199 | if hasattr(module, '__flops_handle__'): 200 | return 201 | if type(module) in CUSTOM_MODULES_MAPPING: 202 | handle = module.register_forward_hook(CUSTOM_MODULES_MAPPING[type(module)]) 203 | else: 204 | handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) 205 | module.__flops_handle__ = handle 206 | seen_types.add(type(module)) 207 | else: 208 | if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and not type(module) in seen_types: 209 | print('Warning: module ' + type(module).__name__ + ' is treated as a zero-op.', file=ost) 210 | seen_types.add(type(module)) 211 | 212 | self.apply(partial(add_flops_counter_hook_function, **kwargs)) 213 | 214 | 215 | def stop_flops_count(self): 216 | """ 217 | A method that will be available after add_flops_counting_methods() is called 218 | on a desired net object. 219 | 220 | Stops computing the mean flops consumption per image. 221 | Call whenever you want to pause the computation. 222 | 223 | """ 224 | remove_batch_counter_hook_function(self) 225 | self.apply(remove_flops_counter_hook_function) 226 | 227 | 228 | def reset_flops_count(self): 229 | """ 230 | A method that will be available after add_flops_counting_methods() is called 231 | on a desired net object. 232 | 233 | Resets statistics computed so far. 234 | 235 | """ 236 | add_batch_counter_variables_or_reset(self) 237 | self.apply(add_flops_counter_variable_or_reset) 238 | 239 | 240 | # ---- Internal functions 241 | def empty_flops_counter_hook(module, input, output): 242 | module.__flops__ += 0 243 | 244 | 245 | def upsample_flops_counter_hook(module, input, output): 246 | output_size = output[0] 247 | batch_size = output_size.shape[0] 248 | output_elements_count = batch_size 249 | for val in output_size.shape[1:]: 250 | output_elements_count *= val 251 | module.__flops__ += int(output_elements_count) 252 | 253 | 254 | def relu_flops_counter_hook(module, input, output): 255 | active_elements_count = output.numel() 256 | module.__flops__ += int(active_elements_count) 257 | 258 | 259 | def linear_flops_counter_hook(module, input, output): 260 | input = input[0] 261 | output_last_dim = output.shape[-1] # pytorch checks dimensions, so here we don't care much 262 | module.__flops__ += int(np.prod(input.shape) * output_last_dim) 263 | 264 | 265 | def pool_flops_counter_hook(module, input, output): 266 | input = input[0] 267 | module.__flops__ += int(np.prod(input.shape)) 268 | 269 | 270 | def bn_flops_counter_hook(module, input, output): 271 | module.affine 272 | input = input[0] 273 | 274 | batch_flops = np.prod(input.shape) 275 | if module.affine: 276 | batch_flops *= 2 277 | module.__flops__ += int(batch_flops) 278 | 279 | 280 | def deconv_flops_counter_hook(conv_module, input, output): 281 | # Can have multiple inputs, getting the first one 282 | input = input[0] 283 | 284 | batch_size = input.shape[0] 285 | input_height, input_width = input.shape[2:] 286 | 287 | kernel_height, kernel_width = conv_module.kernel_size 288 | in_channels = conv_module.in_channels 289 | out_channels = conv_module.out_channels 290 | groups = conv_module.groups 291 | 292 | filters_per_channel = out_channels // groups 293 | conv_per_position_flops = kernel_height * kernel_width * in_channels * filters_per_channel 294 | 295 | active_elements_count = batch_size * input_height * input_width 296 | overall_conv_flops = conv_per_position_flops * active_elements_count 297 | bias_flops = 0 298 | if conv_module.bias is not None: 299 | output_height, output_width = output.shape[2:] 300 | bias_flops = out_channels * batch_size * output_height * output_height 301 | overall_flops = overall_conv_flops + bias_flops 302 | 303 | conv_module.__flops__ += int(overall_flops) 304 | 305 | 306 | def conv_flops_counter_hook(conv_module, input, output): 307 | # Can have multiple inputs, getting the first one 308 | input = input[0] 309 | 310 | batch_size = input.shape[0] 311 | output_dims = list(output.shape[2:]) 312 | 313 | kernel_dims = list(conv_module.kernel_size) 314 | in_channels = conv_module.in_channels 315 | out_channels = conv_module.out_channels 316 | groups = conv_module.groups 317 | 318 | filters_per_channel = out_channels // groups 319 | conv_per_position_flops = int(np.prod(kernel_dims)) * in_channels * filters_per_channel 320 | 321 | active_elements_count = batch_size * int(np.prod(output_dims)) 322 | 323 | overall_conv_flops = conv_per_position_flops * active_elements_count 324 | 325 | bias_flops = 0 326 | 327 | if conv_module.bias is not None: 328 | 329 | bias_flops = out_channels * active_elements_count 330 | 331 | overall_flops = overall_conv_flops + bias_flops 332 | 333 | conv_module.__flops__ += int(overall_flops) 334 | 335 | def batch_counter_hook(module, input, output): 336 | batch_size = 1 337 | if len(input) > 0: 338 | # Can have multiple inputs, getting the first one 339 | input = input[0] 340 | batch_size = len(input) 341 | else: 342 | pass 343 | print('Warning! No positional inputs found for a module, assuming batch size is 1.') 344 | module.__batch_counter__ += batch_size 345 | 346 | 347 | def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): 348 | # matrix matrix mult ih state and internal state 349 | flops += w_ih.shape[0]*w_ih.shape[1] 350 | # matrix matrix mult hh state and internal state 351 | flops += w_hh.shape[0]*w_hh.shape[1] 352 | if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): 353 | # add both operations 354 | flops += rnn_module.hidden_size 355 | elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): 356 | # hadamard of r 357 | flops += rnn_module.hidden_size 358 | # adding operations from both states 359 | flops += rnn_module.hidden_size*3 360 | # last two hadamard product and add 361 | flops += rnn_module.hidden_size*3 362 | elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): 363 | # adding operations from both states 364 | flops += rnn_module.hidden_size*4 365 | # two hadamard product and add for C state 366 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 367 | # final hadamard 368 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 369 | return flops 370 | 371 | 372 | def rnn_flops_counter_hook(rnn_module, input, output): 373 | """ 374 | Takes into account batch goes at first position, contrary 375 | to pytorch common rule (but actually it doesn't matter). 376 | IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate 377 | """ 378 | flops = 0 379 | inp = input[0] # input is a tuble containing a sequence to process and (optionally) hidden state 380 | batch_size = inp.shape[0] 381 | seq_length = inp.shape[1] 382 | num_layers = rnn_module.num_layers 383 | 384 | for i in range(num_layers): 385 | w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) 386 | w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) 387 | if i == 0: 388 | input_size = rnn_module.input_size 389 | else: 390 | input_size = rnn_module.hidden_size 391 | flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) 392 | if rnn_module.bias: 393 | b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) 394 | b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) 395 | flops += b_ih.shape[0] + b_hh.shape[0] 396 | 397 | flops *= batch_size 398 | flops *= seq_length 399 | if rnn_module.bidirectional: 400 | flops *= 2 401 | rnn_module.__flops__ += int(flops) 402 | 403 | 404 | def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): 405 | flops = 0 406 | inp = input[0] 407 | batch_size = inp.shape[0] 408 | w_ih = rnn_cell_module.__getattr__('weight_ih') 409 | w_hh = rnn_cell_module.__getattr__('weight_hh') 410 | input_size = inp.shape[1] 411 | flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) 412 | if rnn_cell_module.bias: 413 | b_ih = rnn_cell_module.__getattr__('bias_ih') 414 | b_hh = rnn_cell_module.__getattr__('bias_hh') 415 | flops += b_ih.shape[0] + b_hh.shape[0] 416 | 417 | flops *= batch_size 418 | rnn_cell_module.__flops__ += int(flops) 419 | 420 | 421 | def add_batch_counter_variables_or_reset(module): 422 | 423 | module.__batch_counter__ = 0 424 | 425 | 426 | def add_batch_counter_hook_function(module): 427 | if hasattr(module, '__batch_counter_handle__'): 428 | return 429 | 430 | handle = module.register_forward_hook(batch_counter_hook) 431 | module.__batch_counter_handle__ = handle 432 | 433 | 434 | def remove_batch_counter_hook_function(module): 435 | if hasattr(module, '__batch_counter_handle__'): 436 | module.__batch_counter_handle__.remove() 437 | del module.__batch_counter_handle__ 438 | 439 | 440 | def add_flops_counter_variable_or_reset(module): 441 | if is_supported_instance(module): 442 | if hasattr(module, '__flops__') or hasattr(module, '__params__'): 443 | print('Warning: variables __flops__ or __params__ are already ' 444 | 'defined for the module' + type(module).__name__ + 445 | ' ptflops can affect your code!') 446 | module.__flops__ = 0 447 | module.__params__ = get_model_parameters_number(module) 448 | 449 | CUSTOM_MODULES_MAPPING = { 450 | 451 | } 452 | 453 | MODULES_MAPPING = { 454 | # convolutions 455 | nn.Conv1d: conv_flops_counter_hook, 456 | nn.Conv2d: conv_flops_counter_hook, 457 | nn.Conv3d: conv_flops_counter_hook, 458 | # activations 459 | nn.ReLU: relu_flops_counter_hook, 460 | nn.PReLU: relu_flops_counter_hook, 461 | nn.ELU: relu_flops_counter_hook, 462 | nn.LeakyReLU: relu_flops_counter_hook, 463 | nn.ReLU6: relu_flops_counter_hook, 464 | nn.SiLU: relu_flops_counter_hook, 465 | nn.Sigmoid: relu_flops_counter_hook, 466 | # poolings 467 | nn.MaxPool1d: pool_flops_counter_hook, 468 | nn.AvgPool1d: pool_flops_counter_hook, 469 | nn.AvgPool2d: pool_flops_counter_hook, 470 | nn.MaxPool2d: pool_flops_counter_hook, 471 | nn.MaxPool3d: pool_flops_counter_hook, 472 | nn.AvgPool3d: pool_flops_counter_hook, 473 | nn.AdaptiveMaxPool1d: pool_flops_counter_hook, 474 | nn.AdaptiveAvgPool1d: pool_flops_counter_hook, 475 | nn.AdaptiveMaxPool2d: pool_flops_counter_hook, 476 | nn.AdaptiveAvgPool2d: pool_flops_counter_hook, 477 | nn.AdaptiveMaxPool3d: pool_flops_counter_hook, 478 | nn.AdaptiveAvgPool3d: pool_flops_counter_hook, 479 | # BNs 480 | nn.BatchNorm1d: bn_flops_counter_hook, 481 | nn.BatchNorm2d: bn_flops_counter_hook, 482 | nn.BatchNorm3d: bn_flops_counter_hook, 483 | # FC 484 | nn.Linear: linear_flops_counter_hook, 485 | # Upscale 486 | nn.Upsample: upsample_flops_counter_hook, 487 | # Deconvolution 488 | nn.ConvTranspose2d: deconv_flops_counter_hook, 489 | # RNN 490 | nn.RNN: rnn_flops_counter_hook, 491 | nn.GRU: rnn_flops_counter_hook, 492 | nn.LSTM: rnn_flops_counter_hook, 493 | nn.RNNCell: rnn_cell_flops_counter_hook, 494 | nn.LSTMCell: rnn_cell_flops_counter_hook, 495 | nn.GRUCell: rnn_cell_flops_counter_hook, 496 | } 497 | 498 | 499 | def is_supported_instance(module): 500 | if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING: 501 | return True 502 | return False 503 | 504 | 505 | def remove_flops_counter_hook_function(module): 506 | if is_supported_instance(module): 507 | if hasattr(module, '__flops_handle__'): 508 | module.__flops_handle__.remove() 509 | del module.__flops_handle__ 510 | -------------------------------------------------------------------------------- /hierarchical_architecture/util/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sea Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Adapted for VOLO 15 | ''' 16 | - load_pretrained_weights: load pretrained paramters to model in transfer learning 17 | - resize_pos_embed: resize position embedding 18 | - get_mean_and_std: calculate the mean and std value of dataset. 19 | ''' 20 | import torch 21 | import math 22 | import functools 23 | import logging 24 | import os 25 | import sys 26 | from collections import OrderedDict 27 | import torch.nn.functional as F 28 | 29 | _logger = logging.getLogger(__name__) 30 | 31 | 32 | def resize_pos_embed(posemb, posemb_new): 33 | ''' 34 | resize position embedding with class token 35 | example: 224:(14x14+1)-> 384: (24x24+1) 36 | return: new position embedding 37 | ''' 38 | ntok_new = posemb_new.shape[1] 39 | 40 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0,1:] # posemb_tok is for cls token, posemb_grid for the following tokens 41 | ntok_new -= 1 42 | gs_old = int(math.sqrt(len(posemb_grid))) # 14 43 | gs_new = int(math.sqrt(ntok_new)) # 24 44 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 45 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute( 46 | 0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] 47 | posemb_grid = F.interpolate( 48 | posemb_grid, size=(gs_new, gs_new), 49 | mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] 50 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape( 51 | 1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] 52 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) # [1, 24*24+1, dim] 53 | return posemb 54 | 55 | 56 | def resize_pos_embed_without_cls(posemb, posemb_new): 57 | ''' 58 | resize position embedding without class token 59 | example: 224:(14x14)-> 384: (24x24) 60 | return new position embedding 61 | ''' 62 | ntok_new = posemb_new.shape[1] 63 | posemb_grid = posemb[0] 64 | gs_old = int(math.sqrt(len(posemb_grid))) # 14 65 | gs_new = int(math.sqrt(ntok_new)) # 24 66 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 67 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute( 68 | 0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] 69 | posemb_grid = F.interpolate( 70 | posemb_grid, size=(gs_new, gs_new), 71 | mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] 72 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape( 73 | 1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] 74 | return posemb_grid 75 | 76 | 77 | def resize_pos_embed_4d(posemb, posemb_new): 78 | '''return new position embedding''' 79 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 80 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 81 | gs_old = posemb.shape[1] # 14 82 | gs_new = posemb_new.shape[1] # 24 83 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 84 | posemb_grid = posemb 85 | posemb_grid = posemb_grid.permute(0, 3, 1, 86 | 2) # [1, 14, 14, dim]->[1, dim, 14, 14] 87 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] 88 | posemb_grid = posemb_grid.permute(0, 2, 3, 1) # [1, dim, 24, 24]->[1, 24, 24, dim] 89 | return posemb_grid 90 | 91 | def load_state_dict(checkpoint_path, model, use_ema=False, num_classes=1000): 92 | # load state_dict 93 | if checkpoint_path and os.path.isfile(checkpoint_path): 94 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 95 | state_dict_key = 'state_dict' 96 | if isinstance(checkpoint, dict): 97 | if use_ema and 'state_dict_ema' in checkpoint: 98 | state_dict_key = 'state_dict_ema' 99 | if state_dict_key and state_dict_key in checkpoint: 100 | new_state_dict = OrderedDict() 101 | for k, v in checkpoint[state_dict_key].items(): 102 | # strip `module.` prefix 103 | name = k[7:] if k.startswith('module') else k 104 | new_state_dict[name] = v 105 | state_dict = new_state_dict 106 | else: 107 | state_dict = checkpoint 108 | _logger.info("Loaded {} from checkpoint '{}'".format( 109 | state_dict_key, checkpoint_path)) 110 | if num_classes != 1000: 111 | # completely discard fully connected for all other differences between pretrained and created model 112 | del state_dict['head' + '.weight'] 113 | del state_dict['head' + '.bias'] 114 | old_aux_head_weight = state_dict.pop('aux_head.weight', None) 115 | old_aux_head_bias = state_dict.pop('aux_head.bias', None) 116 | 117 | old_posemb = state_dict['pos_embed'] 118 | if model.pos_embed.shape != old_posemb.shape: # need resize the position embedding by interpolate 119 | if len(old_posemb.shape) == 3: 120 | if int(math.sqrt( 121 | old_posemb.shape[1]))**2 == old_posemb.shape[1]: 122 | new_posemb = resize_pos_embed_without_cls( 123 | old_posemb, model.pos_embed) 124 | else: 125 | new_posemb = resize_pos_embed(old_posemb, model.pos_embed) 126 | elif len(old_posemb.shape) == 4: 127 | new_posemb = resize_pos_embed_4d(old_posemb, model.pos_embed) 128 | state_dict['pos_embed'] = new_posemb 129 | 130 | return state_dict 131 | else: 132 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 133 | raise FileNotFoundError() 134 | 135 | 136 | def load_pretrained_weights(model, 137 | checkpoint_path, 138 | use_ema=False, 139 | strict=True, 140 | num_classes=1000): 141 | '''load pretrained weight for VOLO models''' 142 | state_dict = load_state_dict(checkpoint_path, model, use_ema, num_classes) 143 | model.load_state_dict(state_dict, strict=strict) 144 | 145 | 146 | def get_mean_and_std(dataset): 147 | '''Compute the mean and std value of dataset.''' 148 | dataloader = torch.utils.data.DataLoader(dataset, 149 | batch_size=1, 150 | shuffle=True, 151 | num_workers=2) 152 | mean = torch.zeros(3) 153 | std = torch.zeros(3) 154 | print('==> Computing mean and std..') 155 | for inputs, targets in dataloader: 156 | for i in range(3): 157 | mean[i] += inputs[:, i, :, :].mean() 158 | std[i] += inputs[:, i, :, :].std() 159 | mean.div_(len(dataset)) 160 | std.div_(len(dataset)) 161 | return mean, std 162 | 163 | @functools.lru_cache() 164 | def setup_logger( 165 | output=None, distributed_rank=0, *, color=True, name="train", abbrev_name=None 166 | ): 167 | logger = logging.getLogger(name) 168 | logger.setLevel(logging.DEBUG) 169 | logger.propagate = False 170 | 171 | if abbrev_name is None: 172 | abbrev_name = "xl" if name == "train" else name 173 | 174 | plain_formatter = logging.Formatter( 175 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 176 | ) 177 | # stdout logging: master only 178 | if distributed_rank == 0: 179 | ch = logging.StreamHandler(stream=sys.stdout) 180 | ch.setLevel(logging.DEBUG) 181 | formatter = plain_formatter 182 | ch.setFormatter(formatter) 183 | logger.addHandler(ch) 184 | 185 | # file logging: all workers 186 | if output is not None: 187 | if output.endswith(".txt") or output.endswith(".log"): 188 | filename = output 189 | else: 190 | filename = os.path.join(output, "log.txt") 191 | if distributed_rank > 0: 192 | filename = filename + ".rank{}".format(distributed_rank) 193 | #PathManager.mkdirs(os.path.dirname(filename)) 194 | 195 | fh = logging.FileHandler(filename) #logging.StreamHandler(filename) 196 | fh.setLevel(logging.DEBUG) 197 | fh.setFormatter(plain_formatter) 198 | logger.addHandler(fh) 199 | 200 | return logger -------------------------------------------------------------------------------- /hierarchical_architecture/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import sys 9 | import time 10 | from collections import defaultdict, deque 11 | import datetime 12 | import logging 13 | import functools 14 | import torch 15 | import torch.distributed as dist 16 | import mmcv 17 | 18 | _logger = logging.getLogger('train') 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if isinstance(v, torch.Tensor): 90 | v = v.item() 91 | assert isinstance(v, (float, int)) 92 | self.meters[k].update(v) 93 | 94 | def __getattr__(self, attr): 95 | if attr in self.meters: 96 | return self.meters[attr] 97 | if attr in self.__dict__: 98 | return self.__dict__[attr] 99 | raise AttributeError("'{}' object has no attribute '{}'".format( 100 | type(self).__name__, attr)) 101 | 102 | def __str__(self): 103 | loss_str = [] 104 | for name, meter in self.meters.items(): 105 | loss_str.append( 106 | "{}: {}".format(name, str(meter)) 107 | ) 108 | return self.delimiter.join(loss_str) 109 | 110 | def synchronize_between_processes(self): 111 | for meter in self.meters.values(): 112 | meter.synchronize_between_processes() 113 | 114 | def add_meter(self, name, meter): 115 | self.meters[name] = meter 116 | 117 | def log_every(self, iterable, print_freq, header=None): 118 | i = 0 119 | if not header: 120 | header = '' 121 | start_time = time.time() 122 | end = time.time() 123 | iter_time = SmoothedValue(fmt='{avg:.4f}') 124 | data_time = SmoothedValue(fmt='{avg:.4f}') 125 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 126 | log_msg = [ 127 | header, 128 | '[{0' + space_fmt + '}/{1}]', 129 | 'eta: {eta}', 130 | '{meters}', 131 | 'time: {time}', 132 | 'data: {data}' 133 | ] 134 | if torch.cuda.is_available(): 135 | log_msg.append('max mem: {memory:.0f}') 136 | log_msg = self.delimiter.join(log_msg) 137 | MB = 1024.0 * 1024.0 138 | for obj in iterable: 139 | data_time.update(time.time() - end) 140 | yield obj 141 | iter_time.update(time.time() - end) 142 | if i % print_freq == 0 or i == len(iterable) - 1: 143 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 144 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 145 | if torch.cuda.is_available(): 146 | #print 147 | _logger.info(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time), 151 | memory=torch.cuda.max_memory_allocated() / MB)) 152 | else: 153 | #print 154 | _logger.info(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | #print 163 | _logger.info('{} Total time: {} ({:.4f} s / it)'.format( 164 | header, total_time_str, total_time / len(iterable))) 165 | 166 | 167 | def _load_checkpoint_for_ema(model_ema, checkpoint): 168 | """ 169 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 170 | """ 171 | mem_file = io.BytesIO() 172 | torch.save(checkpoint, mem_file) 173 | mem_file.seek(0) 174 | model_ema._load_checkpoint(mem_file) 175 | 176 | 177 | def setup_for_distributed(is_master): 178 | """ 179 | This function disables printing when not in master process 180 | """ 181 | import builtins as __builtin__ 182 | builtin_print = __builtin__.print 183 | 184 | def print(*args, **kwargs): 185 | force = kwargs.pop('force', False) 186 | if is_master or force: 187 | builtin_print(*args, **kwargs) 188 | 189 | __builtin__.print = print 190 | 191 | 192 | def is_dist_avail_and_initialized(): 193 | if not dist.is_available(): 194 | return False 195 | if not dist.is_initialized(): 196 | return False 197 | return True 198 | 199 | 200 | def get_world_size(): 201 | if not is_dist_avail_and_initialized(): 202 | return 1 203 | return dist.get_world_size() 204 | 205 | 206 | def get_rank(): 207 | if not is_dist_avail_and_initialized(): 208 | return 0 209 | return dist.get_rank() 210 | 211 | 212 | def is_main_process(): 213 | return get_rank() == 0 214 | 215 | 216 | def save_on_master(*args, **kwargs): 217 | if is_main_process(): 218 | torch.save(*args, **kwargs) 219 | 220 | 221 | def init_distributed_mode(args): 222 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 223 | args.rank = int(os.environ["RANK"]) 224 | args.world_size = int(os.environ['WORLD_SIZE']) 225 | args.gpu = int(os.environ['LOCAL_RANK']) 226 | elif 'SLURM_PROCID' in os.environ: 227 | args.rank = int(os.environ['SLURM_PROCID']) 228 | args.gpu = args.rank % torch.cuda.device_count() 229 | else: 230 | print('Not using distributed mode') 231 | args.distributed = False 232 | return 233 | 234 | args.distributed = True 235 | 236 | torch.cuda.set_device(args.gpu) 237 | args.dist_backend = 'nccl' 238 | print('| distributed init (rank {}): {}'.format( 239 | args.rank, args.dist_url), flush=True) 240 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 241 | world_size=args.world_size, rank=args.rank) 242 | torch.distributed.barrier() 243 | setup_for_distributed(args.rank == 0) 244 | 245 | 246 | def update_from_config(args): 247 | cfg = mmcv.Config.fromfile(args.config) 248 | for _, cfg_item in cfg._cfg_dict.items(): 249 | for k, v in cfg_item.items(): 250 | setattr(args, k, v) 251 | return args 252 | 253 | @functools.lru_cache() 254 | def setup_logger( 255 | output=None, distributed_rank=0, *, color=True, name="train", abbrev_name=None 256 | ): 257 | logger = logging.getLogger(name) 258 | logger.setLevel(logging.DEBUG) 259 | logger.propagate = False 260 | 261 | if abbrev_name is None: 262 | abbrev_name = "xl" if name == "train" else name 263 | 264 | plain_formatter = logging.Formatter( 265 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" 266 | ) 267 | # stdout logging: master only 268 | if distributed_rank == 0: 269 | ch = logging.StreamHandler(stream=sys.stdout) 270 | ch.setLevel(logging.DEBUG) 271 | formatter = plain_formatter 272 | ch.setFormatter(formatter) 273 | logger.addHandler(ch) 274 | 275 | # file logging: all workers 276 | if output is not None: 277 | if output.endswith(".txt") or output.endswith(".log"): 278 | filename = output 279 | else: 280 | filename = os.path.join(output, "log.txt") 281 | if distributed_rank > 0: 282 | filename = filename + ".rank{}".format(distributed_rank) 283 | #PathManager.mkdirs(os.path.dirname(filename)) 284 | 285 | fh = logging.FileHandler(filename) #logging.StreamHandler(filename) 286 | fh.setLevel(logging.DEBUG) 287 | fh.setFormatter(plain_formatter) 288 | logger.addHandler(fh) 289 | 290 | return logger -------------------------------------------------------------------------------- /vanilla_architecture/README.md: -------------------------------------------------------------------------------- 1 | # SpectFormer Model for Image Classification 2 | 3 | Created by [Badri N. Patro](https://badripatro.github.io/), [Vinay P. Namboodiri](https://vinaypn.github.io/), [Vijay Srinivas Agneeswaran](https://in.linkedin.com/in/vijaysrinivasagneeswaran) 4 | 5 | ## Abstract 6 | 7 | ''' 8 | Vision transformers have been applied successfully for image recognition tasks. There have been either multi-headed self-attention based (ViT \cite{dosovitskiy2020image}, DeIT, \cite{touvron2021training}) similar to the original work in textual models or more recently based on spectral layers (Fnet\cite{lee2021fnet}, GFNet\cite{rao2021global}, AFNO\cite{guibas2021efficient}). We hypothesize that both spectral and multi-headed attention plays a major role. We investigate this hypothesis through this work and observe that indeed combining spectral and multi-headed attention layers provides a better transformer architecture. We thus propose the novel Spectformer architecture for transformers that combines spectral and multi-headed attention layers. We believe that the resulting representation allows the transformer to capture the feature representation appropriately and it yields improved performance over other transformer representations. For instance, it improves the top-1 accuracy by 2\% on ImageNet compared to both GFNet-H and LiT. SpectFormer-S reaches 84.25\% top-1 accuracy on ImageNet-1K (state of the art for small version). Further, Spectformer-L achieves 85.7\% that is the state of the art for the comparable base version of the transformers. We further ensure that we obtain reasonable results in other scenarios such as transfer learning on standard datasets such as CIFAR-10, CIFAR-100, Oxford-IIIT-flower, and Standford Car datasets. We then investigate its use in downstream tasks such of object detection and instance segmentation on MS-COCO dataset and observe that Spectformer shows consistent performance that is comparable to the best backbones and can be further optimized and improved. Hence, we believe that combined spectral and attention layers are what are needed for vision transformers. 9 | 10 | ''' 11 | 12 | ![Main Model](figs/SpectFormer_main.png) 13 | 14 | Our code is based on [pytorch-image-models](https://github.com/rwightman/pytorch-image-models), [DeiT](https://github.com/facebookresearch/deit), and [GFNet](https://github.com/raoyongming/GFNet). 15 | 16 | 17 | 18 | 19 | ## Model Zoo 20 | 21 | We provide variants of our SpectFormer models trained on ImageNet-1K: 22 | | name | Params | FLOPs | acc@1 | acc@5 | 23 | | --- | --- | --- | --- | --- | 24 | | SpectFormer-T | 9M | 1.8G | 76.9 | 93.4 | 25 | | SpectFormer-XS | 22M | 4.0G | 80.2 | 94.7 | 26 | | SpectFormer-S | 32M | 6.6G | 81.7 | 95.6 | 27 | | SpectFormer-B | 57M | 11.5G | 82.1 | 95.7 | 28 | 29 | 30 | 31 | 32 | ## Filter Visualization 33 | 34 | We visualize the learned filters initial layers of GFNet-XS and SpectFormer-XS . 35 | 36 | ![GFNet](figs/GFNet_filter.jpg) 37 | ![SpectFormer](figs/SpectFormer_filter.jpg) 38 | 39 | 40 | 41 | ### Requirements 42 | 43 | - python =3.8 (conda create -y --name spectformer python=3.8) 44 | - torch>=1.10.0 (conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch) 45 | - timm (pip install timm) 46 | 47 | 48 | **Data preparation**: download and extract ImageNet images from http://image-net.org/. The directory structure should be 49 | 50 | ``` 51 | │ILSVRC2012/ 52 | ├──train/ 53 | │ ├── n01440764 54 | │ │ ├── n01440764_10026.JPEG 55 | │ │ ├── n01440764_10027.JPEG 56 | │ │ ├── ...... 57 | │ ├── ...... 58 | ├──val/ 59 | │ ├── n01440764 60 | │ │ ├── ILSVRC2012_val_00000293.JPEG 61 | │ │ ├── ILSVRC2012_val_00002138.JPEG 62 | │ │ ├── ...... 63 | │ ├── ...... 64 | ``` 65 | 66 | ### Evaluation 67 | 68 | To evaluate a pre-trained spectformer model on the ImageNet validation set with a single GPU, run: 69 | 70 | ``` 71 | python infer.py --data-path /path/to/ILSVRC2012/ --arch arch_name --model-path /path/to/model 72 | ``` 73 | 74 | 75 | ### Training 76 | 77 | #### ImageNet 78 | 79 | To train spectformer models on ImageNet from scratch, run: 80 | 81 | ``` 82 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer.py --output_dir logs/spectformer-xs --arch spectformer-xs --batch-size 128 --data-path /path/to/ILSVRC2012/ 83 | ``` 84 | 85 | To finetune a pre-trained model at higher resolution, run: 86 | 87 | ``` 88 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer.py --output_dir logs/spectformer-xs-img384 --arch spectformer-xs --input-size 384 --batch-size 64 --data-path /path/to/ILSVRC2012/ --lr 5e-6 --weight-decay 1e-8 --min-lr 5e-6 --epochs 30 --finetune /path/to/model 89 | ``` 90 | 91 | #### Transfer Learning Datasets 92 | 93 | To finetune a pre-trained model on a transfer learning dataset, run: 94 | ``` 95 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-xs-cars --arch spectformer-xs --batch-size 64 --data-set CARS --data-path /path/to/stanford_cars --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune /path/to/model 96 | ``` 97 | 98 | -------------------------------------------------------------------------------- /vanilla_architecture/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | from torchvision import datasets, transforms 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets.folder import ImageFolder, default_loader 7 | 8 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 9 | from timm.data import create_transform 10 | """ Stanford Cars (Car) Dataset 11 | Created: Nov 15,2019 - Yuchong Gu 12 | Revised: Nov 15,2019 - Yuchong Gu 13 | """ 14 | import os 15 | # import pdb 16 | from PIL import Image 17 | import pickle 18 | # from scipy.io import loadmat 19 | 20 | 21 | 22 | class CarsDataset(Dataset): 23 | """ 24 | # Description: 25 | Dataset for retrieving Stanford Cars images and labels 26 | # Member Functions: 27 | __init__(self, phase, resize): initializes a dataset 28 | phase: a string in ['train', 'val', 'test'] 29 | resize: output shape/size of an image 30 | __getitem__(self, item): returns an image 31 | item: the idex of image in the whole dataset 32 | __len__(self): returns the length of dataset 33 | """ 34 | 35 | def __init__(self, root, train=True, transform=None): 36 | self.root = root 37 | self.phase = 'train' if train else 'test' 38 | # self.resize = resize 39 | self.num_classes = 196 40 | 41 | self.images = [] 42 | self.labels = [] 43 | 44 | list_path = os.path.join(root, 'cars_anno.pkl') 45 | 46 | list_mat = pickle.load(open(list_path, 'rb')) 47 | num_inst = len(list_mat['annotations']['relative_im_path'][0]) 48 | for i in range(num_inst): 49 | if self.phase == 'train' and list_mat['annotations']['test'][0][i].item() == 0: 50 | path = list_mat['annotations']['relative_im_path'][0][i].item() 51 | label = list_mat['annotations']['class'][0][i].item() 52 | self.images.append(path) 53 | self.labels.append(label) 54 | elif self.phase != 'train' and list_mat['annotations']['test'][0][i].item() == 1: 55 | path = list_mat['annotations']['relative_im_path'][0][i].item() 56 | label = list_mat['annotations']['class'][0][i].item() 57 | self.images.append(path) 58 | self.labels.append(label) 59 | 60 | print('Car Dataset with {} instances for {} phase'.format(len(self.images), self.phase)) 61 | 62 | # transform 63 | self.transform = transform 64 | 65 | def __getitem__(self, item): 66 | # image 67 | image = Image.open(os.path.join(self.root, self.images[item])).convert('RGB') # (C, H, W) 68 | image = self.transform(image) 69 | 70 | # return image and label 71 | return image, self.labels[item] - 1 # count begin from zero 72 | 73 | def __len__(self): 74 | return len(self.images) 75 | 76 | 77 | class INatDataset(ImageFolder): 78 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 79 | category='name', loader=default_loader): 80 | self.transform = transform 81 | self.loader = loader 82 | self.target_transform = target_transform 83 | self.year = year 84 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 85 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 86 | with open(path_json) as json_file: 87 | data = json.load(json_file) 88 | 89 | with open(os.path.join(root, 'categories.json')) as json_file: 90 | data_catg = json.load(json_file) 91 | 92 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 93 | 94 | with open(path_json_for_targeter) as json_file: 95 | data_for_targeter = json.load(json_file) 96 | 97 | targeter = {} 98 | indexer = 0 99 | for elem in data_for_targeter['annotations']: 100 | king = [] 101 | king.append(data_catg[int(elem['category_id'])][category]) 102 | if king[0] not in targeter.keys(): 103 | targeter[king[0]] = indexer 104 | indexer += 1 105 | self.nb_classes = len(targeter) 106 | 107 | self.samples = [] 108 | for elem in data['images']: 109 | cut = elem['file_name'].split('/') 110 | target_current = int(cut[2]) 111 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 112 | 113 | categors = data_catg[target_current] 114 | target_current_true = targeter[categors[category]] 115 | self.samples.append((path_current, target_current_true)) 116 | 117 | # __getitem__ and __len__ inherited from ImageFolder 118 | 119 | 120 | def build_dataset(is_train, args, infer_no_resize=False): 121 | transform = build_transform(is_train, args, infer_no_resize) 122 | 123 | if args.data_set == 'CIFAR100': 124 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 125 | nb_classes = 100 126 | elif args.data_set == 'CIFAR10': 127 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True) 128 | nb_classes = 10 129 | elif args.data_set == 'CARS': 130 | dataset = CarsDataset(args.data_path, train=is_train, transform=transform) 131 | nb_classes = 196 132 | elif args.data_set == 'FLOWERS': 133 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 134 | dataset = datasets.ImageFolder(root, transform=transform) 135 | nb_classes = 102 136 | elif args.data_set == 'IMNET': 137 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 138 | dataset = datasets.ImageFolder(root, transform=transform) 139 | nb_classes = 1000 140 | elif args.data_set == 'INAT': 141 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 142 | category=args.inat_category, transform=transform) 143 | nb_classes = dataset.nb_classes 144 | elif args.data_set == 'INAT19': 145 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 146 | category=args.inat_category, transform=transform) 147 | nb_classes = dataset.nb_classes 148 | 149 | return dataset, nb_classes 150 | 151 | 152 | def build_transform(is_train, args, infer_no_resize=False): 153 | if hasattr(args, 'arch'): 154 | if 'cait' in args.arch and not is_train: 155 | print('# using cait eval transform') 156 | transformations = {} 157 | transformations= transforms.Compose( 158 | [transforms.Resize(args.input_size, interpolation=3), 159 | transforms.CenterCrop(args.input_size), 160 | transforms.ToTensor(), 161 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 162 | return transformations 163 | 164 | if infer_no_resize: 165 | print('# using cait eval transform') 166 | transformations = {} 167 | transformations= transforms.Compose( 168 | [transforms.Resize(args.input_size, interpolation=3), 169 | transforms.CenterCrop(args.input_size), 170 | transforms.ToTensor(), 171 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 172 | return transformations 173 | 174 | resize_im = args.input_size > 32 175 | if is_train: 176 | # this should always dispatch to transforms_imagenet_train 177 | transform = create_transform( 178 | input_size=args.input_size, 179 | is_training=True, 180 | color_jitter=args.color_jitter, 181 | auto_augment=args.aa, 182 | interpolation=args.train_interpolation, 183 | re_prob=args.reprob, 184 | re_mode=args.remode, 185 | re_count=args.recount, 186 | ) 187 | if not resize_im: 188 | # replace RandomResizedCropAndInterpolation with 189 | # RandomCrop 190 | transform.transforms[0] = transforms.RandomCrop( 191 | args.input_size, padding=4) 192 | return transform 193 | 194 | t = [] 195 | if resize_im: 196 | size = int((256 / 224) * args.input_size) 197 | t.append( 198 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 199 | ) 200 | t.append(transforms.CenterCrop(args.input_size)) 201 | 202 | t.append(transforms.ToTensor()) 203 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 204 | return transforms.Compose(t) 205 | -------------------------------------------------------------------------------- /vanilla_architecture/engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | """ 4 | import math 5 | import sys 6 | from typing import Iterable, Optional 7 | 8 | import torch 9 | 10 | from timm.data import Mixup 11 | from timm.utils import accuracy, ModelEma 12 | 13 | from losses import DistillationLoss 14 | import utils 15 | 16 | import random 17 | 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 23 | set_training_mode=True): 24 | model.train(set_training_mode) 25 | metric_logger = utils.MetricLogger(delimiter=" ") 26 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | print_freq = 10 29 | 30 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 31 | samples = samples.to(device, non_blocking=True) 32 | targets = targets.to(device, non_blocking=True) 33 | 34 | if mixup_fn is not None: 35 | samples, targets = mixup_fn(samples, targets) 36 | 37 | with torch.cuda.amp.autocast(): 38 | outputs = model(samples) 39 | loss = criterion(samples, outputs, targets) 40 | 41 | loss_value = loss.item() 42 | 43 | if not math.isfinite(loss_value): 44 | print("Loss is {}, stopping training".format(loss_value)) 45 | sys.exit(1) 46 | 47 | optimizer.zero_grad() 48 | 49 | # this attribute is added by timm on one optimizer (adahessian) 50 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 51 | loss_scaler(loss, optimizer, clip_grad=max_norm, 52 | parameters=model.parameters(), create_graph=is_second_order) 53 | 54 | torch.cuda.synchronize() 55 | if model_ema is not None: 56 | model_ema.update(model) 57 | 58 | metric_logger.update(loss=loss_value) 59 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 60 | # gather the stats from all processes 61 | metric_logger.synchronize_between_processes() 62 | print("Averaged stats:", metric_logger) 63 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 64 | 65 | 66 | @torch.no_grad() 67 | def evaluate(data_loader, model, device): 68 | criterion = torch.nn.CrossEntropyLoss() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | header = 'Test:' 72 | 73 | # switch to evaluation mode 74 | model.eval() 75 | 76 | for images, target in metric_logger.log_every(data_loader, 10, header): 77 | images = images.to(device, non_blocking=True) 78 | target = target.to(device, non_blocking=True) 79 | 80 | # compute output 81 | with torch.cuda.amp.autocast(): 82 | output = model(images) 83 | loss = criterion(output, target) 84 | 85 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 86 | 87 | batch_size = images.shape[0] 88 | metric_logger.update(loss=loss.item()) 89 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 90 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 91 | # gather the stats from all processes 92 | metric_logger.synchronize_between_processes() 93 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 94 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 95 | 96 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 97 | -------------------------------------------------------------------------------- /vanilla_architecture/figs/GFNet_filter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/GFNet_filter.jpg -------------------------------------------------------------------------------- /vanilla_architecture/figs/SpectFormer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/SpectFormer.png -------------------------------------------------------------------------------- /vanilla_architecture/figs/SpectFormer_filter.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/SpectFormer_filter.jpg -------------------------------------------------------------------------------- /vanilla_architecture/figs/SpectFormer_main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/SpectFormer_main.png -------------------------------------------------------------------------------- /vanilla_architecture/figs/inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/inference.png -------------------------------------------------------------------------------- /vanilla_architecture/figs/sota.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/badripatro/SpectFormers/b6cf487a0ea71eeb252d5ba0fc7410b3f5dbd256/vanilla_architecture/figs/sota.jpg -------------------------------------------------------------------------------- /vanilla_architecture/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.backends.cudnn as cudnn 8 | import json 9 | 10 | from pathlib import Path 11 | 12 | from timm.data import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.scheduler import create_scheduler 16 | from timm.optim import create_optimizer 17 | from timm.utils import NativeScaler, get_state_dict, ModelEma 18 | 19 | from datasets import build_dataset 20 | from engine import train_one_epoch, evaluate 21 | from losses import DistillationLoss 22 | from samplers import RASampler 23 | import utils 24 | from functools import partial 25 | 26 | from spectformer import SpectFormer, _cfg 27 | 28 | def get_args_parser(): 29 | parser = argparse.ArgumentParser('spectformer evaluation script', add_help=False) 30 | parser.add_argument('--batch-size', default=128, type=int) 31 | parser.add_argument('--arch', default='deit_small', type=str, help='Name of model to train') 32 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 33 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 34 | help='dataset path') 35 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 36 | type=str, help='Image Net dataset path') 37 | parser.add_argument('--inat-category', default='name', 38 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 39 | type=str, help='semantic granularity') 40 | parser.add_argument('--seed', default=0, type=int) 41 | parser.add_argument('--model-path', default='', help='resume from checkpoint') 42 | parser.add_argument('--num_workers', default=10, type=int) 43 | parser.add_argument('--pin-mem', action='store_true', 44 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 45 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 46 | help='') 47 | parser.set_defaults(pin_mem=True) 48 | return parser 49 | 50 | 51 | def main(args): 52 | 53 | cudnn.benchmark = True 54 | dataset_val, _ = build_dataset(is_train=False, args=args) 55 | 56 | data_loader_val = torch.utils.data.DataLoader( 57 | dataset_val, 58 | batch_size=128, 59 | num_workers=args.num_workers, 60 | pin_memory=args.pin_mem, 61 | drop_last=False 62 | ) 63 | 64 | if args.arch == 'spectformer-xs': 65 | model = SpectFormer( 66 | img_size=args.input_size, 67 | patch_size=16, embed_dim=384, depth=12, mlp_ratio=4, 68 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 69 | ) 70 | elif args.arch == 'spectformer-ti': 71 | model = SpectFormer( 72 | img_size=args.input_size, 73 | patch_size=16, embed_dim=256, depth=12, mlp_ratio=4, 74 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 75 | ) 76 | elif args.arch == 'spectformer-s': 77 | model = SpectFormer( 78 | img_size=args.input_size, 79 | patch_size=16, embed_dim=384, depth=19, mlp_ratio=4, 80 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 81 | ) 82 | elif args.arch == 'spectformer-b': 83 | model = SpectFormer( 84 | img_size=args.input_size, 85 | patch_size=16, embed_dim=512, depth=19, mlp_ratio=4, 86 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 87 | ) 88 | else: 89 | raise NotImplementedError 90 | 91 | model_path = args.model_path 92 | model.default_cfg = _cfg() 93 | 94 | checkpoint = torch.load(model_path, map_location="cpu") 95 | model.load_state_dict(checkpoint["model"]) 96 | 97 | print('## model has been successfully loaded') 98 | 99 | model = model.cuda() 100 | 101 | n_parameters = sum(p.numel() for p in model.parameters()) 102 | print('number of params:', n_parameters) 103 | 104 | criterion = torch.nn.CrossEntropyLoss().cuda() 105 | validate(data_loader_val, model, criterion) 106 | 107 | class AverageMeter(object): 108 | """Computes and stores the average and current value""" 109 | def __init__(self, name, fmt=':f'): 110 | self.name = name 111 | self.fmt = fmt 112 | self.reset() 113 | 114 | def reset(self): 115 | self.val = 0 116 | self.avg = 0 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | self.avg = self.sum / self.count 125 | 126 | def __str__(self): 127 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 128 | return fmtstr.format(**self.__dict__) 129 | 130 | class ProgressMeter(object): 131 | def __init__(self, num_batches, meters, prefix=""): 132 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 133 | self.meters = meters 134 | self.prefix = prefix 135 | 136 | def display(self, batch): 137 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 138 | entries += [str(meter) for meter in self.meters] 139 | print('\t'.join(entries)) 140 | 141 | def _get_batch_fmtstr(self, num_batches): 142 | num_digits = len(str(num_batches // 1)) 143 | fmt = '{:' + str(num_digits) + 'd}' 144 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 145 | 146 | def accuracy(output, target, topk=(1,)): 147 | """Computes the accuracy over the k top predictions for the specified values of k""" 148 | with torch.no_grad(): 149 | maxk = max(topk) 150 | batch_size = target.size(0) 151 | 152 | _, pred = output.topk(maxk, 1, True, True) 153 | pred = pred.t() 154 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 155 | 156 | res = [] 157 | for k in topk: 158 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 159 | res.append(correct_k.mul_(100.0 / batch_size)) 160 | return res 161 | 162 | def validate(val_loader, model, criterion): 163 | batch_time = AverageMeter('Time', ':6.3f') 164 | losses = AverageMeter('Loss', ':.4e') 165 | top1 = AverageMeter('Acc@1', ':6.2f') 166 | top5 = AverageMeter('Acc@5', ':6.2f') 167 | model.eval() 168 | 169 | progress = ProgressMeter( 170 | len(val_loader), 171 | [batch_time, losses, top1, top5], 172 | prefix='Test: ') 173 | 174 | 175 | with torch.no_grad(): 176 | end = time.time() 177 | for i, (images, target) in enumerate(val_loader): 178 | images = images.cuda() 179 | target = target.cuda() 180 | 181 | # compute output 182 | output = model(images) 183 | loss = criterion(output, target) 184 | 185 | # measure accuracy and record loss 186 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 187 | losses.update(loss.item(), images.size(0)) 188 | top1.update(acc1[0], images.size(0)) 189 | top5.update(acc5[0], images.size(0)) 190 | 191 | # measure elapsed time 192 | batch_time.update(time.time() - end) 193 | end = time.time() 194 | 195 | if i % 20 == 0: 196 | progress.display(i) 197 | 198 | # TODO: this should also be done with the ProgressMeter 199 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 200 | .format(top1=top1, top5=top5)) 201 | 202 | return top1.avg 203 | 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser('spectformer evaluation script', parents=[get_args_parser()]) 208 | args = parser.parse_args() 209 | main(args) 210 | -------------------------------------------------------------------------------- /vanilla_architecture/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | from abc import get_cache_token 7 | import torch 8 | from torch.nn import functional as F 9 | from torch.nn.modules.loss import MSELoss, BCEWithLogitsLoss, CrossEntropyLoss 10 | from utils import batch_index_select 11 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 12 | import math 13 | 14 | class DistillationLoss(torch.nn.Module): 15 | """ 16 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 17 | taking a teacher model prediction and using it as additional supervision. 18 | """ 19 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 20 | distillation_type: str, alpha: float, tau: float): 21 | super().__init__() 22 | self.base_criterion = base_criterion 23 | self.teacher_model = teacher_model 24 | assert distillation_type in ['none', 'soft', 'hard'] 25 | self.distillation_type = distillation_type 26 | self.alpha = alpha 27 | self.tau = tau 28 | 29 | def forward(self, inputs, outputs, labels): 30 | """ 31 | Args: 32 | inputs: The original inputs that are feed to the teacher model 33 | outputs: the outputs of the model to be trained. It is expected to be 34 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 35 | in the first position and the distillation predictions as the second output 36 | labels: the labels for the base criterion 37 | """ 38 | outputs_kd = None 39 | if not isinstance(outputs, torch.Tensor): 40 | # assume that the model outputs a tuple of [outputs, outputs_kd] 41 | outputs, outputs_kd = outputs 42 | base_loss = self.base_criterion(outputs, labels) 43 | if self.distillation_type == 'none': 44 | return base_loss 45 | 46 | if outputs_kd is None: 47 | raise ValueError("When knowledge distillation is enabled, the model is " 48 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 49 | "class_token and the dist_token") 50 | # don't backprop throught the teacher 51 | with torch.no_grad(): 52 | teacher_outputs = self.teacher_model(inputs) 53 | 54 | if self.distillation_type == 'soft': 55 | T = self.tau 56 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 57 | # with slight modifications 58 | distillation_loss = F.kl_div( 59 | F.log_softmax(outputs_kd / T, dim=1), 60 | F.log_softmax(teacher_outputs / T, dim=1), 61 | reduction='sum', 62 | log_target=True 63 | ) * (T * T) / outputs_kd.numel() 64 | elif self.distillation_type == 'hard': 65 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 66 | 67 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 68 | return loss 69 | -------------------------------------------------------------------------------- /vanilla_architecture/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Current path is $PATH" 4 | echo "Running" 5 | nvidia-smi 6 | echo $CUDA_VISIBLE_DEVICES 7 | 8 | # CUDA_LAUNCH_BLOCKING=1 9 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer.py --output_dir logs/spectformer-xs --arch spectformer-xs --batch-size 128 --data-path ../../../dataset/Image_net/imagenet/ --data-set IMNET --num_workers 12 --epochs 320 10 | 11 | # Transfer Learning 12 | # python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-b-cifar10 --arch spectformer-b --batch-size 64 --data-set CIFAR10 --data-path ../../../dataset/Image_net/ --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune ../spectformer_mvt_b/logs/spectformer-b/checkpoint_best.pth 13 | # python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-b-cifar100 --arch spectformer-b --batch-size 64 --data-set CIFAR100 --data-path ../../../dataset/Image_net/ --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune ../spectformer_mvt_b/logs/spectformer-b/checkpoint_best.pth 14 | # python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-b-FLOWERS --arch spectformer-b --batch-size 64 --data-set FLOWERS --data-path ../../../dataset/Image_net/flowers --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune ../spectformer_mvt_b/logs/spectformer-b/checkpoint_best.pth 15 | # python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-b-pet --arch spectformer-b --batch-size 64 --data-set PET --data-path ../../../dataset/Image_net/pets --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune ../spectformer_mvt_b/logs/spectformer-b/checkpoint_best.pth 16 | # python -m torch.distributed.launch --nproc_per_node=8 --use_env main_spectformer_transfer.py --output_dir logs/spectformer-b-car --arch spectformer-b --batch-size 64 --data-set CARS --data-path ../../../dataset/Image_net/cars --epochs 1000 --lr 0.0001 --weight-decay 1e-4 --clip-grad 1 --warmup-epochs 5 --finetune ../spectformer_mvt_b/logs/spectformer-b/checkpoint_best.pth 17 | -------------------------------------------------------------------------------- /vanilla_architecture/main_spectformer.py: -------------------------------------------------------------------------------- 1 | # Code Adapted from GFNet(https://github.com/raoyongming/GFNet) 2 | import argparse 3 | import datetime 4 | import numpy as np 5 | import time 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import json 9 | 10 | from pathlib import Path 11 | 12 | from timm.data import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.scheduler import create_scheduler 16 | from timm.optim import create_optimizer 17 | from timm.utils import NativeScaler, get_state_dict, ModelEma 18 | from functools import partial 19 | import torch.nn as nn 20 | import matplotlib 21 | matplotlib.use('agg') 22 | import matplotlib.pyplot as plt 23 | 24 | from datasets import build_dataset 25 | from engine import train_one_epoch, evaluate 26 | from losses import DistillationLoss 27 | from samplers import RASampler 28 | import utils 29 | from spectformer import SpectFormer 30 | 31 | import warnings 32 | warnings.filterwarnings("ignore", message="Argument interpolation should be") 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 36 | parser.add_argument('--batch-size', default=64, type=int) 37 | parser.add_argument('--epochs', default=300, type=int) 38 | 39 | # Model parameters 40 | parser.add_argument('--arch', default='deit_small', type=str, 41 | help='Name of model to train') 42 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 43 | 44 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 45 | help='Dropout rate (default: 0.)') 46 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 47 | help='Drop path rate (default: 0.1)') 48 | 49 | parser.add_argument('--model-ema', action='store_true') 50 | parser.set_defaults(model_ema=False) 51 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 52 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 53 | 54 | # Optimizer parameters 55 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 56 | help='Optimizer (default: "adamw"') 57 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 58 | help='Optimizer Epsilon (default: 1e-8)') 59 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 60 | help='Optimizer Betas (default: None, use opt default)') 61 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', 62 | help='Clip gradient norm (default: None, no clipping)') 63 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 64 | help='SGD momentum (default: 0.9)') 65 | parser.add_argument('--weight-decay', type=float, default=0.05, 66 | help='weight decay (default: 0.05)') 67 | # Learning rate schedule parameters 68 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 69 | help='LR scheduler (default: "cosine"') 70 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 71 | help='learning rate (default: 5e-4)') 72 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 73 | help='learning rate noise on/off epoch percentages') 74 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 75 | help='learning rate noise limit percent (default: 0.67)') 76 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 77 | help='learning rate noise std-dev (default: 1.0)') 78 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 79 | help='warmup learning rate (default: 1e-6)') 80 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 81 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 82 | 83 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 84 | help='epoch interval to decay LR') 85 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 86 | help='epochs to warmup LR, if scheduler supports') 87 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 88 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 89 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 90 | help='patience epochs for Plateau LR scheduler (default: 10') 91 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 92 | help='LR decay rate (default: 0.1)') 93 | 94 | # Augmentation parameters 95 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 96 | help='Color jitter factor (default: 0.4)') 97 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 98 | help='Use AutoAugment policy. "v0" or "original". " + \ 99 | "(default: rand-m9-mstd0.5-inc1)'), 100 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 101 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 102 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 103 | 104 | parser.add_argument('--repeated-aug', action='store_true') 105 | parser.set_defaults(repeated_aug=False) 106 | 107 | # * Random Erase params 108 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', 109 | help='Random erase prob (default: 0.25)') 110 | parser.add_argument('--remode', type=str, default='pixel', 111 | help='Random erase mode (default: "pixel")') 112 | parser.add_argument('--recount', type=int, default=1, 113 | help='Random erase count (default: 1)') 114 | parser.add_argument('--resplit', action='store_true', default=False, 115 | help='Do not random erase first (clean) augmentation split') 116 | 117 | # * Mixup params 118 | parser.add_argument('--mixup', type=float, default=0.8, 119 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 120 | parser.add_argument('--cutmix', type=float, default=1.0, 121 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 122 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 123 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 124 | parser.add_argument('--mixup-prob', type=float, default=1.0, 125 | help='Probability of performing mixup or cutmix when either/both is enabled') 126 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 127 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 128 | parser.add_argument('--mixup-mode', type=str, default='batch', 129 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 130 | 131 | # Distillation parameters 132 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 133 | help='Name of teacher model to train (default: "regnety_160"') 134 | parser.add_argument('--teacher-path', type=str, default='') 135 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 136 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 137 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 138 | 139 | # * Finetuning params 140 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 141 | 142 | # Dataset parameters 143 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 144 | help='dataset path') 145 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 146 | type=str, help='Image Net dataset path') 147 | parser.add_argument('--inat-category', default='name', 148 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 149 | type=str, help='semantic granularity') 150 | 151 | parser.add_argument('--output_dir', default='', 152 | help='path where to save, empty for no saving') 153 | parser.add_argument('--device', default='cuda', 154 | help='device to use for training / testing') 155 | parser.add_argument('--seed', default=0, type=int) 156 | parser.add_argument('--resume', default='', help='resume from checkpoint') 157 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 158 | help='start epoch') 159 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 160 | parser.add_argument('--dist-eval', action='store_true', default=True, help='Enabling distributed evaluation') 161 | parser.add_argument('--num_workers', default=10, type=int) 162 | parser.add_argument('--pin-mem', action='store_true', 163 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 164 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 165 | help='') 166 | parser.set_defaults(pin_mem=True) 167 | 168 | # distributed training parameters 169 | parser.add_argument('--world_size', default=1, type=int, 170 | help='number of distributed processes') 171 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 172 | return parser 173 | 174 | 175 | def main(args): 176 | utils.init_distributed_mode(args) 177 | 178 | print(args) 179 | 180 | if args.distillation_type != 'none' and args.finetune and not args.eval: 181 | raise NotImplementedError("Finetuning with distillation not yet supported") 182 | 183 | device = torch.device(args.device) 184 | 185 | # fix the seed for reproducibility 186 | seed = args.seed + utils.get_rank() 187 | torch.manual_seed(seed) 188 | np.random.seed(seed) 189 | # random.seed(seed) 190 | train_loss_history=[] 191 | val_loss_history=[] 192 | val_acc1_history=[] 193 | val_acc5_history=[] 194 | 195 | cudnn.benchmark = True 196 | 197 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 198 | dataset_val, _ = build_dataset(is_train=False, args=args) 199 | 200 | if True: # args.distributed: 201 | num_tasks = utils.get_world_size() 202 | global_rank = utils.get_rank() 203 | if args.repeated_aug: 204 | sampler_train = RASampler( 205 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 206 | ) 207 | else: 208 | sampler_train = torch.utils.data.DistributedSampler( 209 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 210 | ) 211 | if args.dist_eval: 212 | if len(dataset_val) % num_tasks != 0: 213 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 214 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 215 | 'equal num of samples per-process.') 216 | sampler_val = torch.utils.data.DistributedSampler( 217 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 218 | else: 219 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 220 | else: 221 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 222 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 223 | 224 | data_loader_train = torch.utils.data.DataLoader( 225 | dataset_train, sampler=sampler_train, 226 | batch_size=args.batch_size, 227 | num_workers=args.num_workers, 228 | pin_memory=args.pin_mem, 229 | drop_last=True, 230 | ) 231 | 232 | data_loader_val = torch.utils.data.DataLoader( 233 | dataset_val, sampler=sampler_val, 234 | batch_size=int(1.5 * args.batch_size), 235 | num_workers=args.num_workers, 236 | pin_memory=args.pin_mem, 237 | drop_last=False 238 | ) 239 | 240 | mixup_fn = None 241 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 242 | if mixup_active: 243 | print('standard mix up') 244 | mixup_fn = Mixup( 245 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 246 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 247 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 248 | else: 249 | print('mix up is not used') 250 | 251 | print(f"Creating model: {args.arch}") 252 | 253 | if args.arch == 'spectformer-xs': 254 | model = SpectFormer( 255 | img_size=args.input_size, 256 | patch_size=16, embed_dim=384, depth=12, mlp_ratio=4, 257 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 258 | ) 259 | elif args.arch == 'spectformer-ti': 260 | model = SpectFormer( 261 | img_size=args.input_size, 262 | patch_size=16, embed_dim=256, depth=12, mlp_ratio=4, 263 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 264 | ) 265 | elif args.arch == 'spectformer-s': 266 | model = SpectFormer( 267 | img_size=args.input_size, 268 | patch_size=16, embed_dim=384, depth=19, mlp_ratio=4, drop_path_rate=0.15, 269 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 270 | ) 271 | elif args.arch == 'spectformer-b': 272 | model = SpectFormer( 273 | img_size=args.input_size, 274 | patch_size=16, embed_dim=512, depth=19, mlp_ratio=4, drop_path_rate=0.25, 275 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 276 | ) 277 | else: 278 | raise NotImplementedError 279 | 280 | if args.finetune: 281 | if args.finetune.startswith('https'): 282 | checkpoint = torch.hub.load_state_dict_from_url( 283 | args.finetune, map_location='cpu', check_hash=True) 284 | else: 285 | checkpoint = torch.load(args.finetune, map_location='cpu') 286 | 287 | checkpoint_model = checkpoint['model'] 288 | state_dict = model.state_dict() 289 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 290 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 291 | print(f"Removing key {k} from pretrained checkpoint") 292 | del checkpoint_model[k] 293 | 294 | # interpolate position embedding 295 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 296 | embedding_size = pos_embed_checkpoint.shape[-1] 297 | 298 | if args.arch in ['spectformer-ti', 'spectformer-xs', 'spectformer-s', 'spectformer-b']: 299 | num_patches = (args.input_size // 16) ** 2 300 | elif args.arch in ['spectformer-h-ti', 'spectformer-h-s', 'spectformer-h-b']: 301 | num_patches = (args.input_size // 4) ** 2 302 | else: 303 | raise NotImplementedError 304 | 305 | num_extra_tokens = 0 306 | # height (== width) for the checkpoint position embedding 307 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 308 | # height (== width) for the new position embedding 309 | new_size = int(num_patches ** 0.5) 310 | 311 | scale_up_ratio = new_size / orig_size 312 | # class_token and dist_token are kept unchanged 313 | # only the position tokens are interpolated 314 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 315 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 316 | pos_tokens = torch.nn.functional.interpolate( 317 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 318 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 319 | checkpoint_model['pos_embed'] = pos_tokens 320 | 321 | for name in checkpoint_model.keys(): 322 | if 'complex_weight' in name: 323 | h, w, num_heads = checkpoint_model[name].shape[0:3] # h, w, c, 2 324 | origin_weight = checkpoint_model[name] 325 | upsample_h = h * new_size // orig_size 326 | upsample_w = upsample_h // 2 + 1 327 | origin_weight = origin_weight.reshape(1, h, w, num_heads * 2).permute(0, 3, 1, 2) 328 | new_weight = torch.nn.functional.interpolate( 329 | origin_weight, size=(upsample_h, upsample_w), mode='bicubic', align_corners=True).permute(0, 2, 3, 1).reshape(upsample_h, upsample_w, num_heads, 2) 330 | checkpoint_model[name] = new_weight 331 | model.load_state_dict(checkpoint_model, strict=True) 332 | 333 | model.to(device) 334 | 335 | model_ema = None 336 | if args.model_ema: 337 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 338 | model_ema = ModelEma( 339 | model, 340 | decay=args.model_ema_decay, 341 | device='cpu' if args.model_ema_force_cpu else '', 342 | resume='') 343 | 344 | model_without_ddp = model 345 | if args.distributed: 346 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 347 | model_without_ddp = model.module 348 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 349 | print('number of params:', n_parameters) 350 | 351 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 352 | args.lr = linear_scaled_lr 353 | optimizer = create_optimizer(args, model_without_ddp) 354 | loss_scaler = NativeScaler() 355 | 356 | lr_scheduler, _ = create_scheduler(args, optimizer) 357 | 358 | criterion = LabelSmoothingCrossEntropy() 359 | 360 | if args.mixup > 0.: 361 | # smoothing is handled with mixup label transform 362 | criterion = SoftTargetCrossEntropy() 363 | elif args.smoothing: 364 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 365 | else: 366 | criterion = torch.nn.CrossEntropyLoss() 367 | 368 | teacher_model = None 369 | if args.distillation_type != 'none': 370 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 371 | print(f"Creating teacher model: {args.teacher_model}") 372 | teacher_model = create_model( 373 | args.teacher_model, 374 | pretrained=False, 375 | num_classes=args.nb_classes, 376 | global_pool='avg', 377 | ) 378 | if args.teacher_path.startswith('https'): 379 | checkpoint = torch.hub.load_state_dict_from_url( 380 | args.teacher_path, map_location='cpu', check_hash=True) 381 | else: 382 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 383 | teacher_model.load_state_dict(checkpoint['model']) 384 | teacher_model.to(device) 385 | teacher_model.eval() 386 | 387 | # wrap the criterion in our custom DistillationLoss, which 388 | # just dispatches to the original criterion if args.distillation_type is 'none' 389 | 390 | criterion = DistillationLoss( 391 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 392 | ) 393 | 394 | output_dir = Path(args.output_dir) 395 | if args.resume: 396 | if args.resume.startswith('https'): 397 | checkpoint = torch.hub.load_state_dict_from_url( 398 | args.resume, map_location='cpu', check_hash=True) 399 | else: 400 | checkpoint = torch.load(args.resume, map_location='cpu') 401 | model_without_ddp.load_state_dict(checkpoint['model']) 402 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 403 | optimizer.load_state_dict(checkpoint['optimizer']) 404 | print('lr scheduler will not be updated') 405 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 406 | args.start_epoch = checkpoint['epoch'] + 1 407 | if args.model_ema: 408 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 409 | if 'scaler' in checkpoint: 410 | loss_scaler.load_state_dict(checkpoint['scaler']) 411 | 412 | if args.eval: 413 | test_stats = evaluate(data_loader_val, model, device) 414 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 415 | return 416 | 417 | print(f"Start training for {args.epochs} epochs") 418 | start_time = time.time() 419 | max_accuracy = 0.0 420 | for epoch in range(args.start_epoch, args.epochs): 421 | if args.distributed: 422 | data_loader_train.sampler.set_epoch(epoch) 423 | 424 | train_stats = train_one_epoch( 425 | model, criterion, data_loader_train, 426 | optimizer, device, epoch, loss_scaler, 427 | args.clip_grad, model_ema, mixup_fn, 428 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 429 | ) 430 | 431 | lr_scheduler.step(epoch) 432 | 433 | if args.output_dir: 434 | checkpoint_paths = [output_dir / 'checkpoint_last.pth'] 435 | for checkpoint_path in checkpoint_paths: 436 | if model_ema is not None: 437 | utils.save_on_master({ 438 | 'model': model_without_ddp.state_dict(), 439 | 'optimizer': optimizer.state_dict(), 440 | 'lr_scheduler': lr_scheduler.state_dict(), 441 | 'epoch': epoch, 442 | 'model_ema': get_state_dict(model_ema), 443 | 'scaler': loss_scaler.state_dict(), 444 | 'args': args, 445 | }, checkpoint_path) 446 | else: 447 | utils.save_on_master({ 448 | 'model': model_without_ddp.state_dict(), 449 | 'optimizer': optimizer.state_dict(), 450 | 'lr_scheduler': lr_scheduler.state_dict(), 451 | 'epoch': epoch, 452 | 'scaler': loss_scaler.state_dict(), 453 | 'args': args, 454 | }, checkpoint_path) 455 | 456 | if (epoch + 1) % 100 == 0: #modified 20 epoch to 100 epoch 457 | file_name = 'checkpoint_epoch%d.pth' % epoch 458 | checkpoint_path = output_dir / file_name 459 | if model_ema is not None: 460 | utils.save_on_master({ 461 | 'model': model_without_ddp.state_dict(), 462 | 'optimizer': optimizer.state_dict(), 463 | 'lr_scheduler': lr_scheduler.state_dict(), 464 | 'epoch': epoch, 465 | 'model_ema': get_state_dict(model_ema), 466 | 'scaler': loss_scaler.state_dict(), 467 | 'args': args, 468 | }, checkpoint_path) 469 | else: 470 | utils.save_on_master({ 471 | 'model': model_without_ddp.state_dict(), 472 | 'optimizer': optimizer.state_dict(), 473 | 'lr_scheduler': lr_scheduler.state_dict(), 474 | 'epoch': epoch, 475 | 'scaler': loss_scaler.state_dict(), 476 | 'args': args, 477 | }, checkpoint_path) 478 | 479 | test_stats = evaluate(data_loader_val, model, device) 480 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 481 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 482 | print(f'Max accuracy: {max_accuracy:.2f}%') 483 | ######################################### 484 | # #added by badri 485 | # print('Test accuracy log Acc@1', test_stats["acc1"], 'Acc@5',test_stats["acc5"]) 486 | # print('Loss log test loss', test_stats["loss"],'train loss', train_stats["loss"]) 487 | train_loss_history.append(train_stats["loss"]) 488 | val_loss_history.append(test_stats["loss"]) 489 | val_acc1_history.append(test_stats["acc1"]) 490 | val_acc5_history.append(test_stats["acc5"]) 491 | 492 | plt.figure() 493 | plt.plot(train_loss_history,label='train') 494 | plt.plot(val_loss_history,label='val') 495 | plot_loss_path = output_dir / 'loss_plot.png' 496 | plt.savefig(plot_loss_path) 497 | plt.figure() 498 | plt.plot(val_acc1_history,label='acc1') 499 | plt.plot(val_acc5_history,label='acc5') 500 | plot_acc_path = output_dir / 'acc_plot.png' 501 | plt.savefig(plot_acc_path) 502 | ######################################### 503 | 504 | if max_accuracy == test_stats["acc1"]: 505 | checkpoint_path = output_dir / 'checkpoint_best.pth' 506 | if model_ema is not None: 507 | utils.save_on_master({ 508 | 'model': model_without_ddp.state_dict(), 509 | 'optimizer': optimizer.state_dict(), 510 | 'lr_scheduler': lr_scheduler.state_dict(), 511 | 'epoch': epoch, 512 | 'model_ema': get_state_dict(model_ema), 513 | 'scaler': loss_scaler.state_dict(), 514 | 'args': args, 515 | }, checkpoint_path) 516 | else: 517 | utils.save_on_master({ 518 | 'model': model_without_ddp.state_dict(), 519 | 'optimizer': optimizer.state_dict(), 520 | 'lr_scheduler': lr_scheduler.state_dict(), 521 | 'epoch': epoch, 522 | 'scaler': loss_scaler.state_dict(), 523 | 'args': args, 524 | }, checkpoint_path) 525 | 526 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 527 | **{f'test_{k}': v for k, v in test_stats.items()}, 528 | 'epoch': epoch, 529 | 'n_parameters': n_parameters} 530 | 531 | if args.output_dir and utils.is_main_process(): 532 | with (output_dir / "log.txt").open("a") as f: 533 | f.write(json.dumps(log_stats) + "\n") 534 | 535 | total_time = time.time() - start_time 536 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 537 | print('Training time {}'.format(total_time_str)) 538 | 539 | 540 | if __name__ == '__main__': 541 | parser = argparse.ArgumentParser('spectformer training and evaluation script', parents=[get_args_parser()]) 542 | args = parser.parse_args() 543 | if args.output_dir: 544 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 545 | main(args) 546 | -------------------------------------------------------------------------------- /vanilla_architecture/main_spectformer_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | 9 | from pathlib import Path 10 | 11 | from timm.data import Mixup 12 | from timm.models import create_model 13 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 14 | from timm.scheduler import create_scheduler 15 | from timm.optim import create_optimizer 16 | from timm.utils import NativeScaler, get_state_dict, ModelEma 17 | from functools import partial 18 | import torch.nn as nn 19 | 20 | from datasets import build_dataset 21 | from engine import train_one_epoch, evaluate 22 | from losses import DistillationLoss 23 | from samplers import RASampler 24 | import utils 25 | from spectformer import SpectFormer 26 | 27 | import warnings 28 | warnings.filterwarnings("ignore", message="Argument interpolation should be") 29 | 30 | 31 | def get_args_parser(): 32 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 33 | parser.add_argument('--batch-size', default=64, type=int) 34 | parser.add_argument('--epochs', default=300, type=int) 35 | 36 | # Model parameters 37 | parser.add_argument('--arch', default='deit_small', type=str, 38 | help='Name of model to train') 39 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 40 | 41 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 42 | help='Dropout rate (default: 0.)') 43 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 44 | help='Drop path rate (default: 0.1)') 45 | 46 | parser.add_argument('--model-ema', action='store_true') 47 | parser.set_defaults(model_ema=False) 48 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 49 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 50 | 51 | # Optimizer parameters 52 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 53 | help='Optimizer (default: "adamw"') 54 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 55 | help='Optimizer Epsilon (default: 1e-8)') 56 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 57 | help='Optimizer Betas (default: None, use opt default)') 58 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', 59 | help='Clip gradient norm (default: None, no clipping)') 60 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 61 | help='SGD momentum (default: 0.9)') 62 | parser.add_argument('--weight-decay', type=float, default=0.05, 63 | help='weight decay (default: 0.05)') 64 | # Learning rate schedule parameters 65 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 66 | help='LR scheduler (default: "cosine"') 67 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 68 | help='learning rate (default: 5e-4)') 69 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 70 | help='learning rate noise on/off epoch percentages') 71 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 72 | help='learning rate noise limit percent (default: 0.67)') 73 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 74 | help='learning rate noise std-dev (default: 1.0)') 75 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 76 | help='warmup learning rate (default: 1e-6)') 77 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 78 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 79 | 80 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 81 | help='epoch interval to decay LR') 82 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 83 | help='epochs to warmup LR, if scheduler supports') 84 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 85 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 86 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 87 | help='patience epochs for Plateau LR scheduler (default: 10') 88 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 89 | help='LR decay rate (default: 0.1)') 90 | 91 | # Augmentation parameters 92 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 93 | help='Color jitter factor (default: 0.4)') 94 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 95 | help='Use AutoAugment policy. "v0" or "original". " + \ 96 | "(default: rand-m9-mstd0.5-inc1)'), 97 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 98 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 99 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 100 | 101 | parser.add_argument('--repeated-aug', action='store_true') 102 | parser.set_defaults(repeated_aug=False) 103 | 104 | # * Random Erase params 105 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', 106 | help='Random erase prob (default: 0.25)') 107 | parser.add_argument('--remode', type=str, default='pixel', 108 | help='Random erase mode (default: "pixel")') 109 | parser.add_argument('--recount', type=int, default=1, 110 | help='Random erase count (default: 1)') 111 | parser.add_argument('--resplit', action='store_true', default=False, 112 | help='Do not random erase first (clean) augmentation split') 113 | 114 | # * Mixup params 115 | parser.add_argument('--mixup', type=float, default=0.8, 116 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 117 | parser.add_argument('--cutmix', type=float, default=1.0, 118 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 119 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 121 | parser.add_argument('--mixup-prob', type=float, default=1.0, 122 | help='Probability of performing mixup or cutmix when either/both is enabled') 123 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 124 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 125 | parser.add_argument('--mixup-mode', type=str, default='batch', 126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 127 | 128 | # Distillation parameters 129 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 130 | help='Name of teacher model to train (default: "regnety_160"') 131 | parser.add_argument('--teacher-path', type=str, default='') 132 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 133 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 134 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 135 | 136 | # * Finetuning params 137 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 138 | 139 | # Dataset parameters 140 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 141 | help='dataset path') 142 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19', 'CARS', 'FLOWERS'], 143 | type=str, help='Image Net dataset path') 144 | parser.add_argument('--inat-category', default='name', 145 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 146 | type=str, help='semantic granularity') 147 | 148 | parser.add_argument('--output_dir', default='', 149 | help='path where to save, empty for no saving') 150 | parser.add_argument('--device', default='cuda', 151 | help='device to use for training / testing') 152 | parser.add_argument('--seed', default=0, type=int) 153 | parser.add_argument('--resume', default='', help='resume from checkpoint') 154 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 155 | help='start epoch') 156 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 157 | parser.add_argument('--dist-eval', action='store_true', default=True, help='Enabling distributed evaluation') 158 | parser.add_argument('--num_workers', default=10, type=int) 159 | parser.add_argument('--pin-mem', action='store_true', 160 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 161 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 162 | help='') 163 | parser.set_defaults(pin_mem=True) 164 | 165 | # distributed training parameters 166 | parser.add_argument('--world_size', default=1, type=int, 167 | help='number of distributed processes') 168 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 169 | return parser 170 | 171 | 172 | def main(args): 173 | utils.init_distributed_mode(args) 174 | 175 | print(args) 176 | 177 | if args.distillation_type != 'none' and args.finetune and not args.eval: 178 | raise NotImplementedError("Finetuning with distillation not yet supported") 179 | 180 | device = torch.device(args.device) 181 | 182 | # fix the seed for reproducibility 183 | seed = args.seed + utils.get_rank() 184 | torch.manual_seed(seed) 185 | np.random.seed(seed) 186 | # random.seed(seed) 187 | 188 | cudnn.benchmark = True 189 | 190 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 191 | dataset_val, _ = build_dataset(is_train=False, args=args) 192 | 193 | if True: # args.distributed: 194 | num_tasks = utils.get_world_size() 195 | global_rank = utils.get_rank() 196 | if args.repeated_aug: 197 | sampler_train = RASampler( 198 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 199 | ) 200 | else: 201 | sampler_train = torch.utils.data.DistributedSampler( 202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 203 | ) 204 | if args.dist_eval: 205 | if len(dataset_val) % num_tasks != 0: 206 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 207 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 208 | 'equal num of samples per-process.') 209 | sampler_val = torch.utils.data.DistributedSampler( 210 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 211 | else: 212 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 213 | else: 214 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 215 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 216 | 217 | data_loader_train = torch.utils.data.DataLoader( 218 | dataset_train, sampler=sampler_train, 219 | batch_size=args.batch_size, 220 | num_workers=args.num_workers, 221 | pin_memory=args.pin_mem, 222 | drop_last=True, 223 | ) 224 | 225 | data_loader_val = torch.utils.data.DataLoader( 226 | dataset_val, sampler=sampler_val, 227 | batch_size=int(1.5 * args.batch_size), 228 | num_workers=args.num_workers, 229 | pin_memory=args.pin_mem, 230 | drop_last=False 231 | ) 232 | 233 | mixup_fn = None 234 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 235 | if mixup_active: 236 | print('standard mix up') 237 | mixup_fn = Mixup( 238 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 239 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 240 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 241 | else: 242 | print('mix up is not used') 243 | 244 | print(f"Creating model: {args.arch}") 245 | 246 | if args.arch == 'spectformer-xs': 247 | model = SpectFormer( 248 | img_size=args.input_size, 249 | patch_size=16, embed_dim=384, depth=12, mlp_ratio=4, 250 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 251 | ) 252 | elif args.arch == 'spectformer-ti': 253 | model = SpectFormer( 254 | img_size=args.input_size, 255 | patch_size=16, embed_dim=256, depth=12, mlp_ratio=4, 256 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 257 | ) 258 | elif args.arch == 'spectformer-s': 259 | model = SpectFormer( 260 | img_size=args.input_size, 261 | patch_size=16, embed_dim=384, depth=19, mlp_ratio=4, drop_path_rate=0.15, 262 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 263 | ) 264 | elif args.arch == 'spectformer-b': 265 | model = SpectFormer( 266 | img_size=args.input_size, 267 | patch_size=16, embed_dim=512, depth=19, mlp_ratio=4, drop_path_rate=0.25, 268 | norm_layer=partial(nn.LayerNorm, eps=1e-6) 269 | ) 270 | else: 271 | raise NotImplementedError 272 | 273 | if args.finetune: 274 | if args.finetune.startswith('https'): 275 | checkpoint = torch.hub.load_state_dict_from_url( 276 | args.finetune, map_location='cpu', check_hash=True) 277 | else: 278 | checkpoint = torch.load(args.finetune, map_location='cpu') 279 | 280 | checkpoint_model = checkpoint['model'] 281 | state_dict = model.state_dict() 282 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias', 'aux_head.weight', 'aux_head.bias']: 283 | if k in checkpoint_model: 284 | print(f"Removing key {k} from pretrained checkpoint") 285 | del checkpoint_model[k] 286 | a, b = model.load_state_dict(checkpoint_model, strict=False) 287 | print(a, b) 288 | 289 | model.to(device) 290 | 291 | model_ema = None 292 | if args.model_ema: 293 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 294 | model_ema = ModelEma( 295 | model, 296 | decay=args.model_ema_decay, 297 | device='cpu' if args.model_ema_force_cpu else '', 298 | resume='') 299 | 300 | model_without_ddp = model 301 | if args.distributed: 302 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 303 | model_without_ddp = model.module 304 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 305 | print('number of params:', n_parameters) 306 | 307 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 308 | args.lr = linear_scaled_lr 309 | optimizer = create_optimizer(args, model_without_ddp) 310 | loss_scaler = NativeScaler() 311 | 312 | lr_scheduler, _ = create_scheduler(args, optimizer) 313 | 314 | criterion = LabelSmoothingCrossEntropy() 315 | 316 | if args.mixup > 0.: 317 | # smoothing is handled with mixup label transform 318 | criterion = SoftTargetCrossEntropy() 319 | elif args.smoothing: 320 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 321 | else: 322 | criterion = torch.nn.CrossEntropyLoss() 323 | 324 | teacher_model = None 325 | if args.distillation_type != 'none': 326 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 327 | print(f"Creating teacher model: {args.teacher_model}") 328 | teacher_model = create_model( 329 | args.teacher_model, 330 | pretrained=False, 331 | num_classes=args.nb_classes, 332 | global_pool='avg', 333 | ) 334 | if args.teacher_path.startswith('https'): 335 | checkpoint = torch.hub.load_state_dict_from_url( 336 | args.teacher_path, map_location='cpu', check_hash=True) 337 | else: 338 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 339 | teacher_model.load_state_dict(checkpoint['model']) 340 | teacher_model.to(device) 341 | teacher_model.eval() 342 | 343 | # wrap the criterion in our custom DistillationLoss, which 344 | # just dispatches to the original criterion if args.distillation_type is 'none' 345 | 346 | 347 | criterion = DistillationLoss( 348 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 349 | ) 350 | 351 | output_dir = Path(args.output_dir) 352 | if args.resume: 353 | if args.resume.startswith('https'): 354 | checkpoint = torch.hub.load_state_dict_from_url( 355 | args.resume, map_location='cpu', check_hash=True) 356 | else: 357 | checkpoint = torch.load(args.resume, map_location='cpu') 358 | model_without_ddp.load_state_dict(checkpoint['model']) 359 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 360 | optimizer.load_state_dict(checkpoint['optimizer']) 361 | print('lr scheduler will not be updated') 362 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 363 | args.start_epoch = checkpoint['epoch'] + 1 364 | if args.model_ema: 365 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 366 | if 'scaler' in checkpoint: 367 | loss_scaler.load_state_dict(checkpoint['scaler']) 368 | 369 | if args.eval: 370 | test_stats = evaluate(data_loader_val, model, device) 371 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 372 | return 373 | 374 | print(f"Start training for {args.epochs} epochs") 375 | start_time = time.time() 376 | max_accuracy = 0.0 377 | for epoch in range(args.start_epoch, args.epochs): 378 | if args.distributed: 379 | data_loader_train.sampler.set_epoch(epoch) 380 | 381 | train_stats = train_one_epoch( 382 | model, criterion, data_loader_train, 383 | optimizer, device, epoch, loss_scaler, 384 | args.clip_grad, model_ema, mixup_fn, 385 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 386 | ) 387 | 388 | lr_scheduler.step(epoch) 389 | 390 | if (epoch + 1) % 100 == 0: 391 | file_name = 'checkpoint_epoch%d.pth' % epoch 392 | checkpoint_path = output_dir / file_name 393 | if model_ema is not None: 394 | utils.save_on_master({ 395 | 'model': model_without_ddp.state_dict(), 396 | 'optimizer': optimizer.state_dict(), 397 | 'lr_scheduler': lr_scheduler.state_dict(), 398 | 'epoch': epoch, 399 | 'model_ema': get_state_dict(model_ema), 400 | 'scaler': loss_scaler.state_dict(), 401 | 'args': args, 402 | }, checkpoint_path) 403 | else: 404 | utils.save_on_master({ 405 | 'model': model_without_ddp.state_dict(), 406 | 'optimizer': optimizer.state_dict(), 407 | 'lr_scheduler': lr_scheduler.state_dict(), 408 | 'epoch': epoch, 409 | 'scaler': loss_scaler.state_dict(), 410 | 'args': args, 411 | }, checkpoint_path) 412 | 413 | test_stats = evaluate(data_loader_val, model, device) 414 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 415 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 416 | print(f'Max accuracy: {max_accuracy:.2f}%') 417 | 418 | if max_accuracy == test_stats["acc1"]: 419 | checkpoint_path = output_dir / 'checkpoint_best.pth' 420 | if model_ema is not None: 421 | utils.save_on_master({ 422 | 'model': model_without_ddp.state_dict(), 423 | 'optimizer': optimizer.state_dict(), 424 | 'lr_scheduler': lr_scheduler.state_dict(), 425 | 'epoch': epoch, 426 | 'model_ema': get_state_dict(model_ema), 427 | 'scaler': loss_scaler.state_dict(), 428 | 'args': args, 429 | }, checkpoint_path) 430 | else: 431 | utils.save_on_master({ 432 | 'model': model_without_ddp.state_dict(), 433 | 'optimizer': optimizer.state_dict(), 434 | 'lr_scheduler': lr_scheduler.state_dict(), 435 | 'epoch': epoch, 436 | 'scaler': loss_scaler.state_dict(), 437 | 'args': args, 438 | }, checkpoint_path) 439 | 440 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 441 | **{f'test_{k}': v for k, v in test_stats.items()}, 442 | 'epoch': epoch, 443 | 'n_parameters': n_parameters} 444 | 445 | if args.output_dir and utils.is_main_process(): 446 | with (output_dir / "log.txt").open("a") as f: 447 | f.write(json.dumps(log_stats) + "\n") 448 | 449 | total_time = time.time() - start_time 450 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 451 | print('Training time {}'.format(total_time_str)) 452 | 453 | 454 | if __name__ == '__main__': 455 | parser = argparse.ArgumentParser('spectformer training and evaluation script', parents=[get_args_parser()]) 456 | args = parser.parse_args() 457 | if args.output_dir: 458 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 459 | main(args) 460 | -------------------------------------------------------------------------------- /vanilla_architecture/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /vanilla_architecture/spectformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from copy import Error, deepcopy 6 | from re import S 7 | from numpy.lib.arraypad import pad 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 15 | import torch.fft 16 | from torch.nn.modules.container import Sequential 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | def _cfg(url='', **kwargs): 22 | return { 23 | 'url': url, 24 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 25 | 'crop_pct': .9, 'interpolation': 'bicubic', 26 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 27 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 28 | **kwargs 29 | } 30 | 31 | class Attention(nn.Module): 32 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 33 | super().__init__() 34 | self.num_heads = num_heads 35 | self.dim = dim 36 | head_dim = dim // num_heads 37 | self.scale = qk_scale or head_dim ** -0.5 38 | 39 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 40 | self.attn_drop = nn.Dropout(attn_drop) 41 | self.proj = nn.Linear(dim, dim) 42 | self.proj_drop = nn.Dropout(proj_drop) 43 | 44 | def forward(self, x): 45 | B, N, C = x.shape 46 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 47 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 48 | 49 | attn = (q @ k.transpose(-2, -1)) * self.scale 50 | attn = attn.softmax(dim=-1) 51 | attn = self.attn_drop(attn) 52 | 53 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 54 | x = self.proj(x) 55 | x = self.proj_drop(x) 56 | return x 57 | 58 | class Mlp(nn.Module): 59 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 60 | super().__init__() 61 | out_features = out_features or in_features 62 | hidden_features = hidden_features or in_features 63 | self.fc1 = nn.Linear(in_features, hidden_features) 64 | self.act = act_layer() 65 | self.fc2 = nn.Linear(hidden_features, out_features) 66 | self.drop = nn.Dropout(drop) 67 | 68 | def forward(self, x): 69 | x = self.fc1(x) 70 | x = self.act(x) 71 | x = self.drop(x) 72 | x = self.fc2(x) 73 | x = self.drop(x) 74 | return x 75 | 76 | class SpectralGatingNetwork(nn.Module): 77 | def __init__(self, dim, h=14, w=8): 78 | super().__init__() 79 | self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02) 80 | self.w = w 81 | self.h = h 82 | 83 | def forward(self, x, spatial_size=None): 84 | B, N, C = x.shape 85 | if spatial_size is None: 86 | a = b = int(math.sqrt(N)) 87 | else: 88 | a, b = spatial_size 89 | 90 | x = x.view(B, a, b, C) 91 | 92 | x = x.to(torch.float32) 93 | 94 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 95 | weight = torch.view_as_complex(self.complex_weight) 96 | x = x * weight 97 | x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho') 98 | 99 | x = x.reshape(B, N, C) 100 | 101 | return x 102 | 103 | class Block(nn.Module): 104 | 105 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8): 106 | super().__init__() 107 | self.norm1 = norm_layer(dim) 108 | self.filter = SpectralGatingNetwork(dim, h=h, w=w) 109 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 110 | self.norm2 = norm_layer(dim) 111 | mlp_hidden_dim = int(dim * mlp_ratio) 112 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 113 | 114 | def forward(self, x): 115 | x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x))))) 116 | return x 117 | 118 | class Block_attention(nn.Module): 119 | 120 | def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8): 121 | super().__init__() 122 | num_heads= 6 # 4 for tiny, 6 for small and 12 for base 123 | self.norm1 = norm_layer(dim) 124 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 125 | self.norm2 = norm_layer(dim) 126 | mlp_hidden_dim = int(dim * mlp_ratio) 127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 128 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=True, qk_scale=False, attn_drop=drop, proj_drop=drop) 129 | 130 | def forward(self, x): 131 | # x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x))))) 132 | x = x + self.drop_path(self.attn(self.norm1(x))) 133 | x = x + self.drop_path(self.mlp(self.norm2(x))) 134 | return x 135 | 136 | 137 | class PatchEmbed(nn.Module): 138 | """ Image to Patch Embedding 139 | """ 140 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 141 | super().__init__() 142 | img_size = to_2tuple(img_size) 143 | patch_size = to_2tuple(patch_size) 144 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 145 | self.img_size = img_size 146 | self.patch_size = patch_size 147 | self.num_patches = num_patches 148 | 149 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 150 | 151 | def forward(self, x): 152 | B, C, H, W = x.shape 153 | # FIXME look at relaxing size constraints 154 | assert H == self.img_size[0] and W == self.img_size[1], \ 155 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 156 | x = self.proj(x).flatten(2).transpose(1, 2) 157 | return x 158 | 159 | 160 | class DownLayer(nn.Module): 161 | """ Image to Patch Embedding 162 | """ 163 | def __init__(self, img_size=56, dim_in=64, dim_out=128): 164 | super().__init__() 165 | self.img_size = img_size 166 | self.dim_in = dim_in 167 | self.dim_out = dim_out 168 | self.proj = nn.Conv2d(dim_in, dim_out, kernel_size=2, stride=2) 169 | self.num_patches = img_size * img_size // 4 170 | 171 | def forward(self, x): 172 | B, N, C = x.size() 173 | x = x.view(B, self.img_size, self.img_size, C).permute(0, 3, 1, 2) 174 | x = self.proj(x).permute(0, 2, 3, 1) 175 | x = x.reshape(B, -1, self.dim_out) 176 | return x 177 | 178 | class SpectFormer(nn.Module): 179 | 180 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 181 | mlp_ratio=4., representation_size=None, uniform_drop=False, 182 | drop_rate=0., drop_path_rate=0., norm_layer=None, 183 | dropcls=0): 184 | """ 185 | Args: 186 | img_size (int, tuple): input image size 187 | patch_size (int, tuple): patch size 188 | in_chans (int): number of input channels 189 | num_classes (int): number of classes for classification head 190 | embed_dim (int): embedding dimension 191 | depth (int): depth of transformer 192 | num_heads (int): number of attention heads 193 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 194 | qkv_bias (bool): enable bias for qkv if True 195 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 196 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 197 | drop_rate (float): dropout rate 198 | attn_drop_rate (float): attention dropout rate 199 | drop_path_rate (float): stochastic depth rate 200 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 201 | norm_layer: (nn.Module): normalization layer 202 | """ 203 | super().__init__() 204 | self.num_classes = num_classes 205 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 206 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 207 | 208 | self.patch_embed = PatchEmbed( 209 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 210 | num_patches = self.patch_embed.num_patches 211 | 212 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 213 | self.pos_drop = nn.Dropout(p=drop_rate) 214 | 215 | h = img_size // patch_size 216 | w = h // 2 + 1 217 | 218 | if uniform_drop: 219 | print('using uniform droppath with expect rate', drop_path_rate) 220 | dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule 221 | else: 222 | print('using linear droppath with expect rate', drop_path_rate * 0.5) 223 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 224 | # dpr = [drop_path_rate for _ in range(depth)] # stochastic depth decay rule 225 | 226 | alpha=4 227 | self.blocks = nn.ModuleList() 228 | for i in range(depth): 229 | if i 0 else nn.Identity() 250 | 251 | if dropcls > 0: 252 | print('dropout %.2f before classifier' % dropcls) 253 | self.final_dropout = nn.Dropout(p=dropcls) 254 | else: 255 | self.final_dropout = nn.Identity() 256 | 257 | trunc_normal_(self.pos_embed, std=.02) 258 | self.apply(self._init_weights) 259 | 260 | def _init_weights(self, m): 261 | if isinstance(m, nn.Linear): 262 | trunc_normal_(m.weight, std=.02) 263 | if isinstance(m, nn.Linear) and m.bias is not None: 264 | nn.init.constant_(m.bias, 0) 265 | elif isinstance(m, nn.LayerNorm): 266 | nn.init.constant_(m.bias, 0) 267 | nn.init.constant_(m.weight, 1.0) 268 | 269 | @torch.jit.ignore 270 | def no_weight_decay(self): 271 | return {'pos_embed', 'cls_token'} 272 | 273 | def get_classifier(self): 274 | return self.head 275 | 276 | def reset_classifier(self, num_classes, global_pool=''): 277 | self.num_classes = num_classes 278 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 279 | 280 | def forward_features(self, x): 281 | B = x.shape[0] 282 | x = self.patch_embed(x) 283 | x = x + self.pos_embed 284 | x = self.pos_drop(x) 285 | 286 | for blk in self.blocks: 287 | x = blk(x) 288 | 289 | x = self.norm(x).mean(1) 290 | return x 291 | 292 | def forward(self, x): 293 | x = self.forward_features(x) 294 | x = self.final_dropout(x) 295 | x = self.head(x) 296 | return x 297 | 298 | 299 | def resize_pos_embed(posemb, posemb_new): 300 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 301 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 302 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 303 | ntok_new = posemb_new.shape[1] 304 | if True: 305 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 306 | ntok_new -= 1 307 | else: 308 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 309 | gs_old = int(math.sqrt(len(posemb_grid))) 310 | gs_new = int(math.sqrt(ntok_new)) 311 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 312 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 313 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') 314 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 315 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 316 | return posemb 317 | 318 | 319 | def checkpoint_filter_fn(state_dict, model): 320 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 321 | out_dict = {} 322 | if 'model' in state_dict: 323 | # For deit models 324 | state_dict = state_dict['model'] 325 | for k, v in state_dict.items(): 326 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 327 | # For old models that I trained prior to conv based patchification 328 | O, I, H, W = model.patch_embed.proj.weight.shape 329 | v = v.reshape(O, -1, H, W) 330 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 331 | # To resize pos embedding when using model at different size from pretrained weights 332 | v = resize_pos_embed(v, model.pos_embed) 333 | out_dict[k] = v 334 | return out_dict 335 | -------------------------------------------------------------------------------- /vanilla_architecture/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | 240 | 241 | def batch_index_select(x, idx): 242 | if len(x.size()) == 3: 243 | B, N, C = x.size() 244 | N_new = idx.size(1) 245 | offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N 246 | idx = idx + offset 247 | out = x.reshape(B*N, C)[idx.reshape(-1)].reshape(B, N_new, C) 248 | return out 249 | elif len(x.size()) == 2: 250 | B, N = x.size() 251 | N_new = idx.size(1) 252 | offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N 253 | idx = idx + offset 254 | out = x.reshape(B*N)[idx.reshape(-1)].reshape(B, N_new) 255 | return out 256 | else: 257 | raise NotImplementedError --------------------------------------------------------------------------------